SinaTools 0.1.41__py2.py3-none-any.whl → 1.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {SinaTools-0.1.41.dist-info → SinaTools-1.0.1.dist-info}/METADATA +1 -1
  2. SinaTools-1.0.1.dist-info/RECORD +73 -0
  3. sinatools/VERSION +1 -1
  4. sinatools/ner/trainers/BertNestedTrainer.py +203 -203
  5. sinatools/ner/trainers/BertTrainer.py +163 -163
  6. sinatools/ner/trainers/__init__.py +2 -2
  7. SinaTools-0.1.41.dist-info/RECORD +0 -123
  8. sinatools/arabert/arabert/__init__.py +0 -14
  9. sinatools/arabert/arabert/create_classification_data.py +0 -260
  10. sinatools/arabert/arabert/create_pretraining_data.py +0 -534
  11. sinatools/arabert/arabert/extract_features.py +0 -444
  12. sinatools/arabert/arabert/lamb_optimizer.py +0 -158
  13. sinatools/arabert/arabert/modeling.py +0 -1027
  14. sinatools/arabert/arabert/optimization.py +0 -202
  15. sinatools/arabert/arabert/run_classifier.py +0 -1078
  16. sinatools/arabert/arabert/run_pretraining.py +0 -593
  17. sinatools/arabert/arabert/run_squad.py +0 -1440
  18. sinatools/arabert/arabert/tokenization.py +0 -414
  19. sinatools/arabert/araelectra/__init__.py +0 -1
  20. sinatools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +0 -103
  21. sinatools/arabert/araelectra/build_pretraining_dataset.py +0 -230
  22. sinatools/arabert/araelectra/build_pretraining_dataset_single_file.py +0 -90
  23. sinatools/arabert/araelectra/configure_finetuning.py +0 -172
  24. sinatools/arabert/araelectra/configure_pretraining.py +0 -143
  25. sinatools/arabert/araelectra/finetune/__init__.py +0 -14
  26. sinatools/arabert/araelectra/finetune/feature_spec.py +0 -56
  27. sinatools/arabert/araelectra/finetune/preprocessing.py +0 -173
  28. sinatools/arabert/araelectra/finetune/scorer.py +0 -54
  29. sinatools/arabert/araelectra/finetune/task.py +0 -74
  30. sinatools/arabert/araelectra/finetune/task_builder.py +0 -70
  31. sinatools/arabert/araelectra/flops_computation.py +0 -215
  32. sinatools/arabert/araelectra/model/__init__.py +0 -14
  33. sinatools/arabert/araelectra/model/modeling.py +0 -1029
  34. sinatools/arabert/araelectra/model/optimization.py +0 -193
  35. sinatools/arabert/araelectra/model/tokenization.py +0 -355
  36. sinatools/arabert/araelectra/pretrain/__init__.py +0 -14
  37. sinatools/arabert/araelectra/pretrain/pretrain_data.py +0 -160
  38. sinatools/arabert/araelectra/pretrain/pretrain_helpers.py +0 -229
  39. sinatools/arabert/araelectra/run_finetuning.py +0 -323
  40. sinatools/arabert/araelectra/run_pretraining.py +0 -469
  41. sinatools/arabert/araelectra/util/__init__.py +0 -14
  42. sinatools/arabert/araelectra/util/training_utils.py +0 -112
  43. sinatools/arabert/araelectra/util/utils.py +0 -109
  44. sinatools/arabert/aragpt2/__init__.py +0 -2
  45. sinatools/arabert/aragpt2/create_pretraining_data.py +0 -95
  46. sinatools/arabert/aragpt2/gpt2/__init__.py +0 -2
  47. sinatools/arabert/aragpt2/gpt2/lamb_optimizer.py +0 -158
  48. sinatools/arabert/aragpt2/gpt2/optimization.py +0 -225
  49. sinatools/arabert/aragpt2/gpt2/run_pretraining.py +0 -397
  50. sinatools/arabert/aragpt2/grover/__init__.py +0 -0
  51. sinatools/arabert/aragpt2/grover/dataloader.py +0 -161
  52. sinatools/arabert/aragpt2/grover/modeling.py +0 -803
  53. sinatools/arabert/aragpt2/grover/modeling_gpt2.py +0 -1196
  54. sinatools/arabert/aragpt2/grover/optimization_adafactor.py +0 -234
  55. sinatools/arabert/aragpt2/grover/train_tpu.py +0 -187
  56. sinatools/arabert/aragpt2/grover/utils.py +0 -234
  57. sinatools/arabert/aragpt2/train_bpe_tokenizer.py +0 -59
  58. {SinaTools-0.1.41.data → SinaTools-1.0.1.data}/data/sinatools/environment.yml +0 -0
  59. {SinaTools-0.1.41.dist-info → SinaTools-1.0.1.dist-info}/AUTHORS.rst +0 -0
  60. {SinaTools-0.1.41.dist-info → SinaTools-1.0.1.dist-info}/LICENSE +0 -0
  61. {SinaTools-0.1.41.dist-info → SinaTools-1.0.1.dist-info}/WHEEL +0 -0
  62. {SinaTools-0.1.41.dist-info → SinaTools-1.0.1.dist-info}/entry_points.txt +0 -0
  63. {SinaTools-0.1.41.dist-info → SinaTools-1.0.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: SinaTools
3
- Version: 0.1.41
3
+ Version: 1.0.1
4
4
  Summary: Open-source Python toolkit for Arabic Natural Understanding, allowing people to integrate it in their system workflow.
5
5
  Home-page: https://github.com/SinaLab/sinatools
6
6
  License: MIT license
@@ -0,0 +1,73 @@
1
+ SinaTools-1.0.1.data/data/sinatools/environment.yml,sha256=i0UFZc-vwU9ZwnI8hBdz7vi-x22vG-HR8ojWBUAOkno,5422
2
+ sinatools/VERSION,sha256=1R5uyUBYVUqEVYpbQC7m71_fVFXjXJAv7aYc2odSlDo,5
3
+ sinatools/__init__.py,sha256=bEosTU1o-FSpyytS6iVP_82BXHF2yHnzpJxPLYRbeII,135
4
+ sinatools/environment.yml,sha256=i0UFZc-vwU9ZwnI8hBdz7vi-x22vG-HR8ojWBUAOkno,5422
5
+ sinatools/install_env.py,sha256=EODeeE0ZzfM_rz33_JSIruX03Nc4ghyVOM5BHVhsZaQ,404
6
+ sinatools/sinatools.py,sha256=vR5AaF0iel21LvsdcqwheoBz0SIj9K9I_Ub8M8oA98Y,20
7
+ sinatools/CLI/DataDownload/download_files.py,sha256=EezvbukR3pZ8s6mGZnzTcjsbo3CBDlC0g6KhJWlYp1w,2686
8
+ sinatools/CLI/morphology/ALMA_multi_word.py,sha256=rmpa72twwIJHme_kpQ1lu3_7y_Jorj70QTvOnQMJRuI,1274
9
+ sinatools/CLI/morphology/morph_analyzer.py,sha256=HPamEKos_JRYCJv_2q6c12N--da58_JXTno9haww5Ao,3497
10
+ sinatools/CLI/ner/corpus_entity_extractor.py,sha256=DdvigsDQzko5nJBjzUXlIDqoBMBTVzktjSo7JfEXTIA,4778
11
+ sinatools/CLI/ner/entity_extractor.py,sha256=G9j-t0WKm2CRORhqARJM-pI-KArQ2IXIvnBK_NHxlHs,2885
12
+ sinatools/CLI/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ sinatools/CLI/utils/arStrip.py,sha256=NLyp8vOu2xv80tL9jiKRvyptmbkRZVg-wcAr-9YyvNY,3264
14
+ sinatools/CLI/utils/corpus_tokenizer.py,sha256=nH0T4h6urr_0Qy6-wN3PquOtnwybj0REde5Ts_OE4U8,1650
15
+ sinatools/CLI/utils/implication.py,sha256=AojpkCwUQJiQjxhyEUWKRHmBnIt1tVqr485cAF7Thq0,2857
16
+ sinatools/CLI/utils/jaccard.py,sha256=w56N_cNEFJ0A7WtunmY_xtms4srFagKBzrW_0YhH2DE,4216
17
+ sinatools/CLI/utils/remove_latin.py,sha256=NOaTm2RHxt5IQrV98ySTmD8rTXTmcqSmfbPAwTyaXqU,848
18
+ sinatools/CLI/utils/remove_punctuation.py,sha256=vJAZlEn7WGftZAFVFYnddkRrxdJ_rMmKB9vFZkY-jN4,1097
19
+ sinatools/CLI/utils/sentence_tokenizer.py,sha256=Wli8eiDbWSd_Z8UKpu_JkaS8jImowa1vnRL0oYCSfqw,2823
20
+ sinatools/CLI/utils/text_dublication_detector.py,sha256=dW70O5O20GxeUDDF6zVYn52wWLmJF-HBZgvqIeVL2rQ,1661
21
+ sinatools/CLI/utils/text_transliteration.py,sha256=vz-3kxWf8pNYVCqNAtBAiA6u_efrS5NtWT-ofN1NX6I,2014
22
+ sinatools/DataDownload/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
+ sinatools/DataDownload/downloader.py,sha256=VdUNgSqMKz1J-DuQD_eS1U2KWqEpy94WlSJ0pPODLig,7833
24
+ sinatools/arabert/__init__.py,sha256=ely2PttjgSv7vKdzskuD1rtK_l_UOpmxJSz8isrveD0,16
25
+ sinatools/arabert/preprocess.py,sha256=qI0FsuMTOzdRlYGCtLrjpXgikNElUZPv9bnjaKDZKJ4,33024
26
+ sinatools/morphology/ALMA_multi_word.py,sha256=hj_-8ojrYYHnfCGk8WKtJdUR8mauzQdma4WUm-okDps,1346
27
+ sinatools/morphology/__init__.py,sha256=I4wVBh8BhyNl-CySVdiI_nUSn6gj1j-gmLKP300RpE0,1216
28
+ sinatools/morphology/morph_analyzer.py,sha256=JOH2UWKNQWo5UzpWNzP9R1D3B3qLSogIiMp8n0N_56o,7177
29
+ sinatools/ner/__init__.py,sha256=59kLMX6UQhF6JpE10RhaDYC3a2_jiWOIVPuejsoflFE,1050
30
+ sinatools/ner/data_format.py,sha256=VmFshZbEPOsWxsb4tgSkwvbM1k7yCce4kmtPkCiWgwM,4513
31
+ sinatools/ner/datasets.py,sha256=mG1iwqSm3lXCFHLqE-b4wNi176cpuzNBz8tKaBU6z6M,5059
32
+ sinatools/ner/entity_extractor.py,sha256=O2epRwRFUUcQs3SnFIYHVBI4zVhr8hRcj0XJYeby4ts,3588
33
+ sinatools/ner/helpers.py,sha256=sX6ezVbuVQxk_xJqZwhUzJVFVuVmFGmei_kd6r3sPHE,3652
34
+ sinatools/ner/metrics.py,sha256=Irz6SsIvpOzGIA2lWxrEV86xnTnm0TzKm9SUVT4SXUU,2734
35
+ sinatools/ner/transforms.py,sha256=vti3mDdi-IRP8i0aTQ37QqpPlP9hdMmJ6_bAMa0uL-s,4871
36
+ sinatools/ner/data/__init__.py,sha256=W0C1ge_XxTfmdEGz0hkclz57aLI5VFS5t6BjByCfkFk,57
37
+ sinatools/ner/data/datasets.py,sha256=_uUlvBAhnTtPwKLj0wIbmB04VCBidfwffxKorLGHq_g,5134
38
+ sinatools/ner/data/transforms.py,sha256=URMz1dHzkHjgUGAkDOenCWvQThO1ha8XeQVjoLL9RXM,4874
39
+ sinatools/ner/nn/BaseModel.py,sha256=3GmujQasTZZunOBuFXpY2p1W8W256iI_Uu4hxhOY2Z0,608
40
+ sinatools/ner/nn/BertNestedTagger.py,sha256=_fwAn1kiKmXe6m5y16Ipty3kvXIEFEmiUq74Ad1818U,1219
41
+ sinatools/ner/nn/BertSeqTagger.py,sha256=dFcBBiMw2QCWsyy7aQDe_PS3aRuNn4DOxKIHgTblFvc,504
42
+ sinatools/ner/nn/__init__.py,sha256=UgQD_XLNzQGBNSYc_Bw1aRJZjq4PJsnMT1iZwnJemqE,170
43
+ sinatools/ner/trainers/BaseTrainer.py,sha256=Uar8HxtgBXCVhKa85sEN622d9P7JiFBcWfs46uRG4aA,4068
44
+ sinatools/ner/trainers/BertNestedTrainer.py,sha256=Pb4O2WeBmTvV3hHMT6DXjxrTzgtuh3OrKQZnogYy8RQ,8429
45
+ sinatools/ner/trainers/BertTrainer.py,sha256=B_uVtUwfv_eFwMMPsKQvZgW_ZNLy6XEsX5ePR0s8d-k,6433
46
+ sinatools/ner/trainers/__init__.py,sha256=UDok8pDDpYOpwRBBKVLKaOgSUlmqqb-zHZI1p0xPxzI,188
47
+ sinatools/relations/__init__.py,sha256=cYjsP2mlTYvAwVIEFtgA6i9gLUSkGVOuDggMs7TvG5k,272
48
+ sinatools/relations/relation_extractor.py,sha256=UuDlaaR0ch9BFv4sBF1tr7P-P9xq8oRZF41tAze6_ok,9751
49
+ sinatools/semantic_relatedness/__init__.py,sha256=S0xrmqtl72L02N56nbNMudPoebnYQgsaIyyX-587DsU,830
50
+ sinatools/semantic_relatedness/compute_relatedness.py,sha256=_9HFPs3nQBLklHFfkc9o3gEjEI6Bd34Ha4E1Kvv1RIg,2256
51
+ sinatools/synonyms/__init__.py,sha256=yMuphNZrm5XLOR2T0weOHcUysJm-JKHUmVLoLQO8390,548
52
+ sinatools/synonyms/synonyms_generator.py,sha256=jRd0D3_kn-jYBaZzqY-7oOy0SFjSJ-mjM7JhsySzX58,9037
53
+ sinatools/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
54
+ sinatools/utils/charsets.py,sha256=rs82oZJqRqosZdTKXfFAJfJ5t4PxjMM_oAPsiWSWuwU,2817
55
+ sinatools/utils/parser.py,sha256=qvHdln5R5CAv_0UOJWe0mcp8JCsGqgazoeIIkoALH88,6259
56
+ sinatools/utils/readfile.py,sha256=xE4LEaCqXJIk9v37QUSSmWb-aY3UnCFUNb7uVdx3cpM,133
57
+ sinatools/utils/similarity.py,sha256=HAK6OmyVnfjPm0GWL3z9s4ZoUwpZHVKxt3CeSMfqLIQ,11990
58
+ sinatools/utils/text_dublication_detector.py,sha256=FeSkbfWGMQluz23H4CBHXION-walZPgjueX6AL8u_Q0,5660
59
+ sinatools/utils/text_transliteration.py,sha256=F3smhr2AEJtySE6wGQsiXXOslTvSDzLivTYu0btgc10,8769
60
+ sinatools/utils/tokenizer.py,sha256=nyk6lh5-p38wrU62hvh4wg7ni9ammkdqqIgcjbbBxxo,6965
61
+ sinatools/utils/tokenizers_words.py,sha256=efNfOil9qDNVJ9yynk_8sqf65PsL-xtsHG7y2SZCkjQ,656
62
+ sinatools/utils/word_compare.py,sha256=rS2Z74sf7R-7MTXyrFj5miRi2TnSG9OdTDp_qQYuo2Y,28200
63
+ sinatools/wsd/__init__.py,sha256=mwmCUurOV42rsNRpIUP3luG0oEzeTfEx3oeDl93Oif8,306
64
+ sinatools/wsd/disambiguator.py,sha256=h-3idc5rPPbMDSE_QVJAsEVkDHwzYY3L2SEPNXIdOcc,20104
65
+ sinatools/wsd/settings.py,sha256=6XflVTFKD8SVySX9Wj7zYQtV26WDTcQ2-uW8-gDNHKE,747
66
+ sinatools/wsd/wsd.py,sha256=gHIBUFXegoY1z3rRnIlK6TduhYq2BTa_dHakOjOlT4k,4434
67
+ SinaTools-1.0.1.dist-info/AUTHORS.rst,sha256=aTWeWlIdfLi56iLJfIUAwIrmqDcgxXKLji75_Fjzjyg,174
68
+ SinaTools-1.0.1.dist-info/LICENSE,sha256=uwsKYG4TayHXNANWdpfMN2lVW4dimxQjA_7vuCVhD70,1088
69
+ SinaTools-1.0.1.dist-info/METADATA,sha256=8EnFO3dSqtJ8JJ4r_-ji5tX_h04_vNTnPvfubqceaQ4,3409
70
+ SinaTools-1.0.1.dist-info/WHEEL,sha256=9Hm2OB-j1QcCUq9Jguht7ayGIIZBRTdOXD1qg9cCgPM,109
71
+ SinaTools-1.0.1.dist-info/entry_points.txt,sha256=_CsRKM_tSCWV5hefBNUsWf9_6DrJnzFlxeAo1wm5XqY,1302
72
+ SinaTools-1.0.1.dist-info/top_level.txt,sha256=8tNdPTeJKw3TQCaua8IJIx6N6WpgZZmVekf1OdBNJpE,10
73
+ SinaTools-1.0.1.dist-info/RECORD,,
sinatools/VERSION CHANGED
@@ -1 +1 @@
1
- 0.1.41
1
+ 1.0.1
@@ -1,203 +1,203 @@
1
- import os
2
- import logging
3
- import torch
4
- import numpy as np
5
- from sinatools.ner.trainers import BaseTrainer
6
- from sinatools.ner.metrics import compute_nested_metrics
7
-
8
- logger = logging.getLogger(__name__)
9
-
10
-
11
- class BertNestedTrainer(BaseTrainer):
12
- def __init__(self, **kwargs):
13
- super().__init__(**kwargs)
14
-
15
- def train(self):
16
- best_val_loss, test_loss = np.inf, np.inf
17
- num_train_batch = len(self.train_dataloader)
18
- num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
19
- patience = self.patience
20
-
21
- for epoch_index in range(self.max_epochs):
22
- self.current_epoch = epoch_index
23
- train_loss = 0
24
-
25
- for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
26
- self.train_dataloader, is_train=True
27
- ), 1):
28
- self.current_timestep += 1
29
-
30
- # Compute loses for each output
31
- # logits = B x T x L x C
32
- losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
33
- torch.reshape(gold_tags[:, i, :], (-1,)).long())
34
- for i, l in enumerate(num_labels)]
35
-
36
- torch.autograd.backward(losses)
37
-
38
- # Avoid exploding gradient by doing gradient clipping
39
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
40
-
41
- self.optimizer.step()
42
- self.scheduler.step()
43
- batch_loss = sum(l.item() for l in losses)
44
- train_loss += batch_loss
45
-
46
- if self.current_timestep % self.log_interval == 0:
47
- logger.info(
48
- "Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
49
- epoch_index,
50
- batch_index,
51
- num_train_batch,
52
- self.current_timestep,
53
- self.optimizer.param_groups[0]['lr'],
54
- batch_loss
55
- )
56
-
57
- train_loss /= num_train_batch
58
-
59
- logger.info("** Evaluating on validation dataset **")
60
- val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
61
- val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
62
-
63
- epoch_summary_loss = {
64
- "train_loss": train_loss,
65
- "val_loss": val_loss
66
- }
67
- epoch_summary_metrics = {
68
- "val_micro_f1": val_metrics.micro_f1,
69
- "val_precision": val_metrics.precision,
70
- "val_recall": val_metrics.recall
71
- }
72
-
73
- logger.info(
74
- "Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
75
- epoch_index,
76
- self.current_timestep,
77
- train_loss,
78
- val_loss,
79
- val_metrics.micro_f1
80
- )
81
-
82
- if val_loss < best_val_loss:
83
- patience = self.patience
84
- best_val_loss = val_loss
85
- logger.info("** Validation improved, evaluating test data **")
86
- test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
87
- self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
88
- test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
89
-
90
- epoch_summary_loss["test_loss"] = test_loss
91
- epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
92
- epoch_summary_metrics["test_precision"] = test_metrics.precision
93
- epoch_summary_metrics["test_recall"] = test_metrics.recall
94
-
95
- logger.info(
96
- f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
97
- epoch_index,
98
- self.current_timestep,
99
- test_loss,
100
- test_metrics.micro_f1
101
- )
102
-
103
- self.save()
104
- else:
105
- patience -= 1
106
-
107
- # No improvements, terminating early
108
- if patience == 0:
109
- logger.info("Early termination triggered")
110
- break
111
-
112
- self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
113
- self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
114
-
115
- def tag(self, dataloader, is_train=True):
116
- """
117
- Given a dataloader containing segments, predict the tags
118
- :param dataloader: torch.utils.data.DataLoader
119
- :param is_train: boolean - True for training model, False for evaluation
120
- :return: Iterator
121
- subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
122
- gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
123
- tokens - List[arabiner.data.dataset.Token] - list of tokens
124
- valid_len (B x 1) - int - valiud length of each sequence
125
- logits (B x T x NUM_LABELS) - logits for each token and each tag
126
- """
127
- for subwords, gold_tags, tokens, mask, valid_len in dataloader:
128
- self.model.train(is_train)
129
-
130
- if torch.cuda.is_available():
131
- subwords = subwords.cuda()
132
- gold_tags = gold_tags.cuda()
133
-
134
- if is_train:
135
- self.optimizer.zero_grad()
136
- logits = self.model(subwords)
137
- else:
138
- with torch.no_grad():
139
- logits = self.model(subwords)
140
-
141
- yield subwords, gold_tags, tokens, valid_len, logits
142
-
143
- def eval(self, dataloader):
144
- golds, preds, segments, valid_lens = list(), list(), list(), list()
145
- num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
146
- loss = 0
147
-
148
- for _, gold_tags, tokens, valid_len, logits in self.tag(
149
- dataloader, is_train=False
150
- ):
151
- losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
152
- torch.reshape(gold_tags[:, i, :], (-1,)).long())
153
- for i, l in enumerate(num_labels)]
154
- loss += sum(losses)
155
- preds += torch.argmax(logits, dim=3)
156
- segments += tokens
157
- valid_lens += list(valid_len)
158
-
159
- loss /= len(dataloader)
160
-
161
- # Update segments, attach predicted tags to each token
162
- segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
163
-
164
- return preds, segments, valid_lens, loss
165
-
166
- def infer(self, dataloader):
167
- golds, preds, segments, valid_lens = list(), list(), list(), list()
168
-
169
- for _, gold_tags, tokens, valid_len, logits in self.tag(
170
- dataloader, is_train=False
171
- ):
172
- preds += torch.argmax(logits, dim=3)
173
- segments += tokens
174
- valid_lens += list(valid_len)
175
-
176
- segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
177
- return segments
178
-
179
- def to_segments(self, segments, preds, valid_lens, vocab):
180
- if vocab is None:
181
- vocab = self.vocab
182
-
183
- tagged_segments = list()
184
- tokens_stoi = vocab.tokens.get_stoi()
185
- unk_id = tokens_stoi["UNK"]
186
-
187
- for segment, pred, valid_len in zip(segments, preds, valid_lens):
188
- # First, the token at 0th index [CLS] and token at nth index [SEP]
189
- # Combine the tokens with their corresponding predictions
190
- segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
191
-
192
- # Ignore the sub-tokens/subwords, which are identified with text being UNK
193
- segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
194
-
195
- # Attach the predicted tags to each token
196
- list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
197
- for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
198
-
199
- # We are only interested in the tagged tokens, we do no longer need raw model predictions
200
- tagged_segment = [t for t, _ in segment_pred]
201
- tagged_segments.append(tagged_segment)
202
-
203
- return tagged_segments
1
+ import os
2
+ import logging
3
+ import torch
4
+ import numpy as np
5
+ from sinatools.ner.trainers import BaseTrainer
6
+ from sinatools.ner.metrics import compute_nested_metrics
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class BertNestedTrainer(BaseTrainer):
12
+ def __init__(self, **kwargs):
13
+ super().__init__(**kwargs)
14
+
15
+ def train(self):
16
+ best_val_loss, test_loss = np.inf, np.inf
17
+ num_train_batch = len(self.train_dataloader)
18
+ num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
19
+ patience = self.patience
20
+
21
+ for epoch_index in range(self.max_epochs):
22
+ self.current_epoch = epoch_index
23
+ train_loss = 0
24
+
25
+ for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
26
+ self.train_dataloader, is_train=True
27
+ ), 1):
28
+ self.current_timestep += 1
29
+
30
+ # Compute loses for each output
31
+ # logits = B x T x L x C
32
+ losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
33
+ torch.reshape(gold_tags[:, i, :], (-1,)).long())
34
+ for i, l in enumerate(num_labels)]
35
+
36
+ torch.autograd.backward(losses)
37
+
38
+ # Avoid exploding gradient by doing gradient clipping
39
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
40
+
41
+ self.optimizer.step()
42
+ self.scheduler.step()
43
+ batch_loss = sum(l.item() for l in losses)
44
+ train_loss += batch_loss
45
+
46
+ if self.current_timestep % self.log_interval == 0:
47
+ logger.info(
48
+ "Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
49
+ epoch_index,
50
+ batch_index,
51
+ num_train_batch,
52
+ self.current_timestep,
53
+ self.optimizer.param_groups[0]['lr'],
54
+ batch_loss
55
+ )
56
+
57
+ train_loss /= num_train_batch
58
+
59
+ logger.info("** Evaluating on validation dataset **")
60
+ val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
61
+ val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
62
+
63
+ epoch_summary_loss = {
64
+ "train_loss": train_loss,
65
+ "val_loss": val_loss
66
+ }
67
+ epoch_summary_metrics = {
68
+ "val_micro_f1": val_metrics.micro_f1,
69
+ "val_precision": val_metrics.precision,
70
+ "val_recall": val_metrics.recall
71
+ }
72
+
73
+ logger.info(
74
+ "Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
75
+ epoch_index,
76
+ self.current_timestep,
77
+ train_loss,
78
+ val_loss,
79
+ val_metrics.micro_f1
80
+ )
81
+
82
+ if val_loss < best_val_loss:
83
+ patience = self.patience
84
+ best_val_loss = val_loss
85
+ logger.info("** Validation improved, evaluating test data **")
86
+ test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
87
+ self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
88
+ test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
89
+
90
+ epoch_summary_loss["test_loss"] = test_loss
91
+ epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
92
+ epoch_summary_metrics["test_precision"] = test_metrics.precision
93
+ epoch_summary_metrics["test_recall"] = test_metrics.recall
94
+
95
+ logger.info(
96
+ f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
97
+ epoch_index,
98
+ self.current_timestep,
99
+ test_loss,
100
+ test_metrics.micro_f1
101
+ )
102
+
103
+ self.save()
104
+ else:
105
+ patience -= 1
106
+
107
+ # No improvements, terminating early
108
+ if patience == 0:
109
+ logger.info("Early termination triggered")
110
+ break
111
+
112
+ self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
113
+ self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
114
+
115
+ def tag(self, dataloader, is_train=True):
116
+ """
117
+ Given a dataloader containing segments, predict the tags
118
+ :param dataloader: torch.utils.data.DataLoader
119
+ :param is_train: boolean - True for training model, False for evaluation
120
+ :return: Iterator
121
+ subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
122
+ gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
123
+ tokens - List[arabiner.data.dataset.Token] - list of tokens
124
+ valid_len (B x 1) - int - valiud length of each sequence
125
+ logits (B x T x NUM_LABELS) - logits for each token and each tag
126
+ """
127
+ for subwords, gold_tags, tokens, mask, valid_len in dataloader:
128
+ self.model.train(is_train)
129
+
130
+ if torch.cuda.is_available():
131
+ subwords = subwords.cuda()
132
+ gold_tags = gold_tags.cuda()
133
+
134
+ if is_train:
135
+ self.optimizer.zero_grad()
136
+ logits = self.model(subwords)
137
+ else:
138
+ with torch.no_grad():
139
+ logits = self.model(subwords)
140
+
141
+ yield subwords, gold_tags, tokens, valid_len, logits
142
+
143
+ def eval(self, dataloader):
144
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
145
+ num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
146
+ loss = 0
147
+
148
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
149
+ dataloader, is_train=False
150
+ ):
151
+ losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
152
+ torch.reshape(gold_tags[:, i, :], (-1,)).long())
153
+ for i, l in enumerate(num_labels)]
154
+ loss += sum(losses)
155
+ preds += torch.argmax(logits, dim=3)
156
+ segments += tokens
157
+ valid_lens += list(valid_len)
158
+
159
+ loss /= len(dataloader)
160
+
161
+ # Update segments, attach predicted tags to each token
162
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
163
+
164
+ return preds, segments, valid_lens, loss
165
+
166
+ def infer(self, dataloader):
167
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
168
+
169
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
170
+ dataloader, is_train=False
171
+ ):
172
+ preds += torch.argmax(logits, dim=3)
173
+ segments += tokens
174
+ valid_lens += list(valid_len)
175
+
176
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
177
+ return segments
178
+
179
+ def to_segments(self, segments, preds, valid_lens, vocab):
180
+ if vocab is None:
181
+ vocab = self.vocab
182
+
183
+ tagged_segments = list()
184
+ tokens_stoi = vocab.tokens.get_stoi()
185
+ unk_id = tokens_stoi["UNK"]
186
+
187
+ for segment, pred, valid_len in zip(segments, preds, valid_lens):
188
+ # First, the token at 0th index [CLS] and token at nth index [SEP]
189
+ # Combine the tokens with their corresponding predictions
190
+ segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
191
+
192
+ # Ignore the sub-tokens/subwords, which are identified with text being UNK
193
+ segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
194
+
195
+ # Attach the predicted tags to each token
196
+ list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
197
+ for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
198
+
199
+ # We are only interested in the tagged tokens, we do no longer need raw model predictions
200
+ tagged_segment = [t for t, _ in segment_pred]
201
+ tagged_segments.append(tagged_segment)
202
+
203
+ return tagged_segments