SinaTools 0.1.24__py2.py3-none-any.whl → 0.1.26__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: SinaTools
3
- Version: 0.1.24
3
+ Version: 0.1.26
4
4
  Summary: Open-source Python toolkit for Arabic Natural Understanding, allowing people to integrate it in their system workflow.
5
5
  Home-page: https://github.com/SinaLab/sinatools
6
6
  License: MIT license
@@ -1,14 +1,14 @@
1
- SinaTools-0.1.24.data/data/sinatools/environment.yml,sha256=OzilhLjZbo_3nU93EQNUFX-6G5O3newiSWrwxvMH2Os,7231
2
- sinatools/VERSION,sha256=S6iCAVLWhyRA7MIGZk5mjvtI6v6w1_bWDqhs5ui2fDk,6
1
+ SinaTools-0.1.26.data/data/sinatools/environment.yml,sha256=OzilhLjZbo_3nU93EQNUFX-6G5O3newiSWrwxvMH2Os,7231
2
+ sinatools/VERSION,sha256=5E6i4X07Go6cKsVD3uEZkX9jXfyE05s7HlzVXSisTX8,6
3
3
  sinatools/__init__.py,sha256=bEosTU1o-FSpyytS6iVP_82BXHF2yHnzpJxPLYRbeII,135
4
4
  sinatools/environment.yml,sha256=OzilhLjZbo_3nU93EQNUFX-6G5O3newiSWrwxvMH2Os,7231
5
5
  sinatools/install_env.py,sha256=EODeeE0ZzfM_rz33_JSIruX03Nc4ghyVOM5BHVhsZaQ,404
6
6
  sinatools/sinatools.py,sha256=vR5AaF0iel21LvsdcqwheoBz0SIj9K9I_Ub8M8oA98Y,20
7
- sinatools/CLI/DataDownload/download_files.py,sha256=KG9W-Y5kJG_9yLUyo-cA33B5uO3avdZ5sSYUeW3wM6s,1960
7
+ sinatools/CLI/DataDownload/download_files.py,sha256=VunXU_vAweKs7aS0FNM84N_2lhYT5T94Y8B3NWmGksg,2630
8
8
  sinatools/CLI/morphology/ALMA_multi_word.py,sha256=ZImJ1vtcpSHydI1BjJmK3KcMJbGBZX16kO4L6rxvBvA,2086
9
9
  sinatools/CLI/morphology/morph_analyzer.py,sha256=ieIM47QK9Nct3MtCS9uq3h2rZN5r4qNhsLmlVeE6wiE,3503
10
- sinatools/CLI/ner/corpus_entity_extractor.py,sha256=_o0frMSgpsFVXPoztS3mQTK7LjHsgzUv9gfs6iJL424,4024
11
- sinatools/CLI/ner/entity_extractor.py,sha256=zn0Jd37BEDE1wHE5HOAK0_N2tURAznFNj7WDd6WGLIw,2932
10
+ sinatools/CLI/ner/corpus_entity_extractor.py,sha256=Da-DHFrqT6if7w6WnodB4TBE5ze3DJYjb2Mmju_Qd7g,4034
11
+ sinatools/CLI/ner/entity_extractor.py,sha256=IiTioe0px0aJ1E58FrDVa2yNgM8Ie4uS2LZKK_z2Qn4,2942
12
12
  sinatools/CLI/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  sinatools/CLI/utils/arStrip.py,sha256=NLyp8vOu2xv80tL9jiKRvyptmbkRZVg-wcAr-9YyvNY,3264
14
14
  sinatools/CLI/utils/corpus_tokenizer.py,sha256=nH0T4h6urr_0Qy6-wN3PquOtnwybj0REde5Ts_OE4U8,1650
@@ -20,7 +20,7 @@ sinatools/CLI/utils/sentence_tokenizer.py,sha256=Wli8eiDbWSd_Z8UKpu_JkaS8jImowa1
20
20
  sinatools/CLI/utils/text_dublication_detector.py,sha256=dW70O5O20GxeUDDF6zVYn52wWLmJF-HBZgvqIeVL2rQ,1661
21
21
  sinatools/CLI/utils/text_transliteration.py,sha256=vz-3kxWf8pNYVCqNAtBAiA6u_efrS5NtWT-ofN1NX6I,2014
22
22
  sinatools/DataDownload/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
- sinatools/DataDownload/downloader.py,sha256=sLmVvnr3mG_tqvGCggzxwsi2sixlKlgCbMnZhCclSpg,6390
23
+ sinatools/DataDownload/downloader.py,sha256=F-SV-0mbYMYFSNCx8FoAYXhn0X1j0dF37PTLU0nUBVg,6482
24
24
  sinatools/arabert/__init__.py,sha256=ely2PttjgSv7vKdzskuD1rtK_l_UOpmxJSz8isrveD0,16
25
25
  sinatools/arabert/preprocess.py,sha256=qI0FsuMTOzdRlYGCtLrjpXgikNElUZPv9bnjaKDZKJ4,33024
26
26
  sinatools/arabert/arabert/__init__.py,sha256=KbSAH-XqbRygn0y59m5-ZYOLXgpT1gSgE3F-qd4rKEc,627
@@ -80,7 +80,7 @@ sinatools/ner/__init__.py,sha256=gSs0x6veWJ8j3_iOs79tynBd_hJP0t44CGpJ0xzoiW4,104
80
80
  sinatools/ner/data.py,sha256=lvOW86dXse8SC75Q0supQaE0rrRffoxNjIA0Qbv5WZY,4354
81
81
  sinatools/ner/data_format.py,sha256=7Yt0aOicOn9_YuuyCkM_IYi_rgjGYxR9bCuUaNGM73o,4341
82
82
  sinatools/ner/datasets.py,sha256=mG1iwqSm3lXCFHLqE-b4wNi176cpuzNBz8tKaBU6z6M,5059
83
- sinatools/ner/entity_extractor.py,sha256=k0Yvvg_aknINkFSdqOgG1KulS0UIo-W0qycv9J2MtNo,2273
83
+ sinatools/ner/entity_extractor.py,sha256=yQnfayT03qAnQ4FBdBFhvl8M2pgIttrdWSWE9wgO2LI,1876
84
84
  sinatools/ner/helpers.py,sha256=dnOoDY5JMyOLTUWVIZLMt8mBn2IbWlVaqHhQyjs1voo,2343
85
85
  sinatools/ner/metrics.py,sha256=Irz6SsIvpOzGIA2lWxrEV86xnTnm0TzKm9SUVT4SXUU,2734
86
86
  sinatools/ner/transforms.py,sha256=vti3mDdi-IRP8i0aTQ37QqpPlP9hdMmJ6_bAMa0uL-s,4871
@@ -91,7 +91,7 @@ sinatools/ner/nn/BaseModel.py,sha256=3GmujQasTZZunOBuFXpY2p1W8W256iI_Uu4hxhOY2Z0
91
91
  sinatools/ner/nn/BertNestedTagger.py,sha256=_fwAn1kiKmXe6m5y16Ipty3kvXIEFEmiUq74Ad1818U,1219
92
92
  sinatools/ner/nn/BertSeqTagger.py,sha256=dFcBBiMw2QCWsyy7aQDe_PS3aRuNn4DOxKIHgTblFvc,504
93
93
  sinatools/ner/nn/__init__.py,sha256=UgQD_XLNzQGBNSYc_Bw1aRJZjq4PJsnMT1iZwnJemqE,170
94
- sinatools/ner/trainers/BaseTrainer.py,sha256=oZgFJW-CawfCKT5gtaBHA7Q7XjNfiyqM62KnFsgVzPU,3919
94
+ sinatools/ner/trainers/BaseTrainer.py,sha256=Ifz4SeTxJwVn1_uWZ3I9KbcSo2hLPN3ojsIYuoKE9wE,4050
95
95
  sinatools/ner/trainers/BertNestedTrainer.py,sha256=Pb4O2WeBmTvV3hHMT6DXjxrTzgtuh3OrKQZnogYy8RQ,8429
96
96
  sinatools/ner/trainers/BertTrainer.py,sha256=B_uVtUwfv_eFwMMPsKQvZgW_ZNLy6XEsX5ePR0s8d-k,6433
97
97
  sinatools/ner/trainers/__init__.py,sha256=UDok8pDDpYOpwRBBKVLKaOgSUlmqqb-zHZI1p0xPxzI,188
@@ -110,13 +110,13 @@ sinatools/utils/text_transliteration.py,sha256=NQoXrxI-h0UXnvVtDA3skNJduxIy0IW26
110
110
  sinatools/utils/tokenizer.py,sha256=QHyrVqJA_On4rKxexiWR2ovq4pI1-u6iZkdhRbK9tew,6676
111
111
  sinatools/utils/tokenizers_words.py,sha256=efNfOil9qDNVJ9yynk_8sqf65PsL-xtsHG7y2SZCkjQ,656
112
112
  sinatools/wsd/__init__.py,sha256=yV-SQSCzSrjbNkciMbDCqzGZ_EESchL7rlJk56uibVI,309
113
- sinatools/wsd/disambiguator.py,sha256=8HrVAGpEQyrzwiuEreLX9X82WSL-U2Aeca0ttrtIw2Y,19998
113
+ sinatools/wsd/disambiguator.py,sha256=43Iq7NTZsiYWGFg-NUDrQuJKO1NT9QOnfBPB10IOJNs,19828
114
114
  sinatools/wsd/settings.py,sha256=6XflVTFKD8SVySX9Wj7zYQtV26WDTcQ2-uW8-gDNHKE,747
115
115
  sinatools/wsd/wsd.py,sha256=gHIBUFXegoY1z3rRnIlK6TduhYq2BTa_dHakOjOlT4k,4434
116
- SinaTools-0.1.24.dist-info/AUTHORS.rst,sha256=aTWeWlIdfLi56iLJfIUAwIrmqDcgxXKLji75_Fjzjyg,174
117
- SinaTools-0.1.24.dist-info/LICENSE,sha256=uwsKYG4TayHXNANWdpfMN2lVW4dimxQjA_7vuCVhD70,1088
118
- SinaTools-0.1.24.dist-info/METADATA,sha256=TS_IfzeMqZsoClo4KPnnhsTHbuo8sWNBXB2ByHkrY_M,953
119
- SinaTools-0.1.24.dist-info/WHEEL,sha256=6T3TYZE4YFi2HTS1BeZHNXAi8N52OZT4O-dJ6-ome_4,116
120
- SinaTools-0.1.24.dist-info/entry_points.txt,sha256=ZwZLolnWog2fjdDrfaHNHob8SE_YtMbD6ayzsOzItxs,1234
121
- SinaTools-0.1.24.dist-info/top_level.txt,sha256=8tNdPTeJKw3TQCaua8IJIx6N6WpgZZmVekf1OdBNJpE,10
122
- SinaTools-0.1.24.dist-info/RECORD,,
116
+ SinaTools-0.1.26.dist-info/AUTHORS.rst,sha256=aTWeWlIdfLi56iLJfIUAwIrmqDcgxXKLji75_Fjzjyg,174
117
+ SinaTools-0.1.26.dist-info/LICENSE,sha256=uwsKYG4TayHXNANWdpfMN2lVW4dimxQjA_7vuCVhD70,1088
118
+ SinaTools-0.1.26.dist-info/METADATA,sha256=jqsARSXI1Z0hT9-ev6ewzZeNH_H350lv_c2oav_SKWg,953
119
+ SinaTools-0.1.26.dist-info/WHEEL,sha256=6T3TYZE4YFi2HTS1BeZHNXAi8N52OZT4O-dJ6-ome_4,116
120
+ SinaTools-0.1.26.dist-info/entry_points.txt,sha256=ZwZLolnWog2fjdDrfaHNHob8SE_YtMbD6ayzsOzItxs,1234
121
+ SinaTools-0.1.26.dist-info/top_level.txt,sha256=8tNdPTeJKw3TQCaua8IJIx6N6WpgZZmVekf1OdBNJpE,10
122
+ SinaTools-0.1.26.dist-info/RECORD,,
@@ -40,7 +40,7 @@ from sinatools.DataDownload.downloader import urls
40
40
 
41
41
  def main():
42
42
  parser = argparse.ArgumentParser(description="Download files from specified URLs.")
43
- parser.add_argument('-f', '--files', nargs="*", choices=urls.keys(),
43
+ parser.add_argument('-f', '--files', nargs="*",
44
44
  help="Names of the files to download. Available files are: "
45
45
  f"{', '.join(urls.keys())}. If no file is specified, all files will be downloaded.")
46
46
 
@@ -50,8 +50,23 @@ def main():
50
50
 
51
51
  if args.files:
52
52
  for file in args.files:
53
- url = urls[file]
54
- download_file(url)
53
+ print("file: ", file)
54
+ if file == "wsd":
55
+ download_file(urls["morph"])
56
+ download_file(urls["ner"])
57
+ download_file(urls["wsd_model"])
58
+ download_file(urls["wsd_tokenizer"])
59
+ download_file(urls["glosses_dic"])
60
+ download_file(urls["five_grams"])
61
+ download_file(urls["four_grams"])
62
+ download_file(urls["three_grams"])
63
+ download_file(urls["two_grams"])
64
+ elif file == "synonyms":
65
+ download_file(urls["synonyms_level2"])
66
+ download_file(urls["synonyms_level3"])
67
+ else:
68
+ url = urls[file]
69
+ download_file(url)
55
70
  else:
56
71
  download_files()
57
72
 
@@ -20,7 +20,7 @@ def jsons_to_list_of_lists(json_list):
20
20
  return [[d['token'], d['tags']] for d in json_list]
21
21
 
22
22
  def combine_tags(sentence):
23
- output = jsons_to_list_of_lists(extract(sentence))
23
+ output = jsons_to_list_of_lists(extract(sentence, "nested"))
24
24
  return [word[1] for word in output]
25
25
 
26
26
 
@@ -46,7 +46,7 @@ def jsons_to_list_of_lists(json_list):
46
46
  return [[d['token'], d['tags']] for d in json_list]
47
47
 
48
48
  def combine_tags(sentence):
49
- output = jsons_to_list_of_lists(extract(sentence))
49
+ output = jsons_to_list_of_lists(extract(sentence, "nested"))
50
50
  return [word[1] for word in output]
51
51
 
52
52
 
@@ -95,37 +95,41 @@ def download_file(url, dest_path=get_appdatadir()):
95
95
  print(filename)
96
96
  headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}
97
97
 
98
- # try:
99
- with requests.get(url, headers=headers, stream=True) as r:
100
- r.raise_for_status()
101
- with open(file_path, 'wb') as f:
102
- total_size = int(r.headers.get('content-length', 0))
103
- block_size = 8192
104
- progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
105
- for chunk in r.iter_content(chunk_size=block_size):
106
- if chunk:
107
- f.write(chunk)
108
- progress_bar.update(len(chunk))
109
- progress_bar.close()
110
- # Check the file type and extract accordingly
111
- file_extension = os.path.splitext(file_path)[1]
112
- extracted_folder_name = os.path.splitext(file_path)[0]
113
-
114
- if file_extension == '.zip':
115
- extract_zip(file_path, extracted_folder_name)
116
- elif file_extension == '.gz':
117
- extract_tar(file_path, extracted_folder_name)
118
- elif file_extension =='.pickle':
119
- print(f'Done: {file_extension}')
120
- else:
121
- print(f'Unsupported file type for extraction: {file_extension}')
122
- return file_path
123
-
124
- # except requests.exceptions.HTTPError as e:
125
- # if e.response.status_code == 403:
126
- # print(f'Error 403: Forbidden. The requested file URL {url} could not be downloaded due to insufficient permissions. Please check the URL and try again.')
127
- # else:
128
- # print('An error occurred while downloading the file:', e)
98
+ try:
99
+ with requests.get(url, headers=headers, stream=True) as r:
100
+ r.raise_for_status()
101
+ with open(file_path, 'wb') as f:
102
+ total_size = int(r.headers.get('content-length', 0))
103
+ block_size = 8192
104
+ progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
105
+ for chunk in r.iter_content(chunk_size=block_size):
106
+ if chunk:
107
+ f.write(chunk)
108
+ progress_bar.update(len(chunk))
109
+ progress_bar.close()
110
+
111
+ # Check the file type and extract accordingly
112
+ file_extension = os.path.splitext(file_path)[1]
113
+ extracted_folder_name = os.path.splitext(file_path)[0]
114
+
115
+ if file_extension == '.zip':
116
+ extract_zip(file_path, extracted_folder_name)
117
+ elif file_extension == '.gz':
118
+
119
+ extract_tar(file_path, extracted_folder_name)
120
+ elif file_extension =='.pickle':
121
+ print(f'Done: {file_extension}')
122
+
123
+ else:
124
+ print(f'Unsupported file type for extraction: {file_extension}')
125
+
126
+ return file_path
127
+
128
+ except requests.exceptions.HTTPError as e:
129
+ if e.response.status_code == 403:
130
+ print(f'Error 403: Forbidden. The requested file URL {url} could not be downloaded due to insufficient permissions. Please check the URL and try again.')
131
+ else:
132
+ print('An error occurred while downloading the file:', e)
129
133
 
130
134
  def extract_zip(file_path, extracted_folder_name):
131
135
  """
sinatools/VERSION CHANGED
@@ -1 +1 @@
1
- 0.1.24
1
+ 0.1.26
@@ -3,43 +3,31 @@ from collections import namedtuple
3
3
  from sinatools.ner.data_format import get_dataloaders, text2segments
4
4
  from . import tagger, tag_vocab, train_config
5
5
 
6
- def extract(text, batch_size=32):
7
- """
8
- This method processes an input text and returns named entites for each token within the text, based on the specified batch size. As follows:
9
6
 
10
- Args:
11
- text (:obj:`str`): The Arabic text to be tagged.
12
- batch_size (int, optional): Batch size for inference. Default is 32.
13
-
14
- Returns:
15
- list (:obj:`list`): A list of JSON objects, where each JSON could be contains:
16
- token: The token from the original text.
17
- NER tag: The label pairs for each segment.
18
-
19
- **Example:**
20
-
21
- .. highlight:: python
22
- .. code-block:: python
7
+ def convert_nested_to_flat(nested_tags):
8
+ flat_tags = []
9
+
10
+ for entry in nested_tags:
11
+ word = entry['token']
12
+ tags = entry['tags'].split()
13
+
14
+ # Initialize with the first tag in the sequence
15
+ flat_tag = tags[0]
16
+
17
+ for tag in tags[1:]:
18
+ # Check if the tag is an "I-" tag, indicating continuation of an entity
19
+ if tag.startswith('I-'):
20
+ flat_tag = tag
21
+ break
22
+
23
+ flat_tags.append({
24
+ 'token': word,
25
+ 'tags': flat_tag
26
+ })
27
+
28
+ return flat_tags
23
29
 
24
- from sinatools.ner.entity_extractor import extract
25
- extract('ذهب محمد إلى جامعة بيرزيت')
26
- [{
27
- "word":"ذهب",
28
- "tags":"O"
29
- },{
30
- "word":"محمد",
31
- "tags":"B-PERS"
32
- },{
33
- "word":"إلى",
34
- "tags":"O"
35
- },{
36
- "word":"جامعة",
37
- "tags":"B-ORG"
38
- },{
39
- "word":"بيرزيت",
40
- "tags":"B-GPE I-ORG"
41
- }]
42
- """
30
+ def extract(text, ner_method):
43
31
 
44
32
  dataset, token_vocab = text2segments(text)
45
33
 
@@ -50,7 +38,7 @@ def extract(text, batch_size=32):
50
38
  (dataset,),
51
39
  vocab,
52
40
  train_config.data_config,
53
- batch_size=batch_size,
41
+ batch_size=32,
54
42
  shuffle=(False,),
55
43
  )[0]
56
44
 
@@ -69,4 +57,7 @@ def extract(text, batch_size=32):
69
57
  else:
70
58
  segments_list["tags"] = ' '.join(list_of_tags)
71
59
  segments_lists.append(segments_list)
60
+
61
+ if ner_method == "flat":
62
+ segments_lists = convert_nested_to_flat(segments_lists)
72
63
  return segments_lists
@@ -1,117 +1,117 @@
1
- import os
2
- import torch
3
- import logging
4
- import natsort
5
- import glob
6
-
7
- logger = logging.getLogger(__name__)
8
-
9
-
10
- class BaseTrainer:
11
- def __init__(
12
- self,
13
- model=None,
14
- max_epochs=50,
15
- optimizer=None,
16
- scheduler=None,
17
- loss=None,
18
- train_dataloader=None,
19
- val_dataloader=None,
20
- test_dataloader=None,
21
- log_interval=10,
22
- summary_writer=None,
23
- output_path=None,
24
- clip=5,
25
- patience=5
26
- ):
27
- self.model = model
28
- self.max_epochs = max_epochs
29
- self.train_dataloader = train_dataloader
30
- self.val_dataloader = val_dataloader
31
- self.test_dataloader = test_dataloader
32
- self.optimizer = optimizer
33
- self.scheduler = scheduler
34
- self.loss = loss
35
- self.log_interval = log_interval
36
- self.summary_writer = summary_writer
37
- self.output_path = output_path
38
- self.current_timestep = 0
39
- self.current_epoch = 0
40
- self.clip = clip
41
- self.patience = patience
42
-
43
- def tag(self, dataloader, is_train=True):
44
- """
45
- Given a dataloader containing segments, predict the tags
46
- :param dataloader: torch.utils.data.DataLoader
47
- :param is_train: boolean - True for training model, False for evaluation
48
- :return: Iterator
49
- subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
50
- gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
51
- tokens - List[arabiner.data.dataset.Token] - list of tokens
52
- valid_len (B x 1) - int - valiud length of each sequence
53
- logits (B x T x NUM_LABELS) - logits for each token and each tag
54
- """
55
- for subwords, gold_tags, tokens, valid_len in dataloader:
56
- self.model.train(is_train)
57
-
58
- if torch.cuda.is_available():
59
- subwords = subwords.cuda()
60
- gold_tags = gold_tags.cuda()
61
-
62
- if is_train:
63
- self.optimizer.zero_grad()
64
- logits = self.model(subwords)
65
- else:
66
- with torch.no_grad():
67
- logits = self.model(subwords)
68
-
69
- yield subwords, gold_tags, tokens, valid_len, logits
70
-
71
- def segments_to_file(self, segments, filename):
72
- """
73
- Write segments to file
74
- :param segments: [List[arabiner.data.dataset.Token]] - list of list of tokens
75
- :param filename: str - output filename
76
- :return: None
77
- """
78
- with open(filename, "w") as fh:
79
- results = "\n\n".join(["\n".join([t.__str__() for t in segment]) for segment in segments])
80
- fh.write("Token\tGold Tag\tPredicted Tag\n")
81
- fh.write(results)
82
- logging.info("Predictions written to %s", filename)
83
-
84
- def save(self):
85
- """
86
- Save model checkpoint
87
- :return:
88
- """
89
- filename = os.path.join(
90
- self.output_path,
91
- "checkpoints",
92
- "checkpoint_{}.pt".format(self.current_epoch),
93
- )
94
-
95
- checkpoint = {
96
- "model": self.model.state_dict(),
97
- "optimizer": self.optimizer.state_dict(),
98
- "epoch": self.current_epoch
99
- }
100
-
101
- logger.info("Saving checkpoint to %s", filename)
102
- torch.save(checkpoint, filename)
103
-
104
- def load(self, checkpoint_path):
105
- """
106
- Load model checkpoint
107
- :param checkpoint_path: str - path/to/checkpoints
108
- :return: None
109
- """
110
- checkpoint_path = natsort.natsorted(glob.glob(f"{checkpoint_path}/checkpoint_*.pt"))
111
- checkpoint_path = checkpoint_path[-1]
112
-
113
- logger.info("Loading checkpoint %s", checkpoint_path)
114
-
115
- device = None if torch.cuda.is_available() else torch.device('cpu')
116
- checkpoint = torch.load(checkpoint_path, map_location=device)
117
- self.model.load_state_dict(checkpoint["model"])
1
+ import os
2
+ import torch
3
+ import logging
4
+ import natsort
5
+ import glob
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class BaseTrainer:
11
+ def __init__(
12
+ self,
13
+ model=None,
14
+ max_epochs=50,
15
+ optimizer=None,
16
+ scheduler=None,
17
+ loss=None,
18
+ train_dataloader=None,
19
+ val_dataloader=None,
20
+ test_dataloader=None,
21
+ log_interval=10,
22
+ summary_writer=None,
23
+ output_path=None,
24
+ clip=5,
25
+ patience=5
26
+ ):
27
+ self.model = model
28
+ self.max_epochs = max_epochs
29
+ self.train_dataloader = train_dataloader
30
+ self.val_dataloader = val_dataloader
31
+ self.test_dataloader = test_dataloader
32
+ self.optimizer = optimizer
33
+ self.scheduler = scheduler
34
+ self.loss = loss
35
+ self.log_interval = log_interval
36
+ self.summary_writer = summary_writer
37
+ self.output_path = output_path
38
+ self.current_timestep = 0
39
+ self.current_epoch = 0
40
+ self.clip = clip
41
+ self.patience = patience
42
+
43
+ def tag(self, dataloader, is_train=True):
44
+ """
45
+ Given a dataloader containing segments, predict the tags
46
+ :param dataloader: torch.utils.data.DataLoader
47
+ :param is_train: boolean - True for training model, False for evaluation
48
+ :return: Iterator
49
+ subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
50
+ gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
51
+ tokens - List[arabiner.data.dataset.Token] - list of tokens
52
+ valid_len (B x 1) - int - valiud length of each sequence
53
+ logits (B x T x NUM_LABELS) - logits for each token and each tag
54
+ """
55
+ for subwords, gold_tags, tokens, valid_len in dataloader:
56
+ self.model.train(is_train)
57
+
58
+ if torch.cuda.is_available():
59
+ subwords = subwords.cuda()
60
+ gold_tags = gold_tags.cuda()
61
+
62
+ if is_train:
63
+ self.optimizer.zero_grad()
64
+ logits = self.model(subwords)
65
+ else:
66
+ with torch.no_grad():
67
+ logits = self.model(subwords)
68
+
69
+ yield subwords, gold_tags, tokens, valid_len, logits
70
+
71
+ def segments_to_file(self, segments, filename):
72
+ """
73
+ Write segments to file
74
+ :param segments: [List[arabiner.data.dataset.Token]] - list of list of tokens
75
+ :param filename: str - output filename
76
+ :return: None
77
+ """
78
+ with open(filename, "w") as fh:
79
+ results = "\n\n".join(["\n".join([t.__str__() for t in segment]) for segment in segments])
80
+ fh.write("Token\tGold Tag\tPredicted Tag\n")
81
+ fh.write(results)
82
+ logging.info("Predictions written to %s", filename)
83
+
84
+ def save(self):
85
+ """
86
+ Save model checkpoint
87
+ :return:
88
+ """
89
+ filename = os.path.join(
90
+ self.output_path,
91
+ "checkpoints",
92
+ "checkpoint_{}.pt".format(self.current_epoch),
93
+ )
94
+
95
+ checkpoint = {
96
+ "model": self.model.state_dict(),
97
+ "optimizer": self.optimizer.state_dict(),
98
+ "epoch": self.current_epoch
99
+ }
100
+
101
+ logger.info("Saving checkpoint to %s", filename)
102
+ torch.save(checkpoint, filename)
103
+
104
+ def load(self, checkpoint_path):
105
+ """
106
+ Load model checkpoint
107
+ :param checkpoint_path: str - path/to/checkpoints
108
+ :return: None
109
+ """
110
+ checkpoint_path = natsort.natsorted(glob.glob(f"{checkpoint_path}/checkpoint_*.pt"))
111
+ checkpoint_path = checkpoint_path[-1]
112
+
113
+ logger.info("Loading checkpoint %s", checkpoint_path)
114
+
115
+ device = None if torch.cuda.is_available() else torch.device('cpu')
116
+ checkpoint = torch.load(checkpoint_path, map_location=device)
117
+ self.model.load_state_dict(checkpoint["model"], strict=False)
@@ -217,7 +217,7 @@ def jsons_to_list_of_lists(json_list):
217
217
  def find_named_entities(string):
218
218
  found_entities = []
219
219
 
220
- ner_entites = extract(string)
220
+ ner_entites = extract(string, "nested")
221
221
  list_of_entites = jsons_to_list_of_lists(ner_entites)
222
222
  entites = distill_entities(list_of_entites)
223
223
 
@@ -288,17 +288,17 @@ def disambiguate_glosses_using_SALMA(glosses, Diac_lemma, Undiac_lemma, word, se
288
288
  concept_id, gloss = GlossPredictor(Diac_lemma, Undiac_lemma,word,sentence,glosses_dictionary)
289
289
 
290
290
  my_json = {}
291
- my_json['Concept_id'] = concept_id
291
+ my_json['concept_id'] = concept_id
292
292
  # my_json['Gloss'] = gloss
293
293
  my_json['word'] = word
294
- my_json['Undiac_lemma'] = Undiac_lemma
295
- my_json['Diac_lemma'] = Diac_lemma
294
+ #my_json['Undiac_lemma'] = Undiac_lemma
295
+ my_json['lemma'] = Diac_lemma
296
296
  return my_json
297
297
  else:
298
298
  my_json = {}
299
299
  my_json['word'] = word
300
- my_json['Undiac_lemma'] = Undiac_lemma
301
- my_json['Diac_lemma'] = Diac_lemma
300
+ #my_json['Undiac_lemma'] = Undiac_lemma
301
+ my_json['lemma'] = Diac_lemma
302
302
  return my_json
303
303
 
304
304
 
@@ -405,26 +405,26 @@ def disambiguate_glosses_main(word, sentence):
405
405
  if concept_count == 0:
406
406
  my_json = {}
407
407
  my_json['word'] = word['word']
408
- my_json['Diac_lemma'] = word['Diac_lemma']
409
- my_json['Undiac_lemma'] = word['Undiac_lemma']
408
+ my_json['lemma'] = word['Diac_lemma']
409
+ #my_json['Undiac_lemma'] = word['Undiac_lemma']
410
410
  return my_json
411
411
  elif concept_count == 1:
412
412
  my_json = {}
413
413
  my_json['word'] = word['word']
414
414
  glosses = word['glosses'][0]
415
415
  # my_json['Gloss'] = glosses['gloss']
416
- my_json['Concept_id'] = glosses['concept_id']
417
- my_json['Diac_lemma'] = word['Diac_lemma']
418
- my_json['Undiac_lemma'] = word['Undiac_lemma']
416
+ my_json['concept_id'] = glosses['concept_id']
417
+ my_json['lemma'] = word['Diac_lemma']
418
+ #my_json['Undiac_lemma'] = word['Undiac_lemma']
419
419
  return my_json
420
420
  elif concept_count == '*':
421
421
  my_json = {}
422
422
  my_json['word'] = word['word']
423
423
  glosses = word['glosses'][0]
424
424
  my_json['Gloss'] = glosses['gloss']
425
- my_json['Concept_id'] = glosses['concept_id']
426
- my_json['Diac_lemma'] = word['Diac_lemma']
427
- my_json['Undiac_lemma'] = word['Undiac_lemma']
425
+ my_json['concept_id'] = glosses['concept_id']
426
+ my_json['lemma'] = word['Diac_lemma']
427
+ #my_json['Undiac_lemma'] = word['Undiac_lemma']
428
428
  return my_json
429
429
  else:
430
430
  input_word = word['word']
@@ -477,21 +477,18 @@ def disambiguate(sentence):
477
477
  #output
478
478
  [
479
479
  {
480
- "Concept_id": "303019218",
480
+ "concept_id": "303019218",
481
481
  "word": "ذهبت",
482
- "Undiac_lemma": "ذهب",
483
- "Diac_lemma": "ذَهَبَ۪ 1"
482
+ "lemma": "ذَهَبَ۪ 1"
484
483
  },
485
484
  {
486
485
  "word": "إلى",
487
- "Diac_lemma": إِلَى 1,
488
- "Undiac_lemma": "الى"
486
+ "lemma": "إِلَى 1"
489
487
  },
490
488
  {
491
489
  "word": "جامعة بيرزيت",
492
- "Concept_id": "334000099",
493
- "Diac_lemma": جامِعَة بيرزَيت,
494
- "Undiac_lemma": "جامعة بيرزيت"
490
+ "concept_id": "334000099",
491
+ "lemma": "جامِعَة بيرزَيت"
495
492
  }
496
493
  ]
497
494
  """