keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +21 -0
- keras_hub/models/__init__.py +27 -0
- keras_hub/src/layers/modeling/non_max_supression.py +5 -2
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
- keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/backbone.py +3 -0
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +2 -4
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +2 -2
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/edrec/edrec_backbone.py +147 -0
- keras_hub/src/models/edrec/edrec_layers.py +434 -0
- keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/flux/flux_layers.py +3 -3
- keras_hub/src/models/flux/flux_maths.py +29 -15
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
- keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
- keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/masked_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_presets.py +209 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/rqvae/__init__.py +5 -0
- keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
- keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
- keras_hub/src/models/rwkv7/__init__.py +5 -0
- keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
- keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
- keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
- keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
- keras_hub/src/models/sam/sam_backbone.py +5 -1
- keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
- keras_hub/src/models/sam3/__init__.py +7 -0
- keras_hub/src/models/sam3/roi_align.py +222 -0
- keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
- keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
- keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
- keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
- keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
- keras_hub/src/models/sam3/sam3_layers.py +814 -0
- keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
- keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
- keras_hub/src/models/sam3/sam3_presets.py +16 -0
- keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
- keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
- keras_hub/src/models/sam3/sam3_utils.py +134 -0
- keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
- keras_hub/src/models/segformer/segformer_backbone.py +6 -6
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/tests/test_case.py +394 -3
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
- keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
- keras_hub/src/utils/preset_utils.py +1 -1
- keras_hub/src/utils/tensor_utils.py +12 -0
- keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
- keras_hub/src/utils/transformers/convert_sam3.py +472 -0
- keras_hub/src/utils/transformers/export/gemma3.py +196 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
- keras_hub/src/utils/transformers/export/qwen.py +136 -0
- keras_hub/src/utils/transformers/preset_loader.py +15 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
- keras_hub/src/models/gemma3/rms_normalization.py +0 -26
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -3,6 +3,7 @@ import re
|
|
|
3
3
|
from typing import Iterable
|
|
4
4
|
|
|
5
5
|
import keras
|
|
6
|
+
from keras.src.saving import serialization_lib
|
|
6
7
|
|
|
7
8
|
from keras_hub.src.api_export import keras_hub_export
|
|
8
9
|
from keras_hub.src.tokenizers import tokenizer
|
|
@@ -13,9 +14,11 @@ from keras_hub.src.utils.tensor_utils import preprocessing_function
|
|
|
13
14
|
|
|
14
15
|
try:
|
|
15
16
|
import tensorflow as tf
|
|
16
|
-
import tensorflow_text as tf_text
|
|
17
17
|
except ImportError:
|
|
18
18
|
tf = None
|
|
19
|
+
try:
|
|
20
|
+
import tensorflow_text as tf_text
|
|
21
|
+
except ImportError:
|
|
19
22
|
tf_text = None
|
|
20
23
|
|
|
21
24
|
VOCAB_FILENAME = "vocabulary.txt"
|
|
@@ -374,6 +377,17 @@ class WordPieceTokenizer(tokenizer.Tokenizer):
|
|
|
374
377
|
return
|
|
375
378
|
|
|
376
379
|
if isinstance(vocabulary, str):
|
|
380
|
+
if serialization_lib.in_safe_mode():
|
|
381
|
+
raise ValueError(
|
|
382
|
+
"Requested the loading of a vocabulary file outside of the "
|
|
383
|
+
"model archive. This carries a potential risk of loading "
|
|
384
|
+
"arbitrary and sensitive files and thus it is disallowed "
|
|
385
|
+
"by default. If you trust the source of the artifact, you "
|
|
386
|
+
"can override this error by passing `safe_mode=False` to "
|
|
387
|
+
"the loading function, or calling "
|
|
388
|
+
"`keras.config.enable_unsafe_deserialization()`. "
|
|
389
|
+
f"Vocabulary file: '{vocabulary}'"
|
|
390
|
+
)
|
|
377
391
|
with open(vocabulary, "r", encoding="utf-8") as file:
|
|
378
392
|
self.vocabulary = [line.rstrip() for line in file]
|
|
379
393
|
elif isinstance(vocabulary, Iterable):
|
|
@@ -285,7 +285,7 @@ def tf_copy_gfile_to_cache(preset, path):
|
|
|
285
285
|
# Work around this bug.
|
|
286
286
|
os.remove(local_path)
|
|
287
287
|
if isinstance(
|
|
288
|
-
e, tf.errors.PermissionDeniedError, tf.errors.NotFoundError
|
|
288
|
+
e, (tf.errors.PermissionDeniedError, tf.errors.NotFoundError)
|
|
289
289
|
):
|
|
290
290
|
raise FileNotFoundError(
|
|
291
291
|
f"`{path}` doesn't exist in preset directory `{preset}`.",
|
|
@@ -231,6 +231,7 @@ def tensor_to_list(inputs):
|
|
|
231
231
|
Args:
|
|
232
232
|
inputs: Input tensor, or dict/list/tuple of input tensors.
|
|
233
233
|
"""
|
|
234
|
+
assert_tf_installed("tensor_to_list")
|
|
234
235
|
if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)):
|
|
235
236
|
inputs = tf.convert_to_tensor(inputs)
|
|
236
237
|
if isinstance(inputs, tf.RaggedTensor):
|
|
@@ -246,6 +247,7 @@ def tensor_to_list(inputs):
|
|
|
246
247
|
|
|
247
248
|
def convert_to_ragged_batch(inputs):
|
|
248
249
|
"""Ensure a tf.Tensor is a ragged rank 2 tensor."""
|
|
250
|
+
assert_tf_installed("convert_to_ragged_batch")
|
|
249
251
|
if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)):
|
|
250
252
|
inputs = tf.convert_to_tensor(inputs)
|
|
251
253
|
unbatched = inputs.shape.rank == 1
|
|
@@ -259,6 +261,7 @@ def convert_to_ragged_batch(inputs):
|
|
|
259
261
|
|
|
260
262
|
def truncate_at_token(inputs, token, mask):
|
|
261
263
|
"""Truncate at first instance of `token`, ignoring `mask`."""
|
|
264
|
+
assert_tf_installed("truncate_at_token")
|
|
262
265
|
matches = (inputs == token) & (~mask)
|
|
263
266
|
end_indices = tf.cast(tf.math.argmax(matches, -1), "int32")
|
|
264
267
|
end_indices = tf.where(end_indices == 0, tf.shape(inputs)[-1], end_indices)
|
|
@@ -267,12 +270,21 @@ def truncate_at_token(inputs, token, mask):
|
|
|
267
270
|
|
|
268
271
|
def strip_to_ragged(token_ids, mask, ids_to_strip):
|
|
269
272
|
"""Remove masked and special tokens from a sequence before detokenizing."""
|
|
273
|
+
assert_tf_installed("strip_to_ragged")
|
|
270
274
|
mask = tf.cast(mask, "bool")
|
|
271
275
|
for id in ids_to_strip:
|
|
272
276
|
mask = mask & (token_ids != id)
|
|
273
277
|
return tf.ragged.boolean_mask(token_ids, mask)
|
|
274
278
|
|
|
275
279
|
|
|
280
|
+
def assert_tf_installed(symbol_name):
|
|
281
|
+
if tf is None:
|
|
282
|
+
raise ImportError(
|
|
283
|
+
f"{symbol_name} requires `tensorflow`. "
|
|
284
|
+
"Run `pip install tensorflow` to install it."
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
|
|
276
288
|
def assert_tf_libs_installed(symbol_name):
|
|
277
289
|
if tf_text is None or tf is None:
|
|
278
290
|
raise ImportError(
|
|
@@ -37,6 +37,7 @@ def convert_backbone_config(transformers_config):
|
|
|
37
37
|
else:
|
|
38
38
|
vision_config = transformers_config["vision_config"]
|
|
39
39
|
image_size = vision_config["image_size"]
|
|
40
|
+
transformer_config = transformers_config["text_config"]
|
|
40
41
|
vision_encoder_config = {
|
|
41
42
|
"image_size": image_size,
|
|
42
43
|
"patch_size": vision_config["patch_size"],
|
|
@@ -44,21 +45,44 @@ def convert_backbone_config(transformers_config):
|
|
|
44
45
|
"hidden_dim": vision_config["hidden_size"],
|
|
45
46
|
"num_layers": vision_config["num_hidden_layers"],
|
|
46
47
|
"intermediate_dim": vision_config["intermediate_size"],
|
|
47
|
-
"output_dim":
|
|
48
|
+
"output_dim": transformer_config["hidden_size"],
|
|
48
49
|
"pool_size": 4,
|
|
49
50
|
"layer_norm_epsilon": vision_config.get("layer_norm_eps", 1e-6),
|
|
50
51
|
}
|
|
51
52
|
vision_encoder = Gemma3VisionEncoder(**vision_encoder_config)
|
|
52
|
-
transformer_config = transformers_config["text_config"]
|
|
53
53
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
54
|
+
# Extract rope parameters. HuggingFace uses `rope_scaling` for the
|
|
55
|
+
# global rotary embedding. `rope_parameters` is optional and not used
|
|
56
|
+
# by HF for global scaling when `rope_scaling` is None.
|
|
57
|
+
rope_scaling = transformer_config.get("rope_scaling", None)
|
|
58
|
+
rope_params = transformer_config.get("rope_parameters") or {}
|
|
59
|
+
|
|
60
|
+
if rope_scaling is not None:
|
|
61
|
+
rope_global_config = rope_scaling or {}
|
|
60
62
|
else:
|
|
61
|
-
rope_global_config = {}
|
|
63
|
+
rope_global_config = rope_params.get("full_attention", {})
|
|
64
|
+
|
|
65
|
+
rope_local_config = rope_params.get("sliding_attention", {})
|
|
66
|
+
|
|
67
|
+
# Determine sliding window attention usage from layer_types or config
|
|
68
|
+
sliding_window = transformer_config.get("sliding_window", None)
|
|
69
|
+
layer_types = transformer_config.get("layer_types", [])
|
|
70
|
+
|
|
71
|
+
use_sliding_window_attention = sliding_window not in (None, 0) or any(
|
|
72
|
+
lt == "sliding_attention" for lt in layer_types
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Determine query_head_dim_normalize
|
|
76
|
+
# If query_pre_attn_scalar equals head_dim, then normalize by head_dim
|
|
77
|
+
query_pre_attn_scalar = transformer_config.get(
|
|
78
|
+
"query_pre_attn_scalar", None
|
|
79
|
+
)
|
|
80
|
+
head_dim = transformer_config.get("head_dim")
|
|
81
|
+
if query_pre_attn_scalar is not None and head_dim is not None:
|
|
82
|
+
query_head_dim_normalize = query_pre_attn_scalar == head_dim
|
|
83
|
+
else:
|
|
84
|
+
query_head_dim_normalize = True
|
|
85
|
+
|
|
62
86
|
return {
|
|
63
87
|
"vocabulary_size": transformer_config.get(
|
|
64
88
|
"vocab_size", 262144 if vision_encoder is None else 262208
|
|
@@ -70,25 +94,35 @@ def convert_backbone_config(transformers_config):
|
|
|
70
94
|
"hidden_dim": transformer_config["hidden_size"],
|
|
71
95
|
"intermediate_dim": transformer_config["intermediate_size"],
|
|
72
96
|
"head_dim": transformer_config["head_dim"],
|
|
73
|
-
|
|
74
|
-
"
|
|
75
|
-
"
|
|
76
|
-
"
|
|
97
|
+
# Gemma3 models use post-norm and post-attention norm by default
|
|
98
|
+
"use_post_ffw_norm": transformer_config.get("use_post_ffw_norm", True),
|
|
99
|
+
"use_post_attention_norm": transformer_config.get(
|
|
100
|
+
"use_post_attention_norm", True
|
|
77
101
|
),
|
|
78
|
-
|
|
79
|
-
|
|
102
|
+
# Handle soft-capping parameters (may be null)
|
|
103
|
+
"attention_logit_soft_cap": transformer_config.get(
|
|
104
|
+
"attn_logit_softcapping", None
|
|
80
105
|
),
|
|
81
|
-
"
|
|
82
|
-
|
|
83
|
-
"sliding_window_size": transformer_config["sliding_window"],
|
|
84
|
-
"local_rope_scaling_factor": 1.0,
|
|
85
|
-
"global_rope_scaling_factor": (
|
|
86
|
-
rope_global_config.get("factor", 1.0) if rope_global_config else 1.0
|
|
106
|
+
"final_logit_soft_cap": transformer_config.get(
|
|
107
|
+
"final_logit_softcapping", None
|
|
87
108
|
),
|
|
109
|
+
# Use sliding window attention if configured
|
|
110
|
+
"use_sliding_window_attention": use_sliding_window_attention,
|
|
111
|
+
# Normalize query by head_dim if query_pre_attn_scalar == head_dim
|
|
112
|
+
"query_head_dim_normalize": query_head_dim_normalize,
|
|
113
|
+
# Sliding window size (default to 1024 for full attention layers)
|
|
114
|
+
"sliding_window_size": sliding_window or 4096,
|
|
115
|
+
# Rope scaling factors for local (sliding) and global (full) attention
|
|
116
|
+
"local_rope_scaling_factor": rope_local_config.get("factor", 1.0),
|
|
117
|
+
"global_rope_scaling_factor": rope_global_config.get("factor", 1.0),
|
|
88
118
|
"layer_norm_epsilon": transformer_config.get("rms_norm_eps", 1e-6),
|
|
89
119
|
"use_bidirectional_attention": transformer_config.get(
|
|
90
120
|
"use_bidirectional_attention", False
|
|
91
121
|
),
|
|
122
|
+
# Gemma3 uses query/key normalization by default
|
|
123
|
+
"use_query_key_norm": transformer_config.get(
|
|
124
|
+
"use_query_key_norm", True
|
|
125
|
+
),
|
|
92
126
|
"vision_encoder": vision_encoder,
|
|
93
127
|
}
|
|
94
128
|
|
|
@@ -97,7 +131,7 @@ def convert_weights(backbone, loader, transformers_config):
|
|
|
97
131
|
if transformers_config["model_type"] == "gemma3_text":
|
|
98
132
|
prefix = "model"
|
|
99
133
|
else:
|
|
100
|
-
prefix =
|
|
134
|
+
prefix = _resolve_multimodal_prefix(loader)
|
|
101
135
|
|
|
102
136
|
loader.port_weight(
|
|
103
137
|
keras_variable=backbone.get_layer("token_embedding").embeddings,
|
|
@@ -336,6 +370,18 @@ def convert_weights(backbone, loader, transformers_config):
|
|
|
336
370
|
return backbone
|
|
337
371
|
|
|
338
372
|
|
|
373
|
+
def _resolve_multimodal_prefix(loader):
|
|
374
|
+
candidates = ["model.language_model", "language_model.model"]
|
|
375
|
+
for candidate in candidates:
|
|
376
|
+
key = f"{candidate}.embed_tokens.weight"
|
|
377
|
+
try:
|
|
378
|
+
loader.get_tensor(key)
|
|
379
|
+
return candidate
|
|
380
|
+
except Exception:
|
|
381
|
+
continue
|
|
382
|
+
return candidates[0]
|
|
383
|
+
|
|
384
|
+
|
|
339
385
|
def convert_tokenizer(cls, preset, **kwargs):
|
|
340
386
|
proto = get_file(preset, "tokenizer.model")
|
|
341
387
|
sp = SentencePieceProcessor()
|
|
@@ -198,7 +198,10 @@ def convert_tokenizer(cls, preset, **kwargs):
|
|
|
198
198
|
tokenizer_config = load_json(preset, "tokenizer.json")
|
|
199
199
|
vocab = tokenizer_config["model"]["vocab"]
|
|
200
200
|
merges = tokenizer_config["model"]["merges"]
|
|
201
|
-
|
|
201
|
+
# Check if merges are already strings or lists
|
|
202
|
+
# If they are lists, join them into strings.
|
|
203
|
+
if merges and isinstance(merges[0], list):
|
|
204
|
+
merges = [" ".join(item) for item in merges]
|
|
202
205
|
|
|
203
206
|
# Load all special tokens with the exception of "reserved" ones.
|
|
204
207
|
special_tokens = set()
|