keras-hub-nightly 0.20.0.dev202503260356__py3-none-any.whl → 0.20.0.dev202503270400__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,352 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.layers.modeling.reversible_embedding import (
6
+ ReversibleEmbedding,
7
+ )
8
+ from keras_hub.src.models.backbone import Backbone
9
+ from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
10
+ from keras_hub.src.models.gemma3.gemma3_decoder_block import Gemma3DecoderBlock
11
+ from keras_hub.src.models.gemma3.gemma3_interleave_embeddings import (
12
+ Gemma3InterleaveEmbeddings,
13
+ )
14
+
15
+
16
+ @keras_hub_export("keras_hub.models.Gemma3Backbone")
17
+ class Gemma3Backbone(Backbone):
18
+ """Gemma3 core network with hyperparameters.
19
+
20
+ This backbone implements the Gemma3 model architecture. Gemma3 is a
21
+ vision-language model (image-text in, text out). The text input is encoded
22
+ using an embedding layer; images are encoded using a vision transformer.
23
+ After encoding these two modalities, the image embeddings are placed in the
24
+ correct position in the text embedding sequence. The mixed sequence of
25
+ embeddings is then passed through transformer decoder layers.
26
+
27
+ Currently, this model supports only the `vision_encoder = None` case, i.e.,
28
+ working only with text.
29
+
30
+ For a higher-level object for text-generation, see
31
+ `keras_hub.models.Gemma3CausalLM`.
32
+
33
+ The default constructor gives a fully customizable, randomly initialized
34
+ Gemma3 model with any vision encoder, number of heads, embedding dimensions,
35
+ and equivalent configuration for the decoder layers. To load preset
36
+ architectures and weights, use the `from_preset` constructor.
37
+
38
+ Args:
39
+ vocabulary_size: int. The size of the token vocabulary.
40
+ image_size: int. The resolution of the image in both width and height.
41
+ The input images must be square.
42
+ num_layers: int. The number of transformer mixed decoder layers.
43
+ num_query_heads: int. The number of heads for the query projections in
44
+ the mixed decoder attention layer.
45
+ num_key_value_heads: int. The number of heads for the key and value
46
+ projections in the mixed decoder attention layers.
47
+ hidden_dim: int. The size of the transformer hidden state at the end
48
+ of each mixed transformer layer.
49
+ intermediate_dim: int. The output dimension of the first Dense layer in
50
+ a two-layer feedforward network for each transformer decoder block.
51
+ head_dim: int. The size of each attention head in the mixed decoder.
52
+ query_head_dim_normalize: boolean. If `True` normalize the query before
53
+ attention with `head_dim`. If `False`, normalize the query with
54
+ `hidden_dim / num_query_heads`. Defaults to `True`.
55
+ use_query_key_norm: bool. If `True`, apply a RMS Norm layer to query and
56
+ key before projecting them. Defaults to `True`.
57
+ use_post_ffw_norm: boolean. Whether to normalize after the feedforward
58
+ block. Defaults to `False`.
59
+ use_post_attention_norm: boolean. Whether to normalize after the
60
+ attention block. Defaults to `False`.
61
+ attention_logit_soft_cap: `None` or int. Soft cap for the attention
62
+ logits. Defaults to `None`.
63
+ final_logit_soft_cap: `None` or int. Soft cap for the final logits.
64
+ Defaults to `None`.
65
+ use_sliding_window_attention: boolean. Whether to use sliding local
66
+ window attention. Defaults to `False`.
67
+ sliding_window_size: int. Size of the sliding local window. Defaults to
68
+ `4096`.
69
+ vision_encoder: `keras.Model` or `keras.layers.Layer` instance. `call()`
70
+ takes in images and returns corresponding sequence of embeddings.
71
+ layer_norm_epsilon: float. The epsilon value user for every layer norm
72
+ in all transformer blocks. Defaults to `1e-6`.
73
+ dropout: float. Dropout probability for the Transformer decoder blocks.
74
+ Defaults to `0`.
75
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
76
+ for the models computations and weights. Note that some
77
+ computations, such as softmax and layer normalization will always
78
+ be done a float32 precision regardless of dtype.
79
+
80
+ Example:
81
+ ```python
82
+ input_data = {}
83
+ input_data["token_ids"] = np.ones(shape=(1, 300), dtype="int32")
84
+ input_data["padding_mask"] = (
85
+ np.expand_dims(np.array([1] * 280 + [0] * (300 - 280)), axis=0)
86
+ .astype(bool)
87
+ )
88
+
89
+ # Pretrained Gemma3 decoder.
90
+ model = keras_hub.models.Gemma3Backbone.from_preset("gemma3_instruct_4b")
91
+ model(input_data)
92
+
93
+ config = {
94
+ 'vocabulary_size': 262144,
95
+ 'image_size': 896,
96
+ 'num_layers': 34,
97
+ 'num_query_heads': 8,
98
+ 'num_key_value_heads': 4,
99
+ 'hidden_dim': 2560,
100
+ 'intermediate_dim': 10240,
101
+ 'head_dim': 256,
102
+ 'query_head_dim_normalize': True,
103
+ 'use_post_ffw_norm': True,
104
+ 'use_post_attention_norm': True,
105
+ 'final_logit_soft_cap': None,
106
+ 'attention_logit_soft_cap': None,
107
+ 'sliding_window_size': 1024,
108
+ 'use_sliding_window_attention': True,
109
+ 'vision_encoder': None,
110
+ 'layer_norm_epsilon': 1e-06,
111
+ dtype: "bfloat16",
112
+ }
113
+
114
+ model = keras_hub.models.Gemma3Backbone(**config)
115
+ model(input_data)
116
+ ```
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ vocabulary_size,
122
+ image_size,
123
+ num_layers,
124
+ num_query_heads,
125
+ num_key_value_heads,
126
+ hidden_dim,
127
+ intermediate_dim,
128
+ head_dim,
129
+ query_head_dim_normalize=True,
130
+ use_query_key_norm=True,
131
+ use_post_ffw_norm=False,
132
+ use_post_attention_norm=False,
133
+ attention_logit_soft_cap=None,
134
+ final_logit_soft_cap=None,
135
+ use_sliding_window_attention=False,
136
+ sliding_window_size=1024,
137
+ vision_encoder=None,
138
+ layer_norm_epsilon=1e-6,
139
+ dropout=0,
140
+ dtype=None,
141
+ **kwargs,
142
+ ):
143
+ if vision_encoder is not None:
144
+ raise ValueError(
145
+ "Currently, only the text version of the Gemma3 model is "
146
+ "supported."
147
+ )
148
+
149
+ # === Layers ===
150
+ self.token_embedding = ReversibleEmbedding(
151
+ input_dim=vocabulary_size,
152
+ output_dim=hidden_dim,
153
+ tie_weights=True,
154
+ embeddings_initializer=keras.initializers.VarianceScaling(
155
+ scale=1.0,
156
+ mode="fan_in",
157
+ distribution="untruncated_normal",
158
+ seed=None,
159
+ ),
160
+ dtype=dtype,
161
+ logit_soft_cap=final_logit_soft_cap,
162
+ name="token_embedding",
163
+ )
164
+
165
+ self.vision_encoder = vision_encoder
166
+ text_only_model = True if vision_encoder is None else False
167
+ if not text_only_model:
168
+ self.interleave_embeddings = Gemma3InterleaveEmbeddings(
169
+ num_vision_tokens_per_image=self.vision_encoder.num_vision_tokens_per_image,
170
+ dtype=dtype,
171
+ name="interleave_embeddings",
172
+ )
173
+
174
+ self.transformer_layers = []
175
+ for i in range(num_layers):
176
+ # 5 local, 1 global
177
+ sliding_window = use_sliding_window_attention and (i % 6 < 5)
178
+ rope_wavelength = 10_000.0 if sliding_window else 1_000_000.0
179
+ rope_scaling_factor = 1.0 if sliding_window else 8.0
180
+ layer = Gemma3DecoderBlock(
181
+ hidden_dim=hidden_dim,
182
+ intermediate_dim=intermediate_dim,
183
+ head_dim=head_dim,
184
+ num_query_heads=num_query_heads,
185
+ num_key_value_heads=num_key_value_heads,
186
+ query_head_dim_normalize=query_head_dim_normalize,
187
+ use_query_key_norm=use_query_key_norm,
188
+ use_post_ffw_norm=use_post_ffw_norm,
189
+ use_post_attention_norm=use_post_attention_norm,
190
+ gate_dim_reduction=1,
191
+ logit_soft_cap=attention_logit_soft_cap,
192
+ use_sliding_window_attention=sliding_window,
193
+ sliding_window_size=sliding_window_size,
194
+ rope_wavelength=rope_wavelength,
195
+ rope_scaling_factor=rope_scaling_factor,
196
+ dropout=dropout,
197
+ dtype=dtype,
198
+ name=f"decoder_block_{i}",
199
+ )
200
+ self.transformer_layers.append(layer)
201
+ self.layer_norm = RMSNormalization(
202
+ epsilon=layer_norm_epsilon,
203
+ dtype=dtype,
204
+ name="final_normalization",
205
+ )
206
+
207
+ # === Functional Model ===
208
+
209
+ # == Model inputs ==
210
+ if not text_only_model:
211
+ image_input = keras.Input(
212
+ shape=(None, image_size, image_size, 3),
213
+ name="images",
214
+ )
215
+ vision_indices_input = keras.Input(
216
+ shape=(None,), dtype="int32", name="vision_indices"
217
+ )
218
+ # TODO: Consider removing `text_mask_input` and using
219
+ # `vision_indices_input` to infer it directly.
220
+ text_mask_input = keras.Input(
221
+ shape=(None,), dtype="int32", name="text_mask"
222
+ )
223
+
224
+ token_id_input = keras.Input(
225
+ shape=(None,), dtype="int32", name="token_ids"
226
+ )
227
+ padding_mask_input = keras.Input(
228
+ shape=(None,), dtype="int32", name="padding_mask"
229
+ )
230
+
231
+ # == Text embeddings ==
232
+ text_embeddings = self.token_embedding(token_id_input)
233
+
234
+ text_embeddings = text_embeddings * ops.cast(
235
+ ops.sqrt(hidden_dim), text_embeddings.dtype
236
+ )
237
+
238
+ # == Image Embeddings ==
239
+ if not text_only_model:
240
+ img_embeddings = self.vision_encoder(image_input)
241
+
242
+ ## == Interleaving text and images ==
243
+ # Place image embeddings in the right position in
244
+ # `text_embeddings`.
245
+ x = self.interleave_embeddings(
246
+ image_embeddings=img_embeddings,
247
+ text_embeddings=text_embeddings,
248
+ vision_indices=vision_indices_input,
249
+ )
250
+ else:
251
+ x = text_embeddings
252
+
253
+ # == Decoder layers ==
254
+ for transformer_layer in self.transformer_layers:
255
+ x = transformer_layer(
256
+ x,
257
+ padding_mask=padding_mask_input,
258
+ text_mask=None if text_only_model else text_mask_input,
259
+ )
260
+ sequence_output = self.layer_norm(x)
261
+
262
+ inputs = {
263
+ "token_ids": token_id_input,
264
+ "padding_mask": padding_mask_input,
265
+ }
266
+ if not text_only_model:
267
+ inputs.update(
268
+ {
269
+ "images": image_input,
270
+ "vision_indices": vision_indices_input,
271
+ "text_mask": text_mask_input,
272
+ }
273
+ )
274
+
275
+ super().__init__(
276
+ inputs=inputs,
277
+ outputs=sequence_output,
278
+ dtype=dtype,
279
+ **kwargs,
280
+ )
281
+
282
+ # === Config ===
283
+ self.vocabulary_size = vocabulary_size
284
+ self.image_size = image_size
285
+ self.num_layers = num_layers
286
+ self.num_query_heads = num_query_heads
287
+ self.num_key_value_heads = num_key_value_heads
288
+ self.hidden_dim = hidden_dim
289
+ self.intermediate_dim = intermediate_dim
290
+ self.head_dim = head_dim
291
+ self.query_head_dim_normalize = query_head_dim_normalize
292
+ self.use_query_key_norm = use_query_key_norm
293
+ self.use_post_ffw_norm = use_post_ffw_norm
294
+ self.use_post_attention_norm = use_post_attention_norm
295
+ self.attention_logit_soft_cap = attention_logit_soft_cap
296
+ self.final_logit_soft_cap = final_logit_soft_cap
297
+ self.use_sliding_window_attention = use_sliding_window_attention
298
+ self.sliding_window_size = sliding_window_size
299
+ self.layer_norm_epsilon = layer_norm_epsilon
300
+ self.dropout = dropout
301
+
302
+ # Keep `num_vision_tokens_per_image` as a backbone property for easy
303
+ # access.
304
+ if not text_only_model:
305
+ self.num_vision_tokens_per_image = (
306
+ self.vision_encoder.num_vision_tokens_per_image
307
+ )
308
+ # Also, the `text_only_model`.
309
+ self.text_only_model = text_only_model
310
+
311
+ def get_config(self):
312
+ config = super().get_config()
313
+ config.update(
314
+ {
315
+ "vocabulary_size": self.vocabulary_size,
316
+ "image_size": self.image_size,
317
+ "num_layers": self.num_layers,
318
+ "num_query_heads": self.num_query_heads,
319
+ "num_key_value_heads": self.num_key_value_heads,
320
+ "hidden_dim": self.hidden_dim,
321
+ "intermediate_dim": self.intermediate_dim,
322
+ "head_dim": self.head_dim,
323
+ "query_head_dim_normalize": self.query_head_dim_normalize,
324
+ "use_query_key_norm": self.use_query_key_norm,
325
+ "use_post_ffw_norm": self.use_post_ffw_norm,
326
+ "use_post_attention_norm": self.use_post_attention_norm,
327
+ "attention_logit_soft_cap": self.attention_logit_soft_cap,
328
+ "final_logit_soft_cap": self.final_logit_soft_cap,
329
+ "use_sliding_window_attention": (
330
+ self.use_sliding_window_attention
331
+ ),
332
+ "sliding_window_size": self.sliding_window_size,
333
+ "vision_encoder": None
334
+ if self.vision_encoder is None
335
+ else keras.layers.serialize(self.vision_encoder),
336
+ "layer_norm_epsilon": self.layer_norm_epsilon,
337
+ "dropout": self.dropout,
338
+ }
339
+ )
340
+ return config
341
+
342
+ @classmethod
343
+ def from_config(cls, config):
344
+ config.update(
345
+ {
346
+ "vision_encoder": None
347
+ if config["vision_encoder"] is None
348
+ else keras.layers.deserialize(config["vision_encoder"]),
349
+ }
350
+ )
351
+
352
+ return super().from_config(config)
@@ -0,0 +1,306 @@
1
+ from keras import ops
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.causal_lm import CausalLM
5
+ from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
6
+ from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import (
7
+ Gemma3CausalLMPreprocessor,
8
+ )
9
+ from keras_hub.src.utils.tensor_utils import any_equal
10
+
11
+
12
+ @keras_hub_export("keras_hub.models.Gemma3CausalLM")
13
+ class Gemma3CausalLM(CausalLM):
14
+ """An end-to-end multi modal Gemma3 model for causal language modeling.
15
+
16
+ A causal language model (LM) predicts the next token based on previous
17
+ tokens. This task setup can be used to train the model unsupervised on
18
+ image and plain text input, or to autoregressively generate plain text
19
+ similar to the data used for training.
20
+
21
+ This model has a `generate()` method, which generates text based on a
22
+ prompt. The generation strategy used is controlled by an additional
23
+ `sampler` argument on `compile()`. You can recompile the model with
24
+ different `keras_hub.samplers` objects to control the generation. By
25
+ default, `"greedy"` sampling will be used.
26
+
27
+ This model can optionally be configured with a `preprocessor` layer, in
28
+ which case it will automatically apply preprocessing to string inputs during
29
+ `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
30
+ when creating the model with `from_preset()`.
31
+
32
+ Args:
33
+ backbone: A `keras_hub.models.Gemma3Backbone` instance.
34
+ preprocessor: A `keras_hub.models.Gemma3CausalLMPreprocessor` or
35
+ `None`. If `None`, this model will not apply preprocessing, and
36
+ inputs should be preprocessed before calling the model.
37
+ """
38
+
39
+ backbone_cls = Gemma3Backbone
40
+ preprocessor_cls = Gemma3CausalLMPreprocessor
41
+
42
+ def __init__(
43
+ self,
44
+ preprocessor,
45
+ backbone,
46
+ **kwargs,
47
+ ):
48
+ # === Layers ===
49
+ self.preprocessor = preprocessor
50
+ self.backbone = backbone
51
+
52
+ # === Functional Model ===
53
+ # This must be "backbone.input" i.e. the full input structure,
54
+ # rather than "backbone.inputs" which is the flattened list of inputs.
55
+ inputs = backbone.input
56
+ hidden_state = backbone(inputs=inputs)
57
+ outputs = backbone.token_embedding(hidden_state, reverse=True)
58
+
59
+ super().__init__(
60
+ inputs=inputs,
61
+ outputs=outputs,
62
+ **kwargs,
63
+ )
64
+
65
+ def compile(
66
+ self,
67
+ optimizer="auto",
68
+ loss="auto",
69
+ *,
70
+ weighted_metrics="auto",
71
+ sampler="greedy",
72
+ **kwargs,
73
+ ):
74
+ super().compile(
75
+ optimizer=optimizer,
76
+ loss=loss,
77
+ weighted_metrics=weighted_metrics,
78
+ sampler=sampler,
79
+ **kwargs,
80
+ )
81
+
82
+ def call_with_cache(
83
+ self,
84
+ token_ids,
85
+ cache,
86
+ cache_update_index,
87
+ img_embeddings=None,
88
+ text_mask=None,
89
+ padding_mask=None,
90
+ vision_indices=None,
91
+ ):
92
+ """Forward pass of `Gemma3CausalLM` with cache.
93
+
94
+ `call_with_cache` adds an additional forward pass for the model for
95
+ autoregressive inference. Unlike calling the model directly, this method
96
+ allows caching previous key/value Tensors in multi-head attention layer,
97
+ and avoids recomputing the outputs of seen tokens.
98
+
99
+ Args:
100
+ token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
101
+ cache: a dense float Tensor, the cache of key and value.
102
+ cache_update_index: int, or int Tensor. The index of current inputs
103
+ in the whole sequence.
104
+ img_embeddings: a dense float Tensor with shape
105
+ `(batch_size, image_sequence_length, hidden_dim)`.
106
+ padding_mask: a dense int Tensor with shape
107
+ `(batch_size, max_length)`.
108
+
109
+ Returns:
110
+ A (logits, hidden_states, cache) tuple. Where `logits` is the
111
+ language model logits for the input token_ids, `hidden_states` is
112
+ the final hidden representation of the input tokens, and `cache` is
113
+ the decoding cache.
114
+ """
115
+
116
+ text_embeddings = self.backbone.token_embedding(token_ids)
117
+ text_embeddings = text_embeddings * ops.cast(
118
+ ops.sqrt(self.backbone.hidden_dim), text_embeddings.dtype
119
+ )
120
+
121
+ # Interleaving logic.
122
+ ## == Interleaving text and images ==
123
+ # Place the image embeddings in the right position in `text_embeddings`.
124
+ if img_embeddings is not None:
125
+ x = self.backbone.interleave_embeddings(
126
+ image_embeddings=img_embeddings,
127
+ text_embeddings=text_embeddings,
128
+ vision_indices=vision_indices,
129
+ )
130
+ else:
131
+ x = text_embeddings
132
+
133
+ # Each decoder layer has a cache; we update them separately.
134
+ caches = []
135
+ for i, transformer_layer in enumerate(self.backbone.transformer_layers):
136
+ current_cache = cache[:, i, ...]
137
+ x, next_cache = transformer_layer(
138
+ x,
139
+ cache=current_cache,
140
+ cache_update_index=cache_update_index,
141
+ padding_mask=padding_mask,
142
+ text_mask=text_mask,
143
+ )
144
+ caches.append(next_cache)
145
+ cache = ops.stack(caches, axis=1)
146
+ hidden_states = x = self.backbone.layer_norm(x)
147
+ logits = self.backbone.token_embedding(x, reverse=True)
148
+ return logits, hidden_states, cache
149
+
150
+ def _build_cache(
151
+ self,
152
+ token_ids,
153
+ img_embeddings,
154
+ text_mask,
155
+ padding_mask,
156
+ vision_indices,
157
+ ):
158
+ """Build an empty cache for use with `call_with_cache()`."""
159
+ batch_size = ops.shape(token_ids)[0]
160
+ max_length = (
161
+ ops.shape(token_ids)[1]
162
+ # + self.backbone.image_sequence_length
163
+ )
164
+ num_layers = self.backbone.num_layers
165
+ num_heads = self.backbone.num_key_value_heads
166
+ head_dim = self.backbone.head_dim
167
+ shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
168
+ cache = ops.zeros(shape, dtype=self.compute_dtype)
169
+ # Seed the cache.
170
+ logits, hidden_states, cache = self.call_with_cache(
171
+ token_ids=token_ids,
172
+ img_embeddings=img_embeddings,
173
+ text_mask=text_mask,
174
+ cache=cache,
175
+ cache_update_index=0,
176
+ padding_mask=padding_mask,
177
+ vision_indices=vision_indices,
178
+ )
179
+ return hidden_states, cache
180
+
181
+ def generate_step(self, inputs, stop_token_ids=[106]):
182
+ """A compilable generation function for a single batch of inputs.
183
+
184
+ This function represents the inner, XLA-compilable, generation function
185
+ for a single batch of inputs. Inputs should have the same structure as
186
+ model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
187
+
188
+ Args:
189
+ inputs: A dictionary with two keys `"token_ids"` and
190
+ `"padding_mask"` and batched tensor values.
191
+ stop_token_ids: Tuple of id's of end token's to stop on. If all
192
+ sequences have produced a new stop token, generation
193
+ will stop.
194
+ """
195
+
196
+ token_ids, padding_mask, images, text_mask, vision_indices = (
197
+ inputs["token_ids"],
198
+ inputs["padding_mask"],
199
+ inputs.get("images", None),
200
+ inputs.get("text_mask", None),
201
+ inputs.get("vision_indices", None),
202
+ )
203
+ if not self.backbone.text_only_model:
204
+ if len(ops.shape(images)) == 3:
205
+ # Handle an unbatched image. Unlike `token_ids` and
206
+ # `padding_mask` this will not automatically be upranked.
207
+ images = ops.expand_dims(images, axis=0)
208
+ img_embeddings = self.backbone.vision_encoder(images)
209
+ else:
210
+ img_embeddings = None
211
+ text_mask = None
212
+ vision_indices = None
213
+
214
+ # Create and seed cache with a single forward pass.
215
+ hidden_states, cache = self._build_cache(
216
+ token_ids,
217
+ img_embeddings,
218
+ text_mask,
219
+ padding_mask,
220
+ vision_indices,
221
+ )
222
+
223
+ # Compute the lengths of all user inputted tokens ids.
224
+ row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
225
+ # Start at the first index that has no user inputted id.
226
+ index = ops.min(row_lengths)
227
+
228
+ def next(prompt, cache, index):
229
+ # The cache index is the index of our previous token.
230
+ cache_update_index = index - 1
231
+ batch_size = ops.shape(prompt)[0]
232
+ prompt = ops.slice(prompt, [0, index - 1], [batch_size, 1])
233
+ logits, hidden_states, cache = self.call_with_cache(
234
+ token_ids=prompt,
235
+ cache=cache,
236
+ cache_update_index=cache_update_index,
237
+ )
238
+ return (
239
+ ops.squeeze(logits, axis=1),
240
+ ops.squeeze(hidden_states, axis=1),
241
+ cache,
242
+ )
243
+
244
+ token_ids = self.sampler(
245
+ next=next,
246
+ prompt=token_ids,
247
+ cache=cache,
248
+ index=index,
249
+ mask=padding_mask,
250
+ stop_token_ids=stop_token_ids,
251
+ hidden_states=hidden_states,
252
+ model=self,
253
+ )
254
+
255
+ # Compute an output padding mask with the token ids we updated.
256
+ if stop_token_ids is not None:
257
+ # Build a mask of `stop_token_ids` locations not in the original
258
+ # prompt (not in locations where `padding_mask` is True).
259
+ end_locations = any_equal(
260
+ token_ids, stop_token_ids, ops.logical_not(padding_mask)
261
+ )
262
+
263
+ end_locations = ops.cast(end_locations, "int32")
264
+ # Use cumsum to get ones in all locations after end_locations.
265
+ cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
266
+ overflow = cumsum - end_locations
267
+ # Our padding mask is the inverse of these overflow locations.
268
+ padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
269
+ else:
270
+ # Without early stopping, all locations will have been updated.
271
+ padding_mask = ops.ones_like(token_ids, dtype="bool")
272
+ return {
273
+ "token_ids": token_ids,
274
+ "padding_mask": padding_mask,
275
+ "images": images,
276
+ }
277
+
278
+ def generate(
279
+ self,
280
+ inputs,
281
+ max_length=None,
282
+ stop_token_ids="auto",
283
+ strip_prompt=False,
284
+ ):
285
+ # If `auto`, add `<end_of_turn>` as a stop token too.
286
+ if self.preprocessor is None and stop_token_ids == "auto":
287
+ raise ValueError(
288
+ "A `preprocessor` must be attached to the model if "
289
+ '`stop_token_ids="auto"`. Currently `preprocessor=None`. To '
290
+ "call `generate()` with preprocessing detached, either pass "
291
+ "`stop_token_ids=None` to always generate until `max_length` "
292
+ "or pass a tuple of token ids that should terminate generation "
293
+ "as `stop_token_ids`."
294
+ )
295
+ elif stop_token_ids == "auto":
296
+ stop_token_ids = [
297
+ self.preprocessor.tokenizer.end_token_id,
298
+ self.preprocessor.tokenizer.token_to_id("<end_of_turn>"),
299
+ ]
300
+
301
+ return super().generate(
302
+ inputs,
303
+ max_length=max_length,
304
+ stop_token_ids=stop_token_ids,
305
+ strip_prompt=strip_prompt,
306
+ )