keras-hub-nightly 0.20.0.dev202504030357__py3-none-any.whl → 0.21.0.dev202504050402__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. keras_hub/api/models/__init__.py +5 -20
  2. keras_hub/api/tokenizers/__init__.py +0 -4
  3. keras_hub/src/layers/preprocessing/image_converter.py +26 -16
  4. keras_hub/src/models/gemma3/gemma3_attention.py +74 -21
  5. keras_hub/src/models/gemma3/gemma3_backbone.py +117 -46
  6. keras_hub/src/models/gemma3/gemma3_causal_lm.py +72 -15
  7. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +512 -355
  8. keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -19
  9. keras_hub/src/models/gemma3/gemma3_image_converter.py +6 -0
  10. keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +56 -16
  11. keras_hub/src/models/gemma3/gemma3_presets.py +74 -8
  12. keras_hub/src/models/gemma3/gemma3_tokenizer.py +9 -0
  13. keras_hub/src/models/gemma3/{gemma3_vit.py → gemma3_vision_encoder.py} +150 -139
  14. keras_hub/src/models/qwen/qwen_backbone.py +0 -7
  15. keras_hub/src/models/qwen/qwen_causal_lm.py +0 -7
  16. keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +0 -7
  17. keras_hub/src/models/qwen/qwen_tokenizer.py +0 -9
  18. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -1
  19. keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +2 -2
  20. keras_hub/src/models/vit/vit_image_converter.py +8 -3
  21. keras_hub/src/tests/test_case.py +4 -0
  22. keras_hub/src/utils/tensor_utils.py +6 -0
  23. keras_hub/src/version_utils.py +1 -1
  24. {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/METADATA +1 -1
  25. {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/RECORD +27 -27
  26. {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/WHEEL +0 -0
  27. {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/top_level.txt +0 -0
@@ -14,16 +14,17 @@ from keras_hub.src.models.gemma3.rms_normalization import RMSNormalization
14
14
  class Gemma3DecoderBlock(keras.layers.Layer):
15
15
  """Transformer decoder layer for Gemma3.
16
16
 
17
- This is different from Gemma and Gemma2 in several ways:
18
-
19
- - `use_query_key_norm`: Applies RMS Norm on query, key.
20
- - `rope_wavelength`: RoPE wavelength differs from local to global attention
21
- layers.
22
- - `rope_scaling_factor`: RoPE scaling factor differs from local to global
23
- attention layers.
24
- - `gate_dim_reduction`: In the gating layers, Gemma and Gemma2 reduce
25
- intermediate dimension by 2. For Gemma3, no such reduction happens.
26
- - Uses bidirectional attention for images, and causal for everything else.
17
+ This decoder layer is the same as the layer used for Gemma and Gemma2.
18
+ However, there are a few key differences. Firstly, image tokens have
19
+ bidirectional masking. Additionally, this layer exposes the following args:
20
+
21
+ `use_query_key_norm`: bool. If True, apply RMS normalization on query
22
+ and key. For Gemma3, this is True.
23
+ `rope_wavelength`: float. Configurable value for RoPE wavelength. Gemma3
24
+ uses 10K for local attention layers and 1M for global attention layers.
25
+ `gate_dim_reduction`: int. In the gating layers, the output dimension is
26
+ `intermediate_dim // gate_dim_reduction`. For Gemma and Gemma2, this
27
+ value is 2. For Gemma3, it is 1.
27
28
  """
28
29
 
29
30
  def __init__(
@@ -160,9 +161,10 @@ class Gemma3DecoderBlock(keras.layers.Layer):
160
161
  # Isometric
161
162
  return input_shape
162
163
 
163
- def _compute_image_bidirectional_attention_mask(self, text_mask):
164
- # text_mask is True for text, False for images. Shape of (bsz, seq_len).
165
- bidirectional_mask = ops.logical_not(text_mask)
164
+ def _compute_image_bidirectional_attention_mask(self, vision_mask):
165
+ # vision_mask is False for text, True for images. Shape of
166
+ # (bsz, seq_len).
167
+ bidirectional_mask = vision_mask
166
168
 
167
169
  # Left pad with 0.
168
170
  padded_mask = ops.cast(
@@ -194,7 +196,7 @@ class Gemma3DecoderBlock(keras.layers.Layer):
194
196
  self,
195
197
  x,
196
198
  padding_mask,
197
- text_mask,
199
+ vision_mask,
198
200
  cache,
199
201
  cache_update_index,
200
202
  ):
@@ -216,9 +218,9 @@ class Gemma3DecoderBlock(keras.layers.Layer):
216
218
 
217
219
  # Compute bidirectional mask (image tokens can attend to each other
218
220
  # in both directions, within the same image).
219
- if text_mask is not None:
221
+ if vision_mask is not None:
220
222
  bidirectional_image_mask = (
221
- self._compute_image_bidirectional_attention_mask(text_mask)
223
+ self._compute_image_bidirectional_attention_mask(vision_mask)
222
224
  )
223
225
  causal_mask = ops.logical_or(causal_mask, bidirectional_image_mask)
224
226
 
@@ -232,14 +234,15 @@ class Gemma3DecoderBlock(keras.layers.Layer):
232
234
  self,
233
235
  x,
234
236
  padding_mask=None,
235
- text_mask=None,
237
+ vision_mask=None,
236
238
  cache=None,
237
239
  cache_update_index=0,
240
+ cache_update_mask=None,
238
241
  ):
239
- # Note: `text_mask` is used only for Gemma33.
242
+ # Note: `vision_mask` is used only for Gemma3.
240
243
  normalized_x = self.pre_attention_norm(x)
241
244
  attention_mask = self._compute_attention_mask(
242
- normalized_x, padding_mask, text_mask, cache, cache_update_index
245
+ normalized_x, padding_mask, vision_mask, cache, cache_update_index
243
246
  )
244
247
  if cache is not None:
245
248
  attention, new_cache = self.attention(
@@ -247,6 +250,7 @@ class Gemma3DecoderBlock(keras.layers.Layer):
247
250
  attention_mask=attention_mask,
248
251
  cache=cache,
249
252
  cache_update_index=cache_update_index,
253
+ cache_update_mask=cache_update_mask,
250
254
  )
251
255
  else:
252
256
  attention = self.attention(
@@ -6,3 +6,9 @@ from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
6
6
  @keras_hub_export("keras_hub.layers.Gemma3ImageConverter")
7
7
  class Gemma3ImageConverter(ImageConverter):
8
8
  backbone_cls = Gemma3Backbone
9
+
10
+ def __init__(self, **kwargs):
11
+ # Always do image preprocessing in float32
12
+ kwargs.pop("dtype", None)
13
+ dtype = "float32"
14
+ super().__init__(dtype=dtype, **kwargs)
@@ -5,12 +5,17 @@ from keras import ops
5
5
  class Gemma3InterleaveEmbeddings(keras.layers.Layer):
6
6
  """Places image embeddings in the correct position in an embedding sequence.
7
7
 
8
+ For Gemma3, images can be in any position in the input sequence. In order
9
+ to do accomplish this, we have image placeholder tokens in the input
10
+ sequence. We fill up these positions with the image embeddings as returned
11
+ by the vision encoder.
12
+
8
13
  Args:
9
14
  num_vision_tokens_per_image: int. Number of soft tokens per image.
10
15
  """
11
16
 
12
- def __init__(self, num_vision_tokens_per_image, **kwargs):
13
- super().__init__(**kwargs)
17
+ def __init__(self, num_vision_tokens_per_image, dtype=None, **kwargs):
18
+ super().__init__(dtype=dtype, **kwargs)
14
19
 
15
20
  self.num_vision_tokens_per_image = num_vision_tokens_per_image
16
21
 
@@ -19,12 +24,17 @@ class Gemma3InterleaveEmbeddings(keras.layers.Layer):
19
24
  Integrates image embeddings into a text embedding sequence.
20
25
 
21
26
  Args:
22
- image_embeddings: Tensor of shape
23
- `(batch_size * num_images_per_prompt,
24
- num_vision_tokens_per_image, embedding_dim)`.
25
- text_embeddings: Tensor of shape
26
- `(batch_size, seq_length, embedding_dim)`.
27
- text_mask: Boolean tensor of shape `(batch_size, seq_length)`.
27
+ image_embeddings: tensor. Image embeddings as returned by the
28
+ vision encoder (`Gemma3VisionEncoder`, usually). Shape:
29
+ `(batch_size * num_images_per_prompt, `
30
+ `num_vision_tokens_per_image, embedding_dim)`.
31
+ text_embeddings: tensor. Embeddings returned by the text embedding
32
+ layer. Shape: `(batch_size, seq_length, embedding_dim)`.
33
+ vision_indices: tensor. Indexes into `text_embeddings`, used to
34
+ identify which places are supposed to be replaced by
35
+ `image_embeddings`. Shape:
36
+ `(batch_size,`
37
+ `num_images_per_prompt * num_vision_tokens_per_image)`.
28
38
 
29
39
  Returns:
30
40
  Tensor of shape `(batch_size, seq_length, embedding_dim)`
@@ -32,32 +42,62 @@ class Gemma3InterleaveEmbeddings(keras.layers.Layer):
32
42
  """
33
43
 
34
44
  batch_size, seq_length, embedding_dim = ops.shape(text_embeddings)
45
+ # `num_images` will be 0 for text only inputs, and
46
+ # `batch_size * max_images_per_prompt` if images are passed.
47
+ num_images = ops.shape(image_embeddings)[0]
35
48
 
36
- # Flatten text embeddings, text mask and image embeddings.
49
+ # Flatten text embeddings, image embeddings and indices.
37
50
  flat_text_embeddings = ops.reshape(
38
51
  text_embeddings, (batch_size * seq_length, embedding_dim)
39
52
  )
40
-
41
- # The image batch size might be different when we pass only text, i.e,
42
- # it will be 0 for text-only.
43
- image_batch_size = ops.shape(image_embeddings)[0]
53
+ # `flat_image_embeddings` is the `updates` tensor and should be of shape
54
+ # `(num_updates, embedding_dim)`.
44
55
  flat_image_embeddings = ops.reshape(
45
56
  image_embeddings,
46
57
  (
47
- image_batch_size * self.num_vision_tokens_per_image,
58
+ num_images * self.num_vision_tokens_per_image,
48
59
  embedding_dim,
49
60
  ),
50
61
  )
51
62
 
52
- # Reconstruct embeddings.
63
+ # For vision indices, we need to add values such that the indices
64
+ # index into a flattened `text_embeddings`.
65
+ to_add = ops.multiply(
66
+ keras.ops.arange(batch_size, dtype="int32"), seq_length
67
+ )
68
+ to_add = ops.expand_dims(to_add, axis=-1)
69
+ vision_indices = ops.add(vision_indices, to_add)
70
+
71
+ # indices should be of shape `(num_updates, 1)`. `num_updates` is
72
+ # how many vision tokens there are to update.
53
73
  vision_indices_shape = ops.shape(vision_indices)
54
74
  flat_vision_indices = ops.reshape(
55
75
  vision_indices,
56
76
  (vision_indices_shape[0] * vision_indices_shape[1], 1),
57
77
  )
58
78
  indices = ops.cast(flat_vision_indices, "int32")
79
+
80
+ # Before reconstructing, store the 0th index so that we can restore it
81
+ # later.
82
+ zeroth_index_text_embeddings = ops.take(
83
+ flat_text_embeddings,
84
+ indices=ops.squeeze(to_add, axis=-1),
85
+ axis=0,
86
+ )
87
+
88
+ # Reconstruct embeddings
89
+ reconstructed_embedding = ops.scatter_update(
90
+ inputs=flat_text_embeddings,
91
+ indices=indices,
92
+ updates=flat_image_embeddings,
93
+ )
94
+
95
+ # Remember that we pad `vision_indices` with the 0th index. We need to
96
+ # restore the original value in the reconstructed embedding tensor.
59
97
  reconstructed_embedding = ops.scatter_update(
60
- flat_text_embeddings, indices, flat_image_embeddings
98
+ inputs=reconstructed_embedding,
99
+ indices=to_add,
100
+ updates=zeroth_index_text_embeddings,
61
101
  )
62
102
 
63
103
  # Reshape to original dimensions
@@ -11,7 +11,7 @@ backbone_presets = {
11
11
  "params": 999885952,
12
12
  "path": "gemma3",
13
13
  },
14
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_1b/1",
14
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_1b/3",
15
15
  },
16
16
  "gemma3_instruct_1b": {
17
17
  "metadata": {
@@ -22,7 +22,7 @@ backbone_presets = {
22
22
  "params": 999885952,
23
23
  "path": "gemma3",
24
24
  },
25
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_1b/1",
25
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_1b/3",
26
26
  },
27
27
  "gemma3_4b_text": {
28
28
  "metadata": {
@@ -33,7 +33,7 @@ backbone_presets = {
33
33
  "params": 3880099328,
34
34
  "path": "gemma3",
35
35
  },
36
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_4b_text/1",
36
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_4b_text/2",
37
37
  },
38
38
  "gemma3_instruct_4b_text": {
39
39
  "metadata": {
@@ -44,7 +44,7 @@ backbone_presets = {
44
44
  "params": 3880099328,
45
45
  "path": "gemma3",
46
46
  },
47
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_4b_text/2",
47
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_4b_text/3",
48
48
  },
49
49
  "gemma3_12b_text": {
50
50
  "metadata": {
@@ -55,7 +55,7 @@ backbone_presets = {
55
55
  "params": 11765788416,
56
56
  "path": "gemma3",
57
57
  },
58
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b_text/1",
58
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b_text/2",
59
59
  },
60
60
  "gemma3_instruct_12b_text": {
61
61
  "metadata": {
@@ -66,7 +66,7 @@ backbone_presets = {
66
66
  "params": 11765788416,
67
67
  "path": "gemma3",
68
68
  },
69
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b_text/1",
69
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b_text/2",
70
70
  },
71
71
  "gemma3_27b_text": {
72
72
  "metadata": {
@@ -77,7 +77,7 @@ backbone_presets = {
77
77
  "params": 27009002240,
78
78
  "path": "gemma3",
79
79
  },
80
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b_text/1",
80
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b_text/3",
81
81
  },
82
82
  "gemma3_instruct_27b_text": {
83
83
  "metadata": {
@@ -88,6 +88,72 @@ backbone_presets = {
88
88
  "params": 27009002240,
89
89
  "path": "gemma3",
90
90
  },
91
- "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b_text/1",
91
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b_text/2",
92
+ },
93
+ "gemma3_4b": {
94
+ "metadata": {
95
+ "description": (
96
+ "4 billion parameter, 34-layer, vision+text pretrained "
97
+ "Gemma3 model."
98
+ ),
99
+ "params": 4299915632,
100
+ "path": "gemma3",
101
+ },
102
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_4b/1",
103
+ },
104
+ "gemma3_instruct_4b": {
105
+ "metadata": {
106
+ "description": (
107
+ "4 billion parameter, 34-layer, vision+text instruction-tuned "
108
+ "Gemma3 model."
109
+ ),
110
+ "params": 4299915632,
111
+ "path": "gemma3",
112
+ },
113
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_4b/1",
114
+ },
115
+ "gemma3_12b": {
116
+ "metadata": {
117
+ "description": (
118
+ "12 billion parameter, 48-layer, vision+text pretrained "
119
+ "Gemma3 model."
120
+ ),
121
+ "params": 12187079280,
122
+ "path": "gemma3",
123
+ },
124
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b/1",
125
+ },
126
+ "gemma3_instruct_12b": {
127
+ "metadata": {
128
+ "description": (
129
+ "12 billion parameter, 48-layer, vision+text instruction-tuned "
130
+ "Gemma3 model."
131
+ ),
132
+ "params": 12187079280,
133
+ "path": "gemma3",
134
+ },
135
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b/1",
136
+ },
137
+ "gemma3_27b": {
138
+ "metadata": {
139
+ "description": (
140
+ "27 billion parameter, 62-layer, vision+text pretrained "
141
+ "Gemma3 model."
142
+ ),
143
+ "params": 27432062576,
144
+ "path": "gemma3",
145
+ },
146
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b/1",
147
+ },
148
+ "gemma3_instruct_27b": {
149
+ "metadata": {
150
+ "description": (
151
+ "27 billion parameter, 62-layer, vision+text instruction-tuned "
152
+ "Gemma3 model."
153
+ ),
154
+ "params": 27432062576,
155
+ "path": "gemma3",
156
+ },
157
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b/1",
92
158
  },
93
159
  }
@@ -4,6 +4,10 @@ from keras_hub.src.tokenizers.sentence_piece_tokenizer import (
4
4
  SentencePieceTokenizer,
5
5
  )
6
6
 
7
+ START_OF_IMAGE_TOKEN = "<start_of_image>"
8
+ IMAGE_PLACEHOLDER_TOKEN = "<img>"
9
+ END_OF_IMAGE_TOKEN = "<end_of_image>"
10
+
7
11
 
8
12
  @keras_hub_export(
9
13
  [
@@ -84,4 +88,9 @@ class Gemma3Tokenizer(SentencePieceTokenizer):
84
88
  # Image placeholder token.
85
89
  self._add_special_token("<img>", "image_placeholder")
86
90
 
91
+ # Some tokens which are used in the preprocessor. We need to keep them
92
+ # here so that the preprocessor works with `tf.data`.
93
+ self._add_special_token("<start_of_image>", "start_of_image_token")
94
+ self._add_special_token("<end_of_image>", "end_of_image_token")
95
+
87
96
  super().__init__(proto=proto, **kwargs)