tf-models-nightly 2.17.0.dev20240327__py2.py3-none-any.whl → 2.17.0.dev20240329__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -69,6 +69,9 @@ class _StateKeys(object):
69
69
  # At the beginning, all of the sequences in FINISHED_SEQ are filler values.
70
70
  # True -> finished sequence, False -> filler. Shape [batch_size, beam_size]
71
71
  FINISHED_FLAGS = "FINISHED_FLAGS"
72
+ # for prefix matching hack. The BS will only constraint the next token to
73
+ # where the mask is 1.
74
+ CONSTRAINT_MASK = "CONSTRAINT_MASK"
72
75
 
73
76
 
74
77
  def _expand_to_same_rank(tensor, target):
@@ -150,7 +153,7 @@ class SequenceBeamSearch(tf.Module):
150
153
  self.decoding_name = decoding_name
151
154
  self.noise_multiplier = noise_multiplier
152
155
 
153
- def search(self, initial_ids, initial_cache):
156
+ def search(self, initial_ids, initial_cache, constraint_mask=None):
154
157
  """Beam search for sequences with highest scores.
155
158
 
156
159
  Args:
@@ -158,6 +161,9 @@ class SequenceBeamSearch(tf.Module):
158
161
  with shape [batch_size, 1]
159
162
  initial_cache: dictionary storing values to be passed into the
160
163
  symbols_to_logits_fn.
164
+ constraint_mask: a [vocab_size] tensor, with 1 represent prefix. During
165
+ autoregressive decoding, the first token should be among where the
166
+ constraint_mask is 1.
161
167
 
162
168
  Returns:
163
169
  finished_seq and finished_scores.
@@ -165,8 +171,9 @@ class SequenceBeamSearch(tf.Module):
165
171
  batch_size = (
166
172
  initial_ids.shape.as_list()[0]
167
173
  if self.padded_decode else tf.shape(initial_ids)[0])
168
- state, state_shapes = self._create_initial_state(initial_ids, initial_cache,
169
- batch_size)
174
+ state, state_shapes = self._create_initial_state(
175
+ initial_ids, initial_cache, batch_size, constraint_mask=constraint_mask
176
+ )
170
177
 
171
178
  def _grow_alive_seq(state):
172
179
  """Grow alive sequences by one token, collect top 2*beam_size sequences.
@@ -204,6 +211,21 @@ class SequenceBeamSearch(tf.Module):
204
211
  flat_logits, flat_cache = self.symbols_to_logits_fn(
205
212
  flat_ids, i, flat_cache)
206
213
 
214
+ if _StateKeys.CONSTRAINT_MASK in state:
215
+ constraint_mask = state[_StateKeys.CONSTRAINT_MASK]
216
+ constraint_mask = tf.cond(
217
+ tf.equal(i, 0),
218
+ lambda: constraint_mask,
219
+ lambda: tf.ones_like(constraint_mask),
220
+ )
221
+ penalty = tf.cast(
222
+ tf.cast(constraint_mask != 1, tf.int32) * 999_999_999,
223
+ flat_logits.dtype,
224
+ )
225
+ flat_logits = flat_logits - penalty[tf.newaxis, :]
226
+ else:
227
+ constraint_mask = None
228
+
207
229
  if self.noise_multiplier > 0:
208
230
  noise = tf.random.uniform(flat_logits.shape, dtype=flat_logits.dtype)
209
231
  # Generates standard Gumbel(0, 1) noise, GSE Tensors
@@ -250,7 +272,7 @@ class SequenceBeamSearch(tf.Module):
250
272
  else:
251
273
  topk_seq = tf.concat(
252
274
  [topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)
253
- return topk_seq, topk_log_probs, topk_ids, new_cache
275
+ return topk_seq, topk_log_probs, topk_ids, new_cache, constraint_mask
254
276
 
255
277
  def _get_new_alive_state(new_seq, new_log_probs, new_finished_flags,
256
278
  new_cache):
@@ -363,7 +385,9 @@ class SequenceBeamSearch(tf.Module):
363
385
  new state dictionary.
364
386
  """
365
387
  # Grow alive sequences by one token.
366
- new_seq, new_log_probs, topk_ids, new_cache = _grow_alive_seq(state)
388
+ new_seq, new_log_probs, topk_ids, new_cache, constraint_mask = (
389
+ _grow_alive_seq(state)
390
+ )
367
391
  new_finished_flags = tf.equal(topk_ids, self.eos_id[0])
368
392
  for eos_id in self.eos_id[1:]:
369
393
  one_finished_flags = tf.equal(topk_ids, eos_id)
@@ -383,6 +407,8 @@ class SequenceBeamSearch(tf.Module):
383
407
  new_state = {_StateKeys.CUR_INDEX: state[_StateKeys.CUR_INDEX] + 1}
384
408
  new_state.update(alive_state)
385
409
  new_state.update(finished_state)
410
+ if constraint_mask is not None:
411
+ new_state[_StateKeys.CONSTRAINT_MASK] = constraint_mask
386
412
  return [new_state]
387
413
 
388
414
  finished_state = tf.nest.map_structure(
@@ -415,7 +441,9 @@ class SequenceBeamSearch(tf.Module):
415
441
  finished_scores = tf.where(score_cond, finished_scores, alive_log_probs)
416
442
  return finished_seq, finished_scores
417
443
 
418
- def _create_initial_state(self, initial_ids, initial_cache, batch_size):
444
+ def _create_initial_state(
445
+ self, initial_ids, initial_cache, batch_size, constraint_mask=None
446
+ ):
419
447
  """Return initial state dictionary and its shape invariants."""
420
448
  for key, value in initial_cache.items():
421
449
  for inner_value in tf.nest.flatten(value):
@@ -466,6 +494,8 @@ class SequenceBeamSearch(tf.Module):
466
494
  _StateKeys.FINISHED_SCORES: finished_scores,
467
495
  _StateKeys.FINISHED_FLAGS: finished_flags
468
496
  }
497
+ if constraint_mask is not None:
498
+ state[_StateKeys.CONSTRAINT_MASK] = constraint_mask
469
499
 
470
500
  # Create state invariants for each value in the state dictionary. Each
471
501
  # dimension must be a constant or None. A None dimension means either:
@@ -509,6 +539,10 @@ class SequenceBeamSearch(tf.Module):
509
539
  _StateKeys.FINISHED_FLAGS:
510
540
  tf.TensorShape([None, self.beam_size])
511
541
  }
542
+ if constraint_mask is not None:
543
+ state_shape_invariants[_StateKeys.CONSTRAINT_MASK] = tf.TensorShape(
544
+ [self.vocab_size]
545
+ )
512
546
 
513
547
  return state, state_shape_invariants
514
548
 
@@ -614,6 +648,7 @@ def sequence_beam_search(
614
648
  dtype="float32",
615
649
  noise_multiplier: float = 0.0,
616
650
  decoding_name=None,
651
+ constraint_mask=None,
617
652
  ):
618
653
  """Search for sequence of subtoken ids with the largest probability.
619
654
 
@@ -641,6 +676,8 @@ def sequence_beam_search(
641
676
  tf.float32.
642
677
  noise_multiplier: The amount of noise.
643
678
  decoding_name: an optional name for the decoding loop tensors.
679
+ constraint_mask: The BS will only constraint the next token to where the
680
+ mask is 1.
644
681
 
645
682
  Returns:
646
683
  Top decoded sequences [batch_size, beam_size, max_decode_length]
@@ -658,7 +695,7 @@ def sequence_beam_search(
658
695
  noise_multiplier,
659
696
  decoding_name,
660
697
  )
661
- return sbs.search(initial_ids, initial_cache)
698
+ return sbs.search(initial_ids, initial_cache, constraint_mask=constraint_mask)
662
699
 
663
700
 
664
701
  def _log_prob_from_logits(logits):
@@ -150,6 +150,54 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
150
150
  else:
151
151
  self.assertAllEqual([[[0, 1, 0, 1], [0, 1, 1, 2]]], predictions)
152
152
 
153
+ @parameterized.named_parameters([
154
+ ('padded_decode_true_with_name', True, 0.0, 'decoding'),
155
+ ('padded_decode_false_with_name', False, 0.0, 'decoding'),
156
+ ('padded_decode_true_without_name', True, 0.0, None),
157
+ ('padded_decode_false_without_name', False, 0.0, None),
158
+ ('padded_decode_false_with_noise', False, 0.5, 'decoding'),
159
+ ])
160
+ def test_sequence_beam_search_with_prefix_constraint(
161
+ self, padded_decode, noise_multiplier, name
162
+ ):
163
+ # batch_size*beam_size, max_decode_length, vocab_size
164
+ probabilities = tf.constant([
165
+ [[0.2, 0.7, 0.1], [0.5, 0.3, 0.2], [0.1, 0.8, 0.1]],
166
+ [[0.1, 0.8, 0.1], [0.3, 0.4, 0.3], [0.2, 0.1, 0.7]],
167
+ ])
168
+ # batch_size, max_decode_length, num_heads, embed_size per head
169
+ x = tf.zeros([1, 3, 2, 32], dtype=tf.float32)
170
+ cache = {'layer_%d' % layer: {'k': x, 'v': x} for layer in range(2)}
171
+
172
+ def _get_test_symbols_to_logits_fn():
173
+ """Test function that returns logits for next token."""
174
+
175
+ def symbols_to_logits_fn(_, i, cache):
176
+ logits = tf.cast(probabilities[:, i, :], tf.float32)
177
+ return logits, cache
178
+
179
+ return symbols_to_logits_fn
180
+
181
+ predictions, _ = beam_search.sequence_beam_search(
182
+ symbols_to_logits_fn=_get_test_symbols_to_logits_fn(),
183
+ initial_ids=tf.zeros([1], dtype=tf.int32),
184
+ initial_cache=cache,
185
+ vocab_size=3,
186
+ beam_size=2,
187
+ alpha=0.6,
188
+ max_decode_length=3,
189
+ eos_id=[9, 10],
190
+ padded_decode=padded_decode,
191
+ dtype=tf.float32,
192
+ noise_multiplier=noise_multiplier,
193
+ decoding_name=name,
194
+ constraint_mask=tf.constant([1, 0, 0]),
195
+ )
196
+ if noise_multiplier > 0:
197
+ self.assertAllEqual([[[0, 0, 0, 1], [0, 0, 0, 2]]], predictions)
198
+ else:
199
+ self.assertAllEqual([[[0, 0, 0, 1], [0, 0, 1, 2]]], predictions)
200
+
153
201
 
154
202
  if __name__ == '__main__':
155
203
  tf.test.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.17.0.dev20240327
3
+ Version: 2.17.0.dev20240329
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -414,8 +414,8 @@ official/nlp/modeling/networks/sparse_mixer_test.py,sha256=9AY4gelHc-rrtUexr33-j
414
414
  official/nlp/modeling/networks/xlnet_base.py,sha256=ditE18dFpJQ87U1-vC3VzgFpx0aK2Hyy6b4HgOO8De4,25867
415
415
  official/nlp/modeling/networks/xlnet_base_test.py,sha256=zt8hLCpKy5wKWsbCizyq8mLGJc32OXIqbhWp0ysQGKc,14788
416
416
  official/nlp/modeling/ops/__init__.py,sha256=VnA497WiK08ukev1d5Tjqc283YGQx6MnGyPAPk_jW7s,1011
417
- official/nlp/modeling/ops/beam_search.py,sha256=rfQluf94mAEUuwCGkKeTUTZhpllYIPysDzuJiHe4OnU,29830
418
- official/nlp/modeling/ops/beam_search_test.py,sha256=BEWDGlOIxJQaocSCZ58LHvI_OFH69SXhsEJCLZMoNuY,5749
417
+ official/nlp/modeling/ops/beam_search.py,sha256=1kwoD3SF1BiWbxeN4u77CjJXJ2hCEzOWJP295_BeWuU,31255
418
+ official/nlp/modeling/ops/beam_search_test.py,sha256=Sz1sirBnYktqQ82NbyLefVpkmLVr7BPVApVxW8DRuoI,7589
419
419
  official/nlp/modeling/ops/decoding_module.py,sha256=-Aw_A2dUbRu7jd-DY4a7iWme-yNSvfng9g_XWdCGwXI,11279
420
420
  official/nlp/modeling/ops/decoding_module_test.py,sha256=VTYYaZxihkDz1FkkwUIyc3EuCqGIW9fJS-3mYw3c4-8,2623
421
421
  official/nlp/modeling/ops/sampling_module.py,sha256=gyUoOnNdh6TJGebce5BMUxTrhk79HzPM3whuEu5BP9A,19250
@@ -1203,9 +1203,9 @@ tensorflow_models/__init__.py,sha256=etxw45SHxuwFCRX5qGxGMP83II0JfJulzNl5GSNJvhw
1203
1203
  tensorflow_models/tensorflow_models_test.py,sha256=AxUYUdiQn416UR7jg0h6rmv688esvlKDfpyDCIQkF18,1395
1204
1204
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1205
1205
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1206
- tf_models_nightly-2.17.0.dev20240327.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1207
- tf_models_nightly-2.17.0.dev20240327.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1208
- tf_models_nightly-2.17.0.dev20240327.dist-info/METADATA,sha256=fIhpn38AKOT9brIxR3wssnpAjVfCTpFQe-BMXLhRxh8,1432
1209
- tf_models_nightly-2.17.0.dev20240327.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1210
- tf_models_nightly-2.17.0.dev20240327.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1211
- tf_models_nightly-2.17.0.dev20240327.dist-info/RECORD,,
1206
+ tf_models_nightly-2.17.0.dev20240329.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1207
+ tf_models_nightly-2.17.0.dev20240329.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1208
+ tf_models_nightly-2.17.0.dev20240329.dist-info/METADATA,sha256=MJ2HcBGugMGyza9VmjXOGjsHsThDRt01z6v5PQmxTOE,1432
1209
+ tf_models_nightly-2.17.0.dev20240329.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1210
+ tf_models_nightly-2.17.0.dev20240329.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1211
+ tf_models_nightly-2.17.0.dev20240329.dist-info/RECORD,,