SinaTools 0.1.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. SinaTools-0.1.1.data/data/nlptools/environment.yml +227 -0
  2. SinaTools-0.1.1.dist-info/AUTHORS.rst +13 -0
  3. SinaTools-0.1.1.dist-info/LICENSE +22 -0
  4. SinaTools-0.1.1.dist-info/METADATA +72 -0
  5. SinaTools-0.1.1.dist-info/RECORD +122 -0
  6. SinaTools-0.1.1.dist-info/WHEEL +6 -0
  7. SinaTools-0.1.1.dist-info/entry_points.txt +18 -0
  8. SinaTools-0.1.1.dist-info/top_level.txt +1 -0
  9. nlptools/CLI/DataDownload/download_files.py +71 -0
  10. nlptools/CLI/arabiner/bin/infer.py +117 -0
  11. nlptools/CLI/arabiner/bin/infer2.py +81 -0
  12. nlptools/CLI/morphology/ALMA_multi_word.py +75 -0
  13. nlptools/CLI/morphology/morph_analyzer.py +91 -0
  14. nlptools/CLI/salma/salma_tools.py +68 -0
  15. nlptools/CLI/utils/__init__.py +0 -0
  16. nlptools/CLI/utils/arStrip.py +99 -0
  17. nlptools/CLI/utils/corpus_tokenizer.py +74 -0
  18. nlptools/CLI/utils/implication.py +92 -0
  19. nlptools/CLI/utils/jaccard.py +96 -0
  20. nlptools/CLI/utils/latin_remove.py +51 -0
  21. nlptools/CLI/utils/remove_Punc.py +53 -0
  22. nlptools/CLI/utils/sentence_tokenizer.py +90 -0
  23. nlptools/CLI/utils/text_transliteration.py +77 -0
  24. nlptools/DataDownload/__init__.py +0 -0
  25. nlptools/DataDownload/downloader.py +185 -0
  26. nlptools/VERSION +1 -0
  27. nlptools/__init__.py +5 -0
  28. nlptools/arabert/__init__.py +1 -0
  29. nlptools/arabert/arabert/__init__.py +14 -0
  30. nlptools/arabert/arabert/create_classification_data.py +260 -0
  31. nlptools/arabert/arabert/create_pretraining_data.py +534 -0
  32. nlptools/arabert/arabert/extract_features.py +444 -0
  33. nlptools/arabert/arabert/lamb_optimizer.py +158 -0
  34. nlptools/arabert/arabert/modeling.py +1027 -0
  35. nlptools/arabert/arabert/optimization.py +202 -0
  36. nlptools/arabert/arabert/run_classifier.py +1078 -0
  37. nlptools/arabert/arabert/run_pretraining.py +593 -0
  38. nlptools/arabert/arabert/run_squad.py +1440 -0
  39. nlptools/arabert/arabert/tokenization.py +414 -0
  40. nlptools/arabert/araelectra/__init__.py +1 -0
  41. nlptools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +103 -0
  42. nlptools/arabert/araelectra/build_pretraining_dataset.py +230 -0
  43. nlptools/arabert/araelectra/build_pretraining_dataset_single_file.py +90 -0
  44. nlptools/arabert/araelectra/configure_finetuning.py +172 -0
  45. nlptools/arabert/araelectra/configure_pretraining.py +143 -0
  46. nlptools/arabert/araelectra/finetune/__init__.py +14 -0
  47. nlptools/arabert/araelectra/finetune/feature_spec.py +56 -0
  48. nlptools/arabert/araelectra/finetune/preprocessing.py +173 -0
  49. nlptools/arabert/araelectra/finetune/scorer.py +54 -0
  50. nlptools/arabert/araelectra/finetune/task.py +74 -0
  51. nlptools/arabert/araelectra/finetune/task_builder.py +70 -0
  52. nlptools/arabert/araelectra/flops_computation.py +215 -0
  53. nlptools/arabert/araelectra/model/__init__.py +14 -0
  54. nlptools/arabert/araelectra/model/modeling.py +1029 -0
  55. nlptools/arabert/araelectra/model/optimization.py +193 -0
  56. nlptools/arabert/araelectra/model/tokenization.py +355 -0
  57. nlptools/arabert/araelectra/pretrain/__init__.py +14 -0
  58. nlptools/arabert/araelectra/pretrain/pretrain_data.py +160 -0
  59. nlptools/arabert/araelectra/pretrain/pretrain_helpers.py +229 -0
  60. nlptools/arabert/araelectra/run_finetuning.py +323 -0
  61. nlptools/arabert/araelectra/run_pretraining.py +469 -0
  62. nlptools/arabert/araelectra/util/__init__.py +14 -0
  63. nlptools/arabert/araelectra/util/training_utils.py +112 -0
  64. nlptools/arabert/araelectra/util/utils.py +109 -0
  65. nlptools/arabert/aragpt2/__init__.py +2 -0
  66. nlptools/arabert/aragpt2/create_pretraining_data.py +95 -0
  67. nlptools/arabert/aragpt2/gpt2/__init__.py +2 -0
  68. nlptools/arabert/aragpt2/gpt2/lamb_optimizer.py +158 -0
  69. nlptools/arabert/aragpt2/gpt2/optimization.py +225 -0
  70. nlptools/arabert/aragpt2/gpt2/run_pretraining.py +397 -0
  71. nlptools/arabert/aragpt2/grover/__init__.py +0 -0
  72. nlptools/arabert/aragpt2/grover/dataloader.py +161 -0
  73. nlptools/arabert/aragpt2/grover/modeling.py +803 -0
  74. nlptools/arabert/aragpt2/grover/modeling_gpt2.py +1196 -0
  75. nlptools/arabert/aragpt2/grover/optimization_adafactor.py +234 -0
  76. nlptools/arabert/aragpt2/grover/train_tpu.py +187 -0
  77. nlptools/arabert/aragpt2/grover/utils.py +234 -0
  78. nlptools/arabert/aragpt2/train_bpe_tokenizer.py +59 -0
  79. nlptools/arabert/preprocess.py +818 -0
  80. nlptools/arabiner/__init__.py +0 -0
  81. nlptools/arabiner/bin/__init__.py +14 -0
  82. nlptools/arabiner/bin/eval.py +87 -0
  83. nlptools/arabiner/bin/infer.py +91 -0
  84. nlptools/arabiner/bin/process.py +140 -0
  85. nlptools/arabiner/bin/train.py +221 -0
  86. nlptools/arabiner/data/__init__.py +1 -0
  87. nlptools/arabiner/data/datasets.py +146 -0
  88. nlptools/arabiner/data/transforms.py +118 -0
  89. nlptools/arabiner/nn/BaseModel.py +22 -0
  90. nlptools/arabiner/nn/BertNestedTagger.py +34 -0
  91. nlptools/arabiner/nn/BertSeqTagger.py +17 -0
  92. nlptools/arabiner/nn/__init__.py +3 -0
  93. nlptools/arabiner/trainers/BaseTrainer.py +117 -0
  94. nlptools/arabiner/trainers/BertNestedTrainer.py +203 -0
  95. nlptools/arabiner/trainers/BertTrainer.py +163 -0
  96. nlptools/arabiner/trainers/__init__.py +3 -0
  97. nlptools/arabiner/utils/__init__.py +0 -0
  98. nlptools/arabiner/utils/data.py +124 -0
  99. nlptools/arabiner/utils/helpers.py +151 -0
  100. nlptools/arabiner/utils/metrics.py +69 -0
  101. nlptools/environment.yml +227 -0
  102. nlptools/install_env.py +13 -0
  103. nlptools/morphology/ALMA_multi_word.py +34 -0
  104. nlptools/morphology/__init__.py +52 -0
  105. nlptools/morphology/charsets.py +60 -0
  106. nlptools/morphology/morph_analyzer.py +170 -0
  107. nlptools/morphology/settings.py +8 -0
  108. nlptools/morphology/tokenizers_words.py +19 -0
  109. nlptools/nlptools.py +1 -0
  110. nlptools/salma/__init__.py +12 -0
  111. nlptools/salma/settings.py +31 -0
  112. nlptools/salma/views.py +459 -0
  113. nlptools/salma/wsd.py +126 -0
  114. nlptools/utils/__init__.py +0 -0
  115. nlptools/utils/corpus_tokenizer.py +73 -0
  116. nlptools/utils/implication.py +662 -0
  117. nlptools/utils/jaccard.py +247 -0
  118. nlptools/utils/parser.py +147 -0
  119. nlptools/utils/readfile.py +3 -0
  120. nlptools/utils/sentence_tokenizer.py +53 -0
  121. nlptools/utils/text_transliteration.py +232 -0
  122. nlptools/utils/utils.py +2 -0
@@ -0,0 +1,163 @@
1
+ import os
2
+ import logging
3
+ import torch
4
+ import numpy as np
5
+ from nlptools.arabiner.trainers import BaseTrainer
6
+ from nlptools.arabiner.utils.metrics import compute_single_label_metrics
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class BertTrainer(BaseTrainer):
12
+ def __init__(self, **kwargs):
13
+ super().__init__(**kwargs)
14
+
15
+ def train(self):
16
+ best_val_loss, test_loss = np.inf, np.inf
17
+ num_train_batch = len(self.train_dataloader)
18
+ patience = self.patience
19
+
20
+ for epoch_index in range(self.max_epochs):
21
+ self.current_epoch = epoch_index
22
+ train_loss = 0
23
+
24
+ for batch_index, (_, gold_tags, _, _, logits) in enumerate(self.tag(
25
+ self.train_dataloader, is_train=True
26
+ ), 1):
27
+ self.current_timestep += 1
28
+ batch_loss = self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
29
+ batch_loss.backward()
30
+
31
+ # Avoid exploding gradient by doing gradient clipping
32
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
33
+
34
+ self.optimizer.step()
35
+ self.scheduler.step()
36
+ train_loss += batch_loss.item()
37
+
38
+ if self.current_timestep % self.log_interval == 0:
39
+ logger.info(
40
+ "Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
41
+ epoch_index,
42
+ batch_index,
43
+ num_train_batch,
44
+ self.current_timestep,
45
+ self.optimizer.param_groups[0]['lr'],
46
+ batch_loss.item()
47
+ )
48
+
49
+ train_loss /= num_train_batch
50
+
51
+ logger.info("** Evaluating on validation dataset **")
52
+ val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
53
+ val_metrics = compute_single_label_metrics(segments)
54
+
55
+ epoch_summary_loss = {
56
+ "train_loss": train_loss,
57
+ "val_loss": val_loss
58
+ }
59
+ epoch_summary_metrics = {
60
+ "val_micro_f1": val_metrics.micro_f1,
61
+ "val_precision": val_metrics.precision,
62
+ "val_recall": val_metrics.recall
63
+ }
64
+
65
+ logger.info(
66
+ "Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
67
+ epoch_index,
68
+ self.current_timestep,
69
+ train_loss,
70
+ val_loss,
71
+ val_metrics.micro_f1
72
+ )
73
+
74
+ if val_loss < best_val_loss:
75
+ patience = self.patience
76
+ best_val_loss = val_loss
77
+ logger.info("** Validation improved, evaluating test data **")
78
+ test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
79
+ self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
80
+ test_metrics = compute_single_label_metrics(segments)
81
+
82
+ epoch_summary_loss["test_loss"] = test_loss
83
+ epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
84
+ epoch_summary_metrics["test_precision"] = test_metrics.precision
85
+ epoch_summary_metrics["test_recall"] = test_metrics.recall
86
+
87
+ logger.info(
88
+ f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
89
+ epoch_index,
90
+ self.current_timestep,
91
+ test_loss,
92
+ test_metrics.micro_f1
93
+ )
94
+
95
+ self.save()
96
+ else:
97
+ patience -= 1
98
+
99
+ # No improvements, terminating early
100
+ if patience == 0:
101
+ logger.info("Early termination triggered")
102
+ break
103
+
104
+ self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
105
+ self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
106
+
107
+ def eval(self, dataloader):
108
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
109
+ loss = 0
110
+
111
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
112
+ dataloader, is_train=False
113
+ ):
114
+ loss += self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
115
+ preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
116
+ segments += tokens
117
+ valid_lens += list(valid_len)
118
+
119
+ loss /= len(dataloader)
120
+
121
+ # Update segments, attach predicted tags to each token
122
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
123
+
124
+ return preds, segments, valid_lens, loss.item()
125
+
126
+ def infer(self, dataloader):
127
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
128
+
129
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
130
+ dataloader, is_train=False
131
+ ):
132
+ preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
133
+ segments += tokens
134
+ valid_lens += list(valid_len)
135
+
136
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
137
+ return segments
138
+
139
+ def to_segments(self, segments, preds, valid_lens, vocab):
140
+ if vocab is None:
141
+ vocab = self.vocab
142
+
143
+ tagged_segments = list()
144
+ tokens_stoi = vocab.tokens.get_stoi()
145
+ tags_itos = vocab.tags[0].get_itos()
146
+ unk_id = tokens_stoi["UNK"]
147
+
148
+ for segment, pred, valid_len in zip(segments, preds, valid_lens):
149
+ # First, the token at 0th index [CLS] and token at nth index [SEP]
150
+ # Combine the tokens with their corresponding predictions
151
+ segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
152
+
153
+ # Ignore the sub-tokens/subwords, which are identified with text being UNK
154
+ segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
155
+
156
+ # Attach the predicted tags to each token
157
+ list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": tags_itos[t[1]]}]), segment_pred))
158
+
159
+ # We are only interested in the tagged tokens, we do no longer need raw model predictions
160
+ tagged_segment = [t for t, _ in segment_pred]
161
+ tagged_segments.append(tagged_segment)
162
+
163
+ return tagged_segments
@@ -0,0 +1,3 @@
1
+ from nlptools.arabiner.trainers.BaseTrainer import BaseTrainer
2
+ from nlptools.arabiner.trainers.BertTrainer import BertTrainer
3
+ from nlptools.arabiner.trainers.BertNestedTrainer import BertNestedTrainer
File without changes
@@ -0,0 +1,124 @@
1
+ from torch.utils.data import DataLoader
2
+ from torchtext.vocab import vocab
3
+ from collections import Counter, namedtuple
4
+ import logging
5
+ import re
6
+ import itertools
7
+ from nlptools.arabiner.utils.helpers import load_object
8
+ from nlptools.arabiner.data.datasets import Token
9
+ from nlptools.morphology.tokenizers_words import simple_word_tokenize
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def conll_to_segments(filename):
15
+ """
16
+ Convert CoNLL files to segments. This return list of segments and each segment is
17
+ a list of tuples (token, tag)
18
+ :param filename: Path
19
+ :return: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
20
+ """
21
+ segments, segment = list(), list()
22
+
23
+ with open(filename, "r") as fh:
24
+ for token in fh.read().splitlines():
25
+ if not token.strip():
26
+ segments.append(segment)
27
+ segment = list()
28
+ else:
29
+ parts = token.split()
30
+ token = Token(text=parts[0], gold_tag=parts[1:])
31
+ segment.append(token)
32
+
33
+ segments.append(segment)
34
+
35
+ return segments
36
+
37
+
38
+ def parse_conll_files(data_paths):
39
+ """
40
+ Parse CoNLL formatted files and return list of segments for each file and index
41
+ the vocabs and tags across all data_paths
42
+ :param data_paths: tuple(Path) - tuple of filenames
43
+ :return: tuple( [[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i]
44
+ [[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i+1],
45
+ ...
46
+ )
47
+ List of segments for each dataset and each segment has list of (tokens, tags)
48
+ """
49
+ vocabs = namedtuple("Vocab", ["tags", "tokens"])
50
+ datasets, tags, tokens = list(), list(), list()
51
+
52
+ for data_path in data_paths:
53
+ dataset = conll_to_segments(data_path)
54
+ datasets.append(dataset)
55
+ tokens += [token.text for segment in dataset for token in segment]
56
+ tags += [token.gold_tag for segment in dataset for token in segment]
57
+
58
+ # Flatten list of tags
59
+ tags = list(itertools.chain(*tags))
60
+
61
+ # Generate vocabs for tags and tokens
62
+ tag_vocabs = tag_vocab_by_type(tags)
63
+ tag_vocabs.insert(0, vocab(Counter(tags)))
64
+ vocabs = vocabs(tokens=vocab(Counter(tokens), specials=["UNK"]), tags=tag_vocabs)
65
+ return tuple(datasets), vocabs
66
+
67
+
68
+ def tag_vocab_by_type(tags):
69
+ vocabs = list()
70
+ c = Counter(tags)
71
+ tag_names = c.keys()
72
+ tag_types = sorted(list(set([tag.split("-", 1)[1] for tag in tag_names if "-" in tag])))
73
+
74
+ for tag_type in tag_types:
75
+ r = re.compile(".*-" + tag_type)
76
+ t = list(filter(r.match, tags)) + ["O"]
77
+ vocabs.append(vocab(Counter(t), specials=["<pad>"]))
78
+
79
+ return vocabs
80
+
81
+
82
+ def text2segments(text):
83
+ """
84
+ Convert text to a datasets and index the tokens
85
+ """
86
+ #dataset = [[Token(text=token, gold_tag=["O"]) for token in text.split()]]
87
+ list_of_tokens = simple_word_tokenize(text)
88
+ dataset = [[Token(text=token, gold_tag=["O"]) for token in list_of_tokens]]
89
+ tokens = [token.text for segment in dataset for token in segment]
90
+
91
+ # Generate vocabs for the tokens
92
+ segment_vocab = vocab(Counter(tokens), specials=["UNK"])
93
+ return dataset, segment_vocab
94
+
95
+
96
+ def get_dataloaders(
97
+ datasets, vocab, data_config, batch_size=32, num_workers=0, shuffle=(True, False, False)
98
+ ):
99
+ """
100
+ From the datasets generate the dataloaders
101
+ :param datasets: list - list of the datasets, list of list of segments and tokens
102
+ :param batch_size: int
103
+ :param num_workers: int
104
+ :param shuffle: boolean - to shuffle the data or not
105
+ :return: List[torch.utils.data.DataLoader]
106
+ """
107
+ dataloaders = list()
108
+
109
+ for i, examples in enumerate(datasets):
110
+ data_config["kwargs"].update({"examples": examples, "vocab": vocab})
111
+ dataset = load_object("nlptools."+data_config["fn"], data_config["kwargs"])
112
+
113
+ dataloader = DataLoader(
114
+ dataset=dataset,
115
+ shuffle=shuffle[i],
116
+ batch_size=batch_size,
117
+ num_workers=num_workers,
118
+ collate_fn=dataset.collate_fn,
119
+ )
120
+
121
+ logger.info("%s batches found", len(dataloader))
122
+ dataloaders.append(dataloader)
123
+
124
+ return dataloaders
@@ -0,0 +1,151 @@
1
+ import os
2
+ import sys
3
+ import logging
4
+ import importlib
5
+ import shutil
6
+ import torch
7
+ import pickle
8
+ import json
9
+ import random
10
+ import numpy as np
11
+ from argparse import Namespace
12
+
13
+
14
+ def logging_config(log_file=None):
15
+ """
16
+ Initialize custom logger
17
+ :param log_file: str - path to log file, full path
18
+ :return: None
19
+ """
20
+ handlers = [logging.StreamHandler(sys.stdout)]
21
+
22
+ if log_file:
23
+ handlers.append(logging.FileHandler(log_file, "w", "utf-8"))
24
+ print("Logging to {}".format(log_file))
25
+
26
+ logging.basicConfig(
27
+ level=logging.INFO,
28
+ handlers=handlers,
29
+ format="%(levelname)s\t%(name)s\t%(asctime)s\t%(message)s",
30
+ datefmt="%a, %d %b %Y %H:%M:%S",
31
+ force=True
32
+ )
33
+
34
+
35
+ #def load_object(name, kwargs):
36
+ # """
37
+ # Load objects dynamically given the object name and its arguments
38
+ # :param name: str - object name, class name or function name
39
+ # :param kwargs: dict - keyword arguments
40
+ # :return: object
41
+ # """
42
+ # object_module, object_name = name.rsplit(".", 1)
43
+ # object_module = importlib.import_module(object_module)
44
+ # fn = getattr(object_module, object_name)(**kwargs)
45
+ # return fn
46
+
47
+ def load_object(name, kwargs):
48
+ #print("Iterate: ", name, kwargs)
49
+ try:
50
+ #print("name: ", name)
51
+ object_module, object_name = name.rsplit(".", 1)
52
+ #print("object_module: ", object_module)
53
+ #print("object_name: ", object_name)
54
+ #if object_module != "nlptools.arabiner.nn":
55
+ object_module = importlib.import_module(object_module)
56
+ #print("object_module 2: ", object_module)
57
+ obj = getattr(object_module, object_name)
58
+ #print("obj: ", obj)
59
+ if callable(obj): # Check if the object is callable (class or function)
60
+ fn = obj(**kwargs)
61
+ return fn
62
+ else:
63
+ raise TypeError(f"{name} is not a callable object.")
64
+ except (ImportError, ModuleNotFoundError) as e:
65
+ # Handle import errors
66
+ print(f"Error importing module: {e}")
67
+ except AttributeError as e:
68
+ # Handle attribute errors (e.g., object not found in module)
69
+ print(f"Attribute error: {e}")
70
+ except Exception as e:
71
+ # Handle other exceptions
72
+ print(f"An error occurred: {e}")
73
+
74
+ #print("Loaded object:", name)
75
+ return None # Return None in case of any error
76
+
77
+ def make_output_dirs(path, subdirs=[], overwrite=True):
78
+ """
79
+ Create root directory and any other sub-directories
80
+ :param path: str - root directory
81
+ :param subdirs: List[str] - list of sub-directories
82
+ :param overwrite: boolean - to overwrite the directory or not
83
+ :return: None
84
+ """
85
+ if overwrite:
86
+ shutil.rmtree(path, ignore_errors=True)
87
+
88
+ os.makedirs(path)
89
+
90
+ for subdir in subdirs:
91
+ os.makedirs(os.path.join(path, subdir))
92
+
93
+
94
+ def load_checkpoint(path):
95
+ """
96
+ Load model given the model path
97
+ :param model_path: str - path to model
98
+ :return: tagger - arabiner.trainers.BaseTrainer - the tagger model
99
+ vocab - torchtext.vocab.Vocab - indexed tags
100
+ train_config - argparse.Namespace - training configurations
101
+ """
102
+ _path = os.path.join(path, "tag_vocab.pkl")
103
+ #print('2',_path)
104
+ with open(_path, "rb") as fh:
105
+ tag_vocab = pickle.load(fh)
106
+
107
+ # Load train configurations from checkpoint
108
+ train_config = Namespace()
109
+ args_path = os.path.join(path, "args.json")
110
+ #print('3', args_path)
111
+ with open(args_path, "r") as fh:
112
+ train_config.__dict__ = json.load(fh)
113
+
114
+ # Initialize the loss function, not used for inference, but evaluation
115
+ loss = load_object(train_config.loss["fn"], train_config.loss["kwargs"])
116
+ #print('4')
117
+ # Load BERT tagger
118
+ model = load_object("nlptools."+train_config.network_config["fn"], train_config.network_config["kwargs"])
119
+ model = torch.nn.DataParallel(model)
120
+ #print('5')
121
+ if torch.cuda.is_available():
122
+ model = model.cuda()
123
+
124
+ # Update arguments for the tagger
125
+ # Attach the model, loss (used for evaluations cases)
126
+ train_config.trainer_config["kwargs"]["model"] = model
127
+ train_config.trainer_config["kwargs"]["loss"] = loss
128
+ #print('6')
129
+ tagger = load_object("nlptools."+train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
130
+ tagger.load(os.path.join(path, "checkpoints"))
131
+ #print('7')
132
+ return tagger, tag_vocab, train_config
133
+
134
+
135
+ def set_seed(seed):
136
+ """
137
+ Set the seed for random intialization and set
138
+ CUDANN parameters to ensure determmihstic results across
139
+ multiple runs with the same seed
140
+
141
+ :param seed: int
142
+ """
143
+ np.random.seed(seed)
144
+ random.seed(seed)
145
+ torch.manual_seed(seed)
146
+ torch.cuda.manual_seed(seed)
147
+ torch.cuda.manual_seed_all(seed)
148
+
149
+ torch.backends.cudnn.deterministic = True
150
+ torch.backends.cudnn.benchmark = False
151
+ torch.backends.cudnn.enabled = False
@@ -0,0 +1,69 @@
1
+ from seqeval.metrics import (
2
+ classification_report,
3
+ precision_score,
4
+ recall_score,
5
+ f1_score,
6
+ accuracy_score,
7
+ )
8
+ from seqeval.scheme import IOB2
9
+ from types import SimpleNamespace
10
+ import logging
11
+ import re
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def compute_nested_metrics(segments, vocabs):
17
+ """
18
+ Compute metrics for nested NER
19
+ :param segments: List[List[arabiner.data.dataset.Token]] - list of segments
20
+ :return: metrics - SimpleNamespace - F1/micro/macro/weights, recall, precision, accuracy
21
+ """
22
+ y, y_hat = list(), list()
23
+
24
+ # We duplicate the dataset N times, where N is the number of entity types
25
+ # For each copy, we create y and y_hat
26
+ # Example: first copy, will create pairs of ground truth and predicted labels for entity type GPE
27
+ # another copy will create pairs for LOC, etc.
28
+ for i, vocab in enumerate(vocabs):
29
+ vocab_tags = [tag for tag in vocab.get_itos() if "-" in tag]
30
+ r = re.compile("|".join(vocab_tags))
31
+
32
+ y += [[(list(filter(r.match, token.gold_tag)) or ["O"])[0] for token in segment] for segment in segments]
33
+ y_hat += [[token.pred_tag[i]["tag"] for token in segment] for segment in segments]
34
+
35
+ logging.info("\n" + classification_report(y, y_hat, scheme=IOB2, digits=4))
36
+
37
+ metrics = {
38
+ "micro_f1": f1_score(y, y_hat, average="micro", scheme=IOB2),
39
+ "macro_f1": f1_score(y, y_hat, average="macro", scheme=IOB2),
40
+ "weights_f1": f1_score(y, y_hat, average="weighted", scheme=IOB2),
41
+ "precision": precision_score(y, y_hat, scheme=IOB2),
42
+ "recall": recall_score(y, y_hat, scheme=IOB2),
43
+ "accuracy": accuracy_score(y, y_hat),
44
+ }
45
+
46
+ return SimpleNamespace(**metrics)
47
+
48
+
49
+ def compute_single_label_metrics(segments):
50
+ """
51
+ Compute metrics for flat NER
52
+ :param segments: List[List[arabiner.data.dataset.Token]] - list of segments
53
+ :return: metrics - SimpleNamespace - F1/micro/macro/weights, recall, precision, accuracy
54
+ """
55
+ y = [[token.gold_tag[0] for token in segment] for segment in segments]
56
+ y_hat = [[token.pred_tag[0]["tag"] for token in segment] for segment in segments]
57
+
58
+ logging.info("\n" + classification_report(y, y_hat, scheme=IOB2))
59
+
60
+ metrics = {
61
+ "micro_f1": f1_score(y, y_hat, average="micro", scheme=IOB2),
62
+ "macro_f1": f1_score(y, y_hat, average="macro", scheme=IOB2),
63
+ "weights_f1": f1_score(y, y_hat, average="weighted", scheme=IOB2),
64
+ "precision": precision_score(y, y_hat, scheme=IOB2),
65
+ "recall": recall_score(y, y_hat, scheme=IOB2),
66
+ "accuracy": accuracy_score(y, y_hat),
67
+ }
68
+
69
+ return SimpleNamespace(**metrics)