SinaTools 0.1.35__py2.py3-none-any.whl → 0.1.37__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,203 +1,203 @@
1
- import os
2
- import logging
3
- import torch
4
- import numpy as np
5
- from sinatools.ner.trainers import BaseTrainer
6
- from sinatools.ner.metrics import compute_nested_metrics
7
-
8
- logger = logging.getLogger(__name__)
9
-
10
-
11
- class BertNestedTrainer(BaseTrainer):
12
- def __init__(self, **kwargs):
13
- super().__init__(**kwargs)
14
-
15
- def train(self):
16
- best_val_loss, test_loss = np.inf, np.inf
17
- num_train_batch = len(self.train_dataloader)
18
- num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
19
- patience = self.patience
20
-
21
- for epoch_index in range(self.max_epochs):
22
- self.current_epoch = epoch_index
23
- train_loss = 0
24
-
25
- for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
26
- self.train_dataloader, is_train=True
27
- ), 1):
28
- self.current_timestep += 1
29
-
30
- # Compute loses for each output
31
- # logits = B x T x L x C
32
- losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
33
- torch.reshape(gold_tags[:, i, :], (-1,)).long())
34
- for i, l in enumerate(num_labels)]
35
-
36
- torch.autograd.backward(losses)
37
-
38
- # Avoid exploding gradient by doing gradient clipping
39
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
40
-
41
- self.optimizer.step()
42
- self.scheduler.step()
43
- batch_loss = sum(l.item() for l in losses)
44
- train_loss += batch_loss
45
-
46
- if self.current_timestep % self.log_interval == 0:
47
- logger.info(
48
- "Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
49
- epoch_index,
50
- batch_index,
51
- num_train_batch,
52
- self.current_timestep,
53
- self.optimizer.param_groups[0]['lr'],
54
- batch_loss
55
- )
56
-
57
- train_loss /= num_train_batch
58
-
59
- logger.info("** Evaluating on validation dataset **")
60
- val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
61
- val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
62
-
63
- epoch_summary_loss = {
64
- "train_loss": train_loss,
65
- "val_loss": val_loss
66
- }
67
- epoch_summary_metrics = {
68
- "val_micro_f1": val_metrics.micro_f1,
69
- "val_precision": val_metrics.precision,
70
- "val_recall": val_metrics.recall
71
- }
72
-
73
- logger.info(
74
- "Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
75
- epoch_index,
76
- self.current_timestep,
77
- train_loss,
78
- val_loss,
79
- val_metrics.micro_f1
80
- )
81
-
82
- if val_loss < best_val_loss:
83
- patience = self.patience
84
- best_val_loss = val_loss
85
- logger.info("** Validation improved, evaluating test data **")
86
- test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
87
- self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
88
- test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
89
-
90
- epoch_summary_loss["test_loss"] = test_loss
91
- epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
92
- epoch_summary_metrics["test_precision"] = test_metrics.precision
93
- epoch_summary_metrics["test_recall"] = test_metrics.recall
94
-
95
- logger.info(
96
- f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
97
- epoch_index,
98
- self.current_timestep,
99
- test_loss,
100
- test_metrics.micro_f1
101
- )
102
-
103
- self.save()
104
- else:
105
- patience -= 1
106
-
107
- # No improvements, terminating early
108
- if patience == 0:
109
- logger.info("Early termination triggered")
110
- break
111
-
112
- self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
113
- self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
114
-
115
- def tag(self, dataloader, is_train=True):
116
- """
117
- Given a dataloader containing segments, predict the tags
118
- :param dataloader: torch.utils.data.DataLoader
119
- :param is_train: boolean - True for training model, False for evaluation
120
- :return: Iterator
121
- subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
122
- gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
123
- tokens - List[arabiner.data.dataset.Token] - list of tokens
124
- valid_len (B x 1) - int - valiud length of each sequence
125
- logits (B x T x NUM_LABELS) - logits for each token and each tag
126
- """
127
- for subwords, gold_tags, tokens, mask, valid_len in dataloader:
128
- self.model.train(is_train)
129
-
130
- if torch.cuda.is_available():
131
- subwords = subwords.cuda()
132
- gold_tags = gold_tags.cuda()
133
-
134
- if is_train:
135
- self.optimizer.zero_grad()
136
- logits = self.model(subwords)
137
- else:
138
- with torch.no_grad():
139
- logits = self.model(subwords)
140
-
141
- yield subwords, gold_tags, tokens, valid_len, logits
142
-
143
- def eval(self, dataloader):
144
- golds, preds, segments, valid_lens = list(), list(), list(), list()
145
- num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
146
- loss = 0
147
-
148
- for _, gold_tags, tokens, valid_len, logits in self.tag(
149
- dataloader, is_train=False
150
- ):
151
- losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
152
- torch.reshape(gold_tags[:, i, :], (-1,)).long())
153
- for i, l in enumerate(num_labels)]
154
- loss += sum(losses)
155
- preds += torch.argmax(logits, dim=3)
156
- segments += tokens
157
- valid_lens += list(valid_len)
158
-
159
- loss /= len(dataloader)
160
-
161
- # Update segments, attach predicted tags to each token
162
- segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
163
-
164
- return preds, segments, valid_lens, loss
165
-
166
- def infer(self, dataloader):
167
- golds, preds, segments, valid_lens = list(), list(), list(), list()
168
-
169
- for _, gold_tags, tokens, valid_len, logits in self.tag(
170
- dataloader, is_train=False
171
- ):
172
- preds += torch.argmax(logits, dim=3)
173
- segments += tokens
174
- valid_lens += list(valid_len)
175
-
176
- segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
177
- return segments
178
-
179
- def to_segments(self, segments, preds, valid_lens, vocab):
180
- if vocab is None:
181
- vocab = self.vocab
182
-
183
- tagged_segments = list()
184
- tokens_stoi = vocab.tokens.get_stoi()
185
- unk_id = tokens_stoi["UNK"]
186
-
187
- for segment, pred, valid_len in zip(segments, preds, valid_lens):
188
- # First, the token at 0th index [CLS] and token at nth index [SEP]
189
- # Combine the tokens with their corresponding predictions
190
- segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
191
-
192
- # Ignore the sub-tokens/subwords, which are identified with text being UNK
193
- segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
194
-
195
- # Attach the predicted tags to each token
196
- list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
197
- for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
198
-
199
- # We are only interested in the tagged tokens, we do no longer need raw model predictions
200
- tagged_segment = [t for t, _ in segment_pred]
201
- tagged_segments.append(tagged_segment)
202
-
203
- return tagged_segments
1
+ import os
2
+ import logging
3
+ import torch
4
+ import numpy as np
5
+ from sinatools.ner.trainers import BaseTrainer
6
+ from sinatools.ner.metrics import compute_nested_metrics
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class BertNestedTrainer(BaseTrainer):
12
+ def __init__(self, **kwargs):
13
+ super().__init__(**kwargs)
14
+
15
+ def train(self):
16
+ best_val_loss, test_loss = np.inf, np.inf
17
+ num_train_batch = len(self.train_dataloader)
18
+ num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
19
+ patience = self.patience
20
+
21
+ for epoch_index in range(self.max_epochs):
22
+ self.current_epoch = epoch_index
23
+ train_loss = 0
24
+
25
+ for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
26
+ self.train_dataloader, is_train=True
27
+ ), 1):
28
+ self.current_timestep += 1
29
+
30
+ # Compute loses for each output
31
+ # logits = B x T x L x C
32
+ losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
33
+ torch.reshape(gold_tags[:, i, :], (-1,)).long())
34
+ for i, l in enumerate(num_labels)]
35
+
36
+ torch.autograd.backward(losses)
37
+
38
+ # Avoid exploding gradient by doing gradient clipping
39
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
40
+
41
+ self.optimizer.step()
42
+ self.scheduler.step()
43
+ batch_loss = sum(l.item() for l in losses)
44
+ train_loss += batch_loss
45
+
46
+ if self.current_timestep % self.log_interval == 0:
47
+ logger.info(
48
+ "Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
49
+ epoch_index,
50
+ batch_index,
51
+ num_train_batch,
52
+ self.current_timestep,
53
+ self.optimizer.param_groups[0]['lr'],
54
+ batch_loss
55
+ )
56
+
57
+ train_loss /= num_train_batch
58
+
59
+ logger.info("** Evaluating on validation dataset **")
60
+ val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
61
+ val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
62
+
63
+ epoch_summary_loss = {
64
+ "train_loss": train_loss,
65
+ "val_loss": val_loss
66
+ }
67
+ epoch_summary_metrics = {
68
+ "val_micro_f1": val_metrics.micro_f1,
69
+ "val_precision": val_metrics.precision,
70
+ "val_recall": val_metrics.recall
71
+ }
72
+
73
+ logger.info(
74
+ "Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
75
+ epoch_index,
76
+ self.current_timestep,
77
+ train_loss,
78
+ val_loss,
79
+ val_metrics.micro_f1
80
+ )
81
+
82
+ if val_loss < best_val_loss:
83
+ patience = self.patience
84
+ best_val_loss = val_loss
85
+ logger.info("** Validation improved, evaluating test data **")
86
+ test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
87
+ self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
88
+ test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
89
+
90
+ epoch_summary_loss["test_loss"] = test_loss
91
+ epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
92
+ epoch_summary_metrics["test_precision"] = test_metrics.precision
93
+ epoch_summary_metrics["test_recall"] = test_metrics.recall
94
+
95
+ logger.info(
96
+ f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
97
+ epoch_index,
98
+ self.current_timestep,
99
+ test_loss,
100
+ test_metrics.micro_f1
101
+ )
102
+
103
+ self.save()
104
+ else:
105
+ patience -= 1
106
+
107
+ # No improvements, terminating early
108
+ if patience == 0:
109
+ logger.info("Early termination triggered")
110
+ break
111
+
112
+ self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
113
+ self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
114
+
115
+ def tag(self, dataloader, is_train=True):
116
+ """
117
+ Given a dataloader containing segments, predict the tags
118
+ :param dataloader: torch.utils.data.DataLoader
119
+ :param is_train: boolean - True for training model, False for evaluation
120
+ :return: Iterator
121
+ subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
122
+ gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
123
+ tokens - List[arabiner.data.dataset.Token] - list of tokens
124
+ valid_len (B x 1) - int - valiud length of each sequence
125
+ logits (B x T x NUM_LABELS) - logits for each token and each tag
126
+ """
127
+ for subwords, gold_tags, tokens, mask, valid_len in dataloader:
128
+ self.model.train(is_train)
129
+
130
+ if torch.cuda.is_available():
131
+ subwords = subwords.cuda()
132
+ gold_tags = gold_tags.cuda()
133
+
134
+ if is_train:
135
+ self.optimizer.zero_grad()
136
+ logits = self.model(subwords)
137
+ else:
138
+ with torch.no_grad():
139
+ logits = self.model(subwords)
140
+
141
+ yield subwords, gold_tags, tokens, valid_len, logits
142
+
143
+ def eval(self, dataloader):
144
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
145
+ num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
146
+ loss = 0
147
+
148
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
149
+ dataloader, is_train=False
150
+ ):
151
+ losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
152
+ torch.reshape(gold_tags[:, i, :], (-1,)).long())
153
+ for i, l in enumerate(num_labels)]
154
+ loss += sum(losses)
155
+ preds += torch.argmax(logits, dim=3)
156
+ segments += tokens
157
+ valid_lens += list(valid_len)
158
+
159
+ loss /= len(dataloader)
160
+
161
+ # Update segments, attach predicted tags to each token
162
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
163
+
164
+ return preds, segments, valid_lens, loss
165
+
166
+ def infer(self, dataloader):
167
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
168
+
169
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
170
+ dataloader, is_train=False
171
+ ):
172
+ preds += torch.argmax(logits, dim=3)
173
+ segments += tokens
174
+ valid_lens += list(valid_len)
175
+
176
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
177
+ return segments
178
+
179
+ def to_segments(self, segments, preds, valid_lens, vocab):
180
+ if vocab is None:
181
+ vocab = self.vocab
182
+
183
+ tagged_segments = list()
184
+ tokens_stoi = vocab.tokens.get_stoi()
185
+ unk_id = tokens_stoi["UNK"]
186
+
187
+ for segment, pred, valid_len in zip(segments, preds, valid_lens):
188
+ # First, the token at 0th index [CLS] and token at nth index [SEP]
189
+ # Combine the tokens with their corresponding predictions
190
+ segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
191
+
192
+ # Ignore the sub-tokens/subwords, which are identified with text being UNK
193
+ segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
194
+
195
+ # Attach the predicted tags to each token
196
+ list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
197
+ for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
198
+
199
+ # We are only interested in the tagged tokens, we do no longer need raw model predictions
200
+ tagged_segment = [t for t, _ in segment_pred]
201
+ tagged_segments.append(tagged_segment)
202
+
203
+ return tagged_segments