keras-hub-nightly 0.22.0.dev202508170419__py3-none-any.whl → 0.24.0.dev202511090424__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/layers/__init__.py +15 -0
- keras_hub/models/__init__.py +93 -0
- keras_hub/src/layers/modeling/position_embedding.py +21 -6
- keras_hub/src/layers/modeling/reversible_embedding.py +8 -1
- keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
- keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
- keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
- keras_hub/src/models/backbone.py +28 -16
- keras_hub/src/models/causal_lm.py +37 -0
- keras_hub/src/models/causal_lm_preprocessor.py +14 -0
- keras_hub/src/models/clip/clip_presets.py +8 -8
- keras_hub/src/models/d_fine/__init__.py +5 -0
- keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
- keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
- keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
- keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
- keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
- keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
- keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
- keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
- keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
- keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/d_fine/d_fine_presets.py +155 -0
- keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +7 -2
- keras_hub/src/models/depth_anything/__init__.py +9 -0
- keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
- keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
- keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
- keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
- keras_hub/src/models/depth_anything/depth_anything_presets.py +41 -0
- keras_hub/src/models/depth_anything/interpolate.py +62 -0
- keras_hub/src/models/depth_estimator.py +239 -0
- keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
- keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
- keras_hub/src/models/dinov2/dinov2_layers.py +16 -4
- keras_hub/src/models/dinov3/__init__.py +5 -0
- keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
- keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
- keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
- keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
- keras_hub/src/models/gemma/gemma_backbone.py +0 -1
- keras_hub/src/models/gemma/gemma_presets.py +30 -0
- keras_hub/src/models/gemma3/gemma3_attention.py +48 -0
- keras_hub/src/models/gemma3/gemma3_backbone.py +4 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +12 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
- keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
- keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
- keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
- keras_hub/src/models/image_to_image.py +5 -0
- keras_hub/src/models/inpaint.py +5 -0
- keras_hub/src/models/mobilenetv5/__init__.py +9 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
- keras_hub/src/models/parseq/__init__.py +5 -0
- keras_hub/src/models/parseq/parseq_backbone.py +134 -0
- keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
- keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
- keras_hub/src/models/parseq/parseq_decoder.py +418 -0
- keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
- keras_hub/src/models/parseq/parseq_presets.py +15 -0
- keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
- keras_hub/src/models/qwen3_moe/__init__.py +5 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
- keras_hub/src/models/sam/sam_prompt_encoder.py +3 -1
- keras_hub/src/models/siglip/siglip_presets.py +15 -0
- keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
- keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
- keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
- keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
- keras_hub/src/models/t5gemma/__init__.py +5 -0
- keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
- keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
- keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
- keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
- keras_hub/src/models/t5gemma/t5gemma_presets.py +374 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
- keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
- keras_hub/src/models/text_to_image.py +5 -0
- keras_hub/src/samplers/beam_sampler.py +6 -6
- keras_hub/src/samplers/sampler.py +8 -6
- keras_hub/src/tests/test_case.py +40 -3
- keras_hub/src/tokenizers/tokenizer.py +15 -0
- keras_hub/src/utils/openvino_utils.py +141 -0
- keras_hub/src/utils/preset_utils.py +58 -2
- keras_hub/src/utils/tensor_utils.py +26 -2
- keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
- keras_hub/src/utils/timm/preset_loader.py +8 -4
- keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
- keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
- keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
- keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
- keras_hub/src/utils/transformers/convert_vit.py +4 -1
- keras_hub/src/utils/transformers/export/gemma.py +49 -4
- keras_hub/src/utils/transformers/export/hf_exporter.py +71 -25
- keras_hub/src/utils/transformers/preset_loader.py +12 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +15 -0
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/RECORD +126 -47
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
2
|
+
from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone
|
|
3
|
+
from keras_hub.src.tokenizers.sentence_piece_tokenizer import (
|
|
4
|
+
SentencePieceTokenizer,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@keras_hub_export(
|
|
9
|
+
[
|
|
10
|
+
"keras_hub.tokenizers.T5GemmaTokenizer",
|
|
11
|
+
"keras_hub.models.T5GemmaTokenizer",
|
|
12
|
+
]
|
|
13
|
+
)
|
|
14
|
+
class T5GemmaTokenizer(SentencePieceTokenizer):
|
|
15
|
+
"""T5Gemma tokenizer layer based on SentencePiece.
|
|
16
|
+
|
|
17
|
+
This tokenizer class will tokenize raw strings into integer sequences and
|
|
18
|
+
is based on `keras_hub.tokenizers.SentencePieceTokenizer`. Unlike the
|
|
19
|
+
underlying tokenizer, it will check for all special tokens needed by
|
|
20
|
+
T5Gemma models and provides a `from_preset()` method to automatically
|
|
21
|
+
download a matching vocabulary for a T5Gemma preset.
|
|
22
|
+
|
|
23
|
+
If input is a batch of strings (rank > 0), the layer will output a
|
|
24
|
+
`tf.RaggedTensor` where the last dimension of the output is ragged.
|
|
25
|
+
|
|
26
|
+
If input is a scalar string (rank == 0), the layer will output a dense
|
|
27
|
+
`tf.Tensor` with static shape `[None]`.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
proto: Either a `string` path to a SentencePiece proto file, or a
|
|
31
|
+
`bytes` object with a serialized SentencePiece proto. See the
|
|
32
|
+
[SentencePiece repository](https://github.com/google/sentencepiece)
|
|
33
|
+
for more details on the format.
|
|
34
|
+
|
|
35
|
+
Examples:
|
|
36
|
+
|
|
37
|
+
```python
|
|
38
|
+
import io
|
|
39
|
+
import tensorflow as tf
|
|
40
|
+
import sentencepiece
|
|
41
|
+
|
|
42
|
+
# Unbatched input.
|
|
43
|
+
tokenizer = keras_hub.models.T5GemmaTokenizer.from_preset(
|
|
44
|
+
"t5gemma_b_b_prefixlm_it"
|
|
45
|
+
)
|
|
46
|
+
tokenizer("The quick brown fox jumped.")
|
|
47
|
+
|
|
48
|
+
# Batched input.
|
|
49
|
+
tokenizer(["The quick brown fox jumped.", "The fox slept."])
|
|
50
|
+
|
|
51
|
+
# Detokenization.
|
|
52
|
+
tokenizer.detokenize(tokenizer("The quick brown fox jumped."))
|
|
53
|
+
|
|
54
|
+
# Custom vocabulary.
|
|
55
|
+
bytes_io = io.BytesIO()
|
|
56
|
+
ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."])
|
|
57
|
+
sentencepiece.SentencePieceTrainer.train(
|
|
58
|
+
sentence_iterator=ds.as_numpy_iterator(),
|
|
59
|
+
model_writer=bytes_io,
|
|
60
|
+
vocab_size=8,
|
|
61
|
+
model_type="WORD",
|
|
62
|
+
pad_id=0,
|
|
63
|
+
bos_id=1,
|
|
64
|
+
eos_id=2,
|
|
65
|
+
unk_id=3,
|
|
66
|
+
pad_piece="<pad>",
|
|
67
|
+
bos_piece="<bos>",
|
|
68
|
+
eos_piece="<eos>",
|
|
69
|
+
unk_piece="<unk>",
|
|
70
|
+
)
|
|
71
|
+
tokenizer = keras_hub.models.T5GemmaTokenizer(
|
|
72
|
+
proto=bytes_io.getvalue(),
|
|
73
|
+
)
|
|
74
|
+
tokenizer("The quick brown fox jumped.")
|
|
75
|
+
```
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
backbone_cls = T5GemmaBackbone
|
|
79
|
+
|
|
80
|
+
def __init__(self, proto, **kwargs):
|
|
81
|
+
self._add_special_token("<bos>", "start_token")
|
|
82
|
+
self._add_special_token("<eos>", "end_token")
|
|
83
|
+
self._add_special_token("<pad>", "pad_token")
|
|
84
|
+
super().__init__(proto=proto, **kwargs)
|
|
@@ -345,3 +345,8 @@ class TextToImage(Task):
|
|
|
345
345
|
# Text-to-image.
|
|
346
346
|
outputs = [generate(x) for x in inputs]
|
|
347
347
|
return self._normalize_generate_outputs(outputs, input_is_scalar)
|
|
348
|
+
|
|
349
|
+
def _post_quantize(self, mode, **kwargs):
|
|
350
|
+
super()._post_quantize(mode, **kwargs)
|
|
351
|
+
# Reset the compiled generate function.
|
|
352
|
+
self.generate_function = None
|
|
@@ -95,15 +95,15 @@ class BeamSampler(Sampler):
|
|
|
95
95
|
)
|
|
96
96
|
log_probs = flatten_beams(ops.repeat(log_probs, batch_size, axis=0))
|
|
97
97
|
|
|
98
|
-
def cond(prompt, cache, index, log_probs):
|
|
98
|
+
def cond(prompt, cache, index, mask, log_probs):
|
|
99
99
|
if stop_token_ids is None:
|
|
100
|
-
return True
|
|
100
|
+
return ops.convert_to_tensor(True, dtype="bool")
|
|
101
101
|
# Stop if all sequences have produced a *new* stop token.
|
|
102
102
|
end_tokens = any_equal(prompt, stop_token_ids, ~mask)
|
|
103
103
|
prompt_done = ops.any(end_tokens, axis=-1)
|
|
104
104
|
return ops.logical_not(ops.all(prompt_done))
|
|
105
105
|
|
|
106
|
-
def body(prompt, cache, index, log_probs):
|
|
106
|
+
def body(prompt, cache, index, mask, log_probs):
|
|
107
107
|
# Compute the softmax distribution for the next token.
|
|
108
108
|
logits, _, cache = next(prompt, cache, index)
|
|
109
109
|
vocab_size = ops.shape(logits)[-1]
|
|
@@ -150,12 +150,12 @@ class BeamSampler(Sampler):
|
|
|
150
150
|
next_token = next_token[:, None]
|
|
151
151
|
prompt = ops.slice_update(prompt, [0, index], next_token)
|
|
152
152
|
# Return the iteration of the loop state.
|
|
153
|
-
return (prompt, cache, index + 1, log_probs)
|
|
153
|
+
return (prompt, cache, index + 1, mask, log_probs)
|
|
154
154
|
|
|
155
|
-
prompt, _, _, log_probs = self.run_loop(
|
|
155
|
+
prompt, _, _, _, log_probs = self.run_loop(
|
|
156
156
|
cond=cond,
|
|
157
157
|
body=body,
|
|
158
|
-
loop_vars=(prompt, cache, index, log_probs),
|
|
158
|
+
loop_vars=(prompt, cache, index, mask, log_probs),
|
|
159
159
|
maximum_iterations=(max_length - index),
|
|
160
160
|
model=model,
|
|
161
161
|
)
|
|
@@ -92,16 +92,18 @@ class Sampler:
|
|
|
92
92
|
# `ops.while_loop` will not accept `None` as a value for `loop_vars`.
|
|
93
93
|
cache = () if cache is None else cache
|
|
94
94
|
|
|
95
|
-
|
|
95
|
+
# OpenVINO requires all parameters to be passed in the body.
|
|
96
|
+
# So we pass `mask` as well.
|
|
97
|
+
def cond(prompt, cache, index, mask):
|
|
96
98
|
if stop_token_ids is None:
|
|
97
|
-
return True
|
|
99
|
+
return ops.convert_to_tensor(True, dtype="bool")
|
|
98
100
|
# Stop if all sequences have produced a *new* id from
|
|
99
101
|
# stop_token_ids.
|
|
100
102
|
end_tokens = any_equal(prompt, stop_token_ids, ~mask)
|
|
101
103
|
prompt_done = ops.any(end_tokens, axis=-1)
|
|
102
104
|
return ops.logical_not(ops.all(prompt_done))
|
|
103
105
|
|
|
104
|
-
def body(prompt, cache, index):
|
|
106
|
+
def body(prompt, cache, index, mask):
|
|
105
107
|
# Compute the softmax distribution for the next token.
|
|
106
108
|
logits, _, cache = next(prompt, cache, index)
|
|
107
109
|
probabilities = self.compute_probabilities(logits)
|
|
@@ -115,12 +117,12 @@ class Sampler:
|
|
|
115
117
|
prompt = ops.slice_update(prompt, [0, index], next_token)
|
|
116
118
|
|
|
117
119
|
# Return the next prompt, cache and incremented index.
|
|
118
|
-
return (prompt, cache, index + 1)
|
|
120
|
+
return (prompt, cache, index + 1, mask)
|
|
119
121
|
|
|
120
|
-
prompt, _, _ = self.run_loop(
|
|
122
|
+
prompt, _, _, _ = self.run_loop(
|
|
121
123
|
cond,
|
|
122
124
|
body,
|
|
123
|
-
loop_vars=(prompt, cache, index),
|
|
125
|
+
loop_vars=(prompt, cache, index, mask),
|
|
124
126
|
maximum_iterations=(max_length - index),
|
|
125
127
|
model=model,
|
|
126
128
|
)
|
keras_hub/src/tests/test_case.py
CHANGED
|
@@ -499,6 +499,7 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
|
|
|
499
499
|
init_kwargs,
|
|
500
500
|
input_data,
|
|
501
501
|
expected_output_shape,
|
|
502
|
+
spatial_output_keys=None,
|
|
502
503
|
expected_pyramid_output_keys=None,
|
|
503
504
|
expected_pyramid_image_sizes=None,
|
|
504
505
|
variable_length_data=None,
|
|
@@ -537,10 +538,11 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
|
|
|
537
538
|
|
|
538
539
|
self.assertIsInstance(output_data, dict)
|
|
539
540
|
self.assertEqual(
|
|
540
|
-
|
|
541
|
+
sorted(output_data.keys()),
|
|
542
|
+
sorted(backbone.pyramid_outputs.keys()),
|
|
541
543
|
)
|
|
542
544
|
self.assertEqual(
|
|
543
|
-
|
|
545
|
+
sorted(output_data.keys()), sorted(expected_pyramid_output_keys)
|
|
544
546
|
)
|
|
545
547
|
# check height and width of each level.
|
|
546
548
|
for i, (k, v) in enumerate(output_data.items()):
|
|
@@ -557,12 +559,47 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
|
|
|
557
559
|
input_data = ops.transpose(input_data, axes=(2, 0, 1))
|
|
558
560
|
elif len(input_data_shape) == 4:
|
|
559
561
|
input_data = ops.transpose(input_data, axes=(0, 3, 1, 2))
|
|
560
|
-
if
|
|
562
|
+
if isinstance(expected_output_shape, dict):
|
|
563
|
+
# Handle dictionary of shapes.
|
|
564
|
+
transposed_shapes = {}
|
|
565
|
+
for key, shape in expected_output_shape.items():
|
|
566
|
+
if spatial_output_keys and key not in spatial_output_keys:
|
|
567
|
+
transposed_shapes[key] = shape
|
|
568
|
+
continue
|
|
569
|
+
if len(shape) == 3:
|
|
570
|
+
transposed_shapes[key] = (shape[0], shape[2], shape[1])
|
|
571
|
+
elif len(shape) == 4:
|
|
572
|
+
transposed_shapes[key] = (
|
|
573
|
+
shape[0],
|
|
574
|
+
shape[3],
|
|
575
|
+
shape[1],
|
|
576
|
+
shape[2],
|
|
577
|
+
)
|
|
578
|
+
else:
|
|
579
|
+
transposed_shapes[key] = shape
|
|
580
|
+
expected_output_shape = transposed_shapes
|
|
581
|
+
elif len(expected_output_shape) == 3:
|
|
561
582
|
x = expected_output_shape
|
|
562
583
|
expected_output_shape = (x[0], x[2], x[1])
|
|
563
584
|
elif len(expected_output_shape) == 4:
|
|
564
585
|
x = expected_output_shape
|
|
565
586
|
expected_output_shape = (x[0], x[3], x[1], x[2])
|
|
587
|
+
original_init_kwargs = init_kwargs.copy()
|
|
588
|
+
init_kwargs = original_init_kwargs.copy()
|
|
589
|
+
# Handle nested `keras.Model` instances passed within `init_kwargs`.
|
|
590
|
+
for k, v in init_kwargs.items():
|
|
591
|
+
if isinstance(v, keras.Model) and hasattr(v, "data_format"):
|
|
592
|
+
config = v.get_config()
|
|
593
|
+
config["data_format"] = "channels_first"
|
|
594
|
+
if (
|
|
595
|
+
"image_shape" in config
|
|
596
|
+
and config["image_shape"] is not None
|
|
597
|
+
and len(config["image_shape"]) == 3
|
|
598
|
+
):
|
|
599
|
+
config["image_shape"] = tuple(
|
|
600
|
+
reversed(config["image_shape"])
|
|
601
|
+
)
|
|
602
|
+
init_kwargs[k] = v.__class__.from_config(config)
|
|
566
603
|
if "image_shape" in init_kwargs:
|
|
567
604
|
init_kwargs = init_kwargs.copy()
|
|
568
605
|
init_kwargs["image_shape"] = tuple(
|
|
@@ -261,3 +261,18 @@ class Tokenizer(PreprocessingLayer):
|
|
|
261
261
|
if cls.backbone_cls != backbone_cls:
|
|
262
262
|
cls = find_subclass(preset, cls, backbone_cls)
|
|
263
263
|
return loader.load_tokenizer(cls, config_file, **kwargs)
|
|
264
|
+
|
|
265
|
+
def export_to_transformers(self, path):
|
|
266
|
+
"""Export the tokenizer to HuggingFace Transformers format.
|
|
267
|
+
|
|
268
|
+
This saves tokenizer assets in a format compatible with HuggingFace
|
|
269
|
+
Transformers.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
path: str. Path to save the exported tokenizer.
|
|
273
|
+
"""
|
|
274
|
+
from keras_hub.src.utils.transformers.export.hf_exporter import (
|
|
275
|
+
export_tokenizer,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
export_tokenizer(self, path)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
from keras import tree
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.utils.keras_utils import print_msg
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import openvino as ov
|
|
7
|
+
import openvino.opset14 as ov_opset
|
|
8
|
+
from openvino import Core
|
|
9
|
+
except ImportError:
|
|
10
|
+
ov = None
|
|
11
|
+
ov_opset = None
|
|
12
|
+
Core = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
_core = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_core():
|
|
19
|
+
"""Get or create OpenVINO Core instance.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
openvino.Core: OpenVINO Core instance,
|
|
23
|
+
or None if OpenVINO not available.
|
|
24
|
+
"""
|
|
25
|
+
global _core
|
|
26
|
+
if _core is None and Core is not None:
|
|
27
|
+
_core = Core()
|
|
28
|
+
return _core
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_device():
|
|
32
|
+
"""Detect and return the best available OpenVINO device.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
str: "GPU" if available, otherwise "CPU".
|
|
36
|
+
"""
|
|
37
|
+
core = get_core()
|
|
38
|
+
if core is None:
|
|
39
|
+
return "CPU"
|
|
40
|
+
return "GPU" if "GPU" in core.available_devices else "CPU"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def compile_model(struct_params, struct_outputs, device, model_dtype):
|
|
44
|
+
"""Compile OpenVINO model with dynamic shapes and precision hints.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
struct_params: Model parameters structure.
|
|
48
|
+
struct_outputs: Model outputs structure.
|
|
49
|
+
device: Target device ("GPU" or "CPU").
|
|
50
|
+
model_dtype: Model precision ("f16" or "f32").
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Compiled OpenVINO model ready for inference.
|
|
54
|
+
"""
|
|
55
|
+
flat_params = tree.flatten(struct_params)
|
|
56
|
+
flat_outputs = tree.flatten(struct_outputs)
|
|
57
|
+
parameters = [p.output.get_node() for p in flat_params]
|
|
58
|
+
results = [ov_opset.result(r.output) for r in flat_outputs]
|
|
59
|
+
ov_model = ov.Model(results=results, parameters=parameters)
|
|
60
|
+
for ov_input in ov_model.inputs:
|
|
61
|
+
rank = ov_input.get_partial_shape().rank.get_length()
|
|
62
|
+
ov_input.get_node().set_partial_shape(ov.PartialShape([-1] * rank))
|
|
63
|
+
ov_model.validate_nodes_and_infer_types()
|
|
64
|
+
config = {"INFERENCE_PRECISION_HINT": model_dtype}
|
|
65
|
+
core = get_core()
|
|
66
|
+
if core is None:
|
|
67
|
+
raise RuntimeError("OpenVINO not available")
|
|
68
|
+
return core.compile_model(ov_model, device, config)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_outputs(inputs, struct_outputs, compiled_ov_model, unpack_singleton):
|
|
72
|
+
"""Execute compiled OpenVINO model and return structured outputs.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
inputs: Input tensors for inference.
|
|
76
|
+
struct_outputs: Expected output structure.
|
|
77
|
+
compiled_ov_model: Compiled OpenVINO model.
|
|
78
|
+
unpack_singleton: Function to unpack singleton outputs.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Structured model outputs matching expected format.
|
|
82
|
+
"""
|
|
83
|
+
flatten_inputs = tree.flatten(inputs)
|
|
84
|
+
raw = compiled_ov_model(flatten_inputs).to_tuple()
|
|
85
|
+
packed = tree.pack_sequence_as(struct_outputs, raw)
|
|
86
|
+
return unpack_singleton(packed)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def ov_infer(model, inputs, stop_token_ids, fn):
|
|
90
|
+
"""High-level OpenVINO inference with model reuse and compilation.
|
|
91
|
+
|
|
92
|
+
This function manages OpenVINO model compilation and caching. It reuses
|
|
93
|
+
existing compiled models when possible, or compiles new ones as needed.
|
|
94
|
+
Handles device detection and automatic precision selection.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
model: Keras model with OpenVINO backend support.
|
|
98
|
+
inputs: Input tensors for inference.
|
|
99
|
+
stop_token_ids: Token IDs that should stop generation.
|
|
100
|
+
fn: Function to execute with the parameterized inputs.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Model outputs from OpenVINO inference.
|
|
104
|
+
"""
|
|
105
|
+
device = get_device()
|
|
106
|
+
|
|
107
|
+
# Try to use existing compiled model for the same device
|
|
108
|
+
if (
|
|
109
|
+
getattr(model, "ov_compiled_model", None) is not None
|
|
110
|
+
and getattr(model, "ov_device", None) is not None
|
|
111
|
+
and device == model.ov_device
|
|
112
|
+
):
|
|
113
|
+
try:
|
|
114
|
+
return get_outputs(
|
|
115
|
+
inputs,
|
|
116
|
+
model.struct_outputs,
|
|
117
|
+
model.ov_compiled_model,
|
|
118
|
+
model._unpack_singleton,
|
|
119
|
+
)
|
|
120
|
+
except RuntimeError as e:
|
|
121
|
+
print_msg(
|
|
122
|
+
"WARNING: OpenVINO inference \033[1mFAILED\033[0m, "
|
|
123
|
+
"recompiling model and trying again.\n" + str(e)
|
|
124
|
+
)
|
|
125
|
+
model.ov_compiled_model = None
|
|
126
|
+
model.struct_outputs = None
|
|
127
|
+
|
|
128
|
+
# Compile a new model
|
|
129
|
+
struct_params = model._parameterize_data(inputs)
|
|
130
|
+
model.struct_outputs = fn(struct_params, stop_token_ids)
|
|
131
|
+
model.ov_device = device
|
|
132
|
+
model_dtype = "f16" if model.dtype in ("float16", "bfloat16") else "f32"
|
|
133
|
+
model.ov_compiled_model = compile_model(
|
|
134
|
+
struct_params, model.struct_outputs, device, model_dtype
|
|
135
|
+
)
|
|
136
|
+
return get_outputs(
|
|
137
|
+
inputs,
|
|
138
|
+
model.struct_outputs,
|
|
139
|
+
model.ov_compiled_model,
|
|
140
|
+
model._unpack_singleton,
|
|
141
|
+
)
|
|
@@ -10,6 +10,7 @@ import keras
|
|
|
10
10
|
from absl import logging
|
|
11
11
|
|
|
12
12
|
from keras_hub.src.api_export import keras_hub_export
|
|
13
|
+
from keras_hub.src.utils import tensor_utils
|
|
13
14
|
from keras_hub.src.utils.keras_utils import print_msg
|
|
14
15
|
from keras_hub.src.utils.keras_utils import sharded_weights_available
|
|
15
16
|
from keras_hub.src.utils.tensor_utils import get_tensor_size_in_bits
|
|
@@ -501,10 +502,17 @@ def jax_memory_cleanup(layer):
|
|
|
501
502
|
# For jax, delete all previous allocated memory to avoid temporarily
|
|
502
503
|
# duplicating variable allocations. torch and tensorflow have stateful
|
|
503
504
|
# variable types and do not need this fix.
|
|
505
|
+
# Skip deletion for sharded arrays to avoid breaking references in
|
|
506
|
+
# distributed setups.
|
|
504
507
|
if keras.config.backend() == "jax":
|
|
505
508
|
for weight in layer.weights:
|
|
506
|
-
if
|
|
507
|
-
|
|
509
|
+
if weight._value is not None:
|
|
510
|
+
# Do not delete sharded arrays, as they may be referenced in
|
|
511
|
+
# JAX's distributed computation graph and deletion can cause
|
|
512
|
+
# errors.
|
|
513
|
+
sharding = getattr(weight._value, "sharding", None)
|
|
514
|
+
if sharding is None:
|
|
515
|
+
weight._value.delete()
|
|
508
516
|
|
|
509
517
|
|
|
510
518
|
def set_dtype_in_config(config, dtype=None):
|
|
@@ -687,6 +695,7 @@ class KerasPresetLoader(PresetLoader):
|
|
|
687
695
|
)
|
|
688
696
|
# We found a `task.json` with a complete config for our class.
|
|
689
697
|
# Forward backbone args.
|
|
698
|
+
kwargs["dtype"] = self._resolve_dtype(self.config, kwargs)
|
|
690
699
|
backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs)
|
|
691
700
|
if "backbone" in task_config["config"]:
|
|
692
701
|
backbone_config = task_config["config"]["backbone"]["config"]
|
|
@@ -708,6 +717,53 @@ class KerasPresetLoader(PresetLoader):
|
|
|
708
717
|
self._load_backbone_weights(task.backbone)
|
|
709
718
|
return task
|
|
710
719
|
|
|
720
|
+
def _resolve_dtype(self, config, kwargs):
|
|
721
|
+
"""Resolves the Model's dtype based on the provided config and kwargs.
|
|
722
|
+
|
|
723
|
+
The data type is resolved based on the following priority:
|
|
724
|
+
1. If a user specified dtype is passed, use that.
|
|
725
|
+
2. If no user specified dtype is passed, and the save dtype is castable
|
|
726
|
+
to the current keras default dtype convert weights on load (float type
|
|
727
|
+
to float type).
|
|
728
|
+
3. If not user specified dtype is passed, and the save dtype is not
|
|
729
|
+
castable to the current default dtype (quantized dtypes). Load the
|
|
730
|
+
saved types verbatim.
|
|
731
|
+
|
|
732
|
+
Args:
|
|
733
|
+
config: dict. The model configuration.
|
|
734
|
+
kwargs: dict. Additional keyword arguments, potentially including
|
|
735
|
+
`dtype`.
|
|
736
|
+
|
|
737
|
+
Returns:
|
|
738
|
+
str, dict, or DTypePolicy. The resolved dtype.
|
|
739
|
+
"""
|
|
740
|
+
# 1. If a user specified dtype is passed, use that.
|
|
741
|
+
if "dtype" in kwargs and kwargs["dtype"] is not None:
|
|
742
|
+
return kwargs["dtype"]
|
|
743
|
+
|
|
744
|
+
saved_dtype = config.get("config", {}).get("dtype")
|
|
745
|
+
|
|
746
|
+
# If there's no saved dtype, we don't need to do anything.
|
|
747
|
+
if saved_dtype is None:
|
|
748
|
+
return None
|
|
749
|
+
|
|
750
|
+
# 2. Check whether the saved dtype is a simple float type.
|
|
751
|
+
policy_name = saved_dtype.get("config", {}).get("name")
|
|
752
|
+
if policy_name and tensor_utils.is_float_dtype(policy_name):
|
|
753
|
+
# If the saved dtype is a float, we can safely cast to the default
|
|
754
|
+
# backend float type.
|
|
755
|
+
if policy_name != keras.config.dtype_policy().name:
|
|
756
|
+
logging.info(
|
|
757
|
+
f"Converting weights saved as {policy_name} "
|
|
758
|
+
"to the current Keras dtype policy "
|
|
759
|
+
f"{keras.config.dtype_policy()}"
|
|
760
|
+
)
|
|
761
|
+
return keras.config.dtype_policy()
|
|
762
|
+
else:
|
|
763
|
+
# 3. Otherwise, the dtype is a complex object (e.g. a
|
|
764
|
+
# DTypePolicyMap for quantization), and should be used as is.
|
|
765
|
+
return saved_dtype
|
|
766
|
+
|
|
711
767
|
def load_preprocessor(
|
|
712
768
|
self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs
|
|
713
769
|
):
|
|
@@ -12,9 +12,11 @@ from packaging import version
|
|
|
12
12
|
|
|
13
13
|
try:
|
|
14
14
|
import tensorflow as tf
|
|
15
|
-
import tensorflow_text as tf_text
|
|
16
15
|
except ImportError:
|
|
17
16
|
tf = None
|
|
17
|
+
try:
|
|
18
|
+
import tensorflow_text as tf_text
|
|
19
|
+
except ImportError:
|
|
18
20
|
tf_text = None
|
|
19
21
|
|
|
20
22
|
|
|
@@ -310,7 +312,29 @@ def is_tensor_type(x):
|
|
|
310
312
|
|
|
311
313
|
|
|
312
314
|
def is_float_dtype(dtype):
|
|
313
|
-
|
|
315
|
+
"""
|
|
316
|
+
Checks if a dtype is a float type by using a regex.
|
|
317
|
+
|
|
318
|
+
This function standardizes the input dtype and then uses a regular
|
|
319
|
+
expression to perform an exact match. It identifies standard floats,
|
|
320
|
+
bfloats, and mixed-precision float types.
|
|
321
|
+
|
|
322
|
+
For example:
|
|
323
|
+
- `is_float_dtype("float32")` returns `True`.
|
|
324
|
+
- `is_float_dtype("bfloat16")` returns `True`.
|
|
325
|
+
- `is_float_dtype("mixed_float16")` returns `True`.
|
|
326
|
+
- `is_float_dtype("int8")` returns `False`.
|
|
327
|
+
- `is_float_dtype("int8_from_float32")` returns `False`.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
dtype: str, DTypePolicy. The data type to check.
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
bool: `True` if the dtype is a floating-point type, `False` otherwise.
|
|
334
|
+
"""
|
|
335
|
+
pattern = re.compile(r"^(mixed_)?(b)?float[0-9]*$")
|
|
336
|
+
standardized_dtype = keras.backend.standardize_dtype(dtype)
|
|
337
|
+
return pattern.match(standardized_dtype) is not None
|
|
314
338
|
|
|
315
339
|
|
|
316
340
|
def is_int_dtype(dtype):
|