SinaTools 0.1.35__py2.py3-none-any.whl → 0.1.37__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,163 +1,163 @@
1
- import os
2
- import logging
3
- import torch
4
- import numpy as np
5
- from sinatools.ner.trainers import BaseTrainer
6
- from sinatools.ner.metrics import compute_single_label_metrics
7
-
8
- logger = logging.getLogger(__name__)
9
-
10
-
11
- class BertTrainer(BaseTrainer):
12
- def __init__(self, **kwargs):
13
- super().__init__(**kwargs)
14
-
15
- def train(self):
16
- best_val_loss, test_loss = np.inf, np.inf
17
- num_train_batch = len(self.train_dataloader)
18
- patience = self.patience
19
-
20
- for epoch_index in range(self.max_epochs):
21
- self.current_epoch = epoch_index
22
- train_loss = 0
23
-
24
- for batch_index, (_, gold_tags, _, _, logits) in enumerate(self.tag(
25
- self.train_dataloader, is_train=True
26
- ), 1):
27
- self.current_timestep += 1
28
- batch_loss = self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
29
- batch_loss.backward()
30
-
31
- # Avoid exploding gradient by doing gradient clipping
32
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
33
-
34
- self.optimizer.step()
35
- self.scheduler.step()
36
- train_loss += batch_loss.item()
37
-
38
- if self.current_timestep % self.log_interval == 0:
39
- logger.info(
40
- "Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
41
- epoch_index,
42
- batch_index,
43
- num_train_batch,
44
- self.current_timestep,
45
- self.optimizer.param_groups[0]['lr'],
46
- batch_loss.item()
47
- )
48
-
49
- train_loss /= num_train_batch
50
-
51
- logger.info("** Evaluating on validation dataset **")
52
- val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
53
- val_metrics = compute_single_label_metrics(segments)
54
-
55
- epoch_summary_loss = {
56
- "train_loss": train_loss,
57
- "val_loss": val_loss
58
- }
59
- epoch_summary_metrics = {
60
- "val_micro_f1": val_metrics.micro_f1,
61
- "val_precision": val_metrics.precision,
62
- "val_recall": val_metrics.recall
63
- }
64
-
65
- logger.info(
66
- "Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
67
- epoch_index,
68
- self.current_timestep,
69
- train_loss,
70
- val_loss,
71
- val_metrics.micro_f1
72
- )
73
-
74
- if val_loss < best_val_loss:
75
- patience = self.patience
76
- best_val_loss = val_loss
77
- logger.info("** Validation improved, evaluating test data **")
78
- test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
79
- self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
80
- test_metrics = compute_single_label_metrics(segments)
81
-
82
- epoch_summary_loss["test_loss"] = test_loss
83
- epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
84
- epoch_summary_metrics["test_precision"] = test_metrics.precision
85
- epoch_summary_metrics["test_recall"] = test_metrics.recall
86
-
87
- logger.info(
88
- f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
89
- epoch_index,
90
- self.current_timestep,
91
- test_loss,
92
- test_metrics.micro_f1
93
- )
94
-
95
- self.save()
96
- else:
97
- patience -= 1
98
-
99
- # No improvements, terminating early
100
- if patience == 0:
101
- logger.info("Early termination triggered")
102
- break
103
-
104
- self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
105
- self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
106
-
107
- def eval(self, dataloader):
108
- golds, preds, segments, valid_lens = list(), list(), list(), list()
109
- loss = 0
110
-
111
- for _, gold_tags, tokens, valid_len, logits in self.tag(
112
- dataloader, is_train=False
113
- ):
114
- loss += self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
115
- preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
116
- segments += tokens
117
- valid_lens += list(valid_len)
118
-
119
- loss /= len(dataloader)
120
-
121
- # Update segments, attach predicted tags to each token
122
- segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
123
-
124
- return preds, segments, valid_lens, loss.item()
125
-
126
- def infer(self, dataloader):
127
- golds, preds, segments, valid_lens = list(), list(), list(), list()
128
-
129
- for _, gold_tags, tokens, valid_len, logits in self.tag(
130
- dataloader, is_train=False
131
- ):
132
- preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
133
- segments += tokens
134
- valid_lens += list(valid_len)
135
-
136
- segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
137
- return segments
138
-
139
- def to_segments(self, segments, preds, valid_lens, vocab):
140
- if vocab is None:
141
- vocab = self.vocab
142
-
143
- tagged_segments = list()
144
- tokens_stoi = vocab.tokens.get_stoi()
145
- tags_itos = vocab.tags[0].get_itos()
146
- unk_id = tokens_stoi["UNK"]
147
-
148
- for segment, pred, valid_len in zip(segments, preds, valid_lens):
149
- # First, the token at 0th index [CLS] and token at nth index [SEP]
150
- # Combine the tokens with their corresponding predictions
151
- segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
152
-
153
- # Ignore the sub-tokens/subwords, which are identified with text being UNK
154
- segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
155
-
156
- # Attach the predicted tags to each token
157
- list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": tags_itos[t[1]]}]), segment_pred))
158
-
159
- # We are only interested in the tagged tokens, we do no longer need raw model predictions
160
- tagged_segment = [t for t, _ in segment_pred]
161
- tagged_segments.append(tagged_segment)
162
-
163
- return tagged_segments
1
+ import os
2
+ import logging
3
+ import torch
4
+ import numpy as np
5
+ from sinatools.ner.trainers import BaseTrainer
6
+ from sinatools.ner.metrics import compute_single_label_metrics
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class BertTrainer(BaseTrainer):
12
+ def __init__(self, **kwargs):
13
+ super().__init__(**kwargs)
14
+
15
+ def train(self):
16
+ best_val_loss, test_loss = np.inf, np.inf
17
+ num_train_batch = len(self.train_dataloader)
18
+ patience = self.patience
19
+
20
+ for epoch_index in range(self.max_epochs):
21
+ self.current_epoch = epoch_index
22
+ train_loss = 0
23
+
24
+ for batch_index, (_, gold_tags, _, _, logits) in enumerate(self.tag(
25
+ self.train_dataloader, is_train=True
26
+ ), 1):
27
+ self.current_timestep += 1
28
+ batch_loss = self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
29
+ batch_loss.backward()
30
+
31
+ # Avoid exploding gradient by doing gradient clipping
32
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
33
+
34
+ self.optimizer.step()
35
+ self.scheduler.step()
36
+ train_loss += batch_loss.item()
37
+
38
+ if self.current_timestep % self.log_interval == 0:
39
+ logger.info(
40
+ "Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
41
+ epoch_index,
42
+ batch_index,
43
+ num_train_batch,
44
+ self.current_timestep,
45
+ self.optimizer.param_groups[0]['lr'],
46
+ batch_loss.item()
47
+ )
48
+
49
+ train_loss /= num_train_batch
50
+
51
+ logger.info("** Evaluating on validation dataset **")
52
+ val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
53
+ val_metrics = compute_single_label_metrics(segments)
54
+
55
+ epoch_summary_loss = {
56
+ "train_loss": train_loss,
57
+ "val_loss": val_loss
58
+ }
59
+ epoch_summary_metrics = {
60
+ "val_micro_f1": val_metrics.micro_f1,
61
+ "val_precision": val_metrics.precision,
62
+ "val_recall": val_metrics.recall
63
+ }
64
+
65
+ logger.info(
66
+ "Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
67
+ epoch_index,
68
+ self.current_timestep,
69
+ train_loss,
70
+ val_loss,
71
+ val_metrics.micro_f1
72
+ )
73
+
74
+ if val_loss < best_val_loss:
75
+ patience = self.patience
76
+ best_val_loss = val_loss
77
+ logger.info("** Validation improved, evaluating test data **")
78
+ test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
79
+ self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
80
+ test_metrics = compute_single_label_metrics(segments)
81
+
82
+ epoch_summary_loss["test_loss"] = test_loss
83
+ epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
84
+ epoch_summary_metrics["test_precision"] = test_metrics.precision
85
+ epoch_summary_metrics["test_recall"] = test_metrics.recall
86
+
87
+ logger.info(
88
+ f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
89
+ epoch_index,
90
+ self.current_timestep,
91
+ test_loss,
92
+ test_metrics.micro_f1
93
+ )
94
+
95
+ self.save()
96
+ else:
97
+ patience -= 1
98
+
99
+ # No improvements, terminating early
100
+ if patience == 0:
101
+ logger.info("Early termination triggered")
102
+ break
103
+
104
+ self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
105
+ self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
106
+
107
+ def eval(self, dataloader):
108
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
109
+ loss = 0
110
+
111
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
112
+ dataloader, is_train=False
113
+ ):
114
+ loss += self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
115
+ preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
116
+ segments += tokens
117
+ valid_lens += list(valid_len)
118
+
119
+ loss /= len(dataloader)
120
+
121
+ # Update segments, attach predicted tags to each token
122
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
123
+
124
+ return preds, segments, valid_lens, loss.item()
125
+
126
+ def infer(self, dataloader):
127
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
128
+
129
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
130
+ dataloader, is_train=False
131
+ ):
132
+ preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
133
+ segments += tokens
134
+ valid_lens += list(valid_len)
135
+
136
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
137
+ return segments
138
+
139
+ def to_segments(self, segments, preds, valid_lens, vocab):
140
+ if vocab is None:
141
+ vocab = self.vocab
142
+
143
+ tagged_segments = list()
144
+ tokens_stoi = vocab.tokens.get_stoi()
145
+ tags_itos = vocab.tags[0].get_itos()
146
+ unk_id = tokens_stoi["UNK"]
147
+
148
+ for segment, pred, valid_len in zip(segments, preds, valid_lens):
149
+ # First, the token at 0th index [CLS] and token at nth index [SEP]
150
+ # Combine the tokens with their corresponding predictions
151
+ segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
152
+
153
+ # Ignore the sub-tokens/subwords, which are identified with text being UNK
154
+ segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
155
+
156
+ # Attach the predicted tags to each token
157
+ list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": tags_itos[t[1]]}]), segment_pred))
158
+
159
+ # We are only interested in the tagged tokens, we do no longer need raw model predictions
160
+ tagged_segment = [t for t, _ in segment_pred]
161
+ tagged_segments.append(tagged_segment)
162
+
163
+ return tagged_segments
@@ -1,3 +1,3 @@
1
- from sinatools.ner.trainers.BaseTrainer import BaseTrainer
2
- from sinatools.ner.trainers.BertTrainer import BertTrainer
1
+ from sinatools.ner.trainers.BaseTrainer import BaseTrainer
2
+ from sinatools.ner.trainers.BertTrainer import BertTrainer
3
3
  from sinatools.ner.trainers.BertNestedTrainer import BertNestedTrainer
@@ -101,56 +101,91 @@ def get_intersection(list1, list2, ignore_all_diacritics_but_not_shadda=False, i
101
101
 
102
102
 
103
103
 
104
- def get_union(list1, list2, ignore_all_diacritics_but_not_shadda, ignore_shadda_diacritic):
105
- """
106
- Computes the union of two sets of Arabic words, considering the differences in their diacritization. The method provides two options for handling diacritics: (i) ignore all diacritics except for shadda, and (ii) ignore the shadda diacritic as well. You can try the demo online.
104
+ # def get_union(list1, list2, ignore_all_diacritics_but_not_shadda, ignore_shadda_diacritic):
105
+ # """
106
+ # Computes the union of two sets of Arabic words, considering the differences in their diacritization. The method provides two options for handling diacritics: (i) ignore all diacritics except for shadda, and (ii) ignore the shadda diacritic as well. You can try the demo online.
107
107
 
108
- Args:
109
- list1 (:obj:`list`): The first list.
110
- list2 (:obj:`bool`): The second list.
111
- ignore_all_diacratics_but_not_shadda (:obj:`bool`, optional) – A flag to ignore all diacratics except for the shadda. Defaults to False.
112
- ignore_shadda_diacritic (:obj:`bool`, optional) – A flag to ignore the shadda diacritic. Defaults to False.
108
+ # Args:
109
+ # list1 (:obj:`list`): The first list.
110
+ # list2 (:obj:`bool`): The second list.
111
+ # ignore_all_diacratics_but_not_shadda (:obj:`bool`, optional) – A flag to ignore all diacratics except for the shadda. Defaults to False.
112
+ # ignore_shadda_diacritic (:obj:`bool`, optional) – A flag to ignore the shadda diacritic. Defaults to False.
113
113
 
114
- Returns:
115
- :obj:`list`: The union of the two lists, ignoring diacritics if flags are true.
114
+ # Returns:
115
+ # :obj:`list`: The union of the two lists, ignoring diacritics if flags are true.
116
116
 
117
- **Example:**
117
+ # **Example:**
118
118
 
119
- .. highlight:: python
120
- .. code-block:: python
119
+ # .. highlight:: python
120
+ # .. code-block:: python
121
121
 
122
- from sinatools.utils.similarity import get_union
123
- list1 = ["كتب","فَعل","فَعَلَ"]
124
- list2 = ["كتب","فَعّل"]
125
- print(get_union(list1, list2, False, True))
126
- #output: ["كتب" ,"فَعل" ,"فَعَلَ"]
127
- """
128
- list1 = [str(i) for i in list1 if i not in (None, ' ', '')]
122
+ # from sinatools.utils.similarity import get_union
123
+ # list1 = ["كتب","فَعل","فَعَلَ"]
124
+ # list2 = ["كتب","فَعّل"]
125
+ # print(get_union(list1, list2, False, True))
126
+ # #output: ["كتب" ,"فَعل" ,"فَعَلَ"]
127
+ # """
128
+ # list1 = [str(i) for i in list1 if i not in (None, ' ', '')]
129
129
 
130
+ # list2 = [str(i) for i in list2 if i not in (None, ' ', '')]
131
+
132
+ # union_list = []
133
+
134
+ # for list1_word in list1:
135
+ # word1 = normalize_word(list1_word, ignore_all_diacritics_but_not_shadda, ignore_shadda_diacritic)
136
+ # union_list.append(word1)
137
+
138
+ # for list2_word in list2:
139
+ # word2 = normalize_word(list2_word, ignore_all_diacritics_but_not_shadda, ignore_shadda_diacritic)
140
+ # union_list.append(word2)
141
+
142
+ # i = 0
143
+ # while i < len(union_list):
144
+ # j = i + 1
145
+ # while j < len(union_list):
146
+ # non_preferred_word = get_non_preferred_word(union_list[i], union_list[j])
147
+ # if (non_preferred_word != "#"):
148
+ # union_list.remove(non_preferred_word)
149
+ # j = j + 1
150
+ # i = i + 1
151
+
152
+ # return union_list
153
+ def get_union(list1, list2, ignore_all_diacritics_but_not_shadda, ignore_shadda_diacritic):
154
+
155
+
156
+ list1 = [str(i) for i in list1 if i not in (None, ' ', '')]
130
157
  list2 = [str(i) for i in list2 if i not in (None, ' ', '')]
131
158
 
159
+
132
160
  union_list = []
133
161
 
162
+ # Normalize and add words from list1
134
163
  for list1_word in list1:
135
164
  word1 = normalize_word(list1_word, ignore_all_diacritics_but_not_shadda, ignore_shadda_diacritic)
136
- union_list.append(word1)
165
+ if word1 not in union_list:
166
+ union_list.append(word1)
137
167
 
168
+ # Normalize and add words from list2
138
169
  for list2_word in list2:
139
170
  word2 = normalize_word(list2_word, ignore_all_diacritics_but_not_shadda, ignore_shadda_diacritic)
140
- union_list.append(word2)
171
+ if word2 not in union_list:
172
+ union_list.append(word2)
141
173
 
174
+
142
175
  i = 0
143
176
  while i < len(union_list):
144
177
  j = i + 1
145
178
  while j < len(union_list):
146
179
  non_preferred_word = get_non_preferred_word(union_list[i], union_list[j])
147
- if (non_preferred_word != "#"):
180
+ if non_preferred_word != "#":
148
181
  union_list.remove(non_preferred_word)
149
- j = j + 1
150
- i = i + 1
182
+ j -= 1
183
+ j += 1
184
+ i += 1
151
185
 
152
186
  return union_list
153
-
187
+
188
+
154
189
 
155
190
 
156
191
  def get_jaccard_similarity(list1: list, list2: list, ignore_all_diacritics_but_not_shadda: bool, ignore_shadda_diacritic: bool) -> float:
@@ -184,7 +219,7 @@ def get_jaccard_similarity(list1: list, list2: list, ignore_all_diacritics_but_n
184
219
 
185
220
  return float(len(intersection_list)) / float(len(union_list))
186
221
 
187
- def get_jaccard(delimiter, str1, str2, selection, ignoreAllDiacriticsButNotShadda=True, ignoreShaddaDiacritic=True):
222
+ def get_jaccard(delimiter, selection, str1, str2, ignoreAllDiacriticsButNotShadda=True, ignoreShaddaDiacritic=True):
188
223
  """
189
224
  Calculates and returns the Jaccard similarity values (union, intersection, or Jaccard similarity) between two lists of Arabic words, considering the differences in their diacritization. The method provides two options for handling diacritics: (i) ignore all diacritics except for shadda, and (ii) ignore the shadda diacritic as well. You can try the demo online.
190
225
 
@@ -8,10 +8,6 @@ from sinatools.morphology.ALMA_multi_word import ALMA_multi_word
8
8
  from sinatools.morphology.morph_analyzer import analyze
9
9
  from sinatools.ner.entity_extractor import extract
10
10
  from . import glosses_dic
11
- import time
12
- #import concurrent
13
- #import threading
14
- import multiprocessing
15
11
 
16
12
 
17
13
  def distill_entities(entities):
@@ -260,7 +256,7 @@ def find_named_entities(string):
260
256
  return found_entities
261
257
 
262
258
 
263
- def find_glosses_using_ALMA(word, glosses_dic):
259
+ def find_glosses_using_ALMA(word):
264
260
 
265
261
  data = analyze(word, language ='MSA', task ='full', flag="1")
266
262
  Diac_lemma = ""
@@ -306,7 +302,7 @@ def disambiguate_glosses_using_SALMA(glosses, Diac_lemma, Undiac_lemma, word, se
306
302
  return my_json
307
303
 
308
304
 
309
- def find_glosses(input_sentence, two_word_lemma, three_word_lemma,four_word_lemma, five_word_lemma, ner, glosses_dic):
305
+ def find_glosses(input_sentence, two_word_lemma, three_word_lemma,four_word_lemma, five_word_lemma, ner):
310
306
  output_list = []
311
307
  position = 0
312
308
  while position < len(input_sentence):
@@ -393,7 +389,7 @@ def find_glosses(input_sentence, two_word_lemma, three_word_lemma,four_word_lemm
393
389
 
394
390
  if flag == "False": # Not found in ner or in multi_word_dictionary, ASK ALMA
395
391
  word = input_sentence[position]
396
- word, Undiac_lemma, Diac_lemma, pos , concept_count, glosses = find_glosses_using_ALMA(word, glosses_dic)
392
+ word, Undiac_lemma, Diac_lemma, pos , concept_count, glosses = find_glosses_using_ALMA(word)
397
393
  my_json = {}
398
394
  my_json['word'] = word
399
395
  my_json['concept_count'] = concept_count
@@ -436,95 +432,26 @@ def disambiguate_glosses_main(word, sentence):
436
432
  glosses = word['glosses']
437
433
  Diac_lemma = word['Diac_lemma']
438
434
  Undiac_lemma = word['Undiac_lemma']
439
- start = time.time()
440
- x = disambiguate_glosses_using_SALMA(glosses, Diac_lemma, Undiac_lemma, input_word, sentence)
441
- end = time.time()
442
- print(f"disambiguate time: {end - start}")
443
- return x
444
-
445
-
446
- def init_resources():
447
- global glosses_dic
448
-
449
-
450
- # Wrapper function for multiprocessing
451
- def disambiguate_glosses_in_parallel(word_and_sentence):
452
- word, sentence = word_and_sentence
453
- return disambiguate_glosses_main(word, sentence)
435
+ return disambiguate_glosses_using_SALMA(glosses, Diac_lemma, Undiac_lemma, input_word, sentence)
454
436
 
455
437
  def WSD(sentence):
456
- start = time.time()
438
+
457
439
  input_sentence = simple_word_tokenize(sentence)
458
- end = time.time()
459
- print(f"tokenizer time: {end - start}")
460
-
461
- start = time.time()
440
+
462
441
  five_word_lemma = find_five_word_lemma(input_sentence)
463
- end = time.time()
464
- print(f"5grams time: {end - start}")
465
442
 
466
- start = time.time()
467
443
  four_word_lemma = find_four_word_lemma(input_sentence)
468
- end = time.time()
469
- print(f"4grams time: {end - start}")
470
-
471
- start = time.time()
444
+
472
445
  three_word_lemma = find_three_word_lemma(input_sentence)
473
- end = time.time()
474
- print(f"3grams time: {end - start}")
475
-
476
- start = time.time()
446
+
477
447
  two_word_lemma = find_two_word_lemma(input_sentence)
478
- end = time.time()
479
- print(f"2grams time: {end - start}")
480
-
481
- start = time.time()
448
+
482
449
  ner = find_named_entities(" ".join(input_sentence))
483
- end = time.time()
484
- print(f"ner time: {end - start}")
485
-
486
-
487
- start = time.time()
488
- output_list = find_glosses(input_sentence, two_word_lemma, three_word_lemma, four_word_lemma, five_word_lemma, ner, glosses_dic_shared)
489
- end = time.time()
490
- print(f"lookup time: {end - start}")
491
-
492
- # for word in output_list:
493
- # start = time.time()
494
- # results.append(disambiguate_glosses_main(word, sentence))
495
- # end = time.time()
496
- # print(f"disambiguate time: {end - start}")
497
- # return results
498
-
499
- # with concurrent.futures.ProcessPoolExecutor() as executor:
500
- # results = list(executor.map(lambda word: disambiguate_glosses_main(word, sentence), output_list))
501
- # return results
502
-
503
- # Create and start threads
504
- # for word in output_list:
505
- # thread = threading.Thread(target=worker, args=(word, sentence))
506
- # threads.append(thread)
507
- # thread.start()
508
- #
509
- # for thread in threads:
510
- # thread.join()
511
- #
512
- # return threading_results
513
-
514
- # Number of CPUs
515
- num_cpus = multiprocessing.cpu_count()
516
- print("num_cpus : ", num_cpus)
517
-
518
- # Create a manager to hold shared data
519
- # with multiprocessing.Manager() as manager:
520
- # glosses_dic_shared = manager.dict(glosses_dic)
521
- # with multiprocessing.Pool(num_cpus) as pool:
522
- # arguments = [(word, sentence) for word in output_list]
523
- # results = pool.starmap(disambiguate_glosses_main, arguments)
524
-
525
- with multiprocessing.Pool(initializer=init_resources) as pool:
526
- # Map the list of words to the disambiguation function in parallel
527
- results = pool.map(disambiguate_glosses_in_parallel, [(word, sentence) for word in output_list])
450
+
451
+ output_list = find_glosses(input_sentence, two_word_lemma, three_word_lemma, four_word_lemma, five_word_lemma, ner)
452
+ results = []
453
+ for word in output_list:
454
+ results.append(disambiguate_glosses_main(word, sentence))
528
455
  return results
529
456
 
530
457
 
@@ -570,8 +497,5 @@ def disambiguate(sentence):
570
497
  content = ["Input is too long"]
571
498
  return content
572
499
  else:
573
- start = time.time()
574
500
  results = WSD(sentence)
575
- end = time.time()
576
- print(f"WSD total time: {end - start}")
577
501
  return results