SinaTools 0.1.36__py2.py3-none-any.whl → 0.1.38__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {SinaTools-0.1.36.dist-info → SinaTools-0.1.38.dist-info}/METADATA +62 -64
- {SinaTools-0.1.36.dist-info → SinaTools-0.1.38.dist-info}/RECORD +13 -13
- {SinaTools-0.1.36.dist-info → SinaTools-0.1.38.dist-info}/WHEEL +6 -6
- {SinaTools-0.1.36.dist-info → SinaTools-0.1.38.dist-info}/entry_points.txt +0 -1
- sinatools/VERSION +1 -1
- sinatools/ner/trainers/BertNestedTrainer.py +203 -203
- sinatools/ner/trainers/BertTrainer.py +163 -163
- sinatools/ner/trainers/__init__.py +2 -2
- sinatools/utils/similarity.py +62 -27
- {SinaTools-0.1.36.data → SinaTools-0.1.38.data}/data/sinatools/environment.yml +0 -0
- {SinaTools-0.1.36.dist-info → SinaTools-0.1.38.dist-info}/AUTHORS.rst +0 -0
- {SinaTools-0.1.36.dist-info → SinaTools-0.1.38.dist-info}/LICENSE +0 -0
- {SinaTools-0.1.36.dist-info → SinaTools-0.1.38.dist-info}/top_level.txt +0 -0
@@ -1,64 +1,62 @@
|
|
1
|
-
Metadata-Version: 2.1
|
2
|
-
Name: SinaTools
|
3
|
-
Version: 0.1.
|
4
|
-
Summary: Open-source Python toolkit for Arabic Natural Understanding, allowing people to integrate it in their system workflow.
|
5
|
-
Home-page: https://github.com/SinaLab/sinatools
|
6
|
-
License: MIT license
|
7
|
-
Keywords: sinatools
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
Requires-Dist:
|
12
|
-
Requires-Dist:
|
13
|
-
Requires-Dist:
|
14
|
-
Requires-Dist:
|
15
|
-
Requires-Dist: pathlib
|
16
|
-
Requires-Dist: torch
|
17
|
-
Requires-Dist: transformers
|
18
|
-
Requires-Dist: torchtext
|
19
|
-
Requires-Dist: torchvision
|
20
|
-
Requires-Dist: seqeval
|
21
|
-
Requires-Dist: natsort
|
22
|
-
|
23
|
-
SinaTools
|
24
|
-
======================
|
25
|
-
Open Source Toolkit for Arabic NLP and NLU developed by [SinaLab](http://sina.birzeit.edu/) at Birzeit University. SinaTools is available through Python APIs, command lines, colabs, and online demos.
|
26
|
-
|
27
|
-
See the full list of [Available Packages](https://sina.birzeit.edu/sinatools/), which include: (1) [Morphology Tagging](https://sina.birzeit.edu/sinatools/index.html#morph), (2) [Named Entity Recognition (NER)](https://sina.birzeit.edu/sinatools/index.html#ner), (3) [Word Sense Disambiguation (WSD)](https://sina.birzeit.edu/sinatools/index.html#wsd), (4) [Semantic Relatedness](https://sina.birzeit.edu/sinatools/index.html#sr), (5) [Synonymy Extraction and Evaluation](https://sina.birzeit.edu/sinatools/index.html#se), (6) [Relation Extraction](https://sina.birzeit.edu/sinatools/index.html#re), (7) [Utilities](https://sina.birzeit.edu/sinatools/index.html#u) (diacritic-based word matching, Jaccard similarly, parser, tokenizers, corpora processing, transliteration, etc).
|
28
|
-
|
29
|
-
See [Demo Pages](https://sina.birzeit.edu/sinatools/).
|
30
|
-
|
31
|
-
See the [benchmarking](https://www.jarrar.info/publications/HJK24.pdf), which shows that SinaTools outperformed all related toolkits.
|
32
|
-
|
33
|
-
Installation
|
34
|
-
--------
|
35
|
-
To install SinaTools, ensure you are using Python version 3.10.8, then clone the [GitHub](git://github.com/SinaLab/SinaTools) repository.
|
36
|
-
|
37
|
-
Alternatively, you can execute the following command:
|
38
|
-
|
39
|
-
```bash
|
40
|
-
pip install sinatools
|
41
|
-
```
|
42
|
-
|
43
|
-
Installing Models and Data Files
|
44
|
-
--------
|
45
|
-
Some modules in SinaTools require some data files and fine-tuned models to be downloaded. To download these models, please consult the [DataDownload](https://sina.birzeit.edu/sinatools/documentation/cli_tools/DataDownload/DataDownload.html).
|
46
|
-
|
47
|
-
Documentation
|
48
|
-
--------
|
49
|
-
For information, please refer to the [main page](https://sina.birzeit.edu/sinatools) or the [online domuementation](https://sina.birzeit.edu/sinatools/documentation).
|
50
|
-
|
51
|
-
Citation
|
52
|
-
-------
|
53
|
-
Tymaa Hammouda, Mustafa Jarrar, Mohammed Khalilia: [SinaTools: Open Source Toolkit for Arabic Natural Language Understanding](http://www.jarrar.info/publications/HJK24.pdf). In Proceedings of the 2024 AI in Computational Linguistics (ACLing 2024), Procedia Computer Science, Dubai. ELSEVIER.
|
54
|
-
|
55
|
-
License
|
56
|
-
--------
|
57
|
-
SinaTools is available under the MIT License. See the [LICENSE](https://github.com/SinaLab/sinatools/blob/main/LICENSE) file for more information.
|
58
|
-
|
59
|
-
Reporting Issues
|
60
|
-
--------
|
61
|
-
To report any issues or bugs, please contact us at "sina.institute.bzu@gmail.com" or visit [SinaTools Issues](https://github.com/SinaLab/sinatools/issues).
|
62
|
-
|
63
|
-
|
64
|
-
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: SinaTools
|
3
|
+
Version: 0.1.38
|
4
|
+
Summary: Open-source Python toolkit for Arabic Natural Understanding, allowing people to integrate it in their system workflow.
|
5
|
+
Home-page: https://github.com/SinaLab/sinatools
|
6
|
+
License: MIT license
|
7
|
+
Keywords: sinatools
|
8
|
+
Description-Content-Type: text/markdown
|
9
|
+
License-File: LICENSE
|
10
|
+
License-File: AUTHORS.rst
|
11
|
+
Requires-Dist: six
|
12
|
+
Requires-Dist: farasapy
|
13
|
+
Requires-Dist: tqdm
|
14
|
+
Requires-Dist: requests
|
15
|
+
Requires-Dist: pathlib
|
16
|
+
Requires-Dist: torch ==1.13.0
|
17
|
+
Requires-Dist: transformers ==4.24.0
|
18
|
+
Requires-Dist: torchtext ==0.14.0
|
19
|
+
Requires-Dist: torchvision ==0.14.0
|
20
|
+
Requires-Dist: seqeval ==1.2.2
|
21
|
+
Requires-Dist: natsort ==7.1.1
|
22
|
+
|
23
|
+
SinaTools
|
24
|
+
======================
|
25
|
+
Open Source Toolkit for Arabic NLP and NLU developed by [SinaLab](http://sina.birzeit.edu/) at Birzeit University. SinaTools is available through Python APIs, command lines, colabs, and online demos.
|
26
|
+
|
27
|
+
See the full list of [Available Packages](https://sina.birzeit.edu/sinatools/), which include: (1) [Morphology Tagging](https://sina.birzeit.edu/sinatools/index.html#morph), (2) [Named Entity Recognition (NER)](https://sina.birzeit.edu/sinatools/index.html#ner), (3) [Word Sense Disambiguation (WSD)](https://sina.birzeit.edu/sinatools/index.html#wsd), (4) [Semantic Relatedness](https://sina.birzeit.edu/sinatools/index.html#sr), (5) [Synonymy Extraction and Evaluation](https://sina.birzeit.edu/sinatools/index.html#se), (6) [Relation Extraction](https://sina.birzeit.edu/sinatools/index.html#re), (7) [Utilities](https://sina.birzeit.edu/sinatools/index.html#u) (diacritic-based word matching, Jaccard similarly, parser, tokenizers, corpora processing, transliteration, etc).
|
28
|
+
|
29
|
+
See [Demo Pages](https://sina.birzeit.edu/sinatools/).
|
30
|
+
|
31
|
+
See the [benchmarking](https://www.jarrar.info/publications/HJK24.pdf), which shows that SinaTools outperformed all related toolkits.
|
32
|
+
|
33
|
+
Installation
|
34
|
+
--------
|
35
|
+
To install SinaTools, ensure you are using Python version 3.10.8, then clone the [GitHub](git://github.com/SinaLab/SinaTools) repository.
|
36
|
+
|
37
|
+
Alternatively, you can execute the following command:
|
38
|
+
|
39
|
+
```bash
|
40
|
+
pip install sinatools
|
41
|
+
```
|
42
|
+
|
43
|
+
Installing Models and Data Files
|
44
|
+
--------
|
45
|
+
Some modules in SinaTools require some data files and fine-tuned models to be downloaded. To download these models, please consult the [DataDownload](https://sina.birzeit.edu/sinatools/documentation/cli_tools/DataDownload/DataDownload.html).
|
46
|
+
|
47
|
+
Documentation
|
48
|
+
--------
|
49
|
+
For information, please refer to the [main page](https://sina.birzeit.edu/sinatools) or the [online domuementation](https://sina.birzeit.edu/sinatools/documentation).
|
50
|
+
|
51
|
+
Citation
|
52
|
+
-------
|
53
|
+
Tymaa Hammouda, Mustafa Jarrar, Mohammed Khalilia: [SinaTools: Open Source Toolkit for Arabic Natural Language Understanding](http://www.jarrar.info/publications/HJK24.pdf). In Proceedings of the 2024 AI in Computational Linguistics (ACLing 2024), Procedia Computer Science, Dubai. ELSEVIER.
|
54
|
+
|
55
|
+
License
|
56
|
+
--------
|
57
|
+
SinaTools is available under the MIT License. See the [LICENSE](https://github.com/SinaLab/sinatools/blob/main/LICENSE) file for more information.
|
58
|
+
|
59
|
+
Reporting Issues
|
60
|
+
--------
|
61
|
+
To report any issues or bugs, please contact us at "sina.institute.bzu@gmail.com" or visit [SinaTools Issues](https://github.com/SinaLab/sinatools/issues).
|
62
|
+
|
@@ -1,5 +1,5 @@
|
|
1
|
-
SinaTools-0.1.
|
2
|
-
sinatools/VERSION,sha256=
|
1
|
+
SinaTools-0.1.38.data/data/sinatools/environment.yml,sha256=OzilhLjZbo_3nU93EQNUFX-6G5O3newiSWrwxvMH2Os,7231
|
2
|
+
sinatools/VERSION,sha256=IG8zXDtajZ6W0rgxySeHulP0aoaEpnkET2yOuT5wRks,6
|
3
3
|
sinatools/__init__.py,sha256=bEosTU1o-FSpyytS6iVP_82BXHF2yHnzpJxPLYRbeII,135
|
4
4
|
sinatools/environment.yml,sha256=OzilhLjZbo_3nU93EQNUFX-6G5O3newiSWrwxvMH2Os,7231
|
5
5
|
sinatools/install_env.py,sha256=EODeeE0ZzfM_rz33_JSIruX03Nc4ghyVOM5BHVhsZaQ,404
|
@@ -91,9 +91,9 @@ sinatools/ner/nn/BertNestedTagger.py,sha256=_fwAn1kiKmXe6m5y16Ipty3kvXIEFEmiUq74
|
|
91
91
|
sinatools/ner/nn/BertSeqTagger.py,sha256=dFcBBiMw2QCWsyy7aQDe_PS3aRuNn4DOxKIHgTblFvc,504
|
92
92
|
sinatools/ner/nn/__init__.py,sha256=UgQD_XLNzQGBNSYc_Bw1aRJZjq4PJsnMT1iZwnJemqE,170
|
93
93
|
sinatools/ner/trainers/BaseTrainer.py,sha256=Ifz4SeTxJwVn1_uWZ3I9KbcSo2hLPN3ojsIYuoKE9wE,4050
|
94
|
-
sinatools/ner/trainers/BertNestedTrainer.py,sha256=
|
95
|
-
sinatools/ner/trainers/BertTrainer.py,sha256=
|
96
|
-
sinatools/ner/trainers/__init__.py,sha256=
|
94
|
+
sinatools/ner/trainers/BertNestedTrainer.py,sha256=iJOah69tXZsAXBimqP0odEsk8SPX4A355riePzW2BFs,8632
|
95
|
+
sinatools/ner/trainers/BertTrainer.py,sha256=BtttsrHPolmK3eRDqrgVUuv6lVMuImIeskxhi02Q-44,6596
|
96
|
+
sinatools/ner/trainers/__init__.py,sha256=Xnbi_M4KKJRqV7FJe1vklyT0nEW2Q2obxgcWkbR0ZbA,190
|
97
97
|
sinatools/relations/__init__.py,sha256=cYjsP2mlTYvAwVIEFtgA6i9gLUSkGVOuDggMs7TvG5k,272
|
98
98
|
sinatools/relations/relation_extractor.py,sha256=UuDlaaR0ch9BFv4sBF1tr7P-P9xq8oRZF41tAze6_ok,9751
|
99
99
|
sinatools/semantic_relatedness/__init__.py,sha256=S0xrmqtl72L02N56nbNMudPoebnYQgsaIyyX-587DsU,830
|
@@ -104,7 +104,7 @@ sinatools/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
104
104
|
sinatools/utils/charsets.py,sha256=rs82oZJqRqosZdTKXfFAJfJ5t4PxjMM_oAPsiWSWuwU,2817
|
105
105
|
sinatools/utils/parser.py,sha256=qvHdln5R5CAv_0UOJWe0mcp8JCsGqgazoeIIkoALH88,6259
|
106
106
|
sinatools/utils/readfile.py,sha256=xE4LEaCqXJIk9v37QUSSmWb-aY3UnCFUNb7uVdx3cpM,133
|
107
|
-
sinatools/utils/similarity.py,sha256=
|
107
|
+
sinatools/utils/similarity.py,sha256=HAK6OmyVnfjPm0GWL3z9s4ZoUwpZHVKxt3CeSMfqLIQ,11990
|
108
108
|
sinatools/utils/text_dublication_detector.py,sha256=FeSkbfWGMQluz23H4CBHXION-walZPgjueX6AL8u_Q0,5660
|
109
109
|
sinatools/utils/text_transliteration.py,sha256=F3smhr2AEJtySE6wGQsiXXOslTvSDzLivTYu0btgc10,8769
|
110
110
|
sinatools/utils/tokenizer.py,sha256=nyk6lh5-p38wrU62hvh4wg7ni9ammkdqqIgcjbbBxxo,6965
|
@@ -114,10 +114,10 @@ sinatools/wsd/__init__.py,sha256=mwmCUurOV42rsNRpIUP3luG0oEzeTfEx3oeDl93Oif8,306
|
|
114
114
|
sinatools/wsd/disambiguator.py,sha256=h-3idc5rPPbMDSE_QVJAsEVkDHwzYY3L2SEPNXIdOcc,20104
|
115
115
|
sinatools/wsd/settings.py,sha256=6XflVTFKD8SVySX9Wj7zYQtV26WDTcQ2-uW8-gDNHKE,747
|
116
116
|
sinatools/wsd/wsd.py,sha256=gHIBUFXegoY1z3rRnIlK6TduhYq2BTa_dHakOjOlT4k,4434
|
117
|
-
SinaTools-0.1.
|
118
|
-
SinaTools-0.1.
|
119
|
-
SinaTools-0.1.
|
120
|
-
SinaTools-0.1.
|
121
|
-
SinaTools-0.1.
|
122
|
-
SinaTools-0.1.
|
123
|
-
SinaTools-0.1.
|
117
|
+
SinaTools-0.1.38.dist-info/AUTHORS.rst,sha256=aTWeWlIdfLi56iLJfIUAwIrmqDcgxXKLji75_Fjzjyg,174
|
118
|
+
SinaTools-0.1.38.dist-info/LICENSE,sha256=uwsKYG4TayHXNANWdpfMN2lVW4dimxQjA_7vuCVhD70,1088
|
119
|
+
SinaTools-0.1.38.dist-info/METADATA,sha256=sMasvTcuV4-3WpBTyGKHkm9nTFfXuZkf4uXTHDh5_I8,3324
|
120
|
+
SinaTools-0.1.38.dist-info/WHEEL,sha256=DZajD4pwLWue70CAfc7YaxT1wLUciNBvN_TTcvXpltE,110
|
121
|
+
SinaTools-0.1.38.dist-info/entry_points.txt,sha256=_CsRKM_tSCWV5hefBNUsWf9_6DrJnzFlxeAo1wm5XqY,1302
|
122
|
+
SinaTools-0.1.38.dist-info/top_level.txt,sha256=8tNdPTeJKw3TQCaua8IJIx6N6WpgZZmVekf1OdBNJpE,10
|
123
|
+
SinaTools-0.1.38.dist-info/RECORD,,
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Wheel-Version: 1.0
|
2
|
-
Generator: bdist_wheel (0.
|
3
|
-
Root-Is-Purelib: true
|
4
|
-
Tag: py2-none-any
|
5
|
-
Tag: py3-none-any
|
6
|
-
|
1
|
+
Wheel-Version: 1.0
|
2
|
+
Generator: bdist_wheel (0.43.0)
|
3
|
+
Root-Is-Purelib: true
|
4
|
+
Tag: py2-none-any
|
5
|
+
Tag: py3-none-any
|
6
|
+
|
sinatools/VERSION
CHANGED
@@ -1 +1 @@
|
|
1
|
-
0.1.
|
1
|
+
0.1.38
|
@@ -1,203 +1,203 @@
|
|
1
|
-
import os
|
2
|
-
import logging
|
3
|
-
import torch
|
4
|
-
import numpy as np
|
5
|
-
from sinatools.ner.trainers import BaseTrainer
|
6
|
-
from sinatools.ner.metrics import compute_nested_metrics
|
7
|
-
|
8
|
-
logger = logging.getLogger(__name__)
|
9
|
-
|
10
|
-
|
11
|
-
class BertNestedTrainer(BaseTrainer):
|
12
|
-
def __init__(self, **kwargs):
|
13
|
-
super().__init__(**kwargs)
|
14
|
-
|
15
|
-
def train(self):
|
16
|
-
best_val_loss, test_loss = np.inf, np.inf
|
17
|
-
num_train_batch = len(self.train_dataloader)
|
18
|
-
num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
|
19
|
-
patience = self.patience
|
20
|
-
|
21
|
-
for epoch_index in range(self.max_epochs):
|
22
|
-
self.current_epoch = epoch_index
|
23
|
-
train_loss = 0
|
24
|
-
|
25
|
-
for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
|
26
|
-
self.train_dataloader, is_train=True
|
27
|
-
), 1):
|
28
|
-
self.current_timestep += 1
|
29
|
-
|
30
|
-
# Compute loses for each output
|
31
|
-
# logits = B x T x L x C
|
32
|
-
losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
|
33
|
-
torch.reshape(gold_tags[:, i, :], (-1,)).long())
|
34
|
-
for i, l in enumerate(num_labels)]
|
35
|
-
|
36
|
-
torch.autograd.backward(losses)
|
37
|
-
|
38
|
-
# Avoid exploding gradient by doing gradient clipping
|
39
|
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
|
40
|
-
|
41
|
-
self.optimizer.step()
|
42
|
-
self.scheduler.step()
|
43
|
-
batch_loss = sum(l.item() for l in losses)
|
44
|
-
train_loss += batch_loss
|
45
|
-
|
46
|
-
if self.current_timestep % self.log_interval == 0:
|
47
|
-
logger.info(
|
48
|
-
"Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
|
49
|
-
epoch_index,
|
50
|
-
batch_index,
|
51
|
-
num_train_batch,
|
52
|
-
self.current_timestep,
|
53
|
-
self.optimizer.param_groups[0]['lr'],
|
54
|
-
batch_loss
|
55
|
-
)
|
56
|
-
|
57
|
-
train_loss /= num_train_batch
|
58
|
-
|
59
|
-
logger.info("** Evaluating on validation dataset **")
|
60
|
-
val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
|
61
|
-
val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
|
62
|
-
|
63
|
-
epoch_summary_loss = {
|
64
|
-
"train_loss": train_loss,
|
65
|
-
"val_loss": val_loss
|
66
|
-
}
|
67
|
-
epoch_summary_metrics = {
|
68
|
-
"val_micro_f1": val_metrics.micro_f1,
|
69
|
-
"val_precision": val_metrics.precision,
|
70
|
-
"val_recall": val_metrics.recall
|
71
|
-
}
|
72
|
-
|
73
|
-
logger.info(
|
74
|
-
"Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
|
75
|
-
epoch_index,
|
76
|
-
self.current_timestep,
|
77
|
-
train_loss,
|
78
|
-
val_loss,
|
79
|
-
val_metrics.micro_f1
|
80
|
-
)
|
81
|
-
|
82
|
-
if val_loss < best_val_loss:
|
83
|
-
patience = self.patience
|
84
|
-
best_val_loss = val_loss
|
85
|
-
logger.info("** Validation improved, evaluating test data **")
|
86
|
-
test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
|
87
|
-
self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
|
88
|
-
test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
|
89
|
-
|
90
|
-
epoch_summary_loss["test_loss"] = test_loss
|
91
|
-
epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
|
92
|
-
epoch_summary_metrics["test_precision"] = test_metrics.precision
|
93
|
-
epoch_summary_metrics["test_recall"] = test_metrics.recall
|
94
|
-
|
95
|
-
logger.info(
|
96
|
-
f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
|
97
|
-
epoch_index,
|
98
|
-
self.current_timestep,
|
99
|
-
test_loss,
|
100
|
-
test_metrics.micro_f1
|
101
|
-
)
|
102
|
-
|
103
|
-
self.save()
|
104
|
-
else:
|
105
|
-
patience -= 1
|
106
|
-
|
107
|
-
# No improvements, terminating early
|
108
|
-
if patience == 0:
|
109
|
-
logger.info("Early termination triggered")
|
110
|
-
break
|
111
|
-
|
112
|
-
self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
|
113
|
-
self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
|
114
|
-
|
115
|
-
def tag(self, dataloader, is_train=True):
|
116
|
-
"""
|
117
|
-
Given a dataloader containing segments, predict the tags
|
118
|
-
:param dataloader: torch.utils.data.DataLoader
|
119
|
-
:param is_train: boolean - True for training model, False for evaluation
|
120
|
-
:return: Iterator
|
121
|
-
subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
|
122
|
-
gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
|
123
|
-
tokens - List[arabiner.data.dataset.Token] - list of tokens
|
124
|
-
valid_len (B x 1) - int - valiud length of each sequence
|
125
|
-
logits (B x T x NUM_LABELS) - logits for each token and each tag
|
126
|
-
"""
|
127
|
-
for subwords, gold_tags, tokens, mask, valid_len in dataloader:
|
128
|
-
self.model.train(is_train)
|
129
|
-
|
130
|
-
if torch.cuda.is_available():
|
131
|
-
subwords = subwords.cuda()
|
132
|
-
gold_tags = gold_tags.cuda()
|
133
|
-
|
134
|
-
if is_train:
|
135
|
-
self.optimizer.zero_grad()
|
136
|
-
logits = self.model(subwords)
|
137
|
-
else:
|
138
|
-
with torch.no_grad():
|
139
|
-
logits = self.model(subwords)
|
140
|
-
|
141
|
-
yield subwords, gold_tags, tokens, valid_len, logits
|
142
|
-
|
143
|
-
def eval(self, dataloader):
|
144
|
-
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
145
|
-
num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
|
146
|
-
loss = 0
|
147
|
-
|
148
|
-
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
149
|
-
dataloader, is_train=False
|
150
|
-
):
|
151
|
-
losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
|
152
|
-
torch.reshape(gold_tags[:, i, :], (-1,)).long())
|
153
|
-
for i, l in enumerate(num_labels)]
|
154
|
-
loss += sum(losses)
|
155
|
-
preds += torch.argmax(logits, dim=3)
|
156
|
-
segments += tokens
|
157
|
-
valid_lens += list(valid_len)
|
158
|
-
|
159
|
-
loss /= len(dataloader)
|
160
|
-
|
161
|
-
# Update segments, attach predicted tags to each token
|
162
|
-
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
163
|
-
|
164
|
-
return preds, segments, valid_lens, loss
|
165
|
-
|
166
|
-
def infer(self, dataloader):
|
167
|
-
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
168
|
-
|
169
|
-
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
170
|
-
dataloader, is_train=False
|
171
|
-
):
|
172
|
-
preds += torch.argmax(logits, dim=3)
|
173
|
-
segments += tokens
|
174
|
-
valid_lens += list(valid_len)
|
175
|
-
|
176
|
-
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
177
|
-
return segments
|
178
|
-
|
179
|
-
def to_segments(self, segments, preds, valid_lens, vocab):
|
180
|
-
if vocab is None:
|
181
|
-
vocab = self.vocab
|
182
|
-
|
183
|
-
tagged_segments = list()
|
184
|
-
tokens_stoi = vocab.tokens.get_stoi()
|
185
|
-
unk_id = tokens_stoi["UNK"]
|
186
|
-
|
187
|
-
for segment, pred, valid_len in zip(segments, preds, valid_lens):
|
188
|
-
# First, the token at 0th index [CLS] and token at nth index [SEP]
|
189
|
-
# Combine the tokens with their corresponding predictions
|
190
|
-
segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
|
191
|
-
|
192
|
-
# Ignore the sub-tokens/subwords, which are identified with text being UNK
|
193
|
-
segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
|
194
|
-
|
195
|
-
# Attach the predicted tags to each token
|
196
|
-
list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
|
197
|
-
for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
|
198
|
-
|
199
|
-
# We are only interested in the tagged tokens, we do no longer need raw model predictions
|
200
|
-
tagged_segment = [t for t, _ in segment_pred]
|
201
|
-
tagged_segments.append(tagged_segment)
|
202
|
-
|
203
|
-
return tagged_segments
|
1
|
+
import os
|
2
|
+
import logging
|
3
|
+
import torch
|
4
|
+
import numpy as np
|
5
|
+
from sinatools.ner.trainers import BaseTrainer
|
6
|
+
from sinatools.ner.metrics import compute_nested_metrics
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
class BertNestedTrainer(BaseTrainer):
|
12
|
+
def __init__(self, **kwargs):
|
13
|
+
super().__init__(**kwargs)
|
14
|
+
|
15
|
+
def train(self):
|
16
|
+
best_val_loss, test_loss = np.inf, np.inf
|
17
|
+
num_train_batch = len(self.train_dataloader)
|
18
|
+
num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
|
19
|
+
patience = self.patience
|
20
|
+
|
21
|
+
for epoch_index in range(self.max_epochs):
|
22
|
+
self.current_epoch = epoch_index
|
23
|
+
train_loss = 0
|
24
|
+
|
25
|
+
for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
|
26
|
+
self.train_dataloader, is_train=True
|
27
|
+
), 1):
|
28
|
+
self.current_timestep += 1
|
29
|
+
|
30
|
+
# Compute loses for each output
|
31
|
+
# logits = B x T x L x C
|
32
|
+
losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
|
33
|
+
torch.reshape(gold_tags[:, i, :], (-1,)).long())
|
34
|
+
for i, l in enumerate(num_labels)]
|
35
|
+
|
36
|
+
torch.autograd.backward(losses)
|
37
|
+
|
38
|
+
# Avoid exploding gradient by doing gradient clipping
|
39
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
|
40
|
+
|
41
|
+
self.optimizer.step()
|
42
|
+
self.scheduler.step()
|
43
|
+
batch_loss = sum(l.item() for l in losses)
|
44
|
+
train_loss += batch_loss
|
45
|
+
|
46
|
+
if self.current_timestep % self.log_interval == 0:
|
47
|
+
logger.info(
|
48
|
+
"Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
|
49
|
+
epoch_index,
|
50
|
+
batch_index,
|
51
|
+
num_train_batch,
|
52
|
+
self.current_timestep,
|
53
|
+
self.optimizer.param_groups[0]['lr'],
|
54
|
+
batch_loss
|
55
|
+
)
|
56
|
+
|
57
|
+
train_loss /= num_train_batch
|
58
|
+
|
59
|
+
logger.info("** Evaluating on validation dataset **")
|
60
|
+
val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
|
61
|
+
val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
|
62
|
+
|
63
|
+
epoch_summary_loss = {
|
64
|
+
"train_loss": train_loss,
|
65
|
+
"val_loss": val_loss
|
66
|
+
}
|
67
|
+
epoch_summary_metrics = {
|
68
|
+
"val_micro_f1": val_metrics.micro_f1,
|
69
|
+
"val_precision": val_metrics.precision,
|
70
|
+
"val_recall": val_metrics.recall
|
71
|
+
}
|
72
|
+
|
73
|
+
logger.info(
|
74
|
+
"Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
|
75
|
+
epoch_index,
|
76
|
+
self.current_timestep,
|
77
|
+
train_loss,
|
78
|
+
val_loss,
|
79
|
+
val_metrics.micro_f1
|
80
|
+
)
|
81
|
+
|
82
|
+
if val_loss < best_val_loss:
|
83
|
+
patience = self.patience
|
84
|
+
best_val_loss = val_loss
|
85
|
+
logger.info("** Validation improved, evaluating test data **")
|
86
|
+
test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
|
87
|
+
self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
|
88
|
+
test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
|
89
|
+
|
90
|
+
epoch_summary_loss["test_loss"] = test_loss
|
91
|
+
epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
|
92
|
+
epoch_summary_metrics["test_precision"] = test_metrics.precision
|
93
|
+
epoch_summary_metrics["test_recall"] = test_metrics.recall
|
94
|
+
|
95
|
+
logger.info(
|
96
|
+
f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
|
97
|
+
epoch_index,
|
98
|
+
self.current_timestep,
|
99
|
+
test_loss,
|
100
|
+
test_metrics.micro_f1
|
101
|
+
)
|
102
|
+
|
103
|
+
self.save()
|
104
|
+
else:
|
105
|
+
patience -= 1
|
106
|
+
|
107
|
+
# No improvements, terminating early
|
108
|
+
if patience == 0:
|
109
|
+
logger.info("Early termination triggered")
|
110
|
+
break
|
111
|
+
|
112
|
+
self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
|
113
|
+
self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
|
114
|
+
|
115
|
+
def tag(self, dataloader, is_train=True):
|
116
|
+
"""
|
117
|
+
Given a dataloader containing segments, predict the tags
|
118
|
+
:param dataloader: torch.utils.data.DataLoader
|
119
|
+
:param is_train: boolean - True for training model, False for evaluation
|
120
|
+
:return: Iterator
|
121
|
+
subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
|
122
|
+
gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
|
123
|
+
tokens - List[arabiner.data.dataset.Token] - list of tokens
|
124
|
+
valid_len (B x 1) - int - valiud length of each sequence
|
125
|
+
logits (B x T x NUM_LABELS) - logits for each token and each tag
|
126
|
+
"""
|
127
|
+
for subwords, gold_tags, tokens, mask, valid_len in dataloader:
|
128
|
+
self.model.train(is_train)
|
129
|
+
|
130
|
+
if torch.cuda.is_available():
|
131
|
+
subwords = subwords.cuda()
|
132
|
+
gold_tags = gold_tags.cuda()
|
133
|
+
|
134
|
+
if is_train:
|
135
|
+
self.optimizer.zero_grad()
|
136
|
+
logits = self.model(subwords)
|
137
|
+
else:
|
138
|
+
with torch.no_grad():
|
139
|
+
logits = self.model(subwords)
|
140
|
+
|
141
|
+
yield subwords, gold_tags, tokens, valid_len, logits
|
142
|
+
|
143
|
+
def eval(self, dataloader):
|
144
|
+
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
145
|
+
num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
|
146
|
+
loss = 0
|
147
|
+
|
148
|
+
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
149
|
+
dataloader, is_train=False
|
150
|
+
):
|
151
|
+
losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
|
152
|
+
torch.reshape(gold_tags[:, i, :], (-1,)).long())
|
153
|
+
for i, l in enumerate(num_labels)]
|
154
|
+
loss += sum(losses)
|
155
|
+
preds += torch.argmax(logits, dim=3)
|
156
|
+
segments += tokens
|
157
|
+
valid_lens += list(valid_len)
|
158
|
+
|
159
|
+
loss /= len(dataloader)
|
160
|
+
|
161
|
+
# Update segments, attach predicted tags to each token
|
162
|
+
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
163
|
+
|
164
|
+
return preds, segments, valid_lens, loss
|
165
|
+
|
166
|
+
def infer(self, dataloader):
|
167
|
+
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
168
|
+
|
169
|
+
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
170
|
+
dataloader, is_train=False
|
171
|
+
):
|
172
|
+
preds += torch.argmax(logits, dim=3)
|
173
|
+
segments += tokens
|
174
|
+
valid_lens += list(valid_len)
|
175
|
+
|
176
|
+
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
177
|
+
return segments
|
178
|
+
|
179
|
+
def to_segments(self, segments, preds, valid_lens, vocab):
|
180
|
+
if vocab is None:
|
181
|
+
vocab = self.vocab
|
182
|
+
|
183
|
+
tagged_segments = list()
|
184
|
+
tokens_stoi = vocab.tokens.get_stoi()
|
185
|
+
unk_id = tokens_stoi["UNK"]
|
186
|
+
|
187
|
+
for segment, pred, valid_len in zip(segments, preds, valid_lens):
|
188
|
+
# First, the token at 0th index [CLS] and token at nth index [SEP]
|
189
|
+
# Combine the tokens with their corresponding predictions
|
190
|
+
segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
|
191
|
+
|
192
|
+
# Ignore the sub-tokens/subwords, which are identified with text being UNK
|
193
|
+
segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
|
194
|
+
|
195
|
+
# Attach the predicted tags to each token
|
196
|
+
list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
|
197
|
+
for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
|
198
|
+
|
199
|
+
# We are only interested in the tagged tokens, we do no longer need raw model predictions
|
200
|
+
tagged_segment = [t for t, _ in segment_pred]
|
201
|
+
tagged_segments.append(tagged_segment)
|
202
|
+
|
203
|
+
return tagged_segments
|
@@ -1,163 +1,163 @@
|
|
1
|
-
import os
|
2
|
-
import logging
|
3
|
-
import torch
|
4
|
-
import numpy as np
|
5
|
-
from sinatools.ner.trainers import BaseTrainer
|
6
|
-
from sinatools.ner.metrics import compute_single_label_metrics
|
7
|
-
|
8
|
-
logger = logging.getLogger(__name__)
|
9
|
-
|
10
|
-
|
11
|
-
class BertTrainer(BaseTrainer):
|
12
|
-
def __init__(self, **kwargs):
|
13
|
-
super().__init__(**kwargs)
|
14
|
-
|
15
|
-
def train(self):
|
16
|
-
best_val_loss, test_loss = np.inf, np.inf
|
17
|
-
num_train_batch = len(self.train_dataloader)
|
18
|
-
patience = self.patience
|
19
|
-
|
20
|
-
for epoch_index in range(self.max_epochs):
|
21
|
-
self.current_epoch = epoch_index
|
22
|
-
train_loss = 0
|
23
|
-
|
24
|
-
for batch_index, (_, gold_tags, _, _, logits) in enumerate(self.tag(
|
25
|
-
self.train_dataloader, is_train=True
|
26
|
-
), 1):
|
27
|
-
self.current_timestep += 1
|
28
|
-
batch_loss = self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
|
29
|
-
batch_loss.backward()
|
30
|
-
|
31
|
-
# Avoid exploding gradient by doing gradient clipping
|
32
|
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
|
33
|
-
|
34
|
-
self.optimizer.step()
|
35
|
-
self.scheduler.step()
|
36
|
-
train_loss += batch_loss.item()
|
37
|
-
|
38
|
-
if self.current_timestep % self.log_interval == 0:
|
39
|
-
logger.info(
|
40
|
-
"Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
|
41
|
-
epoch_index,
|
42
|
-
batch_index,
|
43
|
-
num_train_batch,
|
44
|
-
self.current_timestep,
|
45
|
-
self.optimizer.param_groups[0]['lr'],
|
46
|
-
batch_loss.item()
|
47
|
-
)
|
48
|
-
|
49
|
-
train_loss /= num_train_batch
|
50
|
-
|
51
|
-
logger.info("** Evaluating on validation dataset **")
|
52
|
-
val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
|
53
|
-
val_metrics = compute_single_label_metrics(segments)
|
54
|
-
|
55
|
-
epoch_summary_loss = {
|
56
|
-
"train_loss": train_loss,
|
57
|
-
"val_loss": val_loss
|
58
|
-
}
|
59
|
-
epoch_summary_metrics = {
|
60
|
-
"val_micro_f1": val_metrics.micro_f1,
|
61
|
-
"val_precision": val_metrics.precision,
|
62
|
-
"val_recall": val_metrics.recall
|
63
|
-
}
|
64
|
-
|
65
|
-
logger.info(
|
66
|
-
"Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
|
67
|
-
epoch_index,
|
68
|
-
self.current_timestep,
|
69
|
-
train_loss,
|
70
|
-
val_loss,
|
71
|
-
val_metrics.micro_f1
|
72
|
-
)
|
73
|
-
|
74
|
-
if val_loss < best_val_loss:
|
75
|
-
patience = self.patience
|
76
|
-
best_val_loss = val_loss
|
77
|
-
logger.info("** Validation improved, evaluating test data **")
|
78
|
-
test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
|
79
|
-
self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
|
80
|
-
test_metrics = compute_single_label_metrics(segments)
|
81
|
-
|
82
|
-
epoch_summary_loss["test_loss"] = test_loss
|
83
|
-
epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
|
84
|
-
epoch_summary_metrics["test_precision"] = test_metrics.precision
|
85
|
-
epoch_summary_metrics["test_recall"] = test_metrics.recall
|
86
|
-
|
87
|
-
logger.info(
|
88
|
-
f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
|
89
|
-
epoch_index,
|
90
|
-
self.current_timestep,
|
91
|
-
test_loss,
|
92
|
-
test_metrics.micro_f1
|
93
|
-
)
|
94
|
-
|
95
|
-
self.save()
|
96
|
-
else:
|
97
|
-
patience -= 1
|
98
|
-
|
99
|
-
# No improvements, terminating early
|
100
|
-
if patience == 0:
|
101
|
-
logger.info("Early termination triggered")
|
102
|
-
break
|
103
|
-
|
104
|
-
self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
|
105
|
-
self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
|
106
|
-
|
107
|
-
def eval(self, dataloader):
|
108
|
-
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
109
|
-
loss = 0
|
110
|
-
|
111
|
-
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
112
|
-
dataloader, is_train=False
|
113
|
-
):
|
114
|
-
loss += self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
|
115
|
-
preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
|
116
|
-
segments += tokens
|
117
|
-
valid_lens += list(valid_len)
|
118
|
-
|
119
|
-
loss /= len(dataloader)
|
120
|
-
|
121
|
-
# Update segments, attach predicted tags to each token
|
122
|
-
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
123
|
-
|
124
|
-
return preds, segments, valid_lens, loss.item()
|
125
|
-
|
126
|
-
def infer(self, dataloader):
|
127
|
-
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
128
|
-
|
129
|
-
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
130
|
-
dataloader, is_train=False
|
131
|
-
):
|
132
|
-
preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
|
133
|
-
segments += tokens
|
134
|
-
valid_lens += list(valid_len)
|
135
|
-
|
136
|
-
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
137
|
-
return segments
|
138
|
-
|
139
|
-
def to_segments(self, segments, preds, valid_lens, vocab):
|
140
|
-
if vocab is None:
|
141
|
-
vocab = self.vocab
|
142
|
-
|
143
|
-
tagged_segments = list()
|
144
|
-
tokens_stoi = vocab.tokens.get_stoi()
|
145
|
-
tags_itos = vocab.tags[0].get_itos()
|
146
|
-
unk_id = tokens_stoi["UNK"]
|
147
|
-
|
148
|
-
for segment, pred, valid_len in zip(segments, preds, valid_lens):
|
149
|
-
# First, the token at 0th index [CLS] and token at nth index [SEP]
|
150
|
-
# Combine the tokens with their corresponding predictions
|
151
|
-
segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
|
152
|
-
|
153
|
-
# Ignore the sub-tokens/subwords, which are identified with text being UNK
|
154
|
-
segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
|
155
|
-
|
156
|
-
# Attach the predicted tags to each token
|
157
|
-
list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": tags_itos[t[1]]}]), segment_pred))
|
158
|
-
|
159
|
-
# We are only interested in the tagged tokens, we do no longer need raw model predictions
|
160
|
-
tagged_segment = [t for t, _ in segment_pred]
|
161
|
-
tagged_segments.append(tagged_segment)
|
162
|
-
|
163
|
-
return tagged_segments
|
1
|
+
import os
|
2
|
+
import logging
|
3
|
+
import torch
|
4
|
+
import numpy as np
|
5
|
+
from sinatools.ner.trainers import BaseTrainer
|
6
|
+
from sinatools.ner.metrics import compute_single_label_metrics
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
class BertTrainer(BaseTrainer):
|
12
|
+
def __init__(self, **kwargs):
|
13
|
+
super().__init__(**kwargs)
|
14
|
+
|
15
|
+
def train(self):
|
16
|
+
best_val_loss, test_loss = np.inf, np.inf
|
17
|
+
num_train_batch = len(self.train_dataloader)
|
18
|
+
patience = self.patience
|
19
|
+
|
20
|
+
for epoch_index in range(self.max_epochs):
|
21
|
+
self.current_epoch = epoch_index
|
22
|
+
train_loss = 0
|
23
|
+
|
24
|
+
for batch_index, (_, gold_tags, _, _, logits) in enumerate(self.tag(
|
25
|
+
self.train_dataloader, is_train=True
|
26
|
+
), 1):
|
27
|
+
self.current_timestep += 1
|
28
|
+
batch_loss = self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
|
29
|
+
batch_loss.backward()
|
30
|
+
|
31
|
+
# Avoid exploding gradient by doing gradient clipping
|
32
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
|
33
|
+
|
34
|
+
self.optimizer.step()
|
35
|
+
self.scheduler.step()
|
36
|
+
train_loss += batch_loss.item()
|
37
|
+
|
38
|
+
if self.current_timestep % self.log_interval == 0:
|
39
|
+
logger.info(
|
40
|
+
"Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
|
41
|
+
epoch_index,
|
42
|
+
batch_index,
|
43
|
+
num_train_batch,
|
44
|
+
self.current_timestep,
|
45
|
+
self.optimizer.param_groups[0]['lr'],
|
46
|
+
batch_loss.item()
|
47
|
+
)
|
48
|
+
|
49
|
+
train_loss /= num_train_batch
|
50
|
+
|
51
|
+
logger.info("** Evaluating on validation dataset **")
|
52
|
+
val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
|
53
|
+
val_metrics = compute_single_label_metrics(segments)
|
54
|
+
|
55
|
+
epoch_summary_loss = {
|
56
|
+
"train_loss": train_loss,
|
57
|
+
"val_loss": val_loss
|
58
|
+
}
|
59
|
+
epoch_summary_metrics = {
|
60
|
+
"val_micro_f1": val_metrics.micro_f1,
|
61
|
+
"val_precision": val_metrics.precision,
|
62
|
+
"val_recall": val_metrics.recall
|
63
|
+
}
|
64
|
+
|
65
|
+
logger.info(
|
66
|
+
"Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
|
67
|
+
epoch_index,
|
68
|
+
self.current_timestep,
|
69
|
+
train_loss,
|
70
|
+
val_loss,
|
71
|
+
val_metrics.micro_f1
|
72
|
+
)
|
73
|
+
|
74
|
+
if val_loss < best_val_loss:
|
75
|
+
patience = self.patience
|
76
|
+
best_val_loss = val_loss
|
77
|
+
logger.info("** Validation improved, evaluating test data **")
|
78
|
+
test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
|
79
|
+
self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
|
80
|
+
test_metrics = compute_single_label_metrics(segments)
|
81
|
+
|
82
|
+
epoch_summary_loss["test_loss"] = test_loss
|
83
|
+
epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
|
84
|
+
epoch_summary_metrics["test_precision"] = test_metrics.precision
|
85
|
+
epoch_summary_metrics["test_recall"] = test_metrics.recall
|
86
|
+
|
87
|
+
logger.info(
|
88
|
+
f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
|
89
|
+
epoch_index,
|
90
|
+
self.current_timestep,
|
91
|
+
test_loss,
|
92
|
+
test_metrics.micro_f1
|
93
|
+
)
|
94
|
+
|
95
|
+
self.save()
|
96
|
+
else:
|
97
|
+
patience -= 1
|
98
|
+
|
99
|
+
# No improvements, terminating early
|
100
|
+
if patience == 0:
|
101
|
+
logger.info("Early termination triggered")
|
102
|
+
break
|
103
|
+
|
104
|
+
self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
|
105
|
+
self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
|
106
|
+
|
107
|
+
def eval(self, dataloader):
|
108
|
+
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
109
|
+
loss = 0
|
110
|
+
|
111
|
+
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
112
|
+
dataloader, is_train=False
|
113
|
+
):
|
114
|
+
loss += self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
|
115
|
+
preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
|
116
|
+
segments += tokens
|
117
|
+
valid_lens += list(valid_len)
|
118
|
+
|
119
|
+
loss /= len(dataloader)
|
120
|
+
|
121
|
+
# Update segments, attach predicted tags to each token
|
122
|
+
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
123
|
+
|
124
|
+
return preds, segments, valid_lens, loss.item()
|
125
|
+
|
126
|
+
def infer(self, dataloader):
|
127
|
+
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
128
|
+
|
129
|
+
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
130
|
+
dataloader, is_train=False
|
131
|
+
):
|
132
|
+
preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
|
133
|
+
segments += tokens
|
134
|
+
valid_lens += list(valid_len)
|
135
|
+
|
136
|
+
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
137
|
+
return segments
|
138
|
+
|
139
|
+
def to_segments(self, segments, preds, valid_lens, vocab):
|
140
|
+
if vocab is None:
|
141
|
+
vocab = self.vocab
|
142
|
+
|
143
|
+
tagged_segments = list()
|
144
|
+
tokens_stoi = vocab.tokens.get_stoi()
|
145
|
+
tags_itos = vocab.tags[0].get_itos()
|
146
|
+
unk_id = tokens_stoi["UNK"]
|
147
|
+
|
148
|
+
for segment, pred, valid_len in zip(segments, preds, valid_lens):
|
149
|
+
# First, the token at 0th index [CLS] and token at nth index [SEP]
|
150
|
+
# Combine the tokens with their corresponding predictions
|
151
|
+
segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
|
152
|
+
|
153
|
+
# Ignore the sub-tokens/subwords, which are identified with text being UNK
|
154
|
+
segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
|
155
|
+
|
156
|
+
# Attach the predicted tags to each token
|
157
|
+
list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": tags_itos[t[1]]}]), segment_pred))
|
158
|
+
|
159
|
+
# We are only interested in the tagged tokens, we do no longer need raw model predictions
|
160
|
+
tagged_segment = [t for t, _ in segment_pred]
|
161
|
+
tagged_segments.append(tagged_segment)
|
162
|
+
|
163
|
+
return tagged_segments
|
@@ -1,3 +1,3 @@
|
|
1
|
-
from sinatools.ner.trainers.BaseTrainer import BaseTrainer
|
2
|
-
from sinatools.ner.trainers.BertTrainer import BertTrainer
|
1
|
+
from sinatools.ner.trainers.BaseTrainer import BaseTrainer
|
2
|
+
from sinatools.ner.trainers.BertTrainer import BertTrainer
|
3
3
|
from sinatools.ner.trainers.BertNestedTrainer import BertNestedTrainer
|
sinatools/utils/similarity.py
CHANGED
@@ -101,56 +101,91 @@ def get_intersection(list1, list2, ignore_all_diacritics_but_not_shadda=False, i
|
|
101
101
|
|
102
102
|
|
103
103
|
|
104
|
-
def get_union(list1, list2, ignore_all_diacritics_but_not_shadda, ignore_shadda_diacritic):
|
105
|
-
|
106
|
-
|
104
|
+
# def get_union(list1, list2, ignore_all_diacritics_but_not_shadda, ignore_shadda_diacritic):
|
105
|
+
# """
|
106
|
+
# Computes the union of two sets of Arabic words, considering the differences in their diacritization. The method provides two options for handling diacritics: (i) ignore all diacritics except for shadda, and (ii) ignore the shadda diacritic as well. You can try the demo online.
|
107
107
|
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
108
|
+
# Args:
|
109
|
+
# list1 (:obj:`list`): The first list.
|
110
|
+
# list2 (:obj:`bool`): The second list.
|
111
|
+
# ignore_all_diacratics_but_not_shadda (:obj:`bool`, optional) – A flag to ignore all diacratics except for the shadda. Defaults to False.
|
112
|
+
# ignore_shadda_diacritic (:obj:`bool`, optional) – A flag to ignore the shadda diacritic. Defaults to False.
|
113
113
|
|
114
|
-
|
115
|
-
|
114
|
+
# Returns:
|
115
|
+
# :obj:`list`: The union of the two lists, ignoring diacritics if flags are true.
|
116
116
|
|
117
|
-
|
117
|
+
# **Example:**
|
118
118
|
|
119
|
-
|
120
|
-
|
119
|
+
# .. highlight:: python
|
120
|
+
# .. code-block:: python
|
121
121
|
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
122
|
+
# from sinatools.utils.similarity import get_union
|
123
|
+
# list1 = ["كتب","فَعل","فَعَلَ"]
|
124
|
+
# list2 = ["كتب","فَعّل"]
|
125
|
+
# print(get_union(list1, list2, False, True))
|
126
|
+
# #output: ["كتب" ,"فَعل" ,"فَعَلَ"]
|
127
|
+
# """
|
128
|
+
# list1 = [str(i) for i in list1 if i not in (None, ' ', '')]
|
129
129
|
|
130
|
+
# list2 = [str(i) for i in list2 if i not in (None, ' ', '')]
|
131
|
+
|
132
|
+
# union_list = []
|
133
|
+
|
134
|
+
# for list1_word in list1:
|
135
|
+
# word1 = normalize_word(list1_word, ignore_all_diacritics_but_not_shadda, ignore_shadda_diacritic)
|
136
|
+
# union_list.append(word1)
|
137
|
+
|
138
|
+
# for list2_word in list2:
|
139
|
+
# word2 = normalize_word(list2_word, ignore_all_diacritics_but_not_shadda, ignore_shadda_diacritic)
|
140
|
+
# union_list.append(word2)
|
141
|
+
|
142
|
+
# i = 0
|
143
|
+
# while i < len(union_list):
|
144
|
+
# j = i + 1
|
145
|
+
# while j < len(union_list):
|
146
|
+
# non_preferred_word = get_non_preferred_word(union_list[i], union_list[j])
|
147
|
+
# if (non_preferred_word != "#"):
|
148
|
+
# union_list.remove(non_preferred_word)
|
149
|
+
# j = j + 1
|
150
|
+
# i = i + 1
|
151
|
+
|
152
|
+
# return union_list
|
153
|
+
def get_union(list1, list2, ignore_all_diacritics_but_not_shadda, ignore_shadda_diacritic):
|
154
|
+
|
155
|
+
|
156
|
+
list1 = [str(i) for i in list1 if i not in (None, ' ', '')]
|
130
157
|
list2 = [str(i) for i in list2 if i not in (None, ' ', '')]
|
131
158
|
|
159
|
+
|
132
160
|
union_list = []
|
133
161
|
|
162
|
+
# Normalize and add words from list1
|
134
163
|
for list1_word in list1:
|
135
164
|
word1 = normalize_word(list1_word, ignore_all_diacritics_but_not_shadda, ignore_shadda_diacritic)
|
136
|
-
union_list
|
165
|
+
if word1 not in union_list:
|
166
|
+
union_list.append(word1)
|
137
167
|
|
168
|
+
# Normalize and add words from list2
|
138
169
|
for list2_word in list2:
|
139
170
|
word2 = normalize_word(list2_word, ignore_all_diacritics_but_not_shadda, ignore_shadda_diacritic)
|
140
|
-
union_list
|
171
|
+
if word2 not in union_list:
|
172
|
+
union_list.append(word2)
|
141
173
|
|
174
|
+
|
142
175
|
i = 0
|
143
176
|
while i < len(union_list):
|
144
177
|
j = i + 1
|
145
178
|
while j < len(union_list):
|
146
179
|
non_preferred_word = get_non_preferred_word(union_list[i], union_list[j])
|
147
|
-
if
|
180
|
+
if non_preferred_word != "#":
|
148
181
|
union_list.remove(non_preferred_word)
|
149
|
-
|
150
|
-
|
182
|
+
j -= 1
|
183
|
+
j += 1
|
184
|
+
i += 1
|
151
185
|
|
152
186
|
return union_list
|
153
|
-
|
187
|
+
|
188
|
+
|
154
189
|
|
155
190
|
|
156
191
|
def get_jaccard_similarity(list1: list, list2: list, ignore_all_diacritics_but_not_shadda: bool, ignore_shadda_diacritic: bool) -> float:
|
@@ -184,7 +219,7 @@ def get_jaccard_similarity(list1: list, list2: list, ignore_all_diacritics_but_n
|
|
184
219
|
|
185
220
|
return float(len(intersection_list)) / float(len(union_list))
|
186
221
|
|
187
|
-
def get_jaccard(delimiter, str1, str2,
|
222
|
+
def get_jaccard(delimiter, selection, str1, str2, ignoreAllDiacriticsButNotShadda=True, ignoreShaddaDiacritic=True):
|
188
223
|
"""
|
189
224
|
Calculates and returns the Jaccard similarity values (union, intersection, or Jaccard similarity) between two lists of Arabic words, considering the differences in their diacritization. The method provides two options for handling diacritics: (i) ignore all diacritics except for shadda, and (ii) ignore the shadda diacritic as well. You can try the demo online.
|
190
225
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|