SinaTools 0.1.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SinaTools-0.1.1.data/data/nlptools/environment.yml +227 -0
- SinaTools-0.1.1.dist-info/AUTHORS.rst +13 -0
- SinaTools-0.1.1.dist-info/LICENSE +22 -0
- SinaTools-0.1.1.dist-info/METADATA +72 -0
- SinaTools-0.1.1.dist-info/RECORD +122 -0
- SinaTools-0.1.1.dist-info/WHEEL +6 -0
- SinaTools-0.1.1.dist-info/entry_points.txt +18 -0
- SinaTools-0.1.1.dist-info/top_level.txt +1 -0
- nlptools/CLI/DataDownload/download_files.py +71 -0
- nlptools/CLI/arabiner/bin/infer.py +117 -0
- nlptools/CLI/arabiner/bin/infer2.py +81 -0
- nlptools/CLI/morphology/ALMA_multi_word.py +75 -0
- nlptools/CLI/morphology/morph_analyzer.py +91 -0
- nlptools/CLI/salma/salma_tools.py +68 -0
- nlptools/CLI/utils/__init__.py +0 -0
- nlptools/CLI/utils/arStrip.py +99 -0
- nlptools/CLI/utils/corpus_tokenizer.py +74 -0
- nlptools/CLI/utils/implication.py +92 -0
- nlptools/CLI/utils/jaccard.py +96 -0
- nlptools/CLI/utils/latin_remove.py +51 -0
- nlptools/CLI/utils/remove_Punc.py +53 -0
- nlptools/CLI/utils/sentence_tokenizer.py +90 -0
- nlptools/CLI/utils/text_transliteration.py +77 -0
- nlptools/DataDownload/__init__.py +0 -0
- nlptools/DataDownload/downloader.py +185 -0
- nlptools/VERSION +1 -0
- nlptools/__init__.py +5 -0
- nlptools/arabert/__init__.py +1 -0
- nlptools/arabert/arabert/__init__.py +14 -0
- nlptools/arabert/arabert/create_classification_data.py +260 -0
- nlptools/arabert/arabert/create_pretraining_data.py +534 -0
- nlptools/arabert/arabert/extract_features.py +444 -0
- nlptools/arabert/arabert/lamb_optimizer.py +158 -0
- nlptools/arabert/arabert/modeling.py +1027 -0
- nlptools/arabert/arabert/optimization.py +202 -0
- nlptools/arabert/arabert/run_classifier.py +1078 -0
- nlptools/arabert/arabert/run_pretraining.py +593 -0
- nlptools/arabert/arabert/run_squad.py +1440 -0
- nlptools/arabert/arabert/tokenization.py +414 -0
- nlptools/arabert/araelectra/__init__.py +1 -0
- nlptools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +103 -0
- nlptools/arabert/araelectra/build_pretraining_dataset.py +230 -0
- nlptools/arabert/araelectra/build_pretraining_dataset_single_file.py +90 -0
- nlptools/arabert/araelectra/configure_finetuning.py +172 -0
- nlptools/arabert/araelectra/configure_pretraining.py +143 -0
- nlptools/arabert/araelectra/finetune/__init__.py +14 -0
- nlptools/arabert/araelectra/finetune/feature_spec.py +56 -0
- nlptools/arabert/araelectra/finetune/preprocessing.py +173 -0
- nlptools/arabert/araelectra/finetune/scorer.py +54 -0
- nlptools/arabert/araelectra/finetune/task.py +74 -0
- nlptools/arabert/araelectra/finetune/task_builder.py +70 -0
- nlptools/arabert/araelectra/flops_computation.py +215 -0
- nlptools/arabert/araelectra/model/__init__.py +14 -0
- nlptools/arabert/araelectra/model/modeling.py +1029 -0
- nlptools/arabert/araelectra/model/optimization.py +193 -0
- nlptools/arabert/araelectra/model/tokenization.py +355 -0
- nlptools/arabert/araelectra/pretrain/__init__.py +14 -0
- nlptools/arabert/araelectra/pretrain/pretrain_data.py +160 -0
- nlptools/arabert/araelectra/pretrain/pretrain_helpers.py +229 -0
- nlptools/arabert/araelectra/run_finetuning.py +323 -0
- nlptools/arabert/araelectra/run_pretraining.py +469 -0
- nlptools/arabert/araelectra/util/__init__.py +14 -0
- nlptools/arabert/araelectra/util/training_utils.py +112 -0
- nlptools/arabert/araelectra/util/utils.py +109 -0
- nlptools/arabert/aragpt2/__init__.py +2 -0
- nlptools/arabert/aragpt2/create_pretraining_data.py +95 -0
- nlptools/arabert/aragpt2/gpt2/__init__.py +2 -0
- nlptools/arabert/aragpt2/gpt2/lamb_optimizer.py +158 -0
- nlptools/arabert/aragpt2/gpt2/optimization.py +225 -0
- nlptools/arabert/aragpt2/gpt2/run_pretraining.py +397 -0
- nlptools/arabert/aragpt2/grover/__init__.py +0 -0
- nlptools/arabert/aragpt2/grover/dataloader.py +161 -0
- nlptools/arabert/aragpt2/grover/modeling.py +803 -0
- nlptools/arabert/aragpt2/grover/modeling_gpt2.py +1196 -0
- nlptools/arabert/aragpt2/grover/optimization_adafactor.py +234 -0
- nlptools/arabert/aragpt2/grover/train_tpu.py +187 -0
- nlptools/arabert/aragpt2/grover/utils.py +234 -0
- nlptools/arabert/aragpt2/train_bpe_tokenizer.py +59 -0
- nlptools/arabert/preprocess.py +818 -0
- nlptools/arabiner/__init__.py +0 -0
- nlptools/arabiner/bin/__init__.py +14 -0
- nlptools/arabiner/bin/eval.py +87 -0
- nlptools/arabiner/bin/infer.py +91 -0
- nlptools/arabiner/bin/process.py +140 -0
- nlptools/arabiner/bin/train.py +221 -0
- nlptools/arabiner/data/__init__.py +1 -0
- nlptools/arabiner/data/datasets.py +146 -0
- nlptools/arabiner/data/transforms.py +118 -0
- nlptools/arabiner/nn/BaseModel.py +22 -0
- nlptools/arabiner/nn/BertNestedTagger.py +34 -0
- nlptools/arabiner/nn/BertSeqTagger.py +17 -0
- nlptools/arabiner/nn/__init__.py +3 -0
- nlptools/arabiner/trainers/BaseTrainer.py +117 -0
- nlptools/arabiner/trainers/BertNestedTrainer.py +203 -0
- nlptools/arabiner/trainers/BertTrainer.py +163 -0
- nlptools/arabiner/trainers/__init__.py +3 -0
- nlptools/arabiner/utils/__init__.py +0 -0
- nlptools/arabiner/utils/data.py +124 -0
- nlptools/arabiner/utils/helpers.py +151 -0
- nlptools/arabiner/utils/metrics.py +69 -0
- nlptools/environment.yml +227 -0
- nlptools/install_env.py +13 -0
- nlptools/morphology/ALMA_multi_word.py +34 -0
- nlptools/morphology/__init__.py +52 -0
- nlptools/morphology/charsets.py +60 -0
- nlptools/morphology/morph_analyzer.py +170 -0
- nlptools/morphology/settings.py +8 -0
- nlptools/morphology/tokenizers_words.py +19 -0
- nlptools/nlptools.py +1 -0
- nlptools/salma/__init__.py +12 -0
- nlptools/salma/settings.py +31 -0
- nlptools/salma/views.py +459 -0
- nlptools/salma/wsd.py +126 -0
- nlptools/utils/__init__.py +0 -0
- nlptools/utils/corpus_tokenizer.py +73 -0
- nlptools/utils/implication.py +662 -0
- nlptools/utils/jaccard.py +247 -0
- nlptools/utils/parser.py +147 -0
- nlptools/utils/readfile.py +3 -0
- nlptools/utils/sentence_tokenizer.py +53 -0
- nlptools/utils/text_transliteration.py +232 -0
- nlptools/utils/utils.py +2 -0
File without changes
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from nlptools.DataDownload import downloader
|
2
|
+
import os
|
3
|
+
from nlptools.arabiner.utils.helpers import load_checkpoint
|
4
|
+
import nlptools
|
5
|
+
|
6
|
+
nlptools.tagger = None
|
7
|
+
nlptools.tag_vocab = None
|
8
|
+
nlptools.train_config = None
|
9
|
+
|
10
|
+
filename = 'Wj27012000.tar'
|
11
|
+
path =downloader.get_appdatadir()
|
12
|
+
model_path = os.path.join(path, filename)
|
13
|
+
print('1',model_path)
|
14
|
+
nlptools.tagger, nlptools.tag_vocab, nlptools.train_config = load_checkpoint(model_path)
|
@@ -0,0 +1,87 @@
|
|
1
|
+
import os
|
2
|
+
import logging
|
3
|
+
import argparse
|
4
|
+
from collections import namedtuple
|
5
|
+
from nlptools.arabiner.utils.helpers import load_checkpoint, make_output_dirs, logging_config
|
6
|
+
from nlptools.arabiner.utils.data import get_dataloaders, parse_conll_files
|
7
|
+
from nlptools.arabiner.utils.metrics import compute_single_label_metrics, compute_nested_metrics
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
def parse_args():
|
13
|
+
parser = argparse.ArgumentParser(
|
14
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
15
|
+
)
|
16
|
+
|
17
|
+
parser.add_argument(
|
18
|
+
"--output_path",
|
19
|
+
type=str,
|
20
|
+
required=True,
|
21
|
+
help="Path to save results",
|
22
|
+
)
|
23
|
+
|
24
|
+
parser.add_argument(
|
25
|
+
"--model_path",
|
26
|
+
type=str,
|
27
|
+
required=True,
|
28
|
+
help="Model path",
|
29
|
+
)
|
30
|
+
|
31
|
+
parser.add_argument(
|
32
|
+
"--data_paths",
|
33
|
+
nargs="+",
|
34
|
+
type=str,
|
35
|
+
required=True,
|
36
|
+
help="Text or sequence to tag, this is in same format as training data with 'O' tag for all tokens",
|
37
|
+
)
|
38
|
+
|
39
|
+
parser.add_argument(
|
40
|
+
"--batch_size",
|
41
|
+
type=int,
|
42
|
+
default=32,
|
43
|
+
help="Batch size",
|
44
|
+
)
|
45
|
+
|
46
|
+
args = parser.parse_args()
|
47
|
+
|
48
|
+
return args
|
49
|
+
|
50
|
+
|
51
|
+
def main(args):
|
52
|
+
# Create directory to save predictions
|
53
|
+
make_output_dirs(args.output_path, overwrite=True)
|
54
|
+
logging_config(log_file=os.path.join(args.output_path, "eval.log"))
|
55
|
+
|
56
|
+
# Load tagger
|
57
|
+
tagger, tag_vocab, train_config = load_checkpoint(args.model_path)
|
58
|
+
|
59
|
+
# Convert text to a tagger dataset and index the tokens in args.text
|
60
|
+
datasets, vocab = parse_conll_files(args.data_paths)
|
61
|
+
|
62
|
+
vocabs = namedtuple("Vocab", ["tags", "tokens"])
|
63
|
+
vocab = vocabs(tokens=vocab.tokens, tags=tag_vocab)
|
64
|
+
|
65
|
+
# From the datasets generate the dataloaders
|
66
|
+
dataloaders = get_dataloaders(
|
67
|
+
datasets, vocab,
|
68
|
+
train_config.data_config,
|
69
|
+
batch_size=args.batch_size,
|
70
|
+
shuffle=[False] * len(datasets)
|
71
|
+
)
|
72
|
+
|
73
|
+
# Evaluate the model on each dataloader
|
74
|
+
for dataloader, input_file in zip(dataloaders, args.data_paths):
|
75
|
+
filename = os.path.basename(input_file)
|
76
|
+
predictions_file = os.path.join(args.output_path, f"predictions_{filename}")
|
77
|
+
_, segments, _, _ = tagger.eval(dataloader)
|
78
|
+
tagger.segments_to_file(segments, predictions_file)
|
79
|
+
|
80
|
+
if "Nested" in train_config.trainer_config["fn"]:
|
81
|
+
compute_nested_metrics(segments, vocab.tags[1:])
|
82
|
+
else:
|
83
|
+
compute_single_label_metrics(segments)
|
84
|
+
|
85
|
+
|
86
|
+
if __name__ == "__main__":
|
87
|
+
main(parse_args())
|
@@ -0,0 +1,91 @@
|
|
1
|
+
import os
|
2
|
+
from collections import namedtuple
|
3
|
+
from nlptools.arabiner.utils.helpers import load_checkpoint
|
4
|
+
from nlptools.arabiner.utils.data import get_dataloaders, text2segments
|
5
|
+
from nlptools.DataDownload import downloader
|
6
|
+
import nlptools
|
7
|
+
def ner(text, batch_size=32):
|
8
|
+
"""
|
9
|
+
This method takes a text as input, and a batch size, then performs named entity recognition (NER) on the input text and returns a list of tagged mentions.
|
10
|
+
|
11
|
+
Args:
|
12
|
+
text (str): The input text to perform NER on.
|
13
|
+
batch_size (int, optional): Batch size for inference. Default is 32.
|
14
|
+
|
15
|
+
Returns:
|
16
|
+
list: A list of lists containing token and label pairs for each segment.
|
17
|
+
Each inner list has the format ['token', 'label1 label2 ...'].
|
18
|
+
**Example:**
|
19
|
+
|
20
|
+
.. highlight:: python
|
21
|
+
.. code-block:: python
|
22
|
+
|
23
|
+
from nlptools.arabiner.bin import infer
|
24
|
+
infer.ner('ذهب محمد الى جامعة بيرزيت')
|
25
|
+
|
26
|
+
#the output
|
27
|
+
[['ذهب', 'O'],
|
28
|
+
['محمد', 'B-PERS'],
|
29
|
+
['الى', 'O'],
|
30
|
+
['جامعة', 'B-ORG'],
|
31
|
+
['بيرزيت', 'B-GPE I-ORG']]
|
32
|
+
"""
|
33
|
+
# Load tagger
|
34
|
+
# filename = 'Wj27012000.tar'
|
35
|
+
# path =downloader.get_appdatadir()
|
36
|
+
# model_path = os.path.join(path, filename)
|
37
|
+
# print('1',model_path)
|
38
|
+
# tagger, tag_vocab, train_config = load_checkpoint(model_path)
|
39
|
+
|
40
|
+
|
41
|
+
# Convert text to a tagger dataset and index the tokens in args.text
|
42
|
+
dataset, token_vocab = text2segments(text)
|
43
|
+
|
44
|
+
vocabs = namedtuple("Vocab", ["tags", "tokens"])
|
45
|
+
vocab = vocabs(tokens=token_vocab, tags=nlptools.tag_vocab)
|
46
|
+
|
47
|
+
# From the datasets generate the dataloaders
|
48
|
+
dataloader = get_dataloaders(
|
49
|
+
(dataset,),
|
50
|
+
vocab,
|
51
|
+
nlptools.train_config.data_config,
|
52
|
+
batch_size=batch_size,
|
53
|
+
shuffle=(False,),
|
54
|
+
)[0]
|
55
|
+
|
56
|
+
# Perform inference on the text and get back the tagged segments
|
57
|
+
segments = nlptools.tagger.infer(dataloader)
|
58
|
+
segments_lists = []
|
59
|
+
# Print results
|
60
|
+
for segment in segments:
|
61
|
+
for token in segment:
|
62
|
+
segments_list = []
|
63
|
+
segments_list.append(token.text)
|
64
|
+
#print([t['tag'] for t in token.pred_tag])
|
65
|
+
list_of_tags = [t['tag'] for t in token.pred_tag]
|
66
|
+
list_of_tags = [i for i in list_of_tags if i not in('O',' ','')]
|
67
|
+
#print(list_of_tags)
|
68
|
+
if list_of_tags == []:
|
69
|
+
segments_list.append(' '.join(['O']))
|
70
|
+
else:
|
71
|
+
segments_list.append(' '.join(list_of_tags))
|
72
|
+
segments_lists.append(segments_list)
|
73
|
+
return segments_lists
|
74
|
+
|
75
|
+
#Print results
|
76
|
+
# for segment in segments:
|
77
|
+
# s = [
|
78
|
+
# (token.text, token.pred_tag[0]['tag'])
|
79
|
+
# for token in segment
|
80
|
+
# if token.pred_tag[0]['tag'] != 'O'
|
81
|
+
# ]
|
82
|
+
# print(", ".join([f"({token}, {tag})" for token, tag in s]))
|
83
|
+
|
84
|
+
def extract_tags(text):
|
85
|
+
tags = []
|
86
|
+
tokens = text.split()
|
87
|
+
for token in tokens:
|
88
|
+
tag = token.split("(")[-1].split(")")[0]
|
89
|
+
if tag != "O":
|
90
|
+
tags.append(tag)
|
91
|
+
return " ".join(tags)
|
@@ -0,0 +1,140 @@
|
|
1
|
+
import os
|
2
|
+
import argparse
|
3
|
+
import csv
|
4
|
+
import logging
|
5
|
+
import numpy as np
|
6
|
+
from nlptools.arabiner.utils.helpers import logging_config
|
7
|
+
from nlptools.arabiner.utils.data import conll_to_segments
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
def to_conll_format(input_files, output_path, multi_label=False):
|
13
|
+
"""
|
14
|
+
Parse data files and convert them into CoNLL format
|
15
|
+
:param input_files: List[str] - list of filenames
|
16
|
+
:param output_path: str - output path
|
17
|
+
:param multi_label: boolean - True to process data with mutli-class/multi-label
|
18
|
+
:return:
|
19
|
+
"""
|
20
|
+
for input_file in input_files:
|
21
|
+
tokens = list()
|
22
|
+
prev_sent_id = None
|
23
|
+
|
24
|
+
with open(input_file, "r") as fh:
|
25
|
+
r = csv.reader(fh, delimiter="\t", quotechar=" ")
|
26
|
+
next(r)
|
27
|
+
|
28
|
+
for row in r:
|
29
|
+
sent_id, token, labels = row[1], row[3], row[4].split()
|
30
|
+
valid_labels = sum([1 for l in labels if "-" in l or l == "O"]) == len(labels)
|
31
|
+
|
32
|
+
if not valid_labels:
|
33
|
+
logging.warning("Invalid labels found %s", str(row))
|
34
|
+
continue
|
35
|
+
if not labels:
|
36
|
+
logging.warning("Token %s has no label", str(row))
|
37
|
+
continue
|
38
|
+
if not token:
|
39
|
+
logging.warning("Token %s is missing", str(row))
|
40
|
+
continue
|
41
|
+
if len(token.split()) > 1:
|
42
|
+
logging.warning("Token %s has multiple tokens", str(row))
|
43
|
+
continue
|
44
|
+
|
45
|
+
if prev_sent_id is not None and sent_id != prev_sent_id:
|
46
|
+
tokens.append([])
|
47
|
+
|
48
|
+
if multi_label:
|
49
|
+
tokens.append([token] + labels)
|
50
|
+
else:
|
51
|
+
tokens.append([token, labels[0]])
|
52
|
+
|
53
|
+
prev_sent_id = sent_id
|
54
|
+
|
55
|
+
num_segments = sum([1 for token in tokens if not token])
|
56
|
+
logging.info("Found %d segments and %d tokens in %s", num_segments + 1, len(tokens) - num_segments, input_file)
|
57
|
+
|
58
|
+
filename = os.path.basename(input_file)
|
59
|
+
output_file = os.path.join(output_path, filename)
|
60
|
+
|
61
|
+
with open(output_file, "w") as fh:
|
62
|
+
fh.write("\n".join(" ".join(token) for token in tokens))
|
63
|
+
logging.info("Output file %s", output_file)
|
64
|
+
|
65
|
+
|
66
|
+
def train_dev_test_split(input_files, output_path, train_ratio, dev_ratio):
|
67
|
+
segments = list()
|
68
|
+
filenames = ["train.txt", "val.txt", "test.txt"]
|
69
|
+
|
70
|
+
for input_file in input_files:
|
71
|
+
segments += conll_to_segments(input_file)
|
72
|
+
|
73
|
+
n = len(segments)
|
74
|
+
np.random.shuffle(segments)
|
75
|
+
datasets = np.split(segments, [int(train_ratio*n), int((train_ratio+dev_ratio)*n)])
|
76
|
+
|
77
|
+
# write data to files
|
78
|
+
for i in range(len(datasets)):
|
79
|
+
filename = os.path.join(output_path, filenames[i])
|
80
|
+
|
81
|
+
with open(filename, "w") as fh:
|
82
|
+
text = "\n\n".join(["\n".join([f"{token.text} {' '.join(token.gold_tag)}" for token in segment]) for segment in datasets[i]])
|
83
|
+
fh.write(text)
|
84
|
+
logging.info("Output file %s", filename)
|
85
|
+
|
86
|
+
|
87
|
+
def main(args):
|
88
|
+
if args.task == "to_conll_format":
|
89
|
+
to_conll_format(args.input_files, args.output_path, multi_label=args.multi_label)
|
90
|
+
if args.task == "train_dev_test_split":
|
91
|
+
train_dev_test_split(args.input_files, args.output_path, args.train_ratio, args.dev_ratio)
|
92
|
+
|
93
|
+
|
94
|
+
if __name__ == "__main__":
|
95
|
+
parser = argparse.ArgumentParser(
|
96
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
97
|
+
)
|
98
|
+
|
99
|
+
parser.add_argument(
|
100
|
+
"--input_files",
|
101
|
+
type=str,
|
102
|
+
nargs="+",
|
103
|
+
required=True,
|
104
|
+
help="List of input files",
|
105
|
+
)
|
106
|
+
|
107
|
+
parser.add_argument(
|
108
|
+
"--output_path",
|
109
|
+
type=str,
|
110
|
+
required=True,
|
111
|
+
help="Output path",
|
112
|
+
)
|
113
|
+
|
114
|
+
parser.add_argument(
|
115
|
+
"--train_ratio",
|
116
|
+
type=float,
|
117
|
+
required=False,
|
118
|
+
help="Training data ratio (percent of segments). Required with the task train_dev_test_split. "
|
119
|
+
"Files must in ConLL format",
|
120
|
+
)
|
121
|
+
|
122
|
+
parser.add_argument(
|
123
|
+
"--dev_ratio",
|
124
|
+
type=float,
|
125
|
+
required=False,
|
126
|
+
help="Dev/val data ratio (percent of segments). Required with the task train_dev_test_split. "
|
127
|
+
"Files must in ConLL format",
|
128
|
+
)
|
129
|
+
|
130
|
+
parser.add_argument(
|
131
|
+
"--task", required=True, choices=["to_conll_format", "train_dev_test_split"]
|
132
|
+
)
|
133
|
+
|
134
|
+
parser.add_argument(
|
135
|
+
"--multi_label", action='store_true'
|
136
|
+
)
|
137
|
+
|
138
|
+
args = parser.parse_args()
|
139
|
+
logging_config(os.path.join(args.output_path, "process.log"))
|
140
|
+
main(args)
|
@@ -0,0 +1,221 @@
|
|
1
|
+
import os
|
2
|
+
import logging
|
3
|
+
import json
|
4
|
+
import argparse
|
5
|
+
import torch.utils.tensorboard
|
6
|
+
from torchvision import *
|
7
|
+
import pickle
|
8
|
+
from nlptools.arabiner.utils.data import get_dataloaders, parse_conll_files
|
9
|
+
from nlptools.arabiner.utils.helpers import logging_config, load_object, make_output_dirs, set_seed
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
|
14
|
+
def parse_args():
|
15
|
+
parser = argparse.ArgumentParser(
|
16
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
17
|
+
)
|
18
|
+
|
19
|
+
parser.add_argument(
|
20
|
+
"--output_path",
|
21
|
+
type=str,
|
22
|
+
required=True,
|
23
|
+
help="Output path",
|
24
|
+
)
|
25
|
+
|
26
|
+
parser.add_argument(
|
27
|
+
"--train_path",
|
28
|
+
type=str,
|
29
|
+
required=True,
|
30
|
+
help="Path to training data",
|
31
|
+
)
|
32
|
+
|
33
|
+
parser.add_argument(
|
34
|
+
"--val_path",
|
35
|
+
type=str,
|
36
|
+
required=True,
|
37
|
+
help="Path to training data",
|
38
|
+
)
|
39
|
+
|
40
|
+
parser.add_argument(
|
41
|
+
"--test_path",
|
42
|
+
type=str,
|
43
|
+
required=True,
|
44
|
+
help="Path to training data",
|
45
|
+
)
|
46
|
+
|
47
|
+
parser.add_argument(
|
48
|
+
"--bert_model",
|
49
|
+
type=str,
|
50
|
+
default="aubmindlab/bert-base-arabertv2",
|
51
|
+
help="BERT model",
|
52
|
+
)
|
53
|
+
|
54
|
+
parser.add_argument(
|
55
|
+
"--gpus",
|
56
|
+
type=int,
|
57
|
+
nargs="+",
|
58
|
+
default=[0],
|
59
|
+
help="GPU IDs to train on",
|
60
|
+
)
|
61
|
+
|
62
|
+
parser.add_argument(
|
63
|
+
"--log_interval",
|
64
|
+
type=int,
|
65
|
+
default=10,
|
66
|
+
help="Log results every that many timesteps",
|
67
|
+
)
|
68
|
+
|
69
|
+
parser.add_argument(
|
70
|
+
"--batch_size",
|
71
|
+
type=int,
|
72
|
+
default=32,
|
73
|
+
help="Batch size",
|
74
|
+
)
|
75
|
+
|
76
|
+
parser.add_argument(
|
77
|
+
"--num_workers",
|
78
|
+
type=int,
|
79
|
+
default=0,
|
80
|
+
help="Dataloader number of workers",
|
81
|
+
)
|
82
|
+
|
83
|
+
parser.add_argument(
|
84
|
+
"--data_config",
|
85
|
+
type=json.loads,
|
86
|
+
default='{"fn": "arabiner.data.datasets.DefaultDataset", "kwargs": {"max_seq_len": 512}}',
|
87
|
+
help="Dataset configurations",
|
88
|
+
)
|
89
|
+
|
90
|
+
parser.add_argument(
|
91
|
+
"--trainer_config",
|
92
|
+
type=json.loads,
|
93
|
+
default='{"fn": "arabiner.trainers.BertTrainer", "kwargs": {"max_epochs": 50}}',
|
94
|
+
help="Trainer configurations",
|
95
|
+
)
|
96
|
+
|
97
|
+
parser.add_argument(
|
98
|
+
"--network_config",
|
99
|
+
type=json.loads,
|
100
|
+
default='{"fn": "arabiner.nn.BertSeqTagger", "kwargs": '
|
101
|
+
'{"dropout": 0.1, "bert_model": "aubmindlab/bert-base-arabertv2"}}',
|
102
|
+
help="Network configurations",
|
103
|
+
)
|
104
|
+
|
105
|
+
parser.add_argument(
|
106
|
+
"--optimizer",
|
107
|
+
type=json.loads,
|
108
|
+
default='{"fn": "torch.optim.AdamW", "kwargs": {"lr": 0.0001}}',
|
109
|
+
help="Optimizer configurations",
|
110
|
+
)
|
111
|
+
|
112
|
+
parser.add_argument(
|
113
|
+
"--lr_scheduler",
|
114
|
+
type=json.loads,
|
115
|
+
default='{"fn": "torch.optim.lr_scheduler.ExponentialLR", "kwargs": {"gamma": 1}}',
|
116
|
+
help="Learning rate scheduler configurations",
|
117
|
+
)
|
118
|
+
|
119
|
+
parser.add_argument(
|
120
|
+
"--loss",
|
121
|
+
type=json.loads,
|
122
|
+
default='{"fn": "torch.nn.CrossEntropyLoss", "kwargs": {}}',
|
123
|
+
help="Loss function configurations",
|
124
|
+
)
|
125
|
+
|
126
|
+
parser.add_argument(
|
127
|
+
"--overwrite",
|
128
|
+
action="store_true",
|
129
|
+
help="Overwrite output directory",
|
130
|
+
)
|
131
|
+
|
132
|
+
parser.add_argument(
|
133
|
+
"--seed",
|
134
|
+
type=int,
|
135
|
+
default=1,
|
136
|
+
help="Seed for random initialization",
|
137
|
+
)
|
138
|
+
|
139
|
+
args = parser.parse_args()
|
140
|
+
|
141
|
+
return args
|
142
|
+
|
143
|
+
|
144
|
+
def main(args):
|
145
|
+
make_output_dirs(
|
146
|
+
args.output_path,
|
147
|
+
subdirs=("tensorboard", "checkpoints"),
|
148
|
+
overwrite=args.overwrite,
|
149
|
+
)
|
150
|
+
|
151
|
+
# Set the seed for randomization
|
152
|
+
set_seed(args.seed)
|
153
|
+
|
154
|
+
logging_config(os.path.join(args.output_path, "train.log"))
|
155
|
+
summary_writer = torch.utils.tensorboard.SummaryWriter(
|
156
|
+
os.path.join(args.output_path, "tensorboard")
|
157
|
+
)
|
158
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu) for gpu in args.gpus])
|
159
|
+
|
160
|
+
# Get the datasets and vocab for tags and tokens
|
161
|
+
datasets, vocab = parse_conll_files((args.train_path, args.val_path, args.test_path))
|
162
|
+
|
163
|
+
if "Nested" in args.network_config["fn"]:
|
164
|
+
args.network_config["kwargs"]["num_labels"] = [len(v) for v in vocab.tags[1:]]
|
165
|
+
else:
|
166
|
+
args.network_config["kwargs"]["num_labels"] = len(vocab.tags[0])
|
167
|
+
|
168
|
+
# Save tag vocab to desk
|
169
|
+
with open(os.path.join(args.output_path, "tag_vocab.pkl"), "wb") as fh:
|
170
|
+
pickle.dump(vocab.tags, fh)
|
171
|
+
|
172
|
+
# Write config to file
|
173
|
+
args_file = os.path.join(args.output_path, "args.json")
|
174
|
+
with open(args_file, "w") as fh:
|
175
|
+
logger.info("Writing config to %s", args_file)
|
176
|
+
json.dump(args.__dict__, fh, indent=4)
|
177
|
+
|
178
|
+
# From the datasets generate the dataloaders
|
179
|
+
args.data_config["kwargs"]["bert_model"] = args.network_config["kwargs"]["bert_model"]
|
180
|
+
train_dataloader, val_dataloader, test_dataloader = get_dataloaders(
|
181
|
+
datasets, vocab, args.data_config, args.batch_size, args.num_workers
|
182
|
+
)
|
183
|
+
|
184
|
+
model = load_object(args.network_config["fn"], args.network_config["kwargs"])
|
185
|
+
model = torch.nn.DataParallel(model, device_ids=range(len(args.gpus)))
|
186
|
+
|
187
|
+
if torch.cuda.is_available():
|
188
|
+
model = model.cuda()
|
189
|
+
|
190
|
+
args.optimizer["kwargs"]["params"] = model.parameters()
|
191
|
+
optimizer = load_object(args.optimizer["fn"], args.optimizer["kwargs"])
|
192
|
+
|
193
|
+
args.lr_scheduler["kwargs"]["optimizer"] = optimizer
|
194
|
+
if "num_training_steps" in args.lr_scheduler["kwargs"]:
|
195
|
+
args.lr_scheduler["kwargs"]["num_training_steps"] = args.max_epochs * len(
|
196
|
+
train_dataloader
|
197
|
+
)
|
198
|
+
|
199
|
+
scheduler = load_object(args.lr_scheduler["fn"], args.lr_scheduler["kwargs"])
|
200
|
+
loss = load_object(args.loss["fn"], args.loss["kwargs"])
|
201
|
+
|
202
|
+
args.trainer_config["kwargs"].update({
|
203
|
+
"model": model,
|
204
|
+
"optimizer": optimizer,
|
205
|
+
"scheduler": scheduler,
|
206
|
+
"loss": loss,
|
207
|
+
"train_dataloader": train_dataloader,
|
208
|
+
"val_dataloader": val_dataloader,
|
209
|
+
"test_dataloader": test_dataloader,
|
210
|
+
"log_interval": args.log_interval,
|
211
|
+
"summary_writer": summary_writer,
|
212
|
+
"output_path": args.output_path
|
213
|
+
})
|
214
|
+
|
215
|
+
trainer = load_object(args.trainer_config["fn"], args.trainer_config["kwargs"])
|
216
|
+
trainer.train()
|
217
|
+
return
|
218
|
+
|
219
|
+
|
220
|
+
if __name__ == "__main__":
|
221
|
+
main(parse_args())
|
@@ -0,0 +1 @@
|
|
1
|
+
from nlptools.arabiner.data.datasets import NestedTagsDataset
|