keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -138,6 +138,27 @@ from keras_hub.src.models.sam.sam_mask_decoder import (
138
138
  from keras_hub.src.models.sam.sam_prompt_encoder import (
139
139
  SAMPromptEncoder as SAMPromptEncoder,
140
140
  )
141
+ from keras_hub.src.models.sam3.sam3_detr_decoder import (
142
+ SAM3DetrDecoder as SAM3DetrDecoder,
143
+ )
144
+ from keras_hub.src.models.sam3.sam3_detr_encoder import (
145
+ SAM3DetrEncoder as SAM3DetrEncoder,
146
+ )
147
+ from keras_hub.src.models.sam3.sam3_geometry_encoder import (
148
+ SAM3GeometryEncoder as SAM3GeometryEncoder,
149
+ )
150
+ from keras_hub.src.models.sam3.sam3_image_converter import (
151
+ SAM3ImageConverter as SAM3ImageConverter,
152
+ )
153
+ from keras_hub.src.models.sam3.sam3_mask_decoder import (
154
+ SAM3MaskDecoder as SAM3MaskDecoder,
155
+ )
156
+ from keras_hub.src.models.sam3.sam3_text_encoder import (
157
+ SAM3TextEncoder as SAM3TextEncoder,
158
+ )
159
+ from keras_hub.src.models.sam3.sam3_vision_encoder import (
160
+ SAM3VisionEncoder as SAM3VisionEncoder,
161
+ )
141
162
  from keras_hub.src.models.segformer.segformer_image_converter import (
142
163
  SegFormerImageConverter as SegFormerImageConverter,
143
164
  )
@@ -211,6 +211,12 @@ from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor i
211
211
  from keras_hub.src.models.distil_bert.distil_bert_tokenizer import (
212
212
  DistilBertTokenizer as DistilBertTokenizer,
213
213
  )
214
+ from keras_hub.src.models.edrec.edrec_backbone import (
215
+ EdRecBackbone as EdRecBackbone,
216
+ )
217
+ from keras_hub.src.models.edrec.edrec_seq2seq_lm import (
218
+ EdRecSeq2SeqLM as EdRecSeq2SeqLM,
219
+ )
214
220
  from keras_hub.src.models.efficientnet.efficientnet_backbone import (
215
221
  EfficientNetBackbone as EfficientNetBackbone,
216
222
  )
@@ -629,6 +635,15 @@ from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor i
629
635
  from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
630
636
  RoformerV2Tokenizer as RoformerV2Tokenizer,
631
637
  )
638
+ from keras_hub.src.models.rwkv7.rwkv7_backbone import (
639
+ RWKV7Backbone as RWKV7Backbone,
640
+ )
641
+ from keras_hub.src.models.rwkv7.rwkv7_causal_lm import (
642
+ RWKV7CausalLM as RWKV7CausalLM,
643
+ )
644
+ from keras_hub.src.models.rwkv7.rwkv7_causal_lm_preprocessor import (
645
+ RWKV7CausalLMPreprocessor as RWKV7CausalLMPreprocessor,
646
+ )
632
647
  from keras_hub.src.models.sam.sam_backbone import SAMBackbone as SAMBackbone
633
648
  from keras_hub.src.models.sam.sam_image_segmenter import (
634
649
  SAMImageSegmenter as SAMImageSegmenter,
@@ -636,6 +651,18 @@ from keras_hub.src.models.sam.sam_image_segmenter import (
636
651
  from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (
637
652
  SAMImageSegmenterPreprocessor as SAMImageSegmenterPreprocessor,
638
653
  )
654
+ from keras_hub.src.models.sam3.sam3_pc_backbone import (
655
+ SAM3PromptableConceptBackbone as SAM3PromptableConceptBackbone,
656
+ )
657
+ from keras_hub.src.models.sam3.sam3_pc_image_segmenter import (
658
+ SAM3PromptableConceptImageSegmenter as SAM3PromptableConceptImageSegmenter,
659
+ )
660
+ from keras_hub.src.models.sam3.sam3_pc_image_segmenter_preprocessor import (
661
+ SAM3PromptableConceptImageSegmenterPreprocessor as SAM3PromptableConceptImageSegmenterPreprocessor,
662
+ )
663
+ from keras_hub.src.models.sam3.sam3_tokenizer import (
664
+ SAM3Tokenizer as SAM3Tokenizer,
665
+ )
639
666
  from keras_hub.src.models.segformer.segformer_backbone import (
640
667
  SegFormerBackbone as SegFormerBackbone,
641
668
  )
@@ -290,16 +290,19 @@ def non_max_suppression(
290
290
  "int32",
291
291
  )
292
292
  idx = ops.minimum(idx, num_boxes - 1)
293
+ idx = ops.cast(idx, "int32")
293
294
 
294
295
  index_offsets = ops.cast(ops.arange(batch_size) * num_boxes, "int32")
295
296
  take_along_axis_idx = ops.reshape(
296
297
  idx + ops.expand_dims(index_offsets, 1), [-1]
297
298
  )
299
+ take_along_axis_idx = ops.cast(take_along_axis_idx, "int32")
298
300
 
299
301
  if keras.backend.backend() != "tensorflow":
300
- idx = ops.take_along_axis(
301
- ops.reshape(sorted_indices, [-1]), take_along_axis_idx
302
+ sorted_indices_int = ops.cast(
303
+ ops.reshape(sorted_indices, [-1]), "int32"
302
304
  )
305
+ idx = ops.take_along_axis(sorted_indices_int, take_along_axis_idx)
303
306
  else:
304
307
  import tensorflow as tf
305
308
 
@@ -1,281 +1,8 @@
1
- import inspect
2
-
3
1
  import keras
4
- from keras import ops
5
2
 
6
3
  from keras_hub.src.api_export import keras_hub_export
7
4
 
8
5
 
9
6
  @keras_hub_export("keras_hub.layers.ReversibleEmbedding")
10
- class ReversibleEmbedding(keras.layers.Embedding):
11
- """An embedding layer which can project backwards to the input dim.
12
-
13
- This layer is an extension of `keras.layers.Embedding` for language models.
14
- This layer can be called "in reverse" with `reverse=True`, in which case the
15
- layer will linearly project from `output_dim` back to `input_dim`.
16
-
17
- By default, the reverse projection will use the transpose of the
18
- `embeddings` weights to project to `input_dim` (weights are "tied"). If
19
- `tie_weights=False`, the model will use a separate, trainable variable for
20
- reverse projection.
21
-
22
- This layer has no bias terms.
23
-
24
- Args:
25
- input_dim: Integer. Size of the vocabulary,
26
- i.e. maximum integer index + 1.
27
- output_dim: Integer. Dimension of the dense embedding.
28
- tie_weights: Boolean, whether or not the matrix for embedding and
29
- the matrix for the `reverse` projection should share the same
30
- weights.
31
- embeddings_initializer: Initializer for the `embeddings`
32
- matrix (see `keras.initializers`).
33
- embeddings_regularizer: Regularizer function applied to
34
- the `embeddings` matrix (see `keras.regularizers`).
35
- embeddings_constraint: Constraint function applied to
36
- the `embeddings` matrix (see `keras.constraints`).
37
- mask_zero: Boolean, whether or not the input value 0 is a special
38
- "padding" value that should be masked out.
39
- reverse_dtype: The dtype for the reverse projection computation.
40
- Defaults to the `compute_dtype` of the layer.
41
- logit_soft_cap: If `logit_soft_cap` is set and `reverse=True`, the
42
- output logits will be scaled by
43
- `tanh(logits / logit_soft_cap) * logit_soft_cap`. This narrows the
44
- range of output logits and can improve training.
45
- **kwargs: other keyword arguments passed to `keras.layers.Embedding`,
46
- including `name`, `trainable`, `dtype` etc.
47
-
48
- Call arguments:
49
- inputs: The tensor inputs to the layer.
50
- reverse: Boolean. If `True` the layer will perform a linear projection
51
- from `output_dim` to `input_dim`, instead of a normal embedding
52
- call. Default to `False`.
53
-
54
- Example:
55
- ```python
56
- batch_size = 16
57
- vocab_size = 100
58
- hidden_dim = 32
59
- seq_length = 50
60
-
61
- # Generate random inputs.
62
- token_ids = np.random.randint(vocab_size, size=(batch_size, seq_length))
63
-
64
- embedding = keras_hub.layers.ReversibleEmbedding(vocab_size, hidden_dim)
65
- # Embed tokens to shape `(batch_size, seq_length, hidden_dim)`.
66
- hidden_states = embedding(token_ids)
67
- # Project hidden states to shape `(batch_size, seq_length, vocab_size)`.
68
- logits = embedding(hidden_states, reverse=True)
69
- ```
70
-
71
- References:
72
- - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
73
- - [Press and Wolf, 2016](https://arxiv.org/abs/1608.05859)
74
- """
75
-
76
- def __init__(
77
- self,
78
- input_dim,
79
- output_dim,
80
- tie_weights=True,
81
- embeddings_initializer="uniform",
82
- embeddings_regularizer=None,
83
- embeddings_constraint=None,
84
- mask_zero=False,
85
- reverse_dtype=None,
86
- logit_soft_cap=None,
87
- **kwargs,
88
- ):
89
- super().__init__(
90
- input_dim,
91
- output_dim,
92
- embeddings_initializer=embeddings_initializer,
93
- embeddings_regularizer=embeddings_regularizer,
94
- embeddings_constraint=embeddings_constraint,
95
- mask_zero=mask_zero,
96
- **kwargs,
97
- )
98
- self.tie_weights = tie_weights
99
- self.reverse_dtype = reverse_dtype
100
- self.logit_soft_cap = logit_soft_cap
101
-
102
- def build(self, inputs_shape=None):
103
- super().build(inputs_shape)
104
- if (
105
- not self.tie_weights
106
- and getattr(self, "quantization_mode", None) != "int8"
107
- ):
108
- self.reverse_embeddings = self.add_weight(
109
- name="reverse_embeddings",
110
- shape=(self.output_dim, self.input_dim),
111
- initializer=self.embeddings_initializer,
112
- dtype=self.dtype,
113
- )
114
-
115
- def call(self, inputs, reverse=False):
116
- if reverse:
117
- if self.tie_weights:
118
- kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
119
- else:
120
- kernel = self.reverse_embeddings
121
- if self.reverse_dtype is not None:
122
- inputs = ops.cast(inputs, self.reverse_dtype)
123
- kernel = ops.cast(kernel, self.reverse_dtype)
124
- logits = ops.matmul(inputs, kernel)
125
- # Optionally soft-cap logits.
126
- if self.logit_soft_cap is not None:
127
- soft_cap = self.logit_soft_cap
128
- logits = ops.tanh(logits / soft_cap) * soft_cap
129
- return logits
130
-
131
- return super().call(inputs)
132
-
133
- def get_config(self):
134
- config = super().get_config()
135
- config.update(
136
- {
137
- "tie_weights": self.tie_weights,
138
- "reverse_dtype": self.reverse_dtype,
139
- "logit_soft_cap": self.logit_soft_cap,
140
- }
141
- )
142
- return config
143
-
144
- def save_own_variables(self, store):
145
- if not self.built:
146
- return
147
- super().save_own_variables(store)
148
- target_variables = []
149
- if not self.tie_weights:
150
- # Store the reverse embedding weights as the last weights.
151
- target_variables.append(self.reverse_embeddings)
152
- if getattr(self, "quantization_mode", None) == "int8":
153
- target_variables.append(self.reverse_embeddings_scale)
154
- for i, variable in enumerate(target_variables, start=len(store)):
155
- store[str(i)] = variable
156
-
157
- def load_own_variables(self, store):
158
- if not self.built:
159
- self.build()
160
- super().load_own_variables(store)
161
- if not self.tie_weights:
162
- # Last weights in the stores are the reverse embedding weights.
163
- target_variables = [self.reverse_embeddings]
164
- if getattr(self, "quantization_mode", None) == "int8":
165
- target_variables.append(self.reverse_embeddings_scale)
166
- for i, variable in enumerate(
167
- target_variables, start=len(store) - len(target_variables)
168
- ):
169
- variable.assign(store[str(i)])
170
-
171
- def compute_output_spec(self, inputs, reverse=False):
172
- output_shape = list(inputs.shape)
173
- if reverse:
174
- output_shape[-1] = self.input_dim
175
- else:
176
- output_shape += [self.output_dim]
177
- return keras.KerasTensor(output_shape, dtype=self.compute_dtype)
178
-
179
- # Quantization-related (int8) methods
180
-
181
- def quantized_call(self, inputs, reverse=False):
182
- # TODO (hongyu): This function could be removed once we add `*args` and
183
- # `**kwargs` for `Embedding.quantized_call`
184
- if self.quantization_mode == "int8":
185
- return self._int8_call(inputs, reverse=reverse)
186
- else:
187
- self._quantization_mode_error(self.quantization_mode)
188
-
189
- def _int8_build(self, embeddings_shape=None):
190
- if (
191
- "embeddings_shape"
192
- in inspect.signature(super()._int8_build).parameters
193
- ):
194
- if embeddings_shape is None:
195
- embeddings_shape = (self.input_dim, self.output_dim)
196
- super()._int8_build(embeddings_shape=embeddings_shape)
197
- else:
198
- # Backward compatibility for older versions of Keras.
199
- super()._int8_build()
200
- self.inputs_quantizer = keras.quantizers.AbsMaxQuantizer(axis=-1)
201
- if not self.tie_weights:
202
- self.reverse_embeddings = self.add_weight(
203
- name="reverse_embeddings",
204
- shape=(self.output_dim, self.input_dim),
205
- initializer="zeros",
206
- dtype="int8",
207
- trainable=False,
208
- )
209
- self.reverse_embeddings_scale = self.add_weight(
210
- name="reverse_embeddings_scale",
211
- shape=(self.input_dim,),
212
- initializer="ones",
213
- trainable=False,
214
- )
215
- self._is_quantized = True
216
-
217
- def _int8_call(self, inputs, reverse=False):
218
- if reverse:
219
- if self.tie_weights:
220
- kernel = ops.transpose(self._embeddings)
221
- scale = ops.transpose(self.embeddings_scale)
222
- else:
223
- kernel = self.reverse_embeddings
224
- scale = self.reverse_embeddings_scale
225
- inputs, inputs_scale = self.inputs_quantizer(inputs)
226
- logits = ops.matmul(inputs, kernel)
227
- # De-scale outputs
228
- logits = ops.cast(logits, self.compute_dtype)
229
- logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
230
- # Optionally soft-cap logits.
231
- if self.logit_soft_cap is not None:
232
- soft_cap = self.logit_soft_cap
233
- logits = ops.tanh(logits / soft_cap) * soft_cap
234
- return logits
235
-
236
- return super()._int8_call(inputs)
237
-
238
- def quantize(self, mode, type_check=True, config=None):
239
- del config
240
- if type_check and type(self) is not ReversibleEmbedding:
241
- raise self._not_implemented_error(self.quantize)
242
-
243
- def abs_max_quantize(inputs, axis):
244
- return keras.quantizers.abs_max_quantize(
245
- inputs, axis=axis, to_numpy=True
246
- )
247
-
248
- if mode != "int8":
249
- raise NotImplementedError(
250
- "Invalid quantization mode. Expected 'int8'. "
251
- f"Received: quantization_mode={mode}"
252
- )
253
-
254
- embeddings_shape = (self.input_dim, self.output_dim)
255
- if mode == "int8":
256
- embeddings, embeddings_scale = abs_max_quantize(
257
- self._embeddings, axis=-1
258
- )
259
- embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
260
- del self._embeddings
261
- if not self.tie_weights:
262
- reverse_embeddings, reverse_embeddings_scale = abs_max_quantize(
263
- self.reverse_embeddings, axis=0
264
- )
265
- reverse_embeddings_scale = ops.squeeze(
266
- reverse_embeddings_scale, axis=0
267
- )
268
- del self.reverse_embeddings
269
- self.quantized_build(embeddings_shape, mode)
270
- if mode == "int8":
271
- self._embeddings.assign(embeddings)
272
- self.embeddings_scale.assign(embeddings_scale)
273
- if not self.tie_weights:
274
- self.reverse_embeddings.assign(reverse_embeddings)
275
- self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
276
-
277
- if self.dtype_policy.quantization_mode is None:
278
- policy = keras.dtype_policies.get(
279
- f"{mode}_from_{self.dtype_policy.name}"
280
- )
281
- self.dtype_policy = policy
7
+ class ReversibleEmbedding(keras.layers.ReversibleEmbedding):
8
+ pass
@@ -1,10 +1,10 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
3
+ from keras.src.backend import get_keras_mask
4
+ from keras.src.backend import set_keras_mask
2
5
 
3
6
  from keras_hub.src.api_export import keras_hub_export
4
7
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
8
  from keras_hub.src.utils.keras_utils import clone_initializer
9
9
 
10
10
 
@@ -128,10 +128,10 @@ class TokenAndPositionEmbedding(keras.layers.Layer):
128
128
  positions=positions,
129
129
  )
130
130
  outputs = embedded_tokens + embedded_positions
131
+ mask = get_keras_mask(embedded_tokens)
132
+ if mask is not None:
133
+ set_keras_mask(outputs, mask)
131
134
  return outputs
132
135
 
133
- def compute_mask(self, inputs, mask=None):
134
- return self.token_embedding.compute_mask(inputs, mask=mask)
135
-
136
136
  def compute_output_shape(self, input_shape):
137
137
  return tuple(input_shape) + (self.embedding_dim,)
@@ -1,11 +1,12 @@
1
1
  from absl import logging
2
2
  from keras import ops
3
+ from keras.src.backend import get_keras_mask
3
4
 
4
5
 
5
6
  def _check_masks_shapes(inputs, padding_mask, attention_mask):
6
7
  mask = padding_mask
7
- if hasattr(inputs, "_keras_mask") and mask is None:
8
- mask = inputs._keras_mask
8
+ if mask is None:
9
+ mask = get_keras_mask(inputs)
9
10
  if mask is not None:
10
11
  if len(mask.shape) != 2:
11
12
  raise ValueError(
@@ -68,17 +69,16 @@ def merge_padding_and_attention_mask(
68
69
  returned mask is padding_mask with one additional axis.
69
70
  """
70
71
  _check_masks_shapes(inputs, padding_mask, attention_mask)
71
- mask = padding_mask
72
- if hasattr(inputs, "_keras_mask"):
73
- if mask is None:
74
- # If no padding mask is explicitly provided, we look for padding
75
- # mask from the input data.
76
- mask = inputs._keras_mask
77
- else:
72
+ # We look for a padding mask from the input data.
73
+ mask = get_keras_mask(inputs)
74
+ # But if padding mask is explicitly provided, we use it.
75
+ if padding_mask is not None:
76
+ if mask is not None:
78
77
  logging.warning(
79
78
  "You are explicitly setting `padding_mask` while the `inputs` "
80
79
  "have built-in mask, so the built-in mask is ignored."
81
80
  )
81
+ mask = padding_mask
82
82
  if mask is not None:
83
83
  # Add an axis for broadcasting, the attention mask should be 2D
84
84
  # (not including the batch axis).
@@ -7,9 +7,11 @@ from keras_hub.src.utils.tensor_utils import preprocessing_function
7
7
 
8
8
  try:
9
9
  import tensorflow as tf
10
- import tensorflow_text as tf_text
11
10
  except ImportError:
12
11
  tf = None
12
+ try:
13
+ import tensorflow_text as tf_text
14
+ except ImportError:
13
15
  tf_text = None
14
16
 
15
17
 
@@ -8,9 +8,11 @@ from keras_hub.src.utils.tensor_utils import preprocessing_function
8
8
 
9
9
  try:
10
10
  import tensorflow as tf
11
- import tensorflow_text as tf_text
12
11
  except ImportError:
13
12
  tf = None
13
+ try:
14
+ import tensorflow_text as tf_text
15
+ except ImportError:
14
16
  tf_text = None
15
17
 
16
18
 
@@ -1,10 +1,8 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
5
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
9
7
  from keras_hub.src.models.backbone import Backbone
10
8
  from keras_hub.src.utils.keras_utils import gelu_approximate
@@ -107,6 +107,9 @@ class Backbone(keras.Model):
107
107
  def from_config(cls, config):
108
108
  # The default `from_config()` for functional models will return a
109
109
  # vanilla `keras.Model`. We override it to get a subclass instance back.
110
+ config = config.copy()
111
+ if "dtype" in config and isinstance(config["dtype"], dict):
112
+ config["dtype"] = keras.dtype_policies.get(config["dtype"])
110
113
  return cls(**config)
111
114
 
112
115
  @classproperty
@@ -1,10 +1,8 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
5
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder
9
7
  from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
10
8
  from keras_hub.src.models.backbone import Backbone
@@ -1,10 +1,8 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
5
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
9
7
  from keras_hub.src.models.backbone import Backbone
10
8
  from keras_hub.src.utils.keras_utils import gelu_approximate
@@ -35,7 +33,7 @@ class BertBackbone(Backbone):
35
33
  vocabulary_size: int. The size of the token vocabulary.
36
34
  num_layers: int. The number of transformer layers.
37
35
  num_heads: int. The number of attention heads for each transformer.
38
- The hidden size must be divisible by the number of attention heads.
36
+ The hidden_dim must be divisible by the number of attention heads.
39
37
  hidden_dim: int. The size of the transformer encoding and pooler layers.
40
38
  intermediate_dim: int. The output dimension of the first Dense layer in
41
39
  a two-layer feedforward network for each transformer.
@@ -1,9 +1,7 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.layers.modeling.reversible_embedding import (
5
- ReversibleEmbedding,
6
- )
7
5
  from keras_hub.src.models.backbone import Backbone
8
6
  from keras_hub.src.models.bloom.bloom_decoder import BloomDecoder
9
7
 
@@ -196,7 +196,7 @@ class CausalLM(Task):
196
196
 
197
197
  # Create an explicit tuple of all variable state.
198
198
  state = (
199
- self.sampler.variables,
199
+ [v.value for v in self.sampler.variables],
200
200
  # Use the explicit variable.value to preserve the
201
201
  # sharding spec of distribution.
202
202
  [v.value for v in self.trainable_variables],
@@ -431,7 +431,7 @@ class CausalLM(Task):
431
431
  self.generate_function = None
432
432
 
433
433
  def get_quantization_layer_structure(self, mode):
434
- if mode != "gptq":
434
+ if mode not in ["gptq", "awq"]:
435
435
  return None
436
436
 
437
437
  backbone = self.backbone
@@ -1,9 +1,7 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.layers.modeling.reversible_embedding import (
5
- ReversibleEmbedding,
6
- )
7
5
  from keras_hub.src.models.backbone import Backbone
8
6
  from keras_hub.src.models.deberta_v3.disentangled_attention_encoder import (
9
7
  DisentangledAttentionEncoder,