SinaTools 0.1.40__py2.py3-none-any.whl → 1.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/METADATA +1 -1
- SinaTools-1.0.1.dist-info/RECORD +73 -0
- sinatools/VERSION +1 -1
- sinatools/ner/__init__.py +5 -7
- sinatools/ner/trainers/BertNestedTrainer.py +203 -203
- sinatools/ner/trainers/BertTrainer.py +163 -163
- sinatools/ner/trainers/__init__.py +2 -2
- SinaTools-0.1.40.dist-info/RECORD +0 -123
- sinatools/arabert/arabert/__init__.py +0 -14
- sinatools/arabert/arabert/create_classification_data.py +0 -260
- sinatools/arabert/arabert/create_pretraining_data.py +0 -534
- sinatools/arabert/arabert/extract_features.py +0 -444
- sinatools/arabert/arabert/lamb_optimizer.py +0 -158
- sinatools/arabert/arabert/modeling.py +0 -1027
- sinatools/arabert/arabert/optimization.py +0 -202
- sinatools/arabert/arabert/run_classifier.py +0 -1078
- sinatools/arabert/arabert/run_pretraining.py +0 -593
- sinatools/arabert/arabert/run_squad.py +0 -1440
- sinatools/arabert/arabert/tokenization.py +0 -414
- sinatools/arabert/araelectra/__init__.py +0 -1
- sinatools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +0 -103
- sinatools/arabert/araelectra/build_pretraining_dataset.py +0 -230
- sinatools/arabert/araelectra/build_pretraining_dataset_single_file.py +0 -90
- sinatools/arabert/araelectra/configure_finetuning.py +0 -172
- sinatools/arabert/araelectra/configure_pretraining.py +0 -143
- sinatools/arabert/araelectra/finetune/__init__.py +0 -14
- sinatools/arabert/araelectra/finetune/feature_spec.py +0 -56
- sinatools/arabert/araelectra/finetune/preprocessing.py +0 -173
- sinatools/arabert/araelectra/finetune/scorer.py +0 -54
- sinatools/arabert/araelectra/finetune/task.py +0 -74
- sinatools/arabert/araelectra/finetune/task_builder.py +0 -70
- sinatools/arabert/araelectra/flops_computation.py +0 -215
- sinatools/arabert/araelectra/model/__init__.py +0 -14
- sinatools/arabert/araelectra/model/modeling.py +0 -1029
- sinatools/arabert/araelectra/model/optimization.py +0 -193
- sinatools/arabert/araelectra/model/tokenization.py +0 -355
- sinatools/arabert/araelectra/pretrain/__init__.py +0 -14
- sinatools/arabert/araelectra/pretrain/pretrain_data.py +0 -160
- sinatools/arabert/araelectra/pretrain/pretrain_helpers.py +0 -229
- sinatools/arabert/araelectra/run_finetuning.py +0 -323
- sinatools/arabert/araelectra/run_pretraining.py +0 -469
- sinatools/arabert/araelectra/util/__init__.py +0 -14
- sinatools/arabert/araelectra/util/training_utils.py +0 -112
- sinatools/arabert/araelectra/util/utils.py +0 -109
- sinatools/arabert/aragpt2/__init__.py +0 -2
- sinatools/arabert/aragpt2/create_pretraining_data.py +0 -95
- sinatools/arabert/aragpt2/gpt2/__init__.py +0 -2
- sinatools/arabert/aragpt2/gpt2/lamb_optimizer.py +0 -158
- sinatools/arabert/aragpt2/gpt2/optimization.py +0 -225
- sinatools/arabert/aragpt2/gpt2/run_pretraining.py +0 -397
- sinatools/arabert/aragpt2/grover/__init__.py +0 -0
- sinatools/arabert/aragpt2/grover/dataloader.py +0 -161
- sinatools/arabert/aragpt2/grover/modeling.py +0 -803
- sinatools/arabert/aragpt2/grover/modeling_gpt2.py +0 -1196
- sinatools/arabert/aragpt2/grover/optimization_adafactor.py +0 -234
- sinatools/arabert/aragpt2/grover/train_tpu.py +0 -187
- sinatools/arabert/aragpt2/grover/utils.py +0 -234
- sinatools/arabert/aragpt2/train_bpe_tokenizer.py +0 -59
- {SinaTools-0.1.40.data → SinaTools-1.0.1.data}/data/sinatools/environment.yml +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/AUTHORS.rst +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/LICENSE +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/WHEEL +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/entry_points.txt +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/top_level.txt +0 -0
@@ -1,1440 +0,0 @@
|
|
1
|
-
# coding=utf-8
|
2
|
-
# Copyright 2018 The Google AI Language Team Authors.
|
3
|
-
#
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
-
# you may not use this file except in compliance with the License.
|
6
|
-
# You may obtain a copy of the License at
|
7
|
-
#
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
-
#
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
-
# See the License for the specific language governing permissions and
|
14
|
-
# limitations under the License.
|
15
|
-
"""Run BERT on SQuAD 1.1 and SQuAD 2.0."""
|
16
|
-
|
17
|
-
from __future__ import absolute_import
|
18
|
-
from __future__ import division
|
19
|
-
from __future__ import print_function
|
20
|
-
|
21
|
-
import collections
|
22
|
-
import json
|
23
|
-
import math
|
24
|
-
import os
|
25
|
-
import random
|
26
|
-
import modeling
|
27
|
-
import optimization
|
28
|
-
import tokenization
|
29
|
-
import six
|
30
|
-
import tensorflow as tf
|
31
|
-
|
32
|
-
flags = tf.flags
|
33
|
-
|
34
|
-
FLAGS = flags.FLAGS
|
35
|
-
|
36
|
-
## Required parameters
|
37
|
-
flags.DEFINE_string(
|
38
|
-
"bert_config_file",
|
39
|
-
None,
|
40
|
-
"The config json file corresponding to the pre-trained BERT model. "
|
41
|
-
"This specifies the model architecture.",
|
42
|
-
)
|
43
|
-
|
44
|
-
flags.DEFINE_string(
|
45
|
-
"vocab_file", None, "The vocabulary file that the BERT model was trained on."
|
46
|
-
)
|
47
|
-
|
48
|
-
flags.DEFINE_string(
|
49
|
-
"output_dir",
|
50
|
-
None,
|
51
|
-
"The output directory where the model checkpoints will be written.",
|
52
|
-
)
|
53
|
-
|
54
|
-
## Other parameters
|
55
|
-
flags.DEFINE_string(
|
56
|
-
"train_file", None, "SQuAD json for training. E.g., train-v1.1.json"
|
57
|
-
)
|
58
|
-
|
59
|
-
flags.DEFINE_string(
|
60
|
-
"predict_file",
|
61
|
-
None,
|
62
|
-
"SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json",
|
63
|
-
)
|
64
|
-
|
65
|
-
flags.DEFINE_string(
|
66
|
-
"init_checkpoint",
|
67
|
-
None,
|
68
|
-
"Initial checkpoint (usually from a pre-trained BERT model).",
|
69
|
-
)
|
70
|
-
|
71
|
-
flags.DEFINE_bool(
|
72
|
-
"do_lower_case",
|
73
|
-
True,
|
74
|
-
"Whether to lower case the input text. Should be True for uncased "
|
75
|
-
"models and False for cased models.",
|
76
|
-
)
|
77
|
-
|
78
|
-
flags.DEFINE_integer(
|
79
|
-
"max_seq_length",
|
80
|
-
384,
|
81
|
-
"The maximum total input sequence length after WordPiece tokenization. "
|
82
|
-
"Sequences longer than this will be truncated, and sequences shorter "
|
83
|
-
"than this will be padded.",
|
84
|
-
)
|
85
|
-
|
86
|
-
flags.DEFINE_integer(
|
87
|
-
"doc_stride",
|
88
|
-
128,
|
89
|
-
"When splitting up a long document into chunks, how much stride to "
|
90
|
-
"take between chunks.",
|
91
|
-
)
|
92
|
-
|
93
|
-
flags.DEFINE_integer(
|
94
|
-
"max_query_length",
|
95
|
-
64,
|
96
|
-
"The maximum number of tokens for the question. Questions longer than "
|
97
|
-
"this will be truncated to this length.",
|
98
|
-
)
|
99
|
-
|
100
|
-
flags.DEFINE_bool("do_train", False, "Whether to run training.")
|
101
|
-
|
102
|
-
flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")
|
103
|
-
|
104
|
-
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
|
105
|
-
|
106
|
-
flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predictions.")
|
107
|
-
|
108
|
-
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
|
109
|
-
|
110
|
-
flags.DEFINE_float(
|
111
|
-
"num_train_epochs", 3.0, "Total number of training epochs to perform."
|
112
|
-
)
|
113
|
-
|
114
|
-
flags.DEFINE_float(
|
115
|
-
"warmup_proportion",
|
116
|
-
0.1,
|
117
|
-
"Proportion of training to perform linear learning rate warmup for. "
|
118
|
-
"E.g., 0.1 = 10% of training.",
|
119
|
-
)
|
120
|
-
|
121
|
-
flags.DEFINE_integer(
|
122
|
-
"save_checkpoints_steps", 1000, "How often to save the model checkpoint."
|
123
|
-
)
|
124
|
-
|
125
|
-
flags.DEFINE_integer(
|
126
|
-
"iterations_per_loop", 1000, "How many steps to make in each estimator call."
|
127
|
-
)
|
128
|
-
|
129
|
-
flags.DEFINE_integer(
|
130
|
-
"n_best_size",
|
131
|
-
20,
|
132
|
-
"The total number of n-best predictions to generate in the "
|
133
|
-
"nbest_predictions.json output file.",
|
134
|
-
)
|
135
|
-
|
136
|
-
flags.DEFINE_integer(
|
137
|
-
"max_answer_length",
|
138
|
-
30,
|
139
|
-
"The maximum length of an answer that can be generated. This is needed "
|
140
|
-
"because the start and end predictions are not conditioned on one another.",
|
141
|
-
)
|
142
|
-
|
143
|
-
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
|
144
|
-
|
145
|
-
tf.flags.DEFINE_string(
|
146
|
-
"tpu_name",
|
147
|
-
None,
|
148
|
-
"The Cloud TPU to use for training. This should be either the name "
|
149
|
-
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
150
|
-
"url.",
|
151
|
-
)
|
152
|
-
|
153
|
-
tf.flags.DEFINE_string(
|
154
|
-
"tpu_zone",
|
155
|
-
None,
|
156
|
-
"[Optional] GCE zone where the Cloud TPU is located in. If not "
|
157
|
-
"specified, we will attempt to automatically detect the GCE project from "
|
158
|
-
"metadata.",
|
159
|
-
)
|
160
|
-
|
161
|
-
tf.flags.DEFINE_string(
|
162
|
-
"gcp_project",
|
163
|
-
None,
|
164
|
-
"[Optional] Project name for the Cloud TPU-enabled project. If not "
|
165
|
-
"specified, we will attempt to automatically detect the GCE project from "
|
166
|
-
"metadata.",
|
167
|
-
)
|
168
|
-
|
169
|
-
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
|
170
|
-
|
171
|
-
flags.DEFINE_integer(
|
172
|
-
"num_tpu_cores",
|
173
|
-
8,
|
174
|
-
"Only used if `use_tpu` is True. Total number of TPU cores to use.",
|
175
|
-
)
|
176
|
-
|
177
|
-
flags.DEFINE_bool(
|
178
|
-
"verbose_logging",
|
179
|
-
False,
|
180
|
-
"If true, all of the warnings related to data processing will be printed. "
|
181
|
-
"A number of warnings are expected for a normal SQuAD evaluation.",
|
182
|
-
)
|
183
|
-
|
184
|
-
flags.DEFINE_bool(
|
185
|
-
"version_2_with_negative",
|
186
|
-
False,
|
187
|
-
"If true, the SQuAD examples contain some that do not have an answer.",
|
188
|
-
)
|
189
|
-
|
190
|
-
flags.DEFINE_float(
|
191
|
-
"null_score_diff_threshold",
|
192
|
-
0.0,
|
193
|
-
"If null_score - best_non_null is greater than the threshold predict null.",
|
194
|
-
)
|
195
|
-
|
196
|
-
|
197
|
-
class SquadExample(object):
|
198
|
-
"""A single training/test example for simple sequence classification.
|
199
|
-
|
200
|
-
For examples without an answer, the start and end position are -1.
|
201
|
-
"""
|
202
|
-
|
203
|
-
def __init__(
|
204
|
-
self,
|
205
|
-
qas_id,
|
206
|
-
question_text,
|
207
|
-
doc_tokens,
|
208
|
-
orig_answer_text=None,
|
209
|
-
start_position=None,
|
210
|
-
end_position=None,
|
211
|
-
is_impossible=False,
|
212
|
-
):
|
213
|
-
self.qas_id = qas_id
|
214
|
-
self.question_text = question_text
|
215
|
-
self.doc_tokens = doc_tokens
|
216
|
-
self.orig_answer_text = orig_answer_text
|
217
|
-
self.start_position = start_position
|
218
|
-
self.end_position = end_position
|
219
|
-
self.is_impossible = is_impossible
|
220
|
-
|
221
|
-
def __str__(self):
|
222
|
-
return self.__repr__()
|
223
|
-
|
224
|
-
def __repr__(self):
|
225
|
-
s = ""
|
226
|
-
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
|
227
|
-
s += ", question_text: %s" % (tokenization.printable_text(self.question_text))
|
228
|
-
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
|
229
|
-
if self.start_position:
|
230
|
-
s += ", start_position: %d" % (self.start_position)
|
231
|
-
if self.start_position:
|
232
|
-
s += ", end_position: %d" % (self.end_position)
|
233
|
-
if self.start_position:
|
234
|
-
s += ", is_impossible: %r" % (self.is_impossible)
|
235
|
-
return s
|
236
|
-
|
237
|
-
|
238
|
-
class InputFeatures(object):
|
239
|
-
"""A single set of features of data."""
|
240
|
-
|
241
|
-
def __init__(
|
242
|
-
self,
|
243
|
-
unique_id,
|
244
|
-
example_index,
|
245
|
-
doc_span_index,
|
246
|
-
tokens,
|
247
|
-
token_to_orig_map,
|
248
|
-
token_is_max_context,
|
249
|
-
input_ids,
|
250
|
-
input_mask,
|
251
|
-
segment_ids,
|
252
|
-
start_position=None,
|
253
|
-
end_position=None,
|
254
|
-
is_impossible=None,
|
255
|
-
):
|
256
|
-
self.unique_id = unique_id
|
257
|
-
self.example_index = example_index
|
258
|
-
self.doc_span_index = doc_span_index
|
259
|
-
self.tokens = tokens
|
260
|
-
self.token_to_orig_map = token_to_orig_map
|
261
|
-
self.token_is_max_context = token_is_max_context
|
262
|
-
self.input_ids = input_ids
|
263
|
-
self.input_mask = input_mask
|
264
|
-
self.segment_ids = segment_ids
|
265
|
-
self.start_position = start_position
|
266
|
-
self.end_position = end_position
|
267
|
-
self.is_impossible = is_impossible
|
268
|
-
|
269
|
-
|
270
|
-
def read_squad_examples(input_file, is_training):
|
271
|
-
"""Read a SQuAD json file into a list of SquadExample."""
|
272
|
-
with tf.gfile.Open(input_file, "r") as reader:
|
273
|
-
input_data = json.load(reader)["data"]
|
274
|
-
|
275
|
-
def is_whitespace(c):
|
276
|
-
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
277
|
-
return True
|
278
|
-
return False
|
279
|
-
|
280
|
-
examples = []
|
281
|
-
for entry in input_data:
|
282
|
-
for paragraph in entry["paragraphs"]:
|
283
|
-
paragraph_text = paragraph["context"]
|
284
|
-
doc_tokens = []
|
285
|
-
char_to_word_offset = []
|
286
|
-
prev_is_whitespace = True
|
287
|
-
for c in paragraph_text:
|
288
|
-
if is_whitespace(c):
|
289
|
-
prev_is_whitespace = True
|
290
|
-
else:
|
291
|
-
if prev_is_whitespace:
|
292
|
-
doc_tokens.append(c)
|
293
|
-
else:
|
294
|
-
doc_tokens[-1] += c
|
295
|
-
prev_is_whitespace = False
|
296
|
-
char_to_word_offset.append(len(doc_tokens) - 1)
|
297
|
-
|
298
|
-
for qa in paragraph["qas"]:
|
299
|
-
qas_id = qa["id"]
|
300
|
-
question_text = qa["question"]
|
301
|
-
start_position = None
|
302
|
-
end_position = None
|
303
|
-
orig_answer_text = None
|
304
|
-
is_impossible = False
|
305
|
-
if is_training:
|
306
|
-
|
307
|
-
if FLAGS.version_2_with_negative:
|
308
|
-
is_impossible = qa["is_impossible"]
|
309
|
-
if (len(qa["answers"]) != 1) and (not is_impossible):
|
310
|
-
raise ValueError(
|
311
|
-
"For training, each question should have exactly 1 answer."
|
312
|
-
)
|
313
|
-
if not is_impossible:
|
314
|
-
answer = qa["answers"][0]
|
315
|
-
orig_answer_text = answer["text"]
|
316
|
-
answer_offset = answer["answer_start"]
|
317
|
-
answer_length = len(orig_answer_text)
|
318
|
-
start_position = char_to_word_offset[answer_offset]
|
319
|
-
end_position = char_to_word_offset[
|
320
|
-
answer_offset + answer_length - 1
|
321
|
-
]
|
322
|
-
# Only add answers where the text can be exactly recovered from the
|
323
|
-
# document. If this CAN'T happen it's likely due to weird Unicode
|
324
|
-
# stuff so we will just skip the example.
|
325
|
-
#
|
326
|
-
# Note that this means for training mode, every example is NOT
|
327
|
-
# guaranteed to be preserved.
|
328
|
-
actual_text = " ".join(
|
329
|
-
doc_tokens[start_position : (end_position + 1)]
|
330
|
-
)
|
331
|
-
cleaned_answer_text = " ".join(
|
332
|
-
tokenization.whitespace_tokenize(orig_answer_text)
|
333
|
-
)
|
334
|
-
if actual_text.find(cleaned_answer_text) == -1:
|
335
|
-
tf.logging.warning(
|
336
|
-
"Could not find answer: '%s' vs. '%s'",
|
337
|
-
actual_text,
|
338
|
-
cleaned_answer_text,
|
339
|
-
)
|
340
|
-
continue
|
341
|
-
else:
|
342
|
-
start_position = -1
|
343
|
-
end_position = -1
|
344
|
-
orig_answer_text = ""
|
345
|
-
|
346
|
-
example = SquadExample(
|
347
|
-
qas_id=qas_id,
|
348
|
-
question_text=question_text,
|
349
|
-
doc_tokens=doc_tokens,
|
350
|
-
orig_answer_text=orig_answer_text,
|
351
|
-
start_position=start_position,
|
352
|
-
end_position=end_position,
|
353
|
-
is_impossible=is_impossible,
|
354
|
-
)
|
355
|
-
examples.append(example)
|
356
|
-
|
357
|
-
return examples
|
358
|
-
|
359
|
-
|
360
|
-
def convert_examples_to_features(
|
361
|
-
examples,
|
362
|
-
tokenizer,
|
363
|
-
max_seq_length,
|
364
|
-
doc_stride,
|
365
|
-
max_query_length,
|
366
|
-
is_training,
|
367
|
-
output_fn,
|
368
|
-
):
|
369
|
-
"""Loads a data file into a list of `InputBatch`s."""
|
370
|
-
|
371
|
-
unique_id = 1000000000
|
372
|
-
|
373
|
-
for (example_index, example) in enumerate(examples):
|
374
|
-
query_tokens = tokenizer.tokenize(example.question_text)
|
375
|
-
|
376
|
-
if len(query_tokens) > max_query_length:
|
377
|
-
query_tokens = query_tokens[0:max_query_length]
|
378
|
-
|
379
|
-
tok_to_orig_index = []
|
380
|
-
orig_to_tok_index = []
|
381
|
-
all_doc_tokens = []
|
382
|
-
for (i, token) in enumerate(example.doc_tokens):
|
383
|
-
orig_to_tok_index.append(len(all_doc_tokens))
|
384
|
-
sub_tokens = tokenizer.tokenize(token)
|
385
|
-
for sub_token in sub_tokens:
|
386
|
-
tok_to_orig_index.append(i)
|
387
|
-
all_doc_tokens.append(sub_token)
|
388
|
-
|
389
|
-
tok_start_position = None
|
390
|
-
tok_end_position = None
|
391
|
-
if is_training and example.is_impossible:
|
392
|
-
tok_start_position = -1
|
393
|
-
tok_end_position = -1
|
394
|
-
if is_training and not example.is_impossible:
|
395
|
-
tok_start_position = orig_to_tok_index[example.start_position]
|
396
|
-
if example.end_position < len(example.doc_tokens) - 1:
|
397
|
-
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
398
|
-
else:
|
399
|
-
tok_end_position = len(all_doc_tokens) - 1
|
400
|
-
(tok_start_position, tok_end_position) = _improve_answer_span(
|
401
|
-
all_doc_tokens,
|
402
|
-
tok_start_position,
|
403
|
-
tok_end_position,
|
404
|
-
tokenizer,
|
405
|
-
example.orig_answer_text,
|
406
|
-
)
|
407
|
-
|
408
|
-
# The -3 accounts for [CLS], [SEP] and [SEP]
|
409
|
-
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
|
410
|
-
|
411
|
-
# We can have documents that are longer than the maximum sequence length.
|
412
|
-
# To deal with this we do a sliding window approach, where we take chunks
|
413
|
-
# of the up to our max length with a stride of `doc_stride`.
|
414
|
-
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
|
415
|
-
"DocSpan", ["start", "length"]
|
416
|
-
)
|
417
|
-
doc_spans = []
|
418
|
-
start_offset = 0
|
419
|
-
while start_offset < len(all_doc_tokens):
|
420
|
-
length = len(all_doc_tokens) - start_offset
|
421
|
-
if length > max_tokens_for_doc:
|
422
|
-
length = max_tokens_for_doc
|
423
|
-
doc_spans.append(_DocSpan(start=start_offset, length=length))
|
424
|
-
if start_offset + length == len(all_doc_tokens):
|
425
|
-
break
|
426
|
-
start_offset += min(length, doc_stride)
|
427
|
-
|
428
|
-
for (doc_span_index, doc_span) in enumerate(doc_spans):
|
429
|
-
tokens = []
|
430
|
-
token_to_orig_map = {}
|
431
|
-
token_is_max_context = {}
|
432
|
-
segment_ids = []
|
433
|
-
tokens.append("[CLS]")
|
434
|
-
segment_ids.append(0)
|
435
|
-
for token in query_tokens:
|
436
|
-
tokens.append(token)
|
437
|
-
segment_ids.append(0)
|
438
|
-
tokens.append("[SEP]")
|
439
|
-
segment_ids.append(0)
|
440
|
-
|
441
|
-
for i in range(doc_span.length):
|
442
|
-
split_token_index = doc_span.start + i
|
443
|
-
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
444
|
-
|
445
|
-
is_max_context = _check_is_max_context(
|
446
|
-
doc_spans, doc_span_index, split_token_index
|
447
|
-
)
|
448
|
-
token_is_max_context[len(tokens)] = is_max_context
|
449
|
-
tokens.append(all_doc_tokens[split_token_index])
|
450
|
-
segment_ids.append(1)
|
451
|
-
tokens.append("[SEP]")
|
452
|
-
segment_ids.append(1)
|
453
|
-
|
454
|
-
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
455
|
-
|
456
|
-
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
457
|
-
# tokens are attended to.
|
458
|
-
input_mask = [1] * len(input_ids)
|
459
|
-
|
460
|
-
# Zero-pad up to the sequence length.
|
461
|
-
while len(input_ids) < max_seq_length:
|
462
|
-
input_ids.append(0)
|
463
|
-
input_mask.append(0)
|
464
|
-
segment_ids.append(0)
|
465
|
-
|
466
|
-
assert len(input_ids) == max_seq_length
|
467
|
-
assert len(input_mask) == max_seq_length
|
468
|
-
assert len(segment_ids) == max_seq_length
|
469
|
-
|
470
|
-
start_position = None
|
471
|
-
end_position = None
|
472
|
-
if is_training and not example.is_impossible:
|
473
|
-
# For training, if our document chunk does not contain an annotation
|
474
|
-
# we throw it out, since there is nothing to predict.
|
475
|
-
doc_start = doc_span.start
|
476
|
-
doc_end = doc_span.start + doc_span.length - 1
|
477
|
-
out_of_span = False
|
478
|
-
if not (
|
479
|
-
tok_start_position >= doc_start and tok_end_position <= doc_end
|
480
|
-
):
|
481
|
-
out_of_span = True
|
482
|
-
if out_of_span:
|
483
|
-
start_position = 0
|
484
|
-
end_position = 0
|
485
|
-
else:
|
486
|
-
doc_offset = len(query_tokens) + 2
|
487
|
-
start_position = tok_start_position - doc_start + doc_offset
|
488
|
-
end_position = tok_end_position - doc_start + doc_offset
|
489
|
-
|
490
|
-
if is_training and example.is_impossible:
|
491
|
-
start_position = 0
|
492
|
-
end_position = 0
|
493
|
-
|
494
|
-
if example_index < 20:
|
495
|
-
tf.logging.info("*** Example ***")
|
496
|
-
tf.logging.info("unique_id: %s" % (unique_id))
|
497
|
-
tf.logging.info("example_index: %s" % (example_index))
|
498
|
-
tf.logging.info("doc_span_index: %s" % (doc_span_index))
|
499
|
-
tf.logging.info(
|
500
|
-
"tokens: %s"
|
501
|
-
% " ".join([tokenization.printable_text(x) for x in tokens])
|
502
|
-
)
|
503
|
-
tf.logging.info(
|
504
|
-
"token_to_orig_map: %s"
|
505
|
-
% " ".join(
|
506
|
-
[
|
507
|
-
"%d:%d" % (x, y)
|
508
|
-
for (x, y) in six.iteritems(token_to_orig_map)
|
509
|
-
]
|
510
|
-
)
|
511
|
-
)
|
512
|
-
tf.logging.info(
|
513
|
-
"token_is_max_context: %s"
|
514
|
-
% " ".join(
|
515
|
-
[
|
516
|
-
"%d:%s" % (x, y)
|
517
|
-
for (x, y) in six.iteritems(token_is_max_context)
|
518
|
-
]
|
519
|
-
)
|
520
|
-
)
|
521
|
-
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
522
|
-
tf.logging.info(
|
523
|
-
"input_mask: %s" % " ".join([str(x) for x in input_mask])
|
524
|
-
)
|
525
|
-
tf.logging.info(
|
526
|
-
"segment_ids: %s" % " ".join([str(x) for x in segment_ids])
|
527
|
-
)
|
528
|
-
if is_training and example.is_impossible:
|
529
|
-
tf.logging.info("impossible example")
|
530
|
-
if is_training and not example.is_impossible:
|
531
|
-
answer_text = " ".join(tokens[start_position : (end_position + 1)])
|
532
|
-
tf.logging.info("start_position: %d" % (start_position))
|
533
|
-
tf.logging.info("end_position: %d" % (end_position))
|
534
|
-
tf.logging.info(
|
535
|
-
"answer: %s" % (tokenization.printable_text(answer_text))
|
536
|
-
)
|
537
|
-
|
538
|
-
feature = InputFeatures(
|
539
|
-
unique_id=unique_id,
|
540
|
-
example_index=example_index,
|
541
|
-
doc_span_index=doc_span_index,
|
542
|
-
tokens=tokens,
|
543
|
-
token_to_orig_map=token_to_orig_map,
|
544
|
-
token_is_max_context=token_is_max_context,
|
545
|
-
input_ids=input_ids,
|
546
|
-
input_mask=input_mask,
|
547
|
-
segment_ids=segment_ids,
|
548
|
-
start_position=start_position,
|
549
|
-
end_position=end_position,
|
550
|
-
is_impossible=example.is_impossible,
|
551
|
-
)
|
552
|
-
|
553
|
-
# Run callback
|
554
|
-
output_fn(feature)
|
555
|
-
|
556
|
-
unique_id += 1
|
557
|
-
|
558
|
-
|
559
|
-
def _improve_answer_span(
|
560
|
-
doc_tokens, input_start, input_end, tokenizer, orig_answer_text
|
561
|
-
):
|
562
|
-
"""Returns tokenized answer spans that better match the annotated answer."""
|
563
|
-
|
564
|
-
# The SQuAD annotations are character based. We first project them to
|
565
|
-
# whitespace-tokenized words. But then after WordPiece tokenization, we can
|
566
|
-
# often find a "better match". For example:
|
567
|
-
#
|
568
|
-
# Question: What year was John Smith born?
|
569
|
-
# Context: The leader was John Smith (1895-1943).
|
570
|
-
# Answer: 1895
|
571
|
-
#
|
572
|
-
# The original whitespace-tokenized answer will be "(1895-1943).". However
|
573
|
-
# after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
|
574
|
-
# the exact answer, 1895.
|
575
|
-
#
|
576
|
-
# However, this is not always possible. Consider the following:
|
577
|
-
#
|
578
|
-
# Question: What country is the top exporter of electornics?
|
579
|
-
# Context: The Japanese electronics industry is the lagest in the world.
|
580
|
-
# Answer: Japan
|
581
|
-
#
|
582
|
-
# In this case, the annotator chose "Japan" as a character sub-span of
|
583
|
-
# the word "Japanese". Since our WordPiece tokenizer does not split
|
584
|
-
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
|
585
|
-
# in SQuAD, but does happen.
|
586
|
-
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
|
587
|
-
|
588
|
-
for new_start in range(input_start, input_end + 1):
|
589
|
-
for new_end in range(input_end, new_start - 1, -1):
|
590
|
-
text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
|
591
|
-
if text_span == tok_answer_text:
|
592
|
-
return (new_start, new_end)
|
593
|
-
|
594
|
-
return (input_start, input_end)
|
595
|
-
|
596
|
-
|
597
|
-
def _check_is_max_context(doc_spans, cur_span_index, position):
|
598
|
-
"""Check if this is the 'max context' doc span for the token."""
|
599
|
-
|
600
|
-
# Because of the sliding window approach taken to scoring documents, a single
|
601
|
-
# token can appear in multiple documents. E.g.
|
602
|
-
# Doc: the man went to the store and bought a gallon of milk
|
603
|
-
# Span A: the man went to the
|
604
|
-
# Span B: to the store and bought
|
605
|
-
# Span C: and bought a gallon of
|
606
|
-
# ...
|
607
|
-
#
|
608
|
-
# Now the word 'bought' will have two scores from spans B and C. We only
|
609
|
-
# want to consider the score with "maximum context", which we define as
|
610
|
-
# the *minimum* of its left and right context (the *sum* of left and
|
611
|
-
# right context will always be the same, of course).
|
612
|
-
#
|
613
|
-
# In the example the maximum context for 'bought' would be span C since
|
614
|
-
# it has 1 left context and 3 right context, while span B has 4 left context
|
615
|
-
# and 0 right context.
|
616
|
-
best_score = None
|
617
|
-
best_span_index = None
|
618
|
-
for (span_index, doc_span) in enumerate(doc_spans):
|
619
|
-
end = doc_span.start + doc_span.length - 1
|
620
|
-
if position < doc_span.start:
|
621
|
-
continue
|
622
|
-
if position > end:
|
623
|
-
continue
|
624
|
-
num_left_context = position - doc_span.start
|
625
|
-
num_right_context = end - position
|
626
|
-
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
|
627
|
-
if best_score is None or score > best_score:
|
628
|
-
best_score = score
|
629
|
-
best_span_index = span_index
|
630
|
-
|
631
|
-
return cur_span_index == best_span_index
|
632
|
-
|
633
|
-
|
634
|
-
def create_model(
|
635
|
-
bert_config, is_training, input_ids, input_mask, segment_ids, use_one_hot_embeddings
|
636
|
-
):
|
637
|
-
"""Creates a classification model."""
|
638
|
-
model = modeling.BertModel(
|
639
|
-
config=bert_config,
|
640
|
-
is_training=is_training,
|
641
|
-
input_ids=input_ids,
|
642
|
-
input_mask=input_mask,
|
643
|
-
token_type_ids=segment_ids,
|
644
|
-
use_one_hot_embeddings=use_one_hot_embeddings,
|
645
|
-
)
|
646
|
-
|
647
|
-
final_hidden = model.get_sequence_output()
|
648
|
-
|
649
|
-
final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
|
650
|
-
batch_size = final_hidden_shape[0]
|
651
|
-
seq_length = final_hidden_shape[1]
|
652
|
-
hidden_size = final_hidden_shape[2]
|
653
|
-
|
654
|
-
output_weights = tf.get_variable(
|
655
|
-
"cls/squad/output_weights",
|
656
|
-
[2, hidden_size],
|
657
|
-
initializer=tf.truncated_normal_initializer(stddev=0.02),
|
658
|
-
)
|
659
|
-
|
660
|
-
output_bias = tf.get_variable(
|
661
|
-
"cls/squad/output_bias", [2], initializer=tf.zeros_initializer()
|
662
|
-
)
|
663
|
-
|
664
|
-
final_hidden_matrix = tf.reshape(
|
665
|
-
final_hidden, [batch_size * seq_length, hidden_size]
|
666
|
-
)
|
667
|
-
logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
|
668
|
-
logits = tf.nn.bias_add(logits, output_bias)
|
669
|
-
|
670
|
-
logits = tf.reshape(logits, [batch_size, seq_length, 2])
|
671
|
-
logits = tf.transpose(logits, [2, 0, 1])
|
672
|
-
|
673
|
-
unstacked_logits = tf.unstack(logits, axis=0)
|
674
|
-
|
675
|
-
(start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])
|
676
|
-
|
677
|
-
return (start_logits, end_logits)
|
678
|
-
|
679
|
-
|
680
|
-
def model_fn_builder(
|
681
|
-
bert_config,
|
682
|
-
init_checkpoint,
|
683
|
-
learning_rate,
|
684
|
-
num_train_steps,
|
685
|
-
num_warmup_steps,
|
686
|
-
use_tpu,
|
687
|
-
use_one_hot_embeddings,
|
688
|
-
):
|
689
|
-
"""Returns `model_fn` closure for TPUEstimator."""
|
690
|
-
|
691
|
-
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
692
|
-
"""The `model_fn` for TPUEstimator."""
|
693
|
-
|
694
|
-
tf.logging.info("*** Features ***")
|
695
|
-
for name in sorted(features.keys()):
|
696
|
-
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
697
|
-
|
698
|
-
unique_ids = features["unique_ids"]
|
699
|
-
input_ids = features["input_ids"]
|
700
|
-
input_mask = features["input_mask"]
|
701
|
-
segment_ids = features["segment_ids"]
|
702
|
-
|
703
|
-
is_training = mode == tf.estimator.ModeKeys.TRAIN
|
704
|
-
|
705
|
-
(start_logits, end_logits) = create_model(
|
706
|
-
bert_config=bert_config,
|
707
|
-
is_training=is_training,
|
708
|
-
input_ids=input_ids,
|
709
|
-
input_mask=input_mask,
|
710
|
-
segment_ids=segment_ids,
|
711
|
-
use_one_hot_embeddings=use_one_hot_embeddings,
|
712
|
-
)
|
713
|
-
|
714
|
-
tvars = tf.trainable_variables()
|
715
|
-
|
716
|
-
initialized_variable_names = {}
|
717
|
-
scaffold_fn = None
|
718
|
-
if init_checkpoint:
|
719
|
-
(
|
720
|
-
assignment_map,
|
721
|
-
initialized_variable_names,
|
722
|
-
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
|
723
|
-
if use_tpu:
|
724
|
-
|
725
|
-
def tpu_scaffold():
|
726
|
-
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
727
|
-
return tf.train.Scaffold()
|
728
|
-
|
729
|
-
scaffold_fn = tpu_scaffold
|
730
|
-
else:
|
731
|
-
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
732
|
-
|
733
|
-
tf.logging.info("**** Trainable Variables ****")
|
734
|
-
for var in tvars:
|
735
|
-
init_string = ""
|
736
|
-
if var.name in initialized_variable_names:
|
737
|
-
init_string = ", *INIT_FROM_CKPT*"
|
738
|
-
tf.logging.info(
|
739
|
-
" name = %s, shape = %s%s", var.name, var.shape, init_string
|
740
|
-
)
|
741
|
-
|
742
|
-
output_spec = None
|
743
|
-
if mode == tf.estimator.ModeKeys.TRAIN:
|
744
|
-
seq_length = modeling.get_shape_list(input_ids)[1]
|
745
|
-
|
746
|
-
def compute_loss(logits, positions):
|
747
|
-
one_hot_positions = tf.one_hot(
|
748
|
-
positions, depth=seq_length, dtype=tf.float32
|
749
|
-
)
|
750
|
-
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
751
|
-
loss = -tf.reduce_mean(
|
752
|
-
tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
|
753
|
-
)
|
754
|
-
return loss
|
755
|
-
|
756
|
-
start_positions = features["start_positions"]
|
757
|
-
end_positions = features["end_positions"]
|
758
|
-
|
759
|
-
start_loss = compute_loss(start_logits, start_positions)
|
760
|
-
end_loss = compute_loss(end_logits, end_positions)
|
761
|
-
|
762
|
-
total_loss = (start_loss + end_loss) / 2.0
|
763
|
-
|
764
|
-
train_op = optimization.create_optimizer(
|
765
|
-
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu
|
766
|
-
)
|
767
|
-
|
768
|
-
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
769
|
-
mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn
|
770
|
-
)
|
771
|
-
elif mode == tf.estimator.ModeKeys.PREDICT:
|
772
|
-
predictions = {
|
773
|
-
"unique_ids": unique_ids,
|
774
|
-
"start_logits": start_logits,
|
775
|
-
"end_logits": end_logits,
|
776
|
-
}
|
777
|
-
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
778
|
-
mode=mode, predictions=predictions, scaffold_fn=scaffold_fn
|
779
|
-
)
|
780
|
-
else:
|
781
|
-
raise ValueError("Only TRAIN and PREDICT modes are supported: %s" % (mode))
|
782
|
-
|
783
|
-
return output_spec
|
784
|
-
|
785
|
-
return model_fn
|
786
|
-
|
787
|
-
|
788
|
-
def input_fn_builder(input_file, seq_length, is_training, drop_remainder):
|
789
|
-
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
790
|
-
|
791
|
-
name_to_features = {
|
792
|
-
"unique_ids": tf.FixedLenFeature([], tf.int64),
|
793
|
-
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
|
794
|
-
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
|
795
|
-
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
|
796
|
-
}
|
797
|
-
|
798
|
-
if is_training:
|
799
|
-
name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64)
|
800
|
-
name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64)
|
801
|
-
|
802
|
-
def _decode_record(record, name_to_features):
|
803
|
-
"""Decodes a record to a TensorFlow example."""
|
804
|
-
example = tf.parse_single_example(record, name_to_features)
|
805
|
-
|
806
|
-
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
807
|
-
# So cast all int64 to int32.
|
808
|
-
for name in list(example.keys()):
|
809
|
-
t = example[name]
|
810
|
-
if t.dtype == tf.int64:
|
811
|
-
t = tf.to_int32(t)
|
812
|
-
example[name] = t
|
813
|
-
|
814
|
-
return example
|
815
|
-
|
816
|
-
def input_fn(params):
|
817
|
-
"""The actual input function."""
|
818
|
-
batch_size = params["batch_size"]
|
819
|
-
|
820
|
-
# For training, we want a lot of parallel reading and shuffling.
|
821
|
-
# For eval, we want no shuffling and parallel reading doesn't matter.
|
822
|
-
d = tf.data.TFRecordDataset(input_file)
|
823
|
-
if is_training:
|
824
|
-
d = d.repeat()
|
825
|
-
d = d.shuffle(buffer_size=100)
|
826
|
-
|
827
|
-
d = d.apply(
|
828
|
-
tf.contrib.data.map_and_batch(
|
829
|
-
lambda record: _decode_record(record, name_to_features),
|
830
|
-
batch_size=batch_size,
|
831
|
-
drop_remainder=drop_remainder,
|
832
|
-
)
|
833
|
-
)
|
834
|
-
|
835
|
-
return d
|
836
|
-
|
837
|
-
return input_fn
|
838
|
-
|
839
|
-
|
840
|
-
RawResult = collections.namedtuple(
|
841
|
-
"RawResult", ["unique_id", "start_logits", "end_logits"]
|
842
|
-
)
|
843
|
-
|
844
|
-
|
845
|
-
def write_predictions(
|
846
|
-
all_examples,
|
847
|
-
all_features,
|
848
|
-
all_results,
|
849
|
-
n_best_size,
|
850
|
-
max_answer_length,
|
851
|
-
do_lower_case,
|
852
|
-
output_prediction_file,
|
853
|
-
output_nbest_file,
|
854
|
-
output_null_log_odds_file,
|
855
|
-
):
|
856
|
-
"""Write final predictions to the json file and log-odds of null if needed."""
|
857
|
-
tf.logging.info("Writing predictions to: %s" % (output_prediction_file))
|
858
|
-
tf.logging.info("Writing nbest to: %s" % (output_nbest_file))
|
859
|
-
|
860
|
-
example_index_to_features = collections.defaultdict(list)
|
861
|
-
for feature in all_features:
|
862
|
-
example_index_to_features[feature.example_index].append(feature)
|
863
|
-
|
864
|
-
unique_id_to_result = {}
|
865
|
-
for result in all_results:
|
866
|
-
unique_id_to_result[result.unique_id] = result
|
867
|
-
|
868
|
-
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
869
|
-
"PrelimPrediction",
|
870
|
-
["feature_index", "start_index", "end_index", "start_logit", "end_logit"],
|
871
|
-
)
|
872
|
-
|
873
|
-
all_predictions = collections.OrderedDict()
|
874
|
-
all_nbest_json = collections.OrderedDict()
|
875
|
-
scores_diff_json = collections.OrderedDict()
|
876
|
-
|
877
|
-
for (example_index, example) in enumerate(all_examples):
|
878
|
-
features = example_index_to_features[example_index]
|
879
|
-
|
880
|
-
prelim_predictions = []
|
881
|
-
# keep track of the minimum score of null start+end of position 0
|
882
|
-
score_null = 1000000 # large and positive
|
883
|
-
min_null_feature_index = 0 # the paragraph slice with min mull score
|
884
|
-
null_start_logit = 0 # the start logit at the slice with min null score
|
885
|
-
null_end_logit = 0 # the end logit at the slice with min null score
|
886
|
-
for (feature_index, feature) in enumerate(features):
|
887
|
-
result = unique_id_to_result[feature.unique_id]
|
888
|
-
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
|
889
|
-
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
|
890
|
-
# if we could have irrelevant answers, get the min score of irrelevant
|
891
|
-
if FLAGS.version_2_with_negative:
|
892
|
-
feature_null_score = result.start_logits[0] + result.end_logits[0]
|
893
|
-
if feature_null_score < score_null:
|
894
|
-
score_null = feature_null_score
|
895
|
-
min_null_feature_index = feature_index
|
896
|
-
null_start_logit = result.start_logits[0]
|
897
|
-
null_end_logit = result.end_logits[0]
|
898
|
-
for start_index in start_indexes:
|
899
|
-
for end_index in end_indexes:
|
900
|
-
# We could hypothetically create invalid predictions, e.g., predict
|
901
|
-
# that the start of the span is in the question. We throw out all
|
902
|
-
# invalid predictions.
|
903
|
-
if start_index >= len(feature.tokens):
|
904
|
-
continue
|
905
|
-
if end_index >= len(feature.tokens):
|
906
|
-
continue
|
907
|
-
if start_index not in feature.token_to_orig_map:
|
908
|
-
continue
|
909
|
-
if end_index not in feature.token_to_orig_map:
|
910
|
-
continue
|
911
|
-
if not feature.token_is_max_context.get(start_index, False):
|
912
|
-
continue
|
913
|
-
if end_index < start_index:
|
914
|
-
continue
|
915
|
-
length = end_index - start_index + 1
|
916
|
-
if length > max_answer_length:
|
917
|
-
continue
|
918
|
-
prelim_predictions.append(
|
919
|
-
_PrelimPrediction(
|
920
|
-
feature_index=feature_index,
|
921
|
-
start_index=start_index,
|
922
|
-
end_index=end_index,
|
923
|
-
start_logit=result.start_logits[start_index],
|
924
|
-
end_logit=result.end_logits[end_index],
|
925
|
-
)
|
926
|
-
)
|
927
|
-
|
928
|
-
if FLAGS.version_2_with_negative:
|
929
|
-
prelim_predictions.append(
|
930
|
-
_PrelimPrediction(
|
931
|
-
feature_index=min_null_feature_index,
|
932
|
-
start_index=0,
|
933
|
-
end_index=0,
|
934
|
-
start_logit=null_start_logit,
|
935
|
-
end_logit=null_end_logit,
|
936
|
-
)
|
937
|
-
)
|
938
|
-
prelim_predictions = sorted(
|
939
|
-
prelim_predictions,
|
940
|
-
key=lambda x: (x.start_logit + x.end_logit),
|
941
|
-
reverse=True,
|
942
|
-
)
|
943
|
-
|
944
|
-
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
945
|
-
"NbestPrediction", ["text", "start_logit", "end_logit"]
|
946
|
-
)
|
947
|
-
|
948
|
-
seen_predictions = {}
|
949
|
-
nbest = []
|
950
|
-
for pred in prelim_predictions:
|
951
|
-
if len(nbest) >= n_best_size:
|
952
|
-
break
|
953
|
-
feature = features[pred.feature_index]
|
954
|
-
if pred.start_index > 0: # this is a non-null prediction
|
955
|
-
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
|
956
|
-
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
957
|
-
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
958
|
-
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
|
959
|
-
tok_text = " ".join(tok_tokens)
|
960
|
-
|
961
|
-
# De-tokenize WordPieces that have been split off.
|
962
|
-
tok_text = tok_text.replace(" ##", "")
|
963
|
-
tok_text = tok_text.replace("##", "")
|
964
|
-
|
965
|
-
# Clean whitespace
|
966
|
-
tok_text = tok_text.strip()
|
967
|
-
tok_text = " ".join(tok_text.split())
|
968
|
-
orig_text = " ".join(orig_tokens)
|
969
|
-
|
970
|
-
final_text = get_final_text(tok_text, orig_text, do_lower_case)
|
971
|
-
if final_text in seen_predictions:
|
972
|
-
continue
|
973
|
-
|
974
|
-
seen_predictions[final_text] = True
|
975
|
-
else:
|
976
|
-
final_text = ""
|
977
|
-
seen_predictions[final_text] = True
|
978
|
-
|
979
|
-
nbest.append(
|
980
|
-
_NbestPrediction(
|
981
|
-
text=final_text,
|
982
|
-
start_logit=pred.start_logit,
|
983
|
-
end_logit=pred.end_logit,
|
984
|
-
)
|
985
|
-
)
|
986
|
-
|
987
|
-
# if we didn't inlude the empty option in the n-best, inlcude it
|
988
|
-
if FLAGS.version_2_with_negative:
|
989
|
-
if "" not in seen_predictions:
|
990
|
-
nbest.append(
|
991
|
-
_NbestPrediction(
|
992
|
-
text="", start_logit=null_start_logit, end_logit=null_end_logit
|
993
|
-
)
|
994
|
-
)
|
995
|
-
# In very rare edge cases we could have no valid predictions. So we
|
996
|
-
# just create a nonce prediction in this case to avoid failure.
|
997
|
-
if not nbest:
|
998
|
-
nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
999
|
-
|
1000
|
-
assert len(nbest) >= 1
|
1001
|
-
|
1002
|
-
total_scores = []
|
1003
|
-
best_non_null_entry = None
|
1004
|
-
for entry in nbest:
|
1005
|
-
total_scores.append(entry.start_logit + entry.end_logit)
|
1006
|
-
if not best_non_null_entry:
|
1007
|
-
if entry.text:
|
1008
|
-
best_non_null_entry = entry
|
1009
|
-
|
1010
|
-
probs = _compute_softmax(total_scores)
|
1011
|
-
|
1012
|
-
nbest_json = []
|
1013
|
-
for (i, entry) in enumerate(nbest):
|
1014
|
-
output = collections.OrderedDict()
|
1015
|
-
output["text"] = entry.text
|
1016
|
-
output["probability"] = probs[i]
|
1017
|
-
output["start_logit"] = entry.start_logit
|
1018
|
-
output["end_logit"] = entry.end_logit
|
1019
|
-
nbest_json.append(output)
|
1020
|
-
|
1021
|
-
assert len(nbest_json) >= 1
|
1022
|
-
|
1023
|
-
if not FLAGS.version_2_with_negative:
|
1024
|
-
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
1025
|
-
else:
|
1026
|
-
# predict "" iff the null score - the score of best non-null > threshold
|
1027
|
-
score_diff = (
|
1028
|
-
score_null
|
1029
|
-
- best_non_null_entry.start_logit
|
1030
|
-
- (best_non_null_entry.end_logit)
|
1031
|
-
)
|
1032
|
-
scores_diff_json[example.qas_id] = score_diff
|
1033
|
-
if score_diff > FLAGS.null_score_diff_threshold:
|
1034
|
-
all_predictions[example.qas_id] = ""
|
1035
|
-
else:
|
1036
|
-
all_predictions[example.qas_id] = best_non_null_entry.text
|
1037
|
-
|
1038
|
-
all_nbest_json[example.qas_id] = nbest_json
|
1039
|
-
|
1040
|
-
with tf.gfile.GFile(output_prediction_file, "w") as writer:
|
1041
|
-
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
1042
|
-
|
1043
|
-
with tf.gfile.GFile(output_nbest_file, "w") as writer:
|
1044
|
-
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
1045
|
-
|
1046
|
-
if FLAGS.version_2_with_negative:
|
1047
|
-
with tf.gfile.GFile(output_null_log_odds_file, "w") as writer:
|
1048
|
-
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
1049
|
-
|
1050
|
-
|
1051
|
-
def get_final_text(pred_text, orig_text, do_lower_case):
|
1052
|
-
"""Project the tokenized prediction back to the original text."""
|
1053
|
-
|
1054
|
-
# When we created the data, we kept track of the alignment between original
|
1055
|
-
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
|
1056
|
-
# now `orig_text` contains the span of our original text corresponding to the
|
1057
|
-
# span that we predicted.
|
1058
|
-
#
|
1059
|
-
# However, `orig_text` may contain extra characters that we don't want in
|
1060
|
-
# our prediction.
|
1061
|
-
#
|
1062
|
-
# For example, let's say:
|
1063
|
-
# pred_text = steve smith
|
1064
|
-
# orig_text = Steve Smith's
|
1065
|
-
#
|
1066
|
-
# We don't want to return `orig_text` because it contains the extra "'s".
|
1067
|
-
#
|
1068
|
-
# We don't want to return `pred_text` because it's already been normalized
|
1069
|
-
# (the SQuAD eval script also does punctuation stripping/lower casing but
|
1070
|
-
# our tokenizer does additional normalization like stripping accent
|
1071
|
-
# characters).
|
1072
|
-
#
|
1073
|
-
# What we really want to return is "Steve Smith".
|
1074
|
-
#
|
1075
|
-
# Therefore, we have to apply a semi-complicated alignment heruistic between
|
1076
|
-
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This
|
1077
|
-
# can fail in certain cases in which case we just return `orig_text`.
|
1078
|
-
|
1079
|
-
def _strip_spaces(text):
|
1080
|
-
ns_chars = []
|
1081
|
-
ns_to_s_map = collections.OrderedDict()
|
1082
|
-
for (i, c) in enumerate(text):
|
1083
|
-
if c == " ":
|
1084
|
-
continue
|
1085
|
-
ns_to_s_map[len(ns_chars)] = i
|
1086
|
-
ns_chars.append(c)
|
1087
|
-
ns_text = "".join(ns_chars)
|
1088
|
-
return (ns_text, ns_to_s_map)
|
1089
|
-
|
1090
|
-
# We first tokenize `orig_text`, strip whitespace from the result
|
1091
|
-
# and `pred_text`, and check if they are the same length. If they are
|
1092
|
-
# NOT the same length, the heuristic has failed. If they are the same
|
1093
|
-
# length, we assume the characters are one-to-one aligned.
|
1094
|
-
tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
|
1095
|
-
|
1096
|
-
tok_text = " ".join(tokenizer.tokenize(orig_text))
|
1097
|
-
|
1098
|
-
start_position = tok_text.find(pred_text)
|
1099
|
-
if start_position == -1:
|
1100
|
-
if FLAGS.verbose_logging:
|
1101
|
-
tf.logging.info(
|
1102
|
-
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text)
|
1103
|
-
)
|
1104
|
-
return orig_text
|
1105
|
-
end_position = start_position + len(pred_text) - 1
|
1106
|
-
|
1107
|
-
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
|
1108
|
-
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
|
1109
|
-
|
1110
|
-
if len(orig_ns_text) != len(tok_ns_text):
|
1111
|
-
if FLAGS.verbose_logging:
|
1112
|
-
tf.logging.info(
|
1113
|
-
"Length not equal after stripping spaces: '%s' vs '%s'",
|
1114
|
-
orig_ns_text,
|
1115
|
-
tok_ns_text,
|
1116
|
-
)
|
1117
|
-
return orig_text
|
1118
|
-
|
1119
|
-
# We then project the characters in `pred_text` back to `orig_text` using
|
1120
|
-
# the character-to-character alignment.
|
1121
|
-
tok_s_to_ns_map = {}
|
1122
|
-
for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
|
1123
|
-
tok_s_to_ns_map[tok_index] = i
|
1124
|
-
|
1125
|
-
orig_start_position = None
|
1126
|
-
if start_position in tok_s_to_ns_map:
|
1127
|
-
ns_start_position = tok_s_to_ns_map[start_position]
|
1128
|
-
if ns_start_position in orig_ns_to_s_map:
|
1129
|
-
orig_start_position = orig_ns_to_s_map[ns_start_position]
|
1130
|
-
|
1131
|
-
if orig_start_position is None:
|
1132
|
-
if FLAGS.verbose_logging:
|
1133
|
-
tf.logging.info("Couldn't map start position")
|
1134
|
-
return orig_text
|
1135
|
-
|
1136
|
-
orig_end_position = None
|
1137
|
-
if end_position in tok_s_to_ns_map:
|
1138
|
-
ns_end_position = tok_s_to_ns_map[end_position]
|
1139
|
-
if ns_end_position in orig_ns_to_s_map:
|
1140
|
-
orig_end_position = orig_ns_to_s_map[ns_end_position]
|
1141
|
-
|
1142
|
-
if orig_end_position is None:
|
1143
|
-
if FLAGS.verbose_logging:
|
1144
|
-
tf.logging.info("Couldn't map end position")
|
1145
|
-
return orig_text
|
1146
|
-
|
1147
|
-
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
|
1148
|
-
return output_text
|
1149
|
-
|
1150
|
-
|
1151
|
-
def _get_best_indexes(logits, n_best_size):
|
1152
|
-
"""Get the n-best logits from a list."""
|
1153
|
-
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
1154
|
-
|
1155
|
-
best_indexes = []
|
1156
|
-
for i in range(len(index_and_score)):
|
1157
|
-
if i >= n_best_size:
|
1158
|
-
break
|
1159
|
-
best_indexes.append(index_and_score[i][0])
|
1160
|
-
return best_indexes
|
1161
|
-
|
1162
|
-
|
1163
|
-
def _compute_softmax(scores):
|
1164
|
-
"""Compute softmax probability over raw logits."""
|
1165
|
-
if not scores:
|
1166
|
-
return []
|
1167
|
-
|
1168
|
-
max_score = None
|
1169
|
-
for score in scores:
|
1170
|
-
if max_score is None or score > max_score:
|
1171
|
-
max_score = score
|
1172
|
-
|
1173
|
-
exp_scores = []
|
1174
|
-
total_sum = 0.0
|
1175
|
-
for score in scores:
|
1176
|
-
x = math.exp(score - max_score)
|
1177
|
-
exp_scores.append(x)
|
1178
|
-
total_sum += x
|
1179
|
-
|
1180
|
-
probs = []
|
1181
|
-
for score in exp_scores:
|
1182
|
-
probs.append(score / total_sum)
|
1183
|
-
return probs
|
1184
|
-
|
1185
|
-
|
1186
|
-
class FeatureWriter(object):
|
1187
|
-
"""Writes InputFeature to TF example file."""
|
1188
|
-
|
1189
|
-
def __init__(self, filename, is_training):
|
1190
|
-
self.filename = filename
|
1191
|
-
self.is_training = is_training
|
1192
|
-
self.num_features = 0
|
1193
|
-
self._writer = tf.python_io.TFRecordWriter(filename)
|
1194
|
-
|
1195
|
-
def process_feature(self, feature):
|
1196
|
-
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
|
1197
|
-
self.num_features += 1
|
1198
|
-
|
1199
|
-
def create_int_feature(values):
|
1200
|
-
feature = tf.train.Feature(
|
1201
|
-
int64_list=tf.train.Int64List(value=list(values))
|
1202
|
-
)
|
1203
|
-
return feature
|
1204
|
-
|
1205
|
-
features = collections.OrderedDict()
|
1206
|
-
features["unique_ids"] = create_int_feature([feature.unique_id])
|
1207
|
-
features["input_ids"] = create_int_feature(feature.input_ids)
|
1208
|
-
features["input_mask"] = create_int_feature(feature.input_mask)
|
1209
|
-
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
1210
|
-
|
1211
|
-
if self.is_training:
|
1212
|
-
features["start_positions"] = create_int_feature([feature.start_position])
|
1213
|
-
features["end_positions"] = create_int_feature([feature.end_position])
|
1214
|
-
impossible = 0
|
1215
|
-
if feature.is_impossible:
|
1216
|
-
impossible = 1
|
1217
|
-
features["is_impossible"] = create_int_feature([impossible])
|
1218
|
-
|
1219
|
-
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
1220
|
-
self._writer.write(tf_example.SerializeToString())
|
1221
|
-
|
1222
|
-
def close(self):
|
1223
|
-
self._writer.close()
|
1224
|
-
|
1225
|
-
|
1226
|
-
def validate_flags_or_throw(bert_config):
|
1227
|
-
"""Validate the input FLAGS or throw an exception."""
|
1228
|
-
tokenization.validate_case_matches_checkpoint(
|
1229
|
-
FLAGS.do_lower_case, FLAGS.init_checkpoint
|
1230
|
-
)
|
1231
|
-
|
1232
|
-
if not FLAGS.do_train and not FLAGS.do_predict:
|
1233
|
-
raise ValueError("At least one of `do_train` or `do_predict` must be True.")
|
1234
|
-
|
1235
|
-
if FLAGS.do_train:
|
1236
|
-
if not FLAGS.train_file:
|
1237
|
-
raise ValueError(
|
1238
|
-
"If `do_train` is True, then `train_file` must be specified."
|
1239
|
-
)
|
1240
|
-
if FLAGS.do_predict:
|
1241
|
-
if not FLAGS.predict_file:
|
1242
|
-
raise ValueError(
|
1243
|
-
"If `do_predict` is True, then `predict_file` must be specified."
|
1244
|
-
)
|
1245
|
-
|
1246
|
-
if FLAGS.max_seq_length > bert_config.max_position_embeddings:
|
1247
|
-
raise ValueError(
|
1248
|
-
"Cannot use sequence length %d because the BERT model "
|
1249
|
-
"was only trained up to sequence length %d"
|
1250
|
-
% (FLAGS.max_seq_length, bert_config.max_position_embeddings)
|
1251
|
-
)
|
1252
|
-
|
1253
|
-
if FLAGS.max_seq_length <= FLAGS.max_query_length + 3:
|
1254
|
-
raise ValueError(
|
1255
|
-
"The max_seq_length (%d) must be greater than max_query_length "
|
1256
|
-
"(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length)
|
1257
|
-
)
|
1258
|
-
|
1259
|
-
|
1260
|
-
def main(_):
|
1261
|
-
tf.logging.set_verbosity(tf.logging.INFO)
|
1262
|
-
logger = tf.get_logger()
|
1263
|
-
logger.propagate = False
|
1264
|
-
|
1265
|
-
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
1266
|
-
|
1267
|
-
validate_flags_or_throw(bert_config)
|
1268
|
-
|
1269
|
-
tf.gfile.MakeDirs(FLAGS.output_dir)
|
1270
|
-
|
1271
|
-
tokenizer = tokenization.FullTokenizer(
|
1272
|
-
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
|
1273
|
-
)
|
1274
|
-
|
1275
|
-
tpu_cluster_resolver = None
|
1276
|
-
if FLAGS.use_tpu and FLAGS.tpu_name:
|
1277
|
-
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
1278
|
-
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project
|
1279
|
-
)
|
1280
|
-
|
1281
|
-
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
1282
|
-
run_config = tf.contrib.tpu.RunConfig(
|
1283
|
-
cluster=tpu_cluster_resolver,
|
1284
|
-
master=FLAGS.master,
|
1285
|
-
model_dir=FLAGS.output_dir,
|
1286
|
-
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
1287
|
-
tpu_config=tf.contrib.tpu.TPUConfig(
|
1288
|
-
iterations_per_loop=FLAGS.iterations_per_loop,
|
1289
|
-
num_shards=FLAGS.num_tpu_cores,
|
1290
|
-
per_host_input_for_training=is_per_host,
|
1291
|
-
),
|
1292
|
-
)
|
1293
|
-
|
1294
|
-
train_examples = None
|
1295
|
-
num_train_steps = None
|
1296
|
-
num_warmup_steps = None
|
1297
|
-
if FLAGS.do_train:
|
1298
|
-
train_examples = read_squad_examples(
|
1299
|
-
input_file=FLAGS.train_file, is_training=True
|
1300
|
-
)
|
1301
|
-
num_train_steps = int(
|
1302
|
-
len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs
|
1303
|
-
)
|
1304
|
-
num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
|
1305
|
-
|
1306
|
-
# Pre-shuffle the input to avoid having to make a very large shuffle
|
1307
|
-
# buffer in in the `input_fn`.
|
1308
|
-
rng = random.Random(12345)
|
1309
|
-
rng.shuffle(train_examples)
|
1310
|
-
|
1311
|
-
model_fn = model_fn_builder(
|
1312
|
-
bert_config=bert_config,
|
1313
|
-
init_checkpoint=FLAGS.init_checkpoint,
|
1314
|
-
learning_rate=FLAGS.learning_rate,
|
1315
|
-
num_train_steps=num_train_steps,
|
1316
|
-
num_warmup_steps=num_warmup_steps,
|
1317
|
-
use_tpu=FLAGS.use_tpu,
|
1318
|
-
use_one_hot_embeddings=FLAGS.use_tpu,
|
1319
|
-
)
|
1320
|
-
|
1321
|
-
# If TPU is not available, this will fall back to normal Estimator on CPU
|
1322
|
-
# or GPU.
|
1323
|
-
estimator = tf.contrib.tpu.TPUEstimator(
|
1324
|
-
use_tpu=FLAGS.use_tpu,
|
1325
|
-
model_fn=model_fn,
|
1326
|
-
config=run_config,
|
1327
|
-
train_batch_size=FLAGS.train_batch_size,
|
1328
|
-
predict_batch_size=FLAGS.predict_batch_size,
|
1329
|
-
)
|
1330
|
-
|
1331
|
-
if FLAGS.do_train:
|
1332
|
-
# We write to a temporary file to avoid storing very large constant tensors
|
1333
|
-
# in memory.
|
1334
|
-
train_writer = FeatureWriter(
|
1335
|
-
filename=os.path.join(FLAGS.output_dir, "train.tf_record"), is_training=True
|
1336
|
-
)
|
1337
|
-
convert_examples_to_features(
|
1338
|
-
examples=train_examples,
|
1339
|
-
tokenizer=tokenizer,
|
1340
|
-
max_seq_length=FLAGS.max_seq_length,
|
1341
|
-
doc_stride=FLAGS.doc_stride,
|
1342
|
-
max_query_length=FLAGS.max_query_length,
|
1343
|
-
is_training=True,
|
1344
|
-
output_fn=train_writer.process_feature,
|
1345
|
-
)
|
1346
|
-
train_writer.close()
|
1347
|
-
|
1348
|
-
tf.logging.info("***** Running training *****")
|
1349
|
-
tf.logging.info(" Num orig examples = %d", len(train_examples))
|
1350
|
-
tf.logging.info(" Num split examples = %d", train_writer.num_features)
|
1351
|
-
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
1352
|
-
tf.logging.info(" Num steps = %d", num_train_steps)
|
1353
|
-
del train_examples
|
1354
|
-
|
1355
|
-
train_input_fn = input_fn_builder(
|
1356
|
-
input_file=train_writer.filename,
|
1357
|
-
seq_length=FLAGS.max_seq_length,
|
1358
|
-
is_training=True,
|
1359
|
-
drop_remainder=True,
|
1360
|
-
)
|
1361
|
-
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
|
1362
|
-
|
1363
|
-
if FLAGS.do_predict:
|
1364
|
-
eval_examples = read_squad_examples(
|
1365
|
-
input_file=FLAGS.predict_file, is_training=False
|
1366
|
-
)
|
1367
|
-
|
1368
|
-
eval_writer = FeatureWriter(
|
1369
|
-
filename=os.path.join(FLAGS.output_dir, "eval.tf_record"), is_training=False
|
1370
|
-
)
|
1371
|
-
eval_features = []
|
1372
|
-
|
1373
|
-
def append_feature(feature):
|
1374
|
-
eval_features.append(feature)
|
1375
|
-
eval_writer.process_feature(feature)
|
1376
|
-
|
1377
|
-
convert_examples_to_features(
|
1378
|
-
examples=eval_examples,
|
1379
|
-
tokenizer=tokenizer,
|
1380
|
-
max_seq_length=FLAGS.max_seq_length,
|
1381
|
-
doc_stride=FLAGS.doc_stride,
|
1382
|
-
max_query_length=FLAGS.max_query_length,
|
1383
|
-
is_training=False,
|
1384
|
-
output_fn=append_feature,
|
1385
|
-
)
|
1386
|
-
eval_writer.close()
|
1387
|
-
|
1388
|
-
tf.logging.info("***** Running predictions *****")
|
1389
|
-
tf.logging.info(" Num orig examples = %d", len(eval_examples))
|
1390
|
-
tf.logging.info(" Num split examples = %d", len(eval_features))
|
1391
|
-
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
|
1392
|
-
|
1393
|
-
all_results = []
|
1394
|
-
|
1395
|
-
predict_input_fn = input_fn_builder(
|
1396
|
-
input_file=eval_writer.filename,
|
1397
|
-
seq_length=FLAGS.max_seq_length,
|
1398
|
-
is_training=False,
|
1399
|
-
drop_remainder=False,
|
1400
|
-
)
|
1401
|
-
|
1402
|
-
# If running eval on the TPU, you will need to specify the number of
|
1403
|
-
# steps.
|
1404
|
-
all_results = []
|
1405
|
-
for result in estimator.predict(predict_input_fn, yield_single_examples=True):
|
1406
|
-
if len(all_results) % 1000 == 0:
|
1407
|
-
tf.logging.info("Processing example: %d" % (len(all_results)))
|
1408
|
-
unique_id = int(result["unique_ids"])
|
1409
|
-
start_logits = [float(x) for x in result["start_logits"].flat]
|
1410
|
-
end_logits = [float(x) for x in result["end_logits"].flat]
|
1411
|
-
all_results.append(
|
1412
|
-
RawResult(
|
1413
|
-
unique_id=unique_id,
|
1414
|
-
start_logits=start_logits,
|
1415
|
-
end_logits=end_logits,
|
1416
|
-
)
|
1417
|
-
)
|
1418
|
-
|
1419
|
-
output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json")
|
1420
|
-
output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json")
|
1421
|
-
output_null_log_odds_file = os.path.join(FLAGS.output_dir, "null_odds.json")
|
1422
|
-
|
1423
|
-
write_predictions(
|
1424
|
-
eval_examples,
|
1425
|
-
eval_features,
|
1426
|
-
all_results,
|
1427
|
-
FLAGS.n_best_size,
|
1428
|
-
FLAGS.max_answer_length,
|
1429
|
-
FLAGS.do_lower_case,
|
1430
|
-
output_prediction_file,
|
1431
|
-
output_nbest_file,
|
1432
|
-
output_null_log_odds_file,
|
1433
|
-
)
|
1434
|
-
|
1435
|
-
|
1436
|
-
if __name__ == "__main__":
|
1437
|
-
flags.mark_flag_as_required("vocab_file")
|
1438
|
-
flags.mark_flag_as_required("bert_config_file")
|
1439
|
-
flags.mark_flag_as_required("output_dir")
|
1440
|
-
tf.app.run()
|