SinaTools 0.1.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. SinaTools-0.1.1.data/data/nlptools/environment.yml +227 -0
  2. SinaTools-0.1.1.dist-info/AUTHORS.rst +13 -0
  3. SinaTools-0.1.1.dist-info/LICENSE +22 -0
  4. SinaTools-0.1.1.dist-info/METADATA +72 -0
  5. SinaTools-0.1.1.dist-info/RECORD +122 -0
  6. SinaTools-0.1.1.dist-info/WHEEL +6 -0
  7. SinaTools-0.1.1.dist-info/entry_points.txt +18 -0
  8. SinaTools-0.1.1.dist-info/top_level.txt +1 -0
  9. nlptools/CLI/DataDownload/download_files.py +71 -0
  10. nlptools/CLI/arabiner/bin/infer.py +117 -0
  11. nlptools/CLI/arabiner/bin/infer2.py +81 -0
  12. nlptools/CLI/morphology/ALMA_multi_word.py +75 -0
  13. nlptools/CLI/morphology/morph_analyzer.py +91 -0
  14. nlptools/CLI/salma/salma_tools.py +68 -0
  15. nlptools/CLI/utils/__init__.py +0 -0
  16. nlptools/CLI/utils/arStrip.py +99 -0
  17. nlptools/CLI/utils/corpus_tokenizer.py +74 -0
  18. nlptools/CLI/utils/implication.py +92 -0
  19. nlptools/CLI/utils/jaccard.py +96 -0
  20. nlptools/CLI/utils/latin_remove.py +51 -0
  21. nlptools/CLI/utils/remove_Punc.py +53 -0
  22. nlptools/CLI/utils/sentence_tokenizer.py +90 -0
  23. nlptools/CLI/utils/text_transliteration.py +77 -0
  24. nlptools/DataDownload/__init__.py +0 -0
  25. nlptools/DataDownload/downloader.py +185 -0
  26. nlptools/VERSION +1 -0
  27. nlptools/__init__.py +5 -0
  28. nlptools/arabert/__init__.py +1 -0
  29. nlptools/arabert/arabert/__init__.py +14 -0
  30. nlptools/arabert/arabert/create_classification_data.py +260 -0
  31. nlptools/arabert/arabert/create_pretraining_data.py +534 -0
  32. nlptools/arabert/arabert/extract_features.py +444 -0
  33. nlptools/arabert/arabert/lamb_optimizer.py +158 -0
  34. nlptools/arabert/arabert/modeling.py +1027 -0
  35. nlptools/arabert/arabert/optimization.py +202 -0
  36. nlptools/arabert/arabert/run_classifier.py +1078 -0
  37. nlptools/arabert/arabert/run_pretraining.py +593 -0
  38. nlptools/arabert/arabert/run_squad.py +1440 -0
  39. nlptools/arabert/arabert/tokenization.py +414 -0
  40. nlptools/arabert/araelectra/__init__.py +1 -0
  41. nlptools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +103 -0
  42. nlptools/arabert/araelectra/build_pretraining_dataset.py +230 -0
  43. nlptools/arabert/araelectra/build_pretraining_dataset_single_file.py +90 -0
  44. nlptools/arabert/araelectra/configure_finetuning.py +172 -0
  45. nlptools/arabert/araelectra/configure_pretraining.py +143 -0
  46. nlptools/arabert/araelectra/finetune/__init__.py +14 -0
  47. nlptools/arabert/araelectra/finetune/feature_spec.py +56 -0
  48. nlptools/arabert/araelectra/finetune/preprocessing.py +173 -0
  49. nlptools/arabert/araelectra/finetune/scorer.py +54 -0
  50. nlptools/arabert/araelectra/finetune/task.py +74 -0
  51. nlptools/arabert/araelectra/finetune/task_builder.py +70 -0
  52. nlptools/arabert/araelectra/flops_computation.py +215 -0
  53. nlptools/arabert/araelectra/model/__init__.py +14 -0
  54. nlptools/arabert/araelectra/model/modeling.py +1029 -0
  55. nlptools/arabert/araelectra/model/optimization.py +193 -0
  56. nlptools/arabert/araelectra/model/tokenization.py +355 -0
  57. nlptools/arabert/araelectra/pretrain/__init__.py +14 -0
  58. nlptools/arabert/araelectra/pretrain/pretrain_data.py +160 -0
  59. nlptools/arabert/araelectra/pretrain/pretrain_helpers.py +229 -0
  60. nlptools/arabert/araelectra/run_finetuning.py +323 -0
  61. nlptools/arabert/araelectra/run_pretraining.py +469 -0
  62. nlptools/arabert/araelectra/util/__init__.py +14 -0
  63. nlptools/arabert/araelectra/util/training_utils.py +112 -0
  64. nlptools/arabert/araelectra/util/utils.py +109 -0
  65. nlptools/arabert/aragpt2/__init__.py +2 -0
  66. nlptools/arabert/aragpt2/create_pretraining_data.py +95 -0
  67. nlptools/arabert/aragpt2/gpt2/__init__.py +2 -0
  68. nlptools/arabert/aragpt2/gpt2/lamb_optimizer.py +158 -0
  69. nlptools/arabert/aragpt2/gpt2/optimization.py +225 -0
  70. nlptools/arabert/aragpt2/gpt2/run_pretraining.py +397 -0
  71. nlptools/arabert/aragpt2/grover/__init__.py +0 -0
  72. nlptools/arabert/aragpt2/grover/dataloader.py +161 -0
  73. nlptools/arabert/aragpt2/grover/modeling.py +803 -0
  74. nlptools/arabert/aragpt2/grover/modeling_gpt2.py +1196 -0
  75. nlptools/arabert/aragpt2/grover/optimization_adafactor.py +234 -0
  76. nlptools/arabert/aragpt2/grover/train_tpu.py +187 -0
  77. nlptools/arabert/aragpt2/grover/utils.py +234 -0
  78. nlptools/arabert/aragpt2/train_bpe_tokenizer.py +59 -0
  79. nlptools/arabert/preprocess.py +818 -0
  80. nlptools/arabiner/__init__.py +0 -0
  81. nlptools/arabiner/bin/__init__.py +14 -0
  82. nlptools/arabiner/bin/eval.py +87 -0
  83. nlptools/arabiner/bin/infer.py +91 -0
  84. nlptools/arabiner/bin/process.py +140 -0
  85. nlptools/arabiner/bin/train.py +221 -0
  86. nlptools/arabiner/data/__init__.py +1 -0
  87. nlptools/arabiner/data/datasets.py +146 -0
  88. nlptools/arabiner/data/transforms.py +118 -0
  89. nlptools/arabiner/nn/BaseModel.py +22 -0
  90. nlptools/arabiner/nn/BertNestedTagger.py +34 -0
  91. nlptools/arabiner/nn/BertSeqTagger.py +17 -0
  92. nlptools/arabiner/nn/__init__.py +3 -0
  93. nlptools/arabiner/trainers/BaseTrainer.py +117 -0
  94. nlptools/arabiner/trainers/BertNestedTrainer.py +203 -0
  95. nlptools/arabiner/trainers/BertTrainer.py +163 -0
  96. nlptools/arabiner/trainers/__init__.py +3 -0
  97. nlptools/arabiner/utils/__init__.py +0 -0
  98. nlptools/arabiner/utils/data.py +124 -0
  99. nlptools/arabiner/utils/helpers.py +151 -0
  100. nlptools/arabiner/utils/metrics.py +69 -0
  101. nlptools/environment.yml +227 -0
  102. nlptools/install_env.py +13 -0
  103. nlptools/morphology/ALMA_multi_word.py +34 -0
  104. nlptools/morphology/__init__.py +52 -0
  105. nlptools/morphology/charsets.py +60 -0
  106. nlptools/morphology/morph_analyzer.py +170 -0
  107. nlptools/morphology/settings.py +8 -0
  108. nlptools/morphology/tokenizers_words.py +19 -0
  109. nlptools/nlptools.py +1 -0
  110. nlptools/salma/__init__.py +12 -0
  111. nlptools/salma/settings.py +31 -0
  112. nlptools/salma/views.py +459 -0
  113. nlptools/salma/wsd.py +126 -0
  114. nlptools/utils/__init__.py +0 -0
  115. nlptools/utils/corpus_tokenizer.py +73 -0
  116. nlptools/utils/implication.py +662 -0
  117. nlptools/utils/jaccard.py +247 -0
  118. nlptools/utils/parser.py +147 -0
  119. nlptools/utils/readfile.py +3 -0
  120. nlptools/utils/sentence_tokenizer.py +53 -0
  121. nlptools/utils/text_transliteration.py +232 -0
  122. nlptools/utils/utils.py +2 -0
@@ -0,0 +1,146 @@
1
+ import logging
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ from torch.nn.utils.rnn import pad_sequence
5
+ from nlptools.arabiner.data.transforms import (
6
+ BertSeqTransform,
7
+ NestedTagsTransform
8
+ )
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class Token:
14
+ def __init__(self, text=None, pred_tag=None, gold_tag=None):
15
+ """
16
+ Token object to hold token attributes
17
+ :param text: str
18
+ :param pred_tag: str
19
+ :param gold_tag: str
20
+ """
21
+ self.text = text
22
+ self.gold_tag = gold_tag
23
+ self.pred_tag = pred_tag
24
+ self.subwords = None
25
+
26
+ @property
27
+ def subwords(self):
28
+ return self._subwords
29
+
30
+ @subwords.setter
31
+ def subwords(self, value):
32
+ self._subwords = value
33
+
34
+ def __str__(self):
35
+ """
36
+ Token text representation
37
+ :return: str
38
+ """
39
+ gold_tags = "|".join(self.gold_tag)
40
+ pred_tags = "|".join([pred_tag["tag"] for pred_tag in self.pred_tag])
41
+
42
+ if self.gold_tag:
43
+ r = f"{self.text}\t{gold_tags}\t{pred_tags}"
44
+ else:
45
+ r = f"{self.text}\t{pred_tags}"
46
+
47
+ return r
48
+
49
+
50
+ class DefaultDataset(Dataset):
51
+ def __init__(
52
+ self,
53
+ examples=None,
54
+ vocab=None,
55
+ bert_model="aubmindlab/bert-base-arabertv2",
56
+ max_seq_len=512,
57
+ ):
58
+ """
59
+ The dataset that used to transform the segments into training data
60
+ :param examples: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
61
+ You can get generate examples from -- arabiner.data.dataset.parse_conll_files
62
+ :param vocab: vocab object containing indexed tags and tokens
63
+ :param bert_model: str - BERT model
64
+ :param: int - maximum sequence length
65
+ """
66
+ self.transform = BertSeqTransform(bert_model, vocab, max_seq_len=max_seq_len)
67
+ self.examples = examples
68
+ self.vocab = vocab
69
+
70
+ def __len__(self):
71
+ return len(self.examples)
72
+
73
+ def __getitem__(self, item):
74
+ subwords, tags, tokens, valid_len = self.transform(self.examples[item])
75
+ return subwords, tags, tokens, valid_len
76
+
77
+ def collate_fn(self, batch):
78
+ """
79
+ Collate function that is called when the batch is called by the trainer
80
+ :param batch: Dataloader batch
81
+ :return: Same output as the __getitem__ function
82
+ """
83
+ subwords, tags, tokens, valid_len = zip(*batch)
84
+
85
+ # Pad sequences in this batch
86
+ # subwords and tokens are padded with zeros
87
+ # tags are padding with the index of the O tag
88
+ subwords = pad_sequence(subwords, batch_first=True, padding_value=0)
89
+ tags = pad_sequence(
90
+ tags, batch_first=True, padding_value=self.vocab.tags[0].get_stoi()["O"]
91
+ )
92
+ return subwords, tags, tokens, valid_len
93
+
94
+
95
+ class NestedTagsDataset(Dataset):
96
+ def __init__(
97
+ self,
98
+ examples=None,
99
+ vocab=None,
100
+ bert_model="aubmindlab/bert-base-arabertv2",
101
+ max_seq_len=512,
102
+ ):
103
+ """
104
+ The dataset that used to transform the segments into training data
105
+ :param examples: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
106
+ You can get generate examples from -- arabiner.data.dataset.parse_conll_files
107
+ :param vocab: vocab object containing indexed tags and tokens
108
+ :param bert_model: str - BERT model
109
+ :param: int - maximum sequence length
110
+ """
111
+ self.transform = NestedTagsTransform(
112
+ bert_model, vocab, max_seq_len=max_seq_len
113
+ )
114
+ self.examples = examples
115
+ self.vocab = vocab
116
+
117
+ def __len__(self):
118
+ return len(self.examples)
119
+
120
+ def __getitem__(self, item):
121
+ subwords, tags, tokens, masks, valid_len = self.transform(self.examples[item])
122
+ return subwords, tags, tokens, masks, valid_len
123
+
124
+ def collate_fn(self, batch):
125
+ """
126
+ Collate function that is called when the batch is called by the trainer
127
+ :param batch: Dataloader batch
128
+ :return: Same output as the __getitem__ function
129
+ """
130
+ subwords, tags, tokens, masks, valid_len = zip(*batch)
131
+
132
+ # Pad sequences in this batch
133
+ # subwords and tokens are padded with zeros
134
+ # tags are padding with the index of the O tag
135
+ subwords = pad_sequence(subwords, batch_first=True, padding_value=0)
136
+
137
+ masks = [torch.nn.ConstantPad1d((0, subwords.shape[-1] - tag.shape[-1]), 0)(mask)
138
+ for tag, mask in zip(tags, masks)]
139
+ masks = torch.cat(masks)
140
+
141
+ # Pad the tags, do the padding for each tag type
142
+ tags = [torch.nn.ConstantPad1d((0, subwords.shape[-1] - tag.shape[-1]), vocab.get_stoi()["<pad>"])(tag)
143
+ for tag, vocab in zip(tags, self.vocab.tags[1:])]
144
+ tags = torch.cat(tags)
145
+
146
+ return subwords, tags, tokens, masks, valid_len
@@ -0,0 +1,118 @@
1
+ import torch
2
+ from transformers import BertTokenizer
3
+ from functools import partial
4
+ import re
5
+ import itertools
6
+ from nlptools.arabiner.data import datasets
7
+ class BertSeqTransform:
8
+ def __init__(self, bert_model, vocab, max_seq_len=512):
9
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model)
10
+ self.encoder = partial(
11
+ self.tokenizer.encode,
12
+ max_length=max_seq_len,
13
+ truncation=True,
14
+ )
15
+ self.max_seq_len = max_seq_len
16
+ self.vocab = vocab
17
+
18
+ def __call__(self, segment):
19
+ subwords, tags, tokens = list(), list(), list()
20
+ unk_token = datasets.Token(text="UNK")
21
+
22
+ for token in segment:
23
+ token_subwords = self.encoder(token.text)[1:-1]
24
+ subwords += token_subwords
25
+ tags += [self.vocab.tags[0].get_stoi()[token.gold_tag[0]]] + [self.vocab.tags[0].get_stoi()["O"]] * (len(token_subwords) - 1)
26
+ tokens += [token] + [unk_token] * (len(token_subwords) - 1)
27
+
28
+ # Truncate to max_seq_len
29
+ if len(subwords) > self.max_seq_len - 2:
30
+ text = " ".join([t.text for t in tokens if t.text != "UNK"])
31
+
32
+ subwords = subwords[:self.max_seq_len - 2]
33
+ tags = tags[:self.max_seq_len - 2]
34
+ tokens = tokens[:self.max_seq_len - 2]
35
+
36
+ subwords.insert(0, self.tokenizer.cls_token_id)
37
+ subwords.append(self.tokenizer.sep_token_id)
38
+
39
+ tags.insert(0, self.vocab.tags[0].get_stoi()["O"])
40
+ tags.append(self.vocab.tags[0].get_stoi()["O"])
41
+
42
+ tokens.insert(0, unk_token)
43
+ tokens.append(unk_token)
44
+
45
+ return torch.LongTensor(subwords), torch.LongTensor(tags), tokens, len(tokens)
46
+
47
+
48
+ class NestedTagsTransform:
49
+ def __init__(self, bert_model, vocab, max_seq_len=512):
50
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model)
51
+ self.encoder = partial(
52
+ self.tokenizer.encode,
53
+ max_length=max_seq_len,
54
+ truncation=True,
55
+ )
56
+ self.max_seq_len = max_seq_len
57
+ self.vocab = vocab
58
+
59
+ def __call__(self, segment):
60
+ tags, tokens, subwords = list(), list(), list()
61
+ unk_token = datasets.Token(text="UNK")
62
+
63
+ # Encode each token and get its subwords and IDs
64
+ for token in segment:
65
+ token.subwords = self.encoder(token.text)[1:-1]
66
+ subwords += token.subwords
67
+ tokens += [token] + [unk_token] * (len(token.subwords ) - 1)
68
+
69
+ # Construct the labels for each tag type
70
+ # The sequence will have a list of tags for each type
71
+ # The final tags for a sequence is a matrix NUM_TAG_TYPES x SEQ_LEN
72
+ # Example:
73
+ # [
74
+ # [O, O, B-PERS, I-PERS, O, O, O]
75
+ # [B-ORG, I-ORG, O, O, O, O, O]
76
+ # [O, O, O, O, O, O, B-GPE]
77
+ # ]
78
+ for vocab in self.vocab.tags[1:]:
79
+ vocab_tags = "|".join([t for t in vocab.get_itos() if "-" in t])
80
+ r = re.compile(vocab_tags)
81
+
82
+ # This is really messy
83
+ # For a given token we find a matching tag_name, BUT we might find
84
+ # multiple matches (i.e. a token can be labeled B-ORG and I-ORG) in this
85
+ # case we get only the first tag as we do not have overlapping of same type
86
+ single_type_tags = [[(list(filter(r.match, token.gold_tag))
87
+ or ["O"])[0]] + ["O"] * (len(token.subwords) - 1)
88
+ for token in segment]
89
+ single_type_tags = list(itertools.chain(*single_type_tags))
90
+ tags.append([vocab.get_stoi()[tag] for tag in single_type_tags])
91
+
92
+ # Truncate to max_seq_len
93
+ if len(subwords) > self.max_seq_len - 2:
94
+ text = " ".join([t.text for t in tokens if t.text != "UNK"])
95
+
96
+ subwords = subwords[:self.max_seq_len - 2]
97
+ tags = [t[:self.max_seq_len - 2] for t in tags]
98
+ tokens = tokens[:self.max_seq_len - 2]
99
+
100
+ # Add dummy token at the start end of sequence
101
+ tokens.insert(0, unk_token)
102
+ tokens.append(unk_token)
103
+
104
+ # Add CLS and SEP at start end of subwords
105
+ subwords.insert(0, self.tokenizer.cls_token_id)
106
+ subwords.append(self.tokenizer.sep_token_id)
107
+ subwords = torch.LongTensor(subwords)
108
+
109
+ # Add "O" tags for the first and last subwords
110
+ tags = torch.Tensor(tags)
111
+ tags = torch.column_stack((
112
+ torch.Tensor([vocab.get_stoi()["O"] for vocab in self.vocab.tags[1:]]),
113
+ tags,
114
+ torch.Tensor([vocab.get_stoi()["O"] for vocab in self.vocab.tags[1:]]),
115
+ )).unsqueeze(0)
116
+
117
+ mask = torch.ones_like(tags)
118
+ return subwords, tags, tokens, mask, len(tokens)
@@ -0,0 +1,22 @@
1
+ from torch import nn
2
+ from transformers import BertModel
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class BaseModel(nn.Module):
9
+ def __init__(self,
10
+ bert_model="aubmindlab/bert-base-arabertv2",
11
+ num_labels=2,
12
+ dropout=0.1,
13
+ num_types=0):
14
+ super().__init__()
15
+
16
+ self.bert_model = bert_model
17
+ self.num_labels = num_labels
18
+ self.num_types = num_types
19
+ self.dropout = dropout
20
+
21
+ self.bert = BertModel.from_pretrained(bert_model)
22
+ self.dropout = nn.Dropout(dropout)
@@ -0,0 +1,34 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from nlptools.arabiner.nn import BaseModel
4
+
5
+
6
+ class BertNestedTagger(BaseModel):
7
+ def __init__(self, **kwargs):
8
+ super(BertNestedTagger, self).__init__(**kwargs)
9
+
10
+ self.max_num_labels = max(self.num_labels)
11
+ classifiers = [nn.Linear(768, num_labels) for num_labels in self.num_labels]
12
+ self.classifiers = torch.nn.Sequential(*classifiers)
13
+
14
+ def forward(self, x):
15
+ y = self.bert(x)
16
+ y = self.dropout(y["last_hidden_state"])
17
+ output = list()
18
+
19
+ for i, classifier in enumerate(self.classifiers):
20
+ logits = classifier(y)
21
+
22
+ # Pad logits to allow Multi-GPU/DataParallel training to work
23
+ # We will truncate the padded dimensions when we compute the loss in the trainer
24
+ logits = torch.nn.ConstantPad1d((0, self.max_num_labels - logits.shape[-1]), 0)(logits)
25
+ output.append(logits)
26
+
27
+ # Return tensor of the shape B x T x L x C
28
+ # B: batch size
29
+ # T: sequence length
30
+ # L: number of tag types
31
+ # C: number of classes per tag type
32
+ output = torch.stack(output).permute((1, 2, 0, 3))
33
+ return output
34
+
@@ -0,0 +1,17 @@
1
+ import torch.nn as nn
2
+ from transformers import BertModel
3
+
4
+
5
+ class BertSeqTagger(nn.Module):
6
+ def __init__(self, bert_model, num_labels=2, dropout=0.1):
7
+ super().__init__()
8
+
9
+ self.bert = BertModel.from_pretrained(bert_model)
10
+ self.dropout = nn.Dropout(dropout)
11
+ self.linear = nn.Linear(768, num_labels)
12
+
13
+ def forward(self, x):
14
+ y = self.bert(x)
15
+ y = self.dropout(y["last_hidden_state"])
16
+ logits = self.linear(y)
17
+ return logits
@@ -0,0 +1,3 @@
1
+ from nlptools.arabiner.nn.BaseModel import BaseModel
2
+ from nlptools.arabiner.nn.BertSeqTagger import BertSeqTagger
3
+ from nlptools.arabiner.nn.BertNestedTagger import BertNestedTagger
@@ -0,0 +1,117 @@
1
+ import os
2
+ import torch
3
+ import logging
4
+ import natsort
5
+ import glob
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class BaseTrainer:
11
+ def __init__(
12
+ self,
13
+ model=None,
14
+ max_epochs=50,
15
+ optimizer=None,
16
+ scheduler=None,
17
+ loss=None,
18
+ train_dataloader=None,
19
+ val_dataloader=None,
20
+ test_dataloader=None,
21
+ log_interval=10,
22
+ summary_writer=None,
23
+ output_path=None,
24
+ clip=5,
25
+ patience=5
26
+ ):
27
+ self.model = model
28
+ self.max_epochs = max_epochs
29
+ self.train_dataloader = train_dataloader
30
+ self.val_dataloader = val_dataloader
31
+ self.test_dataloader = test_dataloader
32
+ self.optimizer = optimizer
33
+ self.scheduler = scheduler
34
+ self.loss = loss
35
+ self.log_interval = log_interval
36
+ self.summary_writer = summary_writer
37
+ self.output_path = output_path
38
+ self.current_timestep = 0
39
+ self.current_epoch = 0
40
+ self.clip = clip
41
+ self.patience = patience
42
+
43
+ def tag(self, dataloader, is_train=True):
44
+ """
45
+ Given a dataloader containing segments, predict the tags
46
+ :param dataloader: torch.utils.data.DataLoader
47
+ :param is_train: boolean - True for training model, False for evaluation
48
+ :return: Iterator
49
+ subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
50
+ gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
51
+ tokens - List[arabiner.data.dataset.Token] - list of tokens
52
+ valid_len (B x 1) - int - valiud length of each sequence
53
+ logits (B x T x NUM_LABELS) - logits for each token and each tag
54
+ """
55
+ for subwords, gold_tags, tokens, valid_len in dataloader:
56
+ self.model.train(is_train)
57
+
58
+ if torch.cuda.is_available():
59
+ subwords = subwords.cuda()
60
+ gold_tags = gold_tags.cuda()
61
+
62
+ if is_train:
63
+ self.optimizer.zero_grad()
64
+ logits = self.model(subwords)
65
+ else:
66
+ with torch.no_grad():
67
+ logits = self.model(subwords)
68
+
69
+ yield subwords, gold_tags, tokens, valid_len, logits
70
+
71
+ def segments_to_file(self, segments, filename):
72
+ """
73
+ Write segments to file
74
+ :param segments: [List[arabiner.data.dataset.Token]] - list of list of tokens
75
+ :param filename: str - output filename
76
+ :return: None
77
+ """
78
+ with open(filename, "w") as fh:
79
+ results = "\n\n".join(["\n".join([t.__str__() for t in segment]) for segment in segments])
80
+ fh.write("Token\tGold Tag\tPredicted Tag\n")
81
+ fh.write(results)
82
+ logging.info("Predictions written to %s", filename)
83
+
84
+ def save(self):
85
+ """
86
+ Save model checkpoint
87
+ :return:
88
+ """
89
+ filename = os.path.join(
90
+ self.output_path,
91
+ "checkpoints",
92
+ "checkpoint_{}.pt".format(self.current_epoch),
93
+ )
94
+
95
+ checkpoint = {
96
+ "model": self.model.state_dict(),
97
+ "optimizer": self.optimizer.state_dict(),
98
+ "epoch": self.current_epoch
99
+ }
100
+
101
+ logger.info("Saving checkpoint to %s", filename)
102
+ torch.save(checkpoint, filename)
103
+
104
+ def load(self, checkpoint_path):
105
+ """
106
+ Load model checkpoint
107
+ :param checkpoint_path: str - path/to/checkpoints
108
+ :return: None
109
+ """
110
+ checkpoint_path = natsort.natsorted(glob.glob(f"{checkpoint_path}/checkpoint_*.pt"))
111
+ checkpoint_path = checkpoint_path[-1]
112
+
113
+ logger.info("Loading checkpoint %s", checkpoint_path)
114
+
115
+ device = None if torch.cuda.is_available() else torch.device('cpu')
116
+ checkpoint = torch.load(checkpoint_path, map_location=device)
117
+ self.model.load_state_dict(checkpoint["model"])
@@ -0,0 +1,203 @@
1
+ import os
2
+ import logging
3
+ import torch
4
+ import numpy as np
5
+ from nlptools.arabiner.trainers import BaseTrainer
6
+ from nlptools.arabiner.utils.metrics import compute_nested_metrics
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class BertNestedTrainer(BaseTrainer):
12
+ def __init__(self, **kwargs):
13
+ super().__init__(**kwargs)
14
+
15
+ def train(self):
16
+ best_val_loss, test_loss = np.inf, np.inf
17
+ num_train_batch = len(self.train_dataloader)
18
+ num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
19
+ patience = self.patience
20
+
21
+ for epoch_index in range(self.max_epochs):
22
+ self.current_epoch = epoch_index
23
+ train_loss = 0
24
+
25
+ for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
26
+ self.train_dataloader, is_train=True
27
+ ), 1):
28
+ self.current_timestep += 1
29
+
30
+ # Compute loses for each output
31
+ # logits = B x T x L x C
32
+ losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
33
+ torch.reshape(gold_tags[:, i, :], (-1,)).long())
34
+ for i, l in enumerate(num_labels)]
35
+
36
+ torch.autograd.backward(losses)
37
+
38
+ # Avoid exploding gradient by doing gradient clipping
39
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
40
+
41
+ self.optimizer.step()
42
+ self.scheduler.step()
43
+ batch_loss = sum(l.item() for l in losses)
44
+ train_loss += batch_loss
45
+
46
+ if self.current_timestep % self.log_interval == 0:
47
+ logger.info(
48
+ "Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
49
+ epoch_index,
50
+ batch_index,
51
+ num_train_batch,
52
+ self.current_timestep,
53
+ self.optimizer.param_groups[0]['lr'],
54
+ batch_loss
55
+ )
56
+
57
+ train_loss /= num_train_batch
58
+
59
+ logger.info("** Evaluating on validation dataset **")
60
+ val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
61
+ val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
62
+
63
+ epoch_summary_loss = {
64
+ "train_loss": train_loss,
65
+ "val_loss": val_loss
66
+ }
67
+ epoch_summary_metrics = {
68
+ "val_micro_f1": val_metrics.micro_f1,
69
+ "val_precision": val_metrics.precision,
70
+ "val_recall": val_metrics.recall
71
+ }
72
+
73
+ logger.info(
74
+ "Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
75
+ epoch_index,
76
+ self.current_timestep,
77
+ train_loss,
78
+ val_loss,
79
+ val_metrics.micro_f1
80
+ )
81
+
82
+ if val_loss < best_val_loss:
83
+ patience = self.patience
84
+ best_val_loss = val_loss
85
+ logger.info("** Validation improved, evaluating test data **")
86
+ test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
87
+ self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
88
+ test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
89
+
90
+ epoch_summary_loss["test_loss"] = test_loss
91
+ epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
92
+ epoch_summary_metrics["test_precision"] = test_metrics.precision
93
+ epoch_summary_metrics["test_recall"] = test_metrics.recall
94
+
95
+ logger.info(
96
+ f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
97
+ epoch_index,
98
+ self.current_timestep,
99
+ test_loss,
100
+ test_metrics.micro_f1
101
+ )
102
+
103
+ self.save()
104
+ else:
105
+ patience -= 1
106
+
107
+ # No improvements, terminating early
108
+ if patience == 0:
109
+ logger.info("Early termination triggered")
110
+ break
111
+
112
+ self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
113
+ self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
114
+
115
+ def tag(self, dataloader, is_train=True):
116
+ """
117
+ Given a dataloader containing segments, predict the tags
118
+ :param dataloader: torch.utils.data.DataLoader
119
+ :param is_train: boolean - True for training model, False for evaluation
120
+ :return: Iterator
121
+ subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
122
+ gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
123
+ tokens - List[arabiner.data.dataset.Token] - list of tokens
124
+ valid_len (B x 1) - int - valiud length of each sequence
125
+ logits (B x T x NUM_LABELS) - logits for each token and each tag
126
+ """
127
+ for subwords, gold_tags, tokens, mask, valid_len in dataloader:
128
+ self.model.train(is_train)
129
+
130
+ if torch.cuda.is_available():
131
+ subwords = subwords.cuda()
132
+ gold_tags = gold_tags.cuda()
133
+
134
+ if is_train:
135
+ self.optimizer.zero_grad()
136
+ logits = self.model(subwords)
137
+ else:
138
+ with torch.no_grad():
139
+ logits = self.model(subwords)
140
+
141
+ yield subwords, gold_tags, tokens, valid_len, logits
142
+
143
+ def eval(self, dataloader):
144
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
145
+ num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
146
+ loss = 0
147
+
148
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
149
+ dataloader, is_train=False
150
+ ):
151
+ losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
152
+ torch.reshape(gold_tags[:, i, :], (-1,)).long())
153
+ for i, l in enumerate(num_labels)]
154
+ loss += sum(losses)
155
+ preds += torch.argmax(logits, dim=3)
156
+ segments += tokens
157
+ valid_lens += list(valid_len)
158
+
159
+ loss /= len(dataloader)
160
+
161
+ # Update segments, attach predicted tags to each token
162
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
163
+
164
+ return preds, segments, valid_lens, loss
165
+
166
+ def infer(self, dataloader):
167
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
168
+
169
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
170
+ dataloader, is_train=False
171
+ ):
172
+ preds += torch.argmax(logits, dim=3)
173
+ segments += tokens
174
+ valid_lens += list(valid_len)
175
+
176
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
177
+ return segments
178
+
179
+ def to_segments(self, segments, preds, valid_lens, vocab):
180
+ if vocab is None:
181
+ vocab = self.vocab
182
+
183
+ tagged_segments = list()
184
+ tokens_stoi = vocab.tokens.get_stoi()
185
+ unk_id = tokens_stoi["UNK"]
186
+
187
+ for segment, pred, valid_len in zip(segments, preds, valid_lens):
188
+ # First, the token at 0th index [CLS] and token at nth index [SEP]
189
+ # Combine the tokens with their corresponding predictions
190
+ segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
191
+
192
+ # Ignore the sub-tokens/subwords, which are identified with text being UNK
193
+ segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
194
+
195
+ # Attach the predicted tags to each token
196
+ list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
197
+ for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
198
+
199
+ # We are only interested in the tagged tokens, we do no longer need raw model predictions
200
+ tagged_segment = [t for t, _ in segment_pred]
201
+ tagged_segments.append(tagged_segment)
202
+
203
+ return tagged_segments