keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +21 -0
- keras_hub/models/__init__.py +27 -0
- keras_hub/src/layers/modeling/non_max_supression.py +5 -2
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
- keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/backbone.py +3 -0
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +2 -4
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +2 -2
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/edrec/edrec_backbone.py +147 -0
- keras_hub/src/models/edrec/edrec_layers.py +434 -0
- keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/flux/flux_layers.py +3 -3
- keras_hub/src/models/flux/flux_maths.py +29 -15
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
- keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
- keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/masked_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_presets.py +209 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/rqvae/__init__.py +5 -0
- keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
- keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
- keras_hub/src/models/rwkv7/__init__.py +5 -0
- keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
- keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
- keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
- keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
- keras_hub/src/models/sam/sam_backbone.py +5 -1
- keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
- keras_hub/src/models/sam3/__init__.py +7 -0
- keras_hub/src/models/sam3/roi_align.py +222 -0
- keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
- keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
- keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
- keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
- keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
- keras_hub/src/models/sam3/sam3_layers.py +814 -0
- keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
- keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
- keras_hub/src/models/sam3/sam3_presets.py +16 -0
- keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
- keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
- keras_hub/src/models/sam3/sam3_utils.py +134 -0
- keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
- keras_hub/src/models/segformer/segformer_backbone.py +6 -6
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/tests/test_case.py +394 -3
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
- keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
- keras_hub/src/utils/preset_utils.py +1 -1
- keras_hub/src/utils/tensor_utils.py +12 -0
- keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
- keras_hub/src/utils/transformers/convert_sam3.py +472 -0
- keras_hub/src/utils/transformers/export/gemma3.py +196 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
- keras_hub/src/utils/transformers/export/qwen.py +136 -0
- keras_hub/src/utils/transformers/preset_loader.py +15 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
- keras_hub/src/models/gemma3/rms_normalization.py +0 -26
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras import ops
|
|
3
|
+
|
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
5
|
+
from keras_hub.src.models.edrec.edrec_backbone import EdRecBackbone
|
|
6
|
+
from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
|
|
7
|
+
from keras_hub.src.utils.tensor_utils import any_equal
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@keras_hub_export("keras_hub.models.EdRecSeq2SeqLM")
|
|
11
|
+
class EdRecSeq2SeqLM(Seq2SeqLM):
|
|
12
|
+
"""EdRec Seq2SeqLM.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
backbone: A `keras_hub.models.EdRecBackbone` instance.
|
|
16
|
+
preprocessor: Optional preprocessor.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
backbone_cls = EdRecBackbone
|
|
20
|
+
preprocessor_cls = None
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
backbone,
|
|
25
|
+
preprocessor=None,
|
|
26
|
+
**kwargs,
|
|
27
|
+
):
|
|
28
|
+
# === Layers ===
|
|
29
|
+
self.backbone = backbone
|
|
30
|
+
self.preprocessor = preprocessor
|
|
31
|
+
|
|
32
|
+
# LM Head
|
|
33
|
+
self.lm_head = keras.layers.Dense(
|
|
34
|
+
backbone.vocab_size, use_bias=False, name="lm_head"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# === Functional Model ===
|
|
38
|
+
encoder_token_ids = keras.Input(
|
|
39
|
+
shape=(None,), dtype="int32", name="encoder_token_ids"
|
|
40
|
+
)
|
|
41
|
+
decoder_token_ids = keras.Input(
|
|
42
|
+
shape=(None,), dtype="int32", name="decoder_token_ids"
|
|
43
|
+
)
|
|
44
|
+
encoder_padding_mask = keras.Input(
|
|
45
|
+
shape=(None,), dtype="bool", name="encoder_padding_mask"
|
|
46
|
+
)
|
|
47
|
+
decoder_padding_mask = keras.Input(
|
|
48
|
+
shape=(None,), dtype="bool", name="decoder_padding_mask"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
inputs = {
|
|
52
|
+
"encoder_token_ids": encoder_token_ids,
|
|
53
|
+
"decoder_token_ids": decoder_token_ids,
|
|
54
|
+
"encoder_padding_mask": encoder_padding_mask,
|
|
55
|
+
"decoder_padding_mask": decoder_padding_mask,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
backbone_outputs = backbone(inputs)
|
|
59
|
+
# The backbone returns a dict; we likely want the decoder output for the
|
|
60
|
+
# LM head if both are present, or just use what makes sense.
|
|
61
|
+
# For a Seq2Seq model training, we usually consume the decoder output.
|
|
62
|
+
outputs = self.lm_head(backbone_outputs["decoder_sequence_output"])
|
|
63
|
+
|
|
64
|
+
super().__init__(
|
|
65
|
+
inputs=inputs,
|
|
66
|
+
outputs=outputs,
|
|
67
|
+
**kwargs,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def call_decoder_with_cache(
|
|
71
|
+
self,
|
|
72
|
+
encoder_hidden_states,
|
|
73
|
+
encoder_padding_mask,
|
|
74
|
+
decoder_token_ids,
|
|
75
|
+
decoder_padding_mask=None,
|
|
76
|
+
self_attention_cache=None,
|
|
77
|
+
self_attention_cache_update_index=None,
|
|
78
|
+
cross_attention_cache=None,
|
|
79
|
+
cross_attention_cache_update_index=None,
|
|
80
|
+
):
|
|
81
|
+
x = self.backbone.embedding(decoder_token_ids)
|
|
82
|
+
if decoder_padding_mask is None:
|
|
83
|
+
decoder_padding_mask = ops.not_equal(decoder_token_ids, 0)
|
|
84
|
+
|
|
85
|
+
self_attention_caches = []
|
|
86
|
+
cross_attention_caches = []
|
|
87
|
+
|
|
88
|
+
for i, layer in enumerate(self.backbone.decoder_layers):
|
|
89
|
+
current_self_cache = (
|
|
90
|
+
self_attention_cache[:, i, ...]
|
|
91
|
+
if self_attention_cache is not None
|
|
92
|
+
else None
|
|
93
|
+
)
|
|
94
|
+
current_cross_cache = (
|
|
95
|
+
cross_attention_cache[:, i, ...]
|
|
96
|
+
if cross_attention_cache is not None
|
|
97
|
+
else None
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
x, next_self, next_cross = layer(
|
|
101
|
+
x,
|
|
102
|
+
encoder_outputs=encoder_hidden_states,
|
|
103
|
+
decoder_padding_mask=decoder_padding_mask,
|
|
104
|
+
encoder_padding_mask=encoder_padding_mask,
|
|
105
|
+
self_attention_cache=current_self_cache,
|
|
106
|
+
self_attention_cache_update_index=self_attention_cache_update_index,
|
|
107
|
+
cross_attention_cache=current_cross_cache,
|
|
108
|
+
cross_attention_cache_update_index=cross_attention_cache_update_index,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if next_self is not None:
|
|
112
|
+
self_attention_caches.append(next_self)
|
|
113
|
+
if next_cross is not None:
|
|
114
|
+
cross_attention_caches.append(next_cross)
|
|
115
|
+
|
|
116
|
+
if self_attention_cache_update_index is not None:
|
|
117
|
+
self_attention_cache = ops.stack(self_attention_caches, axis=1)
|
|
118
|
+
if cross_attention_cache_update_index is not None:
|
|
119
|
+
cross_attention_cache = ops.stack(cross_attention_caches, axis=1)
|
|
120
|
+
|
|
121
|
+
hidden_states = x
|
|
122
|
+
logits = self.lm_head(x)
|
|
123
|
+
return (
|
|
124
|
+
logits,
|
|
125
|
+
hidden_states,
|
|
126
|
+
self_attention_cache,
|
|
127
|
+
cross_attention_cache,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def call_encoder(self, token_ids, padding_mask):
|
|
131
|
+
x = self.backbone.embedding(token_ids)
|
|
132
|
+
for layer in self.backbone.encoder_layers:
|
|
133
|
+
x = layer(x, padding_mask=padding_mask)
|
|
134
|
+
return x
|
|
135
|
+
|
|
136
|
+
def _initialize_cache(self, encoder_token_ids, decoder_token_ids):
|
|
137
|
+
batch_size = ops.shape(encoder_token_ids)[0]
|
|
138
|
+
encoder_max_length = ops.shape(encoder_token_ids)[1]
|
|
139
|
+
decoder_max_length = ops.shape(decoder_token_ids)[1]
|
|
140
|
+
|
|
141
|
+
num_layers = self.backbone.num_layers_dec
|
|
142
|
+
num_heads = self.backbone.num_heads
|
|
143
|
+
head_dim = self.backbone.hidden_dim // num_heads
|
|
144
|
+
|
|
145
|
+
shape = [
|
|
146
|
+
batch_size,
|
|
147
|
+
num_layers,
|
|
148
|
+
2,
|
|
149
|
+
decoder_max_length,
|
|
150
|
+
num_heads,
|
|
151
|
+
head_dim,
|
|
152
|
+
]
|
|
153
|
+
self_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)
|
|
154
|
+
|
|
155
|
+
shape[3] = encoder_max_length
|
|
156
|
+
cross_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)
|
|
157
|
+
|
|
158
|
+
return self_attention_cache, cross_attention_cache
|
|
159
|
+
|
|
160
|
+
def generate_step(self, inputs, stop_token_ids=None):
|
|
161
|
+
encoder_token_ids = inputs["encoder_token_ids"]
|
|
162
|
+
encoder_padding_mask = inputs["encoder_padding_mask"]
|
|
163
|
+
decoder_token_ids = inputs.get("decoder_token_ids")
|
|
164
|
+
if decoder_token_ids is None:
|
|
165
|
+
batch_size = ops.shape(encoder_token_ids)[0]
|
|
166
|
+
decoder_token_ids = ops.zeros((batch_size, 1), dtype="int32")
|
|
167
|
+
|
|
168
|
+
decoder_padding_mask = inputs.get("decoder_padding_mask")
|
|
169
|
+
if decoder_padding_mask is None:
|
|
170
|
+
decoder_padding_mask = ops.ones_like(
|
|
171
|
+
decoder_token_ids, dtype="bool"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
batch_size = ops.shape(encoder_token_ids)[0]
|
|
175
|
+
|
|
176
|
+
encoder_hidden_states = self.call_encoder(
|
|
177
|
+
encoder_token_ids, encoder_padding_mask
|
|
178
|
+
)
|
|
179
|
+
self_attention_cache, cross_attention_cache = self._initialize_cache(
|
|
180
|
+
encoder_token_ids, decoder_token_ids
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
row_lengths = ops.sum(ops.cast(decoder_padding_mask, "int32"), axis=-1)
|
|
184
|
+
start_index = ops.min(row_lengths)
|
|
185
|
+
|
|
186
|
+
# Init cache logic for step 0
|
|
187
|
+
token_0 = ops.slice(decoder_token_ids, [0, 0], [batch_size, 1])
|
|
188
|
+
mask_0 = ops.slice(decoder_padding_mask, [0, 0], [batch_size, 1])
|
|
189
|
+
_, _, s_cache, c_cache = self.call_decoder_with_cache(
|
|
190
|
+
encoder_hidden_states,
|
|
191
|
+
encoder_padding_mask,
|
|
192
|
+
token_0,
|
|
193
|
+
mask_0,
|
|
194
|
+
self_attention_cache,
|
|
195
|
+
0,
|
|
196
|
+
cross_attention_cache,
|
|
197
|
+
0,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# We define cache as tuple
|
|
201
|
+
cache = (s_cache, c_cache)
|
|
202
|
+
hidden_states = ops.zeros_like(token_0, dtype="float32")
|
|
203
|
+
|
|
204
|
+
def next(prompt, cache, index):
|
|
205
|
+
s_c, c_c = cache
|
|
206
|
+
|
|
207
|
+
# Handle beam search replication if needed
|
|
208
|
+
curr_batch = ops.shape(prompt)[0]
|
|
209
|
+
enc_batch = ops.shape(encoder_hidden_states)[0]
|
|
210
|
+
|
|
211
|
+
enc_states = encoder_hidden_states
|
|
212
|
+
enc_mask = encoder_padding_mask
|
|
213
|
+
|
|
214
|
+
if curr_batch != enc_batch:
|
|
215
|
+
repeats = curr_batch // enc_batch
|
|
216
|
+
enc_states = ops.repeat(enc_states, repeats, axis=0)
|
|
217
|
+
enc_mask = ops.repeat(enc_mask, repeats, axis=0)
|
|
218
|
+
|
|
219
|
+
cache_index = index - 1
|
|
220
|
+
num_samples = ops.shape(prompt)[0]
|
|
221
|
+
prompt_slice = ops.slice(prompt, [0, cache_index], [num_samples, 1])
|
|
222
|
+
|
|
223
|
+
logits, h_states, next_s, next_c = self.call_decoder_with_cache(
|
|
224
|
+
enc_states,
|
|
225
|
+
enc_mask,
|
|
226
|
+
prompt_slice,
|
|
227
|
+
None,
|
|
228
|
+
s_c,
|
|
229
|
+
index - 1,
|
|
230
|
+
c_c,
|
|
231
|
+
None, # Cross cache re-use
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# If the backbone returns the full sequence, we only need the last
|
|
235
|
+
# token.
|
|
236
|
+
if ops.shape(logits)[1] != 1:
|
|
237
|
+
logits = ops.take(logits, [cache_index], axis=1)
|
|
238
|
+
h_states = ops.take(h_states, [cache_index], axis=1)
|
|
239
|
+
|
|
240
|
+
return (
|
|
241
|
+
ops.squeeze(logits, axis=1),
|
|
242
|
+
ops.squeeze(h_states, axis=1),
|
|
243
|
+
(next_s, next_c),
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
new_tokens = self.sampler(
|
|
247
|
+
next=next,
|
|
248
|
+
prompt=decoder_token_ids,
|
|
249
|
+
cache=cache,
|
|
250
|
+
index=start_index,
|
|
251
|
+
mask=decoder_padding_mask,
|
|
252
|
+
stop_token_ids=stop_token_ids,
|
|
253
|
+
hidden_states=hidden_states,
|
|
254
|
+
model=self,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
if stop_token_ids is not None:
|
|
258
|
+
end_locations = any_equal(
|
|
259
|
+
new_tokens,
|
|
260
|
+
stop_token_ids,
|
|
261
|
+
ops.logical_not(decoder_padding_mask),
|
|
262
|
+
)
|
|
263
|
+
end_locations = ops.cast(end_locations, "int32")
|
|
264
|
+
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
|
|
265
|
+
overflow = cumsum - end_locations
|
|
266
|
+
decoder_padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
|
|
267
|
+
else:
|
|
268
|
+
decoder_padding_mask = ops.ones_like(new_tokens, dtype="bool")
|
|
269
|
+
|
|
270
|
+
return {
|
|
271
|
+
"decoder_token_ids": new_tokens,
|
|
272
|
+
"decoder_padding_mask": decoder_padding_mask,
|
|
273
|
+
}
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
5
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
|
|
9
7
|
from keras_hub.src.models.backbone import Backbone
|
|
10
8
|
from keras_hub.src.utils.keras_utils import gelu_approximate
|
|
@@ -1,11 +1,9 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
5
|
from keras_hub.src.layers.modeling.f_net_encoder import FNetEncoder
|
|
5
6
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
|
6
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
7
|
-
ReversibleEmbedding,
|
|
8
|
-
)
|
|
9
7
|
from keras_hub.src.models.backbone import Backbone
|
|
10
8
|
from keras_hub.src.utils.keras_utils import gelu_approximate
|
|
11
9
|
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
5
|
-
ReversibleEmbedding,
|
|
6
|
-
)
|
|
7
5
|
from keras_hub.src.models.backbone import Backbone
|
|
8
6
|
from keras_hub.src.models.falcon.falcon_transformer_decoder import (
|
|
9
7
|
FalconTransformerDecoder,
|
|
@@ -38,7 +38,7 @@ class EmbedND(keras.Model):
|
|
|
38
38
|
|
|
39
39
|
Returns:
|
|
40
40
|
KerasTensor: Positional embeddings of shape
|
|
41
|
-
(...,
|
|
41
|
+
(..., sum(axes_dim) // 2, 2).
|
|
42
42
|
"""
|
|
43
43
|
n_axes = ids.shape[-1]
|
|
44
44
|
emb = ops.concatenate(
|
|
@@ -46,10 +46,10 @@ class EmbedND(keras.Model):
|
|
|
46
46
|
self.rope(ids[..., i], dim=self.axes_dim[i], theta=self.theta)
|
|
47
47
|
for i in range(n_axes)
|
|
48
48
|
],
|
|
49
|
-
axis=-
|
|
49
|
+
axis=-2,
|
|
50
50
|
)
|
|
51
51
|
|
|
52
|
-
return
|
|
52
|
+
return emb
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
class MLPEmbedder(keras.Model):
|
|
@@ -56,10 +56,7 @@ class RotaryPositionalEmbedding(keras.layers.Layer):
|
|
|
56
56
|
scale = ops.arange(0, dim, 2, dtype="float32") / dim
|
|
57
57
|
omega = 1.0 / (theta**scale)
|
|
58
58
|
out = ops.einsum("...n,d->...nd", pos, omega)
|
|
59
|
-
out = ops.stack(
|
|
60
|
-
[ops.cos(out), -ops.sin(out), ops.sin(out), ops.cos(out)], axis=-1
|
|
61
|
-
)
|
|
62
|
-
out = ops.reshape(out, ops.shape(out)[:-1] + (2, 2))
|
|
59
|
+
out = ops.stack([ops.cos(out), ops.sin(out)], axis=-1)
|
|
63
60
|
return ops.cast(out, dtype="float32")
|
|
64
61
|
|
|
65
62
|
|
|
@@ -71,26 +68,43 @@ class ApplyRoPE(keras.layers.Layer):
|
|
|
71
68
|
xq: KerasTensor. The query tensor of shape (..., L, D).
|
|
72
69
|
xk: KerasTensor. The key tensor of shape (..., L, D).
|
|
73
70
|
freqs_cis: KerasTensor. The frequency complex numbers tensor with shape
|
|
74
|
-
|
|
71
|
+
(..., L, D//2, 2).
|
|
75
72
|
|
|
76
73
|
Returns:
|
|
77
74
|
tuple[KerasTensor, KerasTensor]: The transformed query and key tensors.
|
|
78
75
|
"""
|
|
79
76
|
|
|
80
77
|
def call(self, xq, xk, freqs_cis):
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
78
|
+
# xq, xk shape (..., num_heads, seq_len, D)
|
|
79
|
+
# freqs_cis shape (..., seq_len, D//2, 2)
|
|
80
|
+
# Expand freqs_cis to match num_heads dimension
|
|
81
|
+
freqs_cis = ops.expand_dims(freqs_cis, axis=-4)
|
|
82
|
+
# Now freqs_cis shape (..., 1, seq_len, D//2, 2)
|
|
83
|
+
|
|
84
|
+
xq_ = ops.reshape(xq, (*ops.shape(xq)[:-1], -1, 2))
|
|
85
|
+
xk_ = ops.reshape(xk, (*ops.shape(xk)[:-1], -1, 2))
|
|
86
|
+
|
|
87
|
+
xq_real = xq_[..., 0]
|
|
88
|
+
xq_imag = xq_[..., 1]
|
|
89
|
+
xk_real = xk_[..., 0]
|
|
90
|
+
xk_imag = xk_[..., 1]
|
|
91
|
+
|
|
92
|
+
freqs_cos = freqs_cis[..., 0]
|
|
93
|
+
freqs_sin = freqs_cis[..., 1]
|
|
94
|
+
|
|
95
|
+
xq_out_real = xq_real * freqs_cos - xq_imag * freqs_sin
|
|
96
|
+
xq_out_imag = xq_real * freqs_sin + xq_imag * freqs_cos
|
|
97
|
+
xk_out_real = xk_real * freqs_cos - xk_imag * freqs_sin
|
|
98
|
+
xk_out_imag = xk_real * freqs_sin + xk_imag * freqs_cos
|
|
99
|
+
|
|
100
|
+
xq_out = ops.reshape(
|
|
101
|
+
ops.stack([xq_out_real, xq_out_imag], axis=-1), ops.shape(xq)
|
|
86
102
|
)
|
|
87
|
-
xk_out = (
|
|
88
|
-
|
|
103
|
+
xk_out = ops.reshape(
|
|
104
|
+
ops.stack([xk_out_real, xk_out_imag], axis=-1), ops.shape(xk)
|
|
89
105
|
)
|
|
90
106
|
|
|
91
|
-
return
|
|
92
|
-
xk_out, ops.shape(xk)
|
|
93
|
-
)
|
|
107
|
+
return xq_out, xk_out
|
|
94
108
|
|
|
95
109
|
|
|
96
110
|
class FluxRoPEAttention(keras.layers.Layer):
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
2
|
from keras import ops
|
|
3
|
+
from keras.layers import ReversibleEmbedding
|
|
3
4
|
|
|
4
5
|
from keras_hub.src.api_export import keras_hub_export
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.models.backbone import Backbone
|
|
9
7
|
from keras_hub.src.models.gemma.gemma_decoder_block import GemmaDecoderBlock
|
|
10
8
|
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
|
|
@@ -5,7 +5,7 @@ import numpy as np
|
|
|
5
5
|
from keras import ops
|
|
6
6
|
|
|
7
7
|
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
|
8
|
-
from keras_hub.src.models.
|
|
8
|
+
from keras_hub.src.models.gemma3.gemma3_layers import RMSNormalization
|
|
9
9
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
|
10
10
|
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
|
11
11
|
from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
|
|
@@ -1,16 +1,14 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras import layers
|
|
2
3
|
from keras import ops
|
|
4
|
+
from keras.layers import ReversibleEmbedding
|
|
3
5
|
|
|
4
6
|
from keras_hub.src.api_export import keras_hub_export
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
7
|
from keras_hub.src.models.backbone import Backbone
|
|
9
|
-
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
|
|
10
8
|
from keras_hub.src.models.gemma3.gemma3_decoder_block import Gemma3DecoderBlock
|
|
11
|
-
from keras_hub.src.models.gemma3.
|
|
12
|
-
|
|
13
|
-
|
|
9
|
+
from keras_hub.src.models.gemma3.gemma3_layers import Gemma3InterleaveEmbeddings
|
|
10
|
+
from keras_hub.src.models.gemma3.gemma3_layers import Gemma3MeanPooling
|
|
11
|
+
from keras_hub.src.models.gemma3.gemma3_layers import RMSNormalization
|
|
14
12
|
|
|
15
13
|
|
|
16
14
|
@keras_hub_export("keras_hub.models.Gemma3Backbone")
|
|
@@ -27,6 +25,11 @@ class Gemma3Backbone(Backbone):
|
|
|
27
25
|
For a higher-level object for text-generation, see
|
|
28
26
|
`keras_hub.models.Gemma3CausalLM`.
|
|
29
27
|
|
|
28
|
+
This backbone can also function as an end-to-end embedding model by
|
|
29
|
+
setting the `is_embedding_model` argument to `True`. When configured as an
|
|
30
|
+
embedding model with bi-directional attention, it matches the
|
|
31
|
+
`EmbeddingGemma` architecture.
|
|
32
|
+
|
|
30
33
|
The default constructor gives a fully customizable, randomly initialized
|
|
31
34
|
Gemma3 model with any vision encoder, number of heads, embedding dimensions,
|
|
32
35
|
and equivalent configuration for the decoder layers. To load preset
|
|
@@ -70,6 +73,17 @@ class Gemma3Backbone(Backbone):
|
|
|
70
73
|
in all transformer blocks. Defaults to `1e-6`.
|
|
71
74
|
dropout: float. Dropout probability for the Transformer decoder blocks.
|
|
72
75
|
Defaults to `0`.
|
|
76
|
+
is_embedding_model (bool, optional): If `True`, the model will function
|
|
77
|
+
as an embedding model. This adds mean pooling layer and a two-layer
|
|
78
|
+
dense projection head to the final sequence output. The model output
|
|
79
|
+
will be a dictionary containing `'sequence_output'` and
|
|
80
|
+
`'pooled_output'`. Defaults to `False`.
|
|
81
|
+
pooling_intermediate_dim (int, optional): The intermediate dimension of
|
|
82
|
+
the first dense layer in the two-layer pooling projection head.
|
|
83
|
+
Required if `is_embedding_model` is `True`. Defaults to `None`.
|
|
84
|
+
embedding_dim (int, optional): The dimension of the final projected
|
|
85
|
+
embedding. Required if `is_embedding_model` is `True`. Defaults to
|
|
86
|
+
`None`.
|
|
73
87
|
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
|
74
88
|
for the models computations and weights. Note that some
|
|
75
89
|
computations, such as softmax and layer normalization will always
|
|
@@ -198,6 +212,9 @@ class Gemma3Backbone(Backbone):
|
|
|
198
212
|
layer_norm_epsilon=1e-6,
|
|
199
213
|
use_bidirectional_attention=False,
|
|
200
214
|
dropout=0,
|
|
215
|
+
is_embedding_model=False,
|
|
216
|
+
pooling_intermediate_dim=None,
|
|
217
|
+
embedding_dim=None,
|
|
201
218
|
dtype=None,
|
|
202
219
|
**kwargs,
|
|
203
220
|
):
|
|
@@ -319,6 +336,45 @@ class Gemma3Backbone(Backbone):
|
|
|
319
336
|
)
|
|
320
337
|
sequence_output = self.layer_norm(x)
|
|
321
338
|
|
|
339
|
+
if is_embedding_model:
|
|
340
|
+
if embedding_dim is None or pooling_intermediate_dim is None:
|
|
341
|
+
raise ValueError(
|
|
342
|
+
"Must specify embedding_dim and pooling_intermediate_dim."
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# 1. Mask-aware Mean Pooling
|
|
346
|
+
pooled_output = Gemma3MeanPooling(dtype=dtype, name="mean_pooling")(
|
|
347
|
+
sequence_output, padding_mask=padding_mask_input
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# 2. First Projection (Non-linear or Linear depending on preset)
|
|
351
|
+
pooled_output = layers.Dense(
|
|
352
|
+
pooling_intermediate_dim,
|
|
353
|
+
dtype=dtype,
|
|
354
|
+
name="pooling_dense_1",
|
|
355
|
+
use_bias=False,
|
|
356
|
+
)(pooled_output)
|
|
357
|
+
|
|
358
|
+
# 3. Final Projection
|
|
359
|
+
pooled_output = layers.Dense(
|
|
360
|
+
embedding_dim,
|
|
361
|
+
dtype=dtype,
|
|
362
|
+
name="embedding_projection",
|
|
363
|
+
use_bias=False,
|
|
364
|
+
)(pooled_output)
|
|
365
|
+
|
|
366
|
+
# 4. L2 Normalization (Crucial for Retrieval)
|
|
367
|
+
pooled_output = layers.UnitNormalization(
|
|
368
|
+
axis=-1, dtype=dtype, name="unit_normalization"
|
|
369
|
+
)(pooled_output)
|
|
370
|
+
|
|
371
|
+
outputs = {
|
|
372
|
+
"sequence_output": sequence_output,
|
|
373
|
+
"pooled_output": pooled_output,
|
|
374
|
+
}
|
|
375
|
+
else:
|
|
376
|
+
outputs = sequence_output
|
|
377
|
+
|
|
322
378
|
inputs = {
|
|
323
379
|
"token_ids": token_id_input,
|
|
324
380
|
"padding_mask": padding_mask_input,
|
|
@@ -334,7 +390,7 @@ class Gemma3Backbone(Backbone):
|
|
|
334
390
|
|
|
335
391
|
super().__init__(
|
|
336
392
|
inputs=inputs,
|
|
337
|
-
outputs=
|
|
393
|
+
outputs=outputs,
|
|
338
394
|
dtype=dtype,
|
|
339
395
|
**kwargs,
|
|
340
396
|
)
|
|
@@ -361,6 +417,9 @@ class Gemma3Backbone(Backbone):
|
|
|
361
417
|
self.use_bidirectional_attention = use_bidirectional_attention
|
|
362
418
|
self.layer_norm_epsilon = layer_norm_epsilon
|
|
363
419
|
self.dropout = dropout
|
|
420
|
+
self.is_embedding_model = is_embedding_model
|
|
421
|
+
self.pooling_intermediate_dim = pooling_intermediate_dim
|
|
422
|
+
self.embedding_dim = embedding_dim
|
|
364
423
|
|
|
365
424
|
# Keep `num_vision_tokens_per_image` as a backbone property for easy
|
|
366
425
|
# access.
|
|
@@ -401,6 +460,9 @@ class Gemma3Backbone(Backbone):
|
|
|
401
460
|
"use_bidirectional_attention": self.use_bidirectional_attention,
|
|
402
461
|
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
403
462
|
"dropout": self.dropout,
|
|
463
|
+
"is_embedding_model": self.is_embedding_model,
|
|
464
|
+
"pooling_intermediate_dim": self.pooling_intermediate_dim,
|
|
465
|
+
"embedding_dim": self.embedding_dim,
|
|
404
466
|
}
|
|
405
467
|
)
|
|
406
468
|
return config
|
|
@@ -249,7 +249,22 @@ class Gemma3CausalLM(CausalLM):
|
|
|
249
249
|
inputs.get("vision_mask", None),
|
|
250
250
|
inputs.get("vision_indices", None),
|
|
251
251
|
)
|
|
252
|
-
|
|
252
|
+
|
|
253
|
+
# Determine if we have actual images to process.
|
|
254
|
+
# After preprocessing, images shape is (batch, num_images, h, w, 3).
|
|
255
|
+
# For text-only input, num_images=0 (static shape).
|
|
256
|
+
# We use static shape check which returns a Python int, not a tensor.
|
|
257
|
+
num_images = 0
|
|
258
|
+
if (
|
|
259
|
+
images is not None
|
|
260
|
+
and hasattr(images, "shape")
|
|
261
|
+
and len(images.shape) > 1
|
|
262
|
+
):
|
|
263
|
+
num_images = images.shape[
|
|
264
|
+
1
|
|
265
|
+
] # Static shape, returns Python int or None
|
|
266
|
+
|
|
267
|
+
if not self.backbone.text_only_model and num_images:
|
|
253
268
|
# Handle an unbatched image. Unlike `token_ids` and
|
|
254
269
|
# `padding_mask`, this will not automatically be upranked.
|
|
255
270
|
if len(ops.shape(images)) == 4:
|
|
@@ -8,7 +8,7 @@ from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
|
|
8
8
|
merge_padding_and_attention_mask,
|
|
9
9
|
)
|
|
10
10
|
from keras_hub.src.models.gemma3.gemma3_attention import CachedGemma3Attention
|
|
11
|
-
from keras_hub.src.models.gemma3.
|
|
11
|
+
from keras_hub.src.models.gemma3.gemma3_layers import RMSNormalization
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class Gemma3DecoderBlock(keras.layers.Layer):
|