keras-hub 0.25.0.dev0__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -3
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +79 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,273 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.models.edrec.edrec_backbone import EdRecBackbone
6
+ from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
7
+ from keras_hub.src.utils.tensor_utils import any_equal
8
+
9
+
10
+ @keras_hub_export("keras_hub.models.EdRecSeq2SeqLM")
11
+ class EdRecSeq2SeqLM(Seq2SeqLM):
12
+ """EdRec Seq2SeqLM.
13
+
14
+ Args:
15
+ backbone: A `keras_hub.models.EdRecBackbone` instance.
16
+ preprocessor: Optional preprocessor.
17
+ """
18
+
19
+ backbone_cls = EdRecBackbone
20
+ preprocessor_cls = None
21
+
22
+ def __init__(
23
+ self,
24
+ backbone,
25
+ preprocessor=None,
26
+ **kwargs,
27
+ ):
28
+ # === Layers ===
29
+ self.backbone = backbone
30
+ self.preprocessor = preprocessor
31
+
32
+ # LM Head
33
+ self.lm_head = keras.layers.Dense(
34
+ backbone.vocab_size, use_bias=False, name="lm_head"
35
+ )
36
+
37
+ # === Functional Model ===
38
+ encoder_token_ids = keras.Input(
39
+ shape=(None,), dtype="int32", name="encoder_token_ids"
40
+ )
41
+ decoder_token_ids = keras.Input(
42
+ shape=(None,), dtype="int32", name="decoder_token_ids"
43
+ )
44
+ encoder_padding_mask = keras.Input(
45
+ shape=(None,), dtype="bool", name="encoder_padding_mask"
46
+ )
47
+ decoder_padding_mask = keras.Input(
48
+ shape=(None,), dtype="bool", name="decoder_padding_mask"
49
+ )
50
+
51
+ inputs = {
52
+ "encoder_token_ids": encoder_token_ids,
53
+ "decoder_token_ids": decoder_token_ids,
54
+ "encoder_padding_mask": encoder_padding_mask,
55
+ "decoder_padding_mask": decoder_padding_mask,
56
+ }
57
+
58
+ backbone_outputs = backbone(inputs)
59
+ # The backbone returns a dict; we likely want the decoder output for the
60
+ # LM head if both are present, or just use what makes sense.
61
+ # For a Seq2Seq model training, we usually consume the decoder output.
62
+ outputs = self.lm_head(backbone_outputs["decoder_sequence_output"])
63
+
64
+ super().__init__(
65
+ inputs=inputs,
66
+ outputs=outputs,
67
+ **kwargs,
68
+ )
69
+
70
+ def call_decoder_with_cache(
71
+ self,
72
+ encoder_hidden_states,
73
+ encoder_padding_mask,
74
+ decoder_token_ids,
75
+ decoder_padding_mask=None,
76
+ self_attention_cache=None,
77
+ self_attention_cache_update_index=None,
78
+ cross_attention_cache=None,
79
+ cross_attention_cache_update_index=None,
80
+ ):
81
+ x = self.backbone.embedding(decoder_token_ids)
82
+ if decoder_padding_mask is None:
83
+ decoder_padding_mask = ops.not_equal(decoder_token_ids, 0)
84
+
85
+ self_attention_caches = []
86
+ cross_attention_caches = []
87
+
88
+ for i, layer in enumerate(self.backbone.decoder_layers):
89
+ current_self_cache = (
90
+ self_attention_cache[:, i, ...]
91
+ if self_attention_cache is not None
92
+ else None
93
+ )
94
+ current_cross_cache = (
95
+ cross_attention_cache[:, i, ...]
96
+ if cross_attention_cache is not None
97
+ else None
98
+ )
99
+
100
+ x, next_self, next_cross = layer(
101
+ x,
102
+ encoder_outputs=encoder_hidden_states,
103
+ decoder_padding_mask=decoder_padding_mask,
104
+ encoder_padding_mask=encoder_padding_mask,
105
+ self_attention_cache=current_self_cache,
106
+ self_attention_cache_update_index=self_attention_cache_update_index,
107
+ cross_attention_cache=current_cross_cache,
108
+ cross_attention_cache_update_index=cross_attention_cache_update_index,
109
+ )
110
+
111
+ if next_self is not None:
112
+ self_attention_caches.append(next_self)
113
+ if next_cross is not None:
114
+ cross_attention_caches.append(next_cross)
115
+
116
+ if self_attention_cache_update_index is not None:
117
+ self_attention_cache = ops.stack(self_attention_caches, axis=1)
118
+ if cross_attention_cache_update_index is not None:
119
+ cross_attention_cache = ops.stack(cross_attention_caches, axis=1)
120
+
121
+ hidden_states = x
122
+ logits = self.lm_head(x)
123
+ return (
124
+ logits,
125
+ hidden_states,
126
+ self_attention_cache,
127
+ cross_attention_cache,
128
+ )
129
+
130
+ def call_encoder(self, token_ids, padding_mask):
131
+ x = self.backbone.embedding(token_ids)
132
+ for layer in self.backbone.encoder_layers:
133
+ x = layer(x, padding_mask=padding_mask)
134
+ return x
135
+
136
+ def _initialize_cache(self, encoder_token_ids, decoder_token_ids):
137
+ batch_size = ops.shape(encoder_token_ids)[0]
138
+ encoder_max_length = ops.shape(encoder_token_ids)[1]
139
+ decoder_max_length = ops.shape(decoder_token_ids)[1]
140
+
141
+ num_layers = self.backbone.num_layers_dec
142
+ num_heads = self.backbone.num_heads
143
+ head_dim = self.backbone.hidden_dim // num_heads
144
+
145
+ shape = [
146
+ batch_size,
147
+ num_layers,
148
+ 2,
149
+ decoder_max_length,
150
+ num_heads,
151
+ head_dim,
152
+ ]
153
+ self_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)
154
+
155
+ shape[3] = encoder_max_length
156
+ cross_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)
157
+
158
+ return self_attention_cache, cross_attention_cache
159
+
160
+ def generate_step(self, inputs, stop_token_ids=None):
161
+ encoder_token_ids = inputs["encoder_token_ids"]
162
+ encoder_padding_mask = inputs["encoder_padding_mask"]
163
+ decoder_token_ids = inputs.get("decoder_token_ids")
164
+ if decoder_token_ids is None:
165
+ batch_size = ops.shape(encoder_token_ids)[0]
166
+ decoder_token_ids = ops.zeros((batch_size, 1), dtype="int32")
167
+
168
+ decoder_padding_mask = inputs.get("decoder_padding_mask")
169
+ if decoder_padding_mask is None:
170
+ decoder_padding_mask = ops.ones_like(
171
+ decoder_token_ids, dtype="bool"
172
+ )
173
+
174
+ batch_size = ops.shape(encoder_token_ids)[0]
175
+
176
+ encoder_hidden_states = self.call_encoder(
177
+ encoder_token_ids, encoder_padding_mask
178
+ )
179
+ self_attention_cache, cross_attention_cache = self._initialize_cache(
180
+ encoder_token_ids, decoder_token_ids
181
+ )
182
+
183
+ row_lengths = ops.sum(ops.cast(decoder_padding_mask, "int32"), axis=-1)
184
+ start_index = ops.min(row_lengths)
185
+
186
+ # Init cache logic for step 0
187
+ token_0 = ops.slice(decoder_token_ids, [0, 0], [batch_size, 1])
188
+ mask_0 = ops.slice(decoder_padding_mask, [0, 0], [batch_size, 1])
189
+ _, _, s_cache, c_cache = self.call_decoder_with_cache(
190
+ encoder_hidden_states,
191
+ encoder_padding_mask,
192
+ token_0,
193
+ mask_0,
194
+ self_attention_cache,
195
+ 0,
196
+ cross_attention_cache,
197
+ 0,
198
+ )
199
+
200
+ # We define cache as tuple
201
+ cache = (s_cache, c_cache)
202
+ hidden_states = ops.zeros_like(token_0, dtype="float32")
203
+
204
+ def next(prompt, cache, index):
205
+ s_c, c_c = cache
206
+
207
+ # Handle beam search replication if needed
208
+ curr_batch = ops.shape(prompt)[0]
209
+ enc_batch = ops.shape(encoder_hidden_states)[0]
210
+
211
+ enc_states = encoder_hidden_states
212
+ enc_mask = encoder_padding_mask
213
+
214
+ if curr_batch != enc_batch:
215
+ repeats = curr_batch // enc_batch
216
+ enc_states = ops.repeat(enc_states, repeats, axis=0)
217
+ enc_mask = ops.repeat(enc_mask, repeats, axis=0)
218
+
219
+ cache_index = index - 1
220
+ num_samples = ops.shape(prompt)[0]
221
+ prompt_slice = ops.slice(prompt, [0, cache_index], [num_samples, 1])
222
+
223
+ logits, h_states, next_s, next_c = self.call_decoder_with_cache(
224
+ enc_states,
225
+ enc_mask,
226
+ prompt_slice,
227
+ None,
228
+ s_c,
229
+ index - 1,
230
+ c_c,
231
+ None, # Cross cache re-use
232
+ )
233
+
234
+ # If the backbone returns the full sequence, we only need the last
235
+ # token.
236
+ if ops.shape(logits)[1] != 1:
237
+ logits = ops.take(logits, [cache_index], axis=1)
238
+ h_states = ops.take(h_states, [cache_index], axis=1)
239
+
240
+ return (
241
+ ops.squeeze(logits, axis=1),
242
+ ops.squeeze(h_states, axis=1),
243
+ (next_s, next_c),
244
+ )
245
+
246
+ new_tokens = self.sampler(
247
+ next=next,
248
+ prompt=decoder_token_ids,
249
+ cache=cache,
250
+ index=start_index,
251
+ mask=decoder_padding_mask,
252
+ stop_token_ids=stop_token_ids,
253
+ hidden_states=hidden_states,
254
+ model=self,
255
+ )
256
+
257
+ if stop_token_ids is not None:
258
+ end_locations = any_equal(
259
+ new_tokens,
260
+ stop_token_ids,
261
+ ops.logical_not(decoder_padding_mask),
262
+ )
263
+ end_locations = ops.cast(end_locations, "int32")
264
+ cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
265
+ overflow = cumsum - end_locations
266
+ decoder_padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
267
+ else:
268
+ decoder_padding_mask = ops.ones_like(new_tokens, dtype="bool")
269
+
270
+ return {
271
+ "decoder_token_ids": new_tokens,
272
+ "decoder_padding_mask": decoder_padding_mask,
273
+ }
@@ -1,10 +1,8 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
5
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
9
7
  from keras_hub.src.models.backbone import Backbone
10
8
  from keras_hub.src.utils.keras_utils import gelu_approximate
@@ -1,11 +1,9 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
5
  from keras_hub.src.layers.modeling.f_net_encoder import FNetEncoder
5
6
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
6
- from keras_hub.src.layers.modeling.reversible_embedding import (
7
- ReversibleEmbedding,
8
- )
9
7
  from keras_hub.src.models.backbone import Backbone
10
8
  from keras_hub.src.utils.keras_utils import gelu_approximate
11
9
 
@@ -1,9 +1,7 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.layers.modeling.reversible_embedding import (
5
- ReversibleEmbedding,
6
- )
7
5
  from keras_hub.src.models.backbone import Backbone
8
6
  from keras_hub.src.models.falcon.falcon_transformer_decoder import (
9
7
  FalconTransformerDecoder,
@@ -38,7 +38,7 @@ class EmbedND(keras.Model):
38
38
 
39
39
  Returns:
40
40
  KerasTensor: Positional embeddings of shape
41
- (..., concatenated_dim, 1, ...).
41
+ (..., sum(axes_dim) // 2, 2).
42
42
  """
43
43
  n_axes = ids.shape[-1]
44
44
  emb = ops.concatenate(
@@ -46,10 +46,10 @@ class EmbedND(keras.Model):
46
46
  self.rope(ids[..., i], dim=self.axes_dim[i], theta=self.theta)
47
47
  for i in range(n_axes)
48
48
  ],
49
- axis=-3,
49
+ axis=-2,
50
50
  )
51
51
 
52
- return ops.expand_dims(emb, axis=1)
52
+ return emb
53
53
 
54
54
 
55
55
  class MLPEmbedder(keras.Model):
@@ -56,10 +56,7 @@ class RotaryPositionalEmbedding(keras.layers.Layer):
56
56
  scale = ops.arange(0, dim, 2, dtype="float32") / dim
57
57
  omega = 1.0 / (theta**scale)
58
58
  out = ops.einsum("...n,d->...nd", pos, omega)
59
- out = ops.stack(
60
- [ops.cos(out), -ops.sin(out), ops.sin(out), ops.cos(out)], axis=-1
61
- )
62
- out = ops.reshape(out, ops.shape(out)[:-1] + (2, 2))
59
+ out = ops.stack([ops.cos(out), ops.sin(out)], axis=-1)
63
60
  return ops.cast(out, dtype="float32")
64
61
 
65
62
 
@@ -71,26 +68,43 @@ class ApplyRoPE(keras.layers.Layer):
71
68
  xq: KerasTensor. The query tensor of shape (..., L, D).
72
69
  xk: KerasTensor. The key tensor of shape (..., L, D).
73
70
  freqs_cis: KerasTensor. The frequency complex numbers tensor with shape
74
- `(..., 2)`.
71
+ (..., L, D//2, 2).
75
72
 
76
73
  Returns:
77
74
  tuple[KerasTensor, KerasTensor]: The transformed query and key tensors.
78
75
  """
79
76
 
80
77
  def call(self, xq, xk, freqs_cis):
81
- xq_ = ops.reshape(xq, (*ops.shape(xq)[:-1], -1, 1, 2))
82
- xk_ = ops.reshape(xk, (*ops.shape(xk)[:-1], -1, 1, 2))
83
-
84
- xq_out = (
85
- freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
78
+ # xq, xk shape (..., num_heads, seq_len, D)
79
+ # freqs_cis shape (..., seq_len, D//2, 2)
80
+ # Expand freqs_cis to match num_heads dimension
81
+ freqs_cis = ops.expand_dims(freqs_cis, axis=-4)
82
+ # Now freqs_cis shape (..., 1, seq_len, D//2, 2)
83
+
84
+ xq_ = ops.reshape(xq, (*ops.shape(xq)[:-1], -1, 2))
85
+ xk_ = ops.reshape(xk, (*ops.shape(xk)[:-1], -1, 2))
86
+
87
+ xq_real = xq_[..., 0]
88
+ xq_imag = xq_[..., 1]
89
+ xk_real = xk_[..., 0]
90
+ xk_imag = xk_[..., 1]
91
+
92
+ freqs_cos = freqs_cis[..., 0]
93
+ freqs_sin = freqs_cis[..., 1]
94
+
95
+ xq_out_real = xq_real * freqs_cos - xq_imag * freqs_sin
96
+ xq_out_imag = xq_real * freqs_sin + xq_imag * freqs_cos
97
+ xk_out_real = xk_real * freqs_cos - xk_imag * freqs_sin
98
+ xk_out_imag = xk_real * freqs_sin + xk_imag * freqs_cos
99
+
100
+ xq_out = ops.reshape(
101
+ ops.stack([xq_out_real, xq_out_imag], axis=-1), ops.shape(xq)
86
102
  )
87
- xk_out = (
88
- freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
103
+ xk_out = ops.reshape(
104
+ ops.stack([xk_out_real, xk_out_imag], axis=-1), ops.shape(xk)
89
105
  )
90
106
 
91
- return ops.reshape(xq_out, ops.shape(xq)), ops.reshape(
92
- xk_out, ops.shape(xk)
93
- )
107
+ return xq_out, xk_out
94
108
 
95
109
 
96
110
  class FluxRoPEAttention(keras.layers.Layer):
@@ -1,10 +1,8 @@
1
1
  import keras
2
2
  from keras import ops
3
+ from keras.layers import ReversibleEmbedding
3
4
 
4
5
  from keras_hub.src.api_export import keras_hub_export
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.models.backbone import Backbone
9
7
  from keras_hub.src.models.gemma.gemma_decoder_block import GemmaDecoderBlock
10
8
  from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
@@ -433,7 +433,7 @@ class GemmaCausalLM(CausalLM):
433
433
  return per_token_loss
434
434
 
435
435
  def get_quantization_layer_structure(self, mode):
436
- if mode != "gptq":
436
+ if mode not in ["gptq", "awq"]:
437
437
  return None
438
438
 
439
439
  # Wrap embedding + scaling
@@ -5,7 +5,7 @@ import numpy as np
5
5
  from keras import ops
6
6
 
7
7
  from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
8
- from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
8
+ from keras_hub.src.models.gemma3.gemma3_layers import RMSNormalization
9
9
  from keras_hub.src.utils.keras_utils import clone_initializer
10
10
  from keras_hub.src.utils.keras_utils import fused_attention_op_available
11
11
  from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
@@ -1,16 +1,14 @@
1
1
  import keras
2
+ from keras import layers
2
3
  from keras import ops
4
+ from keras.layers import ReversibleEmbedding
3
5
 
4
6
  from keras_hub.src.api_export import keras_hub_export
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
7
  from keras_hub.src.models.backbone import Backbone
9
- from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
10
8
  from keras_hub.src.models.gemma3.gemma3_decoder_block import Gemma3DecoderBlock
11
- from keras_hub.src.models.gemma3.gemma3_interleave_embeddings import (
12
- Gemma3InterleaveEmbeddings,
13
- )
9
+ from keras_hub.src.models.gemma3.gemma3_layers import Gemma3InterleaveEmbeddings
10
+ from keras_hub.src.models.gemma3.gemma3_layers import Gemma3MeanPooling
11
+ from keras_hub.src.models.gemma3.gemma3_layers import RMSNormalization
14
12
 
15
13
 
16
14
  @keras_hub_export("keras_hub.models.Gemma3Backbone")
@@ -27,6 +25,11 @@ class Gemma3Backbone(Backbone):
27
25
  For a higher-level object for text-generation, see
28
26
  `keras_hub.models.Gemma3CausalLM`.
29
27
 
28
+ This backbone can also function as an end-to-end embedding model by
29
+ setting the `is_embedding_model` argument to `True`. When configured as an
30
+ embedding model with bi-directional attention, it matches the
31
+ `EmbeddingGemma` architecture.
32
+
30
33
  The default constructor gives a fully customizable, randomly initialized
31
34
  Gemma3 model with any vision encoder, number of heads, embedding dimensions,
32
35
  and equivalent configuration for the decoder layers. To load preset
@@ -70,6 +73,17 @@ class Gemma3Backbone(Backbone):
70
73
  in all transformer blocks. Defaults to `1e-6`.
71
74
  dropout: float. Dropout probability for the Transformer decoder blocks.
72
75
  Defaults to `0`.
76
+ is_embedding_model (bool, optional): If `True`, the model will function
77
+ as an embedding model. This adds mean pooling layer and a two-layer
78
+ dense projection head to the final sequence output. The model output
79
+ will be a dictionary containing `'sequence_output'` and
80
+ `'pooled_output'`. Defaults to `False`.
81
+ pooling_intermediate_dim (int, optional): The intermediate dimension of
82
+ the first dense layer in the two-layer pooling projection head.
83
+ Required if `is_embedding_model` is `True`. Defaults to `None`.
84
+ embedding_dim (int, optional): The dimension of the final projected
85
+ embedding. Required if `is_embedding_model` is `True`. Defaults to
86
+ `None`.
73
87
  dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
74
88
  for the models computations and weights. Note that some
75
89
  computations, such as softmax and layer normalization will always
@@ -198,6 +212,9 @@ class Gemma3Backbone(Backbone):
198
212
  layer_norm_epsilon=1e-6,
199
213
  use_bidirectional_attention=False,
200
214
  dropout=0,
215
+ is_embedding_model=False,
216
+ pooling_intermediate_dim=None,
217
+ embedding_dim=None,
201
218
  dtype=None,
202
219
  **kwargs,
203
220
  ):
@@ -319,6 +336,45 @@ class Gemma3Backbone(Backbone):
319
336
  )
320
337
  sequence_output = self.layer_norm(x)
321
338
 
339
+ if is_embedding_model:
340
+ if embedding_dim is None or pooling_intermediate_dim is None:
341
+ raise ValueError(
342
+ "Must specify embedding_dim and pooling_intermediate_dim."
343
+ )
344
+
345
+ # 1. Mask-aware Mean Pooling
346
+ pooled_output = Gemma3MeanPooling(dtype=dtype, name="mean_pooling")(
347
+ sequence_output, padding_mask=padding_mask_input
348
+ )
349
+
350
+ # 2. First Projection (Non-linear or Linear depending on preset)
351
+ pooled_output = layers.Dense(
352
+ pooling_intermediate_dim,
353
+ dtype=dtype,
354
+ name="pooling_dense_1",
355
+ use_bias=False,
356
+ )(pooled_output)
357
+
358
+ # 3. Final Projection
359
+ pooled_output = layers.Dense(
360
+ embedding_dim,
361
+ dtype=dtype,
362
+ name="embedding_projection",
363
+ use_bias=False,
364
+ )(pooled_output)
365
+
366
+ # 4. L2 Normalization (Crucial for Retrieval)
367
+ pooled_output = layers.UnitNormalization(
368
+ axis=-1, dtype=dtype, name="unit_normalization"
369
+ )(pooled_output)
370
+
371
+ outputs = {
372
+ "sequence_output": sequence_output,
373
+ "pooled_output": pooled_output,
374
+ }
375
+ else:
376
+ outputs = sequence_output
377
+
322
378
  inputs = {
323
379
  "token_ids": token_id_input,
324
380
  "padding_mask": padding_mask_input,
@@ -334,7 +390,7 @@ class Gemma3Backbone(Backbone):
334
390
 
335
391
  super().__init__(
336
392
  inputs=inputs,
337
- outputs=sequence_output,
393
+ outputs=outputs,
338
394
  dtype=dtype,
339
395
  **kwargs,
340
396
  )
@@ -361,6 +417,9 @@ class Gemma3Backbone(Backbone):
361
417
  self.use_bidirectional_attention = use_bidirectional_attention
362
418
  self.layer_norm_epsilon = layer_norm_epsilon
363
419
  self.dropout = dropout
420
+ self.is_embedding_model = is_embedding_model
421
+ self.pooling_intermediate_dim = pooling_intermediate_dim
422
+ self.embedding_dim = embedding_dim
364
423
 
365
424
  # Keep `num_vision_tokens_per_image` as a backbone property for easy
366
425
  # access.
@@ -401,6 +460,9 @@ class Gemma3Backbone(Backbone):
401
460
  "use_bidirectional_attention": self.use_bidirectional_attention,
402
461
  "layer_norm_epsilon": self.layer_norm_epsilon,
403
462
  "dropout": self.dropout,
463
+ "is_embedding_model": self.is_embedding_model,
464
+ "pooling_intermediate_dim": self.pooling_intermediate_dim,
465
+ "embedding_dim": self.embedding_dim,
404
466
  }
405
467
  )
406
468
  return config
@@ -249,7 +249,22 @@ class Gemma3CausalLM(CausalLM):
249
249
  inputs.get("vision_mask", None),
250
250
  inputs.get("vision_indices", None),
251
251
  )
252
- if not self.backbone.text_only_model:
252
+
253
+ # Determine if we have actual images to process.
254
+ # After preprocessing, images shape is (batch, num_images, h, w, 3).
255
+ # For text-only input, num_images=0 (static shape).
256
+ # We use static shape check which returns a Python int, not a tensor.
257
+ num_images = 0
258
+ if (
259
+ images is not None
260
+ and hasattr(images, "shape")
261
+ and len(images.shape) > 1
262
+ ):
263
+ num_images = images.shape[
264
+ 1
265
+ ] # Static shape, returns Python int or None
266
+
267
+ if not self.backbone.text_only_model and num_images:
253
268
  # Handle an unbatched image. Unlike `token_ids` and
254
269
  # `padding_mask`, this will not automatically be upranked.
255
270
  if len(ops.shape(images)) == 4:
@@ -8,7 +8,7 @@ from keras_hub.src.layers.modeling.transformer_layer_utils import (
8
8
  merge_padding_and_attention_mask,
9
9
  )
10
10
  from keras_hub.src.models.gemma3.gemma3_attention import CachedGemma3Attention
11
- from keras_hub.src.models.gemma3.rms_normalization import RMSNormalization
11
+ from keras_hub.src.models.gemma3.gemma3_layers import RMSNormalization
12
12
 
13
13
 
14
14
  class Gemma3DecoderBlock(keras.layers.Layer):
@@ -251,6 +251,11 @@ class Gemma3DecoderBlock(keras.layers.Layer):
251
251
  cache_update_mask=None,
252
252
  ):
253
253
  # Note: `vision_mask` is used only for Gemma3.
254
+ # If float16, we clamp the input to avoid overflow.
255
+ is_float16 = keras.backend.standardize_dtype(x.dtype) == "float16"
256
+ if is_float16:
257
+ x = ops.clip(x, -65504, 65504)
258
+
254
259
  normalized_x = self.pre_attention_norm(x)
255
260
  attention_mask = self._compute_attention_mask(
256
261
  normalized_x, padding_mask, vision_mask, cache, cache_update_index
@@ -275,7 +280,15 @@ class Gemma3DecoderBlock(keras.layers.Layer):
275
280
  if self.dropout:
276
281
  attention = self.attention_dropout(attention)
277
282
 
278
- attention_x = x + attention
283
+ if is_float16:
284
+ attention_x = ops.add(
285
+ ops.cast(x, "float32"), ops.cast(attention, "float32")
286
+ )
287
+ attention_x = ops.clip(attention_x, -65504, 65504)
288
+ attention_x = ops.cast(attention_x, "float16")
289
+ else:
290
+ attention_x = x + attention
291
+
279
292
  normalized_x = self.pre_ffw_norm(attention_x)
280
293
 
281
294
  x1 = self.gating_ffw(normalized_x)
@@ -286,7 +299,14 @@ class Gemma3DecoderBlock(keras.layers.Layer):
286
299
  if self.use_post_ffw_norm:
287
300
  x = self.post_ffw_norm(x)
288
301
 
289
- x = x + attention_x
302
+ if is_float16:
303
+ x = ops.add(
304
+ ops.cast(x, "float32"), ops.cast(attention_x, "float32")
305
+ )
306
+ x = ops.clip(x, -65504, 65504)
307
+ x = ops.cast(x, "float16")
308
+ else:
309
+ x = x + attention_x
290
310
 
291
311
  if cache is not None:
292
312
  return x, new_cache