Finetune a Relation Classifier with transformers and LUKE

La Javaness R&D
13 min readApr 25, 2023

1. Introduction

In many of our projects, we deal with the problem of understanding cause-consequence, time-object or place-object relationships. Imagine that we try to understand this customer request:

Example 1:

“I have received the invoice for January & March 2022 but the February 2022 invoice is missing. Can you please send me the missing invoice?”

In this message, several months are mentioned. What’s important to identify is that the customer is asking for the invoice for February 2022 and not the other months.

Example 2

Take an example of customer review analysis:

“I am happy with the product but the delivery takes too long time.”

Here, we want to understand the causes of satisfaction/dissatisfaction. As a result, we need to correctly associate satisfaction with the product and dissatisfaction with delivery.

Formulation of the problem in NLP

In NLP problems, we sometimes need to detect spans in a text that correspond to named entities and identify relationships between entities. For example:

Entities and their relations — Image from Annotto, an annotation tool by La Javaness

In the example above, we are interested in the entities TIME, SUBJECT, ACTION, OBJECT and possible relationships between the entity pairs, like:

  • the WHO_DO relationship from a SUBJECT to an ACTION.
  • the DO_WHAT relationship from an ACTION to an OBJECT.
  • the DO_WHEN relationship from an ACTION to a TIME expression.

Based on the specific problem, the relationships can be symmetric or non-symmetric. In our example, the relationships are not symmetric: DO_WHEN goes from an ACTION to a TIME expression, not the opposite way.

So, we are dealing with two problems: detecting entities and then detecting their relationships.

The first problem, detecting the entities, can be solved with a token classification model or named-entities recognition (NER) model).

The second problem, detecting relationships between entities, can be solved by a relation classification model. One of the most used models for this kind is LUKE, and LUKE can be used both for named-entity recognition and relation classification model.

This article describes how we will use LUKE to finetune a relation classification model (a.k.a. relation classifier) for our custom datasets.

The plan for the remaining of this article:

  • Section 2: Describe the dataset, entities and relationships.
  • Section 3: Describe how to finetune LUKE for our relation classification task.
  • Section 4: Present briefly some use cases we applied/can apply this method in our company's projects.

You can follow the code of this tutorial by opening a Jupyter notebook session. Let's begin by installing/importing the required packages

pip install datasets tensorboard torch tables
import warnings

2. Dataset

2.1 The raw dataset

Let's share with you the raw dataset in annotation. jsonlines. It's a jsonlines file where a row looks like (after beautifying into a Python dictionary)

{'item': {'type': 'text',
'uuid': '114',
'data': {'text': 'Le projet de construire les observatoires naît dans le milieu des années quatre-vingts.'},
'metadata': {},
'tags': ['SUBJ: Le projet de construire les observatoires',
'VERB: naît',
'TIME: dans le milieu des années quatre-vingts.']},
'itemMetadata': {'createdAt': 1672105534798,
'updated': '2022-12-31T13:40:32.878Z',
'seenAt': '2022-12-31T13:40:21.681Z'},
'tags': [],
'comments': [],
'metadata': {},
'annotation': {'classifications': None,
'ner': {'NER': {'entities': [{'value': 'TIME',
'start_char': 47,
'end_char': 86,
'ent_id': 0},
{'value': 'SUBJ', 'start_char': 0, 'end_char': 41, 'ent_id': 1},
{'value': 'VERB', 'start_char': 42, 'end_char': 46, 'ent_id': 2}]},
'relations': [{'label': 'WHO_DO', 'src': 1, 'dest': 2},
{'label': 'DO_WHEN', 'src': 2, 'dest': 0}]}},
'historicAnnotations': []}

What interests us in this example are the following:

  • The plain text:
  • The labelled entities: annotation.ner.NER.entities. It lists all the spans (text segments) corresponding to an entity (like TIME, SUBJ, VERB which stands for time expression, subject and action, respectively). We can also find the start and end positions of the entity, as well as the entity id (TIME: 0, SUBJ: 1, VERB: 2).
  • The labelled relations: annotation/relations. It lists all pairs (source span, destination span) that possess a relationship between them. In our example, there is one WHO_DO relationship which points from a SUBJ (id=1) to a VERB(id=2); and one DO_WHEN solution from a VERB (id=2) to a TIME (id=0).

To load the dataset and split them into 80% train / 20% test parts:

from datasets import Dataset, DatasetDict

INPUT = "annotations.jsonlines"

raw_datasets = Dataset.from_json(INPUT).filter(lambda x:
len((x.get("annotation", {}).get("ner", {}) or {}).get("relations", [])) > 0
).train_test_split(test_size=0.2, seed=42)U

Now we have two datasets: datasets["train"] and datasets["test"]. Let's look at an example of datasets["train"], it is the same format as described above.

# Output
{'item': {'type': 'text',
'uuid': '114',
'data': {'text': 'Le projet de construire les observatoires naît dans le milieu des années quatre-vingts.'},
'metadata': {'time_expressions': 'ABS DIR - MID DC1980'},
'tags': ['SUBJ: Le projet de construire les observatoires',
'VERB: naît',
'TIME: dans le milieu des années quatre-vingts.']},
'itemMetadata': {'createdAt': 1672105534798,
'updated': '2022-12-31T13:40:32.878Z',
'seenAt': '2022-12-31T13:40:21.681Z'},
'tags': ['SUBJ: Le projet de construire les observatoires',
'VERB: naît',
'TIME: dans le milieu des années quatre-vingts.'],
'comments': [],
'metadata': {'time_expressions': 'ABS DIR - MID DC1980'},
'annotationMetadata': {'annotatedBy': '',
'annotatedAt': '2022-12-27T13:03:41.819Z',
'createdAt': '2022-12-27T13:03:41.814Z'},
'annotation': {'classifications': None,
'ner': {'NER': {'entities': [{'value': 'TIME',
'start_char': 47,
'end_char': 86,
'ent_id': 0},
{'value': 'SUBJ', 'start_char': 0, 'end_char': 41, 'ent_id': 1},
{'value': 'VERB', 'start_char': 42, 'end_char': 46, 'ent_id': 2}]},
'relations': [{'label': 'WHO_DO', 'src': 1, 'dest': 2},
{'label': 'DO_WHEN', 'src': 2, 'dest': 0}]}},
'historicAnnotations': []}

2.2 Labels: Entities and Relations

In this dataset, the entities that interest us are:

  • SUBJ: The subject
  • VERB: The verb group if the sentence talks about an action; or the "to be" verb + adjective if the sentence is descriptive.
  • OBJ: Objects or complements (what are not subjects, verb groups, time, places) in the sentence.
  • TIME: Time expression
  • PLACE: Places

As this is a toy dataset, please do not pay too much attention to the rigidity of the linguistic aspect of the annotation (e.g. why an adjective is annotated as "VERB").

The relationships that interest us are:

  • WHO_DO: The SUBJ performs a VERB
  • DO_WHAT: The VERB acts on an OBJ, or acts in an OBJ way.
  • DO_WHEN: The VERB happens at TIME.
  • DO_WHERE: The VERB happens at PLACE.
  • IN_COMP_WITH_WHEN: The OBJ or the SUBJ is being compared to TIME.
  • TIME_OF_WHAT: The TIME is the time of OBJ (not the time of VERB)
  • SAME_AS: The SUBJ or the OBJ refers to another SUBJ or OBJ.

Later, we will introduce a relationship:

  • NO_REL: There is no relationship from one span to another span. This acts as a negative class (i.e. no relationship at all) for the relation classification problem.

The meaning of those relationships is straightforward. It's worth distinguishing the difference of DO_WHEN and TIME_OF_WHAT. For example, in the sentence "I read the document of 2002 today.":

  • then "today" is the time of "read": there should be a DO_WHEN relationship from "read" to "today".
  • "2002" is the time of "the document": there should be a TIME_OF_WHAT from "2002" to "the document".

Let's put these labels into code so we can use them later.

id2label = {
0: "NO_REL",
1: "SAME_AS",
2: "WHO_DO",
3: "DO_WHAT",
4: "DO_WHEN",
5: "DO_WHERE",
label2id = {v: k for k, v in id2label.items()}
# Output
{'NO_REL': 0,
'SAME_AS': 1,
'WHO_DO': 2,
'DO_WHAT': 3,
'DO_WHEN': 4,
'DO_WHERE': 5,

2.3 Transform the datasets for training

To use LUKE, we would like a much simpler format like this:

{'id': '<any-thing>',
'text': 'Le projet de construire les observatoires naît dans le milieu des années quatre-vingts.',
'rel_tag': 'WHO_DO',
'tag_spans': [[0, 41], [42, 46]]}

where for each record, we specify

  • text: the text,
  • rel_tag: one relationship only between a pair of entities
  • and tag_spans: the positions [start, end] of the source and destination entities in order. It must be a list of 2 lists of 2 integers.

The field names (id, text, rel_tag, tag_spans) are of your choice. You will declare them later in the training configuration part.

The following function helps us transform the original format to this simple one.

def split_example(example):
"""Split an example into multiple examples of final format
uuid = example["item"]["uuid"]
text = example["item"]["data"]["text"]

# Construct a dictionary to map the span id to its content
ent_dict = {}
if (entities := (example["annotation"].get("ner", {}) or {}).get("NER", {}).get("entities", [])) != []:
for ent in entities:
ent_dict[ent["ent_id"]] = ent

# Construct the list of relationship
rel_list = []
positive_rel_set = set()

# Put the relation into the list
for rel in (example.get("annotation", {}).get("ner", {}) or {}).get("relations", []):
"src": (ent_dict[rel["src"]]["start_char"], ent_dict[rel["src"]]["end_char"]),
"dest": (ent_dict[rel["dest"]]["start_char"], ent_dict[rel["dest"]]["end_char"]),
"rel_tag": rel["label"]
positive_rel_set.add((rel["src"], rel["dest"]))

# For any pairs of entities that have no relation, put "NO_REL" between them
for src in ent_dict.keys():
for dest in ent_dict.keys():
if src != dest and (src, dest) not in positive_rel_set:
"src": (ent_dict[src]["start_char"], ent_dict[src]["end_char"]),
"dest": (ent_dict[dest]["start_char"], ent_dict[dest]["end_char"]),
"rel_tag": "NO_REL"

# Return all the relationships, including the "NO_REL" ones
outs = [{
"id": f"{uuid}-{i}", # A convention for generating id: concate the original id with its index
"text": text,
"rel_tag": rel["rel_tag"],
"tag_spans": [rel["src"], rel["dest"]]
} for i, rel in enumerate(rel_list)]

return out

Let's test split_example with an example:

split_example({'item': {'type': 'text',
'uuid': '114',
'data': {'text': 'Le projet de construire les observatoires naît dans le milieu des années quatre-vingts.'},
'metadata': {},
'tags': ['SUBJ: Le projet de construire les observatoires',
'VERB: naît',
'TIME: dans le milieu des années quatre-vingts.']},
'itemMetadata': {'createdAt': 1672105534798,
'updated': '2022-12-31T13:40:32.878Z',
'seenAt': '2022-12-31T13:40:21.681Z'},
'tags': [],
'comments': [],
'metadata': {},
'annotation': {'classifications': None,
'ner': {'NER': {'entities': [{'value': 'TIME',
'start_char': 47,
'end_char': 86,
'ent_id': 0},
{'value': 'SUBJ', 'start_char': 0, 'end_char': 41, 'ent_id': 1},
{'value': 'VERB', 'start_char': 42, 'end_char': 46, 'ent_id': 2}]},
'relations': [{'label': 'WHO_DO', 'src': 1, 'dest': 2},
{'label': 'DO_WHEN', 'src': 2, 'dest': 0}]}},
'historicAnnotations': []})
# Output
[{'id': '114-0',
'text': 'Le projet de construire les observatoires naît dans le milieu des années quatre-vingts.',
'rel_tag': 'WHO_DO',
'tag_spans': [(0, 41), (42, 46)]},
{'id': '114-1',
'text': 'Le projet de construire les observatoires naît dans le milieu des années quatre-vingts.',
'rel_tag': 'DO_WHEN',
'tag_spans': [(42, 46), (47, 86)]},
{'id': '114-2',
'text': 'Le projet de construire les observatoires naît dans le milieu des années quatre-vingts.',
'rel_tag': 'NO_REL',
'tag_spans': [(47, 86), (0, 41)]},
{'id': '114-3',
'text': 'Le projet de construire les observatoires naît dans le milieu des années quatre-vingts.',
'rel_tag': 'NO_REL',
'tag_spans': [(47, 86), (42, 46)]},
{'id': '114-4',
'text': 'Le projet de construire les observatoires naît dans le milieu des années quatre-vingts.',
'rel_tag': 'NO_REL',
'tag_spans': [(0, 41), (47, 86)]},
{'id': '114-5',
'text': 'Le projet de construire les observatoires naît dans le milieu des années quatre-vingts.',
'rel_tag': 'NO_REL',
'tag_spans': [(42, 46), (0, 41)]}]

We can see that from one original text, now 6 records are returned with the expected format (having id, text, rel_tag and tag_spans). Pairs of entities that are not labelled are also listed with NO_REL relationships.

Let's apply the transformation for the whole training and test datasets.

from datasets import DatasetDict

# Init the train/test under dictionary format
flatten_example_dict = {
"train": {"id": [], "text": [], "rel_tag": [], "tag_spans": []},
"test": {"id": [], "text": [], "rel_tag": [], "tag_spans": []}

# Then apply the function for each record of the original datasets and concate them into the dict
for split in ["train", "test"]:
for example in raw_datasets[split]:
flatten_examples = split_example(example)
for f_example in flatten_examples:
for key, value in f_example.items():

# Transform the dict into datasets again
ds = DatasetDict(

Now we can look at an example to ensure everything is going OK.

# Output
{'id': '114-0',
'text': 'Le projet de construire les observatoires naît dans le milieu des années quatre-vingts.',
'rel_tag': 'WHO_DO',
'tag_spans': [[0, 41], [42, 46]]}

2.4 Tokenization

Let's load the Luke tokenizer.

from transformers import LukeTokenizer
import torch
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_pair_classification")

The preprocess_function below describes what we need to transform to feed the trainer. We transform each record into a pair of tensors (X, y). It is the X (inputs) and y(target) in numeric form.

  • X (formally the input_values, entity_ids, entity_position_ids, attention_mask, label fields of the transformed record): can be obtained using the Luke tokenizer. We need to convert entity_spans it into a list of 2 tuples of 2 integers. We can specify padding=max_length and truncation=True so every tensor in the batch has the same shape of 256 (the default length) tokens. return_tensors="pt" means to return a Pytorch tensor.
  • y (the label field): is just a convert of rel_tag into a torch tensor.
def preprocess_function(examples):
"""Tokenize the examples"""
model_inputs = tokenizer(
entity_spans=[[tuple(item) for item in ex] for ex in examples["tag_spans"]],
model_inputs["label"] = torch.tensor(
[label2id[ex] for ex in examples["rel_tag"]]
return model_inputs

Applying to the datasets:

tokenized_datasets =, batched=True, remove_columns=ds["train"].column_names)

We verify now the tokenized datasets are in good shape.

train: Dataset({
features: ['input_ids', 'entity_ids', 'entity_position_ids', 'attention_mask', 'entity_attention_mask', 'label'],
num_rows: 6375
test: Dataset({
features: ['input_ids', 'entity_ids', 'entity_position_ids', 'attention_mask', 'entity_attention_mask', 'label'],
num_rows: 2256

Dataset preprocessing is now complete.

3. Modelling — Finetune a relation classifier with LUKE

We will begin with the metrics (section 3.1), then configure the trainer (section 3.2), launch the finetuning (3.3) and predict new data (3.4).

3.1 Prepare the metrics

The model will predict the class of relation between the two given entities. We can track the accuracy, f1_micro and f1_macro scores during the training. We also count NO_REL as a class (for your own case, you can eliminate it), and since it is a multiclass classification without a negative class, f1_micro and accuracy scores should be the same.

from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

def compute_metrics(eval_preds):
logits, labels = eval_preds
preds = logits.argmax(-1)
print(confusion_matrix(labels, preds))
result = {
"accuracy": accuracy_score(labels, preds),
"f1_micro": f1_score(labels, preds, average="micro"),
"f1_macro": f1_score(labels, preds, average="macro"),
return result

3.2 Training arguments

from transformers import LukeForEntityPairClassification, TrainingArguments, Trainer


def model_init():
# For reproducibility
return LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-base", label2id=label2id, id2label=id2label)

args = TrainingArguments(
"../../models/tmp/sample-relation-clf", # Replace it with where you store the model
evaluation_strategy = "steps",

3.3 Training

from transformers import Trainer

trainer = Trainer(
# Output
loading configuration file config.json from cache at /home/
Model config LukeConfig {
"_name_or_path": "studio-ousia/luke-base",
"architectures": [

# The logs are long. This is its last line.
Using cuda_amp half precision backend
loading weights file pytorch_model.bin from cache at /home/
Some weights of the model checkpoint at studio-ousia/luke-base were not used when initializing LukeForEntityPairClassification:
***** Running training *****
Num examples = 6375
Num Epochs = 5
Instantaneous batch size per device = 32
Total train batch size (w. parallel, distributed & accumulation) = 32
Gradient Accumulation steps = 1
Total optimization steps = 1000
Number of trainable parameters = 274514432


***** Running Evaluation *****
Num examples = 2256
Batch size = 32

Step Training Loss Validation Loss Accuracy F1 Micro F1 Macro
100 0.632700 0.442254 0.870189 0.870189 0.252814
200 0.331100 0.259747 0.929434 0.929434 0.398293
300 0.252900 0.199682 0.939623 0.939623 0.440257
400 0.152900 0.186350 0.943774 0.943774 0.510882
500 0.178700 0.141737 0.961509 0.961509 0.572071
600 0.157200 0.151913 0.958113 0.958113 0.527251
700 0.125600 0.163769 0.956981 0.956981 0.539761
800 0.108800 0.144700 0.963396 0.963396 0.599255
900 0.106900 0.156587 0.955472 0.955472 0.586468
1000 0.124500 0.135266 0.966792 0.966792 0.688834
1100 0.083100 0.138630 0.963396 0.963396 0.659829
1200 0.086700 0.141869 0.965660 0.965660 0.687968
1300 0.066100 0.144285 0.968302 0.968302 0.662152
1400 0.068300 0.147904 0.968302 0.968302 0.660367
1500 0.070700 0.142543 0.968679 0.968679 0.670545
1600 0.054100 0.151641 0.967170 0.967170 0.680851
1700 0.045200 0.151504 0.968302 0.968302 0.700430
The last confusion matrix printed in the logs
[[2240    2    5    7   10    0    0   17]
[ 4 3 0 0 0 0 0 0]
[ 2 0 113 0 0 0 0 0]
[ 7 0 0 87 1 2 0 0]
[ 12 0 1 2 82 0 0 1]
[ 3 0 0 0 0 4 0 0]
[ 5 0 0 0 0 0 0 0]
[ 3 0 0 0 0 0 0 37]]

Saving model checkpoint to ../../models/tmp/sample-relation-clf/checkpoint-1000
Loading best model from ../../models/tmp/sample-relation-clf/checkpoint-1000 (score: 0.5488432954354046).
TrainOutput(global_step=1000, training_loss=0.18856227922439575, metrics={'train_runtime': 700.0075, 'train_samples_per_second': 45.535, 'train_steps_per_second': 1.429, 'total_flos': 1.048902746112e+16, 'train_loss': 0.18856227922439575, 'epoch': 5.0})

We observe the improvement of metrics (accuracy = f1 micro, f1 macro) in time. The confusion matrices are also printed at the end of each group of 200 steps. In the logs above of this article, we only show the last confusion matrix. The model seems to perform well in the classes WHO_DO, DO_WHAT, DO_WHERE, a little poorer TIME_OF_WHAT and very poor on SAME_AS, IN_COMP_WITH_WHEN due to a lack of samples in these classes.

3.4 Predict new texts

Let's try a new example and see if the model detects the relations well. First, we want to detect the relationship between "je" and "voudrais".

import torch.nn.functional as F

def predict(text, entity_spans, model=trainer.model, tokenizer=tokenizer, device="cuda"):
"""Predict on new example"""
inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt").to(device)
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return id2label[predicted_class_idx]

text = "Je voudrais une attestation de paiement pour 2022."
entity_spans = [(0, 2), (3, 11)] # "Je", "voudrais"
predict(text, entity_spans)
# Output
'WHO_DO' #Correct

Next, let’s try “2022”. We expect to see it’s the time of “une attestation de paiement”, and not the. verb“voudrais”.

entity_spans = [(45, 49), (3, 11)] # "2022", "voudrais"
predict(text, entity_spans)
# Output
Pentity_spans = [(45, 49), (3, 11)] # "voudrais", "2022"
predict(text, entity_spans)
# Output
entity_spans = [(45, 49), (12, 39)] # "2022", "une attestation de paiement"
predict(text, entity_spans)
# Output
'TIME_OF_WHAT' #It's the correct tag we expect

It works as expected, although it was trained on a few examples. The model can be improved even more, especially for minor classes like SAME_AS, DO_WHERE etc., by adding more examples in the training set.

4. Some applications

The relation classification approach is (was/going to be) used in some of our projects. Some of their use cases are:

  1. To detect if a mentioned time expression in an email refers to the correct object. For example:
  • Scenario: Client A asks agent B to send a payment certificate for the period (says for January 2023). Client A mentions a lot of periods (time expressions) in his email. We want to detect which date should be the period associated with the payment certificate request.
  • Approach: We use a NER model to detect if a span referring to "payment certificate" appears in the email. Then for all of the time expressions mentioned, detect if there is a relationship between it and the "payment certificate" span.

2. In sentiment analysis, to detect if a specific emotion is caused by an event/a situation/a person.

  • Approach: Similarly to the previous use case, we use a NER model to detect events/situations/people/organizations and emotions. (You can find more details of this approach for emotions in my colleague Al Houceine's article.) Then we use the relation classifier to detect if there is a relation "CAUSE_BY" between the emotion and the detected entities or if it is instead "NO_REL" (no relation).

Applications of the relation classification approach are various. Above are two of many examples.

If you are reaching this paragraph, which is the end of the article, we hope you enjoyed the tutorial and the approach using transformers + LUKE for relation classification. Don't hesitate to share with us other exciting approaches to the same problem in the comments.


Thanks to our colleagues Alexander DO, Lu WANG and Edouard LAVAUD for the article review.


Nhut DOAN NGUYEN is a lead data scientist and data engineer at La Javaness.



La Javaness R&D

We help organizations to succeed in the new paradigm of “AI@scale”, by using machine intelligence responsibly and efficiently :