SinaTools 0.1.40__py2.py3-none-any.whl → 1.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/METADATA +1 -1
- SinaTools-1.0.1.dist-info/RECORD +73 -0
- sinatools/VERSION +1 -1
- sinatools/ner/__init__.py +5 -7
- sinatools/ner/trainers/BertNestedTrainer.py +203 -203
- sinatools/ner/trainers/BertTrainer.py +163 -163
- sinatools/ner/trainers/__init__.py +2 -2
- SinaTools-0.1.40.dist-info/RECORD +0 -123
- sinatools/arabert/arabert/__init__.py +0 -14
- sinatools/arabert/arabert/create_classification_data.py +0 -260
- sinatools/arabert/arabert/create_pretraining_data.py +0 -534
- sinatools/arabert/arabert/extract_features.py +0 -444
- sinatools/arabert/arabert/lamb_optimizer.py +0 -158
- sinatools/arabert/arabert/modeling.py +0 -1027
- sinatools/arabert/arabert/optimization.py +0 -202
- sinatools/arabert/arabert/run_classifier.py +0 -1078
- sinatools/arabert/arabert/run_pretraining.py +0 -593
- sinatools/arabert/arabert/run_squad.py +0 -1440
- sinatools/arabert/arabert/tokenization.py +0 -414
- sinatools/arabert/araelectra/__init__.py +0 -1
- sinatools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +0 -103
- sinatools/arabert/araelectra/build_pretraining_dataset.py +0 -230
- sinatools/arabert/araelectra/build_pretraining_dataset_single_file.py +0 -90
- sinatools/arabert/araelectra/configure_finetuning.py +0 -172
- sinatools/arabert/araelectra/configure_pretraining.py +0 -143
- sinatools/arabert/araelectra/finetune/__init__.py +0 -14
- sinatools/arabert/araelectra/finetune/feature_spec.py +0 -56
- sinatools/arabert/araelectra/finetune/preprocessing.py +0 -173
- sinatools/arabert/araelectra/finetune/scorer.py +0 -54
- sinatools/arabert/araelectra/finetune/task.py +0 -74
- sinatools/arabert/araelectra/finetune/task_builder.py +0 -70
- sinatools/arabert/araelectra/flops_computation.py +0 -215
- sinatools/arabert/araelectra/model/__init__.py +0 -14
- sinatools/arabert/araelectra/model/modeling.py +0 -1029
- sinatools/arabert/araelectra/model/optimization.py +0 -193
- sinatools/arabert/araelectra/model/tokenization.py +0 -355
- sinatools/arabert/araelectra/pretrain/__init__.py +0 -14
- sinatools/arabert/araelectra/pretrain/pretrain_data.py +0 -160
- sinatools/arabert/araelectra/pretrain/pretrain_helpers.py +0 -229
- sinatools/arabert/araelectra/run_finetuning.py +0 -323
- sinatools/arabert/araelectra/run_pretraining.py +0 -469
- sinatools/arabert/araelectra/util/__init__.py +0 -14
- sinatools/arabert/araelectra/util/training_utils.py +0 -112
- sinatools/arabert/araelectra/util/utils.py +0 -109
- sinatools/arabert/aragpt2/__init__.py +0 -2
- sinatools/arabert/aragpt2/create_pretraining_data.py +0 -95
- sinatools/arabert/aragpt2/gpt2/__init__.py +0 -2
- sinatools/arabert/aragpt2/gpt2/lamb_optimizer.py +0 -158
- sinatools/arabert/aragpt2/gpt2/optimization.py +0 -225
- sinatools/arabert/aragpt2/gpt2/run_pretraining.py +0 -397
- sinatools/arabert/aragpt2/grover/__init__.py +0 -0
- sinatools/arabert/aragpt2/grover/dataloader.py +0 -161
- sinatools/arabert/aragpt2/grover/modeling.py +0 -803
- sinatools/arabert/aragpt2/grover/modeling_gpt2.py +0 -1196
- sinatools/arabert/aragpt2/grover/optimization_adafactor.py +0 -234
- sinatools/arabert/aragpt2/grover/train_tpu.py +0 -187
- sinatools/arabert/aragpt2/grover/utils.py +0 -234
- sinatools/arabert/aragpt2/train_bpe_tokenizer.py +0 -59
- {SinaTools-0.1.40.data → SinaTools-1.0.1.data}/data/sinatools/environment.yml +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/AUTHORS.rst +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/LICENSE +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/WHEEL +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/entry_points.txt +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/top_level.txt +0 -0
@@ -1,534 +0,0 @@
|
|
1
|
-
# coding=utf-8
|
2
|
-
# Copyright 2018 The Google AI Language Team Authors.
|
3
|
-
#
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
-
# you may not use this file except in compliance with the License.
|
6
|
-
# You may obtain a copy of the License at
|
7
|
-
#
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
-
#
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
-
# See the License for the specific language governing permissions and
|
14
|
-
# limitations under the License.
|
15
|
-
"""Create masked LM/next sentence masked_lm TF examples for BERT."""
|
16
|
-
|
17
|
-
from __future__ import absolute_import
|
18
|
-
from __future__ import division
|
19
|
-
from __future__ import print_function
|
20
|
-
|
21
|
-
import collections
|
22
|
-
import random
|
23
|
-
import tokenization
|
24
|
-
import tensorflow as tf
|
25
|
-
|
26
|
-
flags = tf.flags
|
27
|
-
|
28
|
-
FLAGS = flags.FLAGS
|
29
|
-
|
30
|
-
flags.DEFINE_string(
|
31
|
-
"input_file", None, "Input raw text file (or comma-separated list of files)."
|
32
|
-
)
|
33
|
-
|
34
|
-
flags.DEFINE_string(
|
35
|
-
"output_file", None, "Output TF example file (or comma-separated list of files)."
|
36
|
-
)
|
37
|
-
|
38
|
-
flags.DEFINE_string(
|
39
|
-
"vocab_file", None, "The vocabulary file that the BERT model was trained on."
|
40
|
-
)
|
41
|
-
|
42
|
-
flags.DEFINE_bool(
|
43
|
-
"do_lower_case",
|
44
|
-
True,
|
45
|
-
"Whether to lower case the input text. Should be True for uncased "
|
46
|
-
"models and False for cased models.",
|
47
|
-
)
|
48
|
-
|
49
|
-
flags.DEFINE_bool(
|
50
|
-
"do_whole_word_mask",
|
51
|
-
False,
|
52
|
-
"Whether to use whole word masking rather than per-WordPiece masking.",
|
53
|
-
)
|
54
|
-
|
55
|
-
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
|
56
|
-
|
57
|
-
flags.DEFINE_integer(
|
58
|
-
"max_predictions_per_seq",
|
59
|
-
20,
|
60
|
-
"Maximum number of masked LM predictions per sequence.",
|
61
|
-
)
|
62
|
-
|
63
|
-
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
|
64
|
-
|
65
|
-
flags.DEFINE_integer(
|
66
|
-
"dupe_factor",
|
67
|
-
10,
|
68
|
-
"Number of times to duplicate the input data (with different masks).",
|
69
|
-
)
|
70
|
-
|
71
|
-
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
|
72
|
-
|
73
|
-
flags.DEFINE_float(
|
74
|
-
"short_seq_prob",
|
75
|
-
0.1,
|
76
|
-
"Probability of creating sequences which are shorter than the " "maximum length.",
|
77
|
-
)
|
78
|
-
|
79
|
-
|
80
|
-
class TrainingInstance(object):
|
81
|
-
"""A single training instance (sentence pair)."""
|
82
|
-
|
83
|
-
def __init__(
|
84
|
-
self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, is_random_next
|
85
|
-
):
|
86
|
-
self.tokens = tokens
|
87
|
-
self.segment_ids = segment_ids
|
88
|
-
self.is_random_next = is_random_next
|
89
|
-
self.masked_lm_positions = masked_lm_positions
|
90
|
-
self.masked_lm_labels = masked_lm_labels
|
91
|
-
|
92
|
-
def __str__(self):
|
93
|
-
s = ""
|
94
|
-
s += "tokens: %s\n" % (
|
95
|
-
" ".join([tokenization.printable_text(x) for x in self.tokens])
|
96
|
-
)
|
97
|
-
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
|
98
|
-
s += "is_random_next: %s\n" % self.is_random_next
|
99
|
-
s += "masked_lm_positions: %s\n" % (
|
100
|
-
" ".join([str(x) for x in self.masked_lm_positions])
|
101
|
-
)
|
102
|
-
s += "masked_lm_labels: %s\n" % (
|
103
|
-
" ".join([tokenization.printable_text(x) for x in self.masked_lm_labels])
|
104
|
-
)
|
105
|
-
s += "\n"
|
106
|
-
return s
|
107
|
-
|
108
|
-
def __repr__(self):
|
109
|
-
return self.__str__()
|
110
|
-
|
111
|
-
|
112
|
-
def write_instance_to_example_files(
|
113
|
-
instances, tokenizer, max_seq_length, max_predictions_per_seq, output_files
|
114
|
-
):
|
115
|
-
"""Create TF example files from `TrainingInstance`s."""
|
116
|
-
writers = []
|
117
|
-
for output_file in output_files:
|
118
|
-
writers.append(tf.python_io.TFRecordWriter(output_file))
|
119
|
-
|
120
|
-
writer_index = 0
|
121
|
-
|
122
|
-
total_written = 0
|
123
|
-
for (inst_index, instance) in enumerate(instances):
|
124
|
-
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
|
125
|
-
input_mask = [1] * len(input_ids)
|
126
|
-
segment_ids = list(instance.segment_ids)
|
127
|
-
assert len(input_ids) <= max_seq_length
|
128
|
-
|
129
|
-
while len(input_ids) < max_seq_length:
|
130
|
-
input_ids.append(0)
|
131
|
-
input_mask.append(0)
|
132
|
-
segment_ids.append(0)
|
133
|
-
|
134
|
-
assert len(input_ids) == max_seq_length
|
135
|
-
assert len(input_mask) == max_seq_length
|
136
|
-
assert len(segment_ids) == max_seq_length
|
137
|
-
|
138
|
-
masked_lm_positions = list(instance.masked_lm_positions)
|
139
|
-
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
|
140
|
-
masked_lm_weights = [1.0] * len(masked_lm_ids)
|
141
|
-
|
142
|
-
while len(masked_lm_positions) < max_predictions_per_seq:
|
143
|
-
masked_lm_positions.append(0)
|
144
|
-
masked_lm_ids.append(0)
|
145
|
-
masked_lm_weights.append(0.0)
|
146
|
-
|
147
|
-
next_sentence_label = 1 if instance.is_random_next else 0
|
148
|
-
|
149
|
-
features = collections.OrderedDict()
|
150
|
-
features["input_ids"] = create_int_feature(input_ids)
|
151
|
-
features["input_mask"] = create_int_feature(input_mask)
|
152
|
-
features["segment_ids"] = create_int_feature(segment_ids)
|
153
|
-
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
|
154
|
-
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
|
155
|
-
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
|
156
|
-
features["next_sentence_labels"] = create_int_feature([next_sentence_label])
|
157
|
-
|
158
|
-
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
159
|
-
|
160
|
-
writers[writer_index].write(tf_example.SerializeToString())
|
161
|
-
writer_index = (writer_index + 1) % len(writers)
|
162
|
-
|
163
|
-
total_written += 1
|
164
|
-
|
165
|
-
if inst_index < 20:
|
166
|
-
tf.logging.info("*** Example ***")
|
167
|
-
tf.logging.info(
|
168
|
-
"tokens: %s"
|
169
|
-
% " ".join([tokenization.printable_text(x) for x in instance.tokens])
|
170
|
-
)
|
171
|
-
|
172
|
-
for feature_name in features.keys():
|
173
|
-
feature = features[feature_name]
|
174
|
-
values = []
|
175
|
-
if feature.int64_list.value:
|
176
|
-
values = feature.int64_list.value
|
177
|
-
elif feature.float_list.value:
|
178
|
-
values = feature.float_list.value
|
179
|
-
tf.logging.info(
|
180
|
-
"%s: %s" % (feature_name, " ".join([str(x) for x in values]))
|
181
|
-
)
|
182
|
-
|
183
|
-
for writer in writers:
|
184
|
-
writer.close()
|
185
|
-
|
186
|
-
tf.logging.info("Wrote %d total instances", total_written)
|
187
|
-
|
188
|
-
|
189
|
-
def create_int_feature(values):
|
190
|
-
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
191
|
-
return feature
|
192
|
-
|
193
|
-
|
194
|
-
def create_float_feature(values):
|
195
|
-
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
196
|
-
return feature
|
197
|
-
|
198
|
-
|
199
|
-
def create_training_instances(
|
200
|
-
input_files,
|
201
|
-
tokenizer,
|
202
|
-
max_seq_length,
|
203
|
-
dupe_factor,
|
204
|
-
short_seq_prob,
|
205
|
-
masked_lm_prob,
|
206
|
-
max_predictions_per_seq,
|
207
|
-
rng,
|
208
|
-
):
|
209
|
-
"""Create `TrainingInstance`s from raw text."""
|
210
|
-
all_documents = [[]]
|
211
|
-
|
212
|
-
# Input file format:
|
213
|
-
# (1) One sentence per line. These should ideally be actual sentences, not
|
214
|
-
# entire paragraphs or arbitrary spans of text. (Because we use the
|
215
|
-
# sentence boundaries for the "next sentence prediction" task).
|
216
|
-
# (2) Blank lines between documents. Document boundaries are needed so
|
217
|
-
# that the "next sentence prediction" task doesn't span between documents.
|
218
|
-
for input_file in input_files:
|
219
|
-
with tf.gfile.GFile(input_file, "r") as reader:
|
220
|
-
while True:
|
221
|
-
line = tokenization.convert_to_unicode(reader.readline())
|
222
|
-
if not line:
|
223
|
-
break
|
224
|
-
line = line.strip()
|
225
|
-
|
226
|
-
# Empty lines are used as document delimiters
|
227
|
-
if not line:
|
228
|
-
all_documents.append([])
|
229
|
-
tokens = tokenizer.tokenize(line)
|
230
|
-
if tokens:
|
231
|
-
all_documents[-1].append(tokens)
|
232
|
-
|
233
|
-
# Remove empty documents
|
234
|
-
all_documents = [x for x in all_documents if x]
|
235
|
-
rng.shuffle(all_documents)
|
236
|
-
|
237
|
-
vocab_words = list(tokenizer.vocab.keys())
|
238
|
-
instances = []
|
239
|
-
for _ in range(dupe_factor):
|
240
|
-
for document_index in range(len(all_documents)):
|
241
|
-
instances.extend(
|
242
|
-
create_instances_from_document(
|
243
|
-
all_documents,
|
244
|
-
document_index,
|
245
|
-
max_seq_length,
|
246
|
-
short_seq_prob,
|
247
|
-
masked_lm_prob,
|
248
|
-
max_predictions_per_seq,
|
249
|
-
vocab_words,
|
250
|
-
rng,
|
251
|
-
)
|
252
|
-
)
|
253
|
-
|
254
|
-
rng.shuffle(instances)
|
255
|
-
return instances
|
256
|
-
|
257
|
-
|
258
|
-
def create_instances_from_document(
|
259
|
-
all_documents,
|
260
|
-
document_index,
|
261
|
-
max_seq_length,
|
262
|
-
short_seq_prob,
|
263
|
-
masked_lm_prob,
|
264
|
-
max_predictions_per_seq,
|
265
|
-
vocab_words,
|
266
|
-
rng,
|
267
|
-
):
|
268
|
-
"""Creates `TrainingInstance`s for a single document."""
|
269
|
-
document = all_documents[document_index]
|
270
|
-
|
271
|
-
# Account for [CLS], [SEP], [SEP]
|
272
|
-
max_num_tokens = max_seq_length - 3
|
273
|
-
|
274
|
-
# We *usually* want to fill up the entire sequence since we are padding
|
275
|
-
# to `max_seq_length` anyways, so short sequences are generally wasted
|
276
|
-
# computation. However, we *sometimes*
|
277
|
-
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
278
|
-
# sequences to minimize the mismatch between pre-training and fine-tuning.
|
279
|
-
# The `target_seq_length` is just a rough target however, whereas
|
280
|
-
# `max_seq_length` is a hard limit.
|
281
|
-
target_seq_length = max_num_tokens
|
282
|
-
if rng.random() < short_seq_prob:
|
283
|
-
target_seq_length = rng.randint(2, max_num_tokens)
|
284
|
-
|
285
|
-
# We DON'T just concatenate all of the tokens from a document into a long
|
286
|
-
# sequence and choose an arbitrary split point because this would make the
|
287
|
-
# next sentence prediction task too easy. Instead, we split the input into
|
288
|
-
# segments "A" and "B" based on the actual "sentences" provided by the user
|
289
|
-
# input.
|
290
|
-
instances = []
|
291
|
-
current_chunk = []
|
292
|
-
current_length = 0
|
293
|
-
i = 0
|
294
|
-
while i < len(document):
|
295
|
-
segment = document[i]
|
296
|
-
current_chunk.append(segment)
|
297
|
-
current_length += len(segment)
|
298
|
-
if i == len(document) - 1 or current_length >= target_seq_length:
|
299
|
-
if current_chunk:
|
300
|
-
# `a_end` is how many segments from `current_chunk` go into the `A`
|
301
|
-
# (first) sentence.
|
302
|
-
a_end = 1
|
303
|
-
if len(current_chunk) >= 2:
|
304
|
-
a_end = rng.randint(1, len(current_chunk) - 1)
|
305
|
-
|
306
|
-
tokens_a = []
|
307
|
-
for j in range(a_end):
|
308
|
-
tokens_a.extend(current_chunk[j])
|
309
|
-
|
310
|
-
tokens_b = []
|
311
|
-
# Random next
|
312
|
-
is_random_next = False
|
313
|
-
if len(current_chunk) == 1 or rng.random() < 0.5:
|
314
|
-
is_random_next = True
|
315
|
-
target_b_length = target_seq_length - len(tokens_a)
|
316
|
-
|
317
|
-
# This should rarely go for more than one iteration for large
|
318
|
-
# corpora. However, just to be careful, we try to make sure that
|
319
|
-
# the random document is not the same as the document
|
320
|
-
# we're processing.
|
321
|
-
for _ in range(10):
|
322
|
-
random_document_index = rng.randint(0, len(all_documents) - 1)
|
323
|
-
if random_document_index != document_index:
|
324
|
-
break
|
325
|
-
|
326
|
-
random_document = all_documents[random_document_index]
|
327
|
-
random_start = rng.randint(0, len(random_document) - 1)
|
328
|
-
for j in range(random_start, len(random_document)):
|
329
|
-
tokens_b.extend(random_document[j])
|
330
|
-
if len(tokens_b) >= target_b_length:
|
331
|
-
break
|
332
|
-
# We didn't actually use these segments so we "put them back" so
|
333
|
-
# they don't go to waste.
|
334
|
-
num_unused_segments = len(current_chunk) - a_end
|
335
|
-
i -= num_unused_segments
|
336
|
-
# Actual next
|
337
|
-
else:
|
338
|
-
is_random_next = False
|
339
|
-
for j in range(a_end, len(current_chunk)):
|
340
|
-
tokens_b.extend(current_chunk[j])
|
341
|
-
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
|
342
|
-
|
343
|
-
assert len(tokens_a) >= 1
|
344
|
-
assert len(tokens_b) >= 1
|
345
|
-
|
346
|
-
tokens = []
|
347
|
-
segment_ids = []
|
348
|
-
tokens.append("[CLS]")
|
349
|
-
segment_ids.append(0)
|
350
|
-
for token in tokens_a:
|
351
|
-
tokens.append(token)
|
352
|
-
segment_ids.append(0)
|
353
|
-
|
354
|
-
tokens.append("[SEP]")
|
355
|
-
segment_ids.append(0)
|
356
|
-
|
357
|
-
for token in tokens_b:
|
358
|
-
tokens.append(token)
|
359
|
-
segment_ids.append(1)
|
360
|
-
tokens.append("[SEP]")
|
361
|
-
segment_ids.append(1)
|
362
|
-
|
363
|
-
(
|
364
|
-
tokens,
|
365
|
-
masked_lm_positions,
|
366
|
-
masked_lm_labels,
|
367
|
-
) = create_masked_lm_predictions(
|
368
|
-
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng
|
369
|
-
)
|
370
|
-
instance = TrainingInstance(
|
371
|
-
tokens=tokens,
|
372
|
-
segment_ids=segment_ids,
|
373
|
-
is_random_next=is_random_next,
|
374
|
-
masked_lm_positions=masked_lm_positions,
|
375
|
-
masked_lm_labels=masked_lm_labels,
|
376
|
-
)
|
377
|
-
instances.append(instance)
|
378
|
-
current_chunk = []
|
379
|
-
current_length = 0
|
380
|
-
i += 1
|
381
|
-
|
382
|
-
return instances
|
383
|
-
|
384
|
-
|
385
|
-
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"])
|
386
|
-
|
387
|
-
|
388
|
-
def create_masked_lm_predictions(
|
389
|
-
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng
|
390
|
-
):
|
391
|
-
"""Creates the predictions for the masked LM objective."""
|
392
|
-
|
393
|
-
cand_indexes = []
|
394
|
-
for (i, token) in enumerate(tokens):
|
395
|
-
if token == "[CLS]" or token == "[SEP]":
|
396
|
-
continue
|
397
|
-
# Whole Word Masking means that if we mask all of the wordpieces
|
398
|
-
# corresponding to an original word. When a word has been split into
|
399
|
-
# WordPieces, the first token does not have any marker and any subsequence
|
400
|
-
# tokens are prefixed with ##. So whenever we see the ## token, we
|
401
|
-
# append it to the previous set of word indexes.
|
402
|
-
#
|
403
|
-
# Note that Whole Word Masking does *not* change the training code
|
404
|
-
# at all -- we still predict each WordPiece independently, softmaxed
|
405
|
-
# over the entire vocabulary.
|
406
|
-
if (
|
407
|
-
FLAGS.do_whole_word_mask
|
408
|
-
and len(cand_indexes) >= 1
|
409
|
-
and token.startswith("##")
|
410
|
-
):
|
411
|
-
cand_indexes[-1].append(i)
|
412
|
-
else:
|
413
|
-
cand_indexes.append([i])
|
414
|
-
|
415
|
-
rng.shuffle(cand_indexes)
|
416
|
-
|
417
|
-
output_tokens = list(tokens)
|
418
|
-
|
419
|
-
num_to_predict = min(
|
420
|
-
max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))
|
421
|
-
)
|
422
|
-
|
423
|
-
masked_lms = []
|
424
|
-
covered_indexes = set()
|
425
|
-
for index_set in cand_indexes:
|
426
|
-
if len(masked_lms) >= num_to_predict:
|
427
|
-
break
|
428
|
-
# If adding a whole-word mask would exceed the maximum number of
|
429
|
-
# predictions, then just skip this candidate.
|
430
|
-
if len(masked_lms) + len(index_set) > num_to_predict:
|
431
|
-
continue
|
432
|
-
is_any_index_covered = False
|
433
|
-
for index in index_set:
|
434
|
-
if index in covered_indexes:
|
435
|
-
is_any_index_covered = True
|
436
|
-
break
|
437
|
-
if is_any_index_covered:
|
438
|
-
continue
|
439
|
-
for index in index_set:
|
440
|
-
covered_indexes.add(index)
|
441
|
-
|
442
|
-
masked_token = None
|
443
|
-
# 80% of the time, replace with [MASK]
|
444
|
-
if rng.random() < 0.8:
|
445
|
-
masked_token = "[MASK]"
|
446
|
-
else:
|
447
|
-
# 10% of the time, keep original
|
448
|
-
if rng.random() < 0.5:
|
449
|
-
masked_token = tokens[index]
|
450
|
-
# 10% of the time, replace with random word
|
451
|
-
else:
|
452
|
-
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
|
453
|
-
|
454
|
-
output_tokens[index] = masked_token
|
455
|
-
|
456
|
-
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
|
457
|
-
assert len(masked_lms) <= num_to_predict
|
458
|
-
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
459
|
-
|
460
|
-
masked_lm_positions = []
|
461
|
-
masked_lm_labels = []
|
462
|
-
for p in masked_lms:
|
463
|
-
masked_lm_positions.append(p.index)
|
464
|
-
masked_lm_labels.append(p.label)
|
465
|
-
|
466
|
-
return (output_tokens, masked_lm_positions, masked_lm_labels)
|
467
|
-
|
468
|
-
|
469
|
-
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
|
470
|
-
"""Truncates a pair of sequences to a maximum sequence length."""
|
471
|
-
while True:
|
472
|
-
total_length = len(tokens_a) + len(tokens_b)
|
473
|
-
if total_length <= max_num_tokens:
|
474
|
-
break
|
475
|
-
|
476
|
-
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
477
|
-
assert len(trunc_tokens) >= 1
|
478
|
-
|
479
|
-
# We want to sometimes truncate from the front and sometimes from the
|
480
|
-
# back to add more randomness and avoid biases.
|
481
|
-
if rng.random() < 0.5:
|
482
|
-
del trunc_tokens[0]
|
483
|
-
else:
|
484
|
-
trunc_tokens.pop()
|
485
|
-
|
486
|
-
|
487
|
-
def main(_):
|
488
|
-
tf.logging.set_verbosity(tf.logging.INFO)
|
489
|
-
logger = tf.get_logger()
|
490
|
-
logger.propagate = False
|
491
|
-
|
492
|
-
tokenizer = tokenization.FullTokenizer(
|
493
|
-
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
|
494
|
-
)
|
495
|
-
|
496
|
-
input_files = []
|
497
|
-
for input_pattern in FLAGS.input_file.split(","):
|
498
|
-
input_files.extend(tf.gfile.Glob(input_pattern))
|
499
|
-
|
500
|
-
tf.logging.info("*** Reading from input files ***")
|
501
|
-
for input_file in input_files:
|
502
|
-
tf.logging.info(" %s", input_file)
|
503
|
-
|
504
|
-
rng = random.Random(FLAGS.random_seed)
|
505
|
-
instances = create_training_instances(
|
506
|
-
input_files,
|
507
|
-
tokenizer,
|
508
|
-
FLAGS.max_seq_length,
|
509
|
-
FLAGS.dupe_factor,
|
510
|
-
FLAGS.short_seq_prob,
|
511
|
-
FLAGS.masked_lm_prob,
|
512
|
-
FLAGS.max_predictions_per_seq,
|
513
|
-
rng,
|
514
|
-
)
|
515
|
-
|
516
|
-
output_files = FLAGS.output_file.split(",")
|
517
|
-
tf.logging.info("*** Writing to output files ***")
|
518
|
-
for output_file in output_files:
|
519
|
-
tf.logging.info(" %s", output_file)
|
520
|
-
|
521
|
-
write_instance_to_example_files(
|
522
|
-
instances,
|
523
|
-
tokenizer,
|
524
|
-
FLAGS.max_seq_length,
|
525
|
-
FLAGS.max_predictions_per_seq,
|
526
|
-
output_files,
|
527
|
-
)
|
528
|
-
|
529
|
-
|
530
|
-
if __name__ == "__main__":
|
531
|
-
flags.mark_flag_as_required("input_file")
|
532
|
-
flags.mark_flag_as_required("output_file")
|
533
|
-
flags.mark_flag_as_required("vocab_file")
|
534
|
-
tf.app.run()
|