SinaTools 0.1.11__py2.py3-none-any.whl → 0.1.13__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {SinaTools-0.1.11.dist-info → SinaTools-0.1.13.dist-info}/METADATA +2 -3
- {SinaTools-0.1.11.dist-info → SinaTools-0.1.13.dist-info}/RECORD +47 -26
- {SinaTools-0.1.11.dist-info → SinaTools-0.1.13.dist-info}/entry_points.txt +7 -3
- sinatools/CLI/DataDownload/download_files.py +0 -10
- sinatools/CLI/ner/corpus_entity_extractor.py +9 -6
- sinatools/CLI/ner/entity_extractor.py +18 -42
- sinatools/CLI/utils/arStrip.py +8 -8
- sinatools/CLI/utils/implication.py +0 -8
- sinatools/CLI/utils/jaccard.py +5 -14
- sinatools/CLI/utils/remove_latin.py +2 -2
- sinatools/CLI/utils/text_dublication_detector.py +25 -0
- sinatools/VERSION +1 -1
- sinatools/morphology/ALMA_multi_word.py +14 -16
- sinatools/morphology/__init__.py +32 -31
- sinatools/ner/__init__.py +28 -2
- sinatools/ner/data/__init__.py +1 -0
- sinatools/ner/data/datasets.py +146 -0
- sinatools/ner/data/transforms.py +118 -0
- sinatools/ner/data.py +124 -0
- sinatools/ner/data_format.py +124 -0
- sinatools/ner/datasets.py +146 -0
- sinatools/ner/entity_extractor.py +34 -54
- sinatools/ner/helpers.py +86 -0
- sinatools/ner/metrics.py +69 -0
- sinatools/ner/nn/BaseModel.py +22 -0
- sinatools/ner/nn/BertNestedTagger.py +34 -0
- sinatools/ner/nn/BertSeqTagger.py +17 -0
- sinatools/ner/nn/__init__.py +3 -0
- sinatools/ner/trainers/BaseTrainer.py +117 -0
- sinatools/ner/trainers/BertNestedTrainer.py +203 -0
- sinatools/ner/trainers/BertTrainer.py +163 -0
- sinatools/ner/trainers/__init__.py +3 -0
- sinatools/ner/transforms.py +119 -0
- sinatools/semantic_relatedness/__init__.py +20 -0
- sinatools/semantic_relatedness/compute_relatedness.py +31 -0
- sinatools/synonyms/__init__.py +18 -0
- sinatools/synonyms/synonyms_generator.py +192 -0
- sinatools/utils/text_dublication_detector.py +110 -0
- sinatools/wsd/__init__.py +11 -0
- sinatools/{salma/views.py → wsd/disambiguator.py} +135 -94
- sinatools/{salma → wsd}/wsd.py +1 -1
- sinatools/CLI/salma/salma_tools.py +0 -68
- sinatools/salma/__init__.py +0 -12
- sinatools/utils/utils.py +0 -2
- {SinaTools-0.1.11.data → SinaTools-0.1.13.data}/data/sinatools/environment.yml +0 -0
- {SinaTools-0.1.11.dist-info → SinaTools-0.1.13.dist-info}/AUTHORS.rst +0 -0
- {SinaTools-0.1.11.dist-info → SinaTools-0.1.13.dist-info}/LICENSE +0 -0
- {SinaTools-0.1.11.dist-info → SinaTools-0.1.13.dist-info}/WHEEL +0 -0
- {SinaTools-0.1.11.dist-info → SinaTools-0.1.13.dist-info}/top_level.txt +0 -0
- /sinatools/{salma → wsd}/settings.py +0 -0
@@ -1,51 +1,51 @@
|
|
1
1
|
import os
|
2
2
|
from collections import namedtuple
|
3
|
-
from sinatools.ner.
|
4
|
-
from
|
5
|
-
from sinatools.DataDownload import downloader
|
6
|
-
from . import tag_vocab, train_config, tagger
|
3
|
+
from sinatools.ner.data_format import get_dataloaders, text2segments
|
4
|
+
from . import tagger, tag_vocab, train_config
|
7
5
|
|
8
|
-
def
|
6
|
+
def extract(text, batch_size=32):
|
9
7
|
"""
|
10
|
-
This method
|
8
|
+
This method processes an input text and returns named entites for each token within the text, based on the specified batch size. As follows:
|
11
9
|
|
12
10
|
Args:
|
13
|
-
text (str): The
|
11
|
+
text (:obj:`str`): The Arabic text to be tagged.
|
14
12
|
batch_size (int, optional): Batch size for inference. Default is 32.
|
15
13
|
|
16
14
|
Returns:
|
17
|
-
list: A list of
|
18
|
-
|
15
|
+
list (:obj:`list`): A list of JSON objects, where each JSON could be contains:
|
16
|
+
token: The token from the original text.
|
17
|
+
NER tag: The label pairs for each segment.
|
18
|
+
|
19
19
|
**Example:**
|
20
20
|
|
21
21
|
.. highlight:: python
|
22
22
|
.. code-block:: python
|
23
23
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
24
|
+
from sinatools.ner.entity_extractor import extract
|
25
|
+
extract('ذهب محمد إلى جامعة بيرزيت')
|
26
|
+
[{
|
27
|
+
"word":"ذهب",
|
28
|
+
"tags":"O"
|
29
|
+
},{
|
30
|
+
"word":"محمد",
|
31
|
+
"tags":"B-PERS"
|
32
|
+
},{
|
33
|
+
"word":"إلى",
|
34
|
+
"tags":"O"
|
35
|
+
},{
|
36
|
+
"word":"جامعة",
|
37
|
+
"tags":"B-ORG"
|
38
|
+
},{
|
39
|
+
"word":"بيرزيت",
|
40
|
+
"tags":"B-GPE I-ORG"
|
41
|
+
}]
|
33
42
|
"""
|
34
|
-
# Load tagger
|
35
|
-
# filename = 'Wj27012000.tar'
|
36
|
-
# path =downloader.get_appdatadir()
|
37
|
-
# model_path = os.path.join(path, filename)
|
38
|
-
# print('1',model_path)
|
39
|
-
# tagger, tag_vocab, train_config = load_checkpoint(model_path)
|
40
43
|
|
41
|
-
|
42
|
-
# Convert text to a tagger dataset and index the tokens in args.text
|
43
44
|
dataset, token_vocab = text2segments(text)
|
44
45
|
|
45
46
|
vocabs = namedtuple("Vocab", ["tags", "tokens"])
|
46
47
|
vocab = vocabs(tokens=token_vocab, tags=tag_vocab)
|
47
48
|
|
48
|
-
# From the datasets generate the dataloaders
|
49
49
|
dataloader = get_dataloaders(
|
50
50
|
(dataset,),
|
51
51
|
vocab,
|
@@ -54,39 +54,19 @@ def ner(text, batch_size=32):
|
|
54
54
|
shuffle=(False,),
|
55
55
|
)[0]
|
56
56
|
|
57
|
-
|
57
|
+
|
58
58
|
segments = tagger.infer(dataloader)
|
59
59
|
segments_lists = []
|
60
|
-
|
60
|
+
|
61
61
|
for segment in segments:
|
62
62
|
for token in segment:
|
63
|
-
segments_list =
|
64
|
-
segments_list
|
65
|
-
#print([t['tag'] for t in token.pred_tag])
|
63
|
+
segments_list = {}
|
64
|
+
segments_list["word"] = token.text
|
66
65
|
list_of_tags = [t['tag'] for t in token.pred_tag]
|
67
66
|
list_of_tags = [i for i in list_of_tags if i not in('O',' ','')]
|
68
|
-
#print(list_of_tags)
|
69
67
|
if list_of_tags == []:
|
70
|
-
segments_list
|
68
|
+
segments_list["tag"] = ' '.join(['O'])
|
71
69
|
else:
|
72
|
-
segments_list
|
73
|
-
segments_lists.append(segments_list)
|
70
|
+
segments_list["tag"] = ' '.join(list_of_tags)
|
71
|
+
segments_lists.append(segments_list)
|
74
72
|
return segments_lists
|
75
|
-
|
76
|
-
#Print results
|
77
|
-
# for segment in segments:
|
78
|
-
# s = [
|
79
|
-
# (token.text, token.pred_tag[0]['tag'])
|
80
|
-
# for token in segment
|
81
|
-
# if token.pred_tag[0]['tag'] != 'O'
|
82
|
-
# ]
|
83
|
-
# print(", ".join([f"({token}, {tag})" for token, tag in s]))
|
84
|
-
|
85
|
-
def extract_tags(text):
|
86
|
-
tags = []
|
87
|
-
tokens = text.split()
|
88
|
-
for token in tokens:
|
89
|
-
tag = token.split("(")[-1].split(")")[0]
|
90
|
-
if tag != "O":
|
91
|
-
tags.append(tag)
|
92
|
-
return " ".join(tags)
|
sinatools/ner/helpers.py
ADDED
@@ -0,0 +1,86 @@
|
|
1
|
+
import os
|
2
|
+
import sys
|
3
|
+
import logging
|
4
|
+
import importlib
|
5
|
+
import shutil
|
6
|
+
import torch
|
7
|
+
import random
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
|
11
|
+
def logging_config(log_file=None):
|
12
|
+
"""
|
13
|
+
Initialize custom logger
|
14
|
+
:param log_file: str - path to log file, full path
|
15
|
+
:return: None
|
16
|
+
"""
|
17
|
+
handlers = [logging.StreamHandler(sys.stdout)]
|
18
|
+
|
19
|
+
if log_file:
|
20
|
+
handlers.append(logging.FileHandler(log_file, "w", "utf-8"))
|
21
|
+
print("Logging to {}".format(log_file))
|
22
|
+
|
23
|
+
logging.basicConfig(
|
24
|
+
level=logging.INFO,
|
25
|
+
handlers=handlers,
|
26
|
+
format="%(levelname)s\t%(name)s\t%(asctime)s\t%(message)s",
|
27
|
+
datefmt="%a, %d %b %Y %H:%M:%S",
|
28
|
+
force=True
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
def load_object(name, kwargs):
|
33
|
+
|
34
|
+
try:
|
35
|
+
object_module, object_name = name.rsplit(".", 1)
|
36
|
+
object_module = importlib.import_module(object_module)
|
37
|
+
obj = getattr(object_module, object_name)
|
38
|
+
if callable(obj):
|
39
|
+
fn = obj(**kwargs)
|
40
|
+
return fn
|
41
|
+
else:
|
42
|
+
raise TypeError(f"{name} is not a callable object.")
|
43
|
+
except (ImportError, ModuleNotFoundError) as e:
|
44
|
+
print(f"Error importing module: {e}")
|
45
|
+
except AttributeError as e:
|
46
|
+
print(f"Attribute error: {e}")
|
47
|
+
except Exception as e:
|
48
|
+
print(f"An error occurred: {e}")
|
49
|
+
|
50
|
+
return None
|
51
|
+
|
52
|
+
def make_output_dirs(path, subdirs=[], overwrite=True):
|
53
|
+
"""
|
54
|
+
Create root directory and any other sub-directories
|
55
|
+
:param path: str - root directory
|
56
|
+
:param subdirs: List[str] - list of sub-directories
|
57
|
+
:param overwrite: boolean - to overwrite the directory or not
|
58
|
+
:return: None
|
59
|
+
"""
|
60
|
+
if overwrite:
|
61
|
+
shutil.rmtree(path, ignore_errors=True)
|
62
|
+
|
63
|
+
os.makedirs(path)
|
64
|
+
|
65
|
+
for subdir in subdirs:
|
66
|
+
os.makedirs(os.path.join(path, subdir))
|
67
|
+
|
68
|
+
|
69
|
+
|
70
|
+
def set_seed(seed):
|
71
|
+
"""
|
72
|
+
Set the seed for random intialization and set
|
73
|
+
CUDANN parameters to ensure determmihstic results across
|
74
|
+
multiple runs with the same seed
|
75
|
+
|
76
|
+
:param seed: int
|
77
|
+
"""
|
78
|
+
np.random.seed(seed)
|
79
|
+
random.seed(seed)
|
80
|
+
torch.manual_seed(seed)
|
81
|
+
torch.cuda.manual_seed(seed)
|
82
|
+
torch.cuda.manual_seed_all(seed)
|
83
|
+
|
84
|
+
torch.backends.cudnn.deterministic = True
|
85
|
+
torch.backends.cudnn.benchmark = False
|
86
|
+
torch.backends.cudnn.enabled = False
|
sinatools/ner/metrics.py
ADDED
@@ -0,0 +1,69 @@
|
|
1
|
+
from seqeval.metrics import (
|
2
|
+
classification_report,
|
3
|
+
precision_score,
|
4
|
+
recall_score,
|
5
|
+
f1_score,
|
6
|
+
accuracy_score,
|
7
|
+
)
|
8
|
+
from seqeval.scheme import IOB2
|
9
|
+
from types import SimpleNamespace
|
10
|
+
import logging
|
11
|
+
import re
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
def compute_nested_metrics(segments, vocabs):
|
17
|
+
"""
|
18
|
+
Compute metrics for nested NER
|
19
|
+
:param segments: List[List[arabiner.data.dataset.Token]] - list of segments
|
20
|
+
:return: metrics - SimpleNamespace - F1/micro/macro/weights, recall, precision, accuracy
|
21
|
+
"""
|
22
|
+
y, y_hat = list(), list()
|
23
|
+
|
24
|
+
# We duplicate the dataset N times, where N is the number of entity types
|
25
|
+
# For each copy, we create y and y_hat
|
26
|
+
# Example: first copy, will create pairs of ground truth and predicted labels for entity type GPE
|
27
|
+
# another copy will create pairs for LOC, etc.
|
28
|
+
for i, vocab in enumerate(vocabs):
|
29
|
+
vocab_tags = [tag for tag in vocab.get_itos() if "-" in tag]
|
30
|
+
r = re.compile("|".join(vocab_tags))
|
31
|
+
|
32
|
+
y += [[(list(filter(r.match, token.gold_tag)) or ["O"])[0] for token in segment] for segment in segments]
|
33
|
+
y_hat += [[token.pred_tag[i]["tag"] for token in segment] for segment in segments]
|
34
|
+
|
35
|
+
logging.info("\n" + classification_report(y, y_hat, scheme=IOB2, digits=4))
|
36
|
+
|
37
|
+
metrics = {
|
38
|
+
"micro_f1": f1_score(y, y_hat, average="micro", scheme=IOB2),
|
39
|
+
"macro_f1": f1_score(y, y_hat, average="macro", scheme=IOB2),
|
40
|
+
"weights_f1": f1_score(y, y_hat, average="weighted", scheme=IOB2),
|
41
|
+
"precision": precision_score(y, y_hat, scheme=IOB2),
|
42
|
+
"recall": recall_score(y, y_hat, scheme=IOB2),
|
43
|
+
"accuracy": accuracy_score(y, y_hat),
|
44
|
+
}
|
45
|
+
|
46
|
+
return SimpleNamespace(**metrics)
|
47
|
+
|
48
|
+
|
49
|
+
def compute_single_label_metrics(segments):
|
50
|
+
"""
|
51
|
+
Compute metrics for flat NER
|
52
|
+
:param segments: List[List[arabiner.data.dataset.Token]] - list of segments
|
53
|
+
:return: metrics - SimpleNamespace - F1/micro/macro/weights, recall, precision, accuracy
|
54
|
+
"""
|
55
|
+
y = [[token.gold_tag[0] for token in segment] for segment in segments]
|
56
|
+
y_hat = [[token.pred_tag[0]["tag"] for token in segment] for segment in segments]
|
57
|
+
|
58
|
+
logging.info("\n" + classification_report(y, y_hat, scheme=IOB2))
|
59
|
+
|
60
|
+
metrics = {
|
61
|
+
"micro_f1": f1_score(y, y_hat, average="micro", scheme=IOB2),
|
62
|
+
"macro_f1": f1_score(y, y_hat, average="macro", scheme=IOB2),
|
63
|
+
"weights_f1": f1_score(y, y_hat, average="weighted", scheme=IOB2),
|
64
|
+
"precision": precision_score(y, y_hat, scheme=IOB2),
|
65
|
+
"recall": recall_score(y, y_hat, scheme=IOB2),
|
66
|
+
"accuracy": accuracy_score(y, y_hat),
|
67
|
+
}
|
68
|
+
|
69
|
+
return SimpleNamespace(**metrics)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from torch import nn
|
2
|
+
from transformers import BertModel
|
3
|
+
import logging
|
4
|
+
|
5
|
+
logger = logging.getLogger(__name__)
|
6
|
+
|
7
|
+
|
8
|
+
class BaseModel(nn.Module):
|
9
|
+
def __init__(self,
|
10
|
+
bert_model="aubmindlab/bert-base-arabertv2",
|
11
|
+
num_labels=2,
|
12
|
+
dropout=0.1,
|
13
|
+
num_types=0):
|
14
|
+
super().__init__()
|
15
|
+
|
16
|
+
self.bert_model = bert_model
|
17
|
+
self.num_labels = num_labels
|
18
|
+
self.num_types = num_types
|
19
|
+
self.dropout = dropout
|
20
|
+
|
21
|
+
self.bert = BertModel.from_pretrained(bert_model)
|
22
|
+
self.dropout = nn.Dropout(dropout)
|
@@ -0,0 +1,34 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
from sinatools.ner.nn import BaseModel
|
4
|
+
|
5
|
+
|
6
|
+
class BertNestedTagger(BaseModel):
|
7
|
+
def __init__(self, **kwargs):
|
8
|
+
super(BertNestedTagger, self).__init__(**kwargs)
|
9
|
+
|
10
|
+
self.max_num_labels = max(self.num_labels)
|
11
|
+
classifiers = [nn.Linear(768, num_labels) for num_labels in self.num_labels]
|
12
|
+
self.classifiers = torch.nn.Sequential(*classifiers)
|
13
|
+
|
14
|
+
def forward(self, x):
|
15
|
+
y = self.bert(x)
|
16
|
+
y = self.dropout(y["last_hidden_state"])
|
17
|
+
output = list()
|
18
|
+
|
19
|
+
for i, classifier in enumerate(self.classifiers):
|
20
|
+
logits = classifier(y)
|
21
|
+
|
22
|
+
# Pad logits to allow Multi-GPU/DataParallel training to work
|
23
|
+
# We will truncate the padded dimensions when we compute the loss in the trainer
|
24
|
+
logits = torch.nn.ConstantPad1d((0, self.max_num_labels - logits.shape[-1]), 0)(logits)
|
25
|
+
output.append(logits)
|
26
|
+
|
27
|
+
# Return tensor of the shape B x T x L x C
|
28
|
+
# B: batch size
|
29
|
+
# T: sequence length
|
30
|
+
# L: number of tag types
|
31
|
+
# C: number of classes per tag type
|
32
|
+
output = torch.stack(output).permute((1, 2, 0, 3))
|
33
|
+
return output
|
34
|
+
|
@@ -0,0 +1,17 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
from transformers import BertModel
|
3
|
+
|
4
|
+
|
5
|
+
class BertSeqTagger(nn.Module):
|
6
|
+
def __init__(self, bert_model, num_labels=2, dropout=0.1):
|
7
|
+
super().__init__()
|
8
|
+
|
9
|
+
self.bert = BertModel.from_pretrained(bert_model)
|
10
|
+
self.dropout = nn.Dropout(dropout)
|
11
|
+
self.linear = nn.Linear(768, num_labels)
|
12
|
+
|
13
|
+
def forward(self, x):
|
14
|
+
y = self.bert(x)
|
15
|
+
y = self.dropout(y["last_hidden_state"])
|
16
|
+
logits = self.linear(y)
|
17
|
+
return logits
|
@@ -0,0 +1,117 @@
|
|
1
|
+
import os
|
2
|
+
import torch
|
3
|
+
import logging
|
4
|
+
import natsort
|
5
|
+
import glob
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
class BaseTrainer:
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
model=None,
|
14
|
+
max_epochs=50,
|
15
|
+
optimizer=None,
|
16
|
+
scheduler=None,
|
17
|
+
loss=None,
|
18
|
+
train_dataloader=None,
|
19
|
+
val_dataloader=None,
|
20
|
+
test_dataloader=None,
|
21
|
+
log_interval=10,
|
22
|
+
summary_writer=None,
|
23
|
+
output_path=None,
|
24
|
+
clip=5,
|
25
|
+
patience=5
|
26
|
+
):
|
27
|
+
self.model = model
|
28
|
+
self.max_epochs = max_epochs
|
29
|
+
self.train_dataloader = train_dataloader
|
30
|
+
self.val_dataloader = val_dataloader
|
31
|
+
self.test_dataloader = test_dataloader
|
32
|
+
self.optimizer = optimizer
|
33
|
+
self.scheduler = scheduler
|
34
|
+
self.loss = loss
|
35
|
+
self.log_interval = log_interval
|
36
|
+
self.summary_writer = summary_writer
|
37
|
+
self.output_path = output_path
|
38
|
+
self.current_timestep = 0
|
39
|
+
self.current_epoch = 0
|
40
|
+
self.clip = clip
|
41
|
+
self.patience = patience
|
42
|
+
|
43
|
+
def tag(self, dataloader, is_train=True):
|
44
|
+
"""
|
45
|
+
Given a dataloader containing segments, predict the tags
|
46
|
+
:param dataloader: torch.utils.data.DataLoader
|
47
|
+
:param is_train: boolean - True for training model, False for evaluation
|
48
|
+
:return: Iterator
|
49
|
+
subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
|
50
|
+
gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
|
51
|
+
tokens - List[arabiner.data.dataset.Token] - list of tokens
|
52
|
+
valid_len (B x 1) - int - valiud length of each sequence
|
53
|
+
logits (B x T x NUM_LABELS) - logits for each token and each tag
|
54
|
+
"""
|
55
|
+
for subwords, gold_tags, tokens, valid_len in dataloader:
|
56
|
+
self.model.train(is_train)
|
57
|
+
|
58
|
+
if torch.cuda.is_available():
|
59
|
+
subwords = subwords.cuda()
|
60
|
+
gold_tags = gold_tags.cuda()
|
61
|
+
|
62
|
+
if is_train:
|
63
|
+
self.optimizer.zero_grad()
|
64
|
+
logits = self.model(subwords)
|
65
|
+
else:
|
66
|
+
with torch.no_grad():
|
67
|
+
logits = self.model(subwords)
|
68
|
+
|
69
|
+
yield subwords, gold_tags, tokens, valid_len, logits
|
70
|
+
|
71
|
+
def segments_to_file(self, segments, filename):
|
72
|
+
"""
|
73
|
+
Write segments to file
|
74
|
+
:param segments: [List[arabiner.data.dataset.Token]] - list of list of tokens
|
75
|
+
:param filename: str - output filename
|
76
|
+
:return: None
|
77
|
+
"""
|
78
|
+
with open(filename, "w") as fh:
|
79
|
+
results = "\n\n".join(["\n".join([t.__str__() for t in segment]) for segment in segments])
|
80
|
+
fh.write("Token\tGold Tag\tPredicted Tag\n")
|
81
|
+
fh.write(results)
|
82
|
+
logging.info("Predictions written to %s", filename)
|
83
|
+
|
84
|
+
def save(self):
|
85
|
+
"""
|
86
|
+
Save model checkpoint
|
87
|
+
:return:
|
88
|
+
"""
|
89
|
+
filename = os.path.join(
|
90
|
+
self.output_path,
|
91
|
+
"checkpoints",
|
92
|
+
"checkpoint_{}.pt".format(self.current_epoch),
|
93
|
+
)
|
94
|
+
|
95
|
+
checkpoint = {
|
96
|
+
"model": self.model.state_dict(),
|
97
|
+
"optimizer": self.optimizer.state_dict(),
|
98
|
+
"epoch": self.current_epoch
|
99
|
+
}
|
100
|
+
|
101
|
+
logger.info("Saving checkpoint to %s", filename)
|
102
|
+
torch.save(checkpoint, filename)
|
103
|
+
|
104
|
+
def load(self, checkpoint_path):
|
105
|
+
"""
|
106
|
+
Load model checkpoint
|
107
|
+
:param checkpoint_path: str - path/to/checkpoints
|
108
|
+
:return: None
|
109
|
+
"""
|
110
|
+
checkpoint_path = natsort.natsorted(glob.glob(f"{checkpoint_path}/checkpoint_*.pt"))
|
111
|
+
checkpoint_path = checkpoint_path[-1]
|
112
|
+
|
113
|
+
logger.info("Loading checkpoint %s", checkpoint_path)
|
114
|
+
|
115
|
+
device = None if torch.cuda.is_available() else torch.device('cpu')
|
116
|
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
117
|
+
self.model.load_state_dict(checkpoint["model"])
|
@@ -0,0 +1,203 @@
|
|
1
|
+
import os
|
2
|
+
import logging
|
3
|
+
import torch
|
4
|
+
import numpy as np
|
5
|
+
from sinatools.ner.trainers import BaseTrainer
|
6
|
+
from sinatools.ner.metrics import compute_nested_metrics
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
class BertNestedTrainer(BaseTrainer):
|
12
|
+
def __init__(self, **kwargs):
|
13
|
+
super().__init__(**kwargs)
|
14
|
+
|
15
|
+
def train(self):
|
16
|
+
best_val_loss, test_loss = np.inf, np.inf
|
17
|
+
num_train_batch = len(self.train_dataloader)
|
18
|
+
num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
|
19
|
+
patience = self.patience
|
20
|
+
|
21
|
+
for epoch_index in range(self.max_epochs):
|
22
|
+
self.current_epoch = epoch_index
|
23
|
+
train_loss = 0
|
24
|
+
|
25
|
+
for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
|
26
|
+
self.train_dataloader, is_train=True
|
27
|
+
), 1):
|
28
|
+
self.current_timestep += 1
|
29
|
+
|
30
|
+
# Compute loses for each output
|
31
|
+
# logits = B x T x L x C
|
32
|
+
losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
|
33
|
+
torch.reshape(gold_tags[:, i, :], (-1,)).long())
|
34
|
+
for i, l in enumerate(num_labels)]
|
35
|
+
|
36
|
+
torch.autograd.backward(losses)
|
37
|
+
|
38
|
+
# Avoid exploding gradient by doing gradient clipping
|
39
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
|
40
|
+
|
41
|
+
self.optimizer.step()
|
42
|
+
self.scheduler.step()
|
43
|
+
batch_loss = sum(l.item() for l in losses)
|
44
|
+
train_loss += batch_loss
|
45
|
+
|
46
|
+
if self.current_timestep % self.log_interval == 0:
|
47
|
+
logger.info(
|
48
|
+
"Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
|
49
|
+
epoch_index,
|
50
|
+
batch_index,
|
51
|
+
num_train_batch,
|
52
|
+
self.current_timestep,
|
53
|
+
self.optimizer.param_groups[0]['lr'],
|
54
|
+
batch_loss
|
55
|
+
)
|
56
|
+
|
57
|
+
train_loss /= num_train_batch
|
58
|
+
|
59
|
+
logger.info("** Evaluating on validation dataset **")
|
60
|
+
val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
|
61
|
+
val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
|
62
|
+
|
63
|
+
epoch_summary_loss = {
|
64
|
+
"train_loss": train_loss,
|
65
|
+
"val_loss": val_loss
|
66
|
+
}
|
67
|
+
epoch_summary_metrics = {
|
68
|
+
"val_micro_f1": val_metrics.micro_f1,
|
69
|
+
"val_precision": val_metrics.precision,
|
70
|
+
"val_recall": val_metrics.recall
|
71
|
+
}
|
72
|
+
|
73
|
+
logger.info(
|
74
|
+
"Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
|
75
|
+
epoch_index,
|
76
|
+
self.current_timestep,
|
77
|
+
train_loss,
|
78
|
+
val_loss,
|
79
|
+
val_metrics.micro_f1
|
80
|
+
)
|
81
|
+
|
82
|
+
if val_loss < best_val_loss:
|
83
|
+
patience = self.patience
|
84
|
+
best_val_loss = val_loss
|
85
|
+
logger.info("** Validation improved, evaluating test data **")
|
86
|
+
test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
|
87
|
+
self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
|
88
|
+
test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
|
89
|
+
|
90
|
+
epoch_summary_loss["test_loss"] = test_loss
|
91
|
+
epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
|
92
|
+
epoch_summary_metrics["test_precision"] = test_metrics.precision
|
93
|
+
epoch_summary_metrics["test_recall"] = test_metrics.recall
|
94
|
+
|
95
|
+
logger.info(
|
96
|
+
f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
|
97
|
+
epoch_index,
|
98
|
+
self.current_timestep,
|
99
|
+
test_loss,
|
100
|
+
test_metrics.micro_f1
|
101
|
+
)
|
102
|
+
|
103
|
+
self.save()
|
104
|
+
else:
|
105
|
+
patience -= 1
|
106
|
+
|
107
|
+
# No improvements, terminating early
|
108
|
+
if patience == 0:
|
109
|
+
logger.info("Early termination triggered")
|
110
|
+
break
|
111
|
+
|
112
|
+
self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
|
113
|
+
self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
|
114
|
+
|
115
|
+
def tag(self, dataloader, is_train=True):
|
116
|
+
"""
|
117
|
+
Given a dataloader containing segments, predict the tags
|
118
|
+
:param dataloader: torch.utils.data.DataLoader
|
119
|
+
:param is_train: boolean - True for training model, False for evaluation
|
120
|
+
:return: Iterator
|
121
|
+
subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
|
122
|
+
gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
|
123
|
+
tokens - List[arabiner.data.dataset.Token] - list of tokens
|
124
|
+
valid_len (B x 1) - int - valiud length of each sequence
|
125
|
+
logits (B x T x NUM_LABELS) - logits for each token and each tag
|
126
|
+
"""
|
127
|
+
for subwords, gold_tags, tokens, mask, valid_len in dataloader:
|
128
|
+
self.model.train(is_train)
|
129
|
+
|
130
|
+
if torch.cuda.is_available():
|
131
|
+
subwords = subwords.cuda()
|
132
|
+
gold_tags = gold_tags.cuda()
|
133
|
+
|
134
|
+
if is_train:
|
135
|
+
self.optimizer.zero_grad()
|
136
|
+
logits = self.model(subwords)
|
137
|
+
else:
|
138
|
+
with torch.no_grad():
|
139
|
+
logits = self.model(subwords)
|
140
|
+
|
141
|
+
yield subwords, gold_tags, tokens, valid_len, logits
|
142
|
+
|
143
|
+
def eval(self, dataloader):
|
144
|
+
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
145
|
+
num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
|
146
|
+
loss = 0
|
147
|
+
|
148
|
+
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
149
|
+
dataloader, is_train=False
|
150
|
+
):
|
151
|
+
losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
|
152
|
+
torch.reshape(gold_tags[:, i, :], (-1,)).long())
|
153
|
+
for i, l in enumerate(num_labels)]
|
154
|
+
loss += sum(losses)
|
155
|
+
preds += torch.argmax(logits, dim=3)
|
156
|
+
segments += tokens
|
157
|
+
valid_lens += list(valid_len)
|
158
|
+
|
159
|
+
loss /= len(dataloader)
|
160
|
+
|
161
|
+
# Update segments, attach predicted tags to each token
|
162
|
+
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
163
|
+
|
164
|
+
return preds, segments, valid_lens, loss
|
165
|
+
|
166
|
+
def infer(self, dataloader):
|
167
|
+
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
168
|
+
|
169
|
+
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
170
|
+
dataloader, is_train=False
|
171
|
+
):
|
172
|
+
preds += torch.argmax(logits, dim=3)
|
173
|
+
segments += tokens
|
174
|
+
valid_lens += list(valid_len)
|
175
|
+
|
176
|
+
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
177
|
+
return segments
|
178
|
+
|
179
|
+
def to_segments(self, segments, preds, valid_lens, vocab):
|
180
|
+
if vocab is None:
|
181
|
+
vocab = self.vocab
|
182
|
+
|
183
|
+
tagged_segments = list()
|
184
|
+
tokens_stoi = vocab.tokens.get_stoi()
|
185
|
+
unk_id = tokens_stoi["UNK"]
|
186
|
+
|
187
|
+
for segment, pred, valid_len in zip(segments, preds, valid_lens):
|
188
|
+
# First, the token at 0th index [CLS] and token at nth index [SEP]
|
189
|
+
# Combine the tokens with their corresponding predictions
|
190
|
+
segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
|
191
|
+
|
192
|
+
# Ignore the sub-tokens/subwords, which are identified with text being UNK
|
193
|
+
segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
|
194
|
+
|
195
|
+
# Attach the predicted tags to each token
|
196
|
+
list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
|
197
|
+
for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
|
198
|
+
|
199
|
+
# We are only interested in the tagged tokens, we do no longer need raw model predictions
|
200
|
+
tagged_segment = [t for t, _ in segment_pred]
|
201
|
+
tagged_segments.append(tagged_segment)
|
202
|
+
|
203
|
+
return tagged_segments
|