Multiclass and Multilabel Text Classification in One BERT Model

1. Introduction

Multiclass and multilabel classifications are two familiar tasks to most data scientists. The most well-known approach to these problems is to model them as optimization problems which minimize a different objective function for each kind :

1.1 Problem

In some situations, we face both of these kinds of classification in the same project. Some constraints oblige us to use just one model to perform several multilabel and multiclass classification tasks on different groups of categories. Such constraints or situations may be related to :

  • Platform resource: lack of GPU or memory, which cannot load many models simultaneously.
  • Response time requirement: we must predict everything within one second, so we call the models as few as possible.
  • Non-priority of the problem: We just use the models to suggest something; we don’t require them to be too exact. Most false positives and false negatives are tolerable.

1.2 Use-case Examples

Here are examples of multi-task classifications in the same set of input :

  1. Given an image of an animal, predict the species (multiclass), its health situation (maybe multilabel, due to a list of predefined situations like “dead“, “alive“, “unhealthy“, “critics“)
  2. Given a response from a survey, predict the emotions (multilabel), a general feeling score (multiclass, one class between 1 star to 5 stars), and the “causes” of those emotions (multilabel).

1.3 Objectives

One approach to deal with such examples is to model everything within a multilabel classification problem by optimizing one binary cross-entropy loss function. Then at prediction time, for classes belonging to the multiclass group, one can assign the text to the class with the highest probability computed by the sigmoid function. This approach can work, but after some experiences, we prefer to use a combined loss where softmax is applied for multiclass classes and sigmoid is applied for multilabel classes.

In this article, we will illustrate how to combine multiclass and multilabel tasks in one model with example 2, whose inputs are texts. It is also an opportunity to revisit how to fine-tune a BERT model with the famous huggingface's transformers library.

The remaining of this article is as follows :

  • Section 2 — Description of the dataset and the goal of this example
  • Section 3 — How to fine-tune the model. We will focus on the combined loss and the metrics.
  • Section 4 — Some remarks.

2. Example Description

2.1 Dataset

We reuse the dataset as in our previous articles (Détection d’expressions de tonalités-sentiments dans un document avec un algorithme de NER … ,Regression with Text Input Using BERT and Transformers ). The dataset is collected from public comments online (google maps reviews, Trustpilot etc.) about public services. The data is split into train and validation parts:

Here is a row of the dataset:

It contains an identifier, the text, and the following fields:

  • global_score for the global feeling of the commentator. Its value is one string among S1 (very positive), S2 (positive), S3 (neutral), S4 (negative) and S5 (very negative)
  • emotions for all of his emotions. Its value is a list of strings from the 10 classes E1 to E11. You can find their meaning below.
  • causes for all “causes“ (motifs) mentioned in the comment. Its value is a list of strings from the 12 classes C1 to C12. You can find their meaning below.

Here is a python dictionary containing the meaning of the classes. (Notes: All code parts of this article can be put into a python3-Jupyter notebook)

2.2 Distribution of the classes:

For global feeling:

Figure 1 — Distribution of the global feelings

For emotions (similar code):

Figure 2 — Distribution of the emotions

For causes (similar code):

Figure 3 — Distribution of the causes

Goal

We will fine-tune one BERT model with huggingface 's transformers which performs simultaneously three tasks for the input text :

  • Predict its global_score - a single class among S1 - S5, as an output of the multiclass classification problem on 5 classes.
  • Predict its emotions - a list (can be empty) of classes among E1 - E11 (except E9), as an output of the multilabel classification problem on 10 classes.
  • Predict its causes - a list (can be empty) of classes among C1 - C12, as an output of the multilabel classification problem on 12 classes.

Altogether, we have 27 classes.

3. Fine-tuning BERT Model

3.1 Setup

Let’s start by importing the necessary modules and defining some constants for hyperparameters like the base model, learning rate, batch size, number of epochs for training and max length of text.

We load the well-known French “camembert-base“ backbone model and tokenizer.

3.2 Preprocess the dataset

We load the dataset:

We tokenize the text (input) and represent the label (output) as a 27-dimensional vector.

An example of output

Output

It looks good, the emotions, global feeling and causes are encoded into 0–1 vectors. Let’s apply the preprocessing function to the whole dataset.

3.3 Prepare the metrics

During the training phase, the trained model will compute the logits of each input. The Huggingface trainer object will deduce the predicted vectors from these logits, compare these predicted vectors with the labels, and return metrics like accuracy, precision, recall and f1-score on the whole validation set.

As a reminder,

  • for multiclass classification, the predicted vector is deduced by putting 1 to the class with the highest logit and 0 to all the other classes. (Equivalently, if we compute softmax on all the classes, the class with the highest logit will get the highest probability).
  • for multilabel classification, we choose the threshold = 0 (with the sigmoid function this threshold corresponds to the probability = 0.5). The predicted vector is deduced by putting 1 to all non-negative logits and 0 to all negative logits (which means the input with a probability >= 0.5 belongs to the corresponding class).

Let’s apply the logic above. We first split the column indexes (0 to 26, as we have 27 columns) into groups dedicated to global feelings, emotions and causes.

Now define a function to apply the two logics for multiclass columns and multilabel columns

Let’s look at an example by generating a random 27-dimensional vector whose columns are between -2 and 2

Output

Now compute the predicted vector

Output

As expected, there is only one digit “1” among the first 5 columns (corresponding to the multiclass task). For the remaining columns (of indices 5 to 26) which correspond to multilabel tasks, the “1“ are found if the input is non-negative.

Integration: Put the predicted vector in our custom metrics function

Having the predicted vector (in 0–1 encoded form), we can “compare“ them with the labels and define any metrics we wish. For example, in the following compute_metrics function, we want to compute f1_micro and f1_macro for every group: global feeling, emotions and causes, as well as a global f1_micro and f1_macro for all the classes. We also print the classification report (which prints metrics on every single class). This way, during the training phase, we can follow the model’s performance on each of our 3 problems but also in a single class.

3.4 Prepare the Loss Function

The combined loss function is a weighted sum of 3 components: cross-entropy loss function on the first 5 columns (the GLOBAL_SCORE_INDICES columns, used for global feelings), binary cross-entropy loss function on the next 10 columns (the EMOTION_INDICES columns) and binary cross-entropy loss function on the next 12 columns (the CAUSE_INDICES columns). Those loss functions are implemented with torch.nn.functional .

As it’s a weighted sum, a natural question is which weight (coefficient) to put for each component loss? We can answer this question by experiment.

  • First, do classification on the first group only. We found that the minimum is around 1.2–1.36.
  • Do classification on the second (third) group only. The minimum is found around 0.25

So, we suggest the weights like (1/1.3, 1/0.25, 1/0.25) = (0.7, 4, 4). To implement the custom loss, we override the class Trainer of huggingface’s transformers by redefining __init__, compute_loss and maybe other functions. We introduce an attribute group_weights to vary the weights if needed.

3.5 Put everything together

We are almost ready for fine-tuning. Let’s define a small callback to print the epoch number at each step because we will use it to follow the classification reports (as mentioned in 3.3).

Now, put all training arguments together

and train:

Logs (We only display here the beginning and end of the logs)

3.6 Observations during training

  • The single model can perform 2 kinds of tasks simultaneously. We can verify by watching the classification reports that the first group is multiclass (minor precision = minor recall = minor f1 = accuracy), and the second and the third groups are multilabel.
  • Some tasks “converge“ more quickly than others. For our cases, the first group (global feeling) attains its best performance (macro f1: 0.6) after 10–15 epochs; the emotion group and the cause group converge more slowly.
  • The model performs poorly on emotions and causes classes with few samples (emotion “Fear“, “Sadness“; cause “Complexity/Simplicity“)

Let’s evaluate again on the validation test

3.7 Prediction on Some Examples

Let’s use our model to predict some new simple examples.

Output

4. Some Remarks

After doing experiments on several projects, we also notice that:

  • With this approach, the best performance in each group when we train with this combined model is slightly poorer than the best performance if we do on a single task with the same BERT backbone (in our example, by 0.01 to 0.04 points). It can be improved by adding more data or by more carefully selecting hypermeters. So, the combined model’s approach is completely acceptable and can be used in production, especially when we have constraints on server resources or requirements of response time.
  • If we do not use group weights to sum the loss functions, or if the weights are not well chosen, the component losses may not be of the same magnitude. In that cases :
    – If the component losses are also comparable (for example, if we put the coefficient of the global feeling group = 2 and of the cause and emotion group = 1), the model will converge more slowly for the less-dominant tasks but can end up with good performance (good metrics) on all tasks.
    – If one of the component losses becomes too dominant (for example, if we put the coefficient of the global feeling group = 100 and of the cause and emotion group = 1), the model will perform well on the corresponding group (global feeling) while cannot learn anything for others (causes and emotions).
  • The approach can be generalized for regression as well. In fact, we just replace the cross-entropy loss and binary cross-entropy loss with mean-squared loss. We can then define a weighted sum of the three kinds of loss. We can also define custom metrics to evaluate the model’s performance during training.
  • The poor performance of the model in some classes is due to a lack of data and not due to the combination approach itself. In fact, the "camembert-base" backbone model is complex enough to deal with 27 columns at the same time in our example above. In our real projects, the method was also applied for hundreds of columns.

We hope that this article has well clarified the idea of using one model for several classification tasks on text problems.

5. References

  1. Détection d’expressions de tonalités-sentiments dans un document avec un algorithme de NER …
  2. Regression with Text Input Using BERT and Transformers
  3. https://towardsdatascience.com/cross-entropy-loss-function-f38c4ec8643e
  4. Binary Cross Entropy/Log Loss for Binary Classification

6. Remerciement

Merci à nos collègues Huong Nguyen et Jean-Baptiste Barin pour la revue de l’article.

7. A propos de l’auteur

Nhut DOAN NGUYEN est data scientist à La Javaness depuis 2021.

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
La Javaness R&D

La Javaness R&D

We help organizations to succeed in the new paradigm of “AI@scale”, by using machine intelligence responsibly and efficiently : www.lajavaness.com