keras-hub-nightly 0.20.0.dev202503260356__py3-none-any.whl → 0.20.0.dev202503270400__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -0
- keras_hub/api/models/__init__.py +6 -0
- keras_hub/api/tokenizers/__init__.py +1 -0
- keras_hub/src/models/gemma3/__init__.py +5 -0
- keras_hub/src/models/gemma3/gemma3_attention.py +315 -0
- keras_hub/src/models/gemma3/gemma3_backbone.py +352 -0
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +306 -0
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +691 -0
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +305 -0
- keras_hub/src/models/gemma3/gemma3_image_converter.py +8 -0
- keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +79 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +93 -0
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +87 -0
- keras_hub/src/models/gemma3/gemma3_vit.py +608 -0
- keras_hub/src/models/gemma3/rms_normalization.py +26 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/RECORD +20 -8
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,608 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
|
5
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
6
|
+
|
7
|
+
|
8
|
+
class Gemma3VitEmbeddings(keras.layers.Layer):
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
image_size,
|
12
|
+
patch_size,
|
13
|
+
hidden_dim,
|
14
|
+
num_channels=3,
|
15
|
+
dtype=None,
|
16
|
+
**kwargs,
|
17
|
+
):
|
18
|
+
super().__init__(dtype=dtype, **kwargs)
|
19
|
+
self.hidden_dim = hidden_dim
|
20
|
+
self.image_size = image_size
|
21
|
+
self.patch_size = patch_size
|
22
|
+
self.num_channels = num_channels
|
23
|
+
self.patch_embedding = keras.layers.Conv2D(
|
24
|
+
filters=self.hidden_dim,
|
25
|
+
kernel_size=self.patch_size,
|
26
|
+
strides=self.patch_size,
|
27
|
+
padding="valid",
|
28
|
+
activation=None,
|
29
|
+
dtype=dtype,
|
30
|
+
name="embedding_conv",
|
31
|
+
)
|
32
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
33
|
+
self.num_positions = self.num_patches
|
34
|
+
self.position_embedding = keras.layers.Embedding(
|
35
|
+
self.num_positions,
|
36
|
+
self.hidden_dim,
|
37
|
+
dtype=dtype,
|
38
|
+
name="position_embedding",
|
39
|
+
)
|
40
|
+
|
41
|
+
self.position_ids = ops.expand_dims(
|
42
|
+
ops.arange(self.num_positions), axis=0
|
43
|
+
)
|
44
|
+
|
45
|
+
def build(self, input_shape):
|
46
|
+
self.patch_embedding.build(input_shape)
|
47
|
+
self.position_embedding.build([1, self.num_positions])
|
48
|
+
self.built = True
|
49
|
+
|
50
|
+
def call(self, input_tokens):
|
51
|
+
x = self.patch_embedding(input_tokens)
|
52
|
+
input_shape = ops.shape(x)
|
53
|
+
x = ops.reshape(x, [input_shape[0], self.num_patches, self.hidden_dim])
|
54
|
+
x = x + self.position_embedding(self.position_ids)
|
55
|
+
return x
|
56
|
+
|
57
|
+
def compute_output_shape(self, input_shape):
|
58
|
+
return (
|
59
|
+
input_shape[0],
|
60
|
+
self.num_patches,
|
61
|
+
self.hidden_dim,
|
62
|
+
)
|
63
|
+
|
64
|
+
|
65
|
+
class Gemma3VitAttention(keras.layers.Layer):
|
66
|
+
"""
|
67
|
+
Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
|
68
|
+
"""
|
69
|
+
|
70
|
+
def __init__(
|
71
|
+
self,
|
72
|
+
hidden_dim,
|
73
|
+
num_heads,
|
74
|
+
dropout=0.0,
|
75
|
+
dtype=None,
|
76
|
+
**kwargs,
|
77
|
+
):
|
78
|
+
super().__init__(dtype=dtype, **kwargs)
|
79
|
+
|
80
|
+
self.hidden_dim = hidden_dim
|
81
|
+
self.num_heads = num_heads
|
82
|
+
self.dropout = dropout
|
83
|
+
self.head_dim = self.hidden_dim // self.num_heads
|
84
|
+
if self.head_dim * self.num_heads != self.hidden_dim:
|
85
|
+
raise ValueError(
|
86
|
+
f"hidden_dim must be divisible by num_heads (got `hidden_dim`"
|
87
|
+
f": {self.hidden_dim} and `num_heads`:"
|
88
|
+
f" {self.num_heads})."
|
89
|
+
)
|
90
|
+
self.dropout_layer = keras.layers.Dropout(
|
91
|
+
self.dropout,
|
92
|
+
dtype=dtype,
|
93
|
+
name="dropout",
|
94
|
+
)
|
95
|
+
self.scale = self.head_dim**-0.5
|
96
|
+
self.query_proj = keras.layers.Dense(
|
97
|
+
units=self.hidden_dim,
|
98
|
+
dtype=dtype,
|
99
|
+
name="query_proj",
|
100
|
+
)
|
101
|
+
self.key_proj = keras.layers.Dense(
|
102
|
+
units=self.hidden_dim,
|
103
|
+
dtype=dtype,
|
104
|
+
name="key_proj",
|
105
|
+
)
|
106
|
+
self.value_proj = keras.layers.Dense(
|
107
|
+
units=self.hidden_dim,
|
108
|
+
dtype=dtype,
|
109
|
+
name="value_proj",
|
110
|
+
)
|
111
|
+
self.out_proj = keras.layers.Dense(
|
112
|
+
units=self.hidden_dim,
|
113
|
+
dtype=dtype,
|
114
|
+
name="out_proj",
|
115
|
+
)
|
116
|
+
|
117
|
+
def build(self, input_shape):
|
118
|
+
self.query_proj.build([None, None, self.hidden_dim])
|
119
|
+
self.key_proj.build([None, None, self.hidden_dim])
|
120
|
+
self.value_proj.build([None, None, self.hidden_dim])
|
121
|
+
self.out_proj.build([None, None, self.hidden_dim])
|
122
|
+
self.built = True
|
123
|
+
|
124
|
+
def _transpose_for_scores(self, tensor, batch_size):
|
125
|
+
"""
|
126
|
+
Adapted from https://github.com/huggingface/transformers/blob/8e164c5400b7b413c7b8fb32e35132001effc970/src/transformers/models/bert/modeling_tf_bert.py#L252
|
127
|
+
"""
|
128
|
+
# [batch_size, seq_len, all_head_dim] ->
|
129
|
+
# [batch_size, seq_len, num_heads, head_dim]
|
130
|
+
seq_len = ops.shape(tensor)[1]
|
131
|
+
tensor = ops.reshape(
|
132
|
+
tensor, (batch_size, seq_len, self.num_heads, self.head_dim)
|
133
|
+
)
|
134
|
+
# [batch_size, seq_len, num_heads, head_dim] ->
|
135
|
+
# [batch_size, num_heads, seq_len, head_dim]
|
136
|
+
return ops.transpose(tensor, axes=[0, 2, 1, 3])
|
137
|
+
|
138
|
+
def call(
|
139
|
+
self,
|
140
|
+
x,
|
141
|
+
attention_mask=None,
|
142
|
+
return_attention_scores=None,
|
143
|
+
training=False,
|
144
|
+
):
|
145
|
+
batch_size = ops.shape(x)[0]
|
146
|
+
mixed_query_layer = self.query_proj(inputs=x)
|
147
|
+
mixed_key_layer = self.key_proj(inputs=x)
|
148
|
+
mixed_value_layer = self.value_proj(inputs=x)
|
149
|
+
query_layer = self._transpose_for_scores(mixed_query_layer, batch_size)
|
150
|
+
key_layer = self._transpose_for_scores(mixed_key_layer, batch_size)
|
151
|
+
value_layer = self._transpose_for_scores(mixed_value_layer, batch_size)
|
152
|
+
|
153
|
+
# Scaled dot product between key and query = raw attention scores.
|
154
|
+
attention_scores = ops.matmul(
|
155
|
+
query_layer, ops.transpose(key_layer, axes=[0, 1, 3, 2])
|
156
|
+
)
|
157
|
+
dk = ops.cast(ops.sqrt(self.head_dim), dtype=attention_scores.dtype)
|
158
|
+
attention_scores = ops.divide(
|
159
|
+
attention_scores, dk
|
160
|
+
) # (batch_size, num_heads, seq_len_q, seq_len_k)
|
161
|
+
|
162
|
+
if attention_mask is not None:
|
163
|
+
# Apply the attention mask (precomputed for all layers in the
|
164
|
+
# call() function)
|
165
|
+
attention_scores = ops.add(attention_scores, attention_mask)
|
166
|
+
|
167
|
+
# Normalize the attention scores to probabilities.
|
168
|
+
attention_probs = ops.softmax(attention_scores, axis=-1)
|
169
|
+
|
170
|
+
# This is actually dropping out entire tokens to attend to, which might
|
171
|
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
172
|
+
dropout_attention_probs = self.dropout_layer(
|
173
|
+
inputs=attention_probs, training=training
|
174
|
+
)
|
175
|
+
|
176
|
+
attn_output = ops.matmul(dropout_attention_probs, value_layer)
|
177
|
+
attn_output = ops.transpose(attn_output, axes=[0, 2, 1, 3])
|
178
|
+
|
179
|
+
# (batch_size, seq_len_q, hidden_dim)
|
180
|
+
seq_len_q = ops.shape(attn_output)[1]
|
181
|
+
attn_output = ops.reshape(
|
182
|
+
attn_output, (batch_size, seq_len_q, self.hidden_dim)
|
183
|
+
)
|
184
|
+
|
185
|
+
attn_output = self.out_proj(attn_output, training=training)
|
186
|
+
return (attn_output, attention_probs)
|
187
|
+
|
188
|
+
def get_config(self):
|
189
|
+
config = super().get_config()
|
190
|
+
config.update(
|
191
|
+
{
|
192
|
+
"hidden_dim": self.hidden_dim,
|
193
|
+
"num_heads": self.num_heads,
|
194
|
+
"dropout": self.dropout,
|
195
|
+
}
|
196
|
+
)
|
197
|
+
return config
|
198
|
+
|
199
|
+
|
200
|
+
class Gemma3VitEncoderBlock(keras.layers.Layer):
|
201
|
+
def __init__(
|
202
|
+
self,
|
203
|
+
num_heads,
|
204
|
+
intermediate_dim,
|
205
|
+
layer_norm_epsilon=1e-6,
|
206
|
+
**kwargs,
|
207
|
+
):
|
208
|
+
super().__init__(**kwargs)
|
209
|
+
self.num_heads = num_heads
|
210
|
+
self.intermediate_dim = intermediate_dim
|
211
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
212
|
+
|
213
|
+
def compute_attention(self, x, mask=None):
|
214
|
+
if mask is not None:
|
215
|
+
mask = ops.cast(mask, dtype=x.dtype)
|
216
|
+
return self.attn(x, attention_mask=mask)[0]
|
217
|
+
|
218
|
+
def build(self, input_shape):
|
219
|
+
hidden_dim = input_shape[-1]
|
220
|
+
self.attn = Gemma3VitAttention(
|
221
|
+
hidden_dim,
|
222
|
+
self.num_heads,
|
223
|
+
dtype=self.dtype_policy,
|
224
|
+
name="multi_head_attention",
|
225
|
+
)
|
226
|
+
self.layer_norm_1 = keras.layers.LayerNormalization(
|
227
|
+
epsilon=self.layer_norm_epsilon,
|
228
|
+
dtype=self.dtype_policy,
|
229
|
+
name="layer_norm_1",
|
230
|
+
)
|
231
|
+
self.mlp_dense_1 = keras.layers.Dense(
|
232
|
+
self.intermediate_dim,
|
233
|
+
dtype=self.dtype_policy,
|
234
|
+
name="mlp_dense_1",
|
235
|
+
)
|
236
|
+
self.mlp_dense_2 = keras.layers.Dense(
|
237
|
+
hidden_dim,
|
238
|
+
dtype=self.dtype_policy,
|
239
|
+
name="mlp_dense_2",
|
240
|
+
)
|
241
|
+
self.layer_norm_2 = keras.layers.LayerNormalization(
|
242
|
+
epsilon=self.layer_norm_epsilon,
|
243
|
+
dtype=self.dtype_policy,
|
244
|
+
name="layer_norm_2",
|
245
|
+
)
|
246
|
+
self.attn.build(None)
|
247
|
+
self.layer_norm_1.build([None, None, hidden_dim])
|
248
|
+
self.mlp_dense_1.build([None, None, hidden_dim])
|
249
|
+
self.mlp_dense_2.build([None, None, self.intermediate_dim])
|
250
|
+
self.layer_norm_2.build([None, None, hidden_dim])
|
251
|
+
self.built = True
|
252
|
+
|
253
|
+
def call(self, x, mask=None):
|
254
|
+
residual = x
|
255
|
+
x = self.layer_norm_1(x)
|
256
|
+
# mask = ops.ones_like(x) if mask is None else mask
|
257
|
+
x = self.compute_attention(x, mask)
|
258
|
+
x = x + residual
|
259
|
+
residual = x
|
260
|
+
x = self.mlp_dense_1(self.layer_norm_2(residual))
|
261
|
+
x = keras.activations.gelu(x, approximate=True)
|
262
|
+
x = self.mlp_dense_2(x)
|
263
|
+
return residual + x
|
264
|
+
|
265
|
+
def compute_output_shape(self, inputs_shape):
|
266
|
+
return inputs_shape
|
267
|
+
|
268
|
+
def get_config(self):
|
269
|
+
config = super().get_config()
|
270
|
+
config.update(
|
271
|
+
{
|
272
|
+
"num_heads": self.num_heads,
|
273
|
+
"intermediate_dim": self.intermediate_dim,
|
274
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
275
|
+
}
|
276
|
+
)
|
277
|
+
return config
|
278
|
+
|
279
|
+
|
280
|
+
class Gemma3VitEncoder(keras.layers.Layer):
|
281
|
+
def __init__(
|
282
|
+
self,
|
283
|
+
patch_size,
|
284
|
+
image_size,
|
285
|
+
hidden_dim,
|
286
|
+
num_layers,
|
287
|
+
num_heads,
|
288
|
+
intermediate_dim,
|
289
|
+
layer_norm_epsilon=1e-6,
|
290
|
+
dtype=None,
|
291
|
+
**kwargs,
|
292
|
+
):
|
293
|
+
super().__init__(dtype=dtype, **kwargs)
|
294
|
+
self.patch_size = patch_size
|
295
|
+
self.image_size = image_size
|
296
|
+
self.hidden_dim = hidden_dim
|
297
|
+
self.num_layers = num_layers
|
298
|
+
self.num_heads = num_heads
|
299
|
+
self.intermediate_dim = intermediate_dim
|
300
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
301
|
+
self.encoder_layer_norm = keras.layers.LayerNormalization(
|
302
|
+
epsilon=layer_norm_epsilon,
|
303
|
+
dtype=dtype,
|
304
|
+
name="encoder_layer_norm",
|
305
|
+
)
|
306
|
+
self.vision_embeddings = Gemma3VitEmbeddings(
|
307
|
+
hidden_dim=hidden_dim,
|
308
|
+
patch_size=patch_size,
|
309
|
+
image_size=image_size,
|
310
|
+
dtype=dtype,
|
311
|
+
name="encoder_embeddings",
|
312
|
+
)
|
313
|
+
self.resblocks = [
|
314
|
+
Gemma3VitEncoderBlock(
|
315
|
+
self.num_heads,
|
316
|
+
self.intermediate_dim,
|
317
|
+
dtype=dtype,
|
318
|
+
name=f"encoder_block_{i}",
|
319
|
+
)
|
320
|
+
for i in range(self.num_layers)
|
321
|
+
]
|
322
|
+
|
323
|
+
def build(self, inputs_shape):
|
324
|
+
# Collapse `batch_size`, dummy axis, `image_max_length` into one.
|
325
|
+
inputs_shape = [None] + list(inputs_shape[2:])
|
326
|
+
self.vision_embeddings.build(inputs_shape)
|
327
|
+
for block in self.resblocks:
|
328
|
+
block.build([None, None, self.hidden_dim])
|
329
|
+
self.encoder_layer_norm.build([None, None, self.hidden_dim])
|
330
|
+
self.built = True
|
331
|
+
|
332
|
+
def call(self, inputs, mask=None):
|
333
|
+
inputs_shape = ops.shape(inputs)
|
334
|
+
|
335
|
+
# Collapse `batch_size`, dummy axis, `image_max_length` into one.
|
336
|
+
inputs = ops.reshape(
|
337
|
+
inputs,
|
338
|
+
[inputs_shape[0] * inputs_shape[1]] + list(inputs_shape[2:]),
|
339
|
+
)
|
340
|
+
|
341
|
+
x = self.vision_embeddings(inputs)
|
342
|
+
for block in self.resblocks:
|
343
|
+
x = block(x, mask=mask)
|
344
|
+
x = self.encoder_layer_norm(x)
|
345
|
+
return x
|
346
|
+
|
347
|
+
def compute_output_shape(self, inputs_shape):
|
348
|
+
if inputs_shape is None:
|
349
|
+
# Fix the compatibility issue with Keras 3.1 where
|
350
|
+
# `compute_output_spec` fails to propagate `inputs_shape`
|
351
|
+
# correctly, causing it to be `None`.
|
352
|
+
inputs_shape = [None, None, None]
|
353
|
+
return [
|
354
|
+
None,
|
355
|
+
(inputs_shape[2] // self.patch_size) ** 2,
|
356
|
+
self.hidden_dim,
|
357
|
+
]
|
358
|
+
|
359
|
+
def get_config(self):
|
360
|
+
config = super().get_config()
|
361
|
+
config.update(
|
362
|
+
{
|
363
|
+
"hidden_dim": self.hidden_dim,
|
364
|
+
"num_layers": self.num_layers,
|
365
|
+
"num_heads": self.num_heads,
|
366
|
+
"intermediate_dim": self.intermediate_dim,
|
367
|
+
"patch_size": self.patch_size,
|
368
|
+
"image_size": self.image_size,
|
369
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
370
|
+
}
|
371
|
+
)
|
372
|
+
return config
|
373
|
+
|
374
|
+
|
375
|
+
class AveragePooling(keras.layers.Layer):
|
376
|
+
def __init__(self, image_size, patch_size, pool_size, **kwargs):
|
377
|
+
super().__init__(**kwargs)
|
378
|
+
|
379
|
+
self.width = image_size // patch_size
|
380
|
+
# `reduced_width` is the same as `num_vision_tokens_per_image`.
|
381
|
+
self.reduced_width = self.width // pool_size
|
382
|
+
|
383
|
+
# Attributes.
|
384
|
+
self.image_size = image_size
|
385
|
+
self.patch_size = patch_size
|
386
|
+
self.pool_size = pool_size
|
387
|
+
|
388
|
+
def build(self, input_shape):
|
389
|
+
self.average_pooling = keras.layers.AveragePooling2D(
|
390
|
+
pool_size=self.pool_size,
|
391
|
+
strides=self.pool_size,
|
392
|
+
padding="valid",
|
393
|
+
dtype=self.dtype_policy,
|
394
|
+
name="average_pooling",
|
395
|
+
)
|
396
|
+
|
397
|
+
def call(self, x):
|
398
|
+
# reshape `(bsz, height*width, emb_dim)` to
|
399
|
+
# `(bsz, width, width, emb_dim)`. `height` should be equal to
|
400
|
+
# `width`.
|
401
|
+
batch_size, _, hidden_dim = ops.shape(x)
|
402
|
+
x = ops.reshape(x, (batch_size, self.width, self.width, hidden_dim))
|
403
|
+
x = self.average_pooling(x)
|
404
|
+
output = ops.reshape(
|
405
|
+
x, (batch_size, self.reduced_width * self.reduced_width, hidden_dim)
|
406
|
+
)
|
407
|
+
return output
|
408
|
+
|
409
|
+
def compute_output_shape(self, input_shape):
|
410
|
+
return (
|
411
|
+
input_shape[0],
|
412
|
+
self.reduced_width * self.reduced_width,
|
413
|
+
input_shape[-1],
|
414
|
+
)
|
415
|
+
|
416
|
+
def get_config(self):
|
417
|
+
config = super().get_config()
|
418
|
+
config.update(
|
419
|
+
{
|
420
|
+
"image_size": self.image_size,
|
421
|
+
"patch_size": self.patch_size,
|
422
|
+
"pool_size": self.pool_size,
|
423
|
+
}
|
424
|
+
)
|
425
|
+
return config
|
426
|
+
|
427
|
+
|
428
|
+
class Gemma3VisionOutputEncoder(keras.layers.Layer):
|
429
|
+
def __init__(
|
430
|
+
self,
|
431
|
+
output_dim,
|
432
|
+
layer_norm_epsilon=1e-6,
|
433
|
+
kernel_initializer="glorot_uniform",
|
434
|
+
**kwargs,
|
435
|
+
):
|
436
|
+
super().__init__(**kwargs)
|
437
|
+
|
438
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
439
|
+
self.output_dim = output_dim
|
440
|
+
|
441
|
+
self._kernel_initializer = keras.initializers.get(
|
442
|
+
clone_initializer(kernel_initializer)
|
443
|
+
)
|
444
|
+
|
445
|
+
def build(self, input_shape):
|
446
|
+
self.vision_soft_embedding_norm = RMSNormalization(
|
447
|
+
epsilon=self.layer_norm_epsilon,
|
448
|
+
dtype=self.dtype_policy,
|
449
|
+
name="vision_soft_embedding_norm",
|
450
|
+
)
|
451
|
+
self.vision_soft_embedding_norm.build(input_shape)
|
452
|
+
|
453
|
+
self.vision_input_projection = keras.layers.Dense(
|
454
|
+
units=self.output_dim,
|
455
|
+
use_bias=False,
|
456
|
+
kernel_initializer=self._kernel_initializer,
|
457
|
+
dtype=self.dtype_policy,
|
458
|
+
name="vision_input_projection",
|
459
|
+
)
|
460
|
+
self.vision_input_projection.build(input_shape)
|
461
|
+
|
462
|
+
def call(self, inputs):
|
463
|
+
x = self.vision_soft_embedding_norm(inputs)
|
464
|
+
x = self.vision_input_projection(x)
|
465
|
+
return x
|
466
|
+
|
467
|
+
def get_config(self):
|
468
|
+
config = super().get_config()
|
469
|
+
config.update(
|
470
|
+
{
|
471
|
+
"output_dim": self.output_dim,
|
472
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
473
|
+
"kernel_initializer": keras.initializers.serialize(
|
474
|
+
self._kernel_initializer
|
475
|
+
),
|
476
|
+
}
|
477
|
+
)
|
478
|
+
|
479
|
+
def compute_output_shape(self, input_shape):
|
480
|
+
return input_shape[:-1] + (self.output_dim,)
|
481
|
+
|
482
|
+
|
483
|
+
class Gemma3Vit(keras.Model):
|
484
|
+
"""Vision Transformer (ViT) model for Gemma3.
|
485
|
+
|
486
|
+
Args:
|
487
|
+
image_size: int. The height/width of the image. Both height and width is
|
488
|
+
expected to be the same.
|
489
|
+
patch_size: int. The size of each square patch in the input image.
|
490
|
+
num_heads: int. The number of attention heads for the vision(image)
|
491
|
+
transformer encoder.
|
492
|
+
hidden_dim: int. The size of the transformer hidden state at the end
|
493
|
+
of each vision transformer layer.
|
494
|
+
num_layers: int. The number of transformer layers.
|
495
|
+
intermediate_dim: int. The output dimension of the first Dense layer in
|
496
|
+
a two-layer feedforward network for transformer.
|
497
|
+
pool_size: int. Factors by which to downscale `(dim1, dim2)` in the
|
498
|
+
average pooling layer. The same value is used for `"strides"`.
|
499
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
500
|
+
for the models computations and weights. Note that some
|
501
|
+
computations, such as softmax and layer normalization will always
|
502
|
+
be done a float32 precision regardless of dtype.
|
503
|
+
|
504
|
+
Example:
|
505
|
+
```python
|
506
|
+
image = np.random.rand(224, 224, 3)
|
507
|
+
vit_model = Gemma3Vit(image_size=224)
|
508
|
+
# The output will be of shape:
|
509
|
+
# [batch_size, num_vision_tokens_per_image, hidden_dim]
|
510
|
+
output = vit_model([image])
|
511
|
+
```
|
512
|
+
"""
|
513
|
+
|
514
|
+
def __init__(
|
515
|
+
self,
|
516
|
+
image_size,
|
517
|
+
patch_size,
|
518
|
+
num_heads,
|
519
|
+
hidden_dim,
|
520
|
+
num_layers,
|
521
|
+
intermediate_dim,
|
522
|
+
output_dim,
|
523
|
+
pool_size=14,
|
524
|
+
layer_norm_epsilon=1e-6,
|
525
|
+
dtype=None,
|
526
|
+
**kwargs,
|
527
|
+
):
|
528
|
+
# === Functional Model ===
|
529
|
+
image_input = keras.Input(
|
530
|
+
shape=(None, image_size, image_size, 3),
|
531
|
+
name="images",
|
532
|
+
)
|
533
|
+
x = image_input # Intermediate result.
|
534
|
+
x = Gemma3VitEncoder(
|
535
|
+
hidden_dim=hidden_dim,
|
536
|
+
num_layers=num_layers,
|
537
|
+
num_heads=num_heads,
|
538
|
+
intermediate_dim=intermediate_dim,
|
539
|
+
patch_size=patch_size,
|
540
|
+
image_size=image_size,
|
541
|
+
dtype=dtype,
|
542
|
+
name="image_encoder",
|
543
|
+
)(x)
|
544
|
+
|
545
|
+
x = AveragePooling(
|
546
|
+
image_size=image_size,
|
547
|
+
patch_size=patch_size,
|
548
|
+
pool_size=pool_size,
|
549
|
+
dtype=dtype,
|
550
|
+
name="pooling",
|
551
|
+
)(x)
|
552
|
+
|
553
|
+
x = Gemma3VisionOutputEncoder(
|
554
|
+
output_dim=output_dim,
|
555
|
+
layer_norm_epsilon=layer_norm_epsilon,
|
556
|
+
kernel_initializer=keras.initializers.RandomNormal(
|
557
|
+
mean=0.0, stddev=0.01
|
558
|
+
),
|
559
|
+
dtype=dtype,
|
560
|
+
name="vision_output_encoder",
|
561
|
+
)(x)
|
562
|
+
|
563
|
+
outputs = x
|
564
|
+
super().__init__(
|
565
|
+
inputs=image_input,
|
566
|
+
outputs=outputs,
|
567
|
+
**kwargs,
|
568
|
+
)
|
569
|
+
|
570
|
+
# === Config ===
|
571
|
+
self.image_size = image_size
|
572
|
+
self.patch_size = patch_size
|
573
|
+
self.num_heads = num_heads
|
574
|
+
self.hidden_dim = hidden_dim
|
575
|
+
self.num_layers = num_layers
|
576
|
+
self.intermediate_dim = intermediate_dim
|
577
|
+
self.output_dim = output_dim
|
578
|
+
self.pool_size = pool_size
|
579
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
580
|
+
self.num_vision_tokens_per_image = (
|
581
|
+
(image_size // patch_size) ** 2
|
582
|
+
) // (pool_size**2)
|
583
|
+
|
584
|
+
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
|
585
|
+
if hasattr(keras.dtype_policies, "get"):
|
586
|
+
self.dtype_policy = keras.dtype_policies.get(dtype)
|
587
|
+
else:
|
588
|
+
if isinstance(dtype, keras.dtype_policies.DTypePolicy):
|
589
|
+
dtype = dtype.name
|
590
|
+
dtype = dtype or keras.config.dtype_policy().name
|
591
|
+
self.dtype_policy = keras.dtype_policies.DTypePolicy(dtype)
|
592
|
+
|
593
|
+
def get_config(self):
|
594
|
+
config = super().get_config()
|
595
|
+
config.update(
|
596
|
+
{
|
597
|
+
"num_heads": self.num_heads,
|
598
|
+
"hidden_dim": self.hidden_dim,
|
599
|
+
"num_layers": self.num_layers,
|
600
|
+
"intermediate_dim": self.intermediate_dim,
|
601
|
+
"output_dim": self.output_dim,
|
602
|
+
"pool_size": self.pool_size,
|
603
|
+
"image_size": self.image_size,
|
604
|
+
"patch_size": self.patch_size,
|
605
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
606
|
+
}
|
607
|
+
)
|
608
|
+
return config
|
@@ -0,0 +1,26 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
|
5
|
+
class RMSNormalization(keras.layers.Layer):
|
6
|
+
def __init__(self, epsilon=1e-6, **kwargs):
|
7
|
+
super().__init__(**kwargs)
|
8
|
+
self.epsilon = epsilon
|
9
|
+
|
10
|
+
def build(self, input_shape):
|
11
|
+
self.scale = self.add_weight(
|
12
|
+
name="scale",
|
13
|
+
trainable=True,
|
14
|
+
shape=(input_shape[-1],),
|
15
|
+
initializer="zeros",
|
16
|
+
)
|
17
|
+
self.built = True
|
18
|
+
|
19
|
+
def call(self, x):
|
20
|
+
# Always compute normalization in float32.
|
21
|
+
x = ops.cast(x, "float32")
|
22
|
+
scale = ops.cast(self.scale, "float32")
|
23
|
+
var = ops.mean(ops.square(x), axis=-1, keepdims=True)
|
24
|
+
normed_inputs = x * ops.reciprocal(ops.sqrt(var + self.epsilon))
|
25
|
+
normed_inputs = normed_inputs * (1 + scale)
|
26
|
+
return ops.cast(normed_inputs, self.compute_dtype)
|
keras_hub/src/version_utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: keras-hub-nightly
|
3
|
-
Version: 0.20.0.
|
3
|
+
Version: 0.20.0.dev202503270400
|
4
4
|
Summary: Industry-strength Natural Language Processing extensions for Keras.
|
5
5
|
Home-page: https://github.com/keras-team/keras-hub
|
6
6
|
Author: Keras team
|