SinaTools 0.1.40__py2.py3-none-any.whl → 1.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/METADATA +1 -1
  2. SinaTools-1.0.1.dist-info/RECORD +73 -0
  3. sinatools/VERSION +1 -1
  4. sinatools/ner/__init__.py +5 -7
  5. sinatools/ner/trainers/BertNestedTrainer.py +203 -203
  6. sinatools/ner/trainers/BertTrainer.py +163 -163
  7. sinatools/ner/trainers/__init__.py +2 -2
  8. SinaTools-0.1.40.dist-info/RECORD +0 -123
  9. sinatools/arabert/arabert/__init__.py +0 -14
  10. sinatools/arabert/arabert/create_classification_data.py +0 -260
  11. sinatools/arabert/arabert/create_pretraining_data.py +0 -534
  12. sinatools/arabert/arabert/extract_features.py +0 -444
  13. sinatools/arabert/arabert/lamb_optimizer.py +0 -158
  14. sinatools/arabert/arabert/modeling.py +0 -1027
  15. sinatools/arabert/arabert/optimization.py +0 -202
  16. sinatools/arabert/arabert/run_classifier.py +0 -1078
  17. sinatools/arabert/arabert/run_pretraining.py +0 -593
  18. sinatools/arabert/arabert/run_squad.py +0 -1440
  19. sinatools/arabert/arabert/tokenization.py +0 -414
  20. sinatools/arabert/araelectra/__init__.py +0 -1
  21. sinatools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +0 -103
  22. sinatools/arabert/araelectra/build_pretraining_dataset.py +0 -230
  23. sinatools/arabert/araelectra/build_pretraining_dataset_single_file.py +0 -90
  24. sinatools/arabert/araelectra/configure_finetuning.py +0 -172
  25. sinatools/arabert/araelectra/configure_pretraining.py +0 -143
  26. sinatools/arabert/araelectra/finetune/__init__.py +0 -14
  27. sinatools/arabert/araelectra/finetune/feature_spec.py +0 -56
  28. sinatools/arabert/araelectra/finetune/preprocessing.py +0 -173
  29. sinatools/arabert/araelectra/finetune/scorer.py +0 -54
  30. sinatools/arabert/araelectra/finetune/task.py +0 -74
  31. sinatools/arabert/araelectra/finetune/task_builder.py +0 -70
  32. sinatools/arabert/araelectra/flops_computation.py +0 -215
  33. sinatools/arabert/araelectra/model/__init__.py +0 -14
  34. sinatools/arabert/araelectra/model/modeling.py +0 -1029
  35. sinatools/arabert/araelectra/model/optimization.py +0 -193
  36. sinatools/arabert/araelectra/model/tokenization.py +0 -355
  37. sinatools/arabert/araelectra/pretrain/__init__.py +0 -14
  38. sinatools/arabert/araelectra/pretrain/pretrain_data.py +0 -160
  39. sinatools/arabert/araelectra/pretrain/pretrain_helpers.py +0 -229
  40. sinatools/arabert/araelectra/run_finetuning.py +0 -323
  41. sinatools/arabert/araelectra/run_pretraining.py +0 -469
  42. sinatools/arabert/araelectra/util/__init__.py +0 -14
  43. sinatools/arabert/araelectra/util/training_utils.py +0 -112
  44. sinatools/arabert/araelectra/util/utils.py +0 -109
  45. sinatools/arabert/aragpt2/__init__.py +0 -2
  46. sinatools/arabert/aragpt2/create_pretraining_data.py +0 -95
  47. sinatools/arabert/aragpt2/gpt2/__init__.py +0 -2
  48. sinatools/arabert/aragpt2/gpt2/lamb_optimizer.py +0 -158
  49. sinatools/arabert/aragpt2/gpt2/optimization.py +0 -225
  50. sinatools/arabert/aragpt2/gpt2/run_pretraining.py +0 -397
  51. sinatools/arabert/aragpt2/grover/__init__.py +0 -0
  52. sinatools/arabert/aragpt2/grover/dataloader.py +0 -161
  53. sinatools/arabert/aragpt2/grover/modeling.py +0 -803
  54. sinatools/arabert/aragpt2/grover/modeling_gpt2.py +0 -1196
  55. sinatools/arabert/aragpt2/grover/optimization_adafactor.py +0 -234
  56. sinatools/arabert/aragpt2/grover/train_tpu.py +0 -187
  57. sinatools/arabert/aragpt2/grover/utils.py +0 -234
  58. sinatools/arabert/aragpt2/train_bpe_tokenizer.py +0 -59
  59. {SinaTools-0.1.40.data → SinaTools-1.0.1.data}/data/sinatools/environment.yml +0 -0
  60. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/AUTHORS.rst +0 -0
  61. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/LICENSE +0 -0
  62. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/WHEEL +0 -0
  63. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/entry_points.txt +0 -0
  64. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/top_level.txt +0 -0
@@ -1,1440 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2018 The Google AI Language Team Authors.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Run BERT on SQuAD 1.1 and SQuAD 2.0."""
16
-
17
- from __future__ import absolute_import
18
- from __future__ import division
19
- from __future__ import print_function
20
-
21
- import collections
22
- import json
23
- import math
24
- import os
25
- import random
26
- import modeling
27
- import optimization
28
- import tokenization
29
- import six
30
- import tensorflow as tf
31
-
32
- flags = tf.flags
33
-
34
- FLAGS = flags.FLAGS
35
-
36
- ## Required parameters
37
- flags.DEFINE_string(
38
- "bert_config_file",
39
- None,
40
- "The config json file corresponding to the pre-trained BERT model. "
41
- "This specifies the model architecture.",
42
- )
43
-
44
- flags.DEFINE_string(
45
- "vocab_file", None, "The vocabulary file that the BERT model was trained on."
46
- )
47
-
48
- flags.DEFINE_string(
49
- "output_dir",
50
- None,
51
- "The output directory where the model checkpoints will be written.",
52
- )
53
-
54
- ## Other parameters
55
- flags.DEFINE_string(
56
- "train_file", None, "SQuAD json for training. E.g., train-v1.1.json"
57
- )
58
-
59
- flags.DEFINE_string(
60
- "predict_file",
61
- None,
62
- "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json",
63
- )
64
-
65
- flags.DEFINE_string(
66
- "init_checkpoint",
67
- None,
68
- "Initial checkpoint (usually from a pre-trained BERT model).",
69
- )
70
-
71
- flags.DEFINE_bool(
72
- "do_lower_case",
73
- True,
74
- "Whether to lower case the input text. Should be True for uncased "
75
- "models and False for cased models.",
76
- )
77
-
78
- flags.DEFINE_integer(
79
- "max_seq_length",
80
- 384,
81
- "The maximum total input sequence length after WordPiece tokenization. "
82
- "Sequences longer than this will be truncated, and sequences shorter "
83
- "than this will be padded.",
84
- )
85
-
86
- flags.DEFINE_integer(
87
- "doc_stride",
88
- 128,
89
- "When splitting up a long document into chunks, how much stride to "
90
- "take between chunks.",
91
- )
92
-
93
- flags.DEFINE_integer(
94
- "max_query_length",
95
- 64,
96
- "The maximum number of tokens for the question. Questions longer than "
97
- "this will be truncated to this length.",
98
- )
99
-
100
- flags.DEFINE_bool("do_train", False, "Whether to run training.")
101
-
102
- flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")
103
-
104
- flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
105
-
106
- flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predictions.")
107
-
108
- flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
109
-
110
- flags.DEFINE_float(
111
- "num_train_epochs", 3.0, "Total number of training epochs to perform."
112
- )
113
-
114
- flags.DEFINE_float(
115
- "warmup_proportion",
116
- 0.1,
117
- "Proportion of training to perform linear learning rate warmup for. "
118
- "E.g., 0.1 = 10% of training.",
119
- )
120
-
121
- flags.DEFINE_integer(
122
- "save_checkpoints_steps", 1000, "How often to save the model checkpoint."
123
- )
124
-
125
- flags.DEFINE_integer(
126
- "iterations_per_loop", 1000, "How many steps to make in each estimator call."
127
- )
128
-
129
- flags.DEFINE_integer(
130
- "n_best_size",
131
- 20,
132
- "The total number of n-best predictions to generate in the "
133
- "nbest_predictions.json output file.",
134
- )
135
-
136
- flags.DEFINE_integer(
137
- "max_answer_length",
138
- 30,
139
- "The maximum length of an answer that can be generated. This is needed "
140
- "because the start and end predictions are not conditioned on one another.",
141
- )
142
-
143
- flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
144
-
145
- tf.flags.DEFINE_string(
146
- "tpu_name",
147
- None,
148
- "The Cloud TPU to use for training. This should be either the name "
149
- "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
150
- "url.",
151
- )
152
-
153
- tf.flags.DEFINE_string(
154
- "tpu_zone",
155
- None,
156
- "[Optional] GCE zone where the Cloud TPU is located in. If not "
157
- "specified, we will attempt to automatically detect the GCE project from "
158
- "metadata.",
159
- )
160
-
161
- tf.flags.DEFINE_string(
162
- "gcp_project",
163
- None,
164
- "[Optional] Project name for the Cloud TPU-enabled project. If not "
165
- "specified, we will attempt to automatically detect the GCE project from "
166
- "metadata.",
167
- )
168
-
169
- tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
170
-
171
- flags.DEFINE_integer(
172
- "num_tpu_cores",
173
- 8,
174
- "Only used if `use_tpu` is True. Total number of TPU cores to use.",
175
- )
176
-
177
- flags.DEFINE_bool(
178
- "verbose_logging",
179
- False,
180
- "If true, all of the warnings related to data processing will be printed. "
181
- "A number of warnings are expected for a normal SQuAD evaluation.",
182
- )
183
-
184
- flags.DEFINE_bool(
185
- "version_2_with_negative",
186
- False,
187
- "If true, the SQuAD examples contain some that do not have an answer.",
188
- )
189
-
190
- flags.DEFINE_float(
191
- "null_score_diff_threshold",
192
- 0.0,
193
- "If null_score - best_non_null is greater than the threshold predict null.",
194
- )
195
-
196
-
197
- class SquadExample(object):
198
- """A single training/test example for simple sequence classification.
199
-
200
- For examples without an answer, the start and end position are -1.
201
- """
202
-
203
- def __init__(
204
- self,
205
- qas_id,
206
- question_text,
207
- doc_tokens,
208
- orig_answer_text=None,
209
- start_position=None,
210
- end_position=None,
211
- is_impossible=False,
212
- ):
213
- self.qas_id = qas_id
214
- self.question_text = question_text
215
- self.doc_tokens = doc_tokens
216
- self.orig_answer_text = orig_answer_text
217
- self.start_position = start_position
218
- self.end_position = end_position
219
- self.is_impossible = is_impossible
220
-
221
- def __str__(self):
222
- return self.__repr__()
223
-
224
- def __repr__(self):
225
- s = ""
226
- s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
227
- s += ", question_text: %s" % (tokenization.printable_text(self.question_text))
228
- s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
229
- if self.start_position:
230
- s += ", start_position: %d" % (self.start_position)
231
- if self.start_position:
232
- s += ", end_position: %d" % (self.end_position)
233
- if self.start_position:
234
- s += ", is_impossible: %r" % (self.is_impossible)
235
- return s
236
-
237
-
238
- class InputFeatures(object):
239
- """A single set of features of data."""
240
-
241
- def __init__(
242
- self,
243
- unique_id,
244
- example_index,
245
- doc_span_index,
246
- tokens,
247
- token_to_orig_map,
248
- token_is_max_context,
249
- input_ids,
250
- input_mask,
251
- segment_ids,
252
- start_position=None,
253
- end_position=None,
254
- is_impossible=None,
255
- ):
256
- self.unique_id = unique_id
257
- self.example_index = example_index
258
- self.doc_span_index = doc_span_index
259
- self.tokens = tokens
260
- self.token_to_orig_map = token_to_orig_map
261
- self.token_is_max_context = token_is_max_context
262
- self.input_ids = input_ids
263
- self.input_mask = input_mask
264
- self.segment_ids = segment_ids
265
- self.start_position = start_position
266
- self.end_position = end_position
267
- self.is_impossible = is_impossible
268
-
269
-
270
- def read_squad_examples(input_file, is_training):
271
- """Read a SQuAD json file into a list of SquadExample."""
272
- with tf.gfile.Open(input_file, "r") as reader:
273
- input_data = json.load(reader)["data"]
274
-
275
- def is_whitespace(c):
276
- if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
277
- return True
278
- return False
279
-
280
- examples = []
281
- for entry in input_data:
282
- for paragraph in entry["paragraphs"]:
283
- paragraph_text = paragraph["context"]
284
- doc_tokens = []
285
- char_to_word_offset = []
286
- prev_is_whitespace = True
287
- for c in paragraph_text:
288
- if is_whitespace(c):
289
- prev_is_whitespace = True
290
- else:
291
- if prev_is_whitespace:
292
- doc_tokens.append(c)
293
- else:
294
- doc_tokens[-1] += c
295
- prev_is_whitespace = False
296
- char_to_word_offset.append(len(doc_tokens) - 1)
297
-
298
- for qa in paragraph["qas"]:
299
- qas_id = qa["id"]
300
- question_text = qa["question"]
301
- start_position = None
302
- end_position = None
303
- orig_answer_text = None
304
- is_impossible = False
305
- if is_training:
306
-
307
- if FLAGS.version_2_with_negative:
308
- is_impossible = qa["is_impossible"]
309
- if (len(qa["answers"]) != 1) and (not is_impossible):
310
- raise ValueError(
311
- "For training, each question should have exactly 1 answer."
312
- )
313
- if not is_impossible:
314
- answer = qa["answers"][0]
315
- orig_answer_text = answer["text"]
316
- answer_offset = answer["answer_start"]
317
- answer_length = len(orig_answer_text)
318
- start_position = char_to_word_offset[answer_offset]
319
- end_position = char_to_word_offset[
320
- answer_offset + answer_length - 1
321
- ]
322
- # Only add answers where the text can be exactly recovered from the
323
- # document. If this CAN'T happen it's likely due to weird Unicode
324
- # stuff so we will just skip the example.
325
- #
326
- # Note that this means for training mode, every example is NOT
327
- # guaranteed to be preserved.
328
- actual_text = " ".join(
329
- doc_tokens[start_position : (end_position + 1)]
330
- )
331
- cleaned_answer_text = " ".join(
332
- tokenization.whitespace_tokenize(orig_answer_text)
333
- )
334
- if actual_text.find(cleaned_answer_text) == -1:
335
- tf.logging.warning(
336
- "Could not find answer: '%s' vs. '%s'",
337
- actual_text,
338
- cleaned_answer_text,
339
- )
340
- continue
341
- else:
342
- start_position = -1
343
- end_position = -1
344
- orig_answer_text = ""
345
-
346
- example = SquadExample(
347
- qas_id=qas_id,
348
- question_text=question_text,
349
- doc_tokens=doc_tokens,
350
- orig_answer_text=orig_answer_text,
351
- start_position=start_position,
352
- end_position=end_position,
353
- is_impossible=is_impossible,
354
- )
355
- examples.append(example)
356
-
357
- return examples
358
-
359
-
360
- def convert_examples_to_features(
361
- examples,
362
- tokenizer,
363
- max_seq_length,
364
- doc_stride,
365
- max_query_length,
366
- is_training,
367
- output_fn,
368
- ):
369
- """Loads a data file into a list of `InputBatch`s."""
370
-
371
- unique_id = 1000000000
372
-
373
- for (example_index, example) in enumerate(examples):
374
- query_tokens = tokenizer.tokenize(example.question_text)
375
-
376
- if len(query_tokens) > max_query_length:
377
- query_tokens = query_tokens[0:max_query_length]
378
-
379
- tok_to_orig_index = []
380
- orig_to_tok_index = []
381
- all_doc_tokens = []
382
- for (i, token) in enumerate(example.doc_tokens):
383
- orig_to_tok_index.append(len(all_doc_tokens))
384
- sub_tokens = tokenizer.tokenize(token)
385
- for sub_token in sub_tokens:
386
- tok_to_orig_index.append(i)
387
- all_doc_tokens.append(sub_token)
388
-
389
- tok_start_position = None
390
- tok_end_position = None
391
- if is_training and example.is_impossible:
392
- tok_start_position = -1
393
- tok_end_position = -1
394
- if is_training and not example.is_impossible:
395
- tok_start_position = orig_to_tok_index[example.start_position]
396
- if example.end_position < len(example.doc_tokens) - 1:
397
- tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
398
- else:
399
- tok_end_position = len(all_doc_tokens) - 1
400
- (tok_start_position, tok_end_position) = _improve_answer_span(
401
- all_doc_tokens,
402
- tok_start_position,
403
- tok_end_position,
404
- tokenizer,
405
- example.orig_answer_text,
406
- )
407
-
408
- # The -3 accounts for [CLS], [SEP] and [SEP]
409
- max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
410
-
411
- # We can have documents that are longer than the maximum sequence length.
412
- # To deal with this we do a sliding window approach, where we take chunks
413
- # of the up to our max length with a stride of `doc_stride`.
414
- _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
415
- "DocSpan", ["start", "length"]
416
- )
417
- doc_spans = []
418
- start_offset = 0
419
- while start_offset < len(all_doc_tokens):
420
- length = len(all_doc_tokens) - start_offset
421
- if length > max_tokens_for_doc:
422
- length = max_tokens_for_doc
423
- doc_spans.append(_DocSpan(start=start_offset, length=length))
424
- if start_offset + length == len(all_doc_tokens):
425
- break
426
- start_offset += min(length, doc_stride)
427
-
428
- for (doc_span_index, doc_span) in enumerate(doc_spans):
429
- tokens = []
430
- token_to_orig_map = {}
431
- token_is_max_context = {}
432
- segment_ids = []
433
- tokens.append("[CLS]")
434
- segment_ids.append(0)
435
- for token in query_tokens:
436
- tokens.append(token)
437
- segment_ids.append(0)
438
- tokens.append("[SEP]")
439
- segment_ids.append(0)
440
-
441
- for i in range(doc_span.length):
442
- split_token_index = doc_span.start + i
443
- token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
444
-
445
- is_max_context = _check_is_max_context(
446
- doc_spans, doc_span_index, split_token_index
447
- )
448
- token_is_max_context[len(tokens)] = is_max_context
449
- tokens.append(all_doc_tokens[split_token_index])
450
- segment_ids.append(1)
451
- tokens.append("[SEP]")
452
- segment_ids.append(1)
453
-
454
- input_ids = tokenizer.convert_tokens_to_ids(tokens)
455
-
456
- # The mask has 1 for real tokens and 0 for padding tokens. Only real
457
- # tokens are attended to.
458
- input_mask = [1] * len(input_ids)
459
-
460
- # Zero-pad up to the sequence length.
461
- while len(input_ids) < max_seq_length:
462
- input_ids.append(0)
463
- input_mask.append(0)
464
- segment_ids.append(0)
465
-
466
- assert len(input_ids) == max_seq_length
467
- assert len(input_mask) == max_seq_length
468
- assert len(segment_ids) == max_seq_length
469
-
470
- start_position = None
471
- end_position = None
472
- if is_training and not example.is_impossible:
473
- # For training, if our document chunk does not contain an annotation
474
- # we throw it out, since there is nothing to predict.
475
- doc_start = doc_span.start
476
- doc_end = doc_span.start + doc_span.length - 1
477
- out_of_span = False
478
- if not (
479
- tok_start_position >= doc_start and tok_end_position <= doc_end
480
- ):
481
- out_of_span = True
482
- if out_of_span:
483
- start_position = 0
484
- end_position = 0
485
- else:
486
- doc_offset = len(query_tokens) + 2
487
- start_position = tok_start_position - doc_start + doc_offset
488
- end_position = tok_end_position - doc_start + doc_offset
489
-
490
- if is_training and example.is_impossible:
491
- start_position = 0
492
- end_position = 0
493
-
494
- if example_index < 20:
495
- tf.logging.info("*** Example ***")
496
- tf.logging.info("unique_id: %s" % (unique_id))
497
- tf.logging.info("example_index: %s" % (example_index))
498
- tf.logging.info("doc_span_index: %s" % (doc_span_index))
499
- tf.logging.info(
500
- "tokens: %s"
501
- % " ".join([tokenization.printable_text(x) for x in tokens])
502
- )
503
- tf.logging.info(
504
- "token_to_orig_map: %s"
505
- % " ".join(
506
- [
507
- "%d:%d" % (x, y)
508
- for (x, y) in six.iteritems(token_to_orig_map)
509
- ]
510
- )
511
- )
512
- tf.logging.info(
513
- "token_is_max_context: %s"
514
- % " ".join(
515
- [
516
- "%d:%s" % (x, y)
517
- for (x, y) in six.iteritems(token_is_max_context)
518
- ]
519
- )
520
- )
521
- tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
522
- tf.logging.info(
523
- "input_mask: %s" % " ".join([str(x) for x in input_mask])
524
- )
525
- tf.logging.info(
526
- "segment_ids: %s" % " ".join([str(x) for x in segment_ids])
527
- )
528
- if is_training and example.is_impossible:
529
- tf.logging.info("impossible example")
530
- if is_training and not example.is_impossible:
531
- answer_text = " ".join(tokens[start_position : (end_position + 1)])
532
- tf.logging.info("start_position: %d" % (start_position))
533
- tf.logging.info("end_position: %d" % (end_position))
534
- tf.logging.info(
535
- "answer: %s" % (tokenization.printable_text(answer_text))
536
- )
537
-
538
- feature = InputFeatures(
539
- unique_id=unique_id,
540
- example_index=example_index,
541
- doc_span_index=doc_span_index,
542
- tokens=tokens,
543
- token_to_orig_map=token_to_orig_map,
544
- token_is_max_context=token_is_max_context,
545
- input_ids=input_ids,
546
- input_mask=input_mask,
547
- segment_ids=segment_ids,
548
- start_position=start_position,
549
- end_position=end_position,
550
- is_impossible=example.is_impossible,
551
- )
552
-
553
- # Run callback
554
- output_fn(feature)
555
-
556
- unique_id += 1
557
-
558
-
559
- def _improve_answer_span(
560
- doc_tokens, input_start, input_end, tokenizer, orig_answer_text
561
- ):
562
- """Returns tokenized answer spans that better match the annotated answer."""
563
-
564
- # The SQuAD annotations are character based. We first project them to
565
- # whitespace-tokenized words. But then after WordPiece tokenization, we can
566
- # often find a "better match". For example:
567
- #
568
- # Question: What year was John Smith born?
569
- # Context: The leader was John Smith (1895-1943).
570
- # Answer: 1895
571
- #
572
- # The original whitespace-tokenized answer will be "(1895-1943).". However
573
- # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
574
- # the exact answer, 1895.
575
- #
576
- # However, this is not always possible. Consider the following:
577
- #
578
- # Question: What country is the top exporter of electornics?
579
- # Context: The Japanese electronics industry is the lagest in the world.
580
- # Answer: Japan
581
- #
582
- # In this case, the annotator chose "Japan" as a character sub-span of
583
- # the word "Japanese". Since our WordPiece tokenizer does not split
584
- # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
585
- # in SQuAD, but does happen.
586
- tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
587
-
588
- for new_start in range(input_start, input_end + 1):
589
- for new_end in range(input_end, new_start - 1, -1):
590
- text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
591
- if text_span == tok_answer_text:
592
- return (new_start, new_end)
593
-
594
- return (input_start, input_end)
595
-
596
-
597
- def _check_is_max_context(doc_spans, cur_span_index, position):
598
- """Check if this is the 'max context' doc span for the token."""
599
-
600
- # Because of the sliding window approach taken to scoring documents, a single
601
- # token can appear in multiple documents. E.g.
602
- # Doc: the man went to the store and bought a gallon of milk
603
- # Span A: the man went to the
604
- # Span B: to the store and bought
605
- # Span C: and bought a gallon of
606
- # ...
607
- #
608
- # Now the word 'bought' will have two scores from spans B and C. We only
609
- # want to consider the score with "maximum context", which we define as
610
- # the *minimum* of its left and right context (the *sum* of left and
611
- # right context will always be the same, of course).
612
- #
613
- # In the example the maximum context for 'bought' would be span C since
614
- # it has 1 left context and 3 right context, while span B has 4 left context
615
- # and 0 right context.
616
- best_score = None
617
- best_span_index = None
618
- for (span_index, doc_span) in enumerate(doc_spans):
619
- end = doc_span.start + doc_span.length - 1
620
- if position < doc_span.start:
621
- continue
622
- if position > end:
623
- continue
624
- num_left_context = position - doc_span.start
625
- num_right_context = end - position
626
- score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
627
- if best_score is None or score > best_score:
628
- best_score = score
629
- best_span_index = span_index
630
-
631
- return cur_span_index == best_span_index
632
-
633
-
634
- def create_model(
635
- bert_config, is_training, input_ids, input_mask, segment_ids, use_one_hot_embeddings
636
- ):
637
- """Creates a classification model."""
638
- model = modeling.BertModel(
639
- config=bert_config,
640
- is_training=is_training,
641
- input_ids=input_ids,
642
- input_mask=input_mask,
643
- token_type_ids=segment_ids,
644
- use_one_hot_embeddings=use_one_hot_embeddings,
645
- )
646
-
647
- final_hidden = model.get_sequence_output()
648
-
649
- final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
650
- batch_size = final_hidden_shape[0]
651
- seq_length = final_hidden_shape[1]
652
- hidden_size = final_hidden_shape[2]
653
-
654
- output_weights = tf.get_variable(
655
- "cls/squad/output_weights",
656
- [2, hidden_size],
657
- initializer=tf.truncated_normal_initializer(stddev=0.02),
658
- )
659
-
660
- output_bias = tf.get_variable(
661
- "cls/squad/output_bias", [2], initializer=tf.zeros_initializer()
662
- )
663
-
664
- final_hidden_matrix = tf.reshape(
665
- final_hidden, [batch_size * seq_length, hidden_size]
666
- )
667
- logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
668
- logits = tf.nn.bias_add(logits, output_bias)
669
-
670
- logits = tf.reshape(logits, [batch_size, seq_length, 2])
671
- logits = tf.transpose(logits, [2, 0, 1])
672
-
673
- unstacked_logits = tf.unstack(logits, axis=0)
674
-
675
- (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])
676
-
677
- return (start_logits, end_logits)
678
-
679
-
680
- def model_fn_builder(
681
- bert_config,
682
- init_checkpoint,
683
- learning_rate,
684
- num_train_steps,
685
- num_warmup_steps,
686
- use_tpu,
687
- use_one_hot_embeddings,
688
- ):
689
- """Returns `model_fn` closure for TPUEstimator."""
690
-
691
- def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
692
- """The `model_fn` for TPUEstimator."""
693
-
694
- tf.logging.info("*** Features ***")
695
- for name in sorted(features.keys()):
696
- tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
697
-
698
- unique_ids = features["unique_ids"]
699
- input_ids = features["input_ids"]
700
- input_mask = features["input_mask"]
701
- segment_ids = features["segment_ids"]
702
-
703
- is_training = mode == tf.estimator.ModeKeys.TRAIN
704
-
705
- (start_logits, end_logits) = create_model(
706
- bert_config=bert_config,
707
- is_training=is_training,
708
- input_ids=input_ids,
709
- input_mask=input_mask,
710
- segment_ids=segment_ids,
711
- use_one_hot_embeddings=use_one_hot_embeddings,
712
- )
713
-
714
- tvars = tf.trainable_variables()
715
-
716
- initialized_variable_names = {}
717
- scaffold_fn = None
718
- if init_checkpoint:
719
- (
720
- assignment_map,
721
- initialized_variable_names,
722
- ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
723
- if use_tpu:
724
-
725
- def tpu_scaffold():
726
- tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
727
- return tf.train.Scaffold()
728
-
729
- scaffold_fn = tpu_scaffold
730
- else:
731
- tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
732
-
733
- tf.logging.info("**** Trainable Variables ****")
734
- for var in tvars:
735
- init_string = ""
736
- if var.name in initialized_variable_names:
737
- init_string = ", *INIT_FROM_CKPT*"
738
- tf.logging.info(
739
- " name = %s, shape = %s%s", var.name, var.shape, init_string
740
- )
741
-
742
- output_spec = None
743
- if mode == tf.estimator.ModeKeys.TRAIN:
744
- seq_length = modeling.get_shape_list(input_ids)[1]
745
-
746
- def compute_loss(logits, positions):
747
- one_hot_positions = tf.one_hot(
748
- positions, depth=seq_length, dtype=tf.float32
749
- )
750
- log_probs = tf.nn.log_softmax(logits, axis=-1)
751
- loss = -tf.reduce_mean(
752
- tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
753
- )
754
- return loss
755
-
756
- start_positions = features["start_positions"]
757
- end_positions = features["end_positions"]
758
-
759
- start_loss = compute_loss(start_logits, start_positions)
760
- end_loss = compute_loss(end_logits, end_positions)
761
-
762
- total_loss = (start_loss + end_loss) / 2.0
763
-
764
- train_op = optimization.create_optimizer(
765
- total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu
766
- )
767
-
768
- output_spec = tf.contrib.tpu.TPUEstimatorSpec(
769
- mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn
770
- )
771
- elif mode == tf.estimator.ModeKeys.PREDICT:
772
- predictions = {
773
- "unique_ids": unique_ids,
774
- "start_logits": start_logits,
775
- "end_logits": end_logits,
776
- }
777
- output_spec = tf.contrib.tpu.TPUEstimatorSpec(
778
- mode=mode, predictions=predictions, scaffold_fn=scaffold_fn
779
- )
780
- else:
781
- raise ValueError("Only TRAIN and PREDICT modes are supported: %s" % (mode))
782
-
783
- return output_spec
784
-
785
- return model_fn
786
-
787
-
788
- def input_fn_builder(input_file, seq_length, is_training, drop_remainder):
789
- """Creates an `input_fn` closure to be passed to TPUEstimator."""
790
-
791
- name_to_features = {
792
- "unique_ids": tf.FixedLenFeature([], tf.int64),
793
- "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
794
- "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
795
- "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
796
- }
797
-
798
- if is_training:
799
- name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64)
800
- name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64)
801
-
802
- def _decode_record(record, name_to_features):
803
- """Decodes a record to a TensorFlow example."""
804
- example = tf.parse_single_example(record, name_to_features)
805
-
806
- # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
807
- # So cast all int64 to int32.
808
- for name in list(example.keys()):
809
- t = example[name]
810
- if t.dtype == tf.int64:
811
- t = tf.to_int32(t)
812
- example[name] = t
813
-
814
- return example
815
-
816
- def input_fn(params):
817
- """The actual input function."""
818
- batch_size = params["batch_size"]
819
-
820
- # For training, we want a lot of parallel reading and shuffling.
821
- # For eval, we want no shuffling and parallel reading doesn't matter.
822
- d = tf.data.TFRecordDataset(input_file)
823
- if is_training:
824
- d = d.repeat()
825
- d = d.shuffle(buffer_size=100)
826
-
827
- d = d.apply(
828
- tf.contrib.data.map_and_batch(
829
- lambda record: _decode_record(record, name_to_features),
830
- batch_size=batch_size,
831
- drop_remainder=drop_remainder,
832
- )
833
- )
834
-
835
- return d
836
-
837
- return input_fn
838
-
839
-
840
- RawResult = collections.namedtuple(
841
- "RawResult", ["unique_id", "start_logits", "end_logits"]
842
- )
843
-
844
-
845
- def write_predictions(
846
- all_examples,
847
- all_features,
848
- all_results,
849
- n_best_size,
850
- max_answer_length,
851
- do_lower_case,
852
- output_prediction_file,
853
- output_nbest_file,
854
- output_null_log_odds_file,
855
- ):
856
- """Write final predictions to the json file and log-odds of null if needed."""
857
- tf.logging.info("Writing predictions to: %s" % (output_prediction_file))
858
- tf.logging.info("Writing nbest to: %s" % (output_nbest_file))
859
-
860
- example_index_to_features = collections.defaultdict(list)
861
- for feature in all_features:
862
- example_index_to_features[feature.example_index].append(feature)
863
-
864
- unique_id_to_result = {}
865
- for result in all_results:
866
- unique_id_to_result[result.unique_id] = result
867
-
868
- _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
869
- "PrelimPrediction",
870
- ["feature_index", "start_index", "end_index", "start_logit", "end_logit"],
871
- )
872
-
873
- all_predictions = collections.OrderedDict()
874
- all_nbest_json = collections.OrderedDict()
875
- scores_diff_json = collections.OrderedDict()
876
-
877
- for (example_index, example) in enumerate(all_examples):
878
- features = example_index_to_features[example_index]
879
-
880
- prelim_predictions = []
881
- # keep track of the minimum score of null start+end of position 0
882
- score_null = 1000000 # large and positive
883
- min_null_feature_index = 0 # the paragraph slice with min mull score
884
- null_start_logit = 0 # the start logit at the slice with min null score
885
- null_end_logit = 0 # the end logit at the slice with min null score
886
- for (feature_index, feature) in enumerate(features):
887
- result = unique_id_to_result[feature.unique_id]
888
- start_indexes = _get_best_indexes(result.start_logits, n_best_size)
889
- end_indexes = _get_best_indexes(result.end_logits, n_best_size)
890
- # if we could have irrelevant answers, get the min score of irrelevant
891
- if FLAGS.version_2_with_negative:
892
- feature_null_score = result.start_logits[0] + result.end_logits[0]
893
- if feature_null_score < score_null:
894
- score_null = feature_null_score
895
- min_null_feature_index = feature_index
896
- null_start_logit = result.start_logits[0]
897
- null_end_logit = result.end_logits[0]
898
- for start_index in start_indexes:
899
- for end_index in end_indexes:
900
- # We could hypothetically create invalid predictions, e.g., predict
901
- # that the start of the span is in the question. We throw out all
902
- # invalid predictions.
903
- if start_index >= len(feature.tokens):
904
- continue
905
- if end_index >= len(feature.tokens):
906
- continue
907
- if start_index not in feature.token_to_orig_map:
908
- continue
909
- if end_index not in feature.token_to_orig_map:
910
- continue
911
- if not feature.token_is_max_context.get(start_index, False):
912
- continue
913
- if end_index < start_index:
914
- continue
915
- length = end_index - start_index + 1
916
- if length > max_answer_length:
917
- continue
918
- prelim_predictions.append(
919
- _PrelimPrediction(
920
- feature_index=feature_index,
921
- start_index=start_index,
922
- end_index=end_index,
923
- start_logit=result.start_logits[start_index],
924
- end_logit=result.end_logits[end_index],
925
- )
926
- )
927
-
928
- if FLAGS.version_2_with_negative:
929
- prelim_predictions.append(
930
- _PrelimPrediction(
931
- feature_index=min_null_feature_index,
932
- start_index=0,
933
- end_index=0,
934
- start_logit=null_start_logit,
935
- end_logit=null_end_logit,
936
- )
937
- )
938
- prelim_predictions = sorted(
939
- prelim_predictions,
940
- key=lambda x: (x.start_logit + x.end_logit),
941
- reverse=True,
942
- )
943
-
944
- _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
945
- "NbestPrediction", ["text", "start_logit", "end_logit"]
946
- )
947
-
948
- seen_predictions = {}
949
- nbest = []
950
- for pred in prelim_predictions:
951
- if len(nbest) >= n_best_size:
952
- break
953
- feature = features[pred.feature_index]
954
- if pred.start_index > 0: # this is a non-null prediction
955
- tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
956
- orig_doc_start = feature.token_to_orig_map[pred.start_index]
957
- orig_doc_end = feature.token_to_orig_map[pred.end_index]
958
- orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
959
- tok_text = " ".join(tok_tokens)
960
-
961
- # De-tokenize WordPieces that have been split off.
962
- tok_text = tok_text.replace(" ##", "")
963
- tok_text = tok_text.replace("##", "")
964
-
965
- # Clean whitespace
966
- tok_text = tok_text.strip()
967
- tok_text = " ".join(tok_text.split())
968
- orig_text = " ".join(orig_tokens)
969
-
970
- final_text = get_final_text(tok_text, orig_text, do_lower_case)
971
- if final_text in seen_predictions:
972
- continue
973
-
974
- seen_predictions[final_text] = True
975
- else:
976
- final_text = ""
977
- seen_predictions[final_text] = True
978
-
979
- nbest.append(
980
- _NbestPrediction(
981
- text=final_text,
982
- start_logit=pred.start_logit,
983
- end_logit=pred.end_logit,
984
- )
985
- )
986
-
987
- # if we didn't inlude the empty option in the n-best, inlcude it
988
- if FLAGS.version_2_with_negative:
989
- if "" not in seen_predictions:
990
- nbest.append(
991
- _NbestPrediction(
992
- text="", start_logit=null_start_logit, end_logit=null_end_logit
993
- )
994
- )
995
- # In very rare edge cases we could have no valid predictions. So we
996
- # just create a nonce prediction in this case to avoid failure.
997
- if not nbest:
998
- nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
999
-
1000
- assert len(nbest) >= 1
1001
-
1002
- total_scores = []
1003
- best_non_null_entry = None
1004
- for entry in nbest:
1005
- total_scores.append(entry.start_logit + entry.end_logit)
1006
- if not best_non_null_entry:
1007
- if entry.text:
1008
- best_non_null_entry = entry
1009
-
1010
- probs = _compute_softmax(total_scores)
1011
-
1012
- nbest_json = []
1013
- for (i, entry) in enumerate(nbest):
1014
- output = collections.OrderedDict()
1015
- output["text"] = entry.text
1016
- output["probability"] = probs[i]
1017
- output["start_logit"] = entry.start_logit
1018
- output["end_logit"] = entry.end_logit
1019
- nbest_json.append(output)
1020
-
1021
- assert len(nbest_json) >= 1
1022
-
1023
- if not FLAGS.version_2_with_negative:
1024
- all_predictions[example.qas_id] = nbest_json[0]["text"]
1025
- else:
1026
- # predict "" iff the null score - the score of best non-null > threshold
1027
- score_diff = (
1028
- score_null
1029
- - best_non_null_entry.start_logit
1030
- - (best_non_null_entry.end_logit)
1031
- )
1032
- scores_diff_json[example.qas_id] = score_diff
1033
- if score_diff > FLAGS.null_score_diff_threshold:
1034
- all_predictions[example.qas_id] = ""
1035
- else:
1036
- all_predictions[example.qas_id] = best_non_null_entry.text
1037
-
1038
- all_nbest_json[example.qas_id] = nbest_json
1039
-
1040
- with tf.gfile.GFile(output_prediction_file, "w") as writer:
1041
- writer.write(json.dumps(all_predictions, indent=4) + "\n")
1042
-
1043
- with tf.gfile.GFile(output_nbest_file, "w") as writer:
1044
- writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
1045
-
1046
- if FLAGS.version_2_with_negative:
1047
- with tf.gfile.GFile(output_null_log_odds_file, "w") as writer:
1048
- writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
1049
-
1050
-
1051
- def get_final_text(pred_text, orig_text, do_lower_case):
1052
- """Project the tokenized prediction back to the original text."""
1053
-
1054
- # When we created the data, we kept track of the alignment between original
1055
- # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
1056
- # now `orig_text` contains the span of our original text corresponding to the
1057
- # span that we predicted.
1058
- #
1059
- # However, `orig_text` may contain extra characters that we don't want in
1060
- # our prediction.
1061
- #
1062
- # For example, let's say:
1063
- # pred_text = steve smith
1064
- # orig_text = Steve Smith's
1065
- #
1066
- # We don't want to return `orig_text` because it contains the extra "'s".
1067
- #
1068
- # We don't want to return `pred_text` because it's already been normalized
1069
- # (the SQuAD eval script also does punctuation stripping/lower casing but
1070
- # our tokenizer does additional normalization like stripping accent
1071
- # characters).
1072
- #
1073
- # What we really want to return is "Steve Smith".
1074
- #
1075
- # Therefore, we have to apply a semi-complicated alignment heruistic between
1076
- # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
1077
- # can fail in certain cases in which case we just return `orig_text`.
1078
-
1079
- def _strip_spaces(text):
1080
- ns_chars = []
1081
- ns_to_s_map = collections.OrderedDict()
1082
- for (i, c) in enumerate(text):
1083
- if c == " ":
1084
- continue
1085
- ns_to_s_map[len(ns_chars)] = i
1086
- ns_chars.append(c)
1087
- ns_text = "".join(ns_chars)
1088
- return (ns_text, ns_to_s_map)
1089
-
1090
- # We first tokenize `orig_text`, strip whitespace from the result
1091
- # and `pred_text`, and check if they are the same length. If they are
1092
- # NOT the same length, the heuristic has failed. If they are the same
1093
- # length, we assume the characters are one-to-one aligned.
1094
- tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
1095
-
1096
- tok_text = " ".join(tokenizer.tokenize(orig_text))
1097
-
1098
- start_position = tok_text.find(pred_text)
1099
- if start_position == -1:
1100
- if FLAGS.verbose_logging:
1101
- tf.logging.info(
1102
- "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)
1103
- )
1104
- return orig_text
1105
- end_position = start_position + len(pred_text) - 1
1106
-
1107
- (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
1108
- (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
1109
-
1110
- if len(orig_ns_text) != len(tok_ns_text):
1111
- if FLAGS.verbose_logging:
1112
- tf.logging.info(
1113
- "Length not equal after stripping spaces: '%s' vs '%s'",
1114
- orig_ns_text,
1115
- tok_ns_text,
1116
- )
1117
- return orig_text
1118
-
1119
- # We then project the characters in `pred_text` back to `orig_text` using
1120
- # the character-to-character alignment.
1121
- tok_s_to_ns_map = {}
1122
- for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
1123
- tok_s_to_ns_map[tok_index] = i
1124
-
1125
- orig_start_position = None
1126
- if start_position in tok_s_to_ns_map:
1127
- ns_start_position = tok_s_to_ns_map[start_position]
1128
- if ns_start_position in orig_ns_to_s_map:
1129
- orig_start_position = orig_ns_to_s_map[ns_start_position]
1130
-
1131
- if orig_start_position is None:
1132
- if FLAGS.verbose_logging:
1133
- tf.logging.info("Couldn't map start position")
1134
- return orig_text
1135
-
1136
- orig_end_position = None
1137
- if end_position in tok_s_to_ns_map:
1138
- ns_end_position = tok_s_to_ns_map[end_position]
1139
- if ns_end_position in orig_ns_to_s_map:
1140
- orig_end_position = orig_ns_to_s_map[ns_end_position]
1141
-
1142
- if orig_end_position is None:
1143
- if FLAGS.verbose_logging:
1144
- tf.logging.info("Couldn't map end position")
1145
- return orig_text
1146
-
1147
- output_text = orig_text[orig_start_position : (orig_end_position + 1)]
1148
- return output_text
1149
-
1150
-
1151
- def _get_best_indexes(logits, n_best_size):
1152
- """Get the n-best logits from a list."""
1153
- index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
1154
-
1155
- best_indexes = []
1156
- for i in range(len(index_and_score)):
1157
- if i >= n_best_size:
1158
- break
1159
- best_indexes.append(index_and_score[i][0])
1160
- return best_indexes
1161
-
1162
-
1163
- def _compute_softmax(scores):
1164
- """Compute softmax probability over raw logits."""
1165
- if not scores:
1166
- return []
1167
-
1168
- max_score = None
1169
- for score in scores:
1170
- if max_score is None or score > max_score:
1171
- max_score = score
1172
-
1173
- exp_scores = []
1174
- total_sum = 0.0
1175
- for score in scores:
1176
- x = math.exp(score - max_score)
1177
- exp_scores.append(x)
1178
- total_sum += x
1179
-
1180
- probs = []
1181
- for score in exp_scores:
1182
- probs.append(score / total_sum)
1183
- return probs
1184
-
1185
-
1186
- class FeatureWriter(object):
1187
- """Writes InputFeature to TF example file."""
1188
-
1189
- def __init__(self, filename, is_training):
1190
- self.filename = filename
1191
- self.is_training = is_training
1192
- self.num_features = 0
1193
- self._writer = tf.python_io.TFRecordWriter(filename)
1194
-
1195
- def process_feature(self, feature):
1196
- """Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
1197
- self.num_features += 1
1198
-
1199
- def create_int_feature(values):
1200
- feature = tf.train.Feature(
1201
- int64_list=tf.train.Int64List(value=list(values))
1202
- )
1203
- return feature
1204
-
1205
- features = collections.OrderedDict()
1206
- features["unique_ids"] = create_int_feature([feature.unique_id])
1207
- features["input_ids"] = create_int_feature(feature.input_ids)
1208
- features["input_mask"] = create_int_feature(feature.input_mask)
1209
- features["segment_ids"] = create_int_feature(feature.segment_ids)
1210
-
1211
- if self.is_training:
1212
- features["start_positions"] = create_int_feature([feature.start_position])
1213
- features["end_positions"] = create_int_feature([feature.end_position])
1214
- impossible = 0
1215
- if feature.is_impossible:
1216
- impossible = 1
1217
- features["is_impossible"] = create_int_feature([impossible])
1218
-
1219
- tf_example = tf.train.Example(features=tf.train.Features(feature=features))
1220
- self._writer.write(tf_example.SerializeToString())
1221
-
1222
- def close(self):
1223
- self._writer.close()
1224
-
1225
-
1226
- def validate_flags_or_throw(bert_config):
1227
- """Validate the input FLAGS or throw an exception."""
1228
- tokenization.validate_case_matches_checkpoint(
1229
- FLAGS.do_lower_case, FLAGS.init_checkpoint
1230
- )
1231
-
1232
- if not FLAGS.do_train and not FLAGS.do_predict:
1233
- raise ValueError("At least one of `do_train` or `do_predict` must be True.")
1234
-
1235
- if FLAGS.do_train:
1236
- if not FLAGS.train_file:
1237
- raise ValueError(
1238
- "If `do_train` is True, then `train_file` must be specified."
1239
- )
1240
- if FLAGS.do_predict:
1241
- if not FLAGS.predict_file:
1242
- raise ValueError(
1243
- "If `do_predict` is True, then `predict_file` must be specified."
1244
- )
1245
-
1246
- if FLAGS.max_seq_length > bert_config.max_position_embeddings:
1247
- raise ValueError(
1248
- "Cannot use sequence length %d because the BERT model "
1249
- "was only trained up to sequence length %d"
1250
- % (FLAGS.max_seq_length, bert_config.max_position_embeddings)
1251
- )
1252
-
1253
- if FLAGS.max_seq_length <= FLAGS.max_query_length + 3:
1254
- raise ValueError(
1255
- "The max_seq_length (%d) must be greater than max_query_length "
1256
- "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length)
1257
- )
1258
-
1259
-
1260
- def main(_):
1261
- tf.logging.set_verbosity(tf.logging.INFO)
1262
- logger = tf.get_logger()
1263
- logger.propagate = False
1264
-
1265
- bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
1266
-
1267
- validate_flags_or_throw(bert_config)
1268
-
1269
- tf.gfile.MakeDirs(FLAGS.output_dir)
1270
-
1271
- tokenizer = tokenization.FullTokenizer(
1272
- vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
1273
- )
1274
-
1275
- tpu_cluster_resolver = None
1276
- if FLAGS.use_tpu and FLAGS.tpu_name:
1277
- tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
1278
- FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project
1279
- )
1280
-
1281
- is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
1282
- run_config = tf.contrib.tpu.RunConfig(
1283
- cluster=tpu_cluster_resolver,
1284
- master=FLAGS.master,
1285
- model_dir=FLAGS.output_dir,
1286
- save_checkpoints_steps=FLAGS.save_checkpoints_steps,
1287
- tpu_config=tf.contrib.tpu.TPUConfig(
1288
- iterations_per_loop=FLAGS.iterations_per_loop,
1289
- num_shards=FLAGS.num_tpu_cores,
1290
- per_host_input_for_training=is_per_host,
1291
- ),
1292
- )
1293
-
1294
- train_examples = None
1295
- num_train_steps = None
1296
- num_warmup_steps = None
1297
- if FLAGS.do_train:
1298
- train_examples = read_squad_examples(
1299
- input_file=FLAGS.train_file, is_training=True
1300
- )
1301
- num_train_steps = int(
1302
- len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs
1303
- )
1304
- num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
1305
-
1306
- # Pre-shuffle the input to avoid having to make a very large shuffle
1307
- # buffer in in the `input_fn`.
1308
- rng = random.Random(12345)
1309
- rng.shuffle(train_examples)
1310
-
1311
- model_fn = model_fn_builder(
1312
- bert_config=bert_config,
1313
- init_checkpoint=FLAGS.init_checkpoint,
1314
- learning_rate=FLAGS.learning_rate,
1315
- num_train_steps=num_train_steps,
1316
- num_warmup_steps=num_warmup_steps,
1317
- use_tpu=FLAGS.use_tpu,
1318
- use_one_hot_embeddings=FLAGS.use_tpu,
1319
- )
1320
-
1321
- # If TPU is not available, this will fall back to normal Estimator on CPU
1322
- # or GPU.
1323
- estimator = tf.contrib.tpu.TPUEstimator(
1324
- use_tpu=FLAGS.use_tpu,
1325
- model_fn=model_fn,
1326
- config=run_config,
1327
- train_batch_size=FLAGS.train_batch_size,
1328
- predict_batch_size=FLAGS.predict_batch_size,
1329
- )
1330
-
1331
- if FLAGS.do_train:
1332
- # We write to a temporary file to avoid storing very large constant tensors
1333
- # in memory.
1334
- train_writer = FeatureWriter(
1335
- filename=os.path.join(FLAGS.output_dir, "train.tf_record"), is_training=True
1336
- )
1337
- convert_examples_to_features(
1338
- examples=train_examples,
1339
- tokenizer=tokenizer,
1340
- max_seq_length=FLAGS.max_seq_length,
1341
- doc_stride=FLAGS.doc_stride,
1342
- max_query_length=FLAGS.max_query_length,
1343
- is_training=True,
1344
- output_fn=train_writer.process_feature,
1345
- )
1346
- train_writer.close()
1347
-
1348
- tf.logging.info("***** Running training *****")
1349
- tf.logging.info(" Num orig examples = %d", len(train_examples))
1350
- tf.logging.info(" Num split examples = %d", train_writer.num_features)
1351
- tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
1352
- tf.logging.info(" Num steps = %d", num_train_steps)
1353
- del train_examples
1354
-
1355
- train_input_fn = input_fn_builder(
1356
- input_file=train_writer.filename,
1357
- seq_length=FLAGS.max_seq_length,
1358
- is_training=True,
1359
- drop_remainder=True,
1360
- )
1361
- estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
1362
-
1363
- if FLAGS.do_predict:
1364
- eval_examples = read_squad_examples(
1365
- input_file=FLAGS.predict_file, is_training=False
1366
- )
1367
-
1368
- eval_writer = FeatureWriter(
1369
- filename=os.path.join(FLAGS.output_dir, "eval.tf_record"), is_training=False
1370
- )
1371
- eval_features = []
1372
-
1373
- def append_feature(feature):
1374
- eval_features.append(feature)
1375
- eval_writer.process_feature(feature)
1376
-
1377
- convert_examples_to_features(
1378
- examples=eval_examples,
1379
- tokenizer=tokenizer,
1380
- max_seq_length=FLAGS.max_seq_length,
1381
- doc_stride=FLAGS.doc_stride,
1382
- max_query_length=FLAGS.max_query_length,
1383
- is_training=False,
1384
- output_fn=append_feature,
1385
- )
1386
- eval_writer.close()
1387
-
1388
- tf.logging.info("***** Running predictions *****")
1389
- tf.logging.info(" Num orig examples = %d", len(eval_examples))
1390
- tf.logging.info(" Num split examples = %d", len(eval_features))
1391
- tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
1392
-
1393
- all_results = []
1394
-
1395
- predict_input_fn = input_fn_builder(
1396
- input_file=eval_writer.filename,
1397
- seq_length=FLAGS.max_seq_length,
1398
- is_training=False,
1399
- drop_remainder=False,
1400
- )
1401
-
1402
- # If running eval on the TPU, you will need to specify the number of
1403
- # steps.
1404
- all_results = []
1405
- for result in estimator.predict(predict_input_fn, yield_single_examples=True):
1406
- if len(all_results) % 1000 == 0:
1407
- tf.logging.info("Processing example: %d" % (len(all_results)))
1408
- unique_id = int(result["unique_ids"])
1409
- start_logits = [float(x) for x in result["start_logits"].flat]
1410
- end_logits = [float(x) for x in result["end_logits"].flat]
1411
- all_results.append(
1412
- RawResult(
1413
- unique_id=unique_id,
1414
- start_logits=start_logits,
1415
- end_logits=end_logits,
1416
- )
1417
- )
1418
-
1419
- output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json")
1420
- output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json")
1421
- output_null_log_odds_file = os.path.join(FLAGS.output_dir, "null_odds.json")
1422
-
1423
- write_predictions(
1424
- eval_examples,
1425
- eval_features,
1426
- all_results,
1427
- FLAGS.n_best_size,
1428
- FLAGS.max_answer_length,
1429
- FLAGS.do_lower_case,
1430
- output_prediction_file,
1431
- output_nbest_file,
1432
- output_null_log_odds_file,
1433
- )
1434
-
1435
-
1436
- if __name__ == "__main__":
1437
- flags.mark_flag_as_required("vocab_file")
1438
- flags.mark_flag_as_required("bert_config_file")
1439
- flags.mark_flag_as_required("output_dir")
1440
- tf.app.run()