SinaTools 0.1.40__py2.py3-none-any.whl → 1.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/METADATA +1 -1
  2. SinaTools-1.0.1.dist-info/RECORD +73 -0
  3. sinatools/VERSION +1 -1
  4. sinatools/ner/__init__.py +5 -7
  5. sinatools/ner/trainers/BertNestedTrainer.py +203 -203
  6. sinatools/ner/trainers/BertTrainer.py +163 -163
  7. sinatools/ner/trainers/__init__.py +2 -2
  8. SinaTools-0.1.40.dist-info/RECORD +0 -123
  9. sinatools/arabert/arabert/__init__.py +0 -14
  10. sinatools/arabert/arabert/create_classification_data.py +0 -260
  11. sinatools/arabert/arabert/create_pretraining_data.py +0 -534
  12. sinatools/arabert/arabert/extract_features.py +0 -444
  13. sinatools/arabert/arabert/lamb_optimizer.py +0 -158
  14. sinatools/arabert/arabert/modeling.py +0 -1027
  15. sinatools/arabert/arabert/optimization.py +0 -202
  16. sinatools/arabert/arabert/run_classifier.py +0 -1078
  17. sinatools/arabert/arabert/run_pretraining.py +0 -593
  18. sinatools/arabert/arabert/run_squad.py +0 -1440
  19. sinatools/arabert/arabert/tokenization.py +0 -414
  20. sinatools/arabert/araelectra/__init__.py +0 -1
  21. sinatools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +0 -103
  22. sinatools/arabert/araelectra/build_pretraining_dataset.py +0 -230
  23. sinatools/arabert/araelectra/build_pretraining_dataset_single_file.py +0 -90
  24. sinatools/arabert/araelectra/configure_finetuning.py +0 -172
  25. sinatools/arabert/araelectra/configure_pretraining.py +0 -143
  26. sinatools/arabert/araelectra/finetune/__init__.py +0 -14
  27. sinatools/arabert/araelectra/finetune/feature_spec.py +0 -56
  28. sinatools/arabert/araelectra/finetune/preprocessing.py +0 -173
  29. sinatools/arabert/araelectra/finetune/scorer.py +0 -54
  30. sinatools/arabert/araelectra/finetune/task.py +0 -74
  31. sinatools/arabert/araelectra/finetune/task_builder.py +0 -70
  32. sinatools/arabert/araelectra/flops_computation.py +0 -215
  33. sinatools/arabert/araelectra/model/__init__.py +0 -14
  34. sinatools/arabert/araelectra/model/modeling.py +0 -1029
  35. sinatools/arabert/araelectra/model/optimization.py +0 -193
  36. sinatools/arabert/araelectra/model/tokenization.py +0 -355
  37. sinatools/arabert/araelectra/pretrain/__init__.py +0 -14
  38. sinatools/arabert/araelectra/pretrain/pretrain_data.py +0 -160
  39. sinatools/arabert/araelectra/pretrain/pretrain_helpers.py +0 -229
  40. sinatools/arabert/araelectra/run_finetuning.py +0 -323
  41. sinatools/arabert/araelectra/run_pretraining.py +0 -469
  42. sinatools/arabert/araelectra/util/__init__.py +0 -14
  43. sinatools/arabert/araelectra/util/training_utils.py +0 -112
  44. sinatools/arabert/araelectra/util/utils.py +0 -109
  45. sinatools/arabert/aragpt2/__init__.py +0 -2
  46. sinatools/arabert/aragpt2/create_pretraining_data.py +0 -95
  47. sinatools/arabert/aragpt2/gpt2/__init__.py +0 -2
  48. sinatools/arabert/aragpt2/gpt2/lamb_optimizer.py +0 -158
  49. sinatools/arabert/aragpt2/gpt2/optimization.py +0 -225
  50. sinatools/arabert/aragpt2/gpt2/run_pretraining.py +0 -397
  51. sinatools/arabert/aragpt2/grover/__init__.py +0 -0
  52. sinatools/arabert/aragpt2/grover/dataloader.py +0 -161
  53. sinatools/arabert/aragpt2/grover/modeling.py +0 -803
  54. sinatools/arabert/aragpt2/grover/modeling_gpt2.py +0 -1196
  55. sinatools/arabert/aragpt2/grover/optimization_adafactor.py +0 -234
  56. sinatools/arabert/aragpt2/grover/train_tpu.py +0 -187
  57. sinatools/arabert/aragpt2/grover/utils.py +0 -234
  58. sinatools/arabert/aragpt2/train_bpe_tokenizer.py +0 -59
  59. {SinaTools-0.1.40.data → SinaTools-1.0.1.data}/data/sinatools/environment.yml +0 -0
  60. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/AUTHORS.rst +0 -0
  61. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/LICENSE +0 -0
  62. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/WHEEL +0 -0
  63. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/entry_points.txt +0 -0
  64. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/top_level.txt +0 -0
@@ -1,593 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2018 The Google AI Language Team Authors.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Run masked LM/next sentence masked_lm pre-training for BERT."""
16
-
17
- from __future__ import absolute_import
18
- from __future__ import division
19
- from __future__ import print_function
20
-
21
- import os
22
- import modeling
23
- import optimization
24
- import tensorflow as tf
25
-
26
- flags = tf.flags
27
-
28
- FLAGS = flags.FLAGS
29
-
30
- ## Required parameters
31
- flags.DEFINE_string(
32
- "bert_config_file",
33
- None,
34
- "The config json file corresponding to the pre-trained BERT model. "
35
- "This specifies the model architecture.",
36
- )
37
-
38
- flags.DEFINE_string(
39
- "input_file", None, "Input TF example files (can be a glob or comma separated)."
40
- )
41
-
42
- flags.DEFINE_string(
43
- "output_dir",
44
- None,
45
- "The output directory where the model checkpoints will be written.",
46
- )
47
-
48
- ## Other parameters
49
- flags.DEFINE_string(
50
- "init_checkpoint",
51
- None,
52
- "Initial checkpoint (usually from a pre-trained BERT model).",
53
- )
54
-
55
- flags.DEFINE_integer(
56
- "max_seq_length",
57
- 128,
58
- "The maximum total input sequence length after WordPiece tokenization. "
59
- "Sequences longer than this will be truncated, and sequences shorter "
60
- "than this will be padded. Must match data generation.",
61
- )
62
-
63
- flags.DEFINE_integer(
64
- "max_predictions_per_seq",
65
- 20,
66
- "Maximum number of masked LM predictions per sequence. "
67
- "Must match data generation.",
68
- )
69
-
70
- flags.DEFINE_bool("do_train", False, "Whether to run training.")
71
-
72
- flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
73
-
74
- flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
75
-
76
- flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
77
-
78
- flags.DEFINE_float("poly_power", 1.0, "The power of poly decay.")
79
-
80
- flags.DEFINE_enum("optimizer", "lamb", ["adamw", "lamb"],
81
- "The optimizer for training.")
82
-
83
- flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
84
-
85
- flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.")
86
-
87
- flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.")
88
-
89
- flags.DEFINE_integer("start_warmup_step", 0, "The starting step of warmup.")
90
-
91
- flags.DEFINE_integer(
92
- "save_checkpoints_steps", 1000, "How often to save the model checkpoint."
93
- )
94
-
95
- flags.DEFINE_integer(
96
- "iterations_per_loop", 1000, "How many steps to make in each estimator call."
97
- )
98
-
99
- flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
100
-
101
- flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
102
-
103
- tf.flags.DEFINE_string(
104
- "tpu_name",
105
- None,
106
- "The Cloud TPU to use for training. This should be either the name "
107
- "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
108
- "url.",
109
- )
110
-
111
- tf.flags.DEFINE_string(
112
- "tpu_zone",
113
- None,
114
- "[Optional] GCE zone where the Cloud TPU is located in. If not "
115
- "specified, we will attempt to automatically detect the GCE project from "
116
- "metadata.",
117
- )
118
-
119
- tf.flags.DEFINE_string(
120
- "gcp_project",
121
- None,
122
- "[Optional] Project name for the Cloud TPU-enabled project. If not "
123
- "specified, we will attempt to automatically detect the GCE project from "
124
- "metadata.",
125
- )
126
-
127
- tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
128
-
129
- flags.DEFINE_integer(
130
- "num_tpu_cores",
131
- 8,
132
- "Only used if `use_tpu` is True. Total number of TPU cores to use.",
133
- )
134
-
135
- flags.DEFINE_integer("keep_checkpoint_max", 10,
136
- "How many checkpoints to keep.")
137
-
138
-
139
- def model_fn_builder(
140
- bert_config,
141
- init_checkpoint,
142
- learning_rate,
143
- num_train_steps,
144
- num_warmup_steps,
145
- use_tpu,
146
- use_one_hot_embeddings,
147
- optimizer,
148
- poly_power,
149
- start_warmup_step,
150
- ):
151
- """Returns `model_fn` closure for TPUEstimator."""
152
-
153
- def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
154
- """The `model_fn` for TPUEstimator."""
155
-
156
- tf.logging.info("*** Features ***")
157
- for name in sorted(features.keys()):
158
- tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
159
-
160
- input_ids = features["input_ids"]
161
- input_mask = features["input_mask"]
162
- segment_ids = features["segment_ids"]
163
- masked_lm_positions = features["masked_lm_positions"]
164
- masked_lm_ids = features["masked_lm_ids"]
165
- masked_lm_weights = features["masked_lm_weights"]
166
- next_sentence_labels = features["next_sentence_labels"]
167
-
168
- is_training = mode == tf.estimator.ModeKeys.TRAIN
169
-
170
- model = modeling.BertModel(
171
- config=bert_config,
172
- is_training=is_training,
173
- input_ids=input_ids,
174
- input_mask=input_mask,
175
- token_type_ids=segment_ids,
176
- use_one_hot_embeddings=use_one_hot_embeddings,
177
- )
178
-
179
- (
180
- masked_lm_loss,
181
- masked_lm_example_loss,
182
- masked_lm_log_probs,
183
- ) = get_masked_lm_output(
184
- bert_config,
185
- model.get_sequence_output(),
186
- model.get_embedding_table(),
187
- masked_lm_positions,
188
- masked_lm_ids,
189
- masked_lm_weights,
190
- )
191
-
192
- (
193
- next_sentence_loss,
194
- next_sentence_example_loss,
195
- next_sentence_log_probs,
196
- ) = get_next_sentence_output(
197
- bert_config, model.get_pooled_output(), next_sentence_labels
198
- )
199
-
200
- total_loss = masked_lm_loss + next_sentence_loss
201
-
202
- tvars = tf.trainable_variables()
203
-
204
- initialized_variable_names = {}
205
- scaffold_fn = None
206
- if init_checkpoint:
207
- (
208
- assignment_map,
209
- initialized_variable_names,
210
- ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
211
- if use_tpu:
212
-
213
- def tpu_scaffold():
214
- tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
215
- return tf.train.Scaffold()
216
-
217
- scaffold_fn = tpu_scaffold
218
- else:
219
- tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
220
-
221
- tf.logging.info("**** Trainable Variables ****")
222
- for var in tvars:
223
- init_string = ""
224
- if var.name in initialized_variable_names:
225
- init_string = ", *INIT_FROM_CKPT*"
226
- tf.logging.info(
227
- " name = %s, shape = %s%s", var.name, var.shape, init_string
228
- )
229
-
230
- output_spec = None
231
- if mode == tf.estimator.ModeKeys.TRAIN:
232
- train_op = optimization.create_optimizer(
233
- total_loss,
234
- learning_rate,
235
- num_train_steps,
236
- num_warmup_steps,
237
- use_tpu,
238
- optimizer,
239
- poly_power,
240
- start_warmup_step,
241
- )
242
-
243
- output_spec = tf.contrib.tpu.TPUEstimatorSpec(
244
- mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn
245
- )
246
- elif mode == tf.estimator.ModeKeys.EVAL:
247
-
248
- def metric_fn(
249
- masked_lm_example_loss,
250
- masked_lm_log_probs,
251
- masked_lm_ids,
252
- masked_lm_weights,
253
- next_sentence_example_loss,
254
- next_sentence_log_probs,
255
- next_sentence_labels,
256
- ):
257
- """Computes the loss and accuracy of the model."""
258
- masked_lm_log_probs = tf.reshape(
259
- masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]]
260
- )
261
- masked_lm_predictions = tf.argmax(
262
- masked_lm_log_probs, axis=-1, output_type=tf.int32
263
- )
264
- masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
265
- masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
266
- masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
267
- masked_lm_accuracy = tf.metrics.accuracy(
268
- labels=masked_lm_ids,
269
- predictions=masked_lm_predictions,
270
- weights=masked_lm_weights,
271
- )
272
- masked_lm_mean_loss = tf.metrics.mean(
273
- values=masked_lm_example_loss, weights=masked_lm_weights
274
- )
275
-
276
- next_sentence_log_probs = tf.reshape(
277
- next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]
278
- )
279
- next_sentence_predictions = tf.argmax(
280
- next_sentence_log_probs, axis=-1, output_type=tf.int32
281
- )
282
- next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
283
- next_sentence_accuracy = tf.metrics.accuracy(
284
- labels=next_sentence_labels, predictions=next_sentence_predictions
285
- )
286
- next_sentence_mean_loss = tf.metrics.mean(
287
- values=next_sentence_example_loss
288
- )
289
-
290
- return {
291
- "masked_lm_accuracy": masked_lm_accuracy,
292
- "masked_lm_loss": masked_lm_mean_loss,
293
- "next_sentence_accuracy": next_sentence_accuracy,
294
- "next_sentence_loss": next_sentence_mean_loss,
295
- }
296
-
297
- eval_metrics = (
298
- metric_fn,
299
- [
300
- masked_lm_example_loss,
301
- masked_lm_log_probs,
302
- masked_lm_ids,
303
- masked_lm_weights,
304
- next_sentence_example_loss,
305
- next_sentence_log_probs,
306
- next_sentence_labels,
307
- ],
308
- )
309
- output_spec = tf.contrib.tpu.TPUEstimatorSpec(
310
- mode=mode,
311
- loss=total_loss,
312
- eval_metrics=eval_metrics,
313
- scaffold_fn=scaffold_fn,
314
- )
315
- else:
316
- raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
317
-
318
- return output_spec
319
-
320
- return model_fn
321
-
322
-
323
- def get_masked_lm_output(
324
- bert_config, input_tensor, output_weights, positions, label_ids, label_weights
325
- ):
326
- """Get loss and log probs for the masked LM."""
327
- input_tensor = gather_indexes(input_tensor, positions)
328
-
329
- with tf.variable_scope("cls/predictions"):
330
- # We apply one more non-linear transformation before the output layer.
331
- # This matrix is not used after pre-training.
332
- with tf.variable_scope("transform"):
333
- input_tensor = tf.layers.dense(
334
- input_tensor,
335
- units=bert_config.hidden_size,
336
- activation=modeling.get_activation(bert_config.hidden_act),
337
- kernel_initializer=modeling.create_initializer(
338
- bert_config.initializer_range
339
- ),
340
- )
341
- input_tensor = modeling.layer_norm(input_tensor)
342
-
343
- # The output weights are the same as the input embeddings, but there is
344
- # an output-only bias for each token.
345
- output_bias = tf.get_variable(
346
- "output_bias",
347
- shape=[bert_config.vocab_size],
348
- initializer=tf.zeros_initializer(),
349
- )
350
- logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
351
- logits = tf.nn.bias_add(logits, output_bias)
352
- log_probs = tf.nn.log_softmax(logits, axis=-1)
353
-
354
- label_ids = tf.reshape(label_ids, [-1])
355
- label_weights = tf.reshape(label_weights, [-1])
356
-
357
- one_hot_labels = tf.one_hot(
358
- label_ids, depth=bert_config.vocab_size, dtype=tf.float32
359
- )
360
-
361
- # The `positions` tensor might be zero-padded (if the sequence is too
362
- # short to have the maximum number of predictions). The `label_weights`
363
- # tensor has a value of 1.0 for every real prediction and 0.0 for the
364
- # padding predictions.
365
- per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
366
- numerator = tf.reduce_sum(label_weights * per_example_loss)
367
- denominator = tf.reduce_sum(label_weights) + 1e-5
368
- loss = numerator / denominator
369
-
370
- return (loss, per_example_loss, log_probs)
371
-
372
-
373
- def get_next_sentence_output(bert_config, input_tensor, labels):
374
- """Get loss and log probs for the next sentence prediction."""
375
-
376
- # Simple binary classification. Note that 0 is "next sentence" and 1 is
377
- # "random sentence". This weight matrix is not used after pre-training.
378
- with tf.variable_scope("cls/seq_relationship"):
379
- output_weights = tf.get_variable(
380
- "output_weights",
381
- shape=[2, bert_config.hidden_size],
382
- initializer=modeling.create_initializer(bert_config.initializer_range),
383
- )
384
- output_bias = tf.get_variable(
385
- "output_bias", shape=[2], initializer=tf.zeros_initializer()
386
- )
387
-
388
- logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
389
- logits = tf.nn.bias_add(logits, output_bias)
390
- log_probs = tf.nn.log_softmax(logits, axis=-1)
391
- labels = tf.reshape(labels, [-1])
392
- one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
393
- per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
394
- loss = tf.reduce_mean(per_example_loss)
395
- return (loss, per_example_loss, log_probs)
396
-
397
-
398
- def gather_indexes(sequence_tensor, positions):
399
- """Gathers the vectors at the specific positions over a minibatch."""
400
- sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
401
- batch_size = sequence_shape[0]
402
- seq_length = sequence_shape[1]
403
- width = sequence_shape[2]
404
-
405
- flat_offsets = tf.reshape(
406
- tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]
407
- )
408
- flat_positions = tf.reshape(positions + flat_offsets, [-1])
409
- flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width])
410
- output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
411
- return output_tensor
412
-
413
-
414
- def input_fn_builder(
415
- input_files, max_seq_length, max_predictions_per_seq, is_training, num_cpu_threads=4
416
- ):
417
- """Creates an `input_fn` closure to be passed to TPUEstimator."""
418
-
419
- def input_fn(params):
420
- """The actual input function."""
421
- batch_size = params["batch_size"]
422
-
423
- name_to_features = {
424
- "input_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
425
- "input_mask": tf.FixedLenFeature([max_seq_length], tf.int64),
426
- "segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
427
- "masked_lm_positions": tf.FixedLenFeature(
428
- [max_predictions_per_seq], tf.int64
429
- ),
430
- "masked_lm_ids": tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
431
- "masked_lm_weights": tf.FixedLenFeature(
432
- [max_predictions_per_seq], tf.float32
433
- ),
434
- "next_sentence_labels": tf.FixedLenFeature([1], tf.int64),
435
- }
436
-
437
- # For training, we want a lot of parallel reading and shuffling.
438
- # For eval, we want no shuffling and parallel reading doesn't matter.
439
- if is_training:
440
- d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
441
- d = d.repeat()
442
- d = d.shuffle(buffer_size=len(input_files))
443
-
444
- # `cycle_length` is the number of parallel files that get read.
445
- cycle_length = min(num_cpu_threads, len(input_files))
446
-
447
- # `sloppy` mode means that the interleaving is not exact. This adds
448
- # even more randomness to the training pipeline.
449
- d = d.apply(
450
- tf.contrib.data.parallel_interleave(
451
- tf.data.TFRecordDataset,
452
- sloppy=is_training,
453
- cycle_length=cycle_length,
454
- )
455
- )
456
- d = d.shuffle(buffer_size=100)
457
- else:
458
- d = tf.data.TFRecordDataset(input_files)
459
- # Since we evaluate for a fixed number of steps we don't want to encounter
460
- # out-of-range exceptions.
461
- d = d.repeat()
462
-
463
- # We must `drop_remainder` on training because the TPU requires fixed
464
- # size dimensions. For eval, we assume we are evaluating on the CPU or GPU
465
- # and we *don't* want to drop the remainder, otherwise we wont cover
466
- # every sample.
467
- d = d.apply(
468
- tf.contrib.data.map_and_batch(
469
- lambda record: _decode_record(record, name_to_features),
470
- batch_size=batch_size,
471
- num_parallel_batches=num_cpu_threads,
472
- drop_remainder=True,
473
- )
474
- )
475
- return d
476
-
477
- return input_fn
478
-
479
-
480
- def _decode_record(record, name_to_features):
481
- """Decodes a record to a TensorFlow example."""
482
- example = tf.parse_single_example(record, name_to_features)
483
-
484
- # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
485
- # So cast all int64 to int32.
486
- for name in list(example.keys()):
487
- t = example[name]
488
- if t.dtype == tf.int64:
489
- t = tf.to_int32(t)
490
- example[name] = t
491
-
492
- return example
493
-
494
-
495
- def main(_):
496
- tf.logging.set_verbosity(tf.logging.INFO)
497
- logger = tf.get_logger()
498
- logger.propagate = False
499
-
500
- if not FLAGS.do_train and not FLAGS.do_eval:
501
- raise ValueError("At least one of `do_train` or `do_eval` must be True.")
502
-
503
- bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
504
-
505
- tf.gfile.MakeDirs(FLAGS.output_dir)
506
-
507
- input_files = []
508
- for input_pattern in FLAGS.input_file.split(","):
509
- input_files.extend(tf.gfile.Glob(input_pattern))
510
-
511
- # tf.logging.info("*** Input Files ***")
512
- # for input_file in input_files:
513
- # tf.logging.info(" %s" % input_file)
514
-
515
- tpu_cluster_resolver = None
516
- if FLAGS.use_tpu and FLAGS.tpu_name:
517
- tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
518
- FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project
519
- )
520
-
521
- is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
522
- run_config = tf.contrib.tpu.RunConfig(
523
- cluster=tpu_cluster_resolver,
524
- master=FLAGS.master,
525
- model_dir=FLAGS.output_dir,
526
- save_checkpoints_steps=FLAGS.save_checkpoints_steps,
527
- keep_checkpoint_max=FLAGS.keep_checkpoint_max,
528
- tpu_config=tf.contrib.tpu.TPUConfig(
529
- iterations_per_loop=FLAGS.iterations_per_loop,
530
- num_shards=FLAGS.num_tpu_cores,
531
- per_host_input_for_training=is_per_host,
532
- ),
533
- )
534
-
535
- model_fn = model_fn_builder(
536
- bert_config=bert_config,
537
- init_checkpoint=FLAGS.init_checkpoint,
538
- learning_rate=FLAGS.learning_rate,
539
- num_train_steps=FLAGS.num_train_steps,
540
- num_warmup_steps=FLAGS.num_warmup_steps,
541
- use_tpu=FLAGS.use_tpu,
542
- use_one_hot_embeddings=FLAGS.use_tpu,
543
- optimizer=FLAGS.optimizer,
544
- poly_power=FLAGS.poly_power,
545
- start_warmup_step=FLAGS.start_warmup_step
546
- )
547
-
548
- # If TPU is not available, this will fall back to normal Estimator on CPU
549
- # or GPU.
550
- estimator = tf.contrib.tpu.TPUEstimator(
551
- use_tpu=FLAGS.use_tpu,
552
- model_fn=model_fn,
553
- config=run_config,
554
- train_batch_size=FLAGS.train_batch_size,
555
- eval_batch_size=FLAGS.eval_batch_size,
556
- )
557
-
558
- if FLAGS.do_train:
559
- tf.logging.info("***** Running training *****")
560
- tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
561
- train_input_fn = input_fn_builder(
562
- input_files=input_files,
563
- max_seq_length=FLAGS.max_seq_length,
564
- max_predictions_per_seq=FLAGS.max_predictions_per_seq,
565
- is_training=True,
566
- )
567
- estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
568
-
569
- if FLAGS.do_eval:
570
- tf.logging.info("***** Running evaluation *****")
571
- tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
572
-
573
- eval_input_fn = input_fn_builder(
574
- input_files=input_files,
575
- max_seq_length=FLAGS.max_seq_length,
576
- max_predictions_per_seq=FLAGS.max_predictions_per_seq,
577
- is_training=False,
578
- )
579
-
580
- result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
581
-
582
- output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
583
- with tf.gfile.GFile(output_eval_file, "w") as writer:
584
- tf.logging.info("***** Eval results *****")
585
- for key in sorted(result.keys()):
586
- tf.logging.info(" %s = %s", key, str(result[key]))
587
- writer.write("%s = %s\n" % (key, str(result[key])))
588
-
589
- if __name__ == "__main__":
590
- flags.mark_flag_as_required("input_file")
591
- flags.mark_flag_as_required("bert_config_file")
592
- flags.mark_flag_as_required("output_dir")
593
- tf.app.run()