Finetune a Span Categorizer with BERT and Transformers

La Javaness R&D
13 min readSep 29, 2022

--

1. Token Classification vs. Span Categorization

Token Classification

In the NLP world, we are familiar with applications of token classification tasks like Named Entity Recognition (NER) or Part-of-Speech Tagging. The idea of token classification is to assign to each token their associated class (or entity, or tag). As an example for NER tasks, let’s take an arbitrary sentence:

“Charles André Joseph Marie de Gaulle was born on 22 November 1890 in Lille in the Nord department.”

and tokenize it into a sequence of tokens like:

["Charles", "André", "Joseph", "Marie", "de", "G", "##aul", "##le", "was", "born", "on", "22", "November", "1890", "in", "Lille", "in", "the", "Nord", "department", "."]

(The ## characters mean that the token is part of a word, it is preceded by a word that has not terminated.)

With a token classification, the expected output is like this:

Expected output for a token classification model

where each token is associated with at most one entity (PER for person, DAT for date, LOC for location).

In practice, for a given tag like “PER”, we can distinguish two labels “B-PER” and “I-PER” to represent the beginning token of the detected span and the “inner” (non-beginning) tokens which follow — this mechanism is called “iob” and is one of the options to help us delimit two spans of the same tag who lie one next to the other.

As a consequence, in the output of a token classification task, two segments (spans) associated with different types (like PER and ORG) must not overlap.

Span Categorization

In reality, sometimes we need to tag several segments that may overlap, as shown in the following example.

Expected output for a span categorization model

As shown above, for some reason, we want to detect people, locations, dates but also numbers, subject(s) of the sentence and the event associated simultaneously in a text. The outputted tagged segments can overlap, like:

  • the DAT segment “22 November 1890” and the NUMBER segment “22” overlap on “22”.
  • the span “Charles André Joseph Marie de Gaulle” is tagged both PER and SUBJ.

The example above gives us an example of span categorization problems. Their solutions are called span categorizers. The spans outputted by a span categorizer can overlap while those by a token classifier (name entity recognizer or POS tagger) are not.

Objective

Span categorizers are introduced with spaCy 3.1. At the time this article is written, it is not yet supported by transformers. However, if you are interested in BERT models for span categorization, this article (tutorial) will present simple ways to convert a RoBERTa token classification model into a span categorization. It is based on the simple fact that span categorization can be modelled as a multilabel token classification.

Span Categorizer was introduced with SpaCy 3. Image source: prodi.gy

Indeed, if we break it down into token level, then Token classification is based on multiclass classification at the token level, whereas Span Categorization is based on multilabel classification at the token level (to remind the difference, see here). To convert multiclass classification into multilabel classification, it suffices to change:

  • The loss function: Cross-entropy loss -> Binary cross-entropy loss.
  • The activation at the output layer of BERT model: Softmax -> Sigmoid.

During training (fine-tuning) and model evaluation, we also adapt some metrics, and that’s all.

Let’s practice.

2. Dataset

The training and validation parts of the dataset are found at here and here. They are part of the dataset wikiner.

Tags and labels

In the original dataset, the tags are PER, ORG, LOC and MISC. We annotated about 400 items in this dataset with some new tags:

NCHUNKnoun chunks: a group of words surrounding a noun to represent any entity/concept. e.g. ce mouvement de la gauche radicale

TIMEtime: a group of words representing time periods, e.g. Dans les années 30 du XXe siècle

PLACEplaces: a group of words representing locations/places with prepositions included, e.g. Dans le sud de l’Espagne.

So there are 7 tags in total. Let’s write them clearly in the code.

Output

{1: 'PER', 2: 'ORG', 3: 'LOC', 4: 'MISC', 5: 'NCHUNK', 6: 'TIME', 7: 'PLACE'}

Just like for “iob” approach in NER problems, for each tag above (e.g. “PER”), we define two labels starting with “B-” and “I-” to represent mark the beginning and inner tokens of a span for the given tag. Let also “O” denote the label for tokens belonging to no tag. Let’s write it in code:

Output

{0: 'O',
1: 'B-PER',
3: 'B-ORG',
5: 'B-LOC',
7: 'B-MISC',
9: 'B-NCHUNK',
11: 'B-TIME',
13: 'B-PLACE',
2: 'I-PER',
4: 'I-ORG',
6: 'I-LOC',
8: 'I-MISC',
10: 'I-NCHUNK',
12: 'I-TIME',
14: 'I-PLACE'}

From now on, by convention, we use tags for the original 7 classes in id2tag and labels for the 15 classes in id2label.

Data format

Let’s load the datasets and look at some lines of the training set.

Output

{'tags': [{'end': 32, 'start': 19, 'tag': 'PER'},
{'end': 32, 'start': 6, 'tag': 'NCHUNK'},
{'end': 152, 'start': 143, 'tag': 'NCHUNK'},
{'end': 225, 'start': 211, 'tag': 'NCHUNK'},
{'end': 79, 'start': 45, 'tag': 'NCHUNK'}],
'id': 'train_253',
'text': "Selon l'ethnologue Maurice Duval, « dire que ce mouvement de la gauche radicale est « une secte », ce n'est pas argumenter légitimement contre ses idées, mais c'est suggérer qu'il est malfaisant, malsain et que sa disparition serait souhaitable »."}

The “tags” field of each item is a list of spans together with their offsets (starting and end position in the sentence) and tags (the class). To see that these spans can be overlapped, let’s print some examples

Output

Selon l'ethnologue Maurice Duval, « dire que ce mouvement de la gauche radicale est « une secte », ce n'est pas argumenter légitimement contre ses idées, mais c'est suggérer qu'il est malfaisant, malsain et que sa disparition serait souhaitable ».
PER - Maurice Duval
NCHUNK - l'ethnologue Maurice Duval
NCHUNK - ses idées
NCHUNK - sa disparition
NCHUNK - ce mouvement de la gauche radicale
Adolescent, il joue de la basse dans un groupe de surf music, commence à composer et s'intéresse aux œuvres de musique contemporaine de compositeurs comme Charles Ives, Karlheinz Stockhausen, Mauricio Kagel, ou encore John Cage.
PER - Charles Ives
PER - Karlheinz Stockhausen
PER - Mauricio Kagel
PER - John Cage
NCHUNK - un groupe de surf music
NCHUNK - œuvres de musique contemporaine de compositeurs comme Charles Ives, Karlheinz Stockhausen, Mauricio Kagel, ou encore John Cage
Metacritic ", qui détermine une moyenne pondérée entre 0 et 100 basée sur les critiques populaires, a donné un score moyen de 50 % pour le film, basé sur 40 critiques.
MISC - Metacritic
NCHUNK - une moyenne pondérée entre 0 et 100
NCHUNK - les critiques populaires
NCHUNK - un score moyen de 50 %
NCHUNK - 40 critiques

We see in the first example that the PER span “Maurice Duval” and NCHUNK span “l’ethnologue Maurice Duval” overlap.

3. Data Processing — Tokenization

The input format, called “offset” give us the start-end positions of each span and its associated tag. Modelled as a multilabel classification at the token level, we need to convert this format into “iob” format to assign each token to its tags (or labels). We will use French tokenizers like FlauBERT or CamemBERT.

In this tutorial, our datasets are in “offset“ format, so we prefer CamemBERT as it provides the fast version CamembertTokenizerFast which can return offset mappings of every token in the sentence. To see the difference between a slow tokenizer like FlauBERT and a fast tokenizer like CamemBERT, you can try in a terminal:

So, as return_offsets_mapping is a very useful feature for us to map a span’s label into tokens’ labels, let’s use CamemBERT. We tokenize the dataset:

(The utility function get_token_role_in_span here will be used in the main converter function tokenize_and_adjust_labels)

Now apply the mapping

Look at an example

If we look at an example like tokenized_train_ds[0], we see a field labels is present, which is 0–1 encoded of the tags associated with each token.

sample = tokenized_train_ds[0]
sample["labels"]

Output

[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
...
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]

The 0–1 format above is hard to follow, let’s decode it a bit to see if everything went well.

Output

--------Token---------|--------Labels----------
<s> | []
Selon | []
l | ['B-NCHUNK']
' | ['I-NCHUNK']
ethno | ['I-NCHUNK']
logue | ['I-NCHUNK']
Maurice | ['B-PER', 'I-NCHUNK']
Duval | ['I-PER', 'I-NCHUNK']
, | []
« | []
dire | []
que | []
ce | ['B-NCHUNK']
mouvement | ['I-NCHUNK']
de | ['I-NCHUNK']
la | ['I-NCHUNK']
gauche | ['I-NCHUNK']
radicale | ['I-NCHUNK']
est | []
« | []
une | []
secte | []
», | []
ce | []
n | []
' | []
est | []
pas | []
argument | []
er | []
légitime | []
ment | []
contre | []
ses | ['B-NCHUNK']
idées | ['I-NCHUNK']
, | []
mais | []
c | []
' | []
est | []
suggérer | []
qu | []
' | []
il | []
est | []
mal | []
faisant | []
, | []
malsain | []
et | []
que | []
sa | ['B-NCHUNK']
disparition | ['I-NCHUNK']
serait | []
souhaitable | []
». | []
</s> | []

We see that each token is correctly associated to 0, 1 or several labels. Everything seems fine so far.

DataCollator

We build a DataCollator at the end for finetuning in batch mode in the next step.

4. Modeling

Prepare the Metrics

What did we do in compute_metrics?

  • (1): The input is a tuple. We retrieve the predictions and the true labels. The ground truth labels true_labels is an array of shape (dataset size, number of tokens, number of labels), each item of the array is 0 or 1. The predictions array is of the same shape, each item is a logit returned by the model.
  • (2): We define the same threshold 0 for the logits and assign 1 to any position in the array where the logit is beyond this threshold. (Equivalently, with the logit 0, the sigmoid function returns 0.5 as the probability that the item belongs to the corresponding label).
  • (3): For each label in id2label, we compute the confusion matrix. The output of this line will be an array of n_labels = 15 confusion matrices.
  • (4): We compute the token-based precision, recall and f1 score for the 14 labels (except “O”) and store the “f1” in the dict metrics. If you want to evaluate with entity-based metrics, use seqeval for example. Note that in case we put 0 if it’s a zero-division.
  • (5): We compute the macro f1 score over the 14 labels and append it to metrics.

Prepare the Loss — Build a Custom Model Class

That was the metrics, now to adapt the loss and the output layer of the architecture. As we used CamemBERT as a tokenizer, let’s take advantage of a checkpoint like camembert-base, which is already pre-trained from a sufficient large French corpus. To modify the loss, we need to understand what is CamemBERT’s architecture. So, let’s load the model in interactive mode and look at all the classes it inherits.

We investigate all those classes from left to right and found that transformers.models.camembert.modeling_camembert.CamembertModel inherits from transformers.models.roberta.modeling_roberta.RobertaModel. Look at the code of this class, we see that the architecture for token classification is implemented here:

In particular, The loss we need to change is found here:

Indeed, it suffices to replace it with

loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels.float())

It is worth remarking here that if you use FlauBERT instead of Camembert, then you need to change the loss implemented in the class XLMForTokenClassification (here).

Now, let’s build a class RobertaForSpanCategorization which looks exactly like RobertaForTokenClassification and change the computation of loss only. In fact, by using BCEWithLogitsLoss, we already told torch to compute the sigmoid of logits instead of softmax and compute the binary cross-entropy loss instead of cross-entropy loss.

5. Fine-tuning

Now it’s time for finetuning. Let’s define the training arguments.

To make the training reproducible, instead of using model=”camembert-base”, we will use model_init=… as advised huggingface in https://huggingface.co/transformers/v4.6.0/main_classes/trainer.html#trainingarguments (see the paragraph for the seed argument).

Output

Observations

  • The model prefers major classes (i.e. labels with more samples) to minor classes (i.e. labels with fewer samples) in the sense that it gives high scores for major classes sooner. For the same tag, the “I-” labels attain high scores (or even non-zero scores) sooner than the “B-” labels. This behaviour can be changed if we apply weights to the loss. So, if we have classes with far fewer samples, we can put a higher weight on it so it will not be forgotten for too long.
  • The metrics can also be computed in a way that we merge “B-” and “I-” labels as one single tag, and the global score will be higher.

The global macro f1 score of around 0.75 is not bad for fine-tuning on 400 items only. We save the model here.

trainer.model.save_pretrained("./models/fine_tune_bert_output_span_cat")

5. Inference

To use the fine-tuned model on new data, we load the model and the tokenizer (note that the tokenizer has not changed, so we can use the original “camembert-base”).

We write a function to retrieve the offsets and tags of all tokens in a text.

The function retrieves the list of decoded tokens, their tags and offset mappings (i.e. their start and end position in the text). We can verify it with an example like this text on lefigaro.fr:

Output

<s>             - []
Du - []
coup - []
, - []
la - [9]
menace - [10]
des - [10]
feux - [10]
de - [10]
forêt - [10]
est - []
permanente - []
, - []
après - []
les - [9]
incendie - [10]
s - [10]
dévastateur - [10]
s - [10]
de - [12]
juillet - [12]
dans - [13]
le - [14]
sud - [14]
- - [14]
ouest - [14]
de - [14]
la - [14]
France - [5, 14]
...
en - []
quelques - [12]
jours - [12]
. - []
</s> - []

Group the Entities together

If the format above is not feasible for us to verify the output, we can link the tokens together and returns the spans as in our input format.

Output

Du coup, la menace des feux de forêt est permanente, après les incendies dévastateurs de juillet dans le sud-ouest de la France, en Espagne, au Portugal ou en Grèce. Un important feu de forêt a éclaté le 24 juillet dans le parc national de la Suisse de Bohême, à la frontière entre la République tchèque et l'Allemagne, où des records de chaleur ont été battus (36,4C). Un millier d'hectares ont déjà été touchés. Lundi, les pompiers espéraient que l'incendie pourrait être maîtrisé en quelques jours.[{'start': 9,
'end': 36,
'tag': 'NCHUNK',
'text': 'la menace des feux de forêt'},
{'start': 59,
'end': 85,
'tag': 'NCHUNK',
'text': 'les incendies dévastateurs'},
{'start': 86, 'end': 96, 'tag': 'TIME', 'text': 'de juillet'},
{'start': 97,
'end': 127,
'tag': 'PLACE',
'text': 'dans le sud-ouest de la France'},
{'start': 121, 'end': 127, 'tag': 'LOC', 'text': 'France'},
{'start': 129, 'end': 139, 'tag': 'PLACE', 'text': 'en Espagne'},
{'start': 132, 'end': 139, 'tag': 'LOC', 'text': 'Espagne'},
{'start': 141, 'end': 152, 'tag': 'PLACE', 'text': 'au Portugal'},
{'start': 144, 'end': 152, 'tag': 'LOC', 'text': 'Portugal'},
{'start': 156, 'end': 164, 'tag': 'PLACE', 'text': 'en Grèce'},
{'start': 159, 'end': 164, 'tag': 'LOC', 'text': 'Grèce'},
{'start': 166,
'end': 191,
'tag': 'NCHUNK',
'text': 'Un important feu de forêt'},
{'start': 201, 'end': 214, 'tag': 'TIME', 'text': 'le 24 juillet'},
{'start': 220,
'end': 259,
'tag': 'NCHUNK',
'text': 'le parc national de la Suisse de Bohême'},
{'start': 243, 'end': 259, 'tag': 'LOC', 'text': 'Suisse de Bohême'},
{'start': 263,
'end': 318,
'tag': 'NCHUNK',
'text': "la frontière entre la République tchèque et l'Allemagne"},
{'start': 285, 'end': 303, 'tag': 'LOC', 'text': 'République tchèque'},
{'start': 309, 'end': 318, 'tag': 'LOC', 'text': 'Allemagne'},
{'start': 323, 'end': 345, 'tag': 'NCHUNK', 'text': 'des records de chaleur'},
{'start': 370, 'end': 391, 'tag': 'NCHUNK', 'text': "Un millier d'hectares"},
{'start': 414, 'end': 419, 'tag': 'TIME', 'text': 'Lundi'},
{'start': 486, 'end': 500, 'tag': 'TIME', 'text': 'quelques jours'}]

We confirm that the model can returns overlapped spans like the NCHUNK span “le parc national de la Suisse de Bohême“ and the LOC “Suisse de Bohême“.

6. Recap

That is it. In this article, we introduced a way to adapt very little transformers code for Token Classifier and make it a Span Categorizer, based on again the simple idea: let’s change the loss to convert a multiclass classifier at the token level into a multilabel classifier.

The approach can even be improved in many ways, for example, by putting weights into the binary cross-entropy loss at the output layer of the BERT model to accelerate the convergence of minor classes. Feel free to test and improve the models!

References

  1. https://spacy.io/api/spancategorizer
  2. https://explosion.ai/blog/spancat
  3. https://lajavaness.medium.com/d%C3%A9tection-dexpressions-de-tonalit%C3%A9s-sentiments-dans-un-document-avec-du-ner-partie-ii-f44c7ee43c14
  4. https://scikit-learn.org/stable/modules/multiclass.html

Acknowledgement

Thanks to our colleagues Al Houceine KILANI and Kostia PEREBASKINE for the article review.

About

Nhut DOAN NGUYEN is a data scientist at La Javaness since March 2021.

--

--

La Javaness R&D

We help organizations to succeed in the new paradigm of “AI@scale”, by using machine intelligence responsibly and efficiently : www.lajavaness.com