keras-hub 0.25.0__py3-none-any.whl → 0.25.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +22 -2
- keras_hub/src/version.py +1 -1
- {keras_hub-0.25.0.dist-info → keras_hub-0.25.1.dist-info}/METADATA +1 -1
- {keras_hub-0.25.0.dist-info → keras_hub-0.25.1.dist-info}/RECORD +6 -6
- {keras_hub-0.25.0.dist-info → keras_hub-0.25.1.dist-info}/WHEEL +0 -0
- {keras_hub-0.25.0.dist-info → keras_hub-0.25.1.dist-info}/top_level.txt +0 -0
|
@@ -251,6 +251,11 @@ class Gemma3DecoderBlock(keras.layers.Layer):
|
|
|
251
251
|
cache_update_mask=None,
|
|
252
252
|
):
|
|
253
253
|
# Note: `vision_mask` is used only for Gemma3.
|
|
254
|
+
# If float16, we clamp the input to avoid overflow.
|
|
255
|
+
is_float16 = keras.backend.standardize_dtype(x.dtype) == "float16"
|
|
256
|
+
if is_float16:
|
|
257
|
+
x = ops.clip(x, -65504, 65504)
|
|
258
|
+
|
|
254
259
|
normalized_x = self.pre_attention_norm(x)
|
|
255
260
|
attention_mask = self._compute_attention_mask(
|
|
256
261
|
normalized_x, padding_mask, vision_mask, cache, cache_update_index
|
|
@@ -275,7 +280,15 @@ class Gemma3DecoderBlock(keras.layers.Layer):
|
|
|
275
280
|
if self.dropout:
|
|
276
281
|
attention = self.attention_dropout(attention)
|
|
277
282
|
|
|
278
|
-
|
|
283
|
+
if is_float16:
|
|
284
|
+
attention_x = ops.add(
|
|
285
|
+
ops.cast(x, "float32"), ops.cast(attention, "float32")
|
|
286
|
+
)
|
|
287
|
+
attention_x = ops.clip(attention_x, -65504, 65504)
|
|
288
|
+
attention_x = ops.cast(attention_x, "float16")
|
|
289
|
+
else:
|
|
290
|
+
attention_x = x + attention
|
|
291
|
+
|
|
279
292
|
normalized_x = self.pre_ffw_norm(attention_x)
|
|
280
293
|
|
|
281
294
|
x1 = self.gating_ffw(normalized_x)
|
|
@@ -286,7 +299,14 @@ class Gemma3DecoderBlock(keras.layers.Layer):
|
|
|
286
299
|
if self.use_post_ffw_norm:
|
|
287
300
|
x = self.post_ffw_norm(x)
|
|
288
301
|
|
|
289
|
-
|
|
302
|
+
if is_float16:
|
|
303
|
+
x = ops.add(
|
|
304
|
+
ops.cast(x, "float32"), ops.cast(attention_x, "float32")
|
|
305
|
+
)
|
|
306
|
+
x = ops.clip(x, -65504, 65504)
|
|
307
|
+
x = ops.cast(x, "float16")
|
|
308
|
+
else:
|
|
309
|
+
x = x + attention_x
|
|
290
310
|
|
|
291
311
|
if cache is not None:
|
|
292
312
|
return x, new_cache
|
keras_hub/src/version.py
CHANGED
|
@@ -5,7 +5,7 @@ keras_hub/models/__init__.py,sha256=-RPLKDEOnRJmHyB867IApKj98hBrhUIuGtO15xYKQxw,
|
|
|
5
5
|
keras_hub/samplers/__init__.py,sha256=aFQIkiqbZpi8vjrPp2MVII4QUfE-eQjra5fMeHsoy7k,886
|
|
6
6
|
keras_hub/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
7
|
keras_hub/src/api_export.py,sha256=9pQZK27JObxWZ96QPLBp1OBsjWigh1iuV6RglPGMRk0,1499
|
|
8
|
-
keras_hub/src/version.py,sha256=
|
|
8
|
+
keras_hub/src/version.py,sha256=WSC-QBbLh3MiIyJz2ADMs5o4B5gmc1jUEKu6olCr_hI,206
|
|
9
9
|
keras_hub/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
10
|
keras_hub/src/layers/modeling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
11
|
keras_hub/src/layers/modeling/alibi_bias.py,sha256=1XBTHI52L_iJDhN_w5ydu_iMhCuTgQAxEPwcLA6BPuk,4411
|
|
@@ -249,7 +249,7 @@ keras_hub/src/models/gemma3/gemma3_attention.py,sha256=u3RNI8dva5lzzqFNTAe9996s8
|
|
|
249
249
|
keras_hub/src/models/gemma3/gemma3_backbone.py,sha256=HdWDRuF9MMwIzNVZEd1j53ILzptskvCxFiO__nfVQYU,16686
|
|
250
250
|
keras_hub/src/models/gemma3/gemma3_causal_lm.py,sha256=U3C9TWlIz8VefAxQ0wJ6bDz18wqHBie8B26Ub_nFZs4,13843
|
|
251
251
|
keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py,sha256=_gvKPoXqNXpXcsfc8L29wW50MToHIr2D-4Q6MNVfBU0,29790
|
|
252
|
-
keras_hub/src/models/gemma3/gemma3_decoder_block.py,sha256=
|
|
252
|
+
keras_hub/src/models/gemma3/gemma3_decoder_block.py,sha256=IBfi724Vwtq1vjuoShEEy-WpL8zyiUqeHwg1IVCSehU,12191
|
|
253
253
|
keras_hub/src/models/gemma3/gemma3_image_converter.py,sha256=czi5JrTyKiK0nFzvonviBIX8jjvLHqvGNA9RyheB31k,536
|
|
254
254
|
keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py,sha256=CfYdudk5En9iU6vEnrcrEWIztloD1r8VzF2extqAhAM,4616
|
|
255
255
|
keras_hub/src/models/gemma3/gemma3_presets.py,sha256=3jK1OyDKDdSG_lC7yh-8O5BMKJ61knbVrHQgKe0cJiQ,8209
|
|
@@ -638,7 +638,7 @@ keras_hub/src/utils/transformers/export/gemma.py,sha256=xX_vfQwvFZ_-lQX4kgMNOGKL
|
|
|
638
638
|
keras_hub/src/utils/transformers/export/hf_exporter.py,sha256=Qk52c6LIA2eMHUNY9Vy4STJSpnhLMdJ_t-3ljqhSr4k,5081
|
|
639
639
|
keras_hub/tokenizers/__init__.py,sha256=7squHiwAu3KU5rBiupi4pH0zpUg5BwRfAOu0JcJmfA4,4873
|
|
640
640
|
keras_hub/utils/__init__.py,sha256=jXPqVGBpJr_PpYmqD8aDG-fRMlxH-ulqCR2SZMn288Y,646
|
|
641
|
-
keras_hub-0.25.
|
|
642
|
-
keras_hub-0.25.
|
|
643
|
-
keras_hub-0.25.
|
|
644
|
-
keras_hub-0.25.
|
|
641
|
+
keras_hub-0.25.1.dist-info/METADATA,sha256=uzKpDD4OSxVV5X4qPqBWXg11o6G11LhSs_z0F7t7woU,7371
|
|
642
|
+
keras_hub-0.25.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
643
|
+
keras_hub-0.25.1.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
|
|
644
|
+
keras_hub-0.25.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|