keras-hub 0.25.0__py3-none-any.whl → 0.25.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -251,6 +251,11 @@ class Gemma3DecoderBlock(keras.layers.Layer):
251
251
  cache_update_mask=None,
252
252
  ):
253
253
  # Note: `vision_mask` is used only for Gemma3.
254
+ # If float16, we clamp the input to avoid overflow.
255
+ is_float16 = keras.backend.standardize_dtype(x.dtype) == "float16"
256
+ if is_float16:
257
+ x = ops.clip(x, -65504, 65504)
258
+
254
259
  normalized_x = self.pre_attention_norm(x)
255
260
  attention_mask = self._compute_attention_mask(
256
261
  normalized_x, padding_mask, vision_mask, cache, cache_update_index
@@ -275,7 +280,15 @@ class Gemma3DecoderBlock(keras.layers.Layer):
275
280
  if self.dropout:
276
281
  attention = self.attention_dropout(attention)
277
282
 
278
- attention_x = x + attention
283
+ if is_float16:
284
+ attention_x = ops.add(
285
+ ops.cast(x, "float32"), ops.cast(attention, "float32")
286
+ )
287
+ attention_x = ops.clip(attention_x, -65504, 65504)
288
+ attention_x = ops.cast(attention_x, "float16")
289
+ else:
290
+ attention_x = x + attention
291
+
279
292
  normalized_x = self.pre_ffw_norm(attention_x)
280
293
 
281
294
  x1 = self.gating_ffw(normalized_x)
@@ -286,7 +299,14 @@ class Gemma3DecoderBlock(keras.layers.Layer):
286
299
  if self.use_post_ffw_norm:
287
300
  x = self.post_ffw_norm(x)
288
301
 
289
- x = x + attention_x
302
+ if is_float16:
303
+ x = ops.add(
304
+ ops.cast(x, "float32"), ops.cast(attention_x, "float32")
305
+ )
306
+ x = ops.clip(x, -65504, 65504)
307
+ x = ops.cast(x, "float16")
308
+ else:
309
+ x = x + attention_x
290
310
 
291
311
  if cache is not None:
292
312
  return x, new_cache
keras_hub/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.25.0"
4
+ __version__ = "0.25.1"
5
5
 
6
6
 
7
7
  @keras_hub_export("keras_hub.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-hub
3
- Version: 0.25.0
3
+ Version: 0.25.1
4
4
  Summary: Pretrained models for Keras.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License-Expression: Apache-2.0
@@ -5,7 +5,7 @@ keras_hub/models/__init__.py,sha256=-RPLKDEOnRJmHyB867IApKj98hBrhUIuGtO15xYKQxw,
5
5
  keras_hub/samplers/__init__.py,sha256=aFQIkiqbZpi8vjrPp2MVII4QUfE-eQjra5fMeHsoy7k,886
6
6
  keras_hub/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  keras_hub/src/api_export.py,sha256=9pQZK27JObxWZ96QPLBp1OBsjWigh1iuV6RglPGMRk0,1499
8
- keras_hub/src/version.py,sha256=rshx76XoLXH7RCuRIo37RPFdGafttvDLFxgeDxJX22E,206
8
+ keras_hub/src/version.py,sha256=WSC-QBbLh3MiIyJz2ADMs5o4B5gmc1jUEKu6olCr_hI,206
9
9
  keras_hub/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  keras_hub/src/layers/modeling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  keras_hub/src/layers/modeling/alibi_bias.py,sha256=1XBTHI52L_iJDhN_w5ydu_iMhCuTgQAxEPwcLA6BPuk,4411
@@ -249,7 +249,7 @@ keras_hub/src/models/gemma3/gemma3_attention.py,sha256=u3RNI8dva5lzzqFNTAe9996s8
249
249
  keras_hub/src/models/gemma3/gemma3_backbone.py,sha256=HdWDRuF9MMwIzNVZEd1j53ILzptskvCxFiO__nfVQYU,16686
250
250
  keras_hub/src/models/gemma3/gemma3_causal_lm.py,sha256=U3C9TWlIz8VefAxQ0wJ6bDz18wqHBie8B26Ub_nFZs4,13843
251
251
  keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py,sha256=_gvKPoXqNXpXcsfc8L29wW50MToHIr2D-4Q6MNVfBU0,29790
252
- keras_hub/src/models/gemma3/gemma3_decoder_block.py,sha256=CYwYazqwakLNfhOLBl_8Q2TVZcMcOxMtiZtuVlk_hoo,11470
252
+ keras_hub/src/models/gemma3/gemma3_decoder_block.py,sha256=IBfi724Vwtq1vjuoShEEy-WpL8zyiUqeHwg1IVCSehU,12191
253
253
  keras_hub/src/models/gemma3/gemma3_image_converter.py,sha256=czi5JrTyKiK0nFzvonviBIX8jjvLHqvGNA9RyheB31k,536
254
254
  keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py,sha256=CfYdudk5En9iU6vEnrcrEWIztloD1r8VzF2extqAhAM,4616
255
255
  keras_hub/src/models/gemma3/gemma3_presets.py,sha256=3jK1OyDKDdSG_lC7yh-8O5BMKJ61knbVrHQgKe0cJiQ,8209
@@ -638,7 +638,7 @@ keras_hub/src/utils/transformers/export/gemma.py,sha256=xX_vfQwvFZ_-lQX4kgMNOGKL
638
638
  keras_hub/src/utils/transformers/export/hf_exporter.py,sha256=Qk52c6LIA2eMHUNY9Vy4STJSpnhLMdJ_t-3ljqhSr4k,5081
639
639
  keras_hub/tokenizers/__init__.py,sha256=7squHiwAu3KU5rBiupi4pH0zpUg5BwRfAOu0JcJmfA4,4873
640
640
  keras_hub/utils/__init__.py,sha256=jXPqVGBpJr_PpYmqD8aDG-fRMlxH-ulqCR2SZMn288Y,646
641
- keras_hub-0.25.0.dist-info/METADATA,sha256=XcP71H-itOfJNo_jV_WxUacUFlRLho27ZzLzltHwuns,7371
642
- keras_hub-0.25.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
643
- keras_hub-0.25.0.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
644
- keras_hub-0.25.0.dist-info/RECORD,,
641
+ keras_hub-0.25.1.dist-info/METADATA,sha256=uzKpDD4OSxVV5X4qPqBWXg11o6G11LhSs_z0F7t7woU,7371
642
+ keras_hub-0.25.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
643
+ keras_hub-0.25.1.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
644
+ keras_hub-0.25.1.dist-info/RECORD,,