keras-hub-nightly 0.22.0.dev202508170419__py3-none-any.whl → 0.24.0.dev202511090424__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/layers/__init__.py +15 -0
- keras_hub/models/__init__.py +93 -0
- keras_hub/src/layers/modeling/position_embedding.py +21 -6
- keras_hub/src/layers/modeling/reversible_embedding.py +8 -1
- keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
- keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
- keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
- keras_hub/src/models/backbone.py +28 -16
- keras_hub/src/models/causal_lm.py +37 -0
- keras_hub/src/models/causal_lm_preprocessor.py +14 -0
- keras_hub/src/models/clip/clip_presets.py +8 -8
- keras_hub/src/models/d_fine/__init__.py +5 -0
- keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
- keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
- keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
- keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
- keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
- keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
- keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
- keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
- keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
- keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/d_fine/d_fine_presets.py +155 -0
- keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +7 -2
- keras_hub/src/models/depth_anything/__init__.py +9 -0
- keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
- keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
- keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
- keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
- keras_hub/src/models/depth_anything/depth_anything_presets.py +41 -0
- keras_hub/src/models/depth_anything/interpolate.py +62 -0
- keras_hub/src/models/depth_estimator.py +239 -0
- keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
- keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
- keras_hub/src/models/dinov2/dinov2_layers.py +16 -4
- keras_hub/src/models/dinov3/__init__.py +5 -0
- keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
- keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
- keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
- keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
- keras_hub/src/models/gemma/gemma_backbone.py +0 -1
- keras_hub/src/models/gemma/gemma_presets.py +30 -0
- keras_hub/src/models/gemma3/gemma3_attention.py +48 -0
- keras_hub/src/models/gemma3/gemma3_backbone.py +4 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +12 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
- keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
- keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
- keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
- keras_hub/src/models/image_to_image.py +5 -0
- keras_hub/src/models/inpaint.py +5 -0
- keras_hub/src/models/mobilenetv5/__init__.py +9 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
- keras_hub/src/models/parseq/__init__.py +5 -0
- keras_hub/src/models/parseq/parseq_backbone.py +134 -0
- keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
- keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
- keras_hub/src/models/parseq/parseq_decoder.py +418 -0
- keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
- keras_hub/src/models/parseq/parseq_presets.py +15 -0
- keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
- keras_hub/src/models/qwen3_moe/__init__.py +5 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
- keras_hub/src/models/sam/sam_prompt_encoder.py +3 -1
- keras_hub/src/models/siglip/siglip_presets.py +15 -0
- keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
- keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
- keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
- keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
- keras_hub/src/models/t5gemma/__init__.py +5 -0
- keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
- keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
- keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
- keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
- keras_hub/src/models/t5gemma/t5gemma_presets.py +374 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
- keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
- keras_hub/src/models/text_to_image.py +5 -0
- keras_hub/src/samplers/beam_sampler.py +6 -6
- keras_hub/src/samplers/sampler.py +8 -6
- keras_hub/src/tests/test_case.py +40 -3
- keras_hub/src/tokenizers/tokenizer.py +15 -0
- keras_hub/src/utils/openvino_utils.py +141 -0
- keras_hub/src/utils/preset_utils.py +58 -2
- keras_hub/src/utils/tensor_utils.py +26 -2
- keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
- keras_hub/src/utils/timm/preset_loader.py +8 -4
- keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
- keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
- keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
- keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
- keras_hub/src/utils/transformers/convert_vit.py +4 -1
- keras_hub/src/utils/transformers/export/gemma.py +49 -4
- keras_hub/src/utils/transformers/export/hf_exporter.py +71 -25
- keras_hub/src/utils/transformers/preset_loader.py +12 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +15 -0
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/RECORD +126 -47
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/top_level.txt +0 -0
|
@@ -198,4 +198,34 @@ backbone_presets = {
|
|
|
198
198
|
},
|
|
199
199
|
"kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_27b_en/2",
|
|
200
200
|
},
|
|
201
|
+
"vault_gemma_1b_en": {
|
|
202
|
+
"metadata": {
|
|
203
|
+
"description": "1 billion parameter, 26-layer, VaultGemma model.",
|
|
204
|
+
"params": 1038741120,
|
|
205
|
+
"path": "gemma",
|
|
206
|
+
},
|
|
207
|
+
"kaggle_handle": "kaggle://keras/vaultgemma/keras/vault_gemma_1b_en/2",
|
|
208
|
+
},
|
|
209
|
+
"c2s_scale_gemma_2_2b_en": {
|
|
210
|
+
"metadata": {
|
|
211
|
+
"description": (
|
|
212
|
+
"A 2 billion parameter, single-cell biology-aware model "
|
|
213
|
+
"built on the Gemma-2 architecture."
|
|
214
|
+
),
|
|
215
|
+
"params": 2614341888,
|
|
216
|
+
"path": "gemma",
|
|
217
|
+
},
|
|
218
|
+
"kaggle_handle": "kaggle://keras/cell2sentence/keras/c2s_scale_gemma_2_2b_en/1",
|
|
219
|
+
},
|
|
220
|
+
"c2s_scale_gemma_2_27b_en": {
|
|
221
|
+
"metadata": {
|
|
222
|
+
"description": (
|
|
223
|
+
"A 27 billion parameter, single-cell biology-aware model "
|
|
224
|
+
"built on the Gemma-2 architecture."
|
|
225
|
+
),
|
|
226
|
+
"params": 27227128320,
|
|
227
|
+
"path": "gemma",
|
|
228
|
+
},
|
|
229
|
+
"kaggle_handle": "kaggle://keras/cell2sentence/keras/c2s_scale_gemma_2_27b_en/1",
|
|
230
|
+
},
|
|
201
231
|
}
|
|
@@ -46,6 +46,7 @@ class CachedGemma3Attention(keras.layers.Layer):
|
|
|
46
46
|
layer_norm_epsilon=1e-6,
|
|
47
47
|
rope_wavelength=10_000.0,
|
|
48
48
|
rope_scaling_factor=1.0,
|
|
49
|
+
use_bidirectional_attention=False,
|
|
49
50
|
dropout=0,
|
|
50
51
|
**kwargs,
|
|
51
52
|
):
|
|
@@ -61,6 +62,7 @@ class CachedGemma3Attention(keras.layers.Layer):
|
|
|
61
62
|
self.layer_norm_epsilon = layer_norm_epsilon
|
|
62
63
|
self.rope_wavelength = rope_wavelength
|
|
63
64
|
self.rope_scaling_factor = rope_scaling_factor
|
|
65
|
+
self.use_bidirectional_attention = use_bidirectional_attention
|
|
64
66
|
self.dropout = dropout
|
|
65
67
|
|
|
66
68
|
self._kernel_initializer = keras.initializers.get(
|
|
@@ -240,12 +242,58 @@ class CachedGemma3Attention(keras.layers.Layer):
|
|
|
240
242
|
results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
|
|
241
243
|
return ops.reshape(results, (b, q_len, self.num_query_heads, h))
|
|
242
244
|
|
|
245
|
+
def _compute_bidirectional_sliding_mask(self, batch_size, sequence_length):
|
|
246
|
+
"""Computes a bidirectional sliding window attention mask.
|
|
247
|
+
|
|
248
|
+
A token can attend to any other token if their absolute distance is
|
|
249
|
+
within half the sliding window size. This mask is used in embedding
|
|
250
|
+
models like `EmbeddingGemma`.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
batch_size: The batch size for the mask.
|
|
254
|
+
sequence_length: The length of the sequence.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
A boolean attention mask with shape
|
|
258
|
+
`(batch_size, sequence_length, sequence_length)`.
|
|
259
|
+
"""
|
|
260
|
+
i = keras.ops.expand_dims(
|
|
261
|
+
keras.ops.arange(sequence_length, dtype="int32"), axis=1
|
|
262
|
+
)
|
|
263
|
+
j = keras.ops.arange(sequence_length, dtype="int32")
|
|
264
|
+
|
|
265
|
+
# If sliding window size is 4, the token in question attends to 1
|
|
266
|
+
# token before and 2 tokens after.
|
|
267
|
+
w_right = self.sliding_window_size // 2
|
|
268
|
+
w_left = self.sliding_window_size - w_right - 1
|
|
269
|
+
|
|
270
|
+
# Calculate the relative distance.
|
|
271
|
+
distance = i - j
|
|
272
|
+
|
|
273
|
+
mask = keras.ops.logical_and(distance <= w_left, distance >= -w_right)
|
|
274
|
+
|
|
275
|
+
mask = keras.ops.expand_dims(mask, axis=0)
|
|
276
|
+
return keras.ops.broadcast_to(
|
|
277
|
+
mask, (batch_size, sequence_length, sequence_length)
|
|
278
|
+
)
|
|
279
|
+
|
|
243
280
|
def _mask_sliding_window(
|
|
244
281
|
self,
|
|
245
282
|
attention_mask,
|
|
246
283
|
cache_update_index=0,
|
|
247
284
|
):
|
|
248
285
|
batch_size, query_len, key_len = ops.shape(attention_mask)
|
|
286
|
+
|
|
287
|
+
if self.use_bidirectional_attention:
|
|
288
|
+
bidirectional_sliding_mask = (
|
|
289
|
+
self._compute_bidirectional_sliding_mask(
|
|
290
|
+
batch_size=batch_size,
|
|
291
|
+
# `query_len = key_len` for embedding models
|
|
292
|
+
sequence_length=query_len,
|
|
293
|
+
)
|
|
294
|
+
)
|
|
295
|
+
return ops.logical_and(attention_mask, bidirectional_sliding_mask)
|
|
296
|
+
|
|
249
297
|
# Compute the sliding window for square attention.
|
|
250
298
|
all_ones = ops.ones((key_len, key_len), "bool")
|
|
251
299
|
if keras.config.backend() == "tensorflow":
|
|
@@ -196,6 +196,7 @@ class Gemma3Backbone(Backbone):
|
|
|
196
196
|
global_rope_scaling_factor=1.0,
|
|
197
197
|
vision_encoder=None,
|
|
198
198
|
layer_norm_epsilon=1e-6,
|
|
199
|
+
use_bidirectional_attention=False,
|
|
199
200
|
dropout=0,
|
|
200
201
|
dtype=None,
|
|
201
202
|
**kwargs,
|
|
@@ -209,7 +210,6 @@ class Gemma3Backbone(Backbone):
|
|
|
209
210
|
scale=1.0,
|
|
210
211
|
mode="fan_in",
|
|
211
212
|
distribution="untruncated_normal",
|
|
212
|
-
seed=None,
|
|
213
213
|
),
|
|
214
214
|
dtype=dtype,
|
|
215
215
|
logit_soft_cap=final_logit_soft_cap,
|
|
@@ -251,6 +251,7 @@ class Gemma3Backbone(Backbone):
|
|
|
251
251
|
sliding_window_size=sliding_window_size,
|
|
252
252
|
rope_wavelength=rope_wavelength,
|
|
253
253
|
rope_scaling_factor=rope_scaling_factor,
|
|
254
|
+
use_bidirectional_attention=use_bidirectional_attention,
|
|
254
255
|
dropout=dropout,
|
|
255
256
|
dtype=dtype,
|
|
256
257
|
name=f"decoder_block_{i}",
|
|
@@ -357,6 +358,7 @@ class Gemma3Backbone(Backbone):
|
|
|
357
358
|
self.sliding_window_size = sliding_window_size
|
|
358
359
|
self.local_rope_scaling_factor = local_rope_scaling_factor
|
|
359
360
|
self.global_rope_scaling_factor = global_rope_scaling_factor
|
|
361
|
+
self.use_bidirectional_attention = use_bidirectional_attention
|
|
360
362
|
self.layer_norm_epsilon = layer_norm_epsilon
|
|
361
363
|
self.dropout = dropout
|
|
362
364
|
|
|
@@ -396,6 +398,7 @@ class Gemma3Backbone(Backbone):
|
|
|
396
398
|
"vision_encoder": None
|
|
397
399
|
if self.vision_encoder is None
|
|
398
400
|
else keras.layers.serialize(self.vision_encoder),
|
|
401
|
+
"use_bidirectional_attention": self.use_bidirectional_attention,
|
|
399
402
|
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
400
403
|
"dropout": self.dropout,
|
|
401
404
|
}
|
|
@@ -45,6 +45,7 @@ class Gemma3DecoderBlock(keras.layers.Layer):
|
|
|
45
45
|
layer_norm_epsilon=1e-6,
|
|
46
46
|
rope_wavelength=10_000.0,
|
|
47
47
|
rope_scaling_factor=1.0,
|
|
48
|
+
use_bidirectional_attention=False,
|
|
48
49
|
dropout=0,
|
|
49
50
|
**kwargs,
|
|
50
51
|
):
|
|
@@ -66,6 +67,7 @@ class Gemma3DecoderBlock(keras.layers.Layer):
|
|
|
66
67
|
self.layer_norm_epsilon = layer_norm_epsilon
|
|
67
68
|
self.rope_wavelength = rope_wavelength
|
|
68
69
|
self.rope_scaling_factor = rope_scaling_factor
|
|
70
|
+
self.use_bidirectional_attention = use_bidirectional_attention
|
|
69
71
|
self.dropout = dropout
|
|
70
72
|
|
|
71
73
|
self.pre_attention_norm = RMSNormalization(
|
|
@@ -93,6 +95,7 @@ class Gemma3DecoderBlock(keras.layers.Layer):
|
|
|
93
95
|
rope_wavelength=rope_wavelength,
|
|
94
96
|
rope_scaling_factor=rope_scaling_factor,
|
|
95
97
|
dropout=dropout,
|
|
98
|
+
use_bidirectional_attention=use_bidirectional_attention,
|
|
96
99
|
dtype=self.dtype_policy,
|
|
97
100
|
name="attention",
|
|
98
101
|
)
|
|
@@ -209,6 +212,14 @@ class Gemma3DecoderBlock(keras.layers.Layer):
|
|
|
209
212
|
if cache is not None:
|
|
210
213
|
input_length = ops.shape(cache)[2]
|
|
211
214
|
|
|
215
|
+
if self.use_bidirectional_attention:
|
|
216
|
+
# `output_length` and `input_length` will be the same in this case
|
|
217
|
+
# because we use bidirectional attention for models like
|
|
218
|
+
# `EmbeddingGemma` which aren't used for text generation.
|
|
219
|
+
mask_1 = decoder_mask
|
|
220
|
+
mask_2 = ops.transpose(mask_1, (0, 2, 1))
|
|
221
|
+
return mask_1 * mask_2
|
|
222
|
+
|
|
212
223
|
causal_mask = compute_causal_mask(
|
|
213
224
|
batch_size=batch_size,
|
|
214
225
|
input_length=input_length,
|
|
@@ -304,6 +315,7 @@ class Gemma3DecoderBlock(keras.layers.Layer):
|
|
|
304
315
|
"dropout": self.dropout,
|
|
305
316
|
"rope_wavelength": self.rope_wavelength,
|
|
306
317
|
"rope_scaling_factor": self.rope_scaling_factor,
|
|
318
|
+
"use_bidirectional_attention": self.use_bidirectional_attention,
|
|
307
319
|
}
|
|
308
320
|
)
|
|
309
321
|
return config
|
|
@@ -181,4 +181,43 @@ backbone_presets = {
|
|
|
181
181
|
},
|
|
182
182
|
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_270m/4",
|
|
183
183
|
},
|
|
184
|
+
"medgemma_instruct_4b": {
|
|
185
|
+
"metadata": {
|
|
186
|
+
"description": (
|
|
187
|
+
"A 4 billion parameter model based on Gemma 3. "
|
|
188
|
+
"This model is trained for performance on medical text"
|
|
189
|
+
"and image comprehension and is optimized for medical"
|
|
190
|
+
"applications that involve a text generation component."
|
|
191
|
+
),
|
|
192
|
+
"params": 4300079472,
|
|
193
|
+
"path": "gemma3",
|
|
194
|
+
},
|
|
195
|
+
"kaggle_handle": "kaggle://keras/medgemma/keras/medgemma_instruct_4b/1",
|
|
196
|
+
},
|
|
197
|
+
"medgemma_instruct_27b": {
|
|
198
|
+
"metadata": {
|
|
199
|
+
"description": (
|
|
200
|
+
"A 27 billion parameter model based on Gemma 3. "
|
|
201
|
+
"This model trained for performance on medical text "
|
|
202
|
+
"and image comprehension and is optimized for medical "
|
|
203
|
+
"applications that involve a text generation component."
|
|
204
|
+
),
|
|
205
|
+
"params": 27432406640,
|
|
206
|
+
"path": "gemma3",
|
|
207
|
+
},
|
|
208
|
+
"kaggle_handle": "kaggle://keras/medgemma/keras/medgemma_instruct_27b/1",
|
|
209
|
+
},
|
|
210
|
+
"medgemma_instruct_27b_text": {
|
|
211
|
+
"metadata": {
|
|
212
|
+
"description": (
|
|
213
|
+
"A 27 billion parameter text-only model based on Gemma 3. "
|
|
214
|
+
"This model is trained for performance on medical text "
|
|
215
|
+
"comprehension and is optimized for medical applications "
|
|
216
|
+
"that involve a text generation component."
|
|
217
|
+
),
|
|
218
|
+
"params": 27009002240,
|
|
219
|
+
"path": "gemma3",
|
|
220
|
+
},
|
|
221
|
+
"kaggle_handle": "kaggle://keras/medgemma/keras/medgemma_instruct_27b_text/1",
|
|
222
|
+
},
|
|
184
223
|
}
|
|
@@ -157,7 +157,10 @@ class HGNetV2Backbone(Backbone):
|
|
|
157
157
|
if stage_name in self.out_features
|
|
158
158
|
}
|
|
159
159
|
super().__init__(
|
|
160
|
-
inputs=pixel_values,
|
|
160
|
+
inputs=pixel_values,
|
|
161
|
+
outputs=feature_maps_output,
|
|
162
|
+
dtype=dtype,
|
|
163
|
+
**kwargs,
|
|
161
164
|
)
|
|
162
165
|
|
|
163
166
|
# === Config ===
|
|
@@ -56,9 +56,10 @@ class HGNetV2Encoder(keras.layers.Layer):
|
|
|
56
56
|
use_learnable_affine_block,
|
|
57
57
|
data_format=None,
|
|
58
58
|
channel_axis=None,
|
|
59
|
+
dtype=None,
|
|
59
60
|
**kwargs,
|
|
60
61
|
):
|
|
61
|
-
super().__init__(**kwargs)
|
|
62
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
62
63
|
self.stage_in_channels = stage_in_channels
|
|
63
64
|
self.stage_mid_channels = stage_mid_channels
|
|
64
65
|
self.stage_out_channels = stage_out_channels
|
|
@@ -90,7 +91,7 @@ class HGNetV2Encoder(keras.layers.Layer):
|
|
|
90
91
|
name=f"{self.name}_stage_{stage_idx}"
|
|
91
92
|
if self.name
|
|
92
93
|
else f"stage_{stage_idx}",
|
|
93
|
-
dtype=
|
|
94
|
+
dtype=dtype,
|
|
94
95
|
)
|
|
95
96
|
self.stages_list.append(stage_layer)
|
|
96
97
|
|
|
@@ -17,8 +17,8 @@ class HGNetV2LearnableAffineBlock(keras.layers.Layer):
|
|
|
17
17
|
**kwargs: Additional keyword arguments passed to the parent class.
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
|
-
def __init__(self, scale_value=1.0, bias_value=0.0, **kwargs):
|
|
21
|
-
super().__init__(**kwargs)
|
|
20
|
+
def __init__(self, scale_value=1.0, bias_value=0.0, dtype=None, **kwargs):
|
|
21
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
22
22
|
self.scale_value = scale_value
|
|
23
23
|
self.bias_value = bias_value
|
|
24
24
|
|
|
@@ -87,9 +87,10 @@ class HGNetV2ConvLayer(keras.layers.Layer):
|
|
|
87
87
|
use_learnable_affine_block=False,
|
|
88
88
|
data_format=None,
|
|
89
89
|
channel_axis=None,
|
|
90
|
+
dtype=None,
|
|
90
91
|
**kwargs,
|
|
91
92
|
):
|
|
92
|
-
super().__init__(**kwargs)
|
|
93
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
93
94
|
self.in_channels = in_channels
|
|
94
95
|
self.out_channels = out_channels
|
|
95
96
|
self.kernel_size = kernel_size
|
|
@@ -104,6 +105,7 @@ class HGNetV2ConvLayer(keras.layers.Layer):
|
|
|
104
105
|
padding=((pad, pad), (pad, pad)),
|
|
105
106
|
data_format=self.data_format,
|
|
106
107
|
name=f"{self.name}_pad" if self.name else None,
|
|
108
|
+
dtype=self.dtype_policy,
|
|
107
109
|
)
|
|
108
110
|
self.convolution = keras.layers.Conv2D(
|
|
109
111
|
filters=self.out_channels,
|
|
@@ -156,7 +158,8 @@ class HGNetV2ConvLayer(keras.layers.Layer):
|
|
|
156
158
|
)
|
|
157
159
|
else:
|
|
158
160
|
self.lab = keras.layers.Identity(
|
|
159
|
-
name=f"{self.name}_identity_lab" if self.name else None
|
|
161
|
+
name=f"{self.name}_identity_lab" if self.name else None,
|
|
162
|
+
dtype=self.dtype_policy,
|
|
160
163
|
)
|
|
161
164
|
|
|
162
165
|
def build(self, input_shape):
|
|
@@ -230,9 +233,10 @@ class HGNetV2ConvLayerLight(keras.layers.Layer):
|
|
|
230
233
|
use_learnable_affine_block=False,
|
|
231
234
|
data_format=None,
|
|
232
235
|
channel_axis=None,
|
|
236
|
+
dtype=None,
|
|
233
237
|
**kwargs,
|
|
234
238
|
):
|
|
235
|
-
super().__init__(**kwargs)
|
|
239
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
236
240
|
self.in_channels = in_channels
|
|
237
241
|
self.out_channels = out_channels
|
|
238
242
|
self.kernel_size = kernel_size
|
|
@@ -327,9 +331,10 @@ class HGNetV2Embeddings(keras.layers.Layer):
|
|
|
327
331
|
use_learnable_affine_block,
|
|
328
332
|
data_format=None,
|
|
329
333
|
channel_axis=None,
|
|
334
|
+
dtype=None,
|
|
330
335
|
**kwargs,
|
|
331
336
|
):
|
|
332
|
-
super().__init__(**kwargs)
|
|
337
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
333
338
|
self.stem_channels = stem_channels
|
|
334
339
|
self.hidden_act = hidden_act
|
|
335
340
|
self.use_learnable_affine_block = use_learnable_affine_block
|
|
@@ -352,6 +357,7 @@ class HGNetV2Embeddings(keras.layers.Layer):
|
|
|
352
357
|
padding=((0, 1), (0, 1)),
|
|
353
358
|
data_format=self.data_format,
|
|
354
359
|
name=f"{self.name}_padding1" if self.name else "padding1",
|
|
360
|
+
dtype=self.dtype_policy,
|
|
355
361
|
)
|
|
356
362
|
self.stem2a_layer = HGNetV2ConvLayer(
|
|
357
363
|
in_channels=self.stem_channels[1],
|
|
@@ -370,6 +376,7 @@ class HGNetV2Embeddings(keras.layers.Layer):
|
|
|
370
376
|
padding=((0, 1), (0, 1)),
|
|
371
377
|
data_format=self.data_format,
|
|
372
378
|
name=f"{self.name}_padding2" if self.name else "padding2",
|
|
379
|
+
dtype=self.dtype_policy,
|
|
373
380
|
)
|
|
374
381
|
self.stem2b_layer = HGNetV2ConvLayer(
|
|
375
382
|
in_channels=self.stem_channels[1] // 2,
|
|
@@ -390,10 +397,12 @@ class HGNetV2Embeddings(keras.layers.Layer):
|
|
|
390
397
|
padding="valid",
|
|
391
398
|
data_format=self.data_format,
|
|
392
399
|
name=f"{self.name}_pool" if self.name else "pool",
|
|
400
|
+
dtype=self.dtype_policy,
|
|
393
401
|
)
|
|
394
402
|
self.concatenate_layer = keras.layers.Concatenate(
|
|
395
403
|
axis=self.channel_axis,
|
|
396
404
|
name=f"{self.name}_concat" if self.name else "concat",
|
|
405
|
+
dtype=self.dtype_policy,
|
|
397
406
|
)
|
|
398
407
|
self.stem3_layer = HGNetV2ConvLayer(
|
|
399
408
|
in_channels=self.stem_channels[1] * 2,
|
|
@@ -550,9 +559,10 @@ class HGNetV2BasicLayer(keras.layers.Layer):
|
|
|
550
559
|
use_learnable_affine_block=False,
|
|
551
560
|
data_format=None,
|
|
552
561
|
channel_axis=None,
|
|
562
|
+
dtype=None,
|
|
553
563
|
**kwargs,
|
|
554
564
|
):
|
|
555
|
-
super().__init__(**kwargs)
|
|
565
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
556
566
|
self.in_channels_arg = in_channels
|
|
557
567
|
self.middle_channels = middle_channels
|
|
558
568
|
self.out_channels = out_channels
|
|
@@ -635,23 +645,27 @@ class HGNetV2BasicLayer(keras.layers.Layer):
|
|
|
635
645
|
self.drop_path_rate,
|
|
636
646
|
noise_shape=(None, 1, 1, 1),
|
|
637
647
|
name=f"{self.name}_drop_path" if self.name else "drop_path",
|
|
648
|
+
dtype=self.dtype_policy,
|
|
638
649
|
)
|
|
639
650
|
else:
|
|
640
651
|
self.drop_path_layer = keras.layers.Identity(
|
|
641
652
|
name=f"{self.name}_identity_drop_path"
|
|
642
653
|
if self.name
|
|
643
|
-
else "identity_drop_path"
|
|
654
|
+
else "identity_drop_path",
|
|
655
|
+
dtype=self.dtype_policy,
|
|
644
656
|
)
|
|
645
657
|
|
|
646
658
|
self.concatenate_layer = keras.layers.Concatenate(
|
|
647
659
|
axis=self.channel_axis,
|
|
648
660
|
name=f"{self.name}_concat" if self.name else "concat",
|
|
661
|
+
dtype=self.dtype_policy,
|
|
649
662
|
)
|
|
650
663
|
if self.residual:
|
|
651
664
|
self.add_layer = keras.layers.Add(
|
|
652
665
|
name=f"{self.name}_add_residual"
|
|
653
666
|
if self.name
|
|
654
|
-
else "add_residual"
|
|
667
|
+
else "add_residual",
|
|
668
|
+
dtype=self.dtype_policy,
|
|
655
669
|
)
|
|
656
670
|
|
|
657
671
|
def build(self, input_shape):
|
|
@@ -794,9 +808,10 @@ class HGNetV2Stage(keras.layers.Layer):
|
|
|
794
808
|
drop_path: float = 0.0,
|
|
795
809
|
data_format=None,
|
|
796
810
|
channel_axis=None,
|
|
811
|
+
dtype=None,
|
|
797
812
|
**kwargs,
|
|
798
813
|
):
|
|
799
|
-
super().__init__(**kwargs)
|
|
814
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
800
815
|
self.stage_in_channels = stage_in_channels
|
|
801
816
|
self.stage_mid_channels = stage_mid_channels
|
|
802
817
|
self.stage_out_channels = stage_out_channels
|
|
@@ -842,7 +857,8 @@ class HGNetV2Stage(keras.layers.Layer):
|
|
|
842
857
|
self.downsample_layer = keras.layers.Identity(
|
|
843
858
|
name=f"{self.name}_identity_downsample"
|
|
844
859
|
if self.name
|
|
845
|
-
else "identity_downsample"
|
|
860
|
+
else "identity_downsample",
|
|
861
|
+
dtype=self.dtype_policy,
|
|
846
862
|
)
|
|
847
863
|
|
|
848
864
|
self.blocks_list = []
|
|
@@ -415,3 +415,8 @@ class ImageToImage(Task):
|
|
|
415
415
|
# Image-to-image.
|
|
416
416
|
outputs = [generate(*x) for x in inputs]
|
|
417
417
|
return self._normalize_generate_outputs(outputs, input_is_scalar)
|
|
418
|
+
|
|
419
|
+
def _post_quantize(self, mode, **kwargs):
|
|
420
|
+
super()._post_quantize(mode, **kwargs)
|
|
421
|
+
# Reset the compiled generate function.
|
|
422
|
+
self.generate_function = None
|
keras_hub/src/models/inpaint.py
CHANGED
|
@@ -518,3 +518,8 @@ class Inpaint(Task):
|
|
|
518
518
|
# Inpaint.
|
|
519
519
|
outputs = [generate(*x) for x in inputs]
|
|
520
520
|
return self._normalize_generate_outputs(outputs, input_is_scalar)
|
|
521
|
+
|
|
522
|
+
def _post_quantize(self, mode, **kwargs):
|
|
523
|
+
super()._post_quantize(mode, **kwargs)
|
|
524
|
+
# Reset the compiled generate function.
|
|
525
|
+
self.generate_function = None
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import (
|
|
2
|
+
MobileNetV5Backbone,
|
|
3
|
+
)
|
|
4
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_presets import (
|
|
5
|
+
backbone_presets,
|
|
6
|
+
)
|
|
7
|
+
from keras_hub.src.utils.preset_utils import register_presets
|
|
8
|
+
|
|
9
|
+
register_presets(backbone_presets, MobileNetV5Backbone)
|