SinaTools 0.1.40__py2.py3-none-any.whl → 1.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/METADATA +1 -1
- SinaTools-1.0.1.dist-info/RECORD +73 -0
- sinatools/VERSION +1 -1
- sinatools/ner/__init__.py +5 -7
- sinatools/ner/trainers/BertNestedTrainer.py +203 -203
- sinatools/ner/trainers/BertTrainer.py +163 -163
- sinatools/ner/trainers/__init__.py +2 -2
- SinaTools-0.1.40.dist-info/RECORD +0 -123
- sinatools/arabert/arabert/__init__.py +0 -14
- sinatools/arabert/arabert/create_classification_data.py +0 -260
- sinatools/arabert/arabert/create_pretraining_data.py +0 -534
- sinatools/arabert/arabert/extract_features.py +0 -444
- sinatools/arabert/arabert/lamb_optimizer.py +0 -158
- sinatools/arabert/arabert/modeling.py +0 -1027
- sinatools/arabert/arabert/optimization.py +0 -202
- sinatools/arabert/arabert/run_classifier.py +0 -1078
- sinatools/arabert/arabert/run_pretraining.py +0 -593
- sinatools/arabert/arabert/run_squad.py +0 -1440
- sinatools/arabert/arabert/tokenization.py +0 -414
- sinatools/arabert/araelectra/__init__.py +0 -1
- sinatools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +0 -103
- sinatools/arabert/araelectra/build_pretraining_dataset.py +0 -230
- sinatools/arabert/araelectra/build_pretraining_dataset_single_file.py +0 -90
- sinatools/arabert/araelectra/configure_finetuning.py +0 -172
- sinatools/arabert/araelectra/configure_pretraining.py +0 -143
- sinatools/arabert/araelectra/finetune/__init__.py +0 -14
- sinatools/arabert/araelectra/finetune/feature_spec.py +0 -56
- sinatools/arabert/araelectra/finetune/preprocessing.py +0 -173
- sinatools/arabert/araelectra/finetune/scorer.py +0 -54
- sinatools/arabert/araelectra/finetune/task.py +0 -74
- sinatools/arabert/araelectra/finetune/task_builder.py +0 -70
- sinatools/arabert/araelectra/flops_computation.py +0 -215
- sinatools/arabert/araelectra/model/__init__.py +0 -14
- sinatools/arabert/araelectra/model/modeling.py +0 -1029
- sinatools/arabert/araelectra/model/optimization.py +0 -193
- sinatools/arabert/araelectra/model/tokenization.py +0 -355
- sinatools/arabert/araelectra/pretrain/__init__.py +0 -14
- sinatools/arabert/araelectra/pretrain/pretrain_data.py +0 -160
- sinatools/arabert/araelectra/pretrain/pretrain_helpers.py +0 -229
- sinatools/arabert/araelectra/run_finetuning.py +0 -323
- sinatools/arabert/araelectra/run_pretraining.py +0 -469
- sinatools/arabert/araelectra/util/__init__.py +0 -14
- sinatools/arabert/araelectra/util/training_utils.py +0 -112
- sinatools/arabert/araelectra/util/utils.py +0 -109
- sinatools/arabert/aragpt2/__init__.py +0 -2
- sinatools/arabert/aragpt2/create_pretraining_data.py +0 -95
- sinatools/arabert/aragpt2/gpt2/__init__.py +0 -2
- sinatools/arabert/aragpt2/gpt2/lamb_optimizer.py +0 -158
- sinatools/arabert/aragpt2/gpt2/optimization.py +0 -225
- sinatools/arabert/aragpt2/gpt2/run_pretraining.py +0 -397
- sinatools/arabert/aragpt2/grover/__init__.py +0 -0
- sinatools/arabert/aragpt2/grover/dataloader.py +0 -161
- sinatools/arabert/aragpt2/grover/modeling.py +0 -803
- sinatools/arabert/aragpt2/grover/modeling_gpt2.py +0 -1196
- sinatools/arabert/aragpt2/grover/optimization_adafactor.py +0 -234
- sinatools/arabert/aragpt2/grover/train_tpu.py +0 -187
- sinatools/arabert/aragpt2/grover/utils.py +0 -234
- sinatools/arabert/aragpt2/train_bpe_tokenizer.py +0 -59
- {SinaTools-0.1.40.data → SinaTools-1.0.1.data}/data/sinatools/environment.yml +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/AUTHORS.rst +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/LICENSE +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/WHEEL +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/entry_points.txt +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: SinaTools
|
3
|
-
Version: 0.1
|
3
|
+
Version: 1.0.1
|
4
4
|
Summary: Open-source Python toolkit for Arabic Natural Understanding, allowing people to integrate it in their system workflow.
|
5
5
|
Home-page: https://github.com/SinaLab/sinatools
|
6
6
|
License: MIT license
|
@@ -0,0 +1,73 @@
|
|
1
|
+
SinaTools-1.0.1.data/data/sinatools/environment.yml,sha256=i0UFZc-vwU9ZwnI8hBdz7vi-x22vG-HR8ojWBUAOkno,5422
|
2
|
+
sinatools/VERSION,sha256=1R5uyUBYVUqEVYpbQC7m71_fVFXjXJAv7aYc2odSlDo,5
|
3
|
+
sinatools/__init__.py,sha256=bEosTU1o-FSpyytS6iVP_82BXHF2yHnzpJxPLYRbeII,135
|
4
|
+
sinatools/environment.yml,sha256=i0UFZc-vwU9ZwnI8hBdz7vi-x22vG-HR8ojWBUAOkno,5422
|
5
|
+
sinatools/install_env.py,sha256=EODeeE0ZzfM_rz33_JSIruX03Nc4ghyVOM5BHVhsZaQ,404
|
6
|
+
sinatools/sinatools.py,sha256=vR5AaF0iel21LvsdcqwheoBz0SIj9K9I_Ub8M8oA98Y,20
|
7
|
+
sinatools/CLI/DataDownload/download_files.py,sha256=EezvbukR3pZ8s6mGZnzTcjsbo3CBDlC0g6KhJWlYp1w,2686
|
8
|
+
sinatools/CLI/morphology/ALMA_multi_word.py,sha256=rmpa72twwIJHme_kpQ1lu3_7y_Jorj70QTvOnQMJRuI,1274
|
9
|
+
sinatools/CLI/morphology/morph_analyzer.py,sha256=HPamEKos_JRYCJv_2q6c12N--da58_JXTno9haww5Ao,3497
|
10
|
+
sinatools/CLI/ner/corpus_entity_extractor.py,sha256=DdvigsDQzko5nJBjzUXlIDqoBMBTVzktjSo7JfEXTIA,4778
|
11
|
+
sinatools/CLI/ner/entity_extractor.py,sha256=G9j-t0WKm2CRORhqARJM-pI-KArQ2IXIvnBK_NHxlHs,2885
|
12
|
+
sinatools/CLI/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
+
sinatools/CLI/utils/arStrip.py,sha256=NLyp8vOu2xv80tL9jiKRvyptmbkRZVg-wcAr-9YyvNY,3264
|
14
|
+
sinatools/CLI/utils/corpus_tokenizer.py,sha256=nH0T4h6urr_0Qy6-wN3PquOtnwybj0REde5Ts_OE4U8,1650
|
15
|
+
sinatools/CLI/utils/implication.py,sha256=AojpkCwUQJiQjxhyEUWKRHmBnIt1tVqr485cAF7Thq0,2857
|
16
|
+
sinatools/CLI/utils/jaccard.py,sha256=w56N_cNEFJ0A7WtunmY_xtms4srFagKBzrW_0YhH2DE,4216
|
17
|
+
sinatools/CLI/utils/remove_latin.py,sha256=NOaTm2RHxt5IQrV98ySTmD8rTXTmcqSmfbPAwTyaXqU,848
|
18
|
+
sinatools/CLI/utils/remove_punctuation.py,sha256=vJAZlEn7WGftZAFVFYnddkRrxdJ_rMmKB9vFZkY-jN4,1097
|
19
|
+
sinatools/CLI/utils/sentence_tokenizer.py,sha256=Wli8eiDbWSd_Z8UKpu_JkaS8jImowa1vnRL0oYCSfqw,2823
|
20
|
+
sinatools/CLI/utils/text_dublication_detector.py,sha256=dW70O5O20GxeUDDF6zVYn52wWLmJF-HBZgvqIeVL2rQ,1661
|
21
|
+
sinatools/CLI/utils/text_transliteration.py,sha256=vz-3kxWf8pNYVCqNAtBAiA6u_efrS5NtWT-ofN1NX6I,2014
|
22
|
+
sinatools/DataDownload/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
|
+
sinatools/DataDownload/downloader.py,sha256=VdUNgSqMKz1J-DuQD_eS1U2KWqEpy94WlSJ0pPODLig,7833
|
24
|
+
sinatools/arabert/__init__.py,sha256=ely2PttjgSv7vKdzskuD1rtK_l_UOpmxJSz8isrveD0,16
|
25
|
+
sinatools/arabert/preprocess.py,sha256=qI0FsuMTOzdRlYGCtLrjpXgikNElUZPv9bnjaKDZKJ4,33024
|
26
|
+
sinatools/morphology/ALMA_multi_word.py,sha256=hj_-8ojrYYHnfCGk8WKtJdUR8mauzQdma4WUm-okDps,1346
|
27
|
+
sinatools/morphology/__init__.py,sha256=I4wVBh8BhyNl-CySVdiI_nUSn6gj1j-gmLKP300RpE0,1216
|
28
|
+
sinatools/morphology/morph_analyzer.py,sha256=JOH2UWKNQWo5UzpWNzP9R1D3B3qLSogIiMp8n0N_56o,7177
|
29
|
+
sinatools/ner/__init__.py,sha256=59kLMX6UQhF6JpE10RhaDYC3a2_jiWOIVPuejsoflFE,1050
|
30
|
+
sinatools/ner/data_format.py,sha256=VmFshZbEPOsWxsb4tgSkwvbM1k7yCce4kmtPkCiWgwM,4513
|
31
|
+
sinatools/ner/datasets.py,sha256=mG1iwqSm3lXCFHLqE-b4wNi176cpuzNBz8tKaBU6z6M,5059
|
32
|
+
sinatools/ner/entity_extractor.py,sha256=O2epRwRFUUcQs3SnFIYHVBI4zVhr8hRcj0XJYeby4ts,3588
|
33
|
+
sinatools/ner/helpers.py,sha256=sX6ezVbuVQxk_xJqZwhUzJVFVuVmFGmei_kd6r3sPHE,3652
|
34
|
+
sinatools/ner/metrics.py,sha256=Irz6SsIvpOzGIA2lWxrEV86xnTnm0TzKm9SUVT4SXUU,2734
|
35
|
+
sinatools/ner/transforms.py,sha256=vti3mDdi-IRP8i0aTQ37QqpPlP9hdMmJ6_bAMa0uL-s,4871
|
36
|
+
sinatools/ner/data/__init__.py,sha256=W0C1ge_XxTfmdEGz0hkclz57aLI5VFS5t6BjByCfkFk,57
|
37
|
+
sinatools/ner/data/datasets.py,sha256=_uUlvBAhnTtPwKLj0wIbmB04VCBidfwffxKorLGHq_g,5134
|
38
|
+
sinatools/ner/data/transforms.py,sha256=URMz1dHzkHjgUGAkDOenCWvQThO1ha8XeQVjoLL9RXM,4874
|
39
|
+
sinatools/ner/nn/BaseModel.py,sha256=3GmujQasTZZunOBuFXpY2p1W8W256iI_Uu4hxhOY2Z0,608
|
40
|
+
sinatools/ner/nn/BertNestedTagger.py,sha256=_fwAn1kiKmXe6m5y16Ipty3kvXIEFEmiUq74Ad1818U,1219
|
41
|
+
sinatools/ner/nn/BertSeqTagger.py,sha256=dFcBBiMw2QCWsyy7aQDe_PS3aRuNn4DOxKIHgTblFvc,504
|
42
|
+
sinatools/ner/nn/__init__.py,sha256=UgQD_XLNzQGBNSYc_Bw1aRJZjq4PJsnMT1iZwnJemqE,170
|
43
|
+
sinatools/ner/trainers/BaseTrainer.py,sha256=Uar8HxtgBXCVhKa85sEN622d9P7JiFBcWfs46uRG4aA,4068
|
44
|
+
sinatools/ner/trainers/BertNestedTrainer.py,sha256=Pb4O2WeBmTvV3hHMT6DXjxrTzgtuh3OrKQZnogYy8RQ,8429
|
45
|
+
sinatools/ner/trainers/BertTrainer.py,sha256=B_uVtUwfv_eFwMMPsKQvZgW_ZNLy6XEsX5ePR0s8d-k,6433
|
46
|
+
sinatools/ner/trainers/__init__.py,sha256=UDok8pDDpYOpwRBBKVLKaOgSUlmqqb-zHZI1p0xPxzI,188
|
47
|
+
sinatools/relations/__init__.py,sha256=cYjsP2mlTYvAwVIEFtgA6i9gLUSkGVOuDggMs7TvG5k,272
|
48
|
+
sinatools/relations/relation_extractor.py,sha256=UuDlaaR0ch9BFv4sBF1tr7P-P9xq8oRZF41tAze6_ok,9751
|
49
|
+
sinatools/semantic_relatedness/__init__.py,sha256=S0xrmqtl72L02N56nbNMudPoebnYQgsaIyyX-587DsU,830
|
50
|
+
sinatools/semantic_relatedness/compute_relatedness.py,sha256=_9HFPs3nQBLklHFfkc9o3gEjEI6Bd34Ha4E1Kvv1RIg,2256
|
51
|
+
sinatools/synonyms/__init__.py,sha256=yMuphNZrm5XLOR2T0weOHcUysJm-JKHUmVLoLQO8390,548
|
52
|
+
sinatools/synonyms/synonyms_generator.py,sha256=jRd0D3_kn-jYBaZzqY-7oOy0SFjSJ-mjM7JhsySzX58,9037
|
53
|
+
sinatools/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
54
|
+
sinatools/utils/charsets.py,sha256=rs82oZJqRqosZdTKXfFAJfJ5t4PxjMM_oAPsiWSWuwU,2817
|
55
|
+
sinatools/utils/parser.py,sha256=qvHdln5R5CAv_0UOJWe0mcp8JCsGqgazoeIIkoALH88,6259
|
56
|
+
sinatools/utils/readfile.py,sha256=xE4LEaCqXJIk9v37QUSSmWb-aY3UnCFUNb7uVdx3cpM,133
|
57
|
+
sinatools/utils/similarity.py,sha256=HAK6OmyVnfjPm0GWL3z9s4ZoUwpZHVKxt3CeSMfqLIQ,11990
|
58
|
+
sinatools/utils/text_dublication_detector.py,sha256=FeSkbfWGMQluz23H4CBHXION-walZPgjueX6AL8u_Q0,5660
|
59
|
+
sinatools/utils/text_transliteration.py,sha256=F3smhr2AEJtySE6wGQsiXXOslTvSDzLivTYu0btgc10,8769
|
60
|
+
sinatools/utils/tokenizer.py,sha256=nyk6lh5-p38wrU62hvh4wg7ni9ammkdqqIgcjbbBxxo,6965
|
61
|
+
sinatools/utils/tokenizers_words.py,sha256=efNfOil9qDNVJ9yynk_8sqf65PsL-xtsHG7y2SZCkjQ,656
|
62
|
+
sinatools/utils/word_compare.py,sha256=rS2Z74sf7R-7MTXyrFj5miRi2TnSG9OdTDp_qQYuo2Y,28200
|
63
|
+
sinatools/wsd/__init__.py,sha256=mwmCUurOV42rsNRpIUP3luG0oEzeTfEx3oeDl93Oif8,306
|
64
|
+
sinatools/wsd/disambiguator.py,sha256=h-3idc5rPPbMDSE_QVJAsEVkDHwzYY3L2SEPNXIdOcc,20104
|
65
|
+
sinatools/wsd/settings.py,sha256=6XflVTFKD8SVySX9Wj7zYQtV26WDTcQ2-uW8-gDNHKE,747
|
66
|
+
sinatools/wsd/wsd.py,sha256=gHIBUFXegoY1z3rRnIlK6TduhYq2BTa_dHakOjOlT4k,4434
|
67
|
+
SinaTools-1.0.1.dist-info/AUTHORS.rst,sha256=aTWeWlIdfLi56iLJfIUAwIrmqDcgxXKLji75_Fjzjyg,174
|
68
|
+
SinaTools-1.0.1.dist-info/LICENSE,sha256=uwsKYG4TayHXNANWdpfMN2lVW4dimxQjA_7vuCVhD70,1088
|
69
|
+
SinaTools-1.0.1.dist-info/METADATA,sha256=8EnFO3dSqtJ8JJ4r_-ji5tX_h04_vNTnPvfubqceaQ4,3409
|
70
|
+
SinaTools-1.0.1.dist-info/WHEEL,sha256=9Hm2OB-j1QcCUq9Jguht7ayGIIZBRTdOXD1qg9cCgPM,109
|
71
|
+
SinaTools-1.0.1.dist-info/entry_points.txt,sha256=_CsRKM_tSCWV5hefBNUsWf9_6DrJnzFlxeAo1wm5XqY,1302
|
72
|
+
SinaTools-1.0.1.dist-info/top_level.txt,sha256=8tNdPTeJKw3TQCaua8IJIx6N6WpgZZmVekf1OdBNJpE,10
|
73
|
+
SinaTools-1.0.1.dist-info/RECORD,,
|
sinatools/VERSION
CHANGED
@@ -1 +1 @@
|
|
1
|
-
0.1
|
1
|
+
1.0.1
|
sinatools/ner/__init__.py
CHANGED
@@ -11,7 +11,7 @@ from argparse import Namespace
|
|
11
11
|
tagger = None
|
12
12
|
tag_vocab = None
|
13
13
|
train_config = None
|
14
|
-
|
14
|
+
|
15
15
|
filename = 'Wj27012000.tar'
|
16
16
|
path =downloader.get_appdatadir()
|
17
17
|
model_path = os.path.join(path, filename)
|
@@ -20,21 +20,19 @@ _path = os.path.join(model_path, "tag_vocab.pkl")
|
|
20
20
|
|
21
21
|
with open(_path, "rb") as fh:
|
22
22
|
tag_vocab = pickle.load(fh)
|
23
|
-
print("tag_vocab loaded")
|
24
23
|
|
25
24
|
train_config = Namespace()
|
26
25
|
args_path = os.path.join(model_path, "args.json")
|
27
|
-
|
26
|
+
|
28
27
|
with open(args_path, "r") as fh:
|
29
28
|
train_config.__dict__ = json.load(fh)
|
30
|
-
|
29
|
+
|
31
30
|
model = load_object(train_config.network_config["fn"], train_config.network_config["kwargs"])
|
32
31
|
model = torch.nn.DataParallel(model)
|
33
|
-
|
32
|
+
|
34
33
|
if torch.cuda.is_available():
|
35
34
|
model = model.cuda()
|
36
|
-
|
35
|
+
|
37
36
|
train_config.trainer_config["kwargs"]["model"] = model
|
38
37
|
tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
|
39
38
|
tagger.load(os.path.join(model_path,"checkpoints"))
|
40
|
-
print("steps 4")
|
@@ -1,203 +1,203 @@
|
|
1
|
-
import os
|
2
|
-
import logging
|
3
|
-
import torch
|
4
|
-
import numpy as np
|
5
|
-
from sinatools.ner.trainers import BaseTrainer
|
6
|
-
from sinatools.ner.metrics import compute_nested_metrics
|
7
|
-
|
8
|
-
logger = logging.getLogger(__name__)
|
9
|
-
|
10
|
-
|
11
|
-
class BertNestedTrainer(BaseTrainer):
|
12
|
-
def __init__(self, **kwargs):
|
13
|
-
super().__init__(**kwargs)
|
14
|
-
|
15
|
-
def train(self):
|
16
|
-
best_val_loss, test_loss = np.inf, np.inf
|
17
|
-
num_train_batch = len(self.train_dataloader)
|
18
|
-
num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
|
19
|
-
patience = self.patience
|
20
|
-
|
21
|
-
for epoch_index in range(self.max_epochs):
|
22
|
-
self.current_epoch = epoch_index
|
23
|
-
train_loss = 0
|
24
|
-
|
25
|
-
for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
|
26
|
-
self.train_dataloader, is_train=True
|
27
|
-
), 1):
|
28
|
-
self.current_timestep += 1
|
29
|
-
|
30
|
-
# Compute loses for each output
|
31
|
-
# logits = B x T x L x C
|
32
|
-
losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
|
33
|
-
torch.reshape(gold_tags[:, i, :], (-1,)).long())
|
34
|
-
for i, l in enumerate(num_labels)]
|
35
|
-
|
36
|
-
torch.autograd.backward(losses)
|
37
|
-
|
38
|
-
# Avoid exploding gradient by doing gradient clipping
|
39
|
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
|
40
|
-
|
41
|
-
self.optimizer.step()
|
42
|
-
self.scheduler.step()
|
43
|
-
batch_loss = sum(l.item() for l in losses)
|
44
|
-
train_loss += batch_loss
|
45
|
-
|
46
|
-
if self.current_timestep % self.log_interval == 0:
|
47
|
-
logger.info(
|
48
|
-
"Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
|
49
|
-
epoch_index,
|
50
|
-
batch_index,
|
51
|
-
num_train_batch,
|
52
|
-
self.current_timestep,
|
53
|
-
self.optimizer.param_groups[0]['lr'],
|
54
|
-
batch_loss
|
55
|
-
)
|
56
|
-
|
57
|
-
train_loss /= num_train_batch
|
58
|
-
|
59
|
-
logger.info("** Evaluating on validation dataset **")
|
60
|
-
val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
|
61
|
-
val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
|
62
|
-
|
63
|
-
epoch_summary_loss = {
|
64
|
-
"train_loss": train_loss,
|
65
|
-
"val_loss": val_loss
|
66
|
-
}
|
67
|
-
epoch_summary_metrics = {
|
68
|
-
"val_micro_f1": val_metrics.micro_f1,
|
69
|
-
"val_precision": val_metrics.precision,
|
70
|
-
"val_recall": val_metrics.recall
|
71
|
-
}
|
72
|
-
|
73
|
-
logger.info(
|
74
|
-
"Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
|
75
|
-
epoch_index,
|
76
|
-
self.current_timestep,
|
77
|
-
train_loss,
|
78
|
-
val_loss,
|
79
|
-
val_metrics.micro_f1
|
80
|
-
)
|
81
|
-
|
82
|
-
if val_loss < best_val_loss:
|
83
|
-
patience = self.patience
|
84
|
-
best_val_loss = val_loss
|
85
|
-
logger.info("** Validation improved, evaluating test data **")
|
86
|
-
test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
|
87
|
-
self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
|
88
|
-
test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
|
89
|
-
|
90
|
-
epoch_summary_loss["test_loss"] = test_loss
|
91
|
-
epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
|
92
|
-
epoch_summary_metrics["test_precision"] = test_metrics.precision
|
93
|
-
epoch_summary_metrics["test_recall"] = test_metrics.recall
|
94
|
-
|
95
|
-
logger.info(
|
96
|
-
f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
|
97
|
-
epoch_index,
|
98
|
-
self.current_timestep,
|
99
|
-
test_loss,
|
100
|
-
test_metrics.micro_f1
|
101
|
-
)
|
102
|
-
|
103
|
-
self.save()
|
104
|
-
else:
|
105
|
-
patience -= 1
|
106
|
-
|
107
|
-
# No improvements, terminating early
|
108
|
-
if patience == 0:
|
109
|
-
logger.info("Early termination triggered")
|
110
|
-
break
|
111
|
-
|
112
|
-
self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
|
113
|
-
self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
|
114
|
-
|
115
|
-
def tag(self, dataloader, is_train=True):
|
116
|
-
"""
|
117
|
-
Given a dataloader containing segments, predict the tags
|
118
|
-
:param dataloader: torch.utils.data.DataLoader
|
119
|
-
:param is_train: boolean - True for training model, False for evaluation
|
120
|
-
:return: Iterator
|
121
|
-
subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
|
122
|
-
gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
|
123
|
-
tokens - List[arabiner.data.dataset.Token] - list of tokens
|
124
|
-
valid_len (B x 1) - int - valiud length of each sequence
|
125
|
-
logits (B x T x NUM_LABELS) - logits for each token and each tag
|
126
|
-
"""
|
127
|
-
for subwords, gold_tags, tokens, mask, valid_len in dataloader:
|
128
|
-
self.model.train(is_train)
|
129
|
-
|
130
|
-
if torch.cuda.is_available():
|
131
|
-
subwords = subwords.cuda()
|
132
|
-
gold_tags = gold_tags.cuda()
|
133
|
-
|
134
|
-
if is_train:
|
135
|
-
self.optimizer.zero_grad()
|
136
|
-
logits = self.model(subwords)
|
137
|
-
else:
|
138
|
-
with torch.no_grad():
|
139
|
-
logits = self.model(subwords)
|
140
|
-
|
141
|
-
yield subwords, gold_tags, tokens, valid_len, logits
|
142
|
-
|
143
|
-
def eval(self, dataloader):
|
144
|
-
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
145
|
-
num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
|
146
|
-
loss = 0
|
147
|
-
|
148
|
-
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
149
|
-
dataloader, is_train=False
|
150
|
-
):
|
151
|
-
losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
|
152
|
-
torch.reshape(gold_tags[:, i, :], (-1,)).long())
|
153
|
-
for i, l in enumerate(num_labels)]
|
154
|
-
loss += sum(losses)
|
155
|
-
preds += torch.argmax(logits, dim=3)
|
156
|
-
segments += tokens
|
157
|
-
valid_lens += list(valid_len)
|
158
|
-
|
159
|
-
loss /= len(dataloader)
|
160
|
-
|
161
|
-
# Update segments, attach predicted tags to each token
|
162
|
-
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
163
|
-
|
164
|
-
return preds, segments, valid_lens, loss
|
165
|
-
|
166
|
-
def infer(self, dataloader):
|
167
|
-
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
168
|
-
|
169
|
-
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
170
|
-
dataloader, is_train=False
|
171
|
-
):
|
172
|
-
preds += torch.argmax(logits, dim=3)
|
173
|
-
segments += tokens
|
174
|
-
valid_lens += list(valid_len)
|
175
|
-
|
176
|
-
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
177
|
-
return segments
|
178
|
-
|
179
|
-
def to_segments(self, segments, preds, valid_lens, vocab):
|
180
|
-
if vocab is None:
|
181
|
-
vocab = self.vocab
|
182
|
-
|
183
|
-
tagged_segments = list()
|
184
|
-
tokens_stoi = vocab.tokens.get_stoi()
|
185
|
-
unk_id = tokens_stoi["UNK"]
|
186
|
-
|
187
|
-
for segment, pred, valid_len in zip(segments, preds, valid_lens):
|
188
|
-
# First, the token at 0th index [CLS] and token at nth index [SEP]
|
189
|
-
# Combine the tokens with their corresponding predictions
|
190
|
-
segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
|
191
|
-
|
192
|
-
# Ignore the sub-tokens/subwords, which are identified with text being UNK
|
193
|
-
segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
|
194
|
-
|
195
|
-
# Attach the predicted tags to each token
|
196
|
-
list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
|
197
|
-
for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
|
198
|
-
|
199
|
-
# We are only interested in the tagged tokens, we do no longer need raw model predictions
|
200
|
-
tagged_segment = [t for t, _ in segment_pred]
|
201
|
-
tagged_segments.append(tagged_segment)
|
202
|
-
|
203
|
-
return tagged_segments
|
1
|
+
import os
|
2
|
+
import logging
|
3
|
+
import torch
|
4
|
+
import numpy as np
|
5
|
+
from sinatools.ner.trainers import BaseTrainer
|
6
|
+
from sinatools.ner.metrics import compute_nested_metrics
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
class BertNestedTrainer(BaseTrainer):
|
12
|
+
def __init__(self, **kwargs):
|
13
|
+
super().__init__(**kwargs)
|
14
|
+
|
15
|
+
def train(self):
|
16
|
+
best_val_loss, test_loss = np.inf, np.inf
|
17
|
+
num_train_batch = len(self.train_dataloader)
|
18
|
+
num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
|
19
|
+
patience = self.patience
|
20
|
+
|
21
|
+
for epoch_index in range(self.max_epochs):
|
22
|
+
self.current_epoch = epoch_index
|
23
|
+
train_loss = 0
|
24
|
+
|
25
|
+
for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
|
26
|
+
self.train_dataloader, is_train=True
|
27
|
+
), 1):
|
28
|
+
self.current_timestep += 1
|
29
|
+
|
30
|
+
# Compute loses for each output
|
31
|
+
# logits = B x T x L x C
|
32
|
+
losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
|
33
|
+
torch.reshape(gold_tags[:, i, :], (-1,)).long())
|
34
|
+
for i, l in enumerate(num_labels)]
|
35
|
+
|
36
|
+
torch.autograd.backward(losses)
|
37
|
+
|
38
|
+
# Avoid exploding gradient by doing gradient clipping
|
39
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
|
40
|
+
|
41
|
+
self.optimizer.step()
|
42
|
+
self.scheduler.step()
|
43
|
+
batch_loss = sum(l.item() for l in losses)
|
44
|
+
train_loss += batch_loss
|
45
|
+
|
46
|
+
if self.current_timestep % self.log_interval == 0:
|
47
|
+
logger.info(
|
48
|
+
"Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
|
49
|
+
epoch_index,
|
50
|
+
batch_index,
|
51
|
+
num_train_batch,
|
52
|
+
self.current_timestep,
|
53
|
+
self.optimizer.param_groups[0]['lr'],
|
54
|
+
batch_loss
|
55
|
+
)
|
56
|
+
|
57
|
+
train_loss /= num_train_batch
|
58
|
+
|
59
|
+
logger.info("** Evaluating on validation dataset **")
|
60
|
+
val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
|
61
|
+
val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
|
62
|
+
|
63
|
+
epoch_summary_loss = {
|
64
|
+
"train_loss": train_loss,
|
65
|
+
"val_loss": val_loss
|
66
|
+
}
|
67
|
+
epoch_summary_metrics = {
|
68
|
+
"val_micro_f1": val_metrics.micro_f1,
|
69
|
+
"val_precision": val_metrics.precision,
|
70
|
+
"val_recall": val_metrics.recall
|
71
|
+
}
|
72
|
+
|
73
|
+
logger.info(
|
74
|
+
"Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
|
75
|
+
epoch_index,
|
76
|
+
self.current_timestep,
|
77
|
+
train_loss,
|
78
|
+
val_loss,
|
79
|
+
val_metrics.micro_f1
|
80
|
+
)
|
81
|
+
|
82
|
+
if val_loss < best_val_loss:
|
83
|
+
patience = self.patience
|
84
|
+
best_val_loss = val_loss
|
85
|
+
logger.info("** Validation improved, evaluating test data **")
|
86
|
+
test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
|
87
|
+
self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
|
88
|
+
test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
|
89
|
+
|
90
|
+
epoch_summary_loss["test_loss"] = test_loss
|
91
|
+
epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
|
92
|
+
epoch_summary_metrics["test_precision"] = test_metrics.precision
|
93
|
+
epoch_summary_metrics["test_recall"] = test_metrics.recall
|
94
|
+
|
95
|
+
logger.info(
|
96
|
+
f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
|
97
|
+
epoch_index,
|
98
|
+
self.current_timestep,
|
99
|
+
test_loss,
|
100
|
+
test_metrics.micro_f1
|
101
|
+
)
|
102
|
+
|
103
|
+
self.save()
|
104
|
+
else:
|
105
|
+
patience -= 1
|
106
|
+
|
107
|
+
# No improvements, terminating early
|
108
|
+
if patience == 0:
|
109
|
+
logger.info("Early termination triggered")
|
110
|
+
break
|
111
|
+
|
112
|
+
self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
|
113
|
+
self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
|
114
|
+
|
115
|
+
def tag(self, dataloader, is_train=True):
|
116
|
+
"""
|
117
|
+
Given a dataloader containing segments, predict the tags
|
118
|
+
:param dataloader: torch.utils.data.DataLoader
|
119
|
+
:param is_train: boolean - True for training model, False for evaluation
|
120
|
+
:return: Iterator
|
121
|
+
subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
|
122
|
+
gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
|
123
|
+
tokens - List[arabiner.data.dataset.Token] - list of tokens
|
124
|
+
valid_len (B x 1) - int - valiud length of each sequence
|
125
|
+
logits (B x T x NUM_LABELS) - logits for each token and each tag
|
126
|
+
"""
|
127
|
+
for subwords, gold_tags, tokens, mask, valid_len in dataloader:
|
128
|
+
self.model.train(is_train)
|
129
|
+
|
130
|
+
if torch.cuda.is_available():
|
131
|
+
subwords = subwords.cuda()
|
132
|
+
gold_tags = gold_tags.cuda()
|
133
|
+
|
134
|
+
if is_train:
|
135
|
+
self.optimizer.zero_grad()
|
136
|
+
logits = self.model(subwords)
|
137
|
+
else:
|
138
|
+
with torch.no_grad():
|
139
|
+
logits = self.model(subwords)
|
140
|
+
|
141
|
+
yield subwords, gold_tags, tokens, valid_len, logits
|
142
|
+
|
143
|
+
def eval(self, dataloader):
|
144
|
+
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
145
|
+
num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
|
146
|
+
loss = 0
|
147
|
+
|
148
|
+
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
149
|
+
dataloader, is_train=False
|
150
|
+
):
|
151
|
+
losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
|
152
|
+
torch.reshape(gold_tags[:, i, :], (-1,)).long())
|
153
|
+
for i, l in enumerate(num_labels)]
|
154
|
+
loss += sum(losses)
|
155
|
+
preds += torch.argmax(logits, dim=3)
|
156
|
+
segments += tokens
|
157
|
+
valid_lens += list(valid_len)
|
158
|
+
|
159
|
+
loss /= len(dataloader)
|
160
|
+
|
161
|
+
# Update segments, attach predicted tags to each token
|
162
|
+
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
163
|
+
|
164
|
+
return preds, segments, valid_lens, loss
|
165
|
+
|
166
|
+
def infer(self, dataloader):
|
167
|
+
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
168
|
+
|
169
|
+
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
170
|
+
dataloader, is_train=False
|
171
|
+
):
|
172
|
+
preds += torch.argmax(logits, dim=3)
|
173
|
+
segments += tokens
|
174
|
+
valid_lens += list(valid_len)
|
175
|
+
|
176
|
+
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
177
|
+
return segments
|
178
|
+
|
179
|
+
def to_segments(self, segments, preds, valid_lens, vocab):
|
180
|
+
if vocab is None:
|
181
|
+
vocab = self.vocab
|
182
|
+
|
183
|
+
tagged_segments = list()
|
184
|
+
tokens_stoi = vocab.tokens.get_stoi()
|
185
|
+
unk_id = tokens_stoi["UNK"]
|
186
|
+
|
187
|
+
for segment, pred, valid_len in zip(segments, preds, valid_lens):
|
188
|
+
# First, the token at 0th index [CLS] and token at nth index [SEP]
|
189
|
+
# Combine the tokens with their corresponding predictions
|
190
|
+
segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
|
191
|
+
|
192
|
+
# Ignore the sub-tokens/subwords, which are identified with text being UNK
|
193
|
+
segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
|
194
|
+
|
195
|
+
# Attach the predicted tags to each token
|
196
|
+
list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
|
197
|
+
for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
|
198
|
+
|
199
|
+
# We are only interested in the tagged tokens, we do no longer need raw model predictions
|
200
|
+
tagged_segment = [t for t, _ in segment_pred]
|
201
|
+
tagged_segments.append(tagged_segment)
|
202
|
+
|
203
|
+
return tagged_segments
|