SinaTools 0.1.11__py2.py3-none-any.whl → 0.1.12__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {SinaTools-0.1.11.dist-info → SinaTools-0.1.12.dist-info}/METADATA +2 -3
  2. {SinaTools-0.1.11.dist-info → SinaTools-0.1.12.dist-info}/RECORD +47 -26
  3. {SinaTools-0.1.11.dist-info → SinaTools-0.1.12.dist-info}/entry_points.txt +7 -3
  4. sinatools/CLI/DataDownload/download_files.py +0 -10
  5. sinatools/CLI/ner/corpus_entity_extractor.py +6 -6
  6. sinatools/CLI/ner/entity_extractor.py +18 -42
  7. sinatools/CLI/utils/arStrip.py +8 -8
  8. sinatools/CLI/utils/implication.py +0 -8
  9. sinatools/CLI/utils/jaccard.py +5 -14
  10. sinatools/CLI/utils/remove_latin.py +2 -2
  11. sinatools/CLI/utils/text_dublication_detector.py +25 -0
  12. sinatools/VERSION +1 -1
  13. sinatools/morphology/ALMA_multi_word.py +14 -16
  14. sinatools/morphology/__init__.py +32 -31
  15. sinatools/ner/__init__.py +28 -2
  16. sinatools/ner/data/__init__.py +1 -0
  17. sinatools/ner/data/datasets.py +146 -0
  18. sinatools/ner/data/transforms.py +118 -0
  19. sinatools/ner/data.py +124 -0
  20. sinatools/ner/data_format.py +124 -0
  21. sinatools/ner/datasets.py +146 -0
  22. sinatools/ner/entity_extractor.py +34 -54
  23. sinatools/ner/helpers.py +86 -0
  24. sinatools/ner/metrics.py +69 -0
  25. sinatools/ner/nn/BaseModel.py +22 -0
  26. sinatools/ner/nn/BertNestedTagger.py +34 -0
  27. sinatools/ner/nn/BertSeqTagger.py +17 -0
  28. sinatools/ner/nn/__init__.py +3 -0
  29. sinatools/ner/trainers/BaseTrainer.py +117 -0
  30. sinatools/ner/trainers/BertNestedTrainer.py +203 -0
  31. sinatools/ner/trainers/BertTrainer.py +163 -0
  32. sinatools/ner/trainers/__init__.py +3 -0
  33. sinatools/ner/transforms.py +119 -0
  34. sinatools/semantic_relatedness/__init__.py +20 -0
  35. sinatools/semantic_relatedness/compute_relatedness.py +31 -0
  36. sinatools/synonyms/__init__.py +18 -0
  37. sinatools/synonyms/synonyms_generator.py +192 -0
  38. sinatools/utils/text_dublication_detector.py +110 -0
  39. sinatools/wsd/__init__.py +11 -0
  40. sinatools/{salma/views.py → wsd/disambiguator.py} +135 -94
  41. sinatools/{salma → wsd}/wsd.py +1 -1
  42. sinatools/CLI/salma/salma_tools.py +0 -68
  43. sinatools/salma/__init__.py +0 -12
  44. sinatools/utils/utils.py +0 -2
  45. {SinaTools-0.1.11.data → SinaTools-0.1.12.data}/data/sinatools/environment.yml +0 -0
  46. {SinaTools-0.1.11.dist-info → SinaTools-0.1.12.dist-info}/AUTHORS.rst +0 -0
  47. {SinaTools-0.1.11.dist-info → SinaTools-0.1.12.dist-info}/LICENSE +0 -0
  48. {SinaTools-0.1.11.dist-info → SinaTools-0.1.12.dist-info}/WHEEL +0 -0
  49. {SinaTools-0.1.11.dist-info → SinaTools-0.1.12.dist-info}/top_level.txt +0 -0
  50. /sinatools/{salma → wsd}/settings.py +0 -0
@@ -1,51 +1,51 @@
1
1
  import os
2
2
  from collections import namedtuple
3
- from sinatools.ner.utils.helpers import load_checkpoint
4
- from sinatools.ner.utils.data import get_dataloaders, text2segments
5
- from sinatools.DataDownload import downloader
6
- from . import tag_vocab, train_config, tagger
3
+ from sinatools.ner.data_format import get_dataloaders, text2segments
4
+ from . import tagger, tag_vocab, train_config
7
5
 
8
- def ner(text, batch_size=32):
6
+ def extract(text, batch_size=32):
9
7
  """
10
- This method takes a text as input, and a batch size, then performs named entity recognition (NER) on the input text and returns a list of tagged mentions.
8
+ This method processes an input text and returns named entites for each token within the text, based on the specified batch size. As follows:
11
9
 
12
10
  Args:
13
- text (str): The input text to perform NER on.
11
+ text (:obj:`str`): The Arabic text to be tagged.
14
12
  batch_size (int, optional): Batch size for inference. Default is 32.
15
13
 
16
14
  Returns:
17
- list: A list of lists containing token and label pairs for each segment.
18
- Each inner list has the format ['token', 'label1 label2 ...'].
15
+ list (:obj:`list`): A list of JSON objects, where each JSON could be contains:
16
+ token: The token from the original text.
17
+ NER tag: The label pairs for each segment.
18
+
19
19
  **Example:**
20
20
 
21
21
  .. highlight:: python
22
22
  .. code-block:: python
23
23
 
24
- from sinatools.arabiner.bin import infer
25
- infer.ner('ذهب محمد الى جامعة بيرزيت')
26
-
27
- #the output
28
- [['ذهب', 'O'],
29
- ['محمد', 'B-PERS'],
30
- ['الى', 'O'],
31
- ['جامعة', 'B-ORG'],
32
- ['بيرزيت', 'B-GPE I-ORG']]
24
+ from sinatools.ner.entity_extractor import extract
25
+ extract('ذهب محمد إلى جامعة بيرزيت')
26
+ [{
27
+ "word":"ذهب",
28
+ "tags":"O"
29
+ },{
30
+ "word":"محمد",
31
+ "tags":"B-PERS"
32
+ },{
33
+ "word":"إلى",
34
+ "tags":"O"
35
+ },{
36
+ "word":"جامعة",
37
+ "tags":"B-ORG"
38
+ },{
39
+ "word":"بيرزيت",
40
+ "tags":"B-GPE I-ORG"
41
+ }]
33
42
  """
34
- # Load tagger
35
- # filename = 'Wj27012000.tar'
36
- # path =downloader.get_appdatadir()
37
- # model_path = os.path.join(path, filename)
38
- # print('1',model_path)
39
- # tagger, tag_vocab, train_config = load_checkpoint(model_path)
40
43
 
41
-
42
- # Convert text to a tagger dataset and index the tokens in args.text
43
44
  dataset, token_vocab = text2segments(text)
44
45
 
45
46
  vocabs = namedtuple("Vocab", ["tags", "tokens"])
46
47
  vocab = vocabs(tokens=token_vocab, tags=tag_vocab)
47
48
 
48
- # From the datasets generate the dataloaders
49
49
  dataloader = get_dataloaders(
50
50
  (dataset,),
51
51
  vocab,
@@ -54,39 +54,19 @@ def ner(text, batch_size=32):
54
54
  shuffle=(False,),
55
55
  )[0]
56
56
 
57
- # Perform inference on the text and get back the tagged segments
57
+
58
58
  segments = tagger.infer(dataloader)
59
59
  segments_lists = []
60
- # Print results
60
+
61
61
  for segment in segments:
62
62
  for token in segment:
63
- segments_list = []
64
- segments_list.append(token.text)
65
- #print([t['tag'] for t in token.pred_tag])
63
+ segments_list = {}
64
+ segments_list["word"] = token.text
66
65
  list_of_tags = [t['tag'] for t in token.pred_tag]
67
66
  list_of_tags = [i for i in list_of_tags if i not in('O',' ','')]
68
- #print(list_of_tags)
69
67
  if list_of_tags == []:
70
- segments_list.append(' '.join(['O']))
68
+ segments_list["tag"] = ' '.join(['O'])
71
69
  else:
72
- segments_list.append(' '.join(list_of_tags))
73
- segments_lists.append(segments_list)
70
+ segments_list["tag"] = ' '.join(list_of_tags)
71
+ segments_lists.append(segments_list)
74
72
  return segments_lists
75
-
76
- #Print results
77
- # for segment in segments:
78
- # s = [
79
- # (token.text, token.pred_tag[0]['tag'])
80
- # for token in segment
81
- # if token.pred_tag[0]['tag'] != 'O'
82
- # ]
83
- # print(", ".join([f"({token}, {tag})" for token, tag in s]))
84
-
85
- def extract_tags(text):
86
- tags = []
87
- tokens = text.split()
88
- for token in tokens:
89
- tag = token.split("(")[-1].split(")")[0]
90
- if tag != "O":
91
- tags.append(tag)
92
- return " ".join(tags)
@@ -0,0 +1,86 @@
1
+ import os
2
+ import sys
3
+ import logging
4
+ import importlib
5
+ import shutil
6
+ import torch
7
+ import random
8
+ import numpy as np
9
+
10
+
11
+ def logging_config(log_file=None):
12
+ """
13
+ Initialize custom logger
14
+ :param log_file: str - path to log file, full path
15
+ :return: None
16
+ """
17
+ handlers = [logging.StreamHandler(sys.stdout)]
18
+
19
+ if log_file:
20
+ handlers.append(logging.FileHandler(log_file, "w", "utf-8"))
21
+ print("Logging to {}".format(log_file))
22
+
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ handlers=handlers,
26
+ format="%(levelname)s\t%(name)s\t%(asctime)s\t%(message)s",
27
+ datefmt="%a, %d %b %Y %H:%M:%S",
28
+ force=True
29
+ )
30
+
31
+
32
+ def load_object(name, kwargs):
33
+
34
+ try:
35
+ object_module, object_name = name.rsplit(".", 1)
36
+ object_module = importlib.import_module(object_module)
37
+ obj = getattr(object_module, object_name)
38
+ if callable(obj):
39
+ fn = obj(**kwargs)
40
+ return fn
41
+ else:
42
+ raise TypeError(f"{name} is not a callable object.")
43
+ except (ImportError, ModuleNotFoundError) as e:
44
+ print(f"Error importing module: {e}")
45
+ except AttributeError as e:
46
+ print(f"Attribute error: {e}")
47
+ except Exception as e:
48
+ print(f"An error occurred: {e}")
49
+
50
+ return None
51
+
52
+ def make_output_dirs(path, subdirs=[], overwrite=True):
53
+ """
54
+ Create root directory and any other sub-directories
55
+ :param path: str - root directory
56
+ :param subdirs: List[str] - list of sub-directories
57
+ :param overwrite: boolean - to overwrite the directory or not
58
+ :return: None
59
+ """
60
+ if overwrite:
61
+ shutil.rmtree(path, ignore_errors=True)
62
+
63
+ os.makedirs(path)
64
+
65
+ for subdir in subdirs:
66
+ os.makedirs(os.path.join(path, subdir))
67
+
68
+
69
+
70
+ def set_seed(seed):
71
+ """
72
+ Set the seed for random intialization and set
73
+ CUDANN parameters to ensure determmihstic results across
74
+ multiple runs with the same seed
75
+
76
+ :param seed: int
77
+ """
78
+ np.random.seed(seed)
79
+ random.seed(seed)
80
+ torch.manual_seed(seed)
81
+ torch.cuda.manual_seed(seed)
82
+ torch.cuda.manual_seed_all(seed)
83
+
84
+ torch.backends.cudnn.deterministic = True
85
+ torch.backends.cudnn.benchmark = False
86
+ torch.backends.cudnn.enabled = False
@@ -0,0 +1,69 @@
1
+ from seqeval.metrics import (
2
+ classification_report,
3
+ precision_score,
4
+ recall_score,
5
+ f1_score,
6
+ accuracy_score,
7
+ )
8
+ from seqeval.scheme import IOB2
9
+ from types import SimpleNamespace
10
+ import logging
11
+ import re
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def compute_nested_metrics(segments, vocabs):
17
+ """
18
+ Compute metrics for nested NER
19
+ :param segments: List[List[arabiner.data.dataset.Token]] - list of segments
20
+ :return: metrics - SimpleNamespace - F1/micro/macro/weights, recall, precision, accuracy
21
+ """
22
+ y, y_hat = list(), list()
23
+
24
+ # We duplicate the dataset N times, where N is the number of entity types
25
+ # For each copy, we create y and y_hat
26
+ # Example: first copy, will create pairs of ground truth and predicted labels for entity type GPE
27
+ # another copy will create pairs for LOC, etc.
28
+ for i, vocab in enumerate(vocabs):
29
+ vocab_tags = [tag for tag in vocab.get_itos() if "-" in tag]
30
+ r = re.compile("|".join(vocab_tags))
31
+
32
+ y += [[(list(filter(r.match, token.gold_tag)) or ["O"])[0] for token in segment] for segment in segments]
33
+ y_hat += [[token.pred_tag[i]["tag"] for token in segment] for segment in segments]
34
+
35
+ logging.info("\n" + classification_report(y, y_hat, scheme=IOB2, digits=4))
36
+
37
+ metrics = {
38
+ "micro_f1": f1_score(y, y_hat, average="micro", scheme=IOB2),
39
+ "macro_f1": f1_score(y, y_hat, average="macro", scheme=IOB2),
40
+ "weights_f1": f1_score(y, y_hat, average="weighted", scheme=IOB2),
41
+ "precision": precision_score(y, y_hat, scheme=IOB2),
42
+ "recall": recall_score(y, y_hat, scheme=IOB2),
43
+ "accuracy": accuracy_score(y, y_hat),
44
+ }
45
+
46
+ return SimpleNamespace(**metrics)
47
+
48
+
49
+ def compute_single_label_metrics(segments):
50
+ """
51
+ Compute metrics for flat NER
52
+ :param segments: List[List[arabiner.data.dataset.Token]] - list of segments
53
+ :return: metrics - SimpleNamespace - F1/micro/macro/weights, recall, precision, accuracy
54
+ """
55
+ y = [[token.gold_tag[0] for token in segment] for segment in segments]
56
+ y_hat = [[token.pred_tag[0]["tag"] for token in segment] for segment in segments]
57
+
58
+ logging.info("\n" + classification_report(y, y_hat, scheme=IOB2))
59
+
60
+ metrics = {
61
+ "micro_f1": f1_score(y, y_hat, average="micro", scheme=IOB2),
62
+ "macro_f1": f1_score(y, y_hat, average="macro", scheme=IOB2),
63
+ "weights_f1": f1_score(y, y_hat, average="weighted", scheme=IOB2),
64
+ "precision": precision_score(y, y_hat, scheme=IOB2),
65
+ "recall": recall_score(y, y_hat, scheme=IOB2),
66
+ "accuracy": accuracy_score(y, y_hat),
67
+ }
68
+
69
+ return SimpleNamespace(**metrics)
@@ -0,0 +1,22 @@
1
+ from torch import nn
2
+ from transformers import BertModel
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class BaseModel(nn.Module):
9
+ def __init__(self,
10
+ bert_model="aubmindlab/bert-base-arabertv2",
11
+ num_labels=2,
12
+ dropout=0.1,
13
+ num_types=0):
14
+ super().__init__()
15
+
16
+ self.bert_model = bert_model
17
+ self.num_labels = num_labels
18
+ self.num_types = num_types
19
+ self.dropout = dropout
20
+
21
+ self.bert = BertModel.from_pretrained(bert_model)
22
+ self.dropout = nn.Dropout(dropout)
@@ -0,0 +1,34 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from sinatools.ner.nn import BaseModel
4
+
5
+
6
+ class BertNestedTagger(BaseModel):
7
+ def __init__(self, **kwargs):
8
+ super(BertNestedTagger, self).__init__(**kwargs)
9
+
10
+ self.max_num_labels = max(self.num_labels)
11
+ classifiers = [nn.Linear(768, num_labels) for num_labels in self.num_labels]
12
+ self.classifiers = torch.nn.Sequential(*classifiers)
13
+
14
+ def forward(self, x):
15
+ y = self.bert(x)
16
+ y = self.dropout(y["last_hidden_state"])
17
+ output = list()
18
+
19
+ for i, classifier in enumerate(self.classifiers):
20
+ logits = classifier(y)
21
+
22
+ # Pad logits to allow Multi-GPU/DataParallel training to work
23
+ # We will truncate the padded dimensions when we compute the loss in the trainer
24
+ logits = torch.nn.ConstantPad1d((0, self.max_num_labels - logits.shape[-1]), 0)(logits)
25
+ output.append(logits)
26
+
27
+ # Return tensor of the shape B x T x L x C
28
+ # B: batch size
29
+ # T: sequence length
30
+ # L: number of tag types
31
+ # C: number of classes per tag type
32
+ output = torch.stack(output).permute((1, 2, 0, 3))
33
+ return output
34
+
@@ -0,0 +1,17 @@
1
+ import torch.nn as nn
2
+ from transformers import BertModel
3
+
4
+
5
+ class BertSeqTagger(nn.Module):
6
+ def __init__(self, bert_model, num_labels=2, dropout=0.1):
7
+ super().__init__()
8
+
9
+ self.bert = BertModel.from_pretrained(bert_model)
10
+ self.dropout = nn.Dropout(dropout)
11
+ self.linear = nn.Linear(768, num_labels)
12
+
13
+ def forward(self, x):
14
+ y = self.bert(x)
15
+ y = self.dropout(y["last_hidden_state"])
16
+ logits = self.linear(y)
17
+ return logits
@@ -0,0 +1,3 @@
1
+ from sinatools.ner.nn.BaseModel import BaseModel
2
+ from sinatools.ner.nn.BertSeqTagger import BertSeqTagger
3
+ from sinatools.ner.nn.BertNestedTagger import BertNestedTagger
@@ -0,0 +1,117 @@
1
+ import os
2
+ import torch
3
+ import logging
4
+ import natsort
5
+ import glob
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class BaseTrainer:
11
+ def __init__(
12
+ self,
13
+ model=None,
14
+ max_epochs=50,
15
+ optimizer=None,
16
+ scheduler=None,
17
+ loss=None,
18
+ train_dataloader=None,
19
+ val_dataloader=None,
20
+ test_dataloader=None,
21
+ log_interval=10,
22
+ summary_writer=None,
23
+ output_path=None,
24
+ clip=5,
25
+ patience=5
26
+ ):
27
+ self.model = model
28
+ self.max_epochs = max_epochs
29
+ self.train_dataloader = train_dataloader
30
+ self.val_dataloader = val_dataloader
31
+ self.test_dataloader = test_dataloader
32
+ self.optimizer = optimizer
33
+ self.scheduler = scheduler
34
+ self.loss = loss
35
+ self.log_interval = log_interval
36
+ self.summary_writer = summary_writer
37
+ self.output_path = output_path
38
+ self.current_timestep = 0
39
+ self.current_epoch = 0
40
+ self.clip = clip
41
+ self.patience = patience
42
+
43
+ def tag(self, dataloader, is_train=True):
44
+ """
45
+ Given a dataloader containing segments, predict the tags
46
+ :param dataloader: torch.utils.data.DataLoader
47
+ :param is_train: boolean - True for training model, False for evaluation
48
+ :return: Iterator
49
+ subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
50
+ gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
51
+ tokens - List[arabiner.data.dataset.Token] - list of tokens
52
+ valid_len (B x 1) - int - valiud length of each sequence
53
+ logits (B x T x NUM_LABELS) - logits for each token and each tag
54
+ """
55
+ for subwords, gold_tags, tokens, valid_len in dataloader:
56
+ self.model.train(is_train)
57
+
58
+ if torch.cuda.is_available():
59
+ subwords = subwords.cuda()
60
+ gold_tags = gold_tags.cuda()
61
+
62
+ if is_train:
63
+ self.optimizer.zero_grad()
64
+ logits = self.model(subwords)
65
+ else:
66
+ with torch.no_grad():
67
+ logits = self.model(subwords)
68
+
69
+ yield subwords, gold_tags, tokens, valid_len, logits
70
+
71
+ def segments_to_file(self, segments, filename):
72
+ """
73
+ Write segments to file
74
+ :param segments: [List[arabiner.data.dataset.Token]] - list of list of tokens
75
+ :param filename: str - output filename
76
+ :return: None
77
+ """
78
+ with open(filename, "w") as fh:
79
+ results = "\n\n".join(["\n".join([t.__str__() for t in segment]) for segment in segments])
80
+ fh.write("Token\tGold Tag\tPredicted Tag\n")
81
+ fh.write(results)
82
+ logging.info("Predictions written to %s", filename)
83
+
84
+ def save(self):
85
+ """
86
+ Save model checkpoint
87
+ :return:
88
+ """
89
+ filename = os.path.join(
90
+ self.output_path,
91
+ "checkpoints",
92
+ "checkpoint_{}.pt".format(self.current_epoch),
93
+ )
94
+
95
+ checkpoint = {
96
+ "model": self.model.state_dict(),
97
+ "optimizer": self.optimizer.state_dict(),
98
+ "epoch": self.current_epoch
99
+ }
100
+
101
+ logger.info("Saving checkpoint to %s", filename)
102
+ torch.save(checkpoint, filename)
103
+
104
+ def load(self, checkpoint_path):
105
+ """
106
+ Load model checkpoint
107
+ :param checkpoint_path: str - path/to/checkpoints
108
+ :return: None
109
+ """
110
+ checkpoint_path = natsort.natsorted(glob.glob(f"{checkpoint_path}/checkpoint_*.pt"))
111
+ checkpoint_path = checkpoint_path[-1]
112
+
113
+ logger.info("Loading checkpoint %s", checkpoint_path)
114
+
115
+ device = None if torch.cuda.is_available() else torch.device('cpu')
116
+ checkpoint = torch.load(checkpoint_path, map_location=device)
117
+ self.model.load_state_dict(checkpoint["model"])
@@ -0,0 +1,203 @@
1
+ import os
2
+ import logging
3
+ import torch
4
+ import numpy as np
5
+ from sinatools.ner.trainers import BaseTrainer
6
+ from sinatools.ner.metrics import compute_nested_metrics
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class BertNestedTrainer(BaseTrainer):
12
+ def __init__(self, **kwargs):
13
+ super().__init__(**kwargs)
14
+
15
+ def train(self):
16
+ best_val_loss, test_loss = np.inf, np.inf
17
+ num_train_batch = len(self.train_dataloader)
18
+ num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
19
+ patience = self.patience
20
+
21
+ for epoch_index in range(self.max_epochs):
22
+ self.current_epoch = epoch_index
23
+ train_loss = 0
24
+
25
+ for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
26
+ self.train_dataloader, is_train=True
27
+ ), 1):
28
+ self.current_timestep += 1
29
+
30
+ # Compute loses for each output
31
+ # logits = B x T x L x C
32
+ losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
33
+ torch.reshape(gold_tags[:, i, :], (-1,)).long())
34
+ for i, l in enumerate(num_labels)]
35
+
36
+ torch.autograd.backward(losses)
37
+
38
+ # Avoid exploding gradient by doing gradient clipping
39
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
40
+
41
+ self.optimizer.step()
42
+ self.scheduler.step()
43
+ batch_loss = sum(l.item() for l in losses)
44
+ train_loss += batch_loss
45
+
46
+ if self.current_timestep % self.log_interval == 0:
47
+ logger.info(
48
+ "Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
49
+ epoch_index,
50
+ batch_index,
51
+ num_train_batch,
52
+ self.current_timestep,
53
+ self.optimizer.param_groups[0]['lr'],
54
+ batch_loss
55
+ )
56
+
57
+ train_loss /= num_train_batch
58
+
59
+ logger.info("** Evaluating on validation dataset **")
60
+ val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
61
+ val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
62
+
63
+ epoch_summary_loss = {
64
+ "train_loss": train_loss,
65
+ "val_loss": val_loss
66
+ }
67
+ epoch_summary_metrics = {
68
+ "val_micro_f1": val_metrics.micro_f1,
69
+ "val_precision": val_metrics.precision,
70
+ "val_recall": val_metrics.recall
71
+ }
72
+
73
+ logger.info(
74
+ "Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
75
+ epoch_index,
76
+ self.current_timestep,
77
+ train_loss,
78
+ val_loss,
79
+ val_metrics.micro_f1
80
+ )
81
+
82
+ if val_loss < best_val_loss:
83
+ patience = self.patience
84
+ best_val_loss = val_loss
85
+ logger.info("** Validation improved, evaluating test data **")
86
+ test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
87
+ self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
88
+ test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
89
+
90
+ epoch_summary_loss["test_loss"] = test_loss
91
+ epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
92
+ epoch_summary_metrics["test_precision"] = test_metrics.precision
93
+ epoch_summary_metrics["test_recall"] = test_metrics.recall
94
+
95
+ logger.info(
96
+ f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
97
+ epoch_index,
98
+ self.current_timestep,
99
+ test_loss,
100
+ test_metrics.micro_f1
101
+ )
102
+
103
+ self.save()
104
+ else:
105
+ patience -= 1
106
+
107
+ # No improvements, terminating early
108
+ if patience == 0:
109
+ logger.info("Early termination triggered")
110
+ break
111
+
112
+ self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
113
+ self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
114
+
115
+ def tag(self, dataloader, is_train=True):
116
+ """
117
+ Given a dataloader containing segments, predict the tags
118
+ :param dataloader: torch.utils.data.DataLoader
119
+ :param is_train: boolean - True for training model, False for evaluation
120
+ :return: Iterator
121
+ subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
122
+ gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
123
+ tokens - List[arabiner.data.dataset.Token] - list of tokens
124
+ valid_len (B x 1) - int - valiud length of each sequence
125
+ logits (B x T x NUM_LABELS) - logits for each token and each tag
126
+ """
127
+ for subwords, gold_tags, tokens, mask, valid_len in dataloader:
128
+ self.model.train(is_train)
129
+
130
+ if torch.cuda.is_available():
131
+ subwords = subwords.cuda()
132
+ gold_tags = gold_tags.cuda()
133
+
134
+ if is_train:
135
+ self.optimizer.zero_grad()
136
+ logits = self.model(subwords)
137
+ else:
138
+ with torch.no_grad():
139
+ logits = self.model(subwords)
140
+
141
+ yield subwords, gold_tags, tokens, valid_len, logits
142
+
143
+ def eval(self, dataloader):
144
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
145
+ num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
146
+ loss = 0
147
+
148
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
149
+ dataloader, is_train=False
150
+ ):
151
+ losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
152
+ torch.reshape(gold_tags[:, i, :], (-1,)).long())
153
+ for i, l in enumerate(num_labels)]
154
+ loss += sum(losses)
155
+ preds += torch.argmax(logits, dim=3)
156
+ segments += tokens
157
+ valid_lens += list(valid_len)
158
+
159
+ loss /= len(dataloader)
160
+
161
+ # Update segments, attach predicted tags to each token
162
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
163
+
164
+ return preds, segments, valid_lens, loss
165
+
166
+ def infer(self, dataloader):
167
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
168
+
169
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
170
+ dataloader, is_train=False
171
+ ):
172
+ preds += torch.argmax(logits, dim=3)
173
+ segments += tokens
174
+ valid_lens += list(valid_len)
175
+
176
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
177
+ return segments
178
+
179
+ def to_segments(self, segments, preds, valid_lens, vocab):
180
+ if vocab is None:
181
+ vocab = self.vocab
182
+
183
+ tagged_segments = list()
184
+ tokens_stoi = vocab.tokens.get_stoi()
185
+ unk_id = tokens_stoi["UNK"]
186
+
187
+ for segment, pred, valid_len in zip(segments, preds, valid_lens):
188
+ # First, the token at 0th index [CLS] and token at nth index [SEP]
189
+ # Combine the tokens with their corresponding predictions
190
+ segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
191
+
192
+ # Ignore the sub-tokens/subwords, which are identified with text being UNK
193
+ segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
194
+
195
+ # Attach the predicted tags to each token
196
+ list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
197
+ for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
198
+
199
+ # We are only interested in the tagged tokens, we do no longer need raw model predictions
200
+ tagged_segment = [t for t, _ in segment_pred]
201
+ tagged_segments.append(tagged_segment)
202
+
203
+ return tagged_segments