tf-models-nightly 2.17.0.dev20240327__py2.py3-none-any.whl → 2.17.0.dev20240329__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/nlp/modeling/ops/beam_search.py +44 -7
- official/nlp/modeling/ops/beam_search_test.py +48 -0
- {tf_models_nightly-2.17.0.dev20240327.dist-info → tf_models_nightly-2.17.0.dev20240329.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.17.0.dev20240327.dist-info → tf_models_nightly-2.17.0.dev20240329.dist-info}/RECORD +8 -8
- {tf_models_nightly-2.17.0.dev20240327.dist-info → tf_models_nightly-2.17.0.dev20240329.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.17.0.dev20240327.dist-info → tf_models_nightly-2.17.0.dev20240329.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.17.0.dev20240327.dist-info → tf_models_nightly-2.17.0.dev20240329.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.17.0.dev20240327.dist-info → tf_models_nightly-2.17.0.dev20240329.dist-info}/top_level.txt +0 -0
@@ -69,6 +69,9 @@ class _StateKeys(object):
|
|
69
69
|
# At the beginning, all of the sequences in FINISHED_SEQ are filler values.
|
70
70
|
# True -> finished sequence, False -> filler. Shape [batch_size, beam_size]
|
71
71
|
FINISHED_FLAGS = "FINISHED_FLAGS"
|
72
|
+
# for prefix matching hack. The BS will only constraint the next token to
|
73
|
+
# where the mask is 1.
|
74
|
+
CONSTRAINT_MASK = "CONSTRAINT_MASK"
|
72
75
|
|
73
76
|
|
74
77
|
def _expand_to_same_rank(tensor, target):
|
@@ -150,7 +153,7 @@ class SequenceBeamSearch(tf.Module):
|
|
150
153
|
self.decoding_name = decoding_name
|
151
154
|
self.noise_multiplier = noise_multiplier
|
152
155
|
|
153
|
-
def search(self, initial_ids, initial_cache):
|
156
|
+
def search(self, initial_ids, initial_cache, constraint_mask=None):
|
154
157
|
"""Beam search for sequences with highest scores.
|
155
158
|
|
156
159
|
Args:
|
@@ -158,6 +161,9 @@ class SequenceBeamSearch(tf.Module):
|
|
158
161
|
with shape [batch_size, 1]
|
159
162
|
initial_cache: dictionary storing values to be passed into the
|
160
163
|
symbols_to_logits_fn.
|
164
|
+
constraint_mask: a [vocab_size] tensor, with 1 represent prefix. During
|
165
|
+
autoregressive decoding, the first token should be among where the
|
166
|
+
constraint_mask is 1.
|
161
167
|
|
162
168
|
Returns:
|
163
169
|
finished_seq and finished_scores.
|
@@ -165,8 +171,9 @@ class SequenceBeamSearch(tf.Module):
|
|
165
171
|
batch_size = (
|
166
172
|
initial_ids.shape.as_list()[0]
|
167
173
|
if self.padded_decode else tf.shape(initial_ids)[0])
|
168
|
-
state, state_shapes = self._create_initial_state(
|
169
|
-
|
174
|
+
state, state_shapes = self._create_initial_state(
|
175
|
+
initial_ids, initial_cache, batch_size, constraint_mask=constraint_mask
|
176
|
+
)
|
170
177
|
|
171
178
|
def _grow_alive_seq(state):
|
172
179
|
"""Grow alive sequences by one token, collect top 2*beam_size sequences.
|
@@ -204,6 +211,21 @@ class SequenceBeamSearch(tf.Module):
|
|
204
211
|
flat_logits, flat_cache = self.symbols_to_logits_fn(
|
205
212
|
flat_ids, i, flat_cache)
|
206
213
|
|
214
|
+
if _StateKeys.CONSTRAINT_MASK in state:
|
215
|
+
constraint_mask = state[_StateKeys.CONSTRAINT_MASK]
|
216
|
+
constraint_mask = tf.cond(
|
217
|
+
tf.equal(i, 0),
|
218
|
+
lambda: constraint_mask,
|
219
|
+
lambda: tf.ones_like(constraint_mask),
|
220
|
+
)
|
221
|
+
penalty = tf.cast(
|
222
|
+
tf.cast(constraint_mask != 1, tf.int32) * 999_999_999,
|
223
|
+
flat_logits.dtype,
|
224
|
+
)
|
225
|
+
flat_logits = flat_logits - penalty[tf.newaxis, :]
|
226
|
+
else:
|
227
|
+
constraint_mask = None
|
228
|
+
|
207
229
|
if self.noise_multiplier > 0:
|
208
230
|
noise = tf.random.uniform(flat_logits.shape, dtype=flat_logits.dtype)
|
209
231
|
# Generates standard Gumbel(0, 1) noise, GSE Tensors
|
@@ -250,7 +272,7 @@ class SequenceBeamSearch(tf.Module):
|
|
250
272
|
else:
|
251
273
|
topk_seq = tf.concat(
|
252
274
|
[topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)
|
253
|
-
return topk_seq, topk_log_probs, topk_ids, new_cache
|
275
|
+
return topk_seq, topk_log_probs, topk_ids, new_cache, constraint_mask
|
254
276
|
|
255
277
|
def _get_new_alive_state(new_seq, new_log_probs, new_finished_flags,
|
256
278
|
new_cache):
|
@@ -363,7 +385,9 @@ class SequenceBeamSearch(tf.Module):
|
|
363
385
|
new state dictionary.
|
364
386
|
"""
|
365
387
|
# Grow alive sequences by one token.
|
366
|
-
new_seq, new_log_probs, topk_ids, new_cache =
|
388
|
+
new_seq, new_log_probs, topk_ids, new_cache, constraint_mask = (
|
389
|
+
_grow_alive_seq(state)
|
390
|
+
)
|
367
391
|
new_finished_flags = tf.equal(topk_ids, self.eos_id[0])
|
368
392
|
for eos_id in self.eos_id[1:]:
|
369
393
|
one_finished_flags = tf.equal(topk_ids, eos_id)
|
@@ -383,6 +407,8 @@ class SequenceBeamSearch(tf.Module):
|
|
383
407
|
new_state = {_StateKeys.CUR_INDEX: state[_StateKeys.CUR_INDEX] + 1}
|
384
408
|
new_state.update(alive_state)
|
385
409
|
new_state.update(finished_state)
|
410
|
+
if constraint_mask is not None:
|
411
|
+
new_state[_StateKeys.CONSTRAINT_MASK] = constraint_mask
|
386
412
|
return [new_state]
|
387
413
|
|
388
414
|
finished_state = tf.nest.map_structure(
|
@@ -415,7 +441,9 @@ class SequenceBeamSearch(tf.Module):
|
|
415
441
|
finished_scores = tf.where(score_cond, finished_scores, alive_log_probs)
|
416
442
|
return finished_seq, finished_scores
|
417
443
|
|
418
|
-
def _create_initial_state(
|
444
|
+
def _create_initial_state(
|
445
|
+
self, initial_ids, initial_cache, batch_size, constraint_mask=None
|
446
|
+
):
|
419
447
|
"""Return initial state dictionary and its shape invariants."""
|
420
448
|
for key, value in initial_cache.items():
|
421
449
|
for inner_value in tf.nest.flatten(value):
|
@@ -466,6 +494,8 @@ class SequenceBeamSearch(tf.Module):
|
|
466
494
|
_StateKeys.FINISHED_SCORES: finished_scores,
|
467
495
|
_StateKeys.FINISHED_FLAGS: finished_flags
|
468
496
|
}
|
497
|
+
if constraint_mask is not None:
|
498
|
+
state[_StateKeys.CONSTRAINT_MASK] = constraint_mask
|
469
499
|
|
470
500
|
# Create state invariants for each value in the state dictionary. Each
|
471
501
|
# dimension must be a constant or None. A None dimension means either:
|
@@ -509,6 +539,10 @@ class SequenceBeamSearch(tf.Module):
|
|
509
539
|
_StateKeys.FINISHED_FLAGS:
|
510
540
|
tf.TensorShape([None, self.beam_size])
|
511
541
|
}
|
542
|
+
if constraint_mask is not None:
|
543
|
+
state_shape_invariants[_StateKeys.CONSTRAINT_MASK] = tf.TensorShape(
|
544
|
+
[self.vocab_size]
|
545
|
+
)
|
512
546
|
|
513
547
|
return state, state_shape_invariants
|
514
548
|
|
@@ -614,6 +648,7 @@ def sequence_beam_search(
|
|
614
648
|
dtype="float32",
|
615
649
|
noise_multiplier: float = 0.0,
|
616
650
|
decoding_name=None,
|
651
|
+
constraint_mask=None,
|
617
652
|
):
|
618
653
|
"""Search for sequence of subtoken ids with the largest probability.
|
619
654
|
|
@@ -641,6 +676,8 @@ def sequence_beam_search(
|
|
641
676
|
tf.float32.
|
642
677
|
noise_multiplier: The amount of noise.
|
643
678
|
decoding_name: an optional name for the decoding loop tensors.
|
679
|
+
constraint_mask: The BS will only constraint the next token to where the
|
680
|
+
mask is 1.
|
644
681
|
|
645
682
|
Returns:
|
646
683
|
Top decoded sequences [batch_size, beam_size, max_decode_length]
|
@@ -658,7 +695,7 @@ def sequence_beam_search(
|
|
658
695
|
noise_multiplier,
|
659
696
|
decoding_name,
|
660
697
|
)
|
661
|
-
return sbs.search(initial_ids, initial_cache)
|
698
|
+
return sbs.search(initial_ids, initial_cache, constraint_mask=constraint_mask)
|
662
699
|
|
663
700
|
|
664
701
|
def _log_prob_from_logits(logits):
|
@@ -150,6 +150,54 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
|
|
150
150
|
else:
|
151
151
|
self.assertAllEqual([[[0, 1, 0, 1], [0, 1, 1, 2]]], predictions)
|
152
152
|
|
153
|
+
@parameterized.named_parameters([
|
154
|
+
('padded_decode_true_with_name', True, 0.0, 'decoding'),
|
155
|
+
('padded_decode_false_with_name', False, 0.0, 'decoding'),
|
156
|
+
('padded_decode_true_without_name', True, 0.0, None),
|
157
|
+
('padded_decode_false_without_name', False, 0.0, None),
|
158
|
+
('padded_decode_false_with_noise', False, 0.5, 'decoding'),
|
159
|
+
])
|
160
|
+
def test_sequence_beam_search_with_prefix_constraint(
|
161
|
+
self, padded_decode, noise_multiplier, name
|
162
|
+
):
|
163
|
+
# batch_size*beam_size, max_decode_length, vocab_size
|
164
|
+
probabilities = tf.constant([
|
165
|
+
[[0.2, 0.7, 0.1], [0.5, 0.3, 0.2], [0.1, 0.8, 0.1]],
|
166
|
+
[[0.1, 0.8, 0.1], [0.3, 0.4, 0.3], [0.2, 0.1, 0.7]],
|
167
|
+
])
|
168
|
+
# batch_size, max_decode_length, num_heads, embed_size per head
|
169
|
+
x = tf.zeros([1, 3, 2, 32], dtype=tf.float32)
|
170
|
+
cache = {'layer_%d' % layer: {'k': x, 'v': x} for layer in range(2)}
|
171
|
+
|
172
|
+
def _get_test_symbols_to_logits_fn():
|
173
|
+
"""Test function that returns logits for next token."""
|
174
|
+
|
175
|
+
def symbols_to_logits_fn(_, i, cache):
|
176
|
+
logits = tf.cast(probabilities[:, i, :], tf.float32)
|
177
|
+
return logits, cache
|
178
|
+
|
179
|
+
return symbols_to_logits_fn
|
180
|
+
|
181
|
+
predictions, _ = beam_search.sequence_beam_search(
|
182
|
+
symbols_to_logits_fn=_get_test_symbols_to_logits_fn(),
|
183
|
+
initial_ids=tf.zeros([1], dtype=tf.int32),
|
184
|
+
initial_cache=cache,
|
185
|
+
vocab_size=3,
|
186
|
+
beam_size=2,
|
187
|
+
alpha=0.6,
|
188
|
+
max_decode_length=3,
|
189
|
+
eos_id=[9, 10],
|
190
|
+
padded_decode=padded_decode,
|
191
|
+
dtype=tf.float32,
|
192
|
+
noise_multiplier=noise_multiplier,
|
193
|
+
decoding_name=name,
|
194
|
+
constraint_mask=tf.constant([1, 0, 0]),
|
195
|
+
)
|
196
|
+
if noise_multiplier > 0:
|
197
|
+
self.assertAllEqual([[[0, 0, 0, 1], [0, 0, 0, 2]]], predictions)
|
198
|
+
else:
|
199
|
+
self.assertAllEqual([[[0, 0, 0, 1], [0, 0, 1, 2]]], predictions)
|
200
|
+
|
153
201
|
|
154
202
|
if __name__ == '__main__':
|
155
203
|
tf.test.main()
|
@@ -414,8 +414,8 @@ official/nlp/modeling/networks/sparse_mixer_test.py,sha256=9AY4gelHc-rrtUexr33-j
|
|
414
414
|
official/nlp/modeling/networks/xlnet_base.py,sha256=ditE18dFpJQ87U1-vC3VzgFpx0aK2Hyy6b4HgOO8De4,25867
|
415
415
|
official/nlp/modeling/networks/xlnet_base_test.py,sha256=zt8hLCpKy5wKWsbCizyq8mLGJc32OXIqbhWp0ysQGKc,14788
|
416
416
|
official/nlp/modeling/ops/__init__.py,sha256=VnA497WiK08ukev1d5Tjqc283YGQx6MnGyPAPk_jW7s,1011
|
417
|
-
official/nlp/modeling/ops/beam_search.py,sha256=
|
418
|
-
official/nlp/modeling/ops/beam_search_test.py,sha256=
|
417
|
+
official/nlp/modeling/ops/beam_search.py,sha256=1kwoD3SF1BiWbxeN4u77CjJXJ2hCEzOWJP295_BeWuU,31255
|
418
|
+
official/nlp/modeling/ops/beam_search_test.py,sha256=Sz1sirBnYktqQ82NbyLefVpkmLVr7BPVApVxW8DRuoI,7589
|
419
419
|
official/nlp/modeling/ops/decoding_module.py,sha256=-Aw_A2dUbRu7jd-DY4a7iWme-yNSvfng9g_XWdCGwXI,11279
|
420
420
|
official/nlp/modeling/ops/decoding_module_test.py,sha256=VTYYaZxihkDz1FkkwUIyc3EuCqGIW9fJS-3mYw3c4-8,2623
|
421
421
|
official/nlp/modeling/ops/sampling_module.py,sha256=gyUoOnNdh6TJGebce5BMUxTrhk79HzPM3whuEu5BP9A,19250
|
@@ -1203,9 +1203,9 @@ tensorflow_models/__init__.py,sha256=etxw45SHxuwFCRX5qGxGMP83II0JfJulzNl5GSNJvhw
|
|
1203
1203
|
tensorflow_models/tensorflow_models_test.py,sha256=AxUYUdiQn416UR7jg0h6rmv688esvlKDfpyDCIQkF18,1395
|
1204
1204
|
tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
|
1205
1205
|
tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
|
1206
|
-
tf_models_nightly-2.17.0.
|
1207
|
-
tf_models_nightly-2.17.0.
|
1208
|
-
tf_models_nightly-2.17.0.
|
1209
|
-
tf_models_nightly-2.17.0.
|
1210
|
-
tf_models_nightly-2.17.0.
|
1211
|
-
tf_models_nightly-2.17.0.
|
1206
|
+
tf_models_nightly-2.17.0.dev20240329.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
1207
|
+
tf_models_nightly-2.17.0.dev20240329.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
1208
|
+
tf_models_nightly-2.17.0.dev20240329.dist-info/METADATA,sha256=MJ2HcBGugMGyza9VmjXOGjsHsThDRt01z6v5PQmxTOE,1432
|
1209
|
+
tf_models_nightly-2.17.0.dev20240329.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
1210
|
+
tf_models_nightly-2.17.0.dev20240329.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
1211
|
+
tf_models_nightly-2.17.0.dev20240329.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|