SinaTools 0.1.40__py2.py3-none-any.whl → 1.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/METADATA +1 -1
  2. SinaTools-1.0.1.dist-info/RECORD +73 -0
  3. sinatools/VERSION +1 -1
  4. sinatools/ner/__init__.py +5 -7
  5. sinatools/ner/trainers/BertNestedTrainer.py +203 -203
  6. sinatools/ner/trainers/BertTrainer.py +163 -163
  7. sinatools/ner/trainers/__init__.py +2 -2
  8. SinaTools-0.1.40.dist-info/RECORD +0 -123
  9. sinatools/arabert/arabert/__init__.py +0 -14
  10. sinatools/arabert/arabert/create_classification_data.py +0 -260
  11. sinatools/arabert/arabert/create_pretraining_data.py +0 -534
  12. sinatools/arabert/arabert/extract_features.py +0 -444
  13. sinatools/arabert/arabert/lamb_optimizer.py +0 -158
  14. sinatools/arabert/arabert/modeling.py +0 -1027
  15. sinatools/arabert/arabert/optimization.py +0 -202
  16. sinatools/arabert/arabert/run_classifier.py +0 -1078
  17. sinatools/arabert/arabert/run_pretraining.py +0 -593
  18. sinatools/arabert/arabert/run_squad.py +0 -1440
  19. sinatools/arabert/arabert/tokenization.py +0 -414
  20. sinatools/arabert/araelectra/__init__.py +0 -1
  21. sinatools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +0 -103
  22. sinatools/arabert/araelectra/build_pretraining_dataset.py +0 -230
  23. sinatools/arabert/araelectra/build_pretraining_dataset_single_file.py +0 -90
  24. sinatools/arabert/araelectra/configure_finetuning.py +0 -172
  25. sinatools/arabert/araelectra/configure_pretraining.py +0 -143
  26. sinatools/arabert/araelectra/finetune/__init__.py +0 -14
  27. sinatools/arabert/araelectra/finetune/feature_spec.py +0 -56
  28. sinatools/arabert/araelectra/finetune/preprocessing.py +0 -173
  29. sinatools/arabert/araelectra/finetune/scorer.py +0 -54
  30. sinatools/arabert/araelectra/finetune/task.py +0 -74
  31. sinatools/arabert/araelectra/finetune/task_builder.py +0 -70
  32. sinatools/arabert/araelectra/flops_computation.py +0 -215
  33. sinatools/arabert/araelectra/model/__init__.py +0 -14
  34. sinatools/arabert/araelectra/model/modeling.py +0 -1029
  35. sinatools/arabert/araelectra/model/optimization.py +0 -193
  36. sinatools/arabert/araelectra/model/tokenization.py +0 -355
  37. sinatools/arabert/araelectra/pretrain/__init__.py +0 -14
  38. sinatools/arabert/araelectra/pretrain/pretrain_data.py +0 -160
  39. sinatools/arabert/araelectra/pretrain/pretrain_helpers.py +0 -229
  40. sinatools/arabert/araelectra/run_finetuning.py +0 -323
  41. sinatools/arabert/araelectra/run_pretraining.py +0 -469
  42. sinatools/arabert/araelectra/util/__init__.py +0 -14
  43. sinatools/arabert/araelectra/util/training_utils.py +0 -112
  44. sinatools/arabert/araelectra/util/utils.py +0 -109
  45. sinatools/arabert/aragpt2/__init__.py +0 -2
  46. sinatools/arabert/aragpt2/create_pretraining_data.py +0 -95
  47. sinatools/arabert/aragpt2/gpt2/__init__.py +0 -2
  48. sinatools/arabert/aragpt2/gpt2/lamb_optimizer.py +0 -158
  49. sinatools/arabert/aragpt2/gpt2/optimization.py +0 -225
  50. sinatools/arabert/aragpt2/gpt2/run_pretraining.py +0 -397
  51. sinatools/arabert/aragpt2/grover/__init__.py +0 -0
  52. sinatools/arabert/aragpt2/grover/dataloader.py +0 -161
  53. sinatools/arabert/aragpt2/grover/modeling.py +0 -803
  54. sinatools/arabert/aragpt2/grover/modeling_gpt2.py +0 -1196
  55. sinatools/arabert/aragpt2/grover/optimization_adafactor.py +0 -234
  56. sinatools/arabert/aragpt2/grover/train_tpu.py +0 -187
  57. sinatools/arabert/aragpt2/grover/utils.py +0 -234
  58. sinatools/arabert/aragpt2/train_bpe_tokenizer.py +0 -59
  59. {SinaTools-0.1.40.data → SinaTools-1.0.1.data}/data/sinatools/environment.yml +0 -0
  60. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/AUTHORS.rst +0 -0
  61. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/LICENSE +0 -0
  62. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/WHEEL +0 -0
  63. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/entry_points.txt +0 -0
  64. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: SinaTools
3
- Version: 0.1.40
3
+ Version: 1.0.1
4
4
  Summary: Open-source Python toolkit for Arabic Natural Understanding, allowing people to integrate it in their system workflow.
5
5
  Home-page: https://github.com/SinaLab/sinatools
6
6
  License: MIT license
@@ -0,0 +1,73 @@
1
+ SinaTools-1.0.1.data/data/sinatools/environment.yml,sha256=i0UFZc-vwU9ZwnI8hBdz7vi-x22vG-HR8ojWBUAOkno,5422
2
+ sinatools/VERSION,sha256=1R5uyUBYVUqEVYpbQC7m71_fVFXjXJAv7aYc2odSlDo,5
3
+ sinatools/__init__.py,sha256=bEosTU1o-FSpyytS6iVP_82BXHF2yHnzpJxPLYRbeII,135
4
+ sinatools/environment.yml,sha256=i0UFZc-vwU9ZwnI8hBdz7vi-x22vG-HR8ojWBUAOkno,5422
5
+ sinatools/install_env.py,sha256=EODeeE0ZzfM_rz33_JSIruX03Nc4ghyVOM5BHVhsZaQ,404
6
+ sinatools/sinatools.py,sha256=vR5AaF0iel21LvsdcqwheoBz0SIj9K9I_Ub8M8oA98Y,20
7
+ sinatools/CLI/DataDownload/download_files.py,sha256=EezvbukR3pZ8s6mGZnzTcjsbo3CBDlC0g6KhJWlYp1w,2686
8
+ sinatools/CLI/morphology/ALMA_multi_word.py,sha256=rmpa72twwIJHme_kpQ1lu3_7y_Jorj70QTvOnQMJRuI,1274
9
+ sinatools/CLI/morphology/morph_analyzer.py,sha256=HPamEKos_JRYCJv_2q6c12N--da58_JXTno9haww5Ao,3497
10
+ sinatools/CLI/ner/corpus_entity_extractor.py,sha256=DdvigsDQzko5nJBjzUXlIDqoBMBTVzktjSo7JfEXTIA,4778
11
+ sinatools/CLI/ner/entity_extractor.py,sha256=G9j-t0WKm2CRORhqARJM-pI-KArQ2IXIvnBK_NHxlHs,2885
12
+ sinatools/CLI/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ sinatools/CLI/utils/arStrip.py,sha256=NLyp8vOu2xv80tL9jiKRvyptmbkRZVg-wcAr-9YyvNY,3264
14
+ sinatools/CLI/utils/corpus_tokenizer.py,sha256=nH0T4h6urr_0Qy6-wN3PquOtnwybj0REde5Ts_OE4U8,1650
15
+ sinatools/CLI/utils/implication.py,sha256=AojpkCwUQJiQjxhyEUWKRHmBnIt1tVqr485cAF7Thq0,2857
16
+ sinatools/CLI/utils/jaccard.py,sha256=w56N_cNEFJ0A7WtunmY_xtms4srFagKBzrW_0YhH2DE,4216
17
+ sinatools/CLI/utils/remove_latin.py,sha256=NOaTm2RHxt5IQrV98ySTmD8rTXTmcqSmfbPAwTyaXqU,848
18
+ sinatools/CLI/utils/remove_punctuation.py,sha256=vJAZlEn7WGftZAFVFYnddkRrxdJ_rMmKB9vFZkY-jN4,1097
19
+ sinatools/CLI/utils/sentence_tokenizer.py,sha256=Wli8eiDbWSd_Z8UKpu_JkaS8jImowa1vnRL0oYCSfqw,2823
20
+ sinatools/CLI/utils/text_dublication_detector.py,sha256=dW70O5O20GxeUDDF6zVYn52wWLmJF-HBZgvqIeVL2rQ,1661
21
+ sinatools/CLI/utils/text_transliteration.py,sha256=vz-3kxWf8pNYVCqNAtBAiA6u_efrS5NtWT-ofN1NX6I,2014
22
+ sinatools/DataDownload/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
+ sinatools/DataDownload/downloader.py,sha256=VdUNgSqMKz1J-DuQD_eS1U2KWqEpy94WlSJ0pPODLig,7833
24
+ sinatools/arabert/__init__.py,sha256=ely2PttjgSv7vKdzskuD1rtK_l_UOpmxJSz8isrveD0,16
25
+ sinatools/arabert/preprocess.py,sha256=qI0FsuMTOzdRlYGCtLrjpXgikNElUZPv9bnjaKDZKJ4,33024
26
+ sinatools/morphology/ALMA_multi_word.py,sha256=hj_-8ojrYYHnfCGk8WKtJdUR8mauzQdma4WUm-okDps,1346
27
+ sinatools/morphology/__init__.py,sha256=I4wVBh8BhyNl-CySVdiI_nUSn6gj1j-gmLKP300RpE0,1216
28
+ sinatools/morphology/morph_analyzer.py,sha256=JOH2UWKNQWo5UzpWNzP9R1D3B3qLSogIiMp8n0N_56o,7177
29
+ sinatools/ner/__init__.py,sha256=59kLMX6UQhF6JpE10RhaDYC3a2_jiWOIVPuejsoflFE,1050
30
+ sinatools/ner/data_format.py,sha256=VmFshZbEPOsWxsb4tgSkwvbM1k7yCce4kmtPkCiWgwM,4513
31
+ sinatools/ner/datasets.py,sha256=mG1iwqSm3lXCFHLqE-b4wNi176cpuzNBz8tKaBU6z6M,5059
32
+ sinatools/ner/entity_extractor.py,sha256=O2epRwRFUUcQs3SnFIYHVBI4zVhr8hRcj0XJYeby4ts,3588
33
+ sinatools/ner/helpers.py,sha256=sX6ezVbuVQxk_xJqZwhUzJVFVuVmFGmei_kd6r3sPHE,3652
34
+ sinatools/ner/metrics.py,sha256=Irz6SsIvpOzGIA2lWxrEV86xnTnm0TzKm9SUVT4SXUU,2734
35
+ sinatools/ner/transforms.py,sha256=vti3mDdi-IRP8i0aTQ37QqpPlP9hdMmJ6_bAMa0uL-s,4871
36
+ sinatools/ner/data/__init__.py,sha256=W0C1ge_XxTfmdEGz0hkclz57aLI5VFS5t6BjByCfkFk,57
37
+ sinatools/ner/data/datasets.py,sha256=_uUlvBAhnTtPwKLj0wIbmB04VCBidfwffxKorLGHq_g,5134
38
+ sinatools/ner/data/transforms.py,sha256=URMz1dHzkHjgUGAkDOenCWvQThO1ha8XeQVjoLL9RXM,4874
39
+ sinatools/ner/nn/BaseModel.py,sha256=3GmujQasTZZunOBuFXpY2p1W8W256iI_Uu4hxhOY2Z0,608
40
+ sinatools/ner/nn/BertNestedTagger.py,sha256=_fwAn1kiKmXe6m5y16Ipty3kvXIEFEmiUq74Ad1818U,1219
41
+ sinatools/ner/nn/BertSeqTagger.py,sha256=dFcBBiMw2QCWsyy7aQDe_PS3aRuNn4DOxKIHgTblFvc,504
42
+ sinatools/ner/nn/__init__.py,sha256=UgQD_XLNzQGBNSYc_Bw1aRJZjq4PJsnMT1iZwnJemqE,170
43
+ sinatools/ner/trainers/BaseTrainer.py,sha256=Uar8HxtgBXCVhKa85sEN622d9P7JiFBcWfs46uRG4aA,4068
44
+ sinatools/ner/trainers/BertNestedTrainer.py,sha256=Pb4O2WeBmTvV3hHMT6DXjxrTzgtuh3OrKQZnogYy8RQ,8429
45
+ sinatools/ner/trainers/BertTrainer.py,sha256=B_uVtUwfv_eFwMMPsKQvZgW_ZNLy6XEsX5ePR0s8d-k,6433
46
+ sinatools/ner/trainers/__init__.py,sha256=UDok8pDDpYOpwRBBKVLKaOgSUlmqqb-zHZI1p0xPxzI,188
47
+ sinatools/relations/__init__.py,sha256=cYjsP2mlTYvAwVIEFtgA6i9gLUSkGVOuDggMs7TvG5k,272
48
+ sinatools/relations/relation_extractor.py,sha256=UuDlaaR0ch9BFv4sBF1tr7P-P9xq8oRZF41tAze6_ok,9751
49
+ sinatools/semantic_relatedness/__init__.py,sha256=S0xrmqtl72L02N56nbNMudPoebnYQgsaIyyX-587DsU,830
50
+ sinatools/semantic_relatedness/compute_relatedness.py,sha256=_9HFPs3nQBLklHFfkc9o3gEjEI6Bd34Ha4E1Kvv1RIg,2256
51
+ sinatools/synonyms/__init__.py,sha256=yMuphNZrm5XLOR2T0weOHcUysJm-JKHUmVLoLQO8390,548
52
+ sinatools/synonyms/synonyms_generator.py,sha256=jRd0D3_kn-jYBaZzqY-7oOy0SFjSJ-mjM7JhsySzX58,9037
53
+ sinatools/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
54
+ sinatools/utils/charsets.py,sha256=rs82oZJqRqosZdTKXfFAJfJ5t4PxjMM_oAPsiWSWuwU,2817
55
+ sinatools/utils/parser.py,sha256=qvHdln5R5CAv_0UOJWe0mcp8JCsGqgazoeIIkoALH88,6259
56
+ sinatools/utils/readfile.py,sha256=xE4LEaCqXJIk9v37QUSSmWb-aY3UnCFUNb7uVdx3cpM,133
57
+ sinatools/utils/similarity.py,sha256=HAK6OmyVnfjPm0GWL3z9s4ZoUwpZHVKxt3CeSMfqLIQ,11990
58
+ sinatools/utils/text_dublication_detector.py,sha256=FeSkbfWGMQluz23H4CBHXION-walZPgjueX6AL8u_Q0,5660
59
+ sinatools/utils/text_transliteration.py,sha256=F3smhr2AEJtySE6wGQsiXXOslTvSDzLivTYu0btgc10,8769
60
+ sinatools/utils/tokenizer.py,sha256=nyk6lh5-p38wrU62hvh4wg7ni9ammkdqqIgcjbbBxxo,6965
61
+ sinatools/utils/tokenizers_words.py,sha256=efNfOil9qDNVJ9yynk_8sqf65PsL-xtsHG7y2SZCkjQ,656
62
+ sinatools/utils/word_compare.py,sha256=rS2Z74sf7R-7MTXyrFj5miRi2TnSG9OdTDp_qQYuo2Y,28200
63
+ sinatools/wsd/__init__.py,sha256=mwmCUurOV42rsNRpIUP3luG0oEzeTfEx3oeDl93Oif8,306
64
+ sinatools/wsd/disambiguator.py,sha256=h-3idc5rPPbMDSE_QVJAsEVkDHwzYY3L2SEPNXIdOcc,20104
65
+ sinatools/wsd/settings.py,sha256=6XflVTFKD8SVySX9Wj7zYQtV26WDTcQ2-uW8-gDNHKE,747
66
+ sinatools/wsd/wsd.py,sha256=gHIBUFXegoY1z3rRnIlK6TduhYq2BTa_dHakOjOlT4k,4434
67
+ SinaTools-1.0.1.dist-info/AUTHORS.rst,sha256=aTWeWlIdfLi56iLJfIUAwIrmqDcgxXKLji75_Fjzjyg,174
68
+ SinaTools-1.0.1.dist-info/LICENSE,sha256=uwsKYG4TayHXNANWdpfMN2lVW4dimxQjA_7vuCVhD70,1088
69
+ SinaTools-1.0.1.dist-info/METADATA,sha256=8EnFO3dSqtJ8JJ4r_-ji5tX_h04_vNTnPvfubqceaQ4,3409
70
+ SinaTools-1.0.1.dist-info/WHEEL,sha256=9Hm2OB-j1QcCUq9Jguht7ayGIIZBRTdOXD1qg9cCgPM,109
71
+ SinaTools-1.0.1.dist-info/entry_points.txt,sha256=_CsRKM_tSCWV5hefBNUsWf9_6DrJnzFlxeAo1wm5XqY,1302
72
+ SinaTools-1.0.1.dist-info/top_level.txt,sha256=8tNdPTeJKw3TQCaua8IJIx6N6WpgZZmVekf1OdBNJpE,10
73
+ SinaTools-1.0.1.dist-info/RECORD,,
sinatools/VERSION CHANGED
@@ -1 +1 @@
1
- 0.1.40
1
+ 1.0.1
sinatools/ner/__init__.py CHANGED
@@ -11,7 +11,7 @@ from argparse import Namespace
11
11
  tagger = None
12
12
  tag_vocab = None
13
13
  train_config = None
14
- print("ner started")
14
+
15
15
  filename = 'Wj27012000.tar'
16
16
  path =downloader.get_appdatadir()
17
17
  model_path = os.path.join(path, filename)
@@ -20,21 +20,19 @@ _path = os.path.join(model_path, "tag_vocab.pkl")
20
20
 
21
21
  with open(_path, "rb") as fh:
22
22
  tag_vocab = pickle.load(fh)
23
- print("tag_vocab loaded")
24
23
 
25
24
  train_config = Namespace()
26
25
  args_path = os.path.join(model_path, "args.json")
27
- print("args loaded")
26
+
28
27
  with open(args_path, "r") as fh:
29
28
  train_config.__dict__ = json.load(fh)
30
- print("steps 1")
29
+
31
30
  model = load_object(train_config.network_config["fn"], train_config.network_config["kwargs"])
32
31
  model = torch.nn.DataParallel(model)
33
- print("steps 2")
32
+
34
33
  if torch.cuda.is_available():
35
34
  model = model.cuda()
36
- print("steps 3")
35
+
37
36
  train_config.trainer_config["kwargs"]["model"] = model
38
37
  tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
39
38
  tagger.load(os.path.join(model_path,"checkpoints"))
40
- print("steps 4")
@@ -1,203 +1,203 @@
1
- import os
2
- import logging
3
- import torch
4
- import numpy as np
5
- from sinatools.ner.trainers import BaseTrainer
6
- from sinatools.ner.metrics import compute_nested_metrics
7
-
8
- logger = logging.getLogger(__name__)
9
-
10
-
11
- class BertNestedTrainer(BaseTrainer):
12
- def __init__(self, **kwargs):
13
- super().__init__(**kwargs)
14
-
15
- def train(self):
16
- best_val_loss, test_loss = np.inf, np.inf
17
- num_train_batch = len(self.train_dataloader)
18
- num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
19
- patience = self.patience
20
-
21
- for epoch_index in range(self.max_epochs):
22
- self.current_epoch = epoch_index
23
- train_loss = 0
24
-
25
- for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
26
- self.train_dataloader, is_train=True
27
- ), 1):
28
- self.current_timestep += 1
29
-
30
- # Compute loses for each output
31
- # logits = B x T x L x C
32
- losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
33
- torch.reshape(gold_tags[:, i, :], (-1,)).long())
34
- for i, l in enumerate(num_labels)]
35
-
36
- torch.autograd.backward(losses)
37
-
38
- # Avoid exploding gradient by doing gradient clipping
39
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
40
-
41
- self.optimizer.step()
42
- self.scheduler.step()
43
- batch_loss = sum(l.item() for l in losses)
44
- train_loss += batch_loss
45
-
46
- if self.current_timestep % self.log_interval == 0:
47
- logger.info(
48
- "Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
49
- epoch_index,
50
- batch_index,
51
- num_train_batch,
52
- self.current_timestep,
53
- self.optimizer.param_groups[0]['lr'],
54
- batch_loss
55
- )
56
-
57
- train_loss /= num_train_batch
58
-
59
- logger.info("** Evaluating on validation dataset **")
60
- val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
61
- val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
62
-
63
- epoch_summary_loss = {
64
- "train_loss": train_loss,
65
- "val_loss": val_loss
66
- }
67
- epoch_summary_metrics = {
68
- "val_micro_f1": val_metrics.micro_f1,
69
- "val_precision": val_metrics.precision,
70
- "val_recall": val_metrics.recall
71
- }
72
-
73
- logger.info(
74
- "Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
75
- epoch_index,
76
- self.current_timestep,
77
- train_loss,
78
- val_loss,
79
- val_metrics.micro_f1
80
- )
81
-
82
- if val_loss < best_val_loss:
83
- patience = self.patience
84
- best_val_loss = val_loss
85
- logger.info("** Validation improved, evaluating test data **")
86
- test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
87
- self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
88
- test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
89
-
90
- epoch_summary_loss["test_loss"] = test_loss
91
- epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
92
- epoch_summary_metrics["test_precision"] = test_metrics.precision
93
- epoch_summary_metrics["test_recall"] = test_metrics.recall
94
-
95
- logger.info(
96
- f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
97
- epoch_index,
98
- self.current_timestep,
99
- test_loss,
100
- test_metrics.micro_f1
101
- )
102
-
103
- self.save()
104
- else:
105
- patience -= 1
106
-
107
- # No improvements, terminating early
108
- if patience == 0:
109
- logger.info("Early termination triggered")
110
- break
111
-
112
- self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
113
- self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
114
-
115
- def tag(self, dataloader, is_train=True):
116
- """
117
- Given a dataloader containing segments, predict the tags
118
- :param dataloader: torch.utils.data.DataLoader
119
- :param is_train: boolean - True for training model, False for evaluation
120
- :return: Iterator
121
- subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
122
- gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
123
- tokens - List[arabiner.data.dataset.Token] - list of tokens
124
- valid_len (B x 1) - int - valiud length of each sequence
125
- logits (B x T x NUM_LABELS) - logits for each token and each tag
126
- """
127
- for subwords, gold_tags, tokens, mask, valid_len in dataloader:
128
- self.model.train(is_train)
129
-
130
- if torch.cuda.is_available():
131
- subwords = subwords.cuda()
132
- gold_tags = gold_tags.cuda()
133
-
134
- if is_train:
135
- self.optimizer.zero_grad()
136
- logits = self.model(subwords)
137
- else:
138
- with torch.no_grad():
139
- logits = self.model(subwords)
140
-
141
- yield subwords, gold_tags, tokens, valid_len, logits
142
-
143
- def eval(self, dataloader):
144
- golds, preds, segments, valid_lens = list(), list(), list(), list()
145
- num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
146
- loss = 0
147
-
148
- for _, gold_tags, tokens, valid_len, logits in self.tag(
149
- dataloader, is_train=False
150
- ):
151
- losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
152
- torch.reshape(gold_tags[:, i, :], (-1,)).long())
153
- for i, l in enumerate(num_labels)]
154
- loss += sum(losses)
155
- preds += torch.argmax(logits, dim=3)
156
- segments += tokens
157
- valid_lens += list(valid_len)
158
-
159
- loss /= len(dataloader)
160
-
161
- # Update segments, attach predicted tags to each token
162
- segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
163
-
164
- return preds, segments, valid_lens, loss
165
-
166
- def infer(self, dataloader):
167
- golds, preds, segments, valid_lens = list(), list(), list(), list()
168
-
169
- for _, gold_tags, tokens, valid_len, logits in self.tag(
170
- dataloader, is_train=False
171
- ):
172
- preds += torch.argmax(logits, dim=3)
173
- segments += tokens
174
- valid_lens += list(valid_len)
175
-
176
- segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
177
- return segments
178
-
179
- def to_segments(self, segments, preds, valid_lens, vocab):
180
- if vocab is None:
181
- vocab = self.vocab
182
-
183
- tagged_segments = list()
184
- tokens_stoi = vocab.tokens.get_stoi()
185
- unk_id = tokens_stoi["UNK"]
186
-
187
- for segment, pred, valid_len in zip(segments, preds, valid_lens):
188
- # First, the token at 0th index [CLS] and token at nth index [SEP]
189
- # Combine the tokens with their corresponding predictions
190
- segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
191
-
192
- # Ignore the sub-tokens/subwords, which are identified with text being UNK
193
- segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
194
-
195
- # Attach the predicted tags to each token
196
- list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
197
- for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
198
-
199
- # We are only interested in the tagged tokens, we do no longer need raw model predictions
200
- tagged_segment = [t for t, _ in segment_pred]
201
- tagged_segments.append(tagged_segment)
202
-
203
- return tagged_segments
1
+ import os
2
+ import logging
3
+ import torch
4
+ import numpy as np
5
+ from sinatools.ner.trainers import BaseTrainer
6
+ from sinatools.ner.metrics import compute_nested_metrics
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class BertNestedTrainer(BaseTrainer):
12
+ def __init__(self, **kwargs):
13
+ super().__init__(**kwargs)
14
+
15
+ def train(self):
16
+ best_val_loss, test_loss = np.inf, np.inf
17
+ num_train_batch = len(self.train_dataloader)
18
+ num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
19
+ patience = self.patience
20
+
21
+ for epoch_index in range(self.max_epochs):
22
+ self.current_epoch = epoch_index
23
+ train_loss = 0
24
+
25
+ for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
26
+ self.train_dataloader, is_train=True
27
+ ), 1):
28
+ self.current_timestep += 1
29
+
30
+ # Compute loses for each output
31
+ # logits = B x T x L x C
32
+ losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
33
+ torch.reshape(gold_tags[:, i, :], (-1,)).long())
34
+ for i, l in enumerate(num_labels)]
35
+
36
+ torch.autograd.backward(losses)
37
+
38
+ # Avoid exploding gradient by doing gradient clipping
39
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
40
+
41
+ self.optimizer.step()
42
+ self.scheduler.step()
43
+ batch_loss = sum(l.item() for l in losses)
44
+ train_loss += batch_loss
45
+
46
+ if self.current_timestep % self.log_interval == 0:
47
+ logger.info(
48
+ "Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
49
+ epoch_index,
50
+ batch_index,
51
+ num_train_batch,
52
+ self.current_timestep,
53
+ self.optimizer.param_groups[0]['lr'],
54
+ batch_loss
55
+ )
56
+
57
+ train_loss /= num_train_batch
58
+
59
+ logger.info("** Evaluating on validation dataset **")
60
+ val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
61
+ val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
62
+
63
+ epoch_summary_loss = {
64
+ "train_loss": train_loss,
65
+ "val_loss": val_loss
66
+ }
67
+ epoch_summary_metrics = {
68
+ "val_micro_f1": val_metrics.micro_f1,
69
+ "val_precision": val_metrics.precision,
70
+ "val_recall": val_metrics.recall
71
+ }
72
+
73
+ logger.info(
74
+ "Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
75
+ epoch_index,
76
+ self.current_timestep,
77
+ train_loss,
78
+ val_loss,
79
+ val_metrics.micro_f1
80
+ )
81
+
82
+ if val_loss < best_val_loss:
83
+ patience = self.patience
84
+ best_val_loss = val_loss
85
+ logger.info("** Validation improved, evaluating test data **")
86
+ test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
87
+ self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
88
+ test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
89
+
90
+ epoch_summary_loss["test_loss"] = test_loss
91
+ epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
92
+ epoch_summary_metrics["test_precision"] = test_metrics.precision
93
+ epoch_summary_metrics["test_recall"] = test_metrics.recall
94
+
95
+ logger.info(
96
+ f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
97
+ epoch_index,
98
+ self.current_timestep,
99
+ test_loss,
100
+ test_metrics.micro_f1
101
+ )
102
+
103
+ self.save()
104
+ else:
105
+ patience -= 1
106
+
107
+ # No improvements, terminating early
108
+ if patience == 0:
109
+ logger.info("Early termination triggered")
110
+ break
111
+
112
+ self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
113
+ self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
114
+
115
+ def tag(self, dataloader, is_train=True):
116
+ """
117
+ Given a dataloader containing segments, predict the tags
118
+ :param dataloader: torch.utils.data.DataLoader
119
+ :param is_train: boolean - True for training model, False for evaluation
120
+ :return: Iterator
121
+ subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
122
+ gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
123
+ tokens - List[arabiner.data.dataset.Token] - list of tokens
124
+ valid_len (B x 1) - int - valiud length of each sequence
125
+ logits (B x T x NUM_LABELS) - logits for each token and each tag
126
+ """
127
+ for subwords, gold_tags, tokens, mask, valid_len in dataloader:
128
+ self.model.train(is_train)
129
+
130
+ if torch.cuda.is_available():
131
+ subwords = subwords.cuda()
132
+ gold_tags = gold_tags.cuda()
133
+
134
+ if is_train:
135
+ self.optimizer.zero_grad()
136
+ logits = self.model(subwords)
137
+ else:
138
+ with torch.no_grad():
139
+ logits = self.model(subwords)
140
+
141
+ yield subwords, gold_tags, tokens, valid_len, logits
142
+
143
+ def eval(self, dataloader):
144
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
145
+ num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
146
+ loss = 0
147
+
148
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
149
+ dataloader, is_train=False
150
+ ):
151
+ losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
152
+ torch.reshape(gold_tags[:, i, :], (-1,)).long())
153
+ for i, l in enumerate(num_labels)]
154
+ loss += sum(losses)
155
+ preds += torch.argmax(logits, dim=3)
156
+ segments += tokens
157
+ valid_lens += list(valid_len)
158
+
159
+ loss /= len(dataloader)
160
+
161
+ # Update segments, attach predicted tags to each token
162
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
163
+
164
+ return preds, segments, valid_lens, loss
165
+
166
+ def infer(self, dataloader):
167
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
168
+
169
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
170
+ dataloader, is_train=False
171
+ ):
172
+ preds += torch.argmax(logits, dim=3)
173
+ segments += tokens
174
+ valid_lens += list(valid_len)
175
+
176
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
177
+ return segments
178
+
179
+ def to_segments(self, segments, preds, valid_lens, vocab):
180
+ if vocab is None:
181
+ vocab = self.vocab
182
+
183
+ tagged_segments = list()
184
+ tokens_stoi = vocab.tokens.get_stoi()
185
+ unk_id = tokens_stoi["UNK"]
186
+
187
+ for segment, pred, valid_len in zip(segments, preds, valid_lens):
188
+ # First, the token at 0th index [CLS] and token at nth index [SEP]
189
+ # Combine the tokens with their corresponding predictions
190
+ segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
191
+
192
+ # Ignore the sub-tokens/subwords, which are identified with text being UNK
193
+ segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
194
+
195
+ # Attach the predicted tags to each token
196
+ list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
197
+ for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
198
+
199
+ # We are only interested in the tagged tokens, we do no longer need raw model predictions
200
+ tagged_segment = [t for t, _ in segment_pred]
201
+ tagged_segments.append(tagged_segment)
202
+
203
+ return tagged_segments