keras-hub-nightly 0.20.0.dev202504030357__py3-none-any.whl → 0.21.0.dev202504040358__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/models/__init__.py +5 -20
- keras_hub/api/tokenizers/__init__.py +0 -4
- keras_hub/src/layers/preprocessing/image_converter.py +26 -16
- keras_hub/src/models/gemma3/gemma3_attention.py +74 -21
- keras_hub/src/models/gemma3/gemma3_backbone.py +117 -46
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +72 -15
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +512 -355
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -19
- keras_hub/src/models/gemma3/gemma3_image_converter.py +6 -0
- keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +56 -16
- keras_hub/src/models/gemma3/gemma3_presets.py +74 -8
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +9 -0
- keras_hub/src/models/gemma3/{gemma3_vit.py → gemma3_vision_encoder.py} +150 -139
- keras_hub/src/models/qwen/qwen_backbone.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +0 -7
- keras_hub/src/models/qwen/qwen_tokenizer.py +0 -9
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -1
- keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +2 -2
- keras_hub/src/models/vit/vit_image_converter.py +8 -3
- keras_hub/src/tests/test_case.py +4 -0
- keras_hub/src/utils/tensor_utils.py +6 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/RECORD +27 -27
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
|
|
1
|
+
import numpy as np
|
1
2
|
from keras import ops
|
2
3
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
@@ -8,15 +9,21 @@ from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import (
|
|
8
9
|
)
|
9
10
|
from keras_hub.src.utils.tensor_utils import any_equal
|
10
11
|
|
12
|
+
try:
|
13
|
+
import tensorflow as tf
|
14
|
+
except ImportError:
|
15
|
+
tf = None
|
16
|
+
|
11
17
|
|
12
18
|
@keras_hub_export("keras_hub.models.Gemma3CausalLM")
|
13
19
|
class Gemma3CausalLM(CausalLM):
|
14
|
-
"""An end-to-end
|
20
|
+
"""An end-to-end multimodal Gemma3 model for causal language modeling.
|
15
21
|
|
16
22
|
A causal language model (LM) predicts the next token based on previous
|
17
23
|
tokens. This task setup can be used to train the model unsupervised on
|
18
|
-
|
19
|
-
similar to the data used for training.
|
24
|
+
images and plain text inputs, or to autoregressively generate plain text
|
25
|
+
similar to the data used for training. Note that the model is
|
26
|
+
image-text in, text out.
|
20
27
|
|
21
28
|
This model has a `generate()` method, which generates text based on a
|
22
29
|
prompt. The generation strategy used is controlled by an additional
|
@@ -30,10 +37,10 @@ class Gemma3CausalLM(CausalLM):
|
|
30
37
|
when creating the model with `from_preset()`.
|
31
38
|
|
32
39
|
Args:
|
33
|
-
backbone: A `keras_hub.models.Gemma3Backbone` instance.
|
34
40
|
preprocessor: A `keras_hub.models.Gemma3CausalLMPreprocessor` or
|
35
41
|
`None`. If `None`, this model will not apply preprocessing, and
|
36
42
|
inputs should be preprocessed before calling the model.
|
43
|
+
backbone: A `keras_hub.models.Gemma3Backbone` instance.
|
37
44
|
"""
|
38
45
|
|
39
46
|
backbone_cls = Gemma3Backbone
|
@@ -79,15 +86,55 @@ class Gemma3CausalLM(CausalLM):
|
|
79
86
|
**kwargs,
|
80
87
|
)
|
81
88
|
|
89
|
+
def _normalize_generate_inputs(
|
90
|
+
self,
|
91
|
+
inputs,
|
92
|
+
):
|
93
|
+
"""Overrides the superclass' method to handle unbatched image inputs."""
|
94
|
+
if tf and isinstance(inputs, tf.data.Dataset):
|
95
|
+
return inputs.as_numpy_iterator(), False
|
96
|
+
|
97
|
+
if self.preprocessor is None:
|
98
|
+
return [inputs], False
|
99
|
+
|
100
|
+
def normalize(x):
|
101
|
+
if isinstance(x, str):
|
102
|
+
return [x], True
|
103
|
+
if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0:
|
104
|
+
return x[tf.newaxis], True
|
105
|
+
return x, False
|
106
|
+
|
107
|
+
if isinstance(inputs, dict):
|
108
|
+
inputs["prompts"], input_is_scalar = normalize(inputs["prompts"])
|
109
|
+
|
110
|
+
# If prompt is scalar, images can be either a 3D NumPy array/Tensor,
|
111
|
+
# or, list of 3D NumPy arrays. Let's uprank images.
|
112
|
+
if input_is_scalar and "images" in inputs:
|
113
|
+
x = inputs["images"]
|
114
|
+
if isinstance(x, np.ndarray) and len(x.shape) == 3:
|
115
|
+
inputs["images"] = [x]
|
116
|
+
elif tf and isinstance(x, tf.Tensor) and x.shape.rank == 3:
|
117
|
+
inputs["images"] = x[tf.newaxis]
|
118
|
+
elif isinstance(x, list):
|
119
|
+
inputs["images"] = [x]
|
120
|
+
|
121
|
+
if "responses" in inputs:
|
122
|
+
inputs["responses"], _ = normalize(inputs["responses"])
|
123
|
+
else:
|
124
|
+
inputs, input_is_scalar = normalize(inputs)
|
125
|
+
|
126
|
+
return [inputs], input_is_scalar
|
127
|
+
|
82
128
|
def call_with_cache(
|
83
129
|
self,
|
84
130
|
token_ids,
|
85
131
|
cache,
|
86
132
|
cache_update_index,
|
87
133
|
img_embeddings=None,
|
88
|
-
|
134
|
+
vision_mask=None,
|
89
135
|
padding_mask=None,
|
90
136
|
vision_indices=None,
|
137
|
+
cache_update_mask=None,
|
91
138
|
):
|
92
139
|
"""Forward pass of `Gemma3CausalLM` with cache.
|
93
140
|
|
@@ -139,7 +186,8 @@ class Gemma3CausalLM(CausalLM):
|
|
139
186
|
cache=current_cache,
|
140
187
|
cache_update_index=cache_update_index,
|
141
188
|
padding_mask=padding_mask,
|
142
|
-
|
189
|
+
vision_mask=vision_mask,
|
190
|
+
cache_update_mask=cache_update_mask,
|
143
191
|
)
|
144
192
|
caches.append(next_cache)
|
145
193
|
cache = ops.stack(caches, axis=1)
|
@@ -151,7 +199,7 @@ class Gemma3CausalLM(CausalLM):
|
|
151
199
|
self,
|
152
200
|
token_ids,
|
153
201
|
img_embeddings,
|
154
|
-
|
202
|
+
vision_mask,
|
155
203
|
padding_mask,
|
156
204
|
vision_indices,
|
157
205
|
):
|
@@ -170,11 +218,12 @@ class Gemma3CausalLM(CausalLM):
|
|
170
218
|
logits, hidden_states, cache = self.call_with_cache(
|
171
219
|
token_ids=token_ids,
|
172
220
|
img_embeddings=img_embeddings,
|
173
|
-
|
221
|
+
vision_mask=vision_mask,
|
174
222
|
cache=cache,
|
175
223
|
cache_update_index=0,
|
176
224
|
padding_mask=padding_mask,
|
177
225
|
vision_indices=vision_indices,
|
226
|
+
cache_update_mask=None,
|
178
227
|
)
|
179
228
|
return hidden_states, cache
|
180
229
|
|
@@ -193,29 +242,33 @@ class Gemma3CausalLM(CausalLM):
|
|
193
242
|
will stop.
|
194
243
|
"""
|
195
244
|
|
196
|
-
token_ids, padding_mask, images,
|
245
|
+
token_ids, padding_mask, images, vision_mask, vision_indices = (
|
197
246
|
inputs["token_ids"],
|
198
247
|
inputs["padding_mask"],
|
199
248
|
inputs.get("images", None),
|
200
|
-
inputs.get("
|
249
|
+
inputs.get("vision_mask", None),
|
201
250
|
inputs.get("vision_indices", None),
|
202
251
|
)
|
203
252
|
if not self.backbone.text_only_model:
|
204
|
-
|
205
|
-
|
206
|
-
|
253
|
+
# Handle an unbatched image. Unlike `token_ids` and
|
254
|
+
# `padding_mask`, this will not automatically be upranked.
|
255
|
+
if len(ops.shape(images)) == 4:
|
207
256
|
images = ops.expand_dims(images, axis=0)
|
257
|
+
if len(ops.shape(vision_mask)) == 1:
|
258
|
+
vision_mask = ops.expand_dims(vision_mask, axis=0)
|
259
|
+
if len(ops.shape(vision_indices)) == 1:
|
260
|
+
vision_indices = ops.expand_dims(vision_indices, axis=0)
|
208
261
|
img_embeddings = self.backbone.vision_encoder(images)
|
209
262
|
else:
|
210
263
|
img_embeddings = None
|
211
|
-
|
264
|
+
vision_mask = None
|
212
265
|
vision_indices = None
|
213
266
|
|
214
267
|
# Create and seed cache with a single forward pass.
|
215
268
|
hidden_states, cache = self._build_cache(
|
216
269
|
token_ids,
|
217
270
|
img_embeddings,
|
218
|
-
|
271
|
+
vision_mask,
|
219
272
|
padding_mask,
|
220
273
|
vision_indices,
|
221
274
|
)
|
@@ -230,10 +283,14 @@ class Gemma3CausalLM(CausalLM):
|
|
230
283
|
cache_update_index = index - 1
|
231
284
|
batch_size = ops.shape(prompt)[0]
|
232
285
|
prompt = ops.slice(prompt, [0, index - 1], [batch_size, 1])
|
286
|
+
sliced_cache_update_mask = ops.slice(
|
287
|
+
~padding_mask, [0, index - 1], [batch_size, 1]
|
288
|
+
)
|
233
289
|
logits, hidden_states, cache = self.call_with_cache(
|
234
290
|
token_ids=prompt,
|
235
291
|
cache=cache,
|
236
292
|
cache_update_index=cache_update_index,
|
293
|
+
cache_update_mask=sliced_cache_update_mask,
|
237
294
|
)
|
238
295
|
return (
|
239
296
|
ops.squeeze(logits, axis=1),
|