SinaTools 0.1.4__py2.py3-none-any.whl → 0.1.7__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. {SinaTools-0.1.4.dist-info → SinaTools-0.1.7.dist-info}/METADATA +14 -20
  2. SinaTools-0.1.7.dist-info/RECORD +101 -0
  3. SinaTools-0.1.7.dist-info/entry_points.txt +18 -0
  4. SinaTools-0.1.7.dist-info/top_level.txt +1 -0
  5. {nlptools → sinatools}/CLI/DataDownload/download_files.py +9 -9
  6. {nlptools → sinatools}/CLI/morphology/ALMA_multi_word.py +10 -20
  7. sinatools/CLI/morphology/morph_analyzer.py +80 -0
  8. nlptools/CLI/arabiner/bin/infer2.py → sinatools/CLI/ner/corpus_entity_extractor.py +5 -9
  9. nlptools/CLI/arabiner/bin/infer.py → sinatools/CLI/ner/entity_extractor.py +4 -8
  10. {nlptools → sinatools}/CLI/salma/salma_tools.py +8 -8
  11. {nlptools → sinatools}/CLI/utils/arStrip.py +10 -21
  12. sinatools/CLI/utils/corpus_tokenizer.py +50 -0
  13. {nlptools → sinatools}/CLI/utils/implication.py +9 -9
  14. {nlptools → sinatools}/CLI/utils/jaccard.py +10 -10
  15. sinatools/CLI/utils/remove_latin.py +34 -0
  16. sinatools/CLI/utils/remove_punctuation.py +42 -0
  17. {nlptools → sinatools}/CLI/utils/sentence_tokenizer.py +9 -22
  18. {nlptools → sinatools}/CLI/utils/text_transliteration.py +10 -17
  19. {nlptools → sinatools}/DataDownload/downloader.py +9 -9
  20. sinatools/VERSION +1 -0
  21. {nlptools → sinatools}/__init__.py +1 -1
  22. {nlptools → sinatools}/morphology/ALMA_multi_word.py +4 -5
  23. {nlptools → sinatools}/morphology/__init__.py +4 -14
  24. sinatools/morphology/morph_analyzer.py +172 -0
  25. sinatools/ner/__init__.py +12 -0
  26. nlptools/arabiner/bin/infer.py → sinatools/ner/entity_extractor.py +9 -8
  27. {nlptools → sinatools}/salma/__init__.py +2 -2
  28. {nlptools → sinatools}/salma/settings.py +1 -1
  29. {nlptools → sinatools}/salma/views.py +9 -9
  30. {nlptools → sinatools}/salma/wsd.py +2 -2
  31. {nlptools/morphology → sinatools/utils}/charsets.py +1 -3
  32. {nlptools → sinatools}/utils/implication.py +10 -10
  33. {nlptools → sinatools}/utils/jaccard.py +2 -2
  34. {nlptools → sinatools}/utils/parser.py +18 -21
  35. {nlptools → sinatools}/utils/text_transliteration.py +1 -1
  36. nlptools/utils/corpus_tokenizer.py → sinatools/utils/tokenizer.py +58 -5
  37. {nlptools/morphology → sinatools/utils}/tokenizers_words.py +3 -6
  38. SinaTools-0.1.4.dist-info/RECORD +0 -122
  39. SinaTools-0.1.4.dist-info/entry_points.txt +0 -18
  40. SinaTools-0.1.4.dist-info/top_level.txt +0 -1
  41. nlptools/CLI/morphology/morph_analyzer.py +0 -91
  42. nlptools/CLI/utils/corpus_tokenizer.py +0 -74
  43. nlptools/CLI/utils/latin_remove.py +0 -51
  44. nlptools/CLI/utils/remove_Punc.py +0 -53
  45. nlptools/VERSION +0 -1
  46. nlptools/arabiner/bin/__init__.py +0 -14
  47. nlptools/arabiner/bin/eval.py +0 -87
  48. nlptools/arabiner/bin/process.py +0 -140
  49. nlptools/arabiner/bin/train.py +0 -221
  50. nlptools/arabiner/data/__init__.py +0 -1
  51. nlptools/arabiner/data/datasets.py +0 -146
  52. nlptools/arabiner/data/transforms.py +0 -118
  53. nlptools/arabiner/nn/BaseModel.py +0 -22
  54. nlptools/arabiner/nn/BertNestedTagger.py +0 -34
  55. nlptools/arabiner/nn/BertSeqTagger.py +0 -17
  56. nlptools/arabiner/nn/__init__.py +0 -3
  57. nlptools/arabiner/trainers/BaseTrainer.py +0 -117
  58. nlptools/arabiner/trainers/BertNestedTrainer.py +0 -203
  59. nlptools/arabiner/trainers/BertTrainer.py +0 -163
  60. nlptools/arabiner/trainers/__init__.py +0 -3
  61. nlptools/arabiner/utils/__init__.py +0 -0
  62. nlptools/arabiner/utils/data.py +0 -124
  63. nlptools/arabiner/utils/helpers.py +0 -151
  64. nlptools/arabiner/utils/metrics.py +0 -69
  65. nlptools/morphology/morph_analyzer.py +0 -171
  66. nlptools/morphology/settings.py +0 -8
  67. nlptools/utils/__init__.py +0 -0
  68. nlptools/utils/sentence_tokenizer.py +0 -53
  69. {SinaTools-0.1.4.data/data/nlptools → SinaTools-0.1.7.data/data/sinatools}/environment.yml +0 -0
  70. {SinaTools-0.1.4.dist-info → SinaTools-0.1.7.dist-info}/AUTHORS.rst +0 -0
  71. {SinaTools-0.1.4.dist-info → SinaTools-0.1.7.dist-info}/LICENSE +0 -0
  72. {SinaTools-0.1.4.dist-info → SinaTools-0.1.7.dist-info}/WHEEL +0 -0
  73. {nlptools → sinatools}/CLI/utils/__init__.py +0 -0
  74. {nlptools → sinatools}/DataDownload/__init__.py +0 -0
  75. {nlptools → sinatools}/arabert/__init__.py +0 -0
  76. {nlptools → sinatools}/arabert/arabert/__init__.py +0 -0
  77. {nlptools → sinatools}/arabert/arabert/create_classification_data.py +0 -0
  78. {nlptools → sinatools}/arabert/arabert/create_pretraining_data.py +0 -0
  79. {nlptools → sinatools}/arabert/arabert/extract_features.py +0 -0
  80. {nlptools → sinatools}/arabert/arabert/lamb_optimizer.py +0 -0
  81. {nlptools → sinatools}/arabert/arabert/modeling.py +0 -0
  82. {nlptools → sinatools}/arabert/arabert/optimization.py +0 -0
  83. {nlptools → sinatools}/arabert/arabert/run_classifier.py +0 -0
  84. {nlptools → sinatools}/arabert/arabert/run_pretraining.py +0 -0
  85. {nlptools → sinatools}/arabert/arabert/run_squad.py +0 -0
  86. {nlptools → sinatools}/arabert/arabert/tokenization.py +0 -0
  87. {nlptools → sinatools}/arabert/araelectra/__init__.py +0 -0
  88. {nlptools → sinatools}/arabert/araelectra/build_openwebtext_pretraining_dataset.py +0 -0
  89. {nlptools → sinatools}/arabert/araelectra/build_pretraining_dataset.py +0 -0
  90. {nlptools → sinatools}/arabert/araelectra/build_pretraining_dataset_single_file.py +0 -0
  91. {nlptools → sinatools}/arabert/araelectra/configure_finetuning.py +0 -0
  92. {nlptools → sinatools}/arabert/araelectra/configure_pretraining.py +0 -0
  93. {nlptools → sinatools}/arabert/araelectra/finetune/__init__.py +0 -0
  94. {nlptools → sinatools}/arabert/araelectra/finetune/feature_spec.py +0 -0
  95. {nlptools → sinatools}/arabert/araelectra/finetune/preprocessing.py +0 -0
  96. {nlptools → sinatools}/arabert/araelectra/finetune/scorer.py +0 -0
  97. {nlptools → sinatools}/arabert/araelectra/finetune/task.py +0 -0
  98. {nlptools → sinatools}/arabert/araelectra/finetune/task_builder.py +0 -0
  99. {nlptools → sinatools}/arabert/araelectra/flops_computation.py +0 -0
  100. {nlptools → sinatools}/arabert/araelectra/model/__init__.py +0 -0
  101. {nlptools → sinatools}/arabert/araelectra/model/modeling.py +0 -0
  102. {nlptools → sinatools}/arabert/araelectra/model/optimization.py +0 -0
  103. {nlptools → sinatools}/arabert/araelectra/model/tokenization.py +0 -0
  104. {nlptools → sinatools}/arabert/araelectra/pretrain/__init__.py +0 -0
  105. {nlptools → sinatools}/arabert/araelectra/pretrain/pretrain_data.py +0 -0
  106. {nlptools → sinatools}/arabert/araelectra/pretrain/pretrain_helpers.py +0 -0
  107. {nlptools → sinatools}/arabert/araelectra/run_finetuning.py +0 -0
  108. {nlptools → sinatools}/arabert/araelectra/run_pretraining.py +0 -0
  109. {nlptools → sinatools}/arabert/araelectra/util/__init__.py +0 -0
  110. {nlptools → sinatools}/arabert/araelectra/util/training_utils.py +0 -0
  111. {nlptools → sinatools}/arabert/araelectra/util/utils.py +0 -0
  112. {nlptools → sinatools}/arabert/aragpt2/__init__.py +0 -0
  113. {nlptools → sinatools}/arabert/aragpt2/create_pretraining_data.py +0 -0
  114. {nlptools → sinatools}/arabert/aragpt2/gpt2/__init__.py +0 -0
  115. {nlptools → sinatools}/arabert/aragpt2/gpt2/lamb_optimizer.py +0 -0
  116. {nlptools → sinatools}/arabert/aragpt2/gpt2/optimization.py +0 -0
  117. {nlptools → sinatools}/arabert/aragpt2/gpt2/run_pretraining.py +0 -0
  118. {nlptools → sinatools}/arabert/aragpt2/grover/__init__.py +0 -0
  119. {nlptools → sinatools}/arabert/aragpt2/grover/dataloader.py +0 -0
  120. {nlptools → sinatools}/arabert/aragpt2/grover/modeling.py +0 -0
  121. {nlptools → sinatools}/arabert/aragpt2/grover/modeling_gpt2.py +0 -0
  122. {nlptools → sinatools}/arabert/aragpt2/grover/optimization_adafactor.py +0 -0
  123. {nlptools → sinatools}/arabert/aragpt2/grover/train_tpu.py +0 -0
  124. {nlptools → sinatools}/arabert/aragpt2/grover/utils.py +0 -0
  125. {nlptools → sinatools}/arabert/aragpt2/train_bpe_tokenizer.py +0 -0
  126. {nlptools → sinatools}/arabert/preprocess.py +0 -0
  127. {nlptools → sinatools}/environment.yml +0 -0
  128. {nlptools → sinatools}/install_env.py +0 -0
  129. /nlptools/nlptools.py → /sinatools/sinatools.py +0 -0
  130. {nlptools/arabiner → sinatools/utils}/__init__.py +0 -0
  131. {nlptools → sinatools}/utils/readfile.py +0 -0
  132. {nlptools → sinatools}/utils/utils.py +0 -0
@@ -1,51 +0,0 @@
1
- """
2
- About:
3
- ------
4
- The sina_remove_latin tool performs delete latin characters from the input text.
5
-
6
- Usage:
7
- ------
8
- Below is the usage information that can be generated by running sina_remove_latin --help.
9
-
10
- .. code-block:: none
11
-
12
- Usage:
13
- sina_remove_latin --text=TEXT
14
- sina_remove_latin --file "path/to/your/file.txt"
15
-
16
- Examples:
17
- ---------
18
-
19
- .. code-block:: none
20
-
21
- sina_remove_punctuation --text "123test"
22
-
23
- sina_remove_punctuation --file "path/to/your/file.txt"
24
-
25
- Note:
26
- -----
27
-
28
- .. code-block:: none
29
-
30
- - This tool is specific to Arabic text, as it focuses on Arabic linguistic elements.
31
- - Ensure that the text input is appropriately encoded in UTF-8 or compatible formats.
32
- - This tool for latin characters, if the input text is an Arabic characters or numbers the output will be the same input
33
-
34
- """
35
-
36
- import argparse
37
- from nlptools.utils.parser import remove_latin
38
-
39
-
40
- def main():
41
- parser = argparse.ArgumentParser(description='remove latin characters from the text')
42
-
43
- parser.add_argument('--text', type=str, required=True, help='The input text')
44
- args = parser.parse_args()
45
- result = remove_latin(args.text)
46
-
47
- print(result)
48
- if __name__ == '__main__':
49
- main()
50
-
51
- #sina_remove_latin --text "123test"
@@ -1,53 +0,0 @@
1
- """
2
- About:
3
- ------
4
- The sina_remove_punctuation tool performs delete punctuation marks from the input text.
5
-
6
- Usage:
7
- ------
8
- Below is the usage information that can be generated by running sina_remove_punctuation --help.
9
-
10
- .. code-block:: none
11
-
12
- Usage:
13
- sina_remove_punctuation --text=TEXT
14
- sina_remove_punctuation --file "path/to/your/file.txt"
15
-
16
- Examples:
17
- ---------
18
-
19
- .. code-block:: none
20
-
21
- sina_remove_punctuation --text "te%s@t...!!?"
22
-
23
- sina_remove_punctuation --file "path/to/your/file.txt"
24
-
25
- Note:
26
- -----
27
-
28
- .. code-block:: none
29
-
30
- - This tool is specific to Arabic text, as it focuses on Arabic linguistic elements.
31
- - Ensure that the text input is appropriately encoded in UTF-8 or compatible formats.
32
- """
33
-
34
- import argparse
35
- from nlptools.utils.parser import remove_punctuation
36
- #from nlptools.utils.parser import read_file
37
- #from nlptools.utils.parser import write_file
38
-
39
-
40
- def main():
41
- parser = argparse.ArgumentParser(description='remove punctuation marks from the text')
42
-
43
- parser.add_argument('--text',required=True,help="input text")
44
- # parser.add_argument('myFile', type=argparse.FileType('r'),help='Input file csv')
45
- args = parser.parse_args()
46
- result = remove_punctuation(args.text)
47
-
48
- print(result)
49
- if __name__ == '__main__':
50
- main()
51
-
52
- #sina_remove_punctuation --text "your text"
53
-
nlptools/VERSION DELETED
@@ -1 +0,0 @@
1
- 0.1.4
@@ -1,14 +0,0 @@
1
- from nlptools.DataDownload import downloader
2
- import os
3
- from nlptools.arabiner.utils.helpers import load_checkpoint
4
- import nlptools
5
-
6
- nlptools.tagger = None
7
- nlptools.tag_vocab = None
8
- nlptools.train_config = None
9
-
10
- filename = 'Wj27012000.tar'
11
- path =downloader.get_appdatadir()
12
- model_path = os.path.join(path, filename)
13
- print('1',model_path)
14
- nlptools.tagger, nlptools.tag_vocab, nlptools.train_config = load_checkpoint(model_path)
@@ -1,87 +0,0 @@
1
- import os
2
- import logging
3
- import argparse
4
- from collections import namedtuple
5
- from nlptools.arabiner.utils.helpers import load_checkpoint, make_output_dirs, logging_config
6
- from nlptools.arabiner.utils.data import get_dataloaders, parse_conll_files
7
- from nlptools.arabiner.utils.metrics import compute_single_label_metrics, compute_nested_metrics
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- def parse_args():
13
- parser = argparse.ArgumentParser(
14
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
15
- )
16
-
17
- parser.add_argument(
18
- "--output_path",
19
- type=str,
20
- required=True,
21
- help="Path to save results",
22
- )
23
-
24
- parser.add_argument(
25
- "--model_path",
26
- type=str,
27
- required=True,
28
- help="Model path",
29
- )
30
-
31
- parser.add_argument(
32
- "--data_paths",
33
- nargs="+",
34
- type=str,
35
- required=True,
36
- help="Text or sequence to tag, this is in same format as training data with 'O' tag for all tokens",
37
- )
38
-
39
- parser.add_argument(
40
- "--batch_size",
41
- type=int,
42
- default=32,
43
- help="Batch size",
44
- )
45
-
46
- args = parser.parse_args()
47
-
48
- return args
49
-
50
-
51
- def main(args):
52
- # Create directory to save predictions
53
- make_output_dirs(args.output_path, overwrite=True)
54
- logging_config(log_file=os.path.join(args.output_path, "eval.log"))
55
-
56
- # Load tagger
57
- tagger, tag_vocab, train_config = load_checkpoint(args.model_path)
58
-
59
- # Convert text to a tagger dataset and index the tokens in args.text
60
- datasets, vocab = parse_conll_files(args.data_paths)
61
-
62
- vocabs = namedtuple("Vocab", ["tags", "tokens"])
63
- vocab = vocabs(tokens=vocab.tokens, tags=tag_vocab)
64
-
65
- # From the datasets generate the dataloaders
66
- dataloaders = get_dataloaders(
67
- datasets, vocab,
68
- train_config.data_config,
69
- batch_size=args.batch_size,
70
- shuffle=[False] * len(datasets)
71
- )
72
-
73
- # Evaluate the model on each dataloader
74
- for dataloader, input_file in zip(dataloaders, args.data_paths):
75
- filename = os.path.basename(input_file)
76
- predictions_file = os.path.join(args.output_path, f"predictions_{filename}")
77
- _, segments, _, _ = tagger.eval(dataloader)
78
- tagger.segments_to_file(segments, predictions_file)
79
-
80
- if "Nested" in train_config.trainer_config["fn"]:
81
- compute_nested_metrics(segments, vocab.tags[1:])
82
- else:
83
- compute_single_label_metrics(segments)
84
-
85
-
86
- if __name__ == "__main__":
87
- main(parse_args())
@@ -1,140 +0,0 @@
1
- import os
2
- import argparse
3
- import csv
4
- import logging
5
- import numpy as np
6
- from nlptools.arabiner.utils.helpers import logging_config
7
- from nlptools.arabiner.utils.data import conll_to_segments
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- def to_conll_format(input_files, output_path, multi_label=False):
13
- """
14
- Parse data files and convert them into CoNLL format
15
- :param input_files: List[str] - list of filenames
16
- :param output_path: str - output path
17
- :param multi_label: boolean - True to process data with mutli-class/multi-label
18
- :return:
19
- """
20
- for input_file in input_files:
21
- tokens = list()
22
- prev_sent_id = None
23
-
24
- with open(input_file, "r") as fh:
25
- r = csv.reader(fh, delimiter="\t", quotechar=" ")
26
- next(r)
27
-
28
- for row in r:
29
- sent_id, token, labels = row[1], row[3], row[4].split()
30
- valid_labels = sum([1 for l in labels if "-" in l or l == "O"]) == len(labels)
31
-
32
- if not valid_labels:
33
- logging.warning("Invalid labels found %s", str(row))
34
- continue
35
- if not labels:
36
- logging.warning("Token %s has no label", str(row))
37
- continue
38
- if not token:
39
- logging.warning("Token %s is missing", str(row))
40
- continue
41
- if len(token.split()) > 1:
42
- logging.warning("Token %s has multiple tokens", str(row))
43
- continue
44
-
45
- if prev_sent_id is not None and sent_id != prev_sent_id:
46
- tokens.append([])
47
-
48
- if multi_label:
49
- tokens.append([token] + labels)
50
- else:
51
- tokens.append([token, labels[0]])
52
-
53
- prev_sent_id = sent_id
54
-
55
- num_segments = sum([1 for token in tokens if not token])
56
- logging.info("Found %d segments and %d tokens in %s", num_segments + 1, len(tokens) - num_segments, input_file)
57
-
58
- filename = os.path.basename(input_file)
59
- output_file = os.path.join(output_path, filename)
60
-
61
- with open(output_file, "w") as fh:
62
- fh.write("\n".join(" ".join(token) for token in tokens))
63
- logging.info("Output file %s", output_file)
64
-
65
-
66
- def train_dev_test_split(input_files, output_path, train_ratio, dev_ratio):
67
- segments = list()
68
- filenames = ["train.txt", "val.txt", "test.txt"]
69
-
70
- for input_file in input_files:
71
- segments += conll_to_segments(input_file)
72
-
73
- n = len(segments)
74
- np.random.shuffle(segments)
75
- datasets = np.split(segments, [int(train_ratio*n), int((train_ratio+dev_ratio)*n)])
76
-
77
- # write data to files
78
- for i in range(len(datasets)):
79
- filename = os.path.join(output_path, filenames[i])
80
-
81
- with open(filename, "w") as fh:
82
- text = "\n\n".join(["\n".join([f"{token.text} {' '.join(token.gold_tag)}" for token in segment]) for segment in datasets[i]])
83
- fh.write(text)
84
- logging.info("Output file %s", filename)
85
-
86
-
87
- def main(args):
88
- if args.task == "to_conll_format":
89
- to_conll_format(args.input_files, args.output_path, multi_label=args.multi_label)
90
- if args.task == "train_dev_test_split":
91
- train_dev_test_split(args.input_files, args.output_path, args.train_ratio, args.dev_ratio)
92
-
93
-
94
- if __name__ == "__main__":
95
- parser = argparse.ArgumentParser(
96
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
97
- )
98
-
99
- parser.add_argument(
100
- "--input_files",
101
- type=str,
102
- nargs="+",
103
- required=True,
104
- help="List of input files",
105
- )
106
-
107
- parser.add_argument(
108
- "--output_path",
109
- type=str,
110
- required=True,
111
- help="Output path",
112
- )
113
-
114
- parser.add_argument(
115
- "--train_ratio",
116
- type=float,
117
- required=False,
118
- help="Training data ratio (percent of segments). Required with the task train_dev_test_split. "
119
- "Files must in ConLL format",
120
- )
121
-
122
- parser.add_argument(
123
- "--dev_ratio",
124
- type=float,
125
- required=False,
126
- help="Dev/val data ratio (percent of segments). Required with the task train_dev_test_split. "
127
- "Files must in ConLL format",
128
- )
129
-
130
- parser.add_argument(
131
- "--task", required=True, choices=["to_conll_format", "train_dev_test_split"]
132
- )
133
-
134
- parser.add_argument(
135
- "--multi_label", action='store_true'
136
- )
137
-
138
- args = parser.parse_args()
139
- logging_config(os.path.join(args.output_path, "process.log"))
140
- main(args)
@@ -1,221 +0,0 @@
1
- import os
2
- import logging
3
- import json
4
- import argparse
5
- import torch.utils.tensorboard
6
- from torchvision import *
7
- import pickle
8
- from nlptools.arabiner.utils.data import get_dataloaders, parse_conll_files
9
- from nlptools.arabiner.utils.helpers import logging_config, load_object, make_output_dirs, set_seed
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
- def parse_args():
15
- parser = argparse.ArgumentParser(
16
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
17
- )
18
-
19
- parser.add_argument(
20
- "--output_path",
21
- type=str,
22
- required=True,
23
- help="Output path",
24
- )
25
-
26
- parser.add_argument(
27
- "--train_path",
28
- type=str,
29
- required=True,
30
- help="Path to training data",
31
- )
32
-
33
- parser.add_argument(
34
- "--val_path",
35
- type=str,
36
- required=True,
37
- help="Path to training data",
38
- )
39
-
40
- parser.add_argument(
41
- "--test_path",
42
- type=str,
43
- required=True,
44
- help="Path to training data",
45
- )
46
-
47
- parser.add_argument(
48
- "--bert_model",
49
- type=str,
50
- default="aubmindlab/bert-base-arabertv2",
51
- help="BERT model",
52
- )
53
-
54
- parser.add_argument(
55
- "--gpus",
56
- type=int,
57
- nargs="+",
58
- default=[0],
59
- help="GPU IDs to train on",
60
- )
61
-
62
- parser.add_argument(
63
- "--log_interval",
64
- type=int,
65
- default=10,
66
- help="Log results every that many timesteps",
67
- )
68
-
69
- parser.add_argument(
70
- "--batch_size",
71
- type=int,
72
- default=32,
73
- help="Batch size",
74
- )
75
-
76
- parser.add_argument(
77
- "--num_workers",
78
- type=int,
79
- default=0,
80
- help="Dataloader number of workers",
81
- )
82
-
83
- parser.add_argument(
84
- "--data_config",
85
- type=json.loads,
86
- default='{"fn": "arabiner.data.datasets.DefaultDataset", "kwargs": {"max_seq_len": 512}}',
87
- help="Dataset configurations",
88
- )
89
-
90
- parser.add_argument(
91
- "--trainer_config",
92
- type=json.loads,
93
- default='{"fn": "arabiner.trainers.BertTrainer", "kwargs": {"max_epochs": 50}}',
94
- help="Trainer configurations",
95
- )
96
-
97
- parser.add_argument(
98
- "--network_config",
99
- type=json.loads,
100
- default='{"fn": "arabiner.nn.BertSeqTagger", "kwargs": '
101
- '{"dropout": 0.1, "bert_model": "aubmindlab/bert-base-arabertv2"}}',
102
- help="Network configurations",
103
- )
104
-
105
- parser.add_argument(
106
- "--optimizer",
107
- type=json.loads,
108
- default='{"fn": "torch.optim.AdamW", "kwargs": {"lr": 0.0001}}',
109
- help="Optimizer configurations",
110
- )
111
-
112
- parser.add_argument(
113
- "--lr_scheduler",
114
- type=json.loads,
115
- default='{"fn": "torch.optim.lr_scheduler.ExponentialLR", "kwargs": {"gamma": 1}}',
116
- help="Learning rate scheduler configurations",
117
- )
118
-
119
- parser.add_argument(
120
- "--loss",
121
- type=json.loads,
122
- default='{"fn": "torch.nn.CrossEntropyLoss", "kwargs": {}}',
123
- help="Loss function configurations",
124
- )
125
-
126
- parser.add_argument(
127
- "--overwrite",
128
- action="store_true",
129
- help="Overwrite output directory",
130
- )
131
-
132
- parser.add_argument(
133
- "--seed",
134
- type=int,
135
- default=1,
136
- help="Seed for random initialization",
137
- )
138
-
139
- args = parser.parse_args()
140
-
141
- return args
142
-
143
-
144
- def main(args):
145
- make_output_dirs(
146
- args.output_path,
147
- subdirs=("tensorboard", "checkpoints"),
148
- overwrite=args.overwrite,
149
- )
150
-
151
- # Set the seed for randomization
152
- set_seed(args.seed)
153
-
154
- logging_config(os.path.join(args.output_path, "train.log"))
155
- summary_writer = torch.utils.tensorboard.SummaryWriter(
156
- os.path.join(args.output_path, "tensorboard")
157
- )
158
- os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu) for gpu in args.gpus])
159
-
160
- # Get the datasets and vocab for tags and tokens
161
- datasets, vocab = parse_conll_files((args.train_path, args.val_path, args.test_path))
162
-
163
- if "Nested" in args.network_config["fn"]:
164
- args.network_config["kwargs"]["num_labels"] = [len(v) for v in vocab.tags[1:]]
165
- else:
166
- args.network_config["kwargs"]["num_labels"] = len(vocab.tags[0])
167
-
168
- # Save tag vocab to desk
169
- with open(os.path.join(args.output_path, "tag_vocab.pkl"), "wb") as fh:
170
- pickle.dump(vocab.tags, fh)
171
-
172
- # Write config to file
173
- args_file = os.path.join(args.output_path, "args.json")
174
- with open(args_file, "w") as fh:
175
- logger.info("Writing config to %s", args_file)
176
- json.dump(args.__dict__, fh, indent=4)
177
-
178
- # From the datasets generate the dataloaders
179
- args.data_config["kwargs"]["bert_model"] = args.network_config["kwargs"]["bert_model"]
180
- train_dataloader, val_dataloader, test_dataloader = get_dataloaders(
181
- datasets, vocab, args.data_config, args.batch_size, args.num_workers
182
- )
183
-
184
- model = load_object(args.network_config["fn"], args.network_config["kwargs"])
185
- model = torch.nn.DataParallel(model, device_ids=range(len(args.gpus)))
186
-
187
- if torch.cuda.is_available():
188
- model = model.cuda()
189
-
190
- args.optimizer["kwargs"]["params"] = model.parameters()
191
- optimizer = load_object(args.optimizer["fn"], args.optimizer["kwargs"])
192
-
193
- args.lr_scheduler["kwargs"]["optimizer"] = optimizer
194
- if "num_training_steps" in args.lr_scheduler["kwargs"]:
195
- args.lr_scheduler["kwargs"]["num_training_steps"] = args.max_epochs * len(
196
- train_dataloader
197
- )
198
-
199
- scheduler = load_object(args.lr_scheduler["fn"], args.lr_scheduler["kwargs"])
200
- loss = load_object(args.loss["fn"], args.loss["kwargs"])
201
-
202
- args.trainer_config["kwargs"].update({
203
- "model": model,
204
- "optimizer": optimizer,
205
- "scheduler": scheduler,
206
- "loss": loss,
207
- "train_dataloader": train_dataloader,
208
- "val_dataloader": val_dataloader,
209
- "test_dataloader": test_dataloader,
210
- "log_interval": args.log_interval,
211
- "summary_writer": summary_writer,
212
- "output_path": args.output_path
213
- })
214
-
215
- trainer = load_object(args.trainer_config["fn"], args.trainer_config["kwargs"])
216
- trainer.train()
217
- return
218
-
219
-
220
- if __name__ == "__main__":
221
- main(parse_args())
@@ -1 +0,0 @@
1
- from nlptools.arabiner.data.datasets import NestedTagsDataset