keras-hub-nightly 0.19.0.dev202502190348__py3-none-any.whl → 0.19.0.dev202502210346__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -186,6 +186,14 @@ class Backbone(keras.Model):
186
186
  saver = get_preset_saver(preset_dir)
187
187
  saver.save_backbone(self)
188
188
 
189
+ def get_lora_target_names(self):
190
+ """Returns list of layer names which are to be LoRA-fied.
191
+
192
+ Subclasses can override this method if the names of layers to be
193
+ LoRa-fied are different.
194
+ """
195
+ return ["query_dense", "value_dense", "query", "value"]
196
+
189
197
  def enable_lora(self, rank):
190
198
  """Enable Lora on the backbone.
191
199
 
@@ -193,7 +201,8 @@ class Backbone(keras.Model):
193
201
  while enabling Lora on the query & value `EinsumDense` layers
194
202
  of the attention layers.
195
203
  """
196
- target_names = ["query_dense", "value_dense", "query", "value"]
204
+ target_names = self.get_lora_target_names()
205
+
197
206
  self.trainable = True
198
207
  self._lora_enabled_layers = []
199
208
  self._lora_rank = rank
@@ -4,6 +4,7 @@ from keras import ops
4
4
 
5
5
  from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
6
6
  from keras_hub.src.utils.keras_utils import clone_initializer
7
+ from keras_hub.src.utils.keras_utils import has_flash_attention_support
7
8
 
8
9
 
9
10
  class CachedGemmaAttention(keras.layers.Layer):
@@ -117,6 +118,36 @@ class CachedGemmaAttention(keras.layers.Layer):
117
118
  query_normalization = 1 / np.sqrt(
118
119
  self.hidden_dim // self.num_query_heads
119
120
  )
121
+ use_dot_product_attention = not (
122
+ self.dropout > 0.0 or (len(q.shape) != 4)
123
+ )
124
+ if has_flash_attention_support() and use_dot_product_attention:
125
+ if self.dropout > 0.0:
126
+ raise ValueError(
127
+ "Flash attention does not support dropout. "
128
+ "Please set `dropout` to 0.0."
129
+ )
130
+ if attention_mask is not None:
131
+ while len(attention_mask.shape) < 4:
132
+ attention_mask = ops.expand_dims(
133
+ attention_mask, axis=1
134
+ ) # Add dimension for num_heads
135
+ if attention_mask.shape[1] != self.num_query_heads:
136
+ attention_mask = ops.tile(
137
+ attention_mask, [1, self.num_query_heads, 1, 1]
138
+ )
139
+
140
+ attention_output = ops.dot_product_attention(
141
+ query=q,
142
+ key=k,
143
+ value=v,
144
+ bias=None,
145
+ mask=attention_mask,
146
+ scale=query_normalization,
147
+ is_causal=True,
148
+ flash_attention=True,
149
+ )
150
+ return attention_output
120
151
 
121
152
  q *= ops.cast(query_normalization, dtype=q.dtype)
122
153
  q_shape = ops.shape(q)
@@ -131,8 +162,8 @@ class CachedGemmaAttention(keras.layers.Layer):
131
162
  )
132
163
  b, q_len, _, _, h = ops.shape(q)
133
164
 
165
+ # Fallback to standard attention if flash attention is disabled
134
166
  attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k)
135
-
136
167
  if self.logit_soft_cap is not None:
137
168
  attention_logits = ops.divide(attention_logits, self.logit_soft_cap)
138
169
  attention_logits = ops.multiply(
@@ -274,6 +274,13 @@ class PaliGemmaBackbone(Backbone):
274
274
  # Keep the image_sequence_length as a backbone property for easy access.
275
275
  self.image_sequence_length = self.vit_encoder.image_sequence_length
276
276
 
277
+ def get_lora_target_names(self):
278
+ target_names = super().get_lora_target_names()
279
+
280
+ # Add these for `PaliGemmaVITAttention`.
281
+ target_names += ["query_proj", "value_proj"]
282
+ return target_names
283
+
277
284
  def get_config(self):
278
285
  config = super().get_config()
279
286
  config.update(
@@ -83,6 +83,96 @@ backbone_presets = {
83
83
  },
84
84
  "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_ft_docci_10b_448/2",
85
85
  },
86
+ "pali_gemma2_mix_3b_224": {
87
+ "metadata": {
88
+ "description": (
89
+ "3 billion parameter, image size 224, 27-layer for "
90
+ "SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage "
91
+ "model. This model has been fine-tuned on a wide range of "
92
+ "vision-language tasks and domains."
93
+ ),
94
+ "params": 3032094960,
95
+ "official_name": "PaliGemma2",
96
+ "path": "pali_gemma2",
97
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
98
+ },
99
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_224/2",
100
+ },
101
+ "pali_gemma2_mix_3b_448": {
102
+ "metadata": {
103
+ "description": (
104
+ "3 billion parameter, image size 448, 27-layer for "
105
+ "SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage "
106
+ "model. This model has been fine-tuned on a wide range of "
107
+ "vision-language tasks and domains."
108
+ ),
109
+ "params": 3032979696,
110
+ "official_name": "PaliGemma2",
111
+ "path": "pali_gemma2",
112
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
113
+ },
114
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_448/2",
115
+ },
116
+ "pali_gemma2_mix_10b_224": {
117
+ "metadata": {
118
+ "description": (
119
+ "10 billion parameter, image size 224, 27-layer for "
120
+ "SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage "
121
+ "model. This model has been fine-tuned on a wide range of "
122
+ "vision-language tasks and domains."
123
+ ),
124
+ "params": 9662409456,
125
+ "official_name": "PaliGemma2",
126
+ "path": "pali_gemma2",
127
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
128
+ },
129
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_10b_224/2",
130
+ },
131
+ "pali_gemma2_mix_10b_448": {
132
+ "metadata": {
133
+ "description": (
134
+ "10 billion parameter, image size 448, 27-layer for "
135
+ "SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage "
136
+ "model. This model has been fine-tuned on a wide range of "
137
+ "vision-language tasks and domains."
138
+ ),
139
+ "params": 9663294192,
140
+ "official_name": "PaliGemma2",
141
+ "path": "pali_gemma2",
142
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
143
+ },
144
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_10b_448/2",
145
+ },
146
+ "pali_gemma2_mix_28b_224": {
147
+ "metadata": {
148
+ "description": (
149
+ "28 billion parameter, image size 224, 27-layer for "
150
+ "SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage "
151
+ "model. This model has been fine-tuned on a wide range of "
152
+ "vision-language tasks and domains."
153
+ ),
154
+ "params": 27650192112,
155
+ "official_name": "PaliGemma2",
156
+ "path": "pali_gemma2",
157
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
158
+ },
159
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_mix_224/2",
160
+ },
161
+ "pali_gemma2_mix_28b_448": {
162
+ "metadata": {
163
+ "description": (
164
+ "28 billion parameter, image size 448, 27-layer for "
165
+ "SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage "
166
+ "model. This model has been fine-tuned on a wide range of "
167
+ "vision-language tasks and domains."
168
+ ),
169
+ "params": 27650192112,
170
+ "official_name": "PaliGemma2",
171
+ "path": "pali_gemma2",
172
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
173
+ },
174
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_mix_448/2",
175
+ },
86
176
  "pali_gemma2_pt_3b_224": {
87
177
  "metadata": {
88
178
  "description": (
@@ -181,7 +271,7 @@ backbone_presets = {
181
271
  "model. This model has been pre-trained on a mixture of "
182
272
  "datasets."
183
273
  ),
184
- "params": 9662409456,
274
+ "params": 27650192112,
185
275
  "official_name": "PaliGemma2",
186
276
  "path": "pali_gemma2",
187
277
  "model_card": "https://www.kaggle.com/models/google/paligemma-2",
@@ -196,7 +286,7 @@ backbone_presets = {
196
286
  "model. This model has been pre-trained on a mixture of "
197
287
  "datasets."
198
288
  ),
199
- "params": 9663294192,
289
+ "params": 27650192112,
200
290
  "official_name": "PaliGemma2",
201
291
  "path": "pali_gemma2",
202
292
  "model_card": "https://www.kaggle.com/models/google/paligemma-2",
@@ -211,7 +301,7 @@ backbone_presets = {
211
301
  "model. This model has been pre-trained on a mixture of "
212
302
  "datasets."
213
303
  ),
214
- "params": 9666833136,
304
+ "params": 27650192112,
215
305
  "official_name": "PaliGemma2",
216
306
  "path": "pali_gemma2",
217
307
  "model_card": "https://www.kaggle.com/models/google/paligemma-2",
@@ -56,7 +56,19 @@ def standardize_data_format(data_format):
56
56
 
57
57
 
58
58
  def has_flash_attention_support():
59
- if hasattr(keras.config, "is_flash_attention_enabled"):
59
+ if (
60
+ hasattr(keras.config, "is_flash_attention_enabled")
61
+ and keras.config.backend() == "jax"
62
+ ):
63
+ try:
64
+ from jax.nn import dot_product_attention as dot_product_attention
65
+ except ImportError:
66
+ logging.warning(
67
+ "Flash attention is not supported in your current JAX version. "
68
+ "Please update it by following the official guide: "
69
+ "https://jax.readthedocs.io/en/latest/installation.html"
70
+ )
71
+ return False
60
72
  return True
61
73
  else:
62
74
  return False
@@ -1,7 +1,7 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.19.0.dev202502190348"
4
+ __version__ = "0.19.0.dev202502210346"
5
5
 
6
6
 
7
7
  @keras_hub_export("keras_hub.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: keras-hub-nightly
3
- Version: 0.19.0.dev202502190348
3
+ Version: 0.19.0.dev202502210346
4
4
  Summary: Industry-strength Natural Language Processing extensions for Keras.
5
5
  Home-page: https://github.com/keras-team/keras-hub
6
6
  Author: Keras team
@@ -27,10 +27,11 @@ Requires-Dist: packaging
27
27
  Requires-Dist: regex
28
28
  Requires-Dist: rich
29
29
  Requires-Dist: kagglehub
30
- Requires-Dist: tensorflow-text
31
30
  Provides-Extra: extras
32
31
  Requires-Dist: rouge-score; extra == "extras"
33
32
  Requires-Dist: sentencepiece; extra == "extras"
33
+ Provides-Extra: nlp
34
+ Requires-Dist: tensorflow-text; extra == "nlp"
34
35
  Dynamic: author
35
36
  Dynamic: author-email
36
37
  Dynamic: classifier
@@ -147,6 +148,13 @@ To install the latest KerasHub release with Keras 3, simply run:
147
148
  pip install --upgrade keras-hub
148
149
  ```
149
150
 
151
+ Our text tokenizers are based on TensorFlow Text. Hence, if you are using any
152
+ model which has language as a modality, you will have to run:
153
+
154
+ ```
155
+ pip install --upgrade keras-hub[nlp]
156
+ ```
157
+
150
158
  To install the latest nightly changes for both KerasHub and Keras, you can use
151
159
  our nightly package.
152
160
 
@@ -8,7 +8,7 @@ keras_hub/api/tokenizers/__init__.py,sha256=mtJgQy1spfQnPAkeLoeinsT_W9iCWHlJXwzc
8
8
  keras_hub/api/utils/__init__.py,sha256=Gp1E6gG-RtKQS3PBEQEOz9PQvXkXaJ0ySGMqZ7myN7A,215
9
9
  keras_hub/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  keras_hub/src/api_export.py,sha256=9pQZK27JObxWZ96QPLBp1OBsjWigh1iuV6RglPGMRk0,1499
11
- keras_hub/src/version_utils.py,sha256=aAL0M_iBZYGjtaDtwjeSo1Y9KpY-xoKEWzmooZygJ_c,222
11
+ keras_hub/src/version_utils.py,sha256=ttkrKvEmHIzmFoB_r1Q4g722HgNujcQsmyjdwbeHz9E,222
12
12
  keras_hub/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  keras_hub/src/layers/modeling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  keras_hub/src/layers/modeling/alibi_bias.py,sha256=1XBTHI52L_iJDhN_w5ydu_iMhCuTgQAxEPwcLA6BPuk,4411
@@ -44,7 +44,7 @@ keras_hub/src/metrics/rouge_base.py,sha256=Pt2DUznhTTeR-fX1nQ_wSbPtmuTgxQTvrGpu8
44
44
  keras_hub/src/metrics/rouge_l.py,sha256=JlZhMBV6wS_6zMd57pkTc6yxHkEJT9fVQMlPZKekQzQ,2729
45
45
  keras_hub/src/metrics/rouge_n.py,sha256=JoFtmgjF4Ic263ny6bfD6vMHKreH9le3HnOOxemupRc,3620
46
46
  keras_hub/src/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
- keras_hub/src/models/backbone.py,sha256=lOv8id2qCkewrtBOrSObc3_nh_WOfsHsgGlIBsHug7g,10986
47
+ keras_hub/src/models/backbone.py,sha256=ofIqRvSUrdP6rXAP0QTbStwiEfv-JxS7wTzcHxjj6iQ,11254
48
48
  keras_hub/src/models/causal_lm.py,sha256=ReaF-i3SHsCkHh4c28jM72QjMQ8x7yiCwG39FRb-7KE,16786
49
49
  keras_hub/src/models/causal_lm_preprocessor.py,sha256=YY7VJZicdmnjDSWi9g4_pEpd5bdJK166GlWcapvokF0,6663
50
50
  keras_hub/src/models/feature_pyramid_backbone.py,sha256=clEW-TTQSVJ_5qFNdDF0iABkin1p_xlBUFjJrC7T0IA,2247
@@ -183,7 +183,7 @@ keras_hub/src/models/flux/flux_presets.py,sha256=z7C_FbI1_F5YETXuWpc7Yh_0w-5N0eB
183
183
  keras_hub/src/models/flux/flux_text_to_image.py,sha256=Rf5dD2EhG0bE8Gyg9sqaA8YEexS1kdraofIkxiZDjvc,4166
184
184
  keras_hub/src/models/flux/flux_text_to_image_preprocessor.py,sha256=Fs9jr97QtmRUbRRz1kITpkuhDM2GoV3n0XSFC-qQA14,2252
185
185
  keras_hub/src/models/gemma/__init__.py,sha256=rVzOJMJ39bgVlT8UdC0t8PlN2c237GKTBmfHIsbPuOQ,251
186
- keras_hub/src/models/gemma/gemma_attention.py,sha256=1CVN5z9GKoU8TuNMih2_MweDkpd98xSqdic9F8xIBE8,8317
186
+ keras_hub/src/models/gemma/gemma_attention.py,sha256=uvBDwIfv-pEo4IF2LY7vdt2R9W-OQIqOA0hLWVQUluI,9659
187
187
  keras_hub/src/models/gemma/gemma_backbone.py,sha256=GzAUSArw_pN9dtWQzTVhWDbW-XyWt4GyMcFLn9hwmh0,13391
188
188
  keras_hub/src/models/gemma/gemma_causal_lm.py,sha256=3OXaIXlrKqMIuUnBk-bUz-0SYFL-XkkQTWm8qRY2YII,16770
189
189
  keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py,sha256=bpKkEurWIfa6Kp9s4pz84-sBDSA6ZFNHP8nXG1fFQrg,2912
@@ -250,12 +250,12 @@ keras_hub/src/models/opt/opt_causal_lm_preprocessor.py,sha256=xHfslVMOZlAIj2V2jI
250
250
  keras_hub/src/models/opt/opt_presets.py,sha256=LrjgI5gbq4Cvfl_pmeCnKn4hS_V_0GYTeJaDc9tbeZM,1745
251
251
  keras_hub/src/models/opt/opt_tokenizer.py,sha256=oDHeed4xf07tm14hj_C78BkzMuuRwRP2cRHmqYnObrs,2557
252
252
  keras_hub/src/models/pali_gemma/__init__.py,sha256=uODWTlttOOchcTLpiYHCEWMXnDxIz8ZVIeYFQN2bd8o,288
253
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py,sha256=aRsLlgKqqxwtYxYy-D9k37YSJowUlRWfxpyRBFWDRnI,13413
253
+ keras_hub/src/models/pali_gemma/pali_gemma_backbone.py,sha256=_Sa22j4jk_7400h33S22w0S8Dh8Lzzl6A5WeEp55zSk,13637
254
254
  keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py,sha256=AViEs6YltUqWnIVo7J02JkXcanBgLSdwZwF56TVr8gc,11345
255
255
  keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py,sha256=F57y0fZ0wYYxfGIjfrJc1W9uQpViYFx5bvFjj5CqUbI,4814
256
256
  keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py,sha256=24ABQ1vGlppV-KfWh0YqJjzM_Lu2GIwvyJ4X2XXie_A,5616
257
257
  keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py,sha256=5yM_jUtrFsWIieiwfFBoP7mtPmQAwywkeLKbd7fhmzk,371
258
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py,sha256=Ka1ChUUSKw-yY2th3QtmNtkeXt0krYfwhkHrScioMls,8979
258
+ keras_hub/src/models/pali_gemma/pali_gemma_presets.py,sha256=zF04iShXky_c3IfUbmLlBN2FYb6iCWH1DWTgDdTCqrI,13006
259
259
  keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py,sha256=ljTiADHo0Ok88q-jVzwJIle2C8xcxnudLTsBLzIySaM,2415
260
260
  keras_hub/src/models/pali_gemma/pali_gemma_vit.py,sha256=R-W7SCnlLjkgiJ9vrn3ctbBES_yCxJSrCld5dV7nzaY,18235
261
261
  keras_hub/src/models/phi3/__init__.py,sha256=zIbf1MU-ks91mEkjTRJAsk51N3BBnXDF2JM1vO-13PQ,245
@@ -386,7 +386,7 @@ keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py,sha256=hRv_XxoPIPDpHfO0Z
386
386
  keras_hub/src/tokenizers/word_piece_tokenizer.py,sha256=vP6AZgbzsRiuPCt3W_n94nsF7XiERnagWcH_rqJHtVU,19943
387
387
  keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py,sha256=cylrs02ZrYQ1TuZr9oyS3NrVbDwGctA3VXbIh1pFJMQ,6743
388
388
  keras_hub/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
389
- keras_hub/src/utils/keras_utils.py,sha256=ZULqIQylAQen-_pNC96htvLaxSJbfAenNoCo3ZSvY5g,1843
389
+ keras_hub/src/utils/keras_utils.py,sha256=TNgp3ukTiCA0jrGUq2ZV_Xqtzc7CfiFQKyOH5t47z48,2313
390
390
  keras_hub/src/utils/pipeline_model.py,sha256=jgzB6NQPSl0KOu08N-TazfOnXnUJbZjH2EXXhx25Ftg,9084
391
391
  keras_hub/src/utils/preset_utils.py,sha256=ZbSEUSacKlr_mgHyB3ChUohgOQN7nMCkE6E2lGxt2HA,31927
392
392
  keras_hub/src/utils/python_utils.py,sha256=N8nWeO3san4YnGkffRXG3Ix7VEIMTKSN21FX5TuL7G8,202
@@ -413,7 +413,7 @@ keras_hub/src/utils/transformers/convert_pali_gemma.py,sha256=B1leeDw96Yvu81hYum
413
413
  keras_hub/src/utils/transformers/convert_vit.py,sha256=9SUZ9utNJhW_5cj3acMn9cRy47u2eIcDsrhmzj77o9k,5187
414
414
  keras_hub/src/utils/transformers/preset_loader.py,sha256=DgGJXbTSB9Na8FIR-YWWVqQPOFxHwWrGm41EwcS_EFs,3797
415
415
  keras_hub/src/utils/transformers/safetensor_utils.py,sha256=CYUHyA4y-B61r7NDnCsFb4t_UmSwZ1k9L-8gzEd6KRg,3339
416
- keras_hub_nightly-0.19.0.dev202502190348.dist-info/METADATA,sha256=L0fEtVLfSiKpy7fJyO_VUrydFIaVT0Pirw7kPwu3ob8,7498
417
- keras_hub_nightly-0.19.0.dev202502190348.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
418
- keras_hub_nightly-0.19.0.dev202502190348.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
419
- keras_hub_nightly-0.19.0.dev202502190348.dist-info/RECORD,,
416
+ keras_hub_nightly-0.19.0.dev202502210346.dist-info/METADATA,sha256=SFwTUAZFRtgw028VYnTTxCexaXIPDHlfm7BdUqZPW4Q,7721
417
+ keras_hub_nightly-0.19.0.dev202502210346.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
418
+ keras_hub_nightly-0.19.0.dev202502210346.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
419
+ keras_hub_nightly-0.19.0.dev202502210346.dist-info/RECORD,,