SinaTools 0.1.40__py2.py3-none-any.whl → 1.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/METADATA +1 -1
- SinaTools-1.0.1.dist-info/RECORD +73 -0
- sinatools/VERSION +1 -1
- sinatools/ner/__init__.py +5 -7
- sinatools/ner/trainers/BertNestedTrainer.py +203 -203
- sinatools/ner/trainers/BertTrainer.py +163 -163
- sinatools/ner/trainers/__init__.py +2 -2
- SinaTools-0.1.40.dist-info/RECORD +0 -123
- sinatools/arabert/arabert/__init__.py +0 -14
- sinatools/arabert/arabert/create_classification_data.py +0 -260
- sinatools/arabert/arabert/create_pretraining_data.py +0 -534
- sinatools/arabert/arabert/extract_features.py +0 -444
- sinatools/arabert/arabert/lamb_optimizer.py +0 -158
- sinatools/arabert/arabert/modeling.py +0 -1027
- sinatools/arabert/arabert/optimization.py +0 -202
- sinatools/arabert/arabert/run_classifier.py +0 -1078
- sinatools/arabert/arabert/run_pretraining.py +0 -593
- sinatools/arabert/arabert/run_squad.py +0 -1440
- sinatools/arabert/arabert/tokenization.py +0 -414
- sinatools/arabert/araelectra/__init__.py +0 -1
- sinatools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +0 -103
- sinatools/arabert/araelectra/build_pretraining_dataset.py +0 -230
- sinatools/arabert/araelectra/build_pretraining_dataset_single_file.py +0 -90
- sinatools/arabert/araelectra/configure_finetuning.py +0 -172
- sinatools/arabert/araelectra/configure_pretraining.py +0 -143
- sinatools/arabert/araelectra/finetune/__init__.py +0 -14
- sinatools/arabert/araelectra/finetune/feature_spec.py +0 -56
- sinatools/arabert/araelectra/finetune/preprocessing.py +0 -173
- sinatools/arabert/araelectra/finetune/scorer.py +0 -54
- sinatools/arabert/araelectra/finetune/task.py +0 -74
- sinatools/arabert/araelectra/finetune/task_builder.py +0 -70
- sinatools/arabert/araelectra/flops_computation.py +0 -215
- sinatools/arabert/araelectra/model/__init__.py +0 -14
- sinatools/arabert/araelectra/model/modeling.py +0 -1029
- sinatools/arabert/araelectra/model/optimization.py +0 -193
- sinatools/arabert/araelectra/model/tokenization.py +0 -355
- sinatools/arabert/araelectra/pretrain/__init__.py +0 -14
- sinatools/arabert/araelectra/pretrain/pretrain_data.py +0 -160
- sinatools/arabert/araelectra/pretrain/pretrain_helpers.py +0 -229
- sinatools/arabert/araelectra/run_finetuning.py +0 -323
- sinatools/arabert/araelectra/run_pretraining.py +0 -469
- sinatools/arabert/araelectra/util/__init__.py +0 -14
- sinatools/arabert/araelectra/util/training_utils.py +0 -112
- sinatools/arabert/araelectra/util/utils.py +0 -109
- sinatools/arabert/aragpt2/__init__.py +0 -2
- sinatools/arabert/aragpt2/create_pretraining_data.py +0 -95
- sinatools/arabert/aragpt2/gpt2/__init__.py +0 -2
- sinatools/arabert/aragpt2/gpt2/lamb_optimizer.py +0 -158
- sinatools/arabert/aragpt2/gpt2/optimization.py +0 -225
- sinatools/arabert/aragpt2/gpt2/run_pretraining.py +0 -397
- sinatools/arabert/aragpt2/grover/__init__.py +0 -0
- sinatools/arabert/aragpt2/grover/dataloader.py +0 -161
- sinatools/arabert/aragpt2/grover/modeling.py +0 -803
- sinatools/arabert/aragpt2/grover/modeling_gpt2.py +0 -1196
- sinatools/arabert/aragpt2/grover/optimization_adafactor.py +0 -234
- sinatools/arabert/aragpt2/grover/train_tpu.py +0 -187
- sinatools/arabert/aragpt2/grover/utils.py +0 -234
- sinatools/arabert/aragpt2/train_bpe_tokenizer.py +0 -59
- {SinaTools-0.1.40.data → SinaTools-1.0.1.data}/data/sinatools/environment.yml +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/AUTHORS.rst +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/LICENSE +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/WHEEL +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/entry_points.txt +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/top_level.txt +0 -0
@@ -1,803 +0,0 @@
|
|
1
|
-
# Original work Copyright 2018 The Google AI Language Team Authors.
|
2
|
-
# Modified work Copyright 2019 Rowan Zellers
|
3
|
-
#
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
-
# you may not use this file except in compliance with the License.
|
6
|
-
# You may obtain a copy of the License at
|
7
|
-
#
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
-
#
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
-
# See the License for the specific language governing permissions and
|
14
|
-
# limitations under the License.
|
15
|
-
|
16
|
-
import copy
|
17
|
-
import json
|
18
|
-
import math
|
19
|
-
|
20
|
-
import six
|
21
|
-
import tensorflow as tf
|
22
|
-
|
23
|
-
from grover import optimization_adafactor
|
24
|
-
from grover.utils import get_assignment_map_from_checkpoint, get_shape_list, get_attention_mask, gelu, layer_norm, dropout, \
|
25
|
-
construct_scalar_host_call
|
26
|
-
|
27
|
-
class GroverConfig(object):
|
28
|
-
"""Configuration for `GroverModel`"""
|
29
|
-
|
30
|
-
def __init__(self,
|
31
|
-
vocab_size,
|
32
|
-
hidden_size=768,
|
33
|
-
num_hidden_layers=12,
|
34
|
-
num_attention_heads=12,
|
35
|
-
intermediate_size=3072,
|
36
|
-
hidden_act="gelu",
|
37
|
-
hidden_dropout_prob=0.1,
|
38
|
-
attention_probs_dropout_prob=0.1,
|
39
|
-
max_position_embeddings=512,
|
40
|
-
initializer_range=0.02):
|
41
|
-
"""Constructs NewsConfig.
|
42
|
-
|
43
|
-
Args:
|
44
|
-
vocab_size: Vocabulary size of `inputs_ids` in `GroverModel`.
|
45
|
-
hidden_size: Size of the layers
|
46
|
-
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
47
|
-
num_attention_heads: Number of attention heads for each attention layer in
|
48
|
-
the Transformer encoder.
|
49
|
-
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
50
|
-
layer in the Transformer encoder.
|
51
|
-
hidden_act: The non-linear activation function (function or string) in the
|
52
|
-
encoder and pooler.
|
53
|
-
hidden_dropout_prob: The dropout probability for all fully connected
|
54
|
-
layers in the embeddings, encoder, and pooler.
|
55
|
-
attention_probs_dropout_prob: The dropout ratio for the attention
|
56
|
-
probabilities.
|
57
|
-
max_position_embeddings: The maximum sequence length that this model might
|
58
|
-
ever be used with. Typically set this to something large just in case
|
59
|
-
(e.g., 512 or 1024 or 2048).
|
60
|
-
initializer_range: The stdev of the truncated_normal_initializer for
|
61
|
-
initializing all weight matrices.
|
62
|
-
"""
|
63
|
-
self.vocab_size = vocab_size
|
64
|
-
self.hidden_size = hidden_size
|
65
|
-
self.num_hidden_layers = num_hidden_layers
|
66
|
-
self.num_attention_heads = num_attention_heads
|
67
|
-
self.hidden_act = hidden_act
|
68
|
-
self.intermediate_size = intermediate_size
|
69
|
-
self.hidden_dropout_prob = hidden_dropout_prob
|
70
|
-
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
71
|
-
self.max_position_embeddings = max_position_embeddings
|
72
|
-
self.initializer_range = initializer_range
|
73
|
-
self.pad_token_id = 0
|
74
|
-
|
75
|
-
@classmethod
|
76
|
-
def from_dict(cls, json_object):
|
77
|
-
"""Constructs a `NewsConfig` from a Python dictionary of parameters."""
|
78
|
-
config = GroverConfig(vocab_size=None)
|
79
|
-
for (key, value) in six.iteritems(json_object):
|
80
|
-
config.__dict__[key] = value
|
81
|
-
return config
|
82
|
-
|
83
|
-
@classmethod
|
84
|
-
def from_json_file(cls, json_file):
|
85
|
-
"""Constructs a `NewsConfig` from a json file of parameters."""
|
86
|
-
with tf.gfile.GFile(json_file, "r") as reader:
|
87
|
-
text = reader.read()
|
88
|
-
return cls.from_dict(json.loads(text))
|
89
|
-
|
90
|
-
def to_dict(self):
|
91
|
-
"""Serializes this instance to a Python dictionary."""
|
92
|
-
output = copy.deepcopy(self.__dict__)
|
93
|
-
return output
|
94
|
-
|
95
|
-
def to_json_string(self):
|
96
|
-
"""Serializes this instance to a JSON string."""
|
97
|
-
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
98
|
-
|
99
|
-
|
100
|
-
def mask_attention_for_ltr(attention_scores, attention_mask):
|
101
|
-
"""
|
102
|
-
Mask attention so that we're only predicting going forward
|
103
|
-
:param attention_scores: [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
|
104
|
-
:param attention_mask [query_length, key_length]
|
105
|
-
:return: masked attention
|
106
|
-
"""
|
107
|
-
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
108
|
-
# masked positions, this operation will create a tensor which is 0.0 for
|
109
|
-
# positions we want to attend and -10000.0 for masked positions.
|
110
|
-
mask = attention_mask[None, None]
|
111
|
-
return attention_scores * mask - tf.cast(1e10, attention_scores.dtype) * (1 - mask)
|
112
|
-
|
113
|
-
|
114
|
-
def create_initializer(initializer_range=0.02):
|
115
|
-
"""Creates a `truncated_normal_initializer` with the given range."""
|
116
|
-
return tf.truncated_normal_initializer(stddev=initializer_range)
|
117
|
-
|
118
|
-
|
119
|
-
def _attention_projection_and_transpose(x_flat, batch_size, seq_length, num_attention_heads, size_per_head,
|
120
|
-
name, initializer_range=0.02):
|
121
|
-
"""
|
122
|
-
:param x_flat: [batch_size*seq_length, width]
|
123
|
-
:return: A fixed up tensor of size [batch_size, num_attention_heads, seq_length, size_per_head]
|
124
|
-
"""
|
125
|
-
batch_size_seq_length, dim = get_shape_list(x_flat, expected_rank=2)
|
126
|
-
|
127
|
-
if dim != size_per_head * num_attention_heads:
|
128
|
-
raise ValueError("passed in a tensor of shape {} when size_per_head={} and num_attention_heads={}".format(
|
129
|
-
(batch_size_seq_length, dim), size_per_head, num_attention_heads
|
130
|
-
))
|
131
|
-
|
132
|
-
projected = tf.layers.dense(
|
133
|
-
x_flat,
|
134
|
-
num_attention_heads * size_per_head,
|
135
|
-
name=name,
|
136
|
-
kernel_initializer=create_initializer(initializer_range))
|
137
|
-
|
138
|
-
projected = tf.reshape(
|
139
|
-
projected, [batch_size, seq_length, num_attention_heads, size_per_head])
|
140
|
-
output_tensor = tf.transpose(projected, [0, 2, 1, 3])
|
141
|
-
return output_tensor
|
142
|
-
|
143
|
-
|
144
|
-
def attention_layer(x_flat, attention_mask, batch_size, seq_length, size_per_head=512, num_attention_heads=1, *,
|
145
|
-
cache=None,
|
146
|
-
initializer_range=0.02, hidden_dropout_prob=0.1,
|
147
|
-
attention_probs_dropout_prob=0.1, do_cache=False):
|
148
|
-
"""
|
149
|
-
|
150
|
-
:param x_flat: Tensor input, should be [batch_size*seq_length, dim]
|
151
|
-
:param attention_mask: Attention mask to use of size [seq_length, seq_length+cached_length]
|
152
|
-
:param size_per_head: dim = size_per_head * num_attention_heads
|
153
|
-
:param num_attention_heads: dim = size_per_head * num_attention_heads
|
154
|
-
:param cache: Optionally some past (cached) things of size
|
155
|
-
[batch, 2, heads, sequence, features], where 2 is [k, v]
|
156
|
-
:param do_cache: True if we should return cache
|
157
|
-
:return: A new tensor of shape [batch_size, seq_length, dim]
|
158
|
-
as well as a new cache "cached_keys_and_values" that will be of size
|
159
|
-
[batch_size, 2, num_attention_heads, seq_length, dim]
|
160
|
-
"""
|
161
|
-
batch_size_seq_length, dim = get_shape_list(x_flat, expected_rank=2)
|
162
|
-
|
163
|
-
if dim != size_per_head * num_attention_heads:
|
164
|
-
raise ValueError("passed in a tensor of shape {} when size_per_head={} and num_attention_heads={}".format(
|
165
|
-
(batch_size_seq_length, dim), size_per_head, num_attention_heads
|
166
|
-
))
|
167
|
-
|
168
|
-
query = _attention_projection_and_transpose(x_flat, batch_size=batch_size, seq_length=seq_length,
|
169
|
-
num_attention_heads=num_attention_heads, size_per_head=size_per_head,
|
170
|
-
name='query_layer',
|
171
|
-
initializer_range=initializer_range)
|
172
|
-
key = _attention_projection_and_transpose(x_flat, batch_size=batch_size, seq_length=seq_length,
|
173
|
-
num_attention_heads=num_attention_heads, size_per_head=size_per_head,
|
174
|
-
name='key_layer',
|
175
|
-
initializer_range=initializer_range)
|
176
|
-
|
177
|
-
value = _attention_projection_and_transpose(x_flat, batch_size=batch_size, seq_length=seq_length,
|
178
|
-
num_attention_heads=num_attention_heads, size_per_head=size_per_head,
|
179
|
-
name='value_layer',
|
180
|
-
initializer_range=initializer_range)
|
181
|
-
|
182
|
-
# Add to cache
|
183
|
-
cached_keys_and_values = tf.stack([key, value], axis=1) if do_cache else None
|
184
|
-
|
185
|
-
# Things that were relevant from the cache
|
186
|
-
if cache is not None:
|
187
|
-
pk, pv = tf.unstack(cache, axis=1)
|
188
|
-
key = tf.concat([pk, key], axis=-2)
|
189
|
-
value = tf.concat([pv, value], axis=-2)
|
190
|
-
|
191
|
-
# Multiply [batch_size, num_attention_heads, seq_length, size_per_head] with
|
192
|
-
# [batch_size, num_attention_heads, size_per_head, seq_length+cached_length] ->
|
193
|
-
# [batch_size, num_attention_heads, seq_length, seq_length+cached_length]
|
194
|
-
attention_scores = tf.matmul(query, key, transpose_b=True)
|
195
|
-
attention_scores = tf.multiply(attention_scores,
|
196
|
-
1.0 / math.sqrt(float(size_per_head)))
|
197
|
-
attention_scores = mask_attention_for_ltr(attention_scores, attention_mask)
|
198
|
-
attention_probs = tf.nn.softmax(attention_scores)
|
199
|
-
|
200
|
-
# This is actually dropping out entire tokens to attend to, which might
|
201
|
-
# seem a bit unusual, but is taken from the original Transformer paper.
|
202
|
-
# NOPENOPENOPENOPE
|
203
|
-
# attention_probs = factoreddropout(attention_probs, attention_probs_dropout_prob)
|
204
|
-
|
205
|
-
# Multiply [batch_size, num_attention_heads, seq_length, seq_length+cached_length] with
|
206
|
-
# [batch_size, num_attention_heads, seq_length+cached_length, size_per_head] ->
|
207
|
-
# [batch_size, num_attention_heads, seq_length, size_per_head] ->
|
208
|
-
context_layer = tf.matmul(attention_probs, value)
|
209
|
-
|
210
|
-
# `context_layer` = [batch_size, seq_length, num_attention_heads, size_per_head]
|
211
|
-
context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
|
212
|
-
context_layer = tf.reshape(context_layer, [batch_size * seq_length, num_attention_heads * size_per_head])
|
213
|
-
|
214
|
-
context_layer_projected = tf.layers.dense(
|
215
|
-
context_layer,
|
216
|
-
num_attention_heads * size_per_head,
|
217
|
-
kernel_initializer=create_initializer(initializer_range),
|
218
|
-
name='context_projection_layer'
|
219
|
-
)
|
220
|
-
context_layer_projected = dropout(context_layer_projected, hidden_dropout_prob)
|
221
|
-
|
222
|
-
return context_layer_projected, cached_keys_and_values
|
223
|
-
|
224
|
-
|
225
|
-
def residual_mlp_layer(x_flat, intermediate_size, initializer_range=0.02, hidden_dropout_prob=0.1):
|
226
|
-
"""
|
227
|
-
:param x: The attention output. It should be [batch_size*seq_length, dim]
|
228
|
-
:param intermediate_size: the hidden projection. By default this is the input_dim * 4.
|
229
|
-
|
230
|
-
in the original GPT we would return layer_norm(x_norm + h1) rather than layer_norm(x + h1)
|
231
|
-
|
232
|
-
:return:
|
233
|
-
"""
|
234
|
-
batch_size_seq_length, hidden_size = get_shape_list(x_flat, expected_rank=2)
|
235
|
-
x_norm = layer_norm(x_flat, name='mlp_ln0')
|
236
|
-
|
237
|
-
intermediate_output = tf.layers.dense(
|
238
|
-
x_norm,
|
239
|
-
intermediate_size,
|
240
|
-
activation=gelu,
|
241
|
-
kernel_initializer=create_initializer(initializer_range),
|
242
|
-
name='intermediate',
|
243
|
-
)
|
244
|
-
|
245
|
-
output_for_residual = tf.layers.dense(
|
246
|
-
intermediate_output,
|
247
|
-
hidden_size,
|
248
|
-
name='output',
|
249
|
-
kernel_initializer=create_initializer(initializer_range))
|
250
|
-
output_for_residual = dropout(output_for_residual, hidden_dropout_prob)
|
251
|
-
|
252
|
-
layer_output = layer_norm(x_flat + output_for_residual, name='mlp_ln1')
|
253
|
-
return layer_output
|
254
|
-
|
255
|
-
|
256
|
-
def embed(input_ids,
|
257
|
-
vocab_size,
|
258
|
-
embedding_size,
|
259
|
-
position_offset=0,
|
260
|
-
initializer_range=0.02,
|
261
|
-
max_position_embeddings=512,
|
262
|
-
use_one_hot_embeddings=True):
|
263
|
-
"""reur and position embeddings
|
264
|
-
:param input_ids: int Tensor of shape [batch_size, seq_length].
|
265
|
-
:param vocab_size: number of words in vocab
|
266
|
-
:param embedding_size: dimensionality of the embedding
|
267
|
-
:param position_offset: aka number of cached tokens.
|
268
|
-
:param initializer_range: float. Range of the weight initialization.
|
269
|
-
:param max_position_embeddings: int. Maximum sequence length.
|
270
|
-
:param use_one_hot_embeddings: probably want this to be true
|
271
|
-
:return: [batch_size, seq_length, embedding_size] embedded tensor
|
272
|
-
"""
|
273
|
-
(batch_size, seq_length) = get_shape_list(input_ids, expected_rank=2)
|
274
|
-
|
275
|
-
embedding_table = tf.get_variable(
|
276
|
-
name='word_embed',
|
277
|
-
shape=[vocab_size, embedding_size],
|
278
|
-
initializer=create_initializer(initializer_range),
|
279
|
-
)
|
280
|
-
|
281
|
-
assert_op = tf.assert_less_equal(tf.reduce_max(input_ids), vocab_size - 1)
|
282
|
-
with tf.control_dependencies([assert_op]):
|
283
|
-
if use_one_hot_embeddings:
|
284
|
-
flat_input_ids = tf.reshape(input_ids, [-1])
|
285
|
-
one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
|
286
|
-
output_flat = tf.matmul(one_hot_input_ids, embedding_table)
|
287
|
-
else:
|
288
|
-
output_flat = tf.nn.embedding_lookup(embedding_table, input_ids)
|
289
|
-
|
290
|
-
embedded_input = tf.reshape(output_flat, [batch_size, seq_length, embedding_size])
|
291
|
-
|
292
|
-
assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
|
293
|
-
|
294
|
-
with tf.control_dependencies([assert_op]):
|
295
|
-
full_position_embeddings = tf.get_variable(
|
296
|
-
name='pos_embed',
|
297
|
-
shape=[max_position_embeddings, embedding_size],
|
298
|
-
initializer=create_initializer(initializer_range),
|
299
|
-
)
|
300
|
-
# Since the position embedding table is a learned variable, we create it
|
301
|
-
# using a (long) sequence length `max_position_embeddings`. The actual
|
302
|
-
# sequence length might be shorter than this, for faster training of
|
303
|
-
# tasks that do not have long sequences.
|
304
|
-
#
|
305
|
-
# So `full_position_embeddings` is effectively an embedding table
|
306
|
-
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
|
307
|
-
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
|
308
|
-
# perform a slice.
|
309
|
-
if position_offset == 0:
|
310
|
-
embedded_input += tf.slice(full_position_embeddings, [0, 0], [seq_length, -1])[None]
|
311
|
-
else:
|
312
|
-
# Tensorflow is too stupid to allow slicing
|
313
|
-
flat_pos_ids = (tf.range(seq_length, dtype=tf.int32) + position_offset)
|
314
|
-
one_hot_pos_ids = tf.one_hot(flat_pos_ids, depth=max_position_embeddings)
|
315
|
-
|
316
|
-
# [seq_length, full_position_embeddings], [full_position_embeddings, dim]
|
317
|
-
seq_embeds = tf.matmul(one_hot_pos_ids, full_position_embeddings)
|
318
|
-
embedded_input += seq_embeds[None]
|
319
|
-
|
320
|
-
# embedded_input += tf.slice(full_position_embeddings[position_offset:], [0, 0], [seq_length, -1])[None]
|
321
|
-
|
322
|
-
return layer_norm(embedded_input, name='embed_norm'), embedding_table
|
323
|
-
|
324
|
-
|
325
|
-
def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9):
|
326
|
-
"""
|
327
|
-
Does top-p sampling. if ignore_ids is on, then we will zero out those logits.
|
328
|
-
:param logits: [batch_size, vocab_size] tensor
|
329
|
-
:param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
|
330
|
-
like padding maybe
|
331
|
-
:param p: topp threshold to use, either a float or a [batch_size] vector
|
332
|
-
:return: [batch_size, num_samples] samples
|
333
|
-
|
334
|
-
# TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
|
335
|
-
"""
|
336
|
-
with tf.variable_scope('top_p_sample'):
|
337
|
-
batch_size, vocab_size = get_shape_list(logits, expected_rank=2)
|
338
|
-
|
339
|
-
probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
|
340
|
-
axis=-1)
|
341
|
-
|
342
|
-
if isinstance(p, float) and p > 0.999999:
|
343
|
-
# Don't do top-p sampling in this case
|
344
|
-
print("Top-p sampling DISABLED", flush=True)
|
345
|
-
return {
|
346
|
-
'probs': probs,
|
347
|
-
'sample': tf.random.categorical(
|
348
|
-
logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
|
349
|
-
num_samples=num_samples, dtype=tf.int32),
|
350
|
-
}
|
351
|
-
|
352
|
-
# [batch_size, vocab_perm]
|
353
|
-
indices = tf.argsort(probs, direction='DESCENDING')
|
354
|
-
cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False)
|
355
|
-
|
356
|
-
# find the top pth index to cut off. careful we don't want to cutoff everything!
|
357
|
-
# result will be [batch_size, vocab_perm]
|
358
|
-
p_expanded = p if isinstance(p, float) else p[:, None]
|
359
|
-
exclude_mask = tf.logical_not(
|
360
|
-
tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1))
|
361
|
-
|
362
|
-
# OPTION A - sample in the sorted space, then unsort.
|
363
|
-
logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
|
364
|
-
sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
|
365
|
-
sample = tf.batch_gather(indices, sample_perm)
|
366
|
-
|
367
|
-
# OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample
|
368
|
-
# unperm_indices = tf.argsort(indices, direction='ASCENDING')
|
369
|
-
# include_mask_unperm = tf.batch_gather(include_mask, unperm_indices)
|
370
|
-
# logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10
|
371
|
-
# sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32)
|
372
|
-
|
373
|
-
return {
|
374
|
-
'probs': probs,
|
375
|
-
'sample': sample,
|
376
|
-
}
|
377
|
-
|
378
|
-
|
379
|
-
def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10):
|
380
|
-
"""
|
381
|
-
Does top-k sampling. if ignore_ids is on, then we will zero out those logits.
|
382
|
-
:param logits: [batch_size, vocab_size] tensor
|
383
|
-
:param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
|
384
|
-
like padding maybe
|
385
|
-
:param p: topp threshold to use, either a float or a [batch_size] vector
|
386
|
-
:return: [batch_size, num_samples] samples
|
387
|
-
|
388
|
-
# TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
|
389
|
-
"""
|
390
|
-
with tf.variable_scope('top_p_sample'):
|
391
|
-
batch_size, vocab_size = get_shape_list(logits, expected_rank=2)
|
392
|
-
|
393
|
-
probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
|
394
|
-
axis=-1)
|
395
|
-
# [batch_size, vocab_perm]
|
396
|
-
indices = tf.argsort(probs, direction='DESCENDING')
|
397
|
-
|
398
|
-
# find the top pth index to cut off. careful we don't want to cutoff everything!
|
399
|
-
# result will be [batch_size, vocab_perm]
|
400
|
-
k_expanded = k if isinstance(k, int) else k[:, None]
|
401
|
-
exclude_mask = tf.range(vocab_size)[None] >= k_expanded
|
402
|
-
|
403
|
-
# OPTION A - sample in the sorted space, then unsort.
|
404
|
-
logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
|
405
|
-
sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
|
406
|
-
sample = tf.batch_gather(indices, sample_perm)
|
407
|
-
|
408
|
-
return {
|
409
|
-
'probs': probs,
|
410
|
-
'sample': sample,
|
411
|
-
}
|
412
|
-
|
413
|
-
|
414
|
-
class GroverModel(object):
|
415
|
-
def __init__(self,
|
416
|
-
config: GroverConfig,
|
417
|
-
is_training,
|
418
|
-
input_ids,
|
419
|
-
cache=None,
|
420
|
-
do_cache=False,
|
421
|
-
pad_token_id=0,
|
422
|
-
chop_off_last_token=True,
|
423
|
-
scope=None,
|
424
|
-
reuse=False):
|
425
|
-
"""
|
426
|
-
:param config:
|
427
|
-
:param is_training:
|
428
|
-
:param input_ids: Tensor thats of size [batch_size, seq_length]
|
429
|
-
:param cache: Optionally, a tensor to use that will contain cached information of the size
|
430
|
-
[batch_size, num_layers, 2, num_heads, cache_length, features]
|
431
|
-
:param do_cache: Whether to cache again.
|
432
|
-
:param pad_token_id: Which token will be used for padding (probably 0.)
|
433
|
-
:param chop_off_last_token: True if we will end up using this for TRAINING only. False if we want to generate.
|
434
|
-
it means the last token in input_ids will not be processed by the model as input
|
435
|
-
:param scope: scope to run this on
|
436
|
-
"""
|
437
|
-
self.config = copy.deepcopy(config)
|
438
|
-
self.is_training = is_training
|
439
|
-
self.pad_token_id = pad_token_id
|
440
|
-
|
441
|
-
if not is_training:
|
442
|
-
self.config.hidden_dropout_prob = 0.0
|
443
|
-
self.config.attention_probs_dropout_prob = 0.0
|
444
|
-
|
445
|
-
if chop_off_last_token:
|
446
|
-
self.target_ids = input_ids[:, 1:]
|
447
|
-
self.input_ids = input_ids[:, :-1]
|
448
|
-
else:
|
449
|
-
self.input_ids = input_ids
|
450
|
-
self.target_ids = tf.concat((input_ids[:, 1:],
|
451
|
-
tf.constant(self.pad_token_id, dtype=self.input_ids.dtype,
|
452
|
-
shape=[get_shape_list(self.input_ids, 2)[0], 1])), 1)
|
453
|
-
|
454
|
-
self.batch_size, self.seq_length = get_shape_list(self.input_ids, 2)
|
455
|
-
|
456
|
-
if cache is None:
|
457
|
-
caches = [None] * config.num_hidden_layers
|
458
|
-
self.cache_length = 0
|
459
|
-
else:
|
460
|
-
batch_size_, num_layers_, two_, num_heads_, self.cache_length, features_ = get_shape_list(
|
461
|
-
cache, expected_rank=6)
|
462
|
-
assert batch_size_ == self.batch_size
|
463
|
-
assert num_layers_ == config.num_hidden_layers
|
464
|
-
assert two_ == 2
|
465
|
-
assert num_heads_ == config.num_attention_heads
|
466
|
-
assert features_ == (config.hidden_size // config.num_attention_heads)
|
467
|
-
caches = tf.unstack(cache, axis=1)
|
468
|
-
|
469
|
-
with tf.variable_scope(scope, default_name='newslm', reuse=reuse):
|
470
|
-
with tf.variable_scope("embeddings"):
|
471
|
-
embeddings, self.embedding_table = embed(self.input_ids, config.vocab_size,
|
472
|
-
config.hidden_size,
|
473
|
-
position_offset=self.cache_length,
|
474
|
-
initializer_range=config.initializer_range,
|
475
|
-
max_position_embeddings=config.max_position_embeddings,
|
476
|
-
use_one_hot_embeddings=True)
|
477
|
-
|
478
|
-
mask = get_attention_mask(self.seq_length, self.seq_length + self.cache_length, dtype=embeddings.dtype)
|
479
|
-
|
480
|
-
# We keep the representation as a 2D tensor to avoid re-shaping it back and
|
481
|
-
# forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
|
482
|
-
# the GPU/CPU but may not be free on the TPU, so we want to minimize them to
|
483
|
-
# help the optimizer.
|
484
|
-
hidden_state = tf.reshape(embeddings, [self.batch_size * self.seq_length, self.config.hidden_size])
|
485
|
-
new_kvs = []
|
486
|
-
for layer_idx, layer_cache in enumerate(caches):
|
487
|
-
with tf.variable_scope('layer{:02d}'.format(layer_idx)):
|
488
|
-
# [batch_size * seq_length, hidden_size]
|
489
|
-
attention_output, new_kv = attention_layer(
|
490
|
-
hidden_state,
|
491
|
-
mask,
|
492
|
-
batch_size=self.batch_size,
|
493
|
-
seq_length=self.seq_length,
|
494
|
-
size_per_head=config.hidden_size // config.num_attention_heads,
|
495
|
-
num_attention_heads=config.num_attention_heads,
|
496
|
-
initializer_range=config.initializer_range,
|
497
|
-
hidden_dropout_prob=self.config.hidden_dropout_prob,
|
498
|
-
attention_probs_dropout_prob=self.config.attention_probs_dropout_prob,
|
499
|
-
do_cache=do_cache,
|
500
|
-
cache=layer_cache,
|
501
|
-
)
|
502
|
-
new_kvs.append(new_kv)
|
503
|
-
|
504
|
-
# [batch_size * seq_length, hidden_size]
|
505
|
-
hidden_state = residual_mlp_layer(hidden_state + attention_output,
|
506
|
-
intermediate_size=config.intermediate_size,
|
507
|
-
hidden_dropout_prob=self.config.hidden_dropout_prob)
|
508
|
-
self.hidden_state = hidden_state
|
509
|
-
|
510
|
-
self.new_kvs = tf.stack(new_kvs, axis=1) if do_cache else None
|
511
|
-
|
512
|
-
# Note that the hidden state is still flat (batch_size*hidden_size)
|
513
|
-
self.logits_flat = tf.matmul(self.hidden_state, self.embedding_table, transpose_b=True)
|
514
|
-
|
515
|
-
# THE OUTPUT BIAS DOES NOT SPARK JOY
|
516
|
-
# output_bias = tf.get_variable('output_bias', shape=[config.vocab_size], initializer=tf.zeros_initializer())
|
517
|
-
# self.logits_flat = tf.nn.bias_add(self.logits_flat, output_bias)
|
518
|
-
|
519
|
-
@property
|
520
|
-
def log_probs(self):
|
521
|
-
logprobs_flat = tf.nn.log_softmax(self.logits_flat, axis=-1)
|
522
|
-
return tf.reshape(logprobs_flat, [self.batch_size, self.seq_length, -1])
|
523
|
-
|
524
|
-
def lm_loss(self):
|
525
|
-
"""
|
526
|
-
:return: stuff
|
527
|
-
"""
|
528
|
-
target_ids_flat = tf.reshape(self.target_ids, [-1])
|
529
|
-
|
530
|
-
# 1 if it's valid and 0 otherwise.
|
531
|
-
label_weights = tf.cast(tf.not_equal(target_ids_flat, self.pad_token_id), dtype=self.logits_flat.dtype)
|
532
|
-
|
533
|
-
# [batch_size * seq_length, vocab_size]
|
534
|
-
one_hot_labels = tf.one_hot(target_ids_flat,
|
535
|
-
depth=self.config.vocab_size,
|
536
|
-
dtype=self.logits_flat.dtype)
|
537
|
-
|
538
|
-
# [batch_size * seq_length, vocab_size]
|
539
|
-
logprobs_flat = tf.nn.log_softmax(self.logits_flat, axis=-1)
|
540
|
-
|
541
|
-
per_example_loss = -tf.reduce_sum(logprobs_flat * one_hot_labels, axis=[-1])
|
542
|
-
|
543
|
-
# per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits_flat, labels=target_ids_flat)
|
544
|
-
|
545
|
-
numerator = tf.reduce_sum(label_weights * per_example_loss)
|
546
|
-
denominator = tf.reduce_sum(label_weights) + 1e-5
|
547
|
-
loss = numerator / denominator
|
548
|
-
return loss
|
549
|
-
|
550
|
-
def pooled_output(self, clf_token):
|
551
|
-
"""
|
552
|
-
Extract pooled output given a token that says where we should look
|
553
|
-
:param clf_token:
|
554
|
-
:return:
|
555
|
-
"""
|
556
|
-
pool_idx = tf.cast(tf.argmax(tf.cast(tf.equal(self.input_ids, clf_token), tf.float32), 1), tf.int32)
|
557
|
-
return tf.gather(self.hidden_state, tf.range(self.batch_size, dtype=tf.int32) * self.seq_length + pool_idx)
|
558
|
-
|
559
|
-
|
560
|
-
def model_fn_builder(config: GroverConfig, init_checkpoint, learning_rate,
|
561
|
-
num_train_steps, num_warmup_steps, use_tpu, num_tpu_cores=8 , eval_batch_size=1):
|
562
|
-
"""Returns `model_fn` closure for TPUEstimator."""
|
563
|
-
|
564
|
-
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
565
|
-
"""The `model_fn` for TPUEstimator."""
|
566
|
-
|
567
|
-
tf.logging.info("*** Features ***")
|
568
|
-
for name in sorted(features.keys()):
|
569
|
-
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
570
|
-
|
571
|
-
input_ids = features["input_ids"]
|
572
|
-
|
573
|
-
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
|
574
|
-
|
575
|
-
model = GroverModel(
|
576
|
-
config=config,
|
577
|
-
is_training=is_training,
|
578
|
-
input_ids=input_ids,
|
579
|
-
pad_token_id=config.pad_token_id,
|
580
|
-
chop_off_last_token=True,
|
581
|
-
)
|
582
|
-
|
583
|
-
total_loss = model.lm_loss()
|
584
|
-
|
585
|
-
if is_training:
|
586
|
-
train_op, train_metrics = optimization_adafactor.create_optimizer(
|
587
|
-
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
|
588
|
-
tvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
589
|
-
else:
|
590
|
-
train_op = None
|
591
|
-
train_metrics = {}
|
592
|
-
tvars = tf.trainable_variables()
|
593
|
-
|
594
|
-
initialized_variable_names = {}
|
595
|
-
scaffold_fn = None
|
596
|
-
if init_checkpoint:
|
597
|
-
(assignment_map, initialized_variable_names
|
598
|
-
) = get_assignment_map_from_checkpoint(tvars, init_checkpoint)
|
599
|
-
if use_tpu:
|
600
|
-
def tpu_scaffold():
|
601
|
-
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
602
|
-
return tf.train.Scaffold()
|
603
|
-
|
604
|
-
scaffold_fn = tpu_scaffold
|
605
|
-
else:
|
606
|
-
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
607
|
-
|
608
|
-
tf.logging.info("**** Trainable Variables ****")
|
609
|
-
for var in tvars:
|
610
|
-
init_string = ""
|
611
|
-
if var.name in initialized_variable_names:
|
612
|
-
init_string = ", *INIT_FROM_CKPT*"
|
613
|
-
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
|
614
|
-
init_string)
|
615
|
-
|
616
|
-
output_spec = None
|
617
|
-
if mode == tf.estimator.ModeKeys.TRAIN:
|
618
|
-
if use_tpu:
|
619
|
-
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
620
|
-
mode=mode,
|
621
|
-
loss=total_loss,
|
622
|
-
train_op=train_op,
|
623
|
-
host_call=construct_scalar_host_call(metric_dict=train_metrics, model_dir=params['model_dir'],
|
624
|
-
prefix='training/'),
|
625
|
-
scaffold_fn=scaffold_fn)
|
626
|
-
else:
|
627
|
-
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
628
|
-
mode=mode,
|
629
|
-
loss=total_loss,
|
630
|
-
train_op=train_op,
|
631
|
-
training_hooks=[
|
632
|
-
tf.train.LoggingTensorHook({'loss': tf.metrics.mean(total_loss)[1]}, every_n_iter=100)],
|
633
|
-
scaffold_fn=scaffold_fn)
|
634
|
-
|
635
|
-
elif mode == tf.estimator.ModeKeys.EVAL:
|
636
|
-
def metric_fn(loss):
|
637
|
-
"""Evaluation metric Fn which runs on CPU."""
|
638
|
-
perplexity = tf.exp(tf.reduce_mean(loss))
|
639
|
-
bpc = tf.reduce_mean(loss) / tf.constant(math.log(2))
|
640
|
-
return {
|
641
|
-
"perplexity": tf.metrics.mean(perplexity),
|
642
|
-
"bpc": tf.metrics.mean(bpc),
|
643
|
-
}
|
644
|
-
|
645
|
-
if use_tpu:
|
646
|
-
with tf.colocate_with(total_loss):
|
647
|
-
total_loss = tf.contrib.tpu.cross_replica_sum(total_loss) \
|
648
|
-
/ num_tpu_cores
|
649
|
-
metric_loss = tf.tile(tf.reshape(total_loss, [1, 1]), [eval_batch_size, 1])
|
650
|
-
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
651
|
-
mode=mode,
|
652
|
-
loss=total_loss,
|
653
|
-
eval_metrics=(metric_fn, [metric_loss]),
|
654
|
-
scaffold_fn=scaffold_fn)
|
655
|
-
else:
|
656
|
-
gt_logprobs = tf.squeeze(tf.batch_gather(model.log_probs, model.target_ids[:, :, None]), axis=2)
|
657
|
-
|
658
|
-
# Need top-p required under topp sampling!
|
659
|
-
better_than_gt = model.log_probs > gt_logprobs[:, :, None]
|
660
|
-
top_p_required = tf.reduce_sum(tf.cast(better_than_gt, tf.float32) * tf.exp(model.log_probs), axis=2)
|
661
|
-
|
662
|
-
# No top-p sampling for now, since this seems to be too slow on TPUs
|
663
|
-
if use_tpu:
|
664
|
-
predictions = tf.reshape(
|
665
|
-
tf.random.categorical(logits=model.logits_flat, num_samples=1),
|
666
|
-
get_shape_list(model.target_ids),
|
667
|
-
)
|
668
|
-
else:
|
669
|
-
# Argmax
|
670
|
-
# predictions = tf.math.argmax(model.log_probs, axis=-1, output_type=tf.int32)
|
671
|
-
predictions = tf.reshape(
|
672
|
-
_top_p_sample(model.logits_flat, num_samples=1, p=0.99)['sample'],
|
673
|
-
get_shape_list(model.target_ids),
|
674
|
-
)
|
675
|
-
pred_logprobs = tf.squeeze(tf.batch_gather(model.log_probs, predictions[:, :, None]), axis=2)
|
676
|
-
|
677
|
-
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
678
|
-
mode=mode,
|
679
|
-
predictions={'gt_logprobs': gt_logprobs,
|
680
|
-
'top_p_required': top_p_required,
|
681
|
-
'predictions': predictions,
|
682
|
-
'pred_logprobs': pred_logprobs,
|
683
|
-
'labels': input_ids},
|
684
|
-
scaffold_fn=scaffold_fn)
|
685
|
-
return output_spec
|
686
|
-
|
687
|
-
return model_fn
|
688
|
-
|
689
|
-
|
690
|
-
def sample_step(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False):
|
691
|
-
"""
|
692
|
-
Helper function that samples from grover for a single step
|
693
|
-
:param tokens: [batch_size, n_ctx_b] tokens that we will predict from
|
694
|
-
:param ignore_ids: [n_vocab] mask of the tokens we don't want to predict
|
695
|
-
:param news_config: config for the GroverModel
|
696
|
-
:param batch_size: batch size to use
|
697
|
-
:param p_for_topp: top-p or top-k threshold
|
698
|
-
:param cache: [batch_size, news_config.num_hidden_layers, 2,
|
699
|
-
news_config.num_attention_heads, n_ctx_a,
|
700
|
-
news_config.hidden_size // news_config.num_attention_heads] OR, None
|
701
|
-
:return: new_tokens, size [batch_size]
|
702
|
-
new_probs, also size [batch_size]
|
703
|
-
new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b,
|
704
|
-
news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads]
|
705
|
-
"""
|
706
|
-
model = GroverModel(
|
707
|
-
config=news_config,
|
708
|
-
is_training=False,
|
709
|
-
input_ids=tokens,
|
710
|
-
reuse=tf.AUTO_REUSE,
|
711
|
-
scope='newslm',
|
712
|
-
chop_off_last_token=False,
|
713
|
-
do_cache=True,
|
714
|
-
cache=cache,
|
715
|
-
)
|
716
|
-
|
717
|
-
# Extract the FINAL SEQ LENGTH
|
718
|
-
batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat, expected_rank=2)
|
719
|
-
next_logits = tf.reshape(model.logits_flat, [batch_size, -1, vocab_size])[:, -1]
|
720
|
-
|
721
|
-
if do_topk:
|
722
|
-
sample_info = _top_k_sample(next_logits, num_samples=1, k=tf.cast(p_for_topp, dtype=tf.int32))
|
723
|
-
else:
|
724
|
-
sample_info = _top_p_sample(next_logits, ignore_ids=ignore_ids, num_samples=1, p=p_for_topp)
|
725
|
-
|
726
|
-
new_tokens = tf.squeeze(sample_info['sample'], 1)
|
727
|
-
new_probs = tf.squeeze(tf.batch_gather(sample_info['probs'], sample_info['sample']), 1)
|
728
|
-
return {
|
729
|
-
'new_tokens': new_tokens,
|
730
|
-
'new_probs': new_probs,
|
731
|
-
'new_cache': model.new_kvs,
|
732
|
-
}
|
733
|
-
|
734
|
-
|
735
|
-
def initialize_from_context(initial_context, ignore_ids, news_config, p_for_topp=0.95, do_topk=False):
|
736
|
-
""" same signature as sample_step"""
|
737
|
-
batch_size, _ = get_shape_list(initial_context, expected_rank=2)
|
738
|
-
|
739
|
-
context_output = sample_step(tokens=initial_context, ignore_ids=ignore_ids, news_config=news_config,
|
740
|
-
batch_size=batch_size, p_for_topp=p_for_topp, cache=None, do_topk=do_topk)
|
741
|
-
return {
|
742
|
-
'tokens': tf.concat([initial_context, context_output['new_tokens'][:, None]], 1),
|
743
|
-
'cache': context_output['new_cache'],
|
744
|
-
'probs': context_output['new_probs'][:, None]
|
745
|
-
}
|
746
|
-
|
747
|
-
|
748
|
-
def sample(news_config: GroverConfig, initial_context, eos_token, min_len, ignore_ids=None, p_for_topp=0.95,
|
749
|
-
do_topk=False):
|
750
|
-
"""
|
751
|
-
V1 version of: sample outputs from a model, and do it all at once
|
752
|
-
:param news_config: Configuration used to construct the model
|
753
|
-
:param initial_context: [batch_size, seq_length] that we'll start generating with
|
754
|
-
:param eos_token: Stop generating if you see this (tf scalar)
|
755
|
-
:param min_len: min length of sample
|
756
|
-
:param ignore_ids: NEVER GENERATE THESE [vocab_size]
|
757
|
-
:return:
|
758
|
-
"""
|
759
|
-
batch_size, _ = get_shape_list(initial_context, expected_rank=2)
|
760
|
-
|
761
|
-
if ignore_ids is None:
|
762
|
-
ignore_ids = tf.constant([x == 0 for x in range(news_config.vocab_size)], dtype=tf.bool)
|
763
|
-
|
764
|
-
with tf.name_scope('sample_sequence'):
|
765
|
-
# Initial call to get cache
|
766
|
-
context_output = initialize_from_context(initial_context, ignore_ids=ignore_ids, news_config=news_config,
|
767
|
-
p_for_topp=p_for_topp,
|
768
|
-
do_topk=do_topk)
|
769
|
-
ctx = context_output['tokens']
|
770
|
-
cache = context_output['cache']
|
771
|
-
probs = context_output['probs']
|
772
|
-
|
773
|
-
def body(ctx, cache, probs):
|
774
|
-
""" for whatever reason this didn't work when I ran it on more than one at once... ugh."""
|
775
|
-
next_outputs = sample_step(ctx[:, -1][:, None], ignore_ids=ignore_ids, news_config=news_config,
|
776
|
-
batch_size=batch_size, p_for_topp=p_for_topp, cache=cache,
|
777
|
-
do_topk=do_topk)
|
778
|
-
|
779
|
-
# Update everything
|
780
|
-
new_cache = tf.concat([cache, next_outputs['new_cache']], axis=-2)
|
781
|
-
new_ids = tf.concat([ctx, next_outputs['new_tokens'][:, None]], axis=1)
|
782
|
-
new_probs = tf.concat([probs, next_outputs['new_probs'][:, None]], axis=1)
|
783
|
-
return [new_ids, new_cache, new_probs]
|
784
|
-
|
785
|
-
def cond(ctx, cache, probs):
|
786
|
-
# ctx = tf.Print(ctx,[tf.shape(ctx)])
|
787
|
-
is_eos = tf.reduce_all(tf.reduce_any(tf.equal(ctx[:,-1:], eos_token), axis=1))
|
788
|
-
is_len = tf.greater(get_shape_list(ctx)[1], min_len)
|
789
|
-
return tf.logical_not(tf.logical_and(is_eos, is_len))
|
790
|
-
|
791
|
-
tokens, cache, probs = tf.while_loop(
|
792
|
-
cond=cond, body=body, maximum_iterations=1025 - get_shape_list(ctx)[1],
|
793
|
-
loop_vars=[ctx, cache, probs],
|
794
|
-
shape_invariants=[tf.TensorShape([batch_size, None]),
|
795
|
-
tf.TensorShape(
|
796
|
-
[batch_size, news_config.num_hidden_layers, 2,
|
797
|
-
news_config.num_attention_heads,
|
798
|
-
None, news_config.hidden_size // news_config.num_attention_heads]),
|
799
|
-
tf.TensorShape([batch_size, None]),
|
800
|
-
],
|
801
|
-
back_prop=False,
|
802
|
-
)
|
803
|
-
return tokens, probs
|