SinaTools 0.1.40__py2.py3-none-any.whl → 1.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/METADATA +1 -1
- SinaTools-1.0.1.dist-info/RECORD +73 -0
- sinatools/VERSION +1 -1
- sinatools/ner/__init__.py +5 -7
- sinatools/ner/trainers/BertNestedTrainer.py +203 -203
- sinatools/ner/trainers/BertTrainer.py +163 -163
- sinatools/ner/trainers/__init__.py +2 -2
- SinaTools-0.1.40.dist-info/RECORD +0 -123
- sinatools/arabert/arabert/__init__.py +0 -14
- sinatools/arabert/arabert/create_classification_data.py +0 -260
- sinatools/arabert/arabert/create_pretraining_data.py +0 -534
- sinatools/arabert/arabert/extract_features.py +0 -444
- sinatools/arabert/arabert/lamb_optimizer.py +0 -158
- sinatools/arabert/arabert/modeling.py +0 -1027
- sinatools/arabert/arabert/optimization.py +0 -202
- sinatools/arabert/arabert/run_classifier.py +0 -1078
- sinatools/arabert/arabert/run_pretraining.py +0 -593
- sinatools/arabert/arabert/run_squad.py +0 -1440
- sinatools/arabert/arabert/tokenization.py +0 -414
- sinatools/arabert/araelectra/__init__.py +0 -1
- sinatools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +0 -103
- sinatools/arabert/araelectra/build_pretraining_dataset.py +0 -230
- sinatools/arabert/araelectra/build_pretraining_dataset_single_file.py +0 -90
- sinatools/arabert/araelectra/configure_finetuning.py +0 -172
- sinatools/arabert/araelectra/configure_pretraining.py +0 -143
- sinatools/arabert/araelectra/finetune/__init__.py +0 -14
- sinatools/arabert/araelectra/finetune/feature_spec.py +0 -56
- sinatools/arabert/araelectra/finetune/preprocessing.py +0 -173
- sinatools/arabert/araelectra/finetune/scorer.py +0 -54
- sinatools/arabert/araelectra/finetune/task.py +0 -74
- sinatools/arabert/araelectra/finetune/task_builder.py +0 -70
- sinatools/arabert/araelectra/flops_computation.py +0 -215
- sinatools/arabert/araelectra/model/__init__.py +0 -14
- sinatools/arabert/araelectra/model/modeling.py +0 -1029
- sinatools/arabert/araelectra/model/optimization.py +0 -193
- sinatools/arabert/araelectra/model/tokenization.py +0 -355
- sinatools/arabert/araelectra/pretrain/__init__.py +0 -14
- sinatools/arabert/araelectra/pretrain/pretrain_data.py +0 -160
- sinatools/arabert/araelectra/pretrain/pretrain_helpers.py +0 -229
- sinatools/arabert/araelectra/run_finetuning.py +0 -323
- sinatools/arabert/araelectra/run_pretraining.py +0 -469
- sinatools/arabert/araelectra/util/__init__.py +0 -14
- sinatools/arabert/araelectra/util/training_utils.py +0 -112
- sinatools/arabert/araelectra/util/utils.py +0 -109
- sinatools/arabert/aragpt2/__init__.py +0 -2
- sinatools/arabert/aragpt2/create_pretraining_data.py +0 -95
- sinatools/arabert/aragpt2/gpt2/__init__.py +0 -2
- sinatools/arabert/aragpt2/gpt2/lamb_optimizer.py +0 -158
- sinatools/arabert/aragpt2/gpt2/optimization.py +0 -225
- sinatools/arabert/aragpt2/gpt2/run_pretraining.py +0 -397
- sinatools/arabert/aragpt2/grover/__init__.py +0 -0
- sinatools/arabert/aragpt2/grover/dataloader.py +0 -161
- sinatools/arabert/aragpt2/grover/modeling.py +0 -803
- sinatools/arabert/aragpt2/grover/modeling_gpt2.py +0 -1196
- sinatools/arabert/aragpt2/grover/optimization_adafactor.py +0 -234
- sinatools/arabert/aragpt2/grover/train_tpu.py +0 -187
- sinatools/arabert/aragpt2/grover/utils.py +0 -234
- sinatools/arabert/aragpt2/train_bpe_tokenizer.py +0 -59
- {SinaTools-0.1.40.data → SinaTools-1.0.1.data}/data/sinatools/environment.yml +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/AUTHORS.rst +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/LICENSE +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/WHEEL +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/entry_points.txt +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/top_level.txt +0 -0
@@ -1,593 +0,0 @@
|
|
1
|
-
# coding=utf-8
|
2
|
-
# Copyright 2018 The Google AI Language Team Authors.
|
3
|
-
#
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
-
# you may not use this file except in compliance with the License.
|
6
|
-
# You may obtain a copy of the License at
|
7
|
-
#
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
-
#
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
-
# See the License for the specific language governing permissions and
|
14
|
-
# limitations under the License.
|
15
|
-
"""Run masked LM/next sentence masked_lm pre-training for BERT."""
|
16
|
-
|
17
|
-
from __future__ import absolute_import
|
18
|
-
from __future__ import division
|
19
|
-
from __future__ import print_function
|
20
|
-
|
21
|
-
import os
|
22
|
-
import modeling
|
23
|
-
import optimization
|
24
|
-
import tensorflow as tf
|
25
|
-
|
26
|
-
flags = tf.flags
|
27
|
-
|
28
|
-
FLAGS = flags.FLAGS
|
29
|
-
|
30
|
-
## Required parameters
|
31
|
-
flags.DEFINE_string(
|
32
|
-
"bert_config_file",
|
33
|
-
None,
|
34
|
-
"The config json file corresponding to the pre-trained BERT model. "
|
35
|
-
"This specifies the model architecture.",
|
36
|
-
)
|
37
|
-
|
38
|
-
flags.DEFINE_string(
|
39
|
-
"input_file", None, "Input TF example files (can be a glob or comma separated)."
|
40
|
-
)
|
41
|
-
|
42
|
-
flags.DEFINE_string(
|
43
|
-
"output_dir",
|
44
|
-
None,
|
45
|
-
"The output directory where the model checkpoints will be written.",
|
46
|
-
)
|
47
|
-
|
48
|
-
## Other parameters
|
49
|
-
flags.DEFINE_string(
|
50
|
-
"init_checkpoint",
|
51
|
-
None,
|
52
|
-
"Initial checkpoint (usually from a pre-trained BERT model).",
|
53
|
-
)
|
54
|
-
|
55
|
-
flags.DEFINE_integer(
|
56
|
-
"max_seq_length",
|
57
|
-
128,
|
58
|
-
"The maximum total input sequence length after WordPiece tokenization. "
|
59
|
-
"Sequences longer than this will be truncated, and sequences shorter "
|
60
|
-
"than this will be padded. Must match data generation.",
|
61
|
-
)
|
62
|
-
|
63
|
-
flags.DEFINE_integer(
|
64
|
-
"max_predictions_per_seq",
|
65
|
-
20,
|
66
|
-
"Maximum number of masked LM predictions per sequence. "
|
67
|
-
"Must match data generation.",
|
68
|
-
)
|
69
|
-
|
70
|
-
flags.DEFINE_bool("do_train", False, "Whether to run training.")
|
71
|
-
|
72
|
-
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
|
73
|
-
|
74
|
-
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
|
75
|
-
|
76
|
-
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
|
77
|
-
|
78
|
-
flags.DEFINE_float("poly_power", 1.0, "The power of poly decay.")
|
79
|
-
|
80
|
-
flags.DEFINE_enum("optimizer", "lamb", ["adamw", "lamb"],
|
81
|
-
"The optimizer for training.")
|
82
|
-
|
83
|
-
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
|
84
|
-
|
85
|
-
flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.")
|
86
|
-
|
87
|
-
flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.")
|
88
|
-
|
89
|
-
flags.DEFINE_integer("start_warmup_step", 0, "The starting step of warmup.")
|
90
|
-
|
91
|
-
flags.DEFINE_integer(
|
92
|
-
"save_checkpoints_steps", 1000, "How often to save the model checkpoint."
|
93
|
-
)
|
94
|
-
|
95
|
-
flags.DEFINE_integer(
|
96
|
-
"iterations_per_loop", 1000, "How many steps to make in each estimator call."
|
97
|
-
)
|
98
|
-
|
99
|
-
flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
|
100
|
-
|
101
|
-
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
|
102
|
-
|
103
|
-
tf.flags.DEFINE_string(
|
104
|
-
"tpu_name",
|
105
|
-
None,
|
106
|
-
"The Cloud TPU to use for training. This should be either the name "
|
107
|
-
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
108
|
-
"url.",
|
109
|
-
)
|
110
|
-
|
111
|
-
tf.flags.DEFINE_string(
|
112
|
-
"tpu_zone",
|
113
|
-
None,
|
114
|
-
"[Optional] GCE zone where the Cloud TPU is located in. If not "
|
115
|
-
"specified, we will attempt to automatically detect the GCE project from "
|
116
|
-
"metadata.",
|
117
|
-
)
|
118
|
-
|
119
|
-
tf.flags.DEFINE_string(
|
120
|
-
"gcp_project",
|
121
|
-
None,
|
122
|
-
"[Optional] Project name for the Cloud TPU-enabled project. If not "
|
123
|
-
"specified, we will attempt to automatically detect the GCE project from "
|
124
|
-
"metadata.",
|
125
|
-
)
|
126
|
-
|
127
|
-
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
|
128
|
-
|
129
|
-
flags.DEFINE_integer(
|
130
|
-
"num_tpu_cores",
|
131
|
-
8,
|
132
|
-
"Only used if `use_tpu` is True. Total number of TPU cores to use.",
|
133
|
-
)
|
134
|
-
|
135
|
-
flags.DEFINE_integer("keep_checkpoint_max", 10,
|
136
|
-
"How many checkpoints to keep.")
|
137
|
-
|
138
|
-
|
139
|
-
def model_fn_builder(
|
140
|
-
bert_config,
|
141
|
-
init_checkpoint,
|
142
|
-
learning_rate,
|
143
|
-
num_train_steps,
|
144
|
-
num_warmup_steps,
|
145
|
-
use_tpu,
|
146
|
-
use_one_hot_embeddings,
|
147
|
-
optimizer,
|
148
|
-
poly_power,
|
149
|
-
start_warmup_step,
|
150
|
-
):
|
151
|
-
"""Returns `model_fn` closure for TPUEstimator."""
|
152
|
-
|
153
|
-
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
154
|
-
"""The `model_fn` for TPUEstimator."""
|
155
|
-
|
156
|
-
tf.logging.info("*** Features ***")
|
157
|
-
for name in sorted(features.keys()):
|
158
|
-
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
159
|
-
|
160
|
-
input_ids = features["input_ids"]
|
161
|
-
input_mask = features["input_mask"]
|
162
|
-
segment_ids = features["segment_ids"]
|
163
|
-
masked_lm_positions = features["masked_lm_positions"]
|
164
|
-
masked_lm_ids = features["masked_lm_ids"]
|
165
|
-
masked_lm_weights = features["masked_lm_weights"]
|
166
|
-
next_sentence_labels = features["next_sentence_labels"]
|
167
|
-
|
168
|
-
is_training = mode == tf.estimator.ModeKeys.TRAIN
|
169
|
-
|
170
|
-
model = modeling.BertModel(
|
171
|
-
config=bert_config,
|
172
|
-
is_training=is_training,
|
173
|
-
input_ids=input_ids,
|
174
|
-
input_mask=input_mask,
|
175
|
-
token_type_ids=segment_ids,
|
176
|
-
use_one_hot_embeddings=use_one_hot_embeddings,
|
177
|
-
)
|
178
|
-
|
179
|
-
(
|
180
|
-
masked_lm_loss,
|
181
|
-
masked_lm_example_loss,
|
182
|
-
masked_lm_log_probs,
|
183
|
-
) = get_masked_lm_output(
|
184
|
-
bert_config,
|
185
|
-
model.get_sequence_output(),
|
186
|
-
model.get_embedding_table(),
|
187
|
-
masked_lm_positions,
|
188
|
-
masked_lm_ids,
|
189
|
-
masked_lm_weights,
|
190
|
-
)
|
191
|
-
|
192
|
-
(
|
193
|
-
next_sentence_loss,
|
194
|
-
next_sentence_example_loss,
|
195
|
-
next_sentence_log_probs,
|
196
|
-
) = get_next_sentence_output(
|
197
|
-
bert_config, model.get_pooled_output(), next_sentence_labels
|
198
|
-
)
|
199
|
-
|
200
|
-
total_loss = masked_lm_loss + next_sentence_loss
|
201
|
-
|
202
|
-
tvars = tf.trainable_variables()
|
203
|
-
|
204
|
-
initialized_variable_names = {}
|
205
|
-
scaffold_fn = None
|
206
|
-
if init_checkpoint:
|
207
|
-
(
|
208
|
-
assignment_map,
|
209
|
-
initialized_variable_names,
|
210
|
-
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
|
211
|
-
if use_tpu:
|
212
|
-
|
213
|
-
def tpu_scaffold():
|
214
|
-
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
215
|
-
return tf.train.Scaffold()
|
216
|
-
|
217
|
-
scaffold_fn = tpu_scaffold
|
218
|
-
else:
|
219
|
-
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
220
|
-
|
221
|
-
tf.logging.info("**** Trainable Variables ****")
|
222
|
-
for var in tvars:
|
223
|
-
init_string = ""
|
224
|
-
if var.name in initialized_variable_names:
|
225
|
-
init_string = ", *INIT_FROM_CKPT*"
|
226
|
-
tf.logging.info(
|
227
|
-
" name = %s, shape = %s%s", var.name, var.shape, init_string
|
228
|
-
)
|
229
|
-
|
230
|
-
output_spec = None
|
231
|
-
if mode == tf.estimator.ModeKeys.TRAIN:
|
232
|
-
train_op = optimization.create_optimizer(
|
233
|
-
total_loss,
|
234
|
-
learning_rate,
|
235
|
-
num_train_steps,
|
236
|
-
num_warmup_steps,
|
237
|
-
use_tpu,
|
238
|
-
optimizer,
|
239
|
-
poly_power,
|
240
|
-
start_warmup_step,
|
241
|
-
)
|
242
|
-
|
243
|
-
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
244
|
-
mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn
|
245
|
-
)
|
246
|
-
elif mode == tf.estimator.ModeKeys.EVAL:
|
247
|
-
|
248
|
-
def metric_fn(
|
249
|
-
masked_lm_example_loss,
|
250
|
-
masked_lm_log_probs,
|
251
|
-
masked_lm_ids,
|
252
|
-
masked_lm_weights,
|
253
|
-
next_sentence_example_loss,
|
254
|
-
next_sentence_log_probs,
|
255
|
-
next_sentence_labels,
|
256
|
-
):
|
257
|
-
"""Computes the loss and accuracy of the model."""
|
258
|
-
masked_lm_log_probs = tf.reshape(
|
259
|
-
masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]]
|
260
|
-
)
|
261
|
-
masked_lm_predictions = tf.argmax(
|
262
|
-
masked_lm_log_probs, axis=-1, output_type=tf.int32
|
263
|
-
)
|
264
|
-
masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
|
265
|
-
masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
|
266
|
-
masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
|
267
|
-
masked_lm_accuracy = tf.metrics.accuracy(
|
268
|
-
labels=masked_lm_ids,
|
269
|
-
predictions=masked_lm_predictions,
|
270
|
-
weights=masked_lm_weights,
|
271
|
-
)
|
272
|
-
masked_lm_mean_loss = tf.metrics.mean(
|
273
|
-
values=masked_lm_example_loss, weights=masked_lm_weights
|
274
|
-
)
|
275
|
-
|
276
|
-
next_sentence_log_probs = tf.reshape(
|
277
|
-
next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]
|
278
|
-
)
|
279
|
-
next_sentence_predictions = tf.argmax(
|
280
|
-
next_sentence_log_probs, axis=-1, output_type=tf.int32
|
281
|
-
)
|
282
|
-
next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
|
283
|
-
next_sentence_accuracy = tf.metrics.accuracy(
|
284
|
-
labels=next_sentence_labels, predictions=next_sentence_predictions
|
285
|
-
)
|
286
|
-
next_sentence_mean_loss = tf.metrics.mean(
|
287
|
-
values=next_sentence_example_loss
|
288
|
-
)
|
289
|
-
|
290
|
-
return {
|
291
|
-
"masked_lm_accuracy": masked_lm_accuracy,
|
292
|
-
"masked_lm_loss": masked_lm_mean_loss,
|
293
|
-
"next_sentence_accuracy": next_sentence_accuracy,
|
294
|
-
"next_sentence_loss": next_sentence_mean_loss,
|
295
|
-
}
|
296
|
-
|
297
|
-
eval_metrics = (
|
298
|
-
metric_fn,
|
299
|
-
[
|
300
|
-
masked_lm_example_loss,
|
301
|
-
masked_lm_log_probs,
|
302
|
-
masked_lm_ids,
|
303
|
-
masked_lm_weights,
|
304
|
-
next_sentence_example_loss,
|
305
|
-
next_sentence_log_probs,
|
306
|
-
next_sentence_labels,
|
307
|
-
],
|
308
|
-
)
|
309
|
-
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
310
|
-
mode=mode,
|
311
|
-
loss=total_loss,
|
312
|
-
eval_metrics=eval_metrics,
|
313
|
-
scaffold_fn=scaffold_fn,
|
314
|
-
)
|
315
|
-
else:
|
316
|
-
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
|
317
|
-
|
318
|
-
return output_spec
|
319
|
-
|
320
|
-
return model_fn
|
321
|
-
|
322
|
-
|
323
|
-
def get_masked_lm_output(
|
324
|
-
bert_config, input_tensor, output_weights, positions, label_ids, label_weights
|
325
|
-
):
|
326
|
-
"""Get loss and log probs for the masked LM."""
|
327
|
-
input_tensor = gather_indexes(input_tensor, positions)
|
328
|
-
|
329
|
-
with tf.variable_scope("cls/predictions"):
|
330
|
-
# We apply one more non-linear transformation before the output layer.
|
331
|
-
# This matrix is not used after pre-training.
|
332
|
-
with tf.variable_scope("transform"):
|
333
|
-
input_tensor = tf.layers.dense(
|
334
|
-
input_tensor,
|
335
|
-
units=bert_config.hidden_size,
|
336
|
-
activation=modeling.get_activation(bert_config.hidden_act),
|
337
|
-
kernel_initializer=modeling.create_initializer(
|
338
|
-
bert_config.initializer_range
|
339
|
-
),
|
340
|
-
)
|
341
|
-
input_tensor = modeling.layer_norm(input_tensor)
|
342
|
-
|
343
|
-
# The output weights are the same as the input embeddings, but there is
|
344
|
-
# an output-only bias for each token.
|
345
|
-
output_bias = tf.get_variable(
|
346
|
-
"output_bias",
|
347
|
-
shape=[bert_config.vocab_size],
|
348
|
-
initializer=tf.zeros_initializer(),
|
349
|
-
)
|
350
|
-
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
351
|
-
logits = tf.nn.bias_add(logits, output_bias)
|
352
|
-
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
353
|
-
|
354
|
-
label_ids = tf.reshape(label_ids, [-1])
|
355
|
-
label_weights = tf.reshape(label_weights, [-1])
|
356
|
-
|
357
|
-
one_hot_labels = tf.one_hot(
|
358
|
-
label_ids, depth=bert_config.vocab_size, dtype=tf.float32
|
359
|
-
)
|
360
|
-
|
361
|
-
# The `positions` tensor might be zero-padded (if the sequence is too
|
362
|
-
# short to have the maximum number of predictions). The `label_weights`
|
363
|
-
# tensor has a value of 1.0 for every real prediction and 0.0 for the
|
364
|
-
# padding predictions.
|
365
|
-
per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
|
366
|
-
numerator = tf.reduce_sum(label_weights * per_example_loss)
|
367
|
-
denominator = tf.reduce_sum(label_weights) + 1e-5
|
368
|
-
loss = numerator / denominator
|
369
|
-
|
370
|
-
return (loss, per_example_loss, log_probs)
|
371
|
-
|
372
|
-
|
373
|
-
def get_next_sentence_output(bert_config, input_tensor, labels):
|
374
|
-
"""Get loss and log probs for the next sentence prediction."""
|
375
|
-
|
376
|
-
# Simple binary classification. Note that 0 is "next sentence" and 1 is
|
377
|
-
# "random sentence". This weight matrix is not used after pre-training.
|
378
|
-
with tf.variable_scope("cls/seq_relationship"):
|
379
|
-
output_weights = tf.get_variable(
|
380
|
-
"output_weights",
|
381
|
-
shape=[2, bert_config.hidden_size],
|
382
|
-
initializer=modeling.create_initializer(bert_config.initializer_range),
|
383
|
-
)
|
384
|
-
output_bias = tf.get_variable(
|
385
|
-
"output_bias", shape=[2], initializer=tf.zeros_initializer()
|
386
|
-
)
|
387
|
-
|
388
|
-
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
389
|
-
logits = tf.nn.bias_add(logits, output_bias)
|
390
|
-
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
391
|
-
labels = tf.reshape(labels, [-1])
|
392
|
-
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
|
393
|
-
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
|
394
|
-
loss = tf.reduce_mean(per_example_loss)
|
395
|
-
return (loss, per_example_loss, log_probs)
|
396
|
-
|
397
|
-
|
398
|
-
def gather_indexes(sequence_tensor, positions):
|
399
|
-
"""Gathers the vectors at the specific positions over a minibatch."""
|
400
|
-
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
|
401
|
-
batch_size = sequence_shape[0]
|
402
|
-
seq_length = sequence_shape[1]
|
403
|
-
width = sequence_shape[2]
|
404
|
-
|
405
|
-
flat_offsets = tf.reshape(
|
406
|
-
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]
|
407
|
-
)
|
408
|
-
flat_positions = tf.reshape(positions + flat_offsets, [-1])
|
409
|
-
flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width])
|
410
|
-
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
|
411
|
-
return output_tensor
|
412
|
-
|
413
|
-
|
414
|
-
def input_fn_builder(
|
415
|
-
input_files, max_seq_length, max_predictions_per_seq, is_training, num_cpu_threads=4
|
416
|
-
):
|
417
|
-
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
418
|
-
|
419
|
-
def input_fn(params):
|
420
|
-
"""The actual input function."""
|
421
|
-
batch_size = params["batch_size"]
|
422
|
-
|
423
|
-
name_to_features = {
|
424
|
-
"input_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
|
425
|
-
"input_mask": tf.FixedLenFeature([max_seq_length], tf.int64),
|
426
|
-
"segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
|
427
|
-
"masked_lm_positions": tf.FixedLenFeature(
|
428
|
-
[max_predictions_per_seq], tf.int64
|
429
|
-
),
|
430
|
-
"masked_lm_ids": tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
431
|
-
"masked_lm_weights": tf.FixedLenFeature(
|
432
|
-
[max_predictions_per_seq], tf.float32
|
433
|
-
),
|
434
|
-
"next_sentence_labels": tf.FixedLenFeature([1], tf.int64),
|
435
|
-
}
|
436
|
-
|
437
|
-
# For training, we want a lot of parallel reading and shuffling.
|
438
|
-
# For eval, we want no shuffling and parallel reading doesn't matter.
|
439
|
-
if is_training:
|
440
|
-
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
|
441
|
-
d = d.repeat()
|
442
|
-
d = d.shuffle(buffer_size=len(input_files))
|
443
|
-
|
444
|
-
# `cycle_length` is the number of parallel files that get read.
|
445
|
-
cycle_length = min(num_cpu_threads, len(input_files))
|
446
|
-
|
447
|
-
# `sloppy` mode means that the interleaving is not exact. This adds
|
448
|
-
# even more randomness to the training pipeline.
|
449
|
-
d = d.apply(
|
450
|
-
tf.contrib.data.parallel_interleave(
|
451
|
-
tf.data.TFRecordDataset,
|
452
|
-
sloppy=is_training,
|
453
|
-
cycle_length=cycle_length,
|
454
|
-
)
|
455
|
-
)
|
456
|
-
d = d.shuffle(buffer_size=100)
|
457
|
-
else:
|
458
|
-
d = tf.data.TFRecordDataset(input_files)
|
459
|
-
# Since we evaluate for a fixed number of steps we don't want to encounter
|
460
|
-
# out-of-range exceptions.
|
461
|
-
d = d.repeat()
|
462
|
-
|
463
|
-
# We must `drop_remainder` on training because the TPU requires fixed
|
464
|
-
# size dimensions. For eval, we assume we are evaluating on the CPU or GPU
|
465
|
-
# and we *don't* want to drop the remainder, otherwise we wont cover
|
466
|
-
# every sample.
|
467
|
-
d = d.apply(
|
468
|
-
tf.contrib.data.map_and_batch(
|
469
|
-
lambda record: _decode_record(record, name_to_features),
|
470
|
-
batch_size=batch_size,
|
471
|
-
num_parallel_batches=num_cpu_threads,
|
472
|
-
drop_remainder=True,
|
473
|
-
)
|
474
|
-
)
|
475
|
-
return d
|
476
|
-
|
477
|
-
return input_fn
|
478
|
-
|
479
|
-
|
480
|
-
def _decode_record(record, name_to_features):
|
481
|
-
"""Decodes a record to a TensorFlow example."""
|
482
|
-
example = tf.parse_single_example(record, name_to_features)
|
483
|
-
|
484
|
-
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
485
|
-
# So cast all int64 to int32.
|
486
|
-
for name in list(example.keys()):
|
487
|
-
t = example[name]
|
488
|
-
if t.dtype == tf.int64:
|
489
|
-
t = tf.to_int32(t)
|
490
|
-
example[name] = t
|
491
|
-
|
492
|
-
return example
|
493
|
-
|
494
|
-
|
495
|
-
def main(_):
|
496
|
-
tf.logging.set_verbosity(tf.logging.INFO)
|
497
|
-
logger = tf.get_logger()
|
498
|
-
logger.propagate = False
|
499
|
-
|
500
|
-
if not FLAGS.do_train and not FLAGS.do_eval:
|
501
|
-
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
502
|
-
|
503
|
-
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
504
|
-
|
505
|
-
tf.gfile.MakeDirs(FLAGS.output_dir)
|
506
|
-
|
507
|
-
input_files = []
|
508
|
-
for input_pattern in FLAGS.input_file.split(","):
|
509
|
-
input_files.extend(tf.gfile.Glob(input_pattern))
|
510
|
-
|
511
|
-
# tf.logging.info("*** Input Files ***")
|
512
|
-
# for input_file in input_files:
|
513
|
-
# tf.logging.info(" %s" % input_file)
|
514
|
-
|
515
|
-
tpu_cluster_resolver = None
|
516
|
-
if FLAGS.use_tpu and FLAGS.tpu_name:
|
517
|
-
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
518
|
-
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project
|
519
|
-
)
|
520
|
-
|
521
|
-
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
522
|
-
run_config = tf.contrib.tpu.RunConfig(
|
523
|
-
cluster=tpu_cluster_resolver,
|
524
|
-
master=FLAGS.master,
|
525
|
-
model_dir=FLAGS.output_dir,
|
526
|
-
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
527
|
-
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
|
528
|
-
tpu_config=tf.contrib.tpu.TPUConfig(
|
529
|
-
iterations_per_loop=FLAGS.iterations_per_loop,
|
530
|
-
num_shards=FLAGS.num_tpu_cores,
|
531
|
-
per_host_input_for_training=is_per_host,
|
532
|
-
),
|
533
|
-
)
|
534
|
-
|
535
|
-
model_fn = model_fn_builder(
|
536
|
-
bert_config=bert_config,
|
537
|
-
init_checkpoint=FLAGS.init_checkpoint,
|
538
|
-
learning_rate=FLAGS.learning_rate,
|
539
|
-
num_train_steps=FLAGS.num_train_steps,
|
540
|
-
num_warmup_steps=FLAGS.num_warmup_steps,
|
541
|
-
use_tpu=FLAGS.use_tpu,
|
542
|
-
use_one_hot_embeddings=FLAGS.use_tpu,
|
543
|
-
optimizer=FLAGS.optimizer,
|
544
|
-
poly_power=FLAGS.poly_power,
|
545
|
-
start_warmup_step=FLAGS.start_warmup_step
|
546
|
-
)
|
547
|
-
|
548
|
-
# If TPU is not available, this will fall back to normal Estimator on CPU
|
549
|
-
# or GPU.
|
550
|
-
estimator = tf.contrib.tpu.TPUEstimator(
|
551
|
-
use_tpu=FLAGS.use_tpu,
|
552
|
-
model_fn=model_fn,
|
553
|
-
config=run_config,
|
554
|
-
train_batch_size=FLAGS.train_batch_size,
|
555
|
-
eval_batch_size=FLAGS.eval_batch_size,
|
556
|
-
)
|
557
|
-
|
558
|
-
if FLAGS.do_train:
|
559
|
-
tf.logging.info("***** Running training *****")
|
560
|
-
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
561
|
-
train_input_fn = input_fn_builder(
|
562
|
-
input_files=input_files,
|
563
|
-
max_seq_length=FLAGS.max_seq_length,
|
564
|
-
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
565
|
-
is_training=True,
|
566
|
-
)
|
567
|
-
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
|
568
|
-
|
569
|
-
if FLAGS.do_eval:
|
570
|
-
tf.logging.info("***** Running evaluation *****")
|
571
|
-
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
572
|
-
|
573
|
-
eval_input_fn = input_fn_builder(
|
574
|
-
input_files=input_files,
|
575
|
-
max_seq_length=FLAGS.max_seq_length,
|
576
|
-
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
577
|
-
is_training=False,
|
578
|
-
)
|
579
|
-
|
580
|
-
result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
|
581
|
-
|
582
|
-
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
|
583
|
-
with tf.gfile.GFile(output_eval_file, "w") as writer:
|
584
|
-
tf.logging.info("***** Eval results *****")
|
585
|
-
for key in sorted(result.keys()):
|
586
|
-
tf.logging.info(" %s = %s", key, str(result[key]))
|
587
|
-
writer.write("%s = %s\n" % (key, str(result[key])))
|
588
|
-
|
589
|
-
if __name__ == "__main__":
|
590
|
-
flags.mark_flag_as_required("input_file")
|
591
|
-
flags.mark_flag_as_required("bert_config_file")
|
592
|
-
flags.mark_flag_as_required("output_dir")
|
593
|
-
tf.app.run()
|