keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,222 @@
1
+ from keras import backend
2
+ from keras import ops
3
+
4
+
5
+ def _bilinear_interpolate(
6
+ feature_maps, roi_batch_ind, y, x, ymask, xmask, height, width, hidden_dim
7
+ ):
8
+ feature_maps_dtype = backend.standardize_dtype(feature_maps.dtype)
9
+ y = ops.maximum(y, 0.0)
10
+ x = ops.maximum(x, 0.0)
11
+ y_low = ops.cast(y, "int32")
12
+ x_low = ops.cast(x, "int32")
13
+ y_high = ops.where(
14
+ ops.greater_equal(y_low, height - 1), height - 1, y_low + 1
15
+ )
16
+ y_low = ops.where(ops.greater_equal(y_low, height - 1), height - 1, y_low)
17
+ y = ops.where(
18
+ ops.greater_equal(y_low, height - 1),
19
+ ops.cast(y, dtype=feature_maps_dtype),
20
+ y,
21
+ )
22
+
23
+ x_high = ops.where(
24
+ ops.greater_equal(x_low, width - 1), width - 1, x_low + 1
25
+ )
26
+ x_low = ops.where(ops.greater_equal(x_low, width - 1), width - 1, x_low)
27
+ x = ops.where(
28
+ ops.greater_equal(x_low, width - 1),
29
+ ops.cast(x, dtype=feature_maps_dtype),
30
+ x,
31
+ )
32
+
33
+ ly = ops.subtract(y, y_low)
34
+ lx = ops.subtract(x, x_low)
35
+ hy = ops.subtract(1.0, ly)
36
+ hx = ops.subtract(1.0, lx)
37
+
38
+ def masked_index(y, x):
39
+ y = ops.where(ymask[:, None, :], y, 0)
40
+ x = ops.where(xmask[:, None, :], x, 0)
41
+ batch_idx = roi_batch_ind[:, None, None, None, None, None]
42
+ channel_idx = ops.arange(hidden_dim)[None, None, None, None, None, :]
43
+ y_idx = y[:, :, None, :, None, None]
44
+ x_idx = x[:, None, :, None, :, None]
45
+
46
+ if backend.backend() == "tensorflow":
47
+ import tensorflow as tf
48
+
49
+ # Explicitly broadcast indices to the same shape for XLA
50
+ # compatibility
51
+ common_zero = ops.zeros_like(
52
+ batch_idx + y_idx + x_idx + channel_idx
53
+ )
54
+ batch_idx = batch_idx + common_zero
55
+ y_idx = y_idx + common_zero
56
+ x_idx = ops.transpose(
57
+ ops.transpose(x_idx, (0, 2, 1, 4, 3, 5)) + common_zero,
58
+ (0, 2, 1, 4, 3, 5),
59
+ )
60
+ channel_idx = channel_idx + common_zero
61
+ indices = ops.stack([batch_idx, y_idx, x_idx, channel_idx], axis=-1)
62
+ indices = ops.cast(indices, "int32")
63
+ return tf.gather_nd(feature_maps, indices)
64
+ else:
65
+ return feature_maps[
66
+ batch_idx,
67
+ y_idx,
68
+ x_idx,
69
+ channel_idx,
70
+ ]
71
+
72
+ v1 = masked_index(y_low, x_low)
73
+ v2 = masked_index(y_low, x_high)
74
+ v3 = masked_index(y_high, x_low)
75
+ v4 = masked_index(y_high, x_high)
76
+
77
+ def outer_prod(y, x):
78
+ return ops.multiply(
79
+ y[:, :, None, :, None, None], x[:, None, :, None, :, None]
80
+ )
81
+
82
+ w1 = outer_prod(hy, hx)
83
+ w2 = outer_prod(hy, lx)
84
+ w3 = outer_prod(ly, hx)
85
+ w4 = outer_prod(ly, lx)
86
+
87
+ val = ops.add(
88
+ ops.add(ops.multiply(w1, v1), ops.multiply(w2, v2)),
89
+ ops.add(ops.multiply(w3, v3), ops.multiply(w4, v4)),
90
+ )
91
+ return val
92
+
93
+
94
+ def roi_align_torch(
95
+ feature_maps,
96
+ rois,
97
+ output_size,
98
+ spatial_scale=1.0,
99
+ aligned=False,
100
+ ):
101
+ import torchvision
102
+
103
+ dtype = backend.standardize_dtype(feature_maps.dtype)
104
+ need_cast = False
105
+ if dtype == "bfloat16":
106
+ # torchvision.ops.roi_align does not support bfloat16.
107
+ feature_maps = ops.cast(feature_maps, "float32")
108
+ rois = ops.cast(rois, "float32")
109
+ need_cast = True
110
+
111
+ output = ops.transpose(
112
+ torchvision.ops.roi_align(
113
+ ops.transpose(feature_maps, (0, 3, 1, 2)),
114
+ rois,
115
+ output_size,
116
+ spatial_scale=spatial_scale,
117
+ aligned=aligned,
118
+ ),
119
+ (0, 2, 3, 1),
120
+ )
121
+ if need_cast:
122
+ output = ops.cast(output, dtype)
123
+ return output
124
+
125
+
126
+ def roi_align(
127
+ feature_maps,
128
+ rois,
129
+ output_size,
130
+ height,
131
+ width,
132
+ hidden_dim,
133
+ spatial_scale=1.0,
134
+ aligned=False,
135
+ ):
136
+ # Use torchvision's optimized roi_align implementation.
137
+ if backend.backend() == "torch":
138
+ return roi_align_torch(
139
+ feature_maps,
140
+ rois,
141
+ output_size,
142
+ spatial_scale=spatial_scale,
143
+ aligned=aligned,
144
+ )
145
+
146
+ original_dtype = backend.standardize_dtype(feature_maps.dtype)
147
+ out_h, out_w = output_size[0], output_size[1]
148
+
149
+ feature_maps = ops.cast(feature_maps, "float32")
150
+ rois = ops.cast(rois, "float32")
151
+
152
+ ph = ops.arange(out_h, dtype="float32")
153
+ pw = ops.arange(out_w, dtype="float32")
154
+
155
+ # input: [N, C, H, W]
156
+ # rois: [K, 5]
157
+
158
+ roi_batch_ind = ops.cast(rois[:, 0], "int32")
159
+ offset = 0.5 if aligned else 0.0
160
+ roi_start_w = ops.subtract(ops.multiply(rois[:, 1], spatial_scale), offset)
161
+ roi_start_h = ops.subtract(ops.multiply(rois[:, 2], spatial_scale), offset)
162
+ roi_end_w = ops.subtract(ops.multiply(rois[:, 3], spatial_scale), offset)
163
+ roi_end_h = ops.subtract(ops.multiply(rois[:, 4], spatial_scale), offset)
164
+
165
+ roi_width = ops.subtract(roi_end_w, roi_start_w)
166
+ roi_height = ops.subtract(roi_end_h, roi_start_h)
167
+ if not aligned:
168
+ roi_width = ops.maximum(roi_width, 1.0)
169
+ roi_height = ops.maximum(roi_height, 1.0)
170
+
171
+ bin_size_h = ops.divide(roi_height, out_h)
172
+ bin_size_w = ops.divide(roi_width, out_w)
173
+
174
+ roi_bin_grid_h = ops.ceil(ops.divide(roi_height, out_h))
175
+ roi_bin_grid_w = ops.ceil(ops.divide(roi_width, out_w))
176
+
177
+ count = ops.maximum(ops.multiply(roi_bin_grid_h, roi_bin_grid_w), 1.0)
178
+ iy = ops.arange(height, dtype="float32")
179
+ ix = ops.arange(width, dtype="float32")
180
+ ymask = ops.less(iy[None, :], roi_bin_grid_h[:, None])
181
+ xmask = ops.less(ix[None, :], roi_bin_grid_w[:, None])
182
+
183
+ def from_k(t):
184
+ return t[:, None, None]
185
+
186
+ y = ops.add(
187
+ ops.add(
188
+ from_k(roi_start_h),
189
+ ops.multiply(ph[None, :, None], from_k(bin_size_h)),
190
+ ),
191
+ ops.multiply(
192
+ ops.cast(ops.add(iy[None, None, :], 0.5), dtype="float32"),
193
+ from_k(ops.divide(bin_size_h, roi_bin_grid_h)),
194
+ ),
195
+ )
196
+ x = ops.add(
197
+ ops.add(
198
+ from_k(roi_start_w),
199
+ ops.multiply(pw[None, :, None], from_k(bin_size_w)),
200
+ ),
201
+ ops.multiply(
202
+ ops.cast(ops.add(ix[None, None, :], 0.5), dtype="float32"),
203
+ from_k(ops.divide(bin_size_w, roi_bin_grid_w)),
204
+ ),
205
+ )
206
+ val = _bilinear_interpolate(
207
+ feature_maps,
208
+ roi_batch_ind,
209
+ y,
210
+ x,
211
+ ymask,
212
+ xmask,
213
+ height,
214
+ width,
215
+ hidden_dim,
216
+ )
217
+ val = ops.where(ymask[:, None, None, :, None, None], val, 0.0)
218
+ val = ops.where(xmask[:, None, None, None, :, None], val, 0.0)
219
+
220
+ output = ops.sum(val, axis=(3, 4))
221
+ output = ops.divide(output, count[:, None, None, None])
222
+ return ops.cast(output, original_dtype)