keras-hub-nightly 0.22.0.dev202508170419__py3-none-any.whl → 0.24.0.dev202511090424__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-hub-nightly might be problematic. Click here for more details.

Files changed (126) hide show
  1. keras_hub/layers/__init__.py +15 -0
  2. keras_hub/models/__init__.py +93 -0
  3. keras_hub/src/layers/modeling/position_embedding.py +21 -6
  4. keras_hub/src/layers/modeling/reversible_embedding.py +8 -1
  5. keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
  6. keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
  7. keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
  8. keras_hub/src/models/backbone.py +28 -16
  9. keras_hub/src/models/causal_lm.py +37 -0
  10. keras_hub/src/models/causal_lm_preprocessor.py +14 -0
  11. keras_hub/src/models/clip/clip_presets.py +8 -8
  12. keras_hub/src/models/d_fine/__init__.py +5 -0
  13. keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
  14. keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
  15. keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
  16. keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
  17. keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
  18. keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
  19. keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
  20. keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
  21. keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
  22. keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
  23. keras_hub/src/models/d_fine/d_fine_presets.py +155 -0
  24. keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
  25. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +7 -2
  26. keras_hub/src/models/depth_anything/__init__.py +9 -0
  27. keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
  28. keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
  29. keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
  30. keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
  31. keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
  32. keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
  33. keras_hub/src/models/depth_anything/depth_anything_presets.py +41 -0
  34. keras_hub/src/models/depth_anything/interpolate.py +62 -0
  35. keras_hub/src/models/depth_estimator.py +239 -0
  36. keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
  37. keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
  38. keras_hub/src/models/dinov2/dinov2_layers.py +16 -4
  39. keras_hub/src/models/dinov3/__init__.py +5 -0
  40. keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
  41. keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
  42. keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
  43. keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
  44. keras_hub/src/models/gemma/gemma_backbone.py +0 -1
  45. keras_hub/src/models/gemma/gemma_presets.py +30 -0
  46. keras_hub/src/models/gemma3/gemma3_attention.py +48 -0
  47. keras_hub/src/models/gemma3/gemma3_backbone.py +4 -1
  48. keras_hub/src/models/gemma3/gemma3_decoder_block.py +12 -0
  49. keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
  50. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
  51. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
  52. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
  53. keras_hub/src/models/image_to_image.py +5 -0
  54. keras_hub/src/models/inpaint.py +5 -0
  55. keras_hub/src/models/mobilenetv5/__init__.py +9 -0
  56. keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
  57. keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
  58. keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
  59. keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
  60. keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
  61. keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
  62. keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
  63. keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
  64. keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
  65. keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
  66. keras_hub/src/models/parseq/__init__.py +5 -0
  67. keras_hub/src/models/parseq/parseq_backbone.py +134 -0
  68. keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
  69. keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
  70. keras_hub/src/models/parseq/parseq_decoder.py +418 -0
  71. keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
  72. keras_hub/src/models/parseq/parseq_presets.py +15 -0
  73. keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
  74. keras_hub/src/models/qwen3_moe/__init__.py +5 -0
  75. keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
  76. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
  77. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
  78. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
  79. keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
  80. keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
  81. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
  82. keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
  83. keras_hub/src/models/sam/sam_prompt_encoder.py +3 -1
  84. keras_hub/src/models/siglip/siglip_presets.py +15 -0
  85. keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
  86. keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
  87. keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
  88. keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
  89. keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
  90. keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
  92. keras_hub/src/models/t5gemma/__init__.py +5 -0
  93. keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
  94. keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
  95. keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
  96. keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
  97. keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
  98. keras_hub/src/models/t5gemma/t5gemma_presets.py +374 -0
  99. keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
  100. keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
  101. keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
  102. keras_hub/src/models/text_to_image.py +5 -0
  103. keras_hub/src/samplers/beam_sampler.py +6 -6
  104. keras_hub/src/samplers/sampler.py +8 -6
  105. keras_hub/src/tests/test_case.py +40 -3
  106. keras_hub/src/tokenizers/tokenizer.py +15 -0
  107. keras_hub/src/utils/openvino_utils.py +141 -0
  108. keras_hub/src/utils/preset_utils.py +58 -2
  109. keras_hub/src/utils/tensor_utils.py +26 -2
  110. keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
  111. keras_hub/src/utils/timm/preset_loader.py +8 -4
  112. keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
  113. keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
  114. keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
  115. keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
  116. keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
  117. keras_hub/src/utils/transformers/convert_vit.py +4 -1
  118. keras_hub/src/utils/transformers/export/gemma.py +49 -4
  119. keras_hub/src/utils/transformers/export/hf_exporter.py +71 -25
  120. keras_hub/src/utils/transformers/preset_loader.py +12 -0
  121. keras_hub/src/version.py +1 -1
  122. keras_hub/tokenizers/__init__.py +15 -0
  123. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/METADATA +1 -1
  124. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/RECORD +126 -47
  125. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/WHEEL +0 -0
  126. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,84 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone
3
+ from keras_hub.src.tokenizers.sentence_piece_tokenizer import (
4
+ SentencePieceTokenizer,
5
+ )
6
+
7
+
8
+ @keras_hub_export(
9
+ [
10
+ "keras_hub.tokenizers.T5GemmaTokenizer",
11
+ "keras_hub.models.T5GemmaTokenizer",
12
+ ]
13
+ )
14
+ class T5GemmaTokenizer(SentencePieceTokenizer):
15
+ """T5Gemma tokenizer layer based on SentencePiece.
16
+
17
+ This tokenizer class will tokenize raw strings into integer sequences and
18
+ is based on `keras_hub.tokenizers.SentencePieceTokenizer`. Unlike the
19
+ underlying tokenizer, it will check for all special tokens needed by
20
+ T5Gemma models and provides a `from_preset()` method to automatically
21
+ download a matching vocabulary for a T5Gemma preset.
22
+
23
+ If input is a batch of strings (rank > 0), the layer will output a
24
+ `tf.RaggedTensor` where the last dimension of the output is ragged.
25
+
26
+ If input is a scalar string (rank == 0), the layer will output a dense
27
+ `tf.Tensor` with static shape `[None]`.
28
+
29
+ Args:
30
+ proto: Either a `string` path to a SentencePiece proto file, or a
31
+ `bytes` object with a serialized SentencePiece proto. See the
32
+ [SentencePiece repository](https://github.com/google/sentencepiece)
33
+ for more details on the format.
34
+
35
+ Examples:
36
+
37
+ ```python
38
+ import io
39
+ import tensorflow as tf
40
+ import sentencepiece
41
+
42
+ # Unbatched input.
43
+ tokenizer = keras_hub.models.T5GemmaTokenizer.from_preset(
44
+ "t5gemma_b_b_prefixlm_it"
45
+ )
46
+ tokenizer("The quick brown fox jumped.")
47
+
48
+ # Batched input.
49
+ tokenizer(["The quick brown fox jumped.", "The fox slept."])
50
+
51
+ # Detokenization.
52
+ tokenizer.detokenize(tokenizer("The quick brown fox jumped."))
53
+
54
+ # Custom vocabulary.
55
+ bytes_io = io.BytesIO()
56
+ ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."])
57
+ sentencepiece.SentencePieceTrainer.train(
58
+ sentence_iterator=ds.as_numpy_iterator(),
59
+ model_writer=bytes_io,
60
+ vocab_size=8,
61
+ model_type="WORD",
62
+ pad_id=0,
63
+ bos_id=1,
64
+ eos_id=2,
65
+ unk_id=3,
66
+ pad_piece="<pad>",
67
+ bos_piece="<bos>",
68
+ eos_piece="<eos>",
69
+ unk_piece="<unk>",
70
+ )
71
+ tokenizer = keras_hub.models.T5GemmaTokenizer(
72
+ proto=bytes_io.getvalue(),
73
+ )
74
+ tokenizer("The quick brown fox jumped.")
75
+ ```
76
+ """
77
+
78
+ backbone_cls = T5GemmaBackbone
79
+
80
+ def __init__(self, proto, **kwargs):
81
+ self._add_special_token("<bos>", "start_token")
82
+ self._add_special_token("<eos>", "end_token")
83
+ self._add_special_token("<pad>", "pad_token")
84
+ super().__init__(proto=proto, **kwargs)
@@ -345,3 +345,8 @@ class TextToImage(Task):
345
345
  # Text-to-image.
346
346
  outputs = [generate(x) for x in inputs]
347
347
  return self._normalize_generate_outputs(outputs, input_is_scalar)
348
+
349
+ def _post_quantize(self, mode, **kwargs):
350
+ super()._post_quantize(mode, **kwargs)
351
+ # Reset the compiled generate function.
352
+ self.generate_function = None
@@ -95,15 +95,15 @@ class BeamSampler(Sampler):
95
95
  )
96
96
  log_probs = flatten_beams(ops.repeat(log_probs, batch_size, axis=0))
97
97
 
98
- def cond(prompt, cache, index, log_probs):
98
+ def cond(prompt, cache, index, mask, log_probs):
99
99
  if stop_token_ids is None:
100
- return True
100
+ return ops.convert_to_tensor(True, dtype="bool")
101
101
  # Stop if all sequences have produced a *new* stop token.
102
102
  end_tokens = any_equal(prompt, stop_token_ids, ~mask)
103
103
  prompt_done = ops.any(end_tokens, axis=-1)
104
104
  return ops.logical_not(ops.all(prompt_done))
105
105
 
106
- def body(prompt, cache, index, log_probs):
106
+ def body(prompt, cache, index, mask, log_probs):
107
107
  # Compute the softmax distribution for the next token.
108
108
  logits, _, cache = next(prompt, cache, index)
109
109
  vocab_size = ops.shape(logits)[-1]
@@ -150,12 +150,12 @@ class BeamSampler(Sampler):
150
150
  next_token = next_token[:, None]
151
151
  prompt = ops.slice_update(prompt, [0, index], next_token)
152
152
  # Return the iteration of the loop state.
153
- return (prompt, cache, index + 1, log_probs)
153
+ return (prompt, cache, index + 1, mask, log_probs)
154
154
 
155
- prompt, _, _, log_probs = self.run_loop(
155
+ prompt, _, _, _, log_probs = self.run_loop(
156
156
  cond=cond,
157
157
  body=body,
158
- loop_vars=(prompt, cache, index, log_probs),
158
+ loop_vars=(prompt, cache, index, mask, log_probs),
159
159
  maximum_iterations=(max_length - index),
160
160
  model=model,
161
161
  )
@@ -92,16 +92,18 @@ class Sampler:
92
92
  # `ops.while_loop` will not accept `None` as a value for `loop_vars`.
93
93
  cache = () if cache is None else cache
94
94
 
95
- def cond(prompt, cache, index):
95
+ # OpenVINO requires all parameters to be passed in the body.
96
+ # So we pass `mask` as well.
97
+ def cond(prompt, cache, index, mask):
96
98
  if stop_token_ids is None:
97
- return True
99
+ return ops.convert_to_tensor(True, dtype="bool")
98
100
  # Stop if all sequences have produced a *new* id from
99
101
  # stop_token_ids.
100
102
  end_tokens = any_equal(prompt, stop_token_ids, ~mask)
101
103
  prompt_done = ops.any(end_tokens, axis=-1)
102
104
  return ops.logical_not(ops.all(prompt_done))
103
105
 
104
- def body(prompt, cache, index):
106
+ def body(prompt, cache, index, mask):
105
107
  # Compute the softmax distribution for the next token.
106
108
  logits, _, cache = next(prompt, cache, index)
107
109
  probabilities = self.compute_probabilities(logits)
@@ -115,12 +117,12 @@ class Sampler:
115
117
  prompt = ops.slice_update(prompt, [0, index], next_token)
116
118
 
117
119
  # Return the next prompt, cache and incremented index.
118
- return (prompt, cache, index + 1)
120
+ return (prompt, cache, index + 1, mask)
119
121
 
120
- prompt, _, _ = self.run_loop(
122
+ prompt, _, _, _ = self.run_loop(
121
123
  cond,
122
124
  body,
123
- loop_vars=(prompt, cache, index),
125
+ loop_vars=(prompt, cache, index, mask),
124
126
  maximum_iterations=(max_length - index),
125
127
  model=model,
126
128
  )
@@ -499,6 +499,7 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
499
499
  init_kwargs,
500
500
  input_data,
501
501
  expected_output_shape,
502
+ spatial_output_keys=None,
502
503
  expected_pyramid_output_keys=None,
503
504
  expected_pyramid_image_sizes=None,
504
505
  variable_length_data=None,
@@ -537,10 +538,11 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
537
538
 
538
539
  self.assertIsInstance(output_data, dict)
539
540
  self.assertEqual(
540
- list(output_data.keys()), list(backbone.pyramid_outputs.keys())
541
+ sorted(output_data.keys()),
542
+ sorted(backbone.pyramid_outputs.keys()),
541
543
  )
542
544
  self.assertEqual(
543
- list(output_data.keys()), expected_pyramid_output_keys
545
+ sorted(output_data.keys()), sorted(expected_pyramid_output_keys)
544
546
  )
545
547
  # check height and width of each level.
546
548
  for i, (k, v) in enumerate(output_data.items()):
@@ -557,12 +559,47 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
557
559
  input_data = ops.transpose(input_data, axes=(2, 0, 1))
558
560
  elif len(input_data_shape) == 4:
559
561
  input_data = ops.transpose(input_data, axes=(0, 3, 1, 2))
560
- if len(expected_output_shape) == 3:
562
+ if isinstance(expected_output_shape, dict):
563
+ # Handle dictionary of shapes.
564
+ transposed_shapes = {}
565
+ for key, shape in expected_output_shape.items():
566
+ if spatial_output_keys and key not in spatial_output_keys:
567
+ transposed_shapes[key] = shape
568
+ continue
569
+ if len(shape) == 3:
570
+ transposed_shapes[key] = (shape[0], shape[2], shape[1])
571
+ elif len(shape) == 4:
572
+ transposed_shapes[key] = (
573
+ shape[0],
574
+ shape[3],
575
+ shape[1],
576
+ shape[2],
577
+ )
578
+ else:
579
+ transposed_shapes[key] = shape
580
+ expected_output_shape = transposed_shapes
581
+ elif len(expected_output_shape) == 3:
561
582
  x = expected_output_shape
562
583
  expected_output_shape = (x[0], x[2], x[1])
563
584
  elif len(expected_output_shape) == 4:
564
585
  x = expected_output_shape
565
586
  expected_output_shape = (x[0], x[3], x[1], x[2])
587
+ original_init_kwargs = init_kwargs.copy()
588
+ init_kwargs = original_init_kwargs.copy()
589
+ # Handle nested `keras.Model` instances passed within `init_kwargs`.
590
+ for k, v in init_kwargs.items():
591
+ if isinstance(v, keras.Model) and hasattr(v, "data_format"):
592
+ config = v.get_config()
593
+ config["data_format"] = "channels_first"
594
+ if (
595
+ "image_shape" in config
596
+ and config["image_shape"] is not None
597
+ and len(config["image_shape"]) == 3
598
+ ):
599
+ config["image_shape"] = tuple(
600
+ reversed(config["image_shape"])
601
+ )
602
+ init_kwargs[k] = v.__class__.from_config(config)
566
603
  if "image_shape" in init_kwargs:
567
604
  init_kwargs = init_kwargs.copy()
568
605
  init_kwargs["image_shape"] = tuple(
@@ -261,3 +261,18 @@ class Tokenizer(PreprocessingLayer):
261
261
  if cls.backbone_cls != backbone_cls:
262
262
  cls = find_subclass(preset, cls, backbone_cls)
263
263
  return loader.load_tokenizer(cls, config_file, **kwargs)
264
+
265
+ def export_to_transformers(self, path):
266
+ """Export the tokenizer to HuggingFace Transformers format.
267
+
268
+ This saves tokenizer assets in a format compatible with HuggingFace
269
+ Transformers.
270
+
271
+ Args:
272
+ path: str. Path to save the exported tokenizer.
273
+ """
274
+ from keras_hub.src.utils.transformers.export.hf_exporter import (
275
+ export_tokenizer,
276
+ )
277
+
278
+ export_tokenizer(self, path)
@@ -0,0 +1,141 @@
1
+ from keras import tree
2
+
3
+ from keras_hub.src.utils.keras_utils import print_msg
4
+
5
+ try:
6
+ import openvino as ov
7
+ import openvino.opset14 as ov_opset
8
+ from openvino import Core
9
+ except ImportError:
10
+ ov = None
11
+ ov_opset = None
12
+ Core = None
13
+
14
+
15
+ _core = None
16
+
17
+
18
+ def get_core():
19
+ """Get or create OpenVINO Core instance.
20
+
21
+ Returns:
22
+ openvino.Core: OpenVINO Core instance,
23
+ or None if OpenVINO not available.
24
+ """
25
+ global _core
26
+ if _core is None and Core is not None:
27
+ _core = Core()
28
+ return _core
29
+
30
+
31
+ def get_device():
32
+ """Detect and return the best available OpenVINO device.
33
+
34
+ Returns:
35
+ str: "GPU" if available, otherwise "CPU".
36
+ """
37
+ core = get_core()
38
+ if core is None:
39
+ return "CPU"
40
+ return "GPU" if "GPU" in core.available_devices else "CPU"
41
+
42
+
43
+ def compile_model(struct_params, struct_outputs, device, model_dtype):
44
+ """Compile OpenVINO model with dynamic shapes and precision hints.
45
+
46
+ Args:
47
+ struct_params: Model parameters structure.
48
+ struct_outputs: Model outputs structure.
49
+ device: Target device ("GPU" or "CPU").
50
+ model_dtype: Model precision ("f16" or "f32").
51
+
52
+ Returns:
53
+ Compiled OpenVINO model ready for inference.
54
+ """
55
+ flat_params = tree.flatten(struct_params)
56
+ flat_outputs = tree.flatten(struct_outputs)
57
+ parameters = [p.output.get_node() for p in flat_params]
58
+ results = [ov_opset.result(r.output) for r in flat_outputs]
59
+ ov_model = ov.Model(results=results, parameters=parameters)
60
+ for ov_input in ov_model.inputs:
61
+ rank = ov_input.get_partial_shape().rank.get_length()
62
+ ov_input.get_node().set_partial_shape(ov.PartialShape([-1] * rank))
63
+ ov_model.validate_nodes_and_infer_types()
64
+ config = {"INFERENCE_PRECISION_HINT": model_dtype}
65
+ core = get_core()
66
+ if core is None:
67
+ raise RuntimeError("OpenVINO not available")
68
+ return core.compile_model(ov_model, device, config)
69
+
70
+
71
+ def get_outputs(inputs, struct_outputs, compiled_ov_model, unpack_singleton):
72
+ """Execute compiled OpenVINO model and return structured outputs.
73
+
74
+ Args:
75
+ inputs: Input tensors for inference.
76
+ struct_outputs: Expected output structure.
77
+ compiled_ov_model: Compiled OpenVINO model.
78
+ unpack_singleton: Function to unpack singleton outputs.
79
+
80
+ Returns:
81
+ Structured model outputs matching expected format.
82
+ """
83
+ flatten_inputs = tree.flatten(inputs)
84
+ raw = compiled_ov_model(flatten_inputs).to_tuple()
85
+ packed = tree.pack_sequence_as(struct_outputs, raw)
86
+ return unpack_singleton(packed)
87
+
88
+
89
+ def ov_infer(model, inputs, stop_token_ids, fn):
90
+ """High-level OpenVINO inference with model reuse and compilation.
91
+
92
+ This function manages OpenVINO model compilation and caching. It reuses
93
+ existing compiled models when possible, or compiles new ones as needed.
94
+ Handles device detection and automatic precision selection.
95
+
96
+ Args:
97
+ model: Keras model with OpenVINO backend support.
98
+ inputs: Input tensors for inference.
99
+ stop_token_ids: Token IDs that should stop generation.
100
+ fn: Function to execute with the parameterized inputs.
101
+
102
+ Returns:
103
+ Model outputs from OpenVINO inference.
104
+ """
105
+ device = get_device()
106
+
107
+ # Try to use existing compiled model for the same device
108
+ if (
109
+ getattr(model, "ov_compiled_model", None) is not None
110
+ and getattr(model, "ov_device", None) is not None
111
+ and device == model.ov_device
112
+ ):
113
+ try:
114
+ return get_outputs(
115
+ inputs,
116
+ model.struct_outputs,
117
+ model.ov_compiled_model,
118
+ model._unpack_singleton,
119
+ )
120
+ except RuntimeError as e:
121
+ print_msg(
122
+ "WARNING: OpenVINO inference \033[1mFAILED\033[0m, "
123
+ "recompiling model and trying again.\n" + str(e)
124
+ )
125
+ model.ov_compiled_model = None
126
+ model.struct_outputs = None
127
+
128
+ # Compile a new model
129
+ struct_params = model._parameterize_data(inputs)
130
+ model.struct_outputs = fn(struct_params, stop_token_ids)
131
+ model.ov_device = device
132
+ model_dtype = "f16" if model.dtype in ("float16", "bfloat16") else "f32"
133
+ model.ov_compiled_model = compile_model(
134
+ struct_params, model.struct_outputs, device, model_dtype
135
+ )
136
+ return get_outputs(
137
+ inputs,
138
+ model.struct_outputs,
139
+ model.ov_compiled_model,
140
+ model._unpack_singleton,
141
+ )
@@ -10,6 +10,7 @@ import keras
10
10
  from absl import logging
11
11
 
12
12
  from keras_hub.src.api_export import keras_hub_export
13
+ from keras_hub.src.utils import tensor_utils
13
14
  from keras_hub.src.utils.keras_utils import print_msg
14
15
  from keras_hub.src.utils.keras_utils import sharded_weights_available
15
16
  from keras_hub.src.utils.tensor_utils import get_tensor_size_in_bits
@@ -501,10 +502,17 @@ def jax_memory_cleanup(layer):
501
502
  # For jax, delete all previous allocated memory to avoid temporarily
502
503
  # duplicating variable allocations. torch and tensorflow have stateful
503
504
  # variable types and do not need this fix.
505
+ # Skip deletion for sharded arrays to avoid breaking references in
506
+ # distributed setups.
504
507
  if keras.config.backend() == "jax":
505
508
  for weight in layer.weights:
506
- if getattr(weight, "_value", None) is not None:
507
- weight._value.delete()
509
+ if weight._value is not None:
510
+ # Do not delete sharded arrays, as they may be referenced in
511
+ # JAX's distributed computation graph and deletion can cause
512
+ # errors.
513
+ sharding = getattr(weight._value, "sharding", None)
514
+ if sharding is None:
515
+ weight._value.delete()
508
516
 
509
517
 
510
518
  def set_dtype_in_config(config, dtype=None):
@@ -687,6 +695,7 @@ class KerasPresetLoader(PresetLoader):
687
695
  )
688
696
  # We found a `task.json` with a complete config for our class.
689
697
  # Forward backbone args.
698
+ kwargs["dtype"] = self._resolve_dtype(self.config, kwargs)
690
699
  backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs)
691
700
  if "backbone" in task_config["config"]:
692
701
  backbone_config = task_config["config"]["backbone"]["config"]
@@ -708,6 +717,53 @@ class KerasPresetLoader(PresetLoader):
708
717
  self._load_backbone_weights(task.backbone)
709
718
  return task
710
719
 
720
+ def _resolve_dtype(self, config, kwargs):
721
+ """Resolves the Model's dtype based on the provided config and kwargs.
722
+
723
+ The data type is resolved based on the following priority:
724
+ 1. If a user specified dtype is passed, use that.
725
+ 2. If no user specified dtype is passed, and the save dtype is castable
726
+ to the current keras default dtype convert weights on load (float type
727
+ to float type).
728
+ 3. If not user specified dtype is passed, and the save dtype is not
729
+ castable to the current default dtype (quantized dtypes). Load the
730
+ saved types verbatim.
731
+
732
+ Args:
733
+ config: dict. The model configuration.
734
+ kwargs: dict. Additional keyword arguments, potentially including
735
+ `dtype`.
736
+
737
+ Returns:
738
+ str, dict, or DTypePolicy. The resolved dtype.
739
+ """
740
+ # 1. If a user specified dtype is passed, use that.
741
+ if "dtype" in kwargs and kwargs["dtype"] is not None:
742
+ return kwargs["dtype"]
743
+
744
+ saved_dtype = config.get("config", {}).get("dtype")
745
+
746
+ # If there's no saved dtype, we don't need to do anything.
747
+ if saved_dtype is None:
748
+ return None
749
+
750
+ # 2. Check whether the saved dtype is a simple float type.
751
+ policy_name = saved_dtype.get("config", {}).get("name")
752
+ if policy_name and tensor_utils.is_float_dtype(policy_name):
753
+ # If the saved dtype is a float, we can safely cast to the default
754
+ # backend float type.
755
+ if policy_name != keras.config.dtype_policy().name:
756
+ logging.info(
757
+ f"Converting weights saved as {policy_name} "
758
+ "to the current Keras dtype policy "
759
+ f"{keras.config.dtype_policy()}"
760
+ )
761
+ return keras.config.dtype_policy()
762
+ else:
763
+ # 3. Otherwise, the dtype is a complex object (e.g. a
764
+ # DTypePolicyMap for quantization), and should be used as is.
765
+ return saved_dtype
766
+
711
767
  def load_preprocessor(
712
768
  self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs
713
769
  ):
@@ -12,9 +12,11 @@ from packaging import version
12
12
 
13
13
  try:
14
14
  import tensorflow as tf
15
- import tensorflow_text as tf_text
16
15
  except ImportError:
17
16
  tf = None
17
+ try:
18
+ import tensorflow_text as tf_text
19
+ except ImportError:
18
20
  tf_text = None
19
21
 
20
22
 
@@ -310,7 +312,29 @@ def is_tensor_type(x):
310
312
 
311
313
 
312
314
  def is_float_dtype(dtype):
313
- return "float" in keras.backend.standardize_dtype(dtype)
315
+ """
316
+ Checks if a dtype is a float type by using a regex.
317
+
318
+ This function standardizes the input dtype and then uses a regular
319
+ expression to perform an exact match. It identifies standard floats,
320
+ bfloats, and mixed-precision float types.
321
+
322
+ For example:
323
+ - `is_float_dtype("float32")` returns `True`.
324
+ - `is_float_dtype("bfloat16")` returns `True`.
325
+ - `is_float_dtype("mixed_float16")` returns `True`.
326
+ - `is_float_dtype("int8")` returns `False`.
327
+ - `is_float_dtype("int8_from_float32")` returns `False`.
328
+
329
+ Args:
330
+ dtype: str, DTypePolicy. The data type to check.
331
+
332
+ Returns:
333
+ bool: `True` if the dtype is a floating-point type, `False` otherwise.
334
+ """
335
+ pattern = re.compile(r"^(mixed_)?(b)?float[0-9]*$")
336
+ standardized_dtype = keras.backend.standardize_dtype(dtype)
337
+ return pattern.match(standardized_dtype) is not None
314
338
 
315
339
 
316
340
  def is_int_dtype(dtype):