keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@ import re
3
3
  from typing import Iterable
4
4
 
5
5
  import keras
6
+ from keras.src.saving import serialization_lib
6
7
 
7
8
  from keras_hub.src.api_export import keras_hub_export
8
9
  from keras_hub.src.tokenizers import tokenizer
@@ -13,9 +14,11 @@ from keras_hub.src.utils.tensor_utils import preprocessing_function
13
14
 
14
15
  try:
15
16
  import tensorflow as tf
16
- import tensorflow_text as tf_text
17
17
  except ImportError:
18
18
  tf = None
19
+ try:
20
+ import tensorflow_text as tf_text
21
+ except ImportError:
19
22
  tf_text = None
20
23
 
21
24
  VOCAB_FILENAME = "vocabulary.txt"
@@ -374,6 +377,17 @@ class WordPieceTokenizer(tokenizer.Tokenizer):
374
377
  return
375
378
 
376
379
  if isinstance(vocabulary, str):
380
+ if serialization_lib.in_safe_mode():
381
+ raise ValueError(
382
+ "Requested the loading of a vocabulary file outside of the "
383
+ "model archive. This carries a potential risk of loading "
384
+ "arbitrary and sensitive files and thus it is disallowed "
385
+ "by default. If you trust the source of the artifact, you "
386
+ "can override this error by passing `safe_mode=False` to "
387
+ "the loading function, or calling "
388
+ "`keras.config.enable_unsafe_deserialization()`. "
389
+ f"Vocabulary file: '{vocabulary}'"
390
+ )
377
391
  with open(vocabulary, "r", encoding="utf-8") as file:
378
392
  self.vocabulary = [line.rstrip() for line in file]
379
393
  elif isinstance(vocabulary, Iterable):
@@ -285,7 +285,7 @@ def tf_copy_gfile_to_cache(preset, path):
285
285
  # Work around this bug.
286
286
  os.remove(local_path)
287
287
  if isinstance(
288
- e, tf.errors.PermissionDeniedError, tf.errors.NotFoundError
288
+ e, (tf.errors.PermissionDeniedError, tf.errors.NotFoundError)
289
289
  ):
290
290
  raise FileNotFoundError(
291
291
  f"`{path}` doesn't exist in preset directory `{preset}`.",
@@ -231,6 +231,7 @@ def tensor_to_list(inputs):
231
231
  Args:
232
232
  inputs: Input tensor, or dict/list/tuple of input tensors.
233
233
  """
234
+ assert_tf_installed("tensor_to_list")
234
235
  if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)):
235
236
  inputs = tf.convert_to_tensor(inputs)
236
237
  if isinstance(inputs, tf.RaggedTensor):
@@ -246,6 +247,7 @@ def tensor_to_list(inputs):
246
247
 
247
248
  def convert_to_ragged_batch(inputs):
248
249
  """Ensure a tf.Tensor is a ragged rank 2 tensor."""
250
+ assert_tf_installed("convert_to_ragged_batch")
249
251
  if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)):
250
252
  inputs = tf.convert_to_tensor(inputs)
251
253
  unbatched = inputs.shape.rank == 1
@@ -259,6 +261,7 @@ def convert_to_ragged_batch(inputs):
259
261
 
260
262
  def truncate_at_token(inputs, token, mask):
261
263
  """Truncate at first instance of `token`, ignoring `mask`."""
264
+ assert_tf_installed("truncate_at_token")
262
265
  matches = (inputs == token) & (~mask)
263
266
  end_indices = tf.cast(tf.math.argmax(matches, -1), "int32")
264
267
  end_indices = tf.where(end_indices == 0, tf.shape(inputs)[-1], end_indices)
@@ -267,12 +270,21 @@ def truncate_at_token(inputs, token, mask):
267
270
 
268
271
  def strip_to_ragged(token_ids, mask, ids_to_strip):
269
272
  """Remove masked and special tokens from a sequence before detokenizing."""
273
+ assert_tf_installed("strip_to_ragged")
270
274
  mask = tf.cast(mask, "bool")
271
275
  for id in ids_to_strip:
272
276
  mask = mask & (token_ids != id)
273
277
  return tf.ragged.boolean_mask(token_ids, mask)
274
278
 
275
279
 
280
+ def assert_tf_installed(symbol_name):
281
+ if tf is None:
282
+ raise ImportError(
283
+ f"{symbol_name} requires `tensorflow`. "
284
+ "Run `pip install tensorflow` to install it."
285
+ )
286
+
287
+
276
288
  def assert_tf_libs_installed(symbol_name):
277
289
  if tf_text is None or tf is None:
278
290
  raise ImportError(
@@ -37,6 +37,7 @@ def convert_backbone_config(transformers_config):
37
37
  else:
38
38
  vision_config = transformers_config["vision_config"]
39
39
  image_size = vision_config["image_size"]
40
+ transformer_config = transformers_config["text_config"]
40
41
  vision_encoder_config = {
41
42
  "image_size": image_size,
42
43
  "patch_size": vision_config["patch_size"],
@@ -44,21 +45,44 @@ def convert_backbone_config(transformers_config):
44
45
  "hidden_dim": vision_config["hidden_size"],
45
46
  "num_layers": vision_config["num_hidden_layers"],
46
47
  "intermediate_dim": vision_config["intermediate_size"],
47
- "output_dim": 2560,
48
+ "output_dim": transformer_config["hidden_size"],
48
49
  "pool_size": 4,
49
50
  "layer_norm_epsilon": vision_config.get("layer_norm_eps", 1e-6),
50
51
  }
51
52
  vision_encoder = Gemma3VisionEncoder(**vision_encoder_config)
52
- transformer_config = transformers_config["text_config"]
53
53
 
54
- if "rope_parameters" in transformer_config:
55
- rope_global_config = transformer_config.get("rope_parameters", {}).get(
56
- "full_attention"
57
- )
58
- elif "rope_scaling" in transformer_config:
59
- rope_global_config = transformer_config["rope_scaling"]
54
+ # Extract rope parameters. HuggingFace uses `rope_scaling` for the
55
+ # global rotary embedding. `rope_parameters` is optional and not used
56
+ # by HF for global scaling when `rope_scaling` is None.
57
+ rope_scaling = transformer_config.get("rope_scaling", None)
58
+ rope_params = transformer_config.get("rope_parameters") or {}
59
+
60
+ if rope_scaling is not None:
61
+ rope_global_config = rope_scaling or {}
60
62
  else:
61
- rope_global_config = {}
63
+ rope_global_config = rope_params.get("full_attention", {})
64
+
65
+ rope_local_config = rope_params.get("sliding_attention", {})
66
+
67
+ # Determine sliding window attention usage from layer_types or config
68
+ sliding_window = transformer_config.get("sliding_window", None)
69
+ layer_types = transformer_config.get("layer_types", [])
70
+
71
+ use_sliding_window_attention = sliding_window not in (None, 0) or any(
72
+ lt == "sliding_attention" for lt in layer_types
73
+ )
74
+
75
+ # Determine query_head_dim_normalize
76
+ # If query_pre_attn_scalar equals head_dim, then normalize by head_dim
77
+ query_pre_attn_scalar = transformer_config.get(
78
+ "query_pre_attn_scalar", None
79
+ )
80
+ head_dim = transformer_config.get("head_dim")
81
+ if query_pre_attn_scalar is not None and head_dim is not None:
82
+ query_head_dim_normalize = query_pre_attn_scalar == head_dim
83
+ else:
84
+ query_head_dim_normalize = True
85
+
62
86
  return {
63
87
  "vocabulary_size": transformer_config.get(
64
88
  "vocab_size", 262144 if vision_encoder is None else 262208
@@ -70,25 +94,35 @@ def convert_backbone_config(transformers_config):
70
94
  "hidden_dim": transformer_config["hidden_size"],
71
95
  "intermediate_dim": transformer_config["intermediate_size"],
72
96
  "head_dim": transformer_config["head_dim"],
73
- "use_post_ffw_norm": True,
74
- "use_post_attention_norm": True,
75
- "attention_logit_softcap": transformer_config.get(
76
- "attn_logit_softcap", None
97
+ # Gemma3 models use post-norm and post-attention norm by default
98
+ "use_post_ffw_norm": transformer_config.get("use_post_ffw_norm", True),
99
+ "use_post_attention_norm": transformer_config.get(
100
+ "use_post_attention_norm", True
77
101
  ),
78
- "final_logit_softcap": transformer_config.get(
79
- "final_logit_softcap", None
102
+ # Handle soft-capping parameters (may be null)
103
+ "attention_logit_soft_cap": transformer_config.get(
104
+ "attn_logit_softcapping", None
80
105
  ),
81
- "use_sliding_window_attention": True,
82
- "query_head_dim_normalize": True,
83
- "sliding_window_size": transformer_config["sliding_window"],
84
- "local_rope_scaling_factor": 1.0,
85
- "global_rope_scaling_factor": (
86
- rope_global_config.get("factor", 1.0) if rope_global_config else 1.0
106
+ "final_logit_soft_cap": transformer_config.get(
107
+ "final_logit_softcapping", None
87
108
  ),
109
+ # Use sliding window attention if configured
110
+ "use_sliding_window_attention": use_sliding_window_attention,
111
+ # Normalize query by head_dim if query_pre_attn_scalar == head_dim
112
+ "query_head_dim_normalize": query_head_dim_normalize,
113
+ # Sliding window size (default to 1024 for full attention layers)
114
+ "sliding_window_size": sliding_window or 4096,
115
+ # Rope scaling factors for local (sliding) and global (full) attention
116
+ "local_rope_scaling_factor": rope_local_config.get("factor", 1.0),
117
+ "global_rope_scaling_factor": rope_global_config.get("factor", 1.0),
88
118
  "layer_norm_epsilon": transformer_config.get("rms_norm_eps", 1e-6),
89
119
  "use_bidirectional_attention": transformer_config.get(
90
120
  "use_bidirectional_attention", False
91
121
  ),
122
+ # Gemma3 uses query/key normalization by default
123
+ "use_query_key_norm": transformer_config.get(
124
+ "use_query_key_norm", True
125
+ ),
92
126
  "vision_encoder": vision_encoder,
93
127
  }
94
128
 
@@ -97,7 +131,7 @@ def convert_weights(backbone, loader, transformers_config):
97
131
  if transformers_config["model_type"] == "gemma3_text":
98
132
  prefix = "model"
99
133
  else:
100
- prefix = "language_model.model"
134
+ prefix = _resolve_multimodal_prefix(loader)
101
135
 
102
136
  loader.port_weight(
103
137
  keras_variable=backbone.get_layer("token_embedding").embeddings,
@@ -336,6 +370,18 @@ def convert_weights(backbone, loader, transformers_config):
336
370
  return backbone
337
371
 
338
372
 
373
+ def _resolve_multimodal_prefix(loader):
374
+ candidates = ["model.language_model", "language_model.model"]
375
+ for candidate in candidates:
376
+ key = f"{candidate}.embed_tokens.weight"
377
+ try:
378
+ loader.get_tensor(key)
379
+ return candidate
380
+ except Exception:
381
+ continue
382
+ return candidates[0]
383
+
384
+
339
385
  def convert_tokenizer(cls, preset, **kwargs):
340
386
  proto = get_file(preset, "tokenizer.model")
341
387
  sp = SentencePieceProcessor()
@@ -198,7 +198,10 @@ def convert_tokenizer(cls, preset, **kwargs):
198
198
  tokenizer_config = load_json(preset, "tokenizer.json")
199
199
  vocab = tokenizer_config["model"]["vocab"]
200
200
  merges = tokenizer_config["model"]["merges"]
201
- merges = [" ".join(item) for item in merges]
201
+ # Check if merges are already strings or lists
202
+ # If they are lists, join them into strings.
203
+ if merges and isinstance(merges[0], list):
204
+ merges = [" ".join(item) for item in merges]
202
205
 
203
206
  # Load all special tokens with the exception of "reserved" ones.
204
207
  special_tokens = set()