SinaTools 0.1.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. SinaTools-0.1.1.data/data/nlptools/environment.yml +227 -0
  2. SinaTools-0.1.1.dist-info/AUTHORS.rst +13 -0
  3. SinaTools-0.1.1.dist-info/LICENSE +22 -0
  4. SinaTools-0.1.1.dist-info/METADATA +72 -0
  5. SinaTools-0.1.1.dist-info/RECORD +122 -0
  6. SinaTools-0.1.1.dist-info/WHEEL +6 -0
  7. SinaTools-0.1.1.dist-info/entry_points.txt +18 -0
  8. SinaTools-0.1.1.dist-info/top_level.txt +1 -0
  9. nlptools/CLI/DataDownload/download_files.py +71 -0
  10. nlptools/CLI/arabiner/bin/infer.py +117 -0
  11. nlptools/CLI/arabiner/bin/infer2.py +81 -0
  12. nlptools/CLI/morphology/ALMA_multi_word.py +75 -0
  13. nlptools/CLI/morphology/morph_analyzer.py +91 -0
  14. nlptools/CLI/salma/salma_tools.py +68 -0
  15. nlptools/CLI/utils/__init__.py +0 -0
  16. nlptools/CLI/utils/arStrip.py +99 -0
  17. nlptools/CLI/utils/corpus_tokenizer.py +74 -0
  18. nlptools/CLI/utils/implication.py +92 -0
  19. nlptools/CLI/utils/jaccard.py +96 -0
  20. nlptools/CLI/utils/latin_remove.py +51 -0
  21. nlptools/CLI/utils/remove_Punc.py +53 -0
  22. nlptools/CLI/utils/sentence_tokenizer.py +90 -0
  23. nlptools/CLI/utils/text_transliteration.py +77 -0
  24. nlptools/DataDownload/__init__.py +0 -0
  25. nlptools/DataDownload/downloader.py +185 -0
  26. nlptools/VERSION +1 -0
  27. nlptools/__init__.py +5 -0
  28. nlptools/arabert/__init__.py +1 -0
  29. nlptools/arabert/arabert/__init__.py +14 -0
  30. nlptools/arabert/arabert/create_classification_data.py +260 -0
  31. nlptools/arabert/arabert/create_pretraining_data.py +534 -0
  32. nlptools/arabert/arabert/extract_features.py +444 -0
  33. nlptools/arabert/arabert/lamb_optimizer.py +158 -0
  34. nlptools/arabert/arabert/modeling.py +1027 -0
  35. nlptools/arabert/arabert/optimization.py +202 -0
  36. nlptools/arabert/arabert/run_classifier.py +1078 -0
  37. nlptools/arabert/arabert/run_pretraining.py +593 -0
  38. nlptools/arabert/arabert/run_squad.py +1440 -0
  39. nlptools/arabert/arabert/tokenization.py +414 -0
  40. nlptools/arabert/araelectra/__init__.py +1 -0
  41. nlptools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +103 -0
  42. nlptools/arabert/araelectra/build_pretraining_dataset.py +230 -0
  43. nlptools/arabert/araelectra/build_pretraining_dataset_single_file.py +90 -0
  44. nlptools/arabert/araelectra/configure_finetuning.py +172 -0
  45. nlptools/arabert/araelectra/configure_pretraining.py +143 -0
  46. nlptools/arabert/araelectra/finetune/__init__.py +14 -0
  47. nlptools/arabert/araelectra/finetune/feature_spec.py +56 -0
  48. nlptools/arabert/araelectra/finetune/preprocessing.py +173 -0
  49. nlptools/arabert/araelectra/finetune/scorer.py +54 -0
  50. nlptools/arabert/araelectra/finetune/task.py +74 -0
  51. nlptools/arabert/araelectra/finetune/task_builder.py +70 -0
  52. nlptools/arabert/araelectra/flops_computation.py +215 -0
  53. nlptools/arabert/araelectra/model/__init__.py +14 -0
  54. nlptools/arabert/araelectra/model/modeling.py +1029 -0
  55. nlptools/arabert/araelectra/model/optimization.py +193 -0
  56. nlptools/arabert/araelectra/model/tokenization.py +355 -0
  57. nlptools/arabert/araelectra/pretrain/__init__.py +14 -0
  58. nlptools/arabert/araelectra/pretrain/pretrain_data.py +160 -0
  59. nlptools/arabert/araelectra/pretrain/pretrain_helpers.py +229 -0
  60. nlptools/arabert/araelectra/run_finetuning.py +323 -0
  61. nlptools/arabert/araelectra/run_pretraining.py +469 -0
  62. nlptools/arabert/araelectra/util/__init__.py +14 -0
  63. nlptools/arabert/araelectra/util/training_utils.py +112 -0
  64. nlptools/arabert/araelectra/util/utils.py +109 -0
  65. nlptools/arabert/aragpt2/__init__.py +2 -0
  66. nlptools/arabert/aragpt2/create_pretraining_data.py +95 -0
  67. nlptools/arabert/aragpt2/gpt2/__init__.py +2 -0
  68. nlptools/arabert/aragpt2/gpt2/lamb_optimizer.py +158 -0
  69. nlptools/arabert/aragpt2/gpt2/optimization.py +225 -0
  70. nlptools/arabert/aragpt2/gpt2/run_pretraining.py +397 -0
  71. nlptools/arabert/aragpt2/grover/__init__.py +0 -0
  72. nlptools/arabert/aragpt2/grover/dataloader.py +161 -0
  73. nlptools/arabert/aragpt2/grover/modeling.py +803 -0
  74. nlptools/arabert/aragpt2/grover/modeling_gpt2.py +1196 -0
  75. nlptools/arabert/aragpt2/grover/optimization_adafactor.py +234 -0
  76. nlptools/arabert/aragpt2/grover/train_tpu.py +187 -0
  77. nlptools/arabert/aragpt2/grover/utils.py +234 -0
  78. nlptools/arabert/aragpt2/train_bpe_tokenizer.py +59 -0
  79. nlptools/arabert/preprocess.py +818 -0
  80. nlptools/arabiner/__init__.py +0 -0
  81. nlptools/arabiner/bin/__init__.py +14 -0
  82. nlptools/arabiner/bin/eval.py +87 -0
  83. nlptools/arabiner/bin/infer.py +91 -0
  84. nlptools/arabiner/bin/process.py +140 -0
  85. nlptools/arabiner/bin/train.py +221 -0
  86. nlptools/arabiner/data/__init__.py +1 -0
  87. nlptools/arabiner/data/datasets.py +146 -0
  88. nlptools/arabiner/data/transforms.py +118 -0
  89. nlptools/arabiner/nn/BaseModel.py +22 -0
  90. nlptools/arabiner/nn/BertNestedTagger.py +34 -0
  91. nlptools/arabiner/nn/BertSeqTagger.py +17 -0
  92. nlptools/arabiner/nn/__init__.py +3 -0
  93. nlptools/arabiner/trainers/BaseTrainer.py +117 -0
  94. nlptools/arabiner/trainers/BertNestedTrainer.py +203 -0
  95. nlptools/arabiner/trainers/BertTrainer.py +163 -0
  96. nlptools/arabiner/trainers/__init__.py +3 -0
  97. nlptools/arabiner/utils/__init__.py +0 -0
  98. nlptools/arabiner/utils/data.py +124 -0
  99. nlptools/arabiner/utils/helpers.py +151 -0
  100. nlptools/arabiner/utils/metrics.py +69 -0
  101. nlptools/environment.yml +227 -0
  102. nlptools/install_env.py +13 -0
  103. nlptools/morphology/ALMA_multi_word.py +34 -0
  104. nlptools/morphology/__init__.py +52 -0
  105. nlptools/morphology/charsets.py +60 -0
  106. nlptools/morphology/morph_analyzer.py +170 -0
  107. nlptools/morphology/settings.py +8 -0
  108. nlptools/morphology/tokenizers_words.py +19 -0
  109. nlptools/nlptools.py +1 -0
  110. nlptools/salma/__init__.py +12 -0
  111. nlptools/salma/settings.py +31 -0
  112. nlptools/salma/views.py +459 -0
  113. nlptools/salma/wsd.py +126 -0
  114. nlptools/utils/__init__.py +0 -0
  115. nlptools/utils/corpus_tokenizer.py +73 -0
  116. nlptools/utils/implication.py +662 -0
  117. nlptools/utils/jaccard.py +247 -0
  118. nlptools/utils/parser.py +147 -0
  119. nlptools/utils/readfile.py +3 -0
  120. nlptools/utils/sentence_tokenizer.py +53 -0
  121. nlptools/utils/text_transliteration.py +232 -0
  122. nlptools/utils/utils.py +2 -0
File without changes
@@ -0,0 +1,14 @@
1
+ from nlptools.DataDownload import downloader
2
+ import os
3
+ from nlptools.arabiner.utils.helpers import load_checkpoint
4
+ import nlptools
5
+
6
+ nlptools.tagger = None
7
+ nlptools.tag_vocab = None
8
+ nlptools.train_config = None
9
+
10
+ filename = 'Wj27012000.tar'
11
+ path =downloader.get_appdatadir()
12
+ model_path = os.path.join(path, filename)
13
+ print('1',model_path)
14
+ nlptools.tagger, nlptools.tag_vocab, nlptools.train_config = load_checkpoint(model_path)
@@ -0,0 +1,87 @@
1
+ import os
2
+ import logging
3
+ import argparse
4
+ from collections import namedtuple
5
+ from nlptools.arabiner.utils.helpers import load_checkpoint, make_output_dirs, logging_config
6
+ from nlptools.arabiner.utils.data import get_dataloaders, parse_conll_files
7
+ from nlptools.arabiner.utils.metrics import compute_single_label_metrics, compute_nested_metrics
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser(
14
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
15
+ )
16
+
17
+ parser.add_argument(
18
+ "--output_path",
19
+ type=str,
20
+ required=True,
21
+ help="Path to save results",
22
+ )
23
+
24
+ parser.add_argument(
25
+ "--model_path",
26
+ type=str,
27
+ required=True,
28
+ help="Model path",
29
+ )
30
+
31
+ parser.add_argument(
32
+ "--data_paths",
33
+ nargs="+",
34
+ type=str,
35
+ required=True,
36
+ help="Text or sequence to tag, this is in same format as training data with 'O' tag for all tokens",
37
+ )
38
+
39
+ parser.add_argument(
40
+ "--batch_size",
41
+ type=int,
42
+ default=32,
43
+ help="Batch size",
44
+ )
45
+
46
+ args = parser.parse_args()
47
+
48
+ return args
49
+
50
+
51
+ def main(args):
52
+ # Create directory to save predictions
53
+ make_output_dirs(args.output_path, overwrite=True)
54
+ logging_config(log_file=os.path.join(args.output_path, "eval.log"))
55
+
56
+ # Load tagger
57
+ tagger, tag_vocab, train_config = load_checkpoint(args.model_path)
58
+
59
+ # Convert text to a tagger dataset and index the tokens in args.text
60
+ datasets, vocab = parse_conll_files(args.data_paths)
61
+
62
+ vocabs = namedtuple("Vocab", ["tags", "tokens"])
63
+ vocab = vocabs(tokens=vocab.tokens, tags=tag_vocab)
64
+
65
+ # From the datasets generate the dataloaders
66
+ dataloaders = get_dataloaders(
67
+ datasets, vocab,
68
+ train_config.data_config,
69
+ batch_size=args.batch_size,
70
+ shuffle=[False] * len(datasets)
71
+ )
72
+
73
+ # Evaluate the model on each dataloader
74
+ for dataloader, input_file in zip(dataloaders, args.data_paths):
75
+ filename = os.path.basename(input_file)
76
+ predictions_file = os.path.join(args.output_path, f"predictions_{filename}")
77
+ _, segments, _, _ = tagger.eval(dataloader)
78
+ tagger.segments_to_file(segments, predictions_file)
79
+
80
+ if "Nested" in train_config.trainer_config["fn"]:
81
+ compute_nested_metrics(segments, vocab.tags[1:])
82
+ else:
83
+ compute_single_label_metrics(segments)
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main(parse_args())
@@ -0,0 +1,91 @@
1
+ import os
2
+ from collections import namedtuple
3
+ from nlptools.arabiner.utils.helpers import load_checkpoint
4
+ from nlptools.arabiner.utils.data import get_dataloaders, text2segments
5
+ from nlptools.DataDownload import downloader
6
+ import nlptools
7
+ def ner(text, batch_size=32):
8
+ """
9
+ This method takes a text as input, and a batch size, then performs named entity recognition (NER) on the input text and returns a list of tagged mentions.
10
+
11
+ Args:
12
+ text (str): The input text to perform NER on.
13
+ batch_size (int, optional): Batch size for inference. Default is 32.
14
+
15
+ Returns:
16
+ list: A list of lists containing token and label pairs for each segment.
17
+ Each inner list has the format ['token', 'label1 label2 ...'].
18
+ **Example:**
19
+
20
+ .. highlight:: python
21
+ .. code-block:: python
22
+
23
+ from nlptools.arabiner.bin import infer
24
+ infer.ner('ذهب محمد الى جامعة بيرزيت')
25
+
26
+ #the output
27
+ [['ذهب', 'O'],
28
+ ['محمد', 'B-PERS'],
29
+ ['الى', 'O'],
30
+ ['جامعة', 'B-ORG'],
31
+ ['بيرزيت', 'B-GPE I-ORG']]
32
+ """
33
+ # Load tagger
34
+ # filename = 'Wj27012000.tar'
35
+ # path =downloader.get_appdatadir()
36
+ # model_path = os.path.join(path, filename)
37
+ # print('1',model_path)
38
+ # tagger, tag_vocab, train_config = load_checkpoint(model_path)
39
+
40
+
41
+ # Convert text to a tagger dataset and index the tokens in args.text
42
+ dataset, token_vocab = text2segments(text)
43
+
44
+ vocabs = namedtuple("Vocab", ["tags", "tokens"])
45
+ vocab = vocabs(tokens=token_vocab, tags=nlptools.tag_vocab)
46
+
47
+ # From the datasets generate the dataloaders
48
+ dataloader = get_dataloaders(
49
+ (dataset,),
50
+ vocab,
51
+ nlptools.train_config.data_config,
52
+ batch_size=batch_size,
53
+ shuffle=(False,),
54
+ )[0]
55
+
56
+ # Perform inference on the text and get back the tagged segments
57
+ segments = nlptools.tagger.infer(dataloader)
58
+ segments_lists = []
59
+ # Print results
60
+ for segment in segments:
61
+ for token in segment:
62
+ segments_list = []
63
+ segments_list.append(token.text)
64
+ #print([t['tag'] for t in token.pred_tag])
65
+ list_of_tags = [t['tag'] for t in token.pred_tag]
66
+ list_of_tags = [i for i in list_of_tags if i not in('O',' ','')]
67
+ #print(list_of_tags)
68
+ if list_of_tags == []:
69
+ segments_list.append(' '.join(['O']))
70
+ else:
71
+ segments_list.append(' '.join(list_of_tags))
72
+ segments_lists.append(segments_list)
73
+ return segments_lists
74
+
75
+ #Print results
76
+ # for segment in segments:
77
+ # s = [
78
+ # (token.text, token.pred_tag[0]['tag'])
79
+ # for token in segment
80
+ # if token.pred_tag[0]['tag'] != 'O'
81
+ # ]
82
+ # print(", ".join([f"({token}, {tag})" for token, tag in s]))
83
+
84
+ def extract_tags(text):
85
+ tags = []
86
+ tokens = text.split()
87
+ for token in tokens:
88
+ tag = token.split("(")[-1].split(")")[0]
89
+ if tag != "O":
90
+ tags.append(tag)
91
+ return " ".join(tags)
@@ -0,0 +1,140 @@
1
+ import os
2
+ import argparse
3
+ import csv
4
+ import logging
5
+ import numpy as np
6
+ from nlptools.arabiner.utils.helpers import logging_config
7
+ from nlptools.arabiner.utils.data import conll_to_segments
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def to_conll_format(input_files, output_path, multi_label=False):
13
+ """
14
+ Parse data files and convert them into CoNLL format
15
+ :param input_files: List[str] - list of filenames
16
+ :param output_path: str - output path
17
+ :param multi_label: boolean - True to process data with mutli-class/multi-label
18
+ :return:
19
+ """
20
+ for input_file in input_files:
21
+ tokens = list()
22
+ prev_sent_id = None
23
+
24
+ with open(input_file, "r") as fh:
25
+ r = csv.reader(fh, delimiter="\t", quotechar=" ")
26
+ next(r)
27
+
28
+ for row in r:
29
+ sent_id, token, labels = row[1], row[3], row[4].split()
30
+ valid_labels = sum([1 for l in labels if "-" in l or l == "O"]) == len(labels)
31
+
32
+ if not valid_labels:
33
+ logging.warning("Invalid labels found %s", str(row))
34
+ continue
35
+ if not labels:
36
+ logging.warning("Token %s has no label", str(row))
37
+ continue
38
+ if not token:
39
+ logging.warning("Token %s is missing", str(row))
40
+ continue
41
+ if len(token.split()) > 1:
42
+ logging.warning("Token %s has multiple tokens", str(row))
43
+ continue
44
+
45
+ if prev_sent_id is not None and sent_id != prev_sent_id:
46
+ tokens.append([])
47
+
48
+ if multi_label:
49
+ tokens.append([token] + labels)
50
+ else:
51
+ tokens.append([token, labels[0]])
52
+
53
+ prev_sent_id = sent_id
54
+
55
+ num_segments = sum([1 for token in tokens if not token])
56
+ logging.info("Found %d segments and %d tokens in %s", num_segments + 1, len(tokens) - num_segments, input_file)
57
+
58
+ filename = os.path.basename(input_file)
59
+ output_file = os.path.join(output_path, filename)
60
+
61
+ with open(output_file, "w") as fh:
62
+ fh.write("\n".join(" ".join(token) for token in tokens))
63
+ logging.info("Output file %s", output_file)
64
+
65
+
66
+ def train_dev_test_split(input_files, output_path, train_ratio, dev_ratio):
67
+ segments = list()
68
+ filenames = ["train.txt", "val.txt", "test.txt"]
69
+
70
+ for input_file in input_files:
71
+ segments += conll_to_segments(input_file)
72
+
73
+ n = len(segments)
74
+ np.random.shuffle(segments)
75
+ datasets = np.split(segments, [int(train_ratio*n), int((train_ratio+dev_ratio)*n)])
76
+
77
+ # write data to files
78
+ for i in range(len(datasets)):
79
+ filename = os.path.join(output_path, filenames[i])
80
+
81
+ with open(filename, "w") as fh:
82
+ text = "\n\n".join(["\n".join([f"{token.text} {' '.join(token.gold_tag)}" for token in segment]) for segment in datasets[i]])
83
+ fh.write(text)
84
+ logging.info("Output file %s", filename)
85
+
86
+
87
+ def main(args):
88
+ if args.task == "to_conll_format":
89
+ to_conll_format(args.input_files, args.output_path, multi_label=args.multi_label)
90
+ if args.task == "train_dev_test_split":
91
+ train_dev_test_split(args.input_files, args.output_path, args.train_ratio, args.dev_ratio)
92
+
93
+
94
+ if __name__ == "__main__":
95
+ parser = argparse.ArgumentParser(
96
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
97
+ )
98
+
99
+ parser.add_argument(
100
+ "--input_files",
101
+ type=str,
102
+ nargs="+",
103
+ required=True,
104
+ help="List of input files",
105
+ )
106
+
107
+ parser.add_argument(
108
+ "--output_path",
109
+ type=str,
110
+ required=True,
111
+ help="Output path",
112
+ )
113
+
114
+ parser.add_argument(
115
+ "--train_ratio",
116
+ type=float,
117
+ required=False,
118
+ help="Training data ratio (percent of segments). Required with the task train_dev_test_split. "
119
+ "Files must in ConLL format",
120
+ )
121
+
122
+ parser.add_argument(
123
+ "--dev_ratio",
124
+ type=float,
125
+ required=False,
126
+ help="Dev/val data ratio (percent of segments). Required with the task train_dev_test_split. "
127
+ "Files must in ConLL format",
128
+ )
129
+
130
+ parser.add_argument(
131
+ "--task", required=True, choices=["to_conll_format", "train_dev_test_split"]
132
+ )
133
+
134
+ parser.add_argument(
135
+ "--multi_label", action='store_true'
136
+ )
137
+
138
+ args = parser.parse_args()
139
+ logging_config(os.path.join(args.output_path, "process.log"))
140
+ main(args)
@@ -0,0 +1,221 @@
1
+ import os
2
+ import logging
3
+ import json
4
+ import argparse
5
+ import torch.utils.tensorboard
6
+ from torchvision import *
7
+ import pickle
8
+ from nlptools.arabiner.utils.data import get_dataloaders, parse_conll_files
9
+ from nlptools.arabiner.utils.helpers import logging_config, load_object, make_output_dirs, set_seed
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser(
16
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
17
+ )
18
+
19
+ parser.add_argument(
20
+ "--output_path",
21
+ type=str,
22
+ required=True,
23
+ help="Output path",
24
+ )
25
+
26
+ parser.add_argument(
27
+ "--train_path",
28
+ type=str,
29
+ required=True,
30
+ help="Path to training data",
31
+ )
32
+
33
+ parser.add_argument(
34
+ "--val_path",
35
+ type=str,
36
+ required=True,
37
+ help="Path to training data",
38
+ )
39
+
40
+ parser.add_argument(
41
+ "--test_path",
42
+ type=str,
43
+ required=True,
44
+ help="Path to training data",
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--bert_model",
49
+ type=str,
50
+ default="aubmindlab/bert-base-arabertv2",
51
+ help="BERT model",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--gpus",
56
+ type=int,
57
+ nargs="+",
58
+ default=[0],
59
+ help="GPU IDs to train on",
60
+ )
61
+
62
+ parser.add_argument(
63
+ "--log_interval",
64
+ type=int,
65
+ default=10,
66
+ help="Log results every that many timesteps",
67
+ )
68
+
69
+ parser.add_argument(
70
+ "--batch_size",
71
+ type=int,
72
+ default=32,
73
+ help="Batch size",
74
+ )
75
+
76
+ parser.add_argument(
77
+ "--num_workers",
78
+ type=int,
79
+ default=0,
80
+ help="Dataloader number of workers",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--data_config",
85
+ type=json.loads,
86
+ default='{"fn": "arabiner.data.datasets.DefaultDataset", "kwargs": {"max_seq_len": 512}}',
87
+ help="Dataset configurations",
88
+ )
89
+
90
+ parser.add_argument(
91
+ "--trainer_config",
92
+ type=json.loads,
93
+ default='{"fn": "arabiner.trainers.BertTrainer", "kwargs": {"max_epochs": 50}}',
94
+ help="Trainer configurations",
95
+ )
96
+
97
+ parser.add_argument(
98
+ "--network_config",
99
+ type=json.loads,
100
+ default='{"fn": "arabiner.nn.BertSeqTagger", "kwargs": '
101
+ '{"dropout": 0.1, "bert_model": "aubmindlab/bert-base-arabertv2"}}',
102
+ help="Network configurations",
103
+ )
104
+
105
+ parser.add_argument(
106
+ "--optimizer",
107
+ type=json.loads,
108
+ default='{"fn": "torch.optim.AdamW", "kwargs": {"lr": 0.0001}}',
109
+ help="Optimizer configurations",
110
+ )
111
+
112
+ parser.add_argument(
113
+ "--lr_scheduler",
114
+ type=json.loads,
115
+ default='{"fn": "torch.optim.lr_scheduler.ExponentialLR", "kwargs": {"gamma": 1}}',
116
+ help="Learning rate scheduler configurations",
117
+ )
118
+
119
+ parser.add_argument(
120
+ "--loss",
121
+ type=json.loads,
122
+ default='{"fn": "torch.nn.CrossEntropyLoss", "kwargs": {}}',
123
+ help="Loss function configurations",
124
+ )
125
+
126
+ parser.add_argument(
127
+ "--overwrite",
128
+ action="store_true",
129
+ help="Overwrite output directory",
130
+ )
131
+
132
+ parser.add_argument(
133
+ "--seed",
134
+ type=int,
135
+ default=1,
136
+ help="Seed for random initialization",
137
+ )
138
+
139
+ args = parser.parse_args()
140
+
141
+ return args
142
+
143
+
144
+ def main(args):
145
+ make_output_dirs(
146
+ args.output_path,
147
+ subdirs=("tensorboard", "checkpoints"),
148
+ overwrite=args.overwrite,
149
+ )
150
+
151
+ # Set the seed for randomization
152
+ set_seed(args.seed)
153
+
154
+ logging_config(os.path.join(args.output_path, "train.log"))
155
+ summary_writer = torch.utils.tensorboard.SummaryWriter(
156
+ os.path.join(args.output_path, "tensorboard")
157
+ )
158
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu) for gpu in args.gpus])
159
+
160
+ # Get the datasets and vocab for tags and tokens
161
+ datasets, vocab = parse_conll_files((args.train_path, args.val_path, args.test_path))
162
+
163
+ if "Nested" in args.network_config["fn"]:
164
+ args.network_config["kwargs"]["num_labels"] = [len(v) for v in vocab.tags[1:]]
165
+ else:
166
+ args.network_config["kwargs"]["num_labels"] = len(vocab.tags[0])
167
+
168
+ # Save tag vocab to desk
169
+ with open(os.path.join(args.output_path, "tag_vocab.pkl"), "wb") as fh:
170
+ pickle.dump(vocab.tags, fh)
171
+
172
+ # Write config to file
173
+ args_file = os.path.join(args.output_path, "args.json")
174
+ with open(args_file, "w") as fh:
175
+ logger.info("Writing config to %s", args_file)
176
+ json.dump(args.__dict__, fh, indent=4)
177
+
178
+ # From the datasets generate the dataloaders
179
+ args.data_config["kwargs"]["bert_model"] = args.network_config["kwargs"]["bert_model"]
180
+ train_dataloader, val_dataloader, test_dataloader = get_dataloaders(
181
+ datasets, vocab, args.data_config, args.batch_size, args.num_workers
182
+ )
183
+
184
+ model = load_object(args.network_config["fn"], args.network_config["kwargs"])
185
+ model = torch.nn.DataParallel(model, device_ids=range(len(args.gpus)))
186
+
187
+ if torch.cuda.is_available():
188
+ model = model.cuda()
189
+
190
+ args.optimizer["kwargs"]["params"] = model.parameters()
191
+ optimizer = load_object(args.optimizer["fn"], args.optimizer["kwargs"])
192
+
193
+ args.lr_scheduler["kwargs"]["optimizer"] = optimizer
194
+ if "num_training_steps" in args.lr_scheduler["kwargs"]:
195
+ args.lr_scheduler["kwargs"]["num_training_steps"] = args.max_epochs * len(
196
+ train_dataloader
197
+ )
198
+
199
+ scheduler = load_object(args.lr_scheduler["fn"], args.lr_scheduler["kwargs"])
200
+ loss = load_object(args.loss["fn"], args.loss["kwargs"])
201
+
202
+ args.trainer_config["kwargs"].update({
203
+ "model": model,
204
+ "optimizer": optimizer,
205
+ "scheduler": scheduler,
206
+ "loss": loss,
207
+ "train_dataloader": train_dataloader,
208
+ "val_dataloader": val_dataloader,
209
+ "test_dataloader": test_dataloader,
210
+ "log_interval": args.log_interval,
211
+ "summary_writer": summary_writer,
212
+ "output_path": args.output_path
213
+ })
214
+
215
+ trainer = load_object(args.trainer_config["fn"], args.trainer_config["kwargs"])
216
+ trainer.train()
217
+ return
218
+
219
+
220
+ if __name__ == "__main__":
221
+ main(parse_args())
@@ -0,0 +1 @@
1
+ from nlptools.arabiner.data.datasets import NestedTagsDataset