keras-hub-nightly 0.22.0.dev202508170419__py3-none-any.whl → 0.24.0.dev202511090424__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-hub-nightly might be problematic. Click here for more details.

Files changed (126) hide show
  1. keras_hub/layers/__init__.py +15 -0
  2. keras_hub/models/__init__.py +93 -0
  3. keras_hub/src/layers/modeling/position_embedding.py +21 -6
  4. keras_hub/src/layers/modeling/reversible_embedding.py +8 -1
  5. keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
  6. keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
  7. keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
  8. keras_hub/src/models/backbone.py +28 -16
  9. keras_hub/src/models/causal_lm.py +37 -0
  10. keras_hub/src/models/causal_lm_preprocessor.py +14 -0
  11. keras_hub/src/models/clip/clip_presets.py +8 -8
  12. keras_hub/src/models/d_fine/__init__.py +5 -0
  13. keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
  14. keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
  15. keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
  16. keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
  17. keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
  18. keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
  19. keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
  20. keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
  21. keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
  22. keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
  23. keras_hub/src/models/d_fine/d_fine_presets.py +155 -0
  24. keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
  25. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +7 -2
  26. keras_hub/src/models/depth_anything/__init__.py +9 -0
  27. keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
  28. keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
  29. keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
  30. keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
  31. keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
  32. keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
  33. keras_hub/src/models/depth_anything/depth_anything_presets.py +41 -0
  34. keras_hub/src/models/depth_anything/interpolate.py +62 -0
  35. keras_hub/src/models/depth_estimator.py +239 -0
  36. keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
  37. keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
  38. keras_hub/src/models/dinov2/dinov2_layers.py +16 -4
  39. keras_hub/src/models/dinov3/__init__.py +5 -0
  40. keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
  41. keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
  42. keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
  43. keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
  44. keras_hub/src/models/gemma/gemma_backbone.py +0 -1
  45. keras_hub/src/models/gemma/gemma_presets.py +30 -0
  46. keras_hub/src/models/gemma3/gemma3_attention.py +48 -0
  47. keras_hub/src/models/gemma3/gemma3_backbone.py +4 -1
  48. keras_hub/src/models/gemma3/gemma3_decoder_block.py +12 -0
  49. keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
  50. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
  51. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
  52. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
  53. keras_hub/src/models/image_to_image.py +5 -0
  54. keras_hub/src/models/inpaint.py +5 -0
  55. keras_hub/src/models/mobilenetv5/__init__.py +9 -0
  56. keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
  57. keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
  58. keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
  59. keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
  60. keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
  61. keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
  62. keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
  63. keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
  64. keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
  65. keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
  66. keras_hub/src/models/parseq/__init__.py +5 -0
  67. keras_hub/src/models/parseq/parseq_backbone.py +134 -0
  68. keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
  69. keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
  70. keras_hub/src/models/parseq/parseq_decoder.py +418 -0
  71. keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
  72. keras_hub/src/models/parseq/parseq_presets.py +15 -0
  73. keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
  74. keras_hub/src/models/qwen3_moe/__init__.py +5 -0
  75. keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
  76. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
  77. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
  78. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
  79. keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
  80. keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
  81. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
  82. keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
  83. keras_hub/src/models/sam/sam_prompt_encoder.py +3 -1
  84. keras_hub/src/models/siglip/siglip_presets.py +15 -0
  85. keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
  86. keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
  87. keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
  88. keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
  89. keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
  90. keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
  92. keras_hub/src/models/t5gemma/__init__.py +5 -0
  93. keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
  94. keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
  95. keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
  96. keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
  97. keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
  98. keras_hub/src/models/t5gemma/t5gemma_presets.py +374 -0
  99. keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
  100. keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
  101. keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
  102. keras_hub/src/models/text_to_image.py +5 -0
  103. keras_hub/src/samplers/beam_sampler.py +6 -6
  104. keras_hub/src/samplers/sampler.py +8 -6
  105. keras_hub/src/tests/test_case.py +40 -3
  106. keras_hub/src/tokenizers/tokenizer.py +15 -0
  107. keras_hub/src/utils/openvino_utils.py +141 -0
  108. keras_hub/src/utils/preset_utils.py +58 -2
  109. keras_hub/src/utils/tensor_utils.py +26 -2
  110. keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
  111. keras_hub/src/utils/timm/preset_loader.py +8 -4
  112. keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
  113. keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
  114. keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
  115. keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
  116. keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
  117. keras_hub/src/utils/transformers/convert_vit.py +4 -1
  118. keras_hub/src/utils/transformers/export/gemma.py +49 -4
  119. keras_hub/src/utils/transformers/export/hf_exporter.py +71 -25
  120. keras_hub/src/utils/transformers/preset_loader.py +12 -0
  121. keras_hub/src/version.py +1 -1
  122. keras_hub/tokenizers/__init__.py +15 -0
  123. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/METADATA +1 -1
  124. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/RECORD +126 -47
  125. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/WHEEL +0 -0
  126. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,4 @@
1
+ """DINOV3 model preset configurations."""
2
+
3
+ # Metadata for loading pretrained model weights.
4
+ backbone_presets = {}
@@ -114,7 +114,6 @@ class GemmaBackbone(Backbone):
114
114
  scale=1.0,
115
115
  mode="fan_in",
116
116
  distribution="untruncated_normal",
117
- seed=None,
118
117
  ),
119
118
  dtype=dtype,
120
119
  logit_soft_cap=final_logit_soft_cap,
@@ -198,4 +198,34 @@ backbone_presets = {
198
198
  },
199
199
  "kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_27b_en/2",
200
200
  },
201
+ "vault_gemma_1b_en": {
202
+ "metadata": {
203
+ "description": "1 billion parameter, 26-layer, VaultGemma model.",
204
+ "params": 1038741120,
205
+ "path": "gemma",
206
+ },
207
+ "kaggle_handle": "kaggle://keras/vaultgemma/keras/vault_gemma_1b_en/2",
208
+ },
209
+ "c2s_scale_gemma_2_2b_en": {
210
+ "metadata": {
211
+ "description": (
212
+ "A 2 billion parameter, single-cell biology-aware model "
213
+ "built on the Gemma-2 architecture."
214
+ ),
215
+ "params": 2614341888,
216
+ "path": "gemma",
217
+ },
218
+ "kaggle_handle": "kaggle://keras/cell2sentence/keras/c2s_scale_gemma_2_2b_en/1",
219
+ },
220
+ "c2s_scale_gemma_2_27b_en": {
221
+ "metadata": {
222
+ "description": (
223
+ "A 27 billion parameter, single-cell biology-aware model "
224
+ "built on the Gemma-2 architecture."
225
+ ),
226
+ "params": 27227128320,
227
+ "path": "gemma",
228
+ },
229
+ "kaggle_handle": "kaggle://keras/cell2sentence/keras/c2s_scale_gemma_2_27b_en/1",
230
+ },
201
231
  }
@@ -46,6 +46,7 @@ class CachedGemma3Attention(keras.layers.Layer):
46
46
  layer_norm_epsilon=1e-6,
47
47
  rope_wavelength=10_000.0,
48
48
  rope_scaling_factor=1.0,
49
+ use_bidirectional_attention=False,
49
50
  dropout=0,
50
51
  **kwargs,
51
52
  ):
@@ -61,6 +62,7 @@ class CachedGemma3Attention(keras.layers.Layer):
61
62
  self.layer_norm_epsilon = layer_norm_epsilon
62
63
  self.rope_wavelength = rope_wavelength
63
64
  self.rope_scaling_factor = rope_scaling_factor
65
+ self.use_bidirectional_attention = use_bidirectional_attention
64
66
  self.dropout = dropout
65
67
 
66
68
  self._kernel_initializer = keras.initializers.get(
@@ -240,12 +242,58 @@ class CachedGemma3Attention(keras.layers.Layer):
240
242
  results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
241
243
  return ops.reshape(results, (b, q_len, self.num_query_heads, h))
242
244
 
245
+ def _compute_bidirectional_sliding_mask(self, batch_size, sequence_length):
246
+ """Computes a bidirectional sliding window attention mask.
247
+
248
+ A token can attend to any other token if their absolute distance is
249
+ within half the sliding window size. This mask is used in embedding
250
+ models like `EmbeddingGemma`.
251
+
252
+ Args:
253
+ batch_size: The batch size for the mask.
254
+ sequence_length: The length of the sequence.
255
+
256
+ Returns:
257
+ A boolean attention mask with shape
258
+ `(batch_size, sequence_length, sequence_length)`.
259
+ """
260
+ i = keras.ops.expand_dims(
261
+ keras.ops.arange(sequence_length, dtype="int32"), axis=1
262
+ )
263
+ j = keras.ops.arange(sequence_length, dtype="int32")
264
+
265
+ # If sliding window size is 4, the token in question attends to 1
266
+ # token before and 2 tokens after.
267
+ w_right = self.sliding_window_size // 2
268
+ w_left = self.sliding_window_size - w_right - 1
269
+
270
+ # Calculate the relative distance.
271
+ distance = i - j
272
+
273
+ mask = keras.ops.logical_and(distance <= w_left, distance >= -w_right)
274
+
275
+ mask = keras.ops.expand_dims(mask, axis=0)
276
+ return keras.ops.broadcast_to(
277
+ mask, (batch_size, sequence_length, sequence_length)
278
+ )
279
+
243
280
  def _mask_sliding_window(
244
281
  self,
245
282
  attention_mask,
246
283
  cache_update_index=0,
247
284
  ):
248
285
  batch_size, query_len, key_len = ops.shape(attention_mask)
286
+
287
+ if self.use_bidirectional_attention:
288
+ bidirectional_sliding_mask = (
289
+ self._compute_bidirectional_sliding_mask(
290
+ batch_size=batch_size,
291
+ # `query_len = key_len` for embedding models
292
+ sequence_length=query_len,
293
+ )
294
+ )
295
+ return ops.logical_and(attention_mask, bidirectional_sliding_mask)
296
+
249
297
  # Compute the sliding window for square attention.
250
298
  all_ones = ops.ones((key_len, key_len), "bool")
251
299
  if keras.config.backend() == "tensorflow":
@@ -196,6 +196,7 @@ class Gemma3Backbone(Backbone):
196
196
  global_rope_scaling_factor=1.0,
197
197
  vision_encoder=None,
198
198
  layer_norm_epsilon=1e-6,
199
+ use_bidirectional_attention=False,
199
200
  dropout=0,
200
201
  dtype=None,
201
202
  **kwargs,
@@ -209,7 +210,6 @@ class Gemma3Backbone(Backbone):
209
210
  scale=1.0,
210
211
  mode="fan_in",
211
212
  distribution="untruncated_normal",
212
- seed=None,
213
213
  ),
214
214
  dtype=dtype,
215
215
  logit_soft_cap=final_logit_soft_cap,
@@ -251,6 +251,7 @@ class Gemma3Backbone(Backbone):
251
251
  sliding_window_size=sliding_window_size,
252
252
  rope_wavelength=rope_wavelength,
253
253
  rope_scaling_factor=rope_scaling_factor,
254
+ use_bidirectional_attention=use_bidirectional_attention,
254
255
  dropout=dropout,
255
256
  dtype=dtype,
256
257
  name=f"decoder_block_{i}",
@@ -357,6 +358,7 @@ class Gemma3Backbone(Backbone):
357
358
  self.sliding_window_size = sliding_window_size
358
359
  self.local_rope_scaling_factor = local_rope_scaling_factor
359
360
  self.global_rope_scaling_factor = global_rope_scaling_factor
361
+ self.use_bidirectional_attention = use_bidirectional_attention
360
362
  self.layer_norm_epsilon = layer_norm_epsilon
361
363
  self.dropout = dropout
362
364
 
@@ -396,6 +398,7 @@ class Gemma3Backbone(Backbone):
396
398
  "vision_encoder": None
397
399
  if self.vision_encoder is None
398
400
  else keras.layers.serialize(self.vision_encoder),
401
+ "use_bidirectional_attention": self.use_bidirectional_attention,
399
402
  "layer_norm_epsilon": self.layer_norm_epsilon,
400
403
  "dropout": self.dropout,
401
404
  }
@@ -45,6 +45,7 @@ class Gemma3DecoderBlock(keras.layers.Layer):
45
45
  layer_norm_epsilon=1e-6,
46
46
  rope_wavelength=10_000.0,
47
47
  rope_scaling_factor=1.0,
48
+ use_bidirectional_attention=False,
48
49
  dropout=0,
49
50
  **kwargs,
50
51
  ):
@@ -66,6 +67,7 @@ class Gemma3DecoderBlock(keras.layers.Layer):
66
67
  self.layer_norm_epsilon = layer_norm_epsilon
67
68
  self.rope_wavelength = rope_wavelength
68
69
  self.rope_scaling_factor = rope_scaling_factor
70
+ self.use_bidirectional_attention = use_bidirectional_attention
69
71
  self.dropout = dropout
70
72
 
71
73
  self.pre_attention_norm = RMSNormalization(
@@ -93,6 +95,7 @@ class Gemma3DecoderBlock(keras.layers.Layer):
93
95
  rope_wavelength=rope_wavelength,
94
96
  rope_scaling_factor=rope_scaling_factor,
95
97
  dropout=dropout,
98
+ use_bidirectional_attention=use_bidirectional_attention,
96
99
  dtype=self.dtype_policy,
97
100
  name="attention",
98
101
  )
@@ -209,6 +212,14 @@ class Gemma3DecoderBlock(keras.layers.Layer):
209
212
  if cache is not None:
210
213
  input_length = ops.shape(cache)[2]
211
214
 
215
+ if self.use_bidirectional_attention:
216
+ # `output_length` and `input_length` will be the same in this case
217
+ # because we use bidirectional attention for models like
218
+ # `EmbeddingGemma` which aren't used for text generation.
219
+ mask_1 = decoder_mask
220
+ mask_2 = ops.transpose(mask_1, (0, 2, 1))
221
+ return mask_1 * mask_2
222
+
212
223
  causal_mask = compute_causal_mask(
213
224
  batch_size=batch_size,
214
225
  input_length=input_length,
@@ -304,6 +315,7 @@ class Gemma3DecoderBlock(keras.layers.Layer):
304
315
  "dropout": self.dropout,
305
316
  "rope_wavelength": self.rope_wavelength,
306
317
  "rope_scaling_factor": self.rope_scaling_factor,
318
+ "use_bidirectional_attention": self.use_bidirectional_attention,
307
319
  }
308
320
  )
309
321
  return config
@@ -181,4 +181,43 @@ backbone_presets = {
181
181
  },
182
182
  "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_270m/4",
183
183
  },
184
+ "medgemma_instruct_4b": {
185
+ "metadata": {
186
+ "description": (
187
+ "A 4 billion parameter model based on Gemma 3. "
188
+ "This model is trained for performance on medical text"
189
+ "and image comprehension and is optimized for medical"
190
+ "applications that involve a text generation component."
191
+ ),
192
+ "params": 4300079472,
193
+ "path": "gemma3",
194
+ },
195
+ "kaggle_handle": "kaggle://keras/medgemma/keras/medgemma_instruct_4b/1",
196
+ },
197
+ "medgemma_instruct_27b": {
198
+ "metadata": {
199
+ "description": (
200
+ "A 27 billion parameter model based on Gemma 3. "
201
+ "This model trained for performance on medical text "
202
+ "and image comprehension and is optimized for medical "
203
+ "applications that involve a text generation component."
204
+ ),
205
+ "params": 27432406640,
206
+ "path": "gemma3",
207
+ },
208
+ "kaggle_handle": "kaggle://keras/medgemma/keras/medgemma_instruct_27b/1",
209
+ },
210
+ "medgemma_instruct_27b_text": {
211
+ "metadata": {
212
+ "description": (
213
+ "A 27 billion parameter text-only model based on Gemma 3. "
214
+ "This model is trained for performance on medical text "
215
+ "comprehension and is optimized for medical applications "
216
+ "that involve a text generation component."
217
+ ),
218
+ "params": 27009002240,
219
+ "path": "gemma3",
220
+ },
221
+ "kaggle_handle": "kaggle://keras/medgemma/keras/medgemma_instruct_27b_text/1",
222
+ },
184
223
  }
@@ -157,7 +157,10 @@ class HGNetV2Backbone(Backbone):
157
157
  if stage_name in self.out_features
158
158
  }
159
159
  super().__init__(
160
- inputs=pixel_values, outputs=feature_maps_output, **kwargs
160
+ inputs=pixel_values,
161
+ outputs=feature_maps_output,
162
+ dtype=dtype,
163
+ **kwargs,
161
164
  )
162
165
 
163
166
  # === Config ===
@@ -56,9 +56,10 @@ class HGNetV2Encoder(keras.layers.Layer):
56
56
  use_learnable_affine_block,
57
57
  data_format=None,
58
58
  channel_axis=None,
59
+ dtype=None,
59
60
  **kwargs,
60
61
  ):
61
- super().__init__(**kwargs)
62
+ super().__init__(dtype=dtype, **kwargs)
62
63
  self.stage_in_channels = stage_in_channels
63
64
  self.stage_mid_channels = stage_mid_channels
64
65
  self.stage_out_channels = stage_out_channels
@@ -90,7 +91,7 @@ class HGNetV2Encoder(keras.layers.Layer):
90
91
  name=f"{self.name}_stage_{stage_idx}"
91
92
  if self.name
92
93
  else f"stage_{stage_idx}",
93
- dtype=self.dtype,
94
+ dtype=dtype,
94
95
  )
95
96
  self.stages_list.append(stage_layer)
96
97
 
@@ -17,8 +17,8 @@ class HGNetV2LearnableAffineBlock(keras.layers.Layer):
17
17
  **kwargs: Additional keyword arguments passed to the parent class.
18
18
  """
19
19
 
20
- def __init__(self, scale_value=1.0, bias_value=0.0, **kwargs):
21
- super().__init__(**kwargs)
20
+ def __init__(self, scale_value=1.0, bias_value=0.0, dtype=None, **kwargs):
21
+ super().__init__(dtype=dtype, **kwargs)
22
22
  self.scale_value = scale_value
23
23
  self.bias_value = bias_value
24
24
 
@@ -87,9 +87,10 @@ class HGNetV2ConvLayer(keras.layers.Layer):
87
87
  use_learnable_affine_block=False,
88
88
  data_format=None,
89
89
  channel_axis=None,
90
+ dtype=None,
90
91
  **kwargs,
91
92
  ):
92
- super().__init__(**kwargs)
93
+ super().__init__(dtype=dtype, **kwargs)
93
94
  self.in_channels = in_channels
94
95
  self.out_channels = out_channels
95
96
  self.kernel_size = kernel_size
@@ -104,6 +105,7 @@ class HGNetV2ConvLayer(keras.layers.Layer):
104
105
  padding=((pad, pad), (pad, pad)),
105
106
  data_format=self.data_format,
106
107
  name=f"{self.name}_pad" if self.name else None,
108
+ dtype=self.dtype_policy,
107
109
  )
108
110
  self.convolution = keras.layers.Conv2D(
109
111
  filters=self.out_channels,
@@ -156,7 +158,8 @@ class HGNetV2ConvLayer(keras.layers.Layer):
156
158
  )
157
159
  else:
158
160
  self.lab = keras.layers.Identity(
159
- name=f"{self.name}_identity_lab" if self.name else None
161
+ name=f"{self.name}_identity_lab" if self.name else None,
162
+ dtype=self.dtype_policy,
160
163
  )
161
164
 
162
165
  def build(self, input_shape):
@@ -230,9 +233,10 @@ class HGNetV2ConvLayerLight(keras.layers.Layer):
230
233
  use_learnable_affine_block=False,
231
234
  data_format=None,
232
235
  channel_axis=None,
236
+ dtype=None,
233
237
  **kwargs,
234
238
  ):
235
- super().__init__(**kwargs)
239
+ super().__init__(dtype=dtype, **kwargs)
236
240
  self.in_channels = in_channels
237
241
  self.out_channels = out_channels
238
242
  self.kernel_size = kernel_size
@@ -327,9 +331,10 @@ class HGNetV2Embeddings(keras.layers.Layer):
327
331
  use_learnable_affine_block,
328
332
  data_format=None,
329
333
  channel_axis=None,
334
+ dtype=None,
330
335
  **kwargs,
331
336
  ):
332
- super().__init__(**kwargs)
337
+ super().__init__(dtype=dtype, **kwargs)
333
338
  self.stem_channels = stem_channels
334
339
  self.hidden_act = hidden_act
335
340
  self.use_learnable_affine_block = use_learnable_affine_block
@@ -352,6 +357,7 @@ class HGNetV2Embeddings(keras.layers.Layer):
352
357
  padding=((0, 1), (0, 1)),
353
358
  data_format=self.data_format,
354
359
  name=f"{self.name}_padding1" if self.name else "padding1",
360
+ dtype=self.dtype_policy,
355
361
  )
356
362
  self.stem2a_layer = HGNetV2ConvLayer(
357
363
  in_channels=self.stem_channels[1],
@@ -370,6 +376,7 @@ class HGNetV2Embeddings(keras.layers.Layer):
370
376
  padding=((0, 1), (0, 1)),
371
377
  data_format=self.data_format,
372
378
  name=f"{self.name}_padding2" if self.name else "padding2",
379
+ dtype=self.dtype_policy,
373
380
  )
374
381
  self.stem2b_layer = HGNetV2ConvLayer(
375
382
  in_channels=self.stem_channels[1] // 2,
@@ -390,10 +397,12 @@ class HGNetV2Embeddings(keras.layers.Layer):
390
397
  padding="valid",
391
398
  data_format=self.data_format,
392
399
  name=f"{self.name}_pool" if self.name else "pool",
400
+ dtype=self.dtype_policy,
393
401
  )
394
402
  self.concatenate_layer = keras.layers.Concatenate(
395
403
  axis=self.channel_axis,
396
404
  name=f"{self.name}_concat" if self.name else "concat",
405
+ dtype=self.dtype_policy,
397
406
  )
398
407
  self.stem3_layer = HGNetV2ConvLayer(
399
408
  in_channels=self.stem_channels[1] * 2,
@@ -550,9 +559,10 @@ class HGNetV2BasicLayer(keras.layers.Layer):
550
559
  use_learnable_affine_block=False,
551
560
  data_format=None,
552
561
  channel_axis=None,
562
+ dtype=None,
553
563
  **kwargs,
554
564
  ):
555
- super().__init__(**kwargs)
565
+ super().__init__(dtype=dtype, **kwargs)
556
566
  self.in_channels_arg = in_channels
557
567
  self.middle_channels = middle_channels
558
568
  self.out_channels = out_channels
@@ -635,23 +645,27 @@ class HGNetV2BasicLayer(keras.layers.Layer):
635
645
  self.drop_path_rate,
636
646
  noise_shape=(None, 1, 1, 1),
637
647
  name=f"{self.name}_drop_path" if self.name else "drop_path",
648
+ dtype=self.dtype_policy,
638
649
  )
639
650
  else:
640
651
  self.drop_path_layer = keras.layers.Identity(
641
652
  name=f"{self.name}_identity_drop_path"
642
653
  if self.name
643
- else "identity_drop_path"
654
+ else "identity_drop_path",
655
+ dtype=self.dtype_policy,
644
656
  )
645
657
 
646
658
  self.concatenate_layer = keras.layers.Concatenate(
647
659
  axis=self.channel_axis,
648
660
  name=f"{self.name}_concat" if self.name else "concat",
661
+ dtype=self.dtype_policy,
649
662
  )
650
663
  if self.residual:
651
664
  self.add_layer = keras.layers.Add(
652
665
  name=f"{self.name}_add_residual"
653
666
  if self.name
654
- else "add_residual"
667
+ else "add_residual",
668
+ dtype=self.dtype_policy,
655
669
  )
656
670
 
657
671
  def build(self, input_shape):
@@ -794,9 +808,10 @@ class HGNetV2Stage(keras.layers.Layer):
794
808
  drop_path: float = 0.0,
795
809
  data_format=None,
796
810
  channel_axis=None,
811
+ dtype=None,
797
812
  **kwargs,
798
813
  ):
799
- super().__init__(**kwargs)
814
+ super().__init__(dtype=dtype, **kwargs)
800
815
  self.stage_in_channels = stage_in_channels
801
816
  self.stage_mid_channels = stage_mid_channels
802
817
  self.stage_out_channels = stage_out_channels
@@ -842,7 +857,8 @@ class HGNetV2Stage(keras.layers.Layer):
842
857
  self.downsample_layer = keras.layers.Identity(
843
858
  name=f"{self.name}_identity_downsample"
844
859
  if self.name
845
- else "identity_downsample"
860
+ else "identity_downsample",
861
+ dtype=self.dtype_policy,
846
862
  )
847
863
 
848
864
  self.blocks_list = []
@@ -415,3 +415,8 @@ class ImageToImage(Task):
415
415
  # Image-to-image.
416
416
  outputs = [generate(*x) for x in inputs]
417
417
  return self._normalize_generate_outputs(outputs, input_is_scalar)
418
+
419
+ def _post_quantize(self, mode, **kwargs):
420
+ super()._post_quantize(mode, **kwargs)
421
+ # Reset the compiled generate function.
422
+ self.generate_function = None
@@ -518,3 +518,8 @@ class Inpaint(Task):
518
518
  # Inpaint.
519
519
  outputs = [generate(*x) for x in inputs]
520
520
  return self._normalize_generate_outputs(outputs, input_is_scalar)
521
+
522
+ def _post_quantize(self, mode, **kwargs):
523
+ super()._post_quantize(mode, **kwargs)
524
+ # Reset the compiled generate function.
525
+ self.generate_function = None
@@ -0,0 +1,9 @@
1
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import (
2
+ MobileNetV5Backbone,
3
+ )
4
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_presets import (
5
+ backbone_presets,
6
+ )
7
+ from keras_hub.src.utils.preset_utils import register_presets
8
+
9
+ register_presets(backbone_presets, MobileNetV5Backbone)