SinaTools 0.1.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SinaTools-0.1.1.data/data/nlptools/environment.yml +227 -0
- SinaTools-0.1.1.dist-info/AUTHORS.rst +13 -0
- SinaTools-0.1.1.dist-info/LICENSE +22 -0
- SinaTools-0.1.1.dist-info/METADATA +72 -0
- SinaTools-0.1.1.dist-info/RECORD +122 -0
- SinaTools-0.1.1.dist-info/WHEEL +6 -0
- SinaTools-0.1.1.dist-info/entry_points.txt +18 -0
- SinaTools-0.1.1.dist-info/top_level.txt +1 -0
- nlptools/CLI/DataDownload/download_files.py +71 -0
- nlptools/CLI/arabiner/bin/infer.py +117 -0
- nlptools/CLI/arabiner/bin/infer2.py +81 -0
- nlptools/CLI/morphology/ALMA_multi_word.py +75 -0
- nlptools/CLI/morphology/morph_analyzer.py +91 -0
- nlptools/CLI/salma/salma_tools.py +68 -0
- nlptools/CLI/utils/__init__.py +0 -0
- nlptools/CLI/utils/arStrip.py +99 -0
- nlptools/CLI/utils/corpus_tokenizer.py +74 -0
- nlptools/CLI/utils/implication.py +92 -0
- nlptools/CLI/utils/jaccard.py +96 -0
- nlptools/CLI/utils/latin_remove.py +51 -0
- nlptools/CLI/utils/remove_Punc.py +53 -0
- nlptools/CLI/utils/sentence_tokenizer.py +90 -0
- nlptools/CLI/utils/text_transliteration.py +77 -0
- nlptools/DataDownload/__init__.py +0 -0
- nlptools/DataDownload/downloader.py +185 -0
- nlptools/VERSION +1 -0
- nlptools/__init__.py +5 -0
- nlptools/arabert/__init__.py +1 -0
- nlptools/arabert/arabert/__init__.py +14 -0
- nlptools/arabert/arabert/create_classification_data.py +260 -0
- nlptools/arabert/arabert/create_pretraining_data.py +534 -0
- nlptools/arabert/arabert/extract_features.py +444 -0
- nlptools/arabert/arabert/lamb_optimizer.py +158 -0
- nlptools/arabert/arabert/modeling.py +1027 -0
- nlptools/arabert/arabert/optimization.py +202 -0
- nlptools/arabert/arabert/run_classifier.py +1078 -0
- nlptools/arabert/arabert/run_pretraining.py +593 -0
- nlptools/arabert/arabert/run_squad.py +1440 -0
- nlptools/arabert/arabert/tokenization.py +414 -0
- nlptools/arabert/araelectra/__init__.py +1 -0
- nlptools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +103 -0
- nlptools/arabert/araelectra/build_pretraining_dataset.py +230 -0
- nlptools/arabert/araelectra/build_pretraining_dataset_single_file.py +90 -0
- nlptools/arabert/araelectra/configure_finetuning.py +172 -0
- nlptools/arabert/araelectra/configure_pretraining.py +143 -0
- nlptools/arabert/araelectra/finetune/__init__.py +14 -0
- nlptools/arabert/araelectra/finetune/feature_spec.py +56 -0
- nlptools/arabert/araelectra/finetune/preprocessing.py +173 -0
- nlptools/arabert/araelectra/finetune/scorer.py +54 -0
- nlptools/arabert/araelectra/finetune/task.py +74 -0
- nlptools/arabert/araelectra/finetune/task_builder.py +70 -0
- nlptools/arabert/araelectra/flops_computation.py +215 -0
- nlptools/arabert/araelectra/model/__init__.py +14 -0
- nlptools/arabert/araelectra/model/modeling.py +1029 -0
- nlptools/arabert/araelectra/model/optimization.py +193 -0
- nlptools/arabert/araelectra/model/tokenization.py +355 -0
- nlptools/arabert/araelectra/pretrain/__init__.py +14 -0
- nlptools/arabert/araelectra/pretrain/pretrain_data.py +160 -0
- nlptools/arabert/araelectra/pretrain/pretrain_helpers.py +229 -0
- nlptools/arabert/araelectra/run_finetuning.py +323 -0
- nlptools/arabert/araelectra/run_pretraining.py +469 -0
- nlptools/arabert/araelectra/util/__init__.py +14 -0
- nlptools/arabert/araelectra/util/training_utils.py +112 -0
- nlptools/arabert/araelectra/util/utils.py +109 -0
- nlptools/arabert/aragpt2/__init__.py +2 -0
- nlptools/arabert/aragpt2/create_pretraining_data.py +95 -0
- nlptools/arabert/aragpt2/gpt2/__init__.py +2 -0
- nlptools/arabert/aragpt2/gpt2/lamb_optimizer.py +158 -0
- nlptools/arabert/aragpt2/gpt2/optimization.py +225 -0
- nlptools/arabert/aragpt2/gpt2/run_pretraining.py +397 -0
- nlptools/arabert/aragpt2/grover/__init__.py +0 -0
- nlptools/arabert/aragpt2/grover/dataloader.py +161 -0
- nlptools/arabert/aragpt2/grover/modeling.py +803 -0
- nlptools/arabert/aragpt2/grover/modeling_gpt2.py +1196 -0
- nlptools/arabert/aragpt2/grover/optimization_adafactor.py +234 -0
- nlptools/arabert/aragpt2/grover/train_tpu.py +187 -0
- nlptools/arabert/aragpt2/grover/utils.py +234 -0
- nlptools/arabert/aragpt2/train_bpe_tokenizer.py +59 -0
- nlptools/arabert/preprocess.py +818 -0
- nlptools/arabiner/__init__.py +0 -0
- nlptools/arabiner/bin/__init__.py +14 -0
- nlptools/arabiner/bin/eval.py +87 -0
- nlptools/arabiner/bin/infer.py +91 -0
- nlptools/arabiner/bin/process.py +140 -0
- nlptools/arabiner/bin/train.py +221 -0
- nlptools/arabiner/data/__init__.py +1 -0
- nlptools/arabiner/data/datasets.py +146 -0
- nlptools/arabiner/data/transforms.py +118 -0
- nlptools/arabiner/nn/BaseModel.py +22 -0
- nlptools/arabiner/nn/BertNestedTagger.py +34 -0
- nlptools/arabiner/nn/BertSeqTagger.py +17 -0
- nlptools/arabiner/nn/__init__.py +3 -0
- nlptools/arabiner/trainers/BaseTrainer.py +117 -0
- nlptools/arabiner/trainers/BertNestedTrainer.py +203 -0
- nlptools/arabiner/trainers/BertTrainer.py +163 -0
- nlptools/arabiner/trainers/__init__.py +3 -0
- nlptools/arabiner/utils/__init__.py +0 -0
- nlptools/arabiner/utils/data.py +124 -0
- nlptools/arabiner/utils/helpers.py +151 -0
- nlptools/arabiner/utils/metrics.py +69 -0
- nlptools/environment.yml +227 -0
- nlptools/install_env.py +13 -0
- nlptools/morphology/ALMA_multi_word.py +34 -0
- nlptools/morphology/__init__.py +52 -0
- nlptools/morphology/charsets.py +60 -0
- nlptools/morphology/morph_analyzer.py +170 -0
- nlptools/morphology/settings.py +8 -0
- nlptools/morphology/tokenizers_words.py +19 -0
- nlptools/nlptools.py +1 -0
- nlptools/salma/__init__.py +12 -0
- nlptools/salma/settings.py +31 -0
- nlptools/salma/views.py +459 -0
- nlptools/salma/wsd.py +126 -0
- nlptools/utils/__init__.py +0 -0
- nlptools/utils/corpus_tokenizer.py +73 -0
- nlptools/utils/implication.py +662 -0
- nlptools/utils/jaccard.py +247 -0
- nlptools/utils/parser.py +147 -0
- nlptools/utils/readfile.py +3 -0
- nlptools/utils/sentence_tokenizer.py +53 -0
- nlptools/utils/text_transliteration.py +232 -0
- nlptools/utils/utils.py +2 -0
@@ -0,0 +1,163 @@
|
|
1
|
+
import os
|
2
|
+
import logging
|
3
|
+
import torch
|
4
|
+
import numpy as np
|
5
|
+
from nlptools.arabiner.trainers import BaseTrainer
|
6
|
+
from nlptools.arabiner.utils.metrics import compute_single_label_metrics
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
class BertTrainer(BaseTrainer):
|
12
|
+
def __init__(self, **kwargs):
|
13
|
+
super().__init__(**kwargs)
|
14
|
+
|
15
|
+
def train(self):
|
16
|
+
best_val_loss, test_loss = np.inf, np.inf
|
17
|
+
num_train_batch = len(self.train_dataloader)
|
18
|
+
patience = self.patience
|
19
|
+
|
20
|
+
for epoch_index in range(self.max_epochs):
|
21
|
+
self.current_epoch = epoch_index
|
22
|
+
train_loss = 0
|
23
|
+
|
24
|
+
for batch_index, (_, gold_tags, _, _, logits) in enumerate(self.tag(
|
25
|
+
self.train_dataloader, is_train=True
|
26
|
+
), 1):
|
27
|
+
self.current_timestep += 1
|
28
|
+
batch_loss = self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
|
29
|
+
batch_loss.backward()
|
30
|
+
|
31
|
+
# Avoid exploding gradient by doing gradient clipping
|
32
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
|
33
|
+
|
34
|
+
self.optimizer.step()
|
35
|
+
self.scheduler.step()
|
36
|
+
train_loss += batch_loss.item()
|
37
|
+
|
38
|
+
if self.current_timestep % self.log_interval == 0:
|
39
|
+
logger.info(
|
40
|
+
"Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
|
41
|
+
epoch_index,
|
42
|
+
batch_index,
|
43
|
+
num_train_batch,
|
44
|
+
self.current_timestep,
|
45
|
+
self.optimizer.param_groups[0]['lr'],
|
46
|
+
batch_loss.item()
|
47
|
+
)
|
48
|
+
|
49
|
+
train_loss /= num_train_batch
|
50
|
+
|
51
|
+
logger.info("** Evaluating on validation dataset **")
|
52
|
+
val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
|
53
|
+
val_metrics = compute_single_label_metrics(segments)
|
54
|
+
|
55
|
+
epoch_summary_loss = {
|
56
|
+
"train_loss": train_loss,
|
57
|
+
"val_loss": val_loss
|
58
|
+
}
|
59
|
+
epoch_summary_metrics = {
|
60
|
+
"val_micro_f1": val_metrics.micro_f1,
|
61
|
+
"val_precision": val_metrics.precision,
|
62
|
+
"val_recall": val_metrics.recall
|
63
|
+
}
|
64
|
+
|
65
|
+
logger.info(
|
66
|
+
"Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
|
67
|
+
epoch_index,
|
68
|
+
self.current_timestep,
|
69
|
+
train_loss,
|
70
|
+
val_loss,
|
71
|
+
val_metrics.micro_f1
|
72
|
+
)
|
73
|
+
|
74
|
+
if val_loss < best_val_loss:
|
75
|
+
patience = self.patience
|
76
|
+
best_val_loss = val_loss
|
77
|
+
logger.info("** Validation improved, evaluating test data **")
|
78
|
+
test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
|
79
|
+
self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
|
80
|
+
test_metrics = compute_single_label_metrics(segments)
|
81
|
+
|
82
|
+
epoch_summary_loss["test_loss"] = test_loss
|
83
|
+
epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
|
84
|
+
epoch_summary_metrics["test_precision"] = test_metrics.precision
|
85
|
+
epoch_summary_metrics["test_recall"] = test_metrics.recall
|
86
|
+
|
87
|
+
logger.info(
|
88
|
+
f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
|
89
|
+
epoch_index,
|
90
|
+
self.current_timestep,
|
91
|
+
test_loss,
|
92
|
+
test_metrics.micro_f1
|
93
|
+
)
|
94
|
+
|
95
|
+
self.save()
|
96
|
+
else:
|
97
|
+
patience -= 1
|
98
|
+
|
99
|
+
# No improvements, terminating early
|
100
|
+
if patience == 0:
|
101
|
+
logger.info("Early termination triggered")
|
102
|
+
break
|
103
|
+
|
104
|
+
self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
|
105
|
+
self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
|
106
|
+
|
107
|
+
def eval(self, dataloader):
|
108
|
+
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
109
|
+
loss = 0
|
110
|
+
|
111
|
+
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
112
|
+
dataloader, is_train=False
|
113
|
+
):
|
114
|
+
loss += self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
|
115
|
+
preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
|
116
|
+
segments += tokens
|
117
|
+
valid_lens += list(valid_len)
|
118
|
+
|
119
|
+
loss /= len(dataloader)
|
120
|
+
|
121
|
+
# Update segments, attach predicted tags to each token
|
122
|
+
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
123
|
+
|
124
|
+
return preds, segments, valid_lens, loss.item()
|
125
|
+
|
126
|
+
def infer(self, dataloader):
|
127
|
+
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
128
|
+
|
129
|
+
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
130
|
+
dataloader, is_train=False
|
131
|
+
):
|
132
|
+
preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
|
133
|
+
segments += tokens
|
134
|
+
valid_lens += list(valid_len)
|
135
|
+
|
136
|
+
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
137
|
+
return segments
|
138
|
+
|
139
|
+
def to_segments(self, segments, preds, valid_lens, vocab):
|
140
|
+
if vocab is None:
|
141
|
+
vocab = self.vocab
|
142
|
+
|
143
|
+
tagged_segments = list()
|
144
|
+
tokens_stoi = vocab.tokens.get_stoi()
|
145
|
+
tags_itos = vocab.tags[0].get_itos()
|
146
|
+
unk_id = tokens_stoi["UNK"]
|
147
|
+
|
148
|
+
for segment, pred, valid_len in zip(segments, preds, valid_lens):
|
149
|
+
# First, the token at 0th index [CLS] and token at nth index [SEP]
|
150
|
+
# Combine the tokens with their corresponding predictions
|
151
|
+
segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
|
152
|
+
|
153
|
+
# Ignore the sub-tokens/subwords, which are identified with text being UNK
|
154
|
+
segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
|
155
|
+
|
156
|
+
# Attach the predicted tags to each token
|
157
|
+
list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": tags_itos[t[1]]}]), segment_pred))
|
158
|
+
|
159
|
+
# We are only interested in the tagged tokens, we do no longer need raw model predictions
|
160
|
+
tagged_segment = [t for t, _ in segment_pred]
|
161
|
+
tagged_segments.append(tagged_segment)
|
162
|
+
|
163
|
+
return tagged_segments
|
File without changes
|
@@ -0,0 +1,124 @@
|
|
1
|
+
from torch.utils.data import DataLoader
|
2
|
+
from torchtext.vocab import vocab
|
3
|
+
from collections import Counter, namedtuple
|
4
|
+
import logging
|
5
|
+
import re
|
6
|
+
import itertools
|
7
|
+
from nlptools.arabiner.utils.helpers import load_object
|
8
|
+
from nlptools.arabiner.data.datasets import Token
|
9
|
+
from nlptools.morphology.tokenizers_words import simple_word_tokenize
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
|
14
|
+
def conll_to_segments(filename):
|
15
|
+
"""
|
16
|
+
Convert CoNLL files to segments. This return list of segments and each segment is
|
17
|
+
a list of tuples (token, tag)
|
18
|
+
:param filename: Path
|
19
|
+
:return: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
|
20
|
+
"""
|
21
|
+
segments, segment = list(), list()
|
22
|
+
|
23
|
+
with open(filename, "r") as fh:
|
24
|
+
for token in fh.read().splitlines():
|
25
|
+
if not token.strip():
|
26
|
+
segments.append(segment)
|
27
|
+
segment = list()
|
28
|
+
else:
|
29
|
+
parts = token.split()
|
30
|
+
token = Token(text=parts[0], gold_tag=parts[1:])
|
31
|
+
segment.append(token)
|
32
|
+
|
33
|
+
segments.append(segment)
|
34
|
+
|
35
|
+
return segments
|
36
|
+
|
37
|
+
|
38
|
+
def parse_conll_files(data_paths):
|
39
|
+
"""
|
40
|
+
Parse CoNLL formatted files and return list of segments for each file and index
|
41
|
+
the vocabs and tags across all data_paths
|
42
|
+
:param data_paths: tuple(Path) - tuple of filenames
|
43
|
+
:return: tuple( [[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i]
|
44
|
+
[[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i+1],
|
45
|
+
...
|
46
|
+
)
|
47
|
+
List of segments for each dataset and each segment has list of (tokens, tags)
|
48
|
+
"""
|
49
|
+
vocabs = namedtuple("Vocab", ["tags", "tokens"])
|
50
|
+
datasets, tags, tokens = list(), list(), list()
|
51
|
+
|
52
|
+
for data_path in data_paths:
|
53
|
+
dataset = conll_to_segments(data_path)
|
54
|
+
datasets.append(dataset)
|
55
|
+
tokens += [token.text for segment in dataset for token in segment]
|
56
|
+
tags += [token.gold_tag for segment in dataset for token in segment]
|
57
|
+
|
58
|
+
# Flatten list of tags
|
59
|
+
tags = list(itertools.chain(*tags))
|
60
|
+
|
61
|
+
# Generate vocabs for tags and tokens
|
62
|
+
tag_vocabs = tag_vocab_by_type(tags)
|
63
|
+
tag_vocabs.insert(0, vocab(Counter(tags)))
|
64
|
+
vocabs = vocabs(tokens=vocab(Counter(tokens), specials=["UNK"]), tags=tag_vocabs)
|
65
|
+
return tuple(datasets), vocabs
|
66
|
+
|
67
|
+
|
68
|
+
def tag_vocab_by_type(tags):
|
69
|
+
vocabs = list()
|
70
|
+
c = Counter(tags)
|
71
|
+
tag_names = c.keys()
|
72
|
+
tag_types = sorted(list(set([tag.split("-", 1)[1] for tag in tag_names if "-" in tag])))
|
73
|
+
|
74
|
+
for tag_type in tag_types:
|
75
|
+
r = re.compile(".*-" + tag_type)
|
76
|
+
t = list(filter(r.match, tags)) + ["O"]
|
77
|
+
vocabs.append(vocab(Counter(t), specials=["<pad>"]))
|
78
|
+
|
79
|
+
return vocabs
|
80
|
+
|
81
|
+
|
82
|
+
def text2segments(text):
|
83
|
+
"""
|
84
|
+
Convert text to a datasets and index the tokens
|
85
|
+
"""
|
86
|
+
#dataset = [[Token(text=token, gold_tag=["O"]) for token in text.split()]]
|
87
|
+
list_of_tokens = simple_word_tokenize(text)
|
88
|
+
dataset = [[Token(text=token, gold_tag=["O"]) for token in list_of_tokens]]
|
89
|
+
tokens = [token.text for segment in dataset for token in segment]
|
90
|
+
|
91
|
+
# Generate vocabs for the tokens
|
92
|
+
segment_vocab = vocab(Counter(tokens), specials=["UNK"])
|
93
|
+
return dataset, segment_vocab
|
94
|
+
|
95
|
+
|
96
|
+
def get_dataloaders(
|
97
|
+
datasets, vocab, data_config, batch_size=32, num_workers=0, shuffle=(True, False, False)
|
98
|
+
):
|
99
|
+
"""
|
100
|
+
From the datasets generate the dataloaders
|
101
|
+
:param datasets: list - list of the datasets, list of list of segments and tokens
|
102
|
+
:param batch_size: int
|
103
|
+
:param num_workers: int
|
104
|
+
:param shuffle: boolean - to shuffle the data or not
|
105
|
+
:return: List[torch.utils.data.DataLoader]
|
106
|
+
"""
|
107
|
+
dataloaders = list()
|
108
|
+
|
109
|
+
for i, examples in enumerate(datasets):
|
110
|
+
data_config["kwargs"].update({"examples": examples, "vocab": vocab})
|
111
|
+
dataset = load_object("nlptools."+data_config["fn"], data_config["kwargs"])
|
112
|
+
|
113
|
+
dataloader = DataLoader(
|
114
|
+
dataset=dataset,
|
115
|
+
shuffle=shuffle[i],
|
116
|
+
batch_size=batch_size,
|
117
|
+
num_workers=num_workers,
|
118
|
+
collate_fn=dataset.collate_fn,
|
119
|
+
)
|
120
|
+
|
121
|
+
logger.info("%s batches found", len(dataloader))
|
122
|
+
dataloaders.append(dataloader)
|
123
|
+
|
124
|
+
return dataloaders
|
@@ -0,0 +1,151 @@
|
|
1
|
+
import os
|
2
|
+
import sys
|
3
|
+
import logging
|
4
|
+
import importlib
|
5
|
+
import shutil
|
6
|
+
import torch
|
7
|
+
import pickle
|
8
|
+
import json
|
9
|
+
import random
|
10
|
+
import numpy as np
|
11
|
+
from argparse import Namespace
|
12
|
+
|
13
|
+
|
14
|
+
def logging_config(log_file=None):
|
15
|
+
"""
|
16
|
+
Initialize custom logger
|
17
|
+
:param log_file: str - path to log file, full path
|
18
|
+
:return: None
|
19
|
+
"""
|
20
|
+
handlers = [logging.StreamHandler(sys.stdout)]
|
21
|
+
|
22
|
+
if log_file:
|
23
|
+
handlers.append(logging.FileHandler(log_file, "w", "utf-8"))
|
24
|
+
print("Logging to {}".format(log_file))
|
25
|
+
|
26
|
+
logging.basicConfig(
|
27
|
+
level=logging.INFO,
|
28
|
+
handlers=handlers,
|
29
|
+
format="%(levelname)s\t%(name)s\t%(asctime)s\t%(message)s",
|
30
|
+
datefmt="%a, %d %b %Y %H:%M:%S",
|
31
|
+
force=True
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
#def load_object(name, kwargs):
|
36
|
+
# """
|
37
|
+
# Load objects dynamically given the object name and its arguments
|
38
|
+
# :param name: str - object name, class name or function name
|
39
|
+
# :param kwargs: dict - keyword arguments
|
40
|
+
# :return: object
|
41
|
+
# """
|
42
|
+
# object_module, object_name = name.rsplit(".", 1)
|
43
|
+
# object_module = importlib.import_module(object_module)
|
44
|
+
# fn = getattr(object_module, object_name)(**kwargs)
|
45
|
+
# return fn
|
46
|
+
|
47
|
+
def load_object(name, kwargs):
|
48
|
+
#print("Iterate: ", name, kwargs)
|
49
|
+
try:
|
50
|
+
#print("name: ", name)
|
51
|
+
object_module, object_name = name.rsplit(".", 1)
|
52
|
+
#print("object_module: ", object_module)
|
53
|
+
#print("object_name: ", object_name)
|
54
|
+
#if object_module != "nlptools.arabiner.nn":
|
55
|
+
object_module = importlib.import_module(object_module)
|
56
|
+
#print("object_module 2: ", object_module)
|
57
|
+
obj = getattr(object_module, object_name)
|
58
|
+
#print("obj: ", obj)
|
59
|
+
if callable(obj): # Check if the object is callable (class or function)
|
60
|
+
fn = obj(**kwargs)
|
61
|
+
return fn
|
62
|
+
else:
|
63
|
+
raise TypeError(f"{name} is not a callable object.")
|
64
|
+
except (ImportError, ModuleNotFoundError) as e:
|
65
|
+
# Handle import errors
|
66
|
+
print(f"Error importing module: {e}")
|
67
|
+
except AttributeError as e:
|
68
|
+
# Handle attribute errors (e.g., object not found in module)
|
69
|
+
print(f"Attribute error: {e}")
|
70
|
+
except Exception as e:
|
71
|
+
# Handle other exceptions
|
72
|
+
print(f"An error occurred: {e}")
|
73
|
+
|
74
|
+
#print("Loaded object:", name)
|
75
|
+
return None # Return None in case of any error
|
76
|
+
|
77
|
+
def make_output_dirs(path, subdirs=[], overwrite=True):
|
78
|
+
"""
|
79
|
+
Create root directory and any other sub-directories
|
80
|
+
:param path: str - root directory
|
81
|
+
:param subdirs: List[str] - list of sub-directories
|
82
|
+
:param overwrite: boolean - to overwrite the directory or not
|
83
|
+
:return: None
|
84
|
+
"""
|
85
|
+
if overwrite:
|
86
|
+
shutil.rmtree(path, ignore_errors=True)
|
87
|
+
|
88
|
+
os.makedirs(path)
|
89
|
+
|
90
|
+
for subdir in subdirs:
|
91
|
+
os.makedirs(os.path.join(path, subdir))
|
92
|
+
|
93
|
+
|
94
|
+
def load_checkpoint(path):
|
95
|
+
"""
|
96
|
+
Load model given the model path
|
97
|
+
:param model_path: str - path to model
|
98
|
+
:return: tagger - arabiner.trainers.BaseTrainer - the tagger model
|
99
|
+
vocab - torchtext.vocab.Vocab - indexed tags
|
100
|
+
train_config - argparse.Namespace - training configurations
|
101
|
+
"""
|
102
|
+
_path = os.path.join(path, "tag_vocab.pkl")
|
103
|
+
#print('2',_path)
|
104
|
+
with open(_path, "rb") as fh:
|
105
|
+
tag_vocab = pickle.load(fh)
|
106
|
+
|
107
|
+
# Load train configurations from checkpoint
|
108
|
+
train_config = Namespace()
|
109
|
+
args_path = os.path.join(path, "args.json")
|
110
|
+
#print('3', args_path)
|
111
|
+
with open(args_path, "r") as fh:
|
112
|
+
train_config.__dict__ = json.load(fh)
|
113
|
+
|
114
|
+
# Initialize the loss function, not used for inference, but evaluation
|
115
|
+
loss = load_object(train_config.loss["fn"], train_config.loss["kwargs"])
|
116
|
+
#print('4')
|
117
|
+
# Load BERT tagger
|
118
|
+
model = load_object("nlptools."+train_config.network_config["fn"], train_config.network_config["kwargs"])
|
119
|
+
model = torch.nn.DataParallel(model)
|
120
|
+
#print('5')
|
121
|
+
if torch.cuda.is_available():
|
122
|
+
model = model.cuda()
|
123
|
+
|
124
|
+
# Update arguments for the tagger
|
125
|
+
# Attach the model, loss (used for evaluations cases)
|
126
|
+
train_config.trainer_config["kwargs"]["model"] = model
|
127
|
+
train_config.trainer_config["kwargs"]["loss"] = loss
|
128
|
+
#print('6')
|
129
|
+
tagger = load_object("nlptools."+train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
|
130
|
+
tagger.load(os.path.join(path, "checkpoints"))
|
131
|
+
#print('7')
|
132
|
+
return tagger, tag_vocab, train_config
|
133
|
+
|
134
|
+
|
135
|
+
def set_seed(seed):
|
136
|
+
"""
|
137
|
+
Set the seed for random intialization and set
|
138
|
+
CUDANN parameters to ensure determmihstic results across
|
139
|
+
multiple runs with the same seed
|
140
|
+
|
141
|
+
:param seed: int
|
142
|
+
"""
|
143
|
+
np.random.seed(seed)
|
144
|
+
random.seed(seed)
|
145
|
+
torch.manual_seed(seed)
|
146
|
+
torch.cuda.manual_seed(seed)
|
147
|
+
torch.cuda.manual_seed_all(seed)
|
148
|
+
|
149
|
+
torch.backends.cudnn.deterministic = True
|
150
|
+
torch.backends.cudnn.benchmark = False
|
151
|
+
torch.backends.cudnn.enabled = False
|
@@ -0,0 +1,69 @@
|
|
1
|
+
from seqeval.metrics import (
|
2
|
+
classification_report,
|
3
|
+
precision_score,
|
4
|
+
recall_score,
|
5
|
+
f1_score,
|
6
|
+
accuracy_score,
|
7
|
+
)
|
8
|
+
from seqeval.scheme import IOB2
|
9
|
+
from types import SimpleNamespace
|
10
|
+
import logging
|
11
|
+
import re
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
def compute_nested_metrics(segments, vocabs):
|
17
|
+
"""
|
18
|
+
Compute metrics for nested NER
|
19
|
+
:param segments: List[List[arabiner.data.dataset.Token]] - list of segments
|
20
|
+
:return: metrics - SimpleNamespace - F1/micro/macro/weights, recall, precision, accuracy
|
21
|
+
"""
|
22
|
+
y, y_hat = list(), list()
|
23
|
+
|
24
|
+
# We duplicate the dataset N times, where N is the number of entity types
|
25
|
+
# For each copy, we create y and y_hat
|
26
|
+
# Example: first copy, will create pairs of ground truth and predicted labels for entity type GPE
|
27
|
+
# another copy will create pairs for LOC, etc.
|
28
|
+
for i, vocab in enumerate(vocabs):
|
29
|
+
vocab_tags = [tag for tag in vocab.get_itos() if "-" in tag]
|
30
|
+
r = re.compile("|".join(vocab_tags))
|
31
|
+
|
32
|
+
y += [[(list(filter(r.match, token.gold_tag)) or ["O"])[0] for token in segment] for segment in segments]
|
33
|
+
y_hat += [[token.pred_tag[i]["tag"] for token in segment] for segment in segments]
|
34
|
+
|
35
|
+
logging.info("\n" + classification_report(y, y_hat, scheme=IOB2, digits=4))
|
36
|
+
|
37
|
+
metrics = {
|
38
|
+
"micro_f1": f1_score(y, y_hat, average="micro", scheme=IOB2),
|
39
|
+
"macro_f1": f1_score(y, y_hat, average="macro", scheme=IOB2),
|
40
|
+
"weights_f1": f1_score(y, y_hat, average="weighted", scheme=IOB2),
|
41
|
+
"precision": precision_score(y, y_hat, scheme=IOB2),
|
42
|
+
"recall": recall_score(y, y_hat, scheme=IOB2),
|
43
|
+
"accuracy": accuracy_score(y, y_hat),
|
44
|
+
}
|
45
|
+
|
46
|
+
return SimpleNamespace(**metrics)
|
47
|
+
|
48
|
+
|
49
|
+
def compute_single_label_metrics(segments):
|
50
|
+
"""
|
51
|
+
Compute metrics for flat NER
|
52
|
+
:param segments: List[List[arabiner.data.dataset.Token]] - list of segments
|
53
|
+
:return: metrics - SimpleNamespace - F1/micro/macro/weights, recall, precision, accuracy
|
54
|
+
"""
|
55
|
+
y = [[token.gold_tag[0] for token in segment] for segment in segments]
|
56
|
+
y_hat = [[token.pred_tag[0]["tag"] for token in segment] for segment in segments]
|
57
|
+
|
58
|
+
logging.info("\n" + classification_report(y, y_hat, scheme=IOB2))
|
59
|
+
|
60
|
+
metrics = {
|
61
|
+
"micro_f1": f1_score(y, y_hat, average="micro", scheme=IOB2),
|
62
|
+
"macro_f1": f1_score(y, y_hat, average="macro", scheme=IOB2),
|
63
|
+
"weights_f1": f1_score(y, y_hat, average="weighted", scheme=IOB2),
|
64
|
+
"precision": precision_score(y, y_hat, scheme=IOB2),
|
65
|
+
"recall": recall_score(y, y_hat, scheme=IOB2),
|
66
|
+
"accuracy": accuracy_score(y, y_hat),
|
67
|
+
}
|
68
|
+
|
69
|
+
return SimpleNamespace(**metrics)
|