keras-hub-nightly 0.24.0.dev202511220420__py3-none-any.whl → 0.26.0.dev202601010440__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-hub-nightly might be problematic. Click here for more details.

Files changed (66) hide show
  1. keras_hub/models/__init__.py +12 -0
  2. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  3. keras_hub/src/layers/modeling/rotary_embedding.py +188 -14
  4. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -3
  5. keras_hub/src/models/albert/albert_backbone.py +1 -3
  6. keras_hub/src/models/bart/bart_backbone.py +1 -3
  7. keras_hub/src/models/bert/bert_backbone.py +1 -3
  8. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  9. keras_hub/src/models/causal_lm.py +23 -1
  10. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  11. keras_hub/src/models/dinov3/dinov3_presets.py +90 -1
  12. keras_hub/src/models/electra/electra_backbone.py +1 -3
  13. keras_hub/src/models/esm/esm_attention.py +11 -4
  14. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  15. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  16. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  17. keras_hub/src/models/gemma/gemma_causal_lm.py +16 -0
  18. keras_hub/src/models/gemma3/gemma3_backbone.py +1 -3
  19. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +8 -3
  20. keras_hub/src/models/gemma3/gemma3_presets.py +12 -0
  21. keras_hub/src/models/gemma3/gemma3_tokenizer.py +20 -8
  22. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  23. keras_hub/src/models/gpt2/gpt2_causal_lm.py +17 -0
  24. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  25. keras_hub/src/models/gpt_oss/__init__.py +5 -0
  26. keras_hub/src/models/gpt_oss/gpt_oss_attention.py +330 -0
  27. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +219 -0
  28. keras_hub/src/models/gpt_oss/gpt_oss_causal_lm.py +284 -0
  29. keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_preprocessor.py +79 -0
  30. keras_hub/src/models/gpt_oss/gpt_oss_decoder.py +444 -0
  31. keras_hub/src/models/gpt_oss/gpt_oss_layer_norm.py +34 -0
  32. keras_hub/src/models/gpt_oss/gpt_oss_presets.py +51 -0
  33. keras_hub/src/models/gpt_oss/gpt_oss_tokenizer.py +39 -0
  34. keras_hub/src/models/llama/llama_backbone.py +1 -3
  35. keras_hub/src/models/llama3/llama3_presets.py +1 -1
  36. keras_hub/src/models/masked_lm.py +22 -0
  37. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  38. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  39. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  40. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  41. keras_hub/src/models/parseq/parseq_decoder.py +21 -9
  42. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  43. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  44. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  45. keras_hub/src/models/qwen3/qwen3_presets.py +36 -0
  46. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  47. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  48. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  49. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  50. keras_hub/src/models/smollm3/__init__.py +5 -0
  51. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  52. keras_hub/src/models/smollm3/smollm3_presets.py +16 -0
  53. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +1 -1
  54. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  55. keras_hub/src/models/t5/t5_backbone.py +1 -3
  56. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  57. keras_hub/src/tests/test_case.py +1 -3
  58. keras_hub/src/utils/transformers/convert_gemma3.py +353 -0
  59. keras_hub/src/utils/transformers/convert_gpt_oss.py +302 -0
  60. keras_hub/src/utils/transformers/preset_loader.py +12 -0
  61. keras_hub/src/version.py +1 -1
  62. keras_hub/tokenizers/__init__.py +3 -0
  63. {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/METADATA +4 -5
  64. {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/RECORD +66 -53
  65. {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/WHEEL +0 -0
  66. {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/top_level.txt +0 -0
@@ -340,6 +340,18 @@ from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import (
340
340
  from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import (
341
341
  GPTNeoXTokenizer as GPTNeoXTokenizer,
342
342
  )
343
+ from keras_hub.src.models.gpt_oss.gpt_oss_backbone import (
344
+ GptOssBackbone as GptOssBackbone,
345
+ )
346
+ from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm import (
347
+ GptOssCausalLM as GptOssCausalLM,
348
+ )
349
+ from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm_preprocessor import (
350
+ GptOssCausalLMPreprocessor as GptOssCausalLMPreprocessor,
351
+ )
352
+ from keras_hub.src.models.gpt_oss.gpt_oss_tokenizer import (
353
+ GptOssTokenizer as GptOssTokenizer,
354
+ )
343
355
  from keras_hub.src.models.hgnetv2.hgnetv2_backbone import (
344
356
  HGNetV2Backbone as HGNetV2Backbone,
345
357
  )
@@ -1,281 +1,8 @@
1
- import inspect
2
-
3
1
  import keras
4
- from keras import ops
5
2
 
6
3
  from keras_hub.src.api_export import keras_hub_export
7
4
 
8
5
 
9
6
  @keras_hub_export("keras_hub.layers.ReversibleEmbedding")
10
- class ReversibleEmbedding(keras.layers.Embedding):
11
- """An embedding layer which can project backwards to the input dim.
12
-
13
- This layer is an extension of `keras.layers.Embedding` for language models.
14
- This layer can be called "in reverse" with `reverse=True`, in which case the
15
- layer will linearly project from `output_dim` back to `input_dim`.
16
-
17
- By default, the reverse projection will use the transpose of the
18
- `embeddings` weights to project to `input_dim` (weights are "tied"). If
19
- `tie_weights=False`, the model will use a separate, trainable variable for
20
- reverse projection.
21
-
22
- This layer has no bias terms.
23
-
24
- Args:
25
- input_dim: Integer. Size of the vocabulary,
26
- i.e. maximum integer index + 1.
27
- output_dim: Integer. Dimension of the dense embedding.
28
- tie_weights: Boolean, whether or not the matrix for embedding and
29
- the matrix for the `reverse` projection should share the same
30
- weights.
31
- embeddings_initializer: Initializer for the `embeddings`
32
- matrix (see `keras.initializers`).
33
- embeddings_regularizer: Regularizer function applied to
34
- the `embeddings` matrix (see `keras.regularizers`).
35
- embeddings_constraint: Constraint function applied to
36
- the `embeddings` matrix (see `keras.constraints`).
37
- mask_zero: Boolean, whether or not the input value 0 is a special
38
- "padding" value that should be masked out.
39
- reverse_dtype: The dtype for the reverse projection computation.
40
- Defaults to the `compute_dtype` of the layer.
41
- logit_soft_cap: If `logit_soft_cap` is set and `reverse=True`, the
42
- output logits will be scaled by
43
- `tanh(logits / logit_soft_cap) * logit_soft_cap`. This narrows the
44
- range of output logits and can improve training.
45
- **kwargs: other keyword arguments passed to `keras.layers.Embedding`,
46
- including `name`, `trainable`, `dtype` etc.
47
-
48
- Call arguments:
49
- inputs: The tensor inputs to the layer.
50
- reverse: Boolean. If `True` the layer will perform a linear projection
51
- from `output_dim` to `input_dim`, instead of a normal embedding
52
- call. Default to `False`.
53
-
54
- Example:
55
- ```python
56
- batch_size = 16
57
- vocab_size = 100
58
- hidden_dim = 32
59
- seq_length = 50
60
-
61
- # Generate random inputs.
62
- token_ids = np.random.randint(vocab_size, size=(batch_size, seq_length))
63
-
64
- embedding = keras_hub.layers.ReversibleEmbedding(vocab_size, hidden_dim)
65
- # Embed tokens to shape `(batch_size, seq_length, hidden_dim)`.
66
- hidden_states = embedding(token_ids)
67
- # Project hidden states to shape `(batch_size, seq_length, vocab_size)`.
68
- logits = embedding(hidden_states, reverse=True)
69
- ```
70
-
71
- References:
72
- - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
73
- - [Press and Wolf, 2016](https://arxiv.org/abs/1608.05859)
74
- """
75
-
76
- def __init__(
77
- self,
78
- input_dim,
79
- output_dim,
80
- tie_weights=True,
81
- embeddings_initializer="uniform",
82
- embeddings_regularizer=None,
83
- embeddings_constraint=None,
84
- mask_zero=False,
85
- reverse_dtype=None,
86
- logit_soft_cap=None,
87
- **kwargs,
88
- ):
89
- super().__init__(
90
- input_dim,
91
- output_dim,
92
- embeddings_initializer=embeddings_initializer,
93
- embeddings_regularizer=embeddings_regularizer,
94
- embeddings_constraint=embeddings_constraint,
95
- mask_zero=mask_zero,
96
- **kwargs,
97
- )
98
- self.tie_weights = tie_weights
99
- self.reverse_dtype = reverse_dtype
100
- self.logit_soft_cap = logit_soft_cap
101
-
102
- def build(self, inputs_shape=None):
103
- super().build(inputs_shape)
104
- if (
105
- not self.tie_weights
106
- and getattr(self, "quantization_mode", None) != "int8"
107
- ):
108
- self.reverse_embeddings = self.add_weight(
109
- name="reverse_embeddings",
110
- shape=(self.output_dim, self.input_dim),
111
- initializer=self.embeddings_initializer,
112
- dtype=self.dtype,
113
- )
114
-
115
- def call(self, inputs, reverse=False):
116
- if reverse:
117
- if self.tie_weights:
118
- kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
119
- else:
120
- kernel = self.reverse_embeddings
121
- if self.reverse_dtype is not None:
122
- inputs = ops.cast(inputs, self.reverse_dtype)
123
- kernel = ops.cast(kernel, self.reverse_dtype)
124
- logits = ops.matmul(inputs, kernel)
125
- # Optionally soft-cap logits.
126
- if self.logit_soft_cap is not None:
127
- soft_cap = self.logit_soft_cap
128
- logits = ops.tanh(logits / soft_cap) * soft_cap
129
- return logits
130
-
131
- return super().call(inputs)
132
-
133
- def get_config(self):
134
- config = super().get_config()
135
- config.update(
136
- {
137
- "tie_weights": self.tie_weights,
138
- "reverse_dtype": self.reverse_dtype,
139
- "logit_soft_cap": self.logit_soft_cap,
140
- }
141
- )
142
- return config
143
-
144
- def save_own_variables(self, store):
145
- if not self.built:
146
- return
147
- super().save_own_variables(store)
148
- target_variables = []
149
- if not self.tie_weights:
150
- # Store the reverse embedding weights as the last weights.
151
- target_variables.append(self.reverse_embeddings)
152
- if getattr(self, "quantization_mode", None) == "int8":
153
- target_variables.append(self.reverse_embeddings_scale)
154
- for i, variable in enumerate(target_variables, start=len(store)):
155
- store[str(i)] = variable
156
-
157
- def load_own_variables(self, store):
158
- if not self.built:
159
- self.build()
160
- super().load_own_variables(store)
161
- if not self.tie_weights:
162
- # Last weights in the stores are the reverse embedding weights.
163
- target_variables = [self.reverse_embeddings]
164
- if getattr(self, "quantization_mode", None) == "int8":
165
- target_variables.append(self.reverse_embeddings_scale)
166
- for i, variable in enumerate(
167
- target_variables, start=len(store) - len(target_variables)
168
- ):
169
- variable.assign(store[str(i)])
170
-
171
- def compute_output_spec(self, inputs, reverse=False):
172
- output_shape = list(inputs.shape)
173
- if reverse:
174
- output_shape[-1] = self.input_dim
175
- else:
176
- output_shape += [self.output_dim]
177
- return keras.KerasTensor(output_shape, dtype=self.compute_dtype)
178
-
179
- # Quantization-related (int8) methods
180
-
181
- def quantized_call(self, inputs, reverse=False):
182
- # TODO (hongyu): This function could be removed once we add `*args` and
183
- # `**kwargs` for `Embedding.quantized_call`
184
- if self.quantization_mode == "int8":
185
- return self._int8_call(inputs, reverse=reverse)
186
- else:
187
- self._quantization_mode_error(self.quantization_mode)
188
-
189
- def _int8_build(self, embeddings_shape=None):
190
- if (
191
- "embeddings_shape"
192
- in inspect.signature(super()._int8_build).parameters
193
- ):
194
- if embeddings_shape is None:
195
- embeddings_shape = (self.input_dim, self.output_dim)
196
- super()._int8_build(embeddings_shape=embeddings_shape)
197
- else:
198
- # Backward compatibility for older versions of Keras.
199
- super()._int8_build()
200
- self.inputs_quantizer = keras.quantizers.AbsMaxQuantizer(axis=-1)
201
- if not self.tie_weights:
202
- self.reverse_embeddings = self.add_weight(
203
- name="reverse_embeddings",
204
- shape=(self.output_dim, self.input_dim),
205
- initializer="zeros",
206
- dtype="int8",
207
- trainable=False,
208
- )
209
- self.reverse_embeddings_scale = self.add_weight(
210
- name="reverse_embeddings_scale",
211
- shape=(self.input_dim,),
212
- initializer="ones",
213
- trainable=False,
214
- )
215
- self._is_quantized = True
216
-
217
- def _int8_call(self, inputs, reverse=False):
218
- if reverse:
219
- if self.tie_weights:
220
- kernel = ops.transpose(self._embeddings)
221
- scale = ops.transpose(self.embeddings_scale)
222
- else:
223
- kernel = self.reverse_embeddings
224
- scale = self.reverse_embeddings_scale
225
- inputs, inputs_scale = self.inputs_quantizer(inputs)
226
- logits = ops.matmul(inputs, kernel)
227
- # De-scale outputs
228
- logits = ops.cast(logits, self.compute_dtype)
229
- logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
230
- # Optionally soft-cap logits.
231
- if self.logit_soft_cap is not None:
232
- soft_cap = self.logit_soft_cap
233
- logits = ops.tanh(logits / soft_cap) * soft_cap
234
- return logits
235
-
236
- return super()._int8_call(inputs)
237
-
238
- def quantize(self, mode, type_check=True, config=None):
239
- del config
240
- if type_check and type(self) is not ReversibleEmbedding:
241
- raise self._not_implemented_error(self.quantize)
242
-
243
- def abs_max_quantize(inputs, axis):
244
- return keras.quantizers.abs_max_quantize(
245
- inputs, axis=axis, to_numpy=True
246
- )
247
-
248
- if mode != "int8":
249
- raise NotImplementedError(
250
- "Invalid quantization mode. Expected 'int8'. "
251
- f"Received: quantization_mode={mode}"
252
- )
253
-
254
- embeddings_shape = (self.input_dim, self.output_dim)
255
- if mode == "int8":
256
- embeddings, embeddings_scale = abs_max_quantize(
257
- self._embeddings, axis=-1
258
- )
259
- embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
260
- del self._embeddings
261
- if not self.tie_weights:
262
- reverse_embeddings, reverse_embeddings_scale = abs_max_quantize(
263
- self.reverse_embeddings, axis=0
264
- )
265
- reverse_embeddings_scale = ops.squeeze(
266
- reverse_embeddings_scale, axis=0
267
- )
268
- del self.reverse_embeddings
269
- self.quantized_build(embeddings_shape, mode)
270
- if mode == "int8":
271
- self._embeddings.assign(embeddings)
272
- self.embeddings_scale.assign(embeddings_scale)
273
- if not self.tie_weights:
274
- self.reverse_embeddings.assign(reverse_embeddings)
275
- self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
276
-
277
- if self.dtype_policy.quantization_mode is None:
278
- policy = keras.dtype_policies.get(
279
- f"{mode}_from_{self.dtype_policy.name}"
280
- )
281
- self.dtype_policy = policy
7
+ class ReversibleEmbedding(keras.layers.ReversibleEmbedding):
8
+ pass
@@ -1,4 +1,5 @@
1
1
  import keras
2
+ import numpy as np
2
3
  from keras import ops
3
4
 
4
5
  from keras_hub.src.api_export import keras_hub_export
@@ -25,6 +26,17 @@ class RotaryEmbedding(keras.layers.Layer):
25
26
  curves.
26
27
  scaling_factor: float. The scaling factor used to scale positions of
27
28
  the tokens.
29
+ rope_type: str. The type of RoPE scaling to apply. Supported types:
30
+ "linear", "dynamic", "yarn". Defaults to "linear".
31
+ beta_fast: float. Beta fast parameter for YaRN scaling. Only used
32
+ when rope_type="yarn". Defaults to 32.0.
33
+ beta_slow: float. Beta slow parameter for YaRN scaling. Only used
34
+ when rope_type="yarn". Defaults to 1.0.
35
+ original_max_position_embeddings: int. Original maximum position
36
+ embeddings for YaRN scaling. Only used when rope_type="yarn".
37
+ Defaults to 4096.
38
+ truncate: bool. Whether to apply truncation for YaRN scaling. Only used
39
+ when rope_type="yarn". Defaults to False.
28
40
  sequence_axis: int. Sequence axis in the input tensor.
29
41
  feature_axis: int. Feature axis in the input tensor.
30
42
  **kwargs: other keyword arguments passed to `keras.layers.Layer`,
@@ -69,6 +81,11 @@ class RotaryEmbedding(keras.layers.Layer):
69
81
  self,
70
82
  max_wavelength=10000,
71
83
  scaling_factor=1.0,
84
+ rope_type="linear",
85
+ beta_fast=32.0,
86
+ beta_slow=1.0,
87
+ original_max_position_embeddings=4096,
88
+ truncate=False,
72
89
  sequence_axis=1,
73
90
  feature_axis=-1,
74
91
  **kwargs,
@@ -78,24 +95,70 @@ class RotaryEmbedding(keras.layers.Layer):
78
95
  self.sequence_axis = sequence_axis
79
96
  self.feature_axis = feature_axis
80
97
  self.scaling_factor = scaling_factor
98
+ self.rope_type = rope_type
99
+
100
+ # YaRN-specific parameters (only used when rope_type="yarn")
101
+ self.beta_fast = beta_fast
102
+ self.beta_slow = beta_slow
103
+ self.original_max_position_embeddings = original_max_position_embeddings
104
+ self.truncate = truncate
81
105
  self.built = True
82
106
 
107
+ def _normalize_axes(self, input_shape):
108
+ """Normalize and validate axis indices for the given input shape."""
109
+ rank = len(input_shape)
110
+
111
+ # Normalize negative indices
112
+ sequence_axis = self.sequence_axis
113
+ feature_axis = self.feature_axis
114
+
115
+ if sequence_axis < 0:
116
+ sequence_axis += rank
117
+ if feature_axis < 0:
118
+ feature_axis += rank
119
+
120
+ if sequence_axis < 0 or sequence_axis >= rank:
121
+ raise ValueError(
122
+ f"sequence_axis {self.sequence_axis} "
123
+ f"is out of range for input with rank {rank}"
124
+ )
125
+ if feature_axis < 0 or feature_axis >= rank:
126
+ raise ValueError(
127
+ f"feature_axis {self.feature_axis} "
128
+ f"is out of range for input with rank {rank}"
129
+ )
130
+ if sequence_axis == feature_axis:
131
+ raise ValueError("sequence_axis and feature_axis must be different")
132
+
133
+ return sequence_axis, feature_axis
134
+
135
+ def _validate_rotary_dimension(self, rotary_dim):
136
+ if rotary_dim % 2 != 0:
137
+ raise ValueError(
138
+ f"Rotary dimension must be even, got {rotary_dim}."
139
+ "The rotary embedding splits the feature dimension "
140
+ "into two halves. Consider using a different feature "
141
+ "dimension or padding."
142
+ )
143
+
83
144
  def call(self, inputs, start_index=0, positions=None):
145
+ input_shape = ops.shape(inputs)
146
+ sequence_axis, feature_axis = self._normalize_axes(input_shape)
147
+
148
+ rotary_dim = input_shape[feature_axis]
149
+ self._validate_rotary_dimension(rotary_dim)
150
+
84
151
  # Take care of unbatched `positions`.
85
152
  if positions is not None:
86
153
  if len(ops.shape(positions)) == 1:
87
154
  positions = ops.expand_dims(positions, axis=0)
88
155
 
89
- inputs = ops.moveaxis(
90
- inputs, (self.feature_axis, self.sequence_axis), (-1, 1)
91
- )
156
+ inputs = ops.moveaxis(inputs, (feature_axis, sequence_axis), (-1, 1))
92
157
  cos_emb, sin_emb = self._compute_cos_sin_embedding(
93
158
  inputs, start_index, positions
94
159
  )
95
160
  output = self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb)
96
- return ops.moveaxis(
97
- output, (-1, 1), (self.feature_axis, self.sequence_axis)
98
- )
161
+ return ops.moveaxis(output, (-1, 1), (feature_axis, sequence_axis))
99
162
 
100
163
  def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
101
164
  x1, x2 = ops.split(tensor, 2, axis=-1)
@@ -113,19 +176,35 @@ class RotaryEmbedding(keras.layers.Layer):
113
176
  return positions + ops.cast(start_index, dtype="float32")
114
177
 
115
178
  def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
179
+ """Compute cos & sin RoPE embeddings with optional YaRN scaling.
180
+ Uses tensor ops only to remain JIT/backends friendly.
181
+ """
116
182
  batch_axis = 0
117
- feature_axis = len(inputs.shape) - 1
118
183
  sequence_axis = 1
184
+ feature_axis = len(inputs.shape) - 1
119
185
 
120
186
  rotary_dim = ops.shape(inputs)[feature_axis]
121
187
  inverse_freq = self._get_inverse_freq(rotary_dim)
122
188
 
123
189
  if positions is None:
124
190
  positions = self._compute_positions(inputs, start_index)
125
- positions = ops.expand_dims(positions, axis=batch_axis)
191
+ positions = ops.expand_dims(
192
+ positions, axis=batch_axis
193
+ ) # shape (1, seq_len)
126
194
  else:
127
195
  positions = ops.cast(positions, "float32")
128
- positions = positions / ops.cast(self.scaling_factor, "float32")
196
+ if len(ops.shape(positions)) == 1:
197
+ positions = ops.expand_dims(positions, axis=batch_axis)
198
+
199
+ if (
200
+ self.rope_type == "yarn"
201
+ and self.truncate
202
+ and self.original_max_position_embeddings is not None
203
+ ):
204
+ positions = ops.minimum(
205
+ positions,
206
+ ops.cast(self.original_max_position_embeddings, "float32"),
207
+ )
129
208
 
130
209
  freq = ops.einsum("bi,j->bij", positions, inverse_freq)
131
210
 
@@ -140,15 +219,103 @@ class RotaryEmbedding(keras.layers.Layer):
140
219
 
141
220
  cos_emb = ops.cast(ops.cos(embedding), self.compute_dtype)
142
221
  sin_emb = ops.cast(ops.sin(embedding), self.compute_dtype)
222
+
223
+ if self.rope_type == "yarn":
224
+ # YaRN temperature scaling
225
+ factor = ops.add(
226
+ ops.multiply(
227
+ ops.cast(0.1, self.compute_dtype),
228
+ ops.log(ops.cast(self.scaling_factor, self.compute_dtype)),
229
+ ),
230
+ ops.cast(1.0, self.compute_dtype),
231
+ )
232
+ cos_emb = cos_emb * factor
233
+ sin_emb = sin_emb * factor
143
234
  return cos_emb, sin_emb
144
235
 
145
236
  def _get_inverse_freq(self, rotary_dim):
146
- freq_range = ops.divide(
147
- ops.arange(0, rotary_dim, 2, dtype="float32"),
148
- ops.cast(rotary_dim, "float32"),
237
+ """Return inverse frequencies."""
238
+ idx = ops.arange(0, rotary_dim, 2, dtype="float32")
239
+ denom = ops.cast(rotary_dim, "float32")
240
+ freq_range = idx / denom
241
+ inv = ops.power(ops.cast(self.max_wavelength, "float32"), -freq_range)
242
+
243
+ if self.rope_type == "linear":
244
+ return inv / ops.cast(self.scaling_factor, "float32")
245
+ elif self.rope_type == "dynamic":
246
+ exponent = ops.cast(rotary_dim, "float32") / ops.cast(
247
+ max(1, rotary_dim - 2), "float32"
248
+ )
249
+ return inv / ops.power(
250
+ ops.cast(self.scaling_factor, "float32"), exponent
251
+ )
252
+ elif self.rope_type == "yarn":
253
+ return self._get_yarn_inverse_freq(rotary_dim)
254
+ else:
255
+ return inv
256
+
257
+ def _get_yarn_inverse_freq(self, rotary_dim):
258
+ # Get the base (rope_theta equivalent) from max_wavelength
259
+ base = ops.cast(self.max_wavelength, "float32")
260
+
261
+ # Compute base frequencies: base ** (idx / dim)
262
+ idx = ops.arange(0, rotary_dim, 2, dtype="float32")
263
+ pos_freqs = ops.power(base, idx / ops.cast(rotary_dim, "float32"))
264
+
265
+ # Compute interpolation and extrapolation frequencies
266
+ inv_freq_extrapolation = 1.0 / pos_freqs
267
+ inv_freq_interpolation = 1.0 / (
268
+ ops.cast(self.scaling_factor, "float32") * pos_freqs
149
269
  )
150
- inverse_freq = 1.0 / (self.max_wavelength**freq_range)
151
- return inverse_freq
270
+
271
+ # Find correction range
272
+ beta_fast = ops.cast(self.beta_fast, "float32")
273
+ beta_slow = ops.cast(self.beta_slow, "float32")
274
+
275
+ # Find correction dimensions for beta_fast and beta_slow
276
+ def find_correction_dim_tensor(num_rotations, dim):
277
+ max_pos = ops.cast(self.original_max_position_embeddings, "float32")
278
+ return (dim * ops.log(max_pos / (num_rotations * 2 * np.pi))) / (
279
+ 2 * ops.log(base)
280
+ )
281
+
282
+ low = find_correction_dim_tensor(
283
+ beta_fast, ops.cast(rotary_dim, "float32")
284
+ )
285
+ high = find_correction_dim_tensor(
286
+ beta_slow, ops.cast(rotary_dim, "float32")
287
+ )
288
+
289
+ # Apply truncation if specified
290
+ if self.truncate:
291
+ low = ops.floor(low)
292
+ high = ops.ceil(high)
293
+
294
+ # Clamp to valid range
295
+ low = ops.maximum(low, ops.cast(0, "float32"))
296
+ high = ops.minimum(high, ops.cast(rotary_dim // 2 - 1, "float32"))
297
+
298
+ # Linear ramp function
299
+ dim_half = rotary_dim // 2
300
+ idx_half = ops.arange(0, dim_half, dtype="float32")
301
+
302
+ # Prevent singularity
303
+ diff = high - low
304
+ diff = ops.maximum(diff, ops.cast(0.001, "float32"))
305
+
306
+ linear_func = (idx_half - low) / diff
307
+ ramp_func = ops.clip(linear_func, 0, 1)
308
+
309
+ # Apply the ramp to get extrapolation factor
310
+ inv_freq_extrapolation_factor = 1 - ramp_func
311
+
312
+ # Combine interpolation and extrapolation
313
+ scaled_inverse_freq = (
314
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
315
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
316
+ )
317
+
318
+ return scaled_inverse_freq
152
319
 
153
320
  def get_config(self):
154
321
  config = super().get_config()
@@ -156,6 +323,13 @@ class RotaryEmbedding(keras.layers.Layer):
156
323
  {
157
324
  "max_wavelength": self.max_wavelength,
158
325
  "scaling_factor": self.scaling_factor,
326
+ "rope_type": self.rope_type,
327
+ "beta_fast": self.beta_fast,
328
+ "beta_slow": self.beta_slow,
329
+ "original_max_position_embeddings": (
330
+ self.original_max_position_embeddings
331
+ ),
332
+ "truncate": self.truncate,
159
333
  "sequence_axis": self.sequence_axis,
160
334
  "feature_axis": self.feature_axis,
161
335
  }
@@ -1,10 +1,8 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
5
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.utils.keras_utils import clone_initializer
9
7
 
10
8
 
@@ -1,10 +1,8 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
5
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
9
7
  from keras_hub.src.models.backbone import Backbone
10
8
  from keras_hub.src.utils.keras_utils import gelu_approximate
@@ -1,10 +1,8 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
5
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder
9
7
  from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
10
8
  from keras_hub.src.models.backbone import Backbone
@@ -1,10 +1,8 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
5
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
9
7
  from keras_hub.src.models.backbone import Backbone
10
8
  from keras_hub.src.utils.keras_utils import gelu_approximate
@@ -1,9 +1,7 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.layers.modeling.reversible_embedding import (
5
- ReversibleEmbedding,
6
- )
7
5
  from keras_hub.src.models.backbone import Backbone
8
6
  from keras_hub.src.models.bloom.bloom_decoder import BloomDecoder
9
7
 
@@ -196,7 +196,7 @@ class CausalLM(Task):
196
196
 
197
197
  # Create an explicit tuple of all variable state.
198
198
  state = (
199
- self.sampler.variables,
199
+ [v.value for v in self.sampler.variables],
200
200
  # Use the explicit variable.value to preserve the
201
201
  # sharding spec of distribution.
202
202
  [v.value for v in self.trainable_variables],
@@ -429,3 +429,25 @@ class CausalLM(Task):
429
429
  super()._post_quantize(mode, **kwargs)
430
430
  # Reset the compiled generate function.
431
431
  self.generate_function = None
432
+
433
+ def get_quantization_layer_structure(self, mode):
434
+ if mode != "gptq":
435
+ return None
436
+
437
+ backbone = self.backbone
438
+ # Check for standard backbone structure.
439
+ if not hasattr(backbone, "transformer_layers"):
440
+ return None
441
+
442
+ # Check for embedding.
443
+ embedding = getattr(backbone, "token_embedding", None)
444
+ if embedding is None:
445
+ embedding = getattr(backbone, "embedding", None)
446
+
447
+ if embedding is None:
448
+ return None
449
+
450
+ return {
451
+ "pre_block_layers": [embedding],
452
+ "sequential_blocks": backbone.transformer_layers,
453
+ }
@@ -1,9 +1,7 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.layers.modeling.reversible_embedding import (
5
- ReversibleEmbedding,
6
- )
7
5
  from keras_hub.src.models.backbone import Backbone
8
6
  from keras_hub.src.models.deberta_v3.disentangled_attention_encoder import (
9
7
  DisentangledAttentionEncoder,