keras-hub-nightly 0.24.0.dev202511220420__py3-none-any.whl → 0.26.0.dev202601010440__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/models/__init__.py +12 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/rotary_embedding.py +188 -14
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -3
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +1 -3
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +23 -1
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/dinov3/dinov3_presets.py +90 -1
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/esm/esm_attention.py +11 -4
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +16 -0
- keras_hub/src/models/gemma3/gemma3_backbone.py +1 -3
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +8 -3
- keras_hub/src/models/gemma3/gemma3_presets.py +12 -0
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +20 -8
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +17 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/__init__.py +5 -0
- keras_hub/src/models/gpt_oss/gpt_oss_attention.py +330 -0
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +219 -0
- keras_hub/src/models/gpt_oss/gpt_oss_causal_lm.py +284 -0
- keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_preprocessor.py +79 -0
- keras_hub/src/models/gpt_oss/gpt_oss_decoder.py +444 -0
- keras_hub/src/models/gpt_oss/gpt_oss_layer_norm.py +34 -0
- keras_hub/src/models/gpt_oss/gpt_oss_presets.py +51 -0
- keras_hub/src/models/gpt_oss/gpt_oss_tokenizer.py +39 -0
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/llama3/llama3_presets.py +1 -1
- keras_hub/src/models/masked_lm.py +22 -0
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_decoder.py +21 -9
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3/qwen3_presets.py +36 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/__init__.py +5 -0
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/smollm3/smollm3_presets.py +16 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +1 -1
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/tests/test_case.py +1 -3
- keras_hub/src/utils/transformers/convert_gemma3.py +353 -0
- keras_hub/src/utils/transformers/convert_gpt_oss.py +302 -0
- keras_hub/src/utils/transformers/preset_loader.py +12 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/METADATA +4 -5
- {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/RECORD +66 -53
- {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/top_level.txt +0 -0
keras_hub/models/__init__.py
CHANGED
|
@@ -340,6 +340,18 @@ from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import (
|
|
|
340
340
|
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import (
|
|
341
341
|
GPTNeoXTokenizer as GPTNeoXTokenizer,
|
|
342
342
|
)
|
|
343
|
+
from keras_hub.src.models.gpt_oss.gpt_oss_backbone import (
|
|
344
|
+
GptOssBackbone as GptOssBackbone,
|
|
345
|
+
)
|
|
346
|
+
from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm import (
|
|
347
|
+
GptOssCausalLM as GptOssCausalLM,
|
|
348
|
+
)
|
|
349
|
+
from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm_preprocessor import (
|
|
350
|
+
GptOssCausalLMPreprocessor as GptOssCausalLMPreprocessor,
|
|
351
|
+
)
|
|
352
|
+
from keras_hub.src.models.gpt_oss.gpt_oss_tokenizer import (
|
|
353
|
+
GptOssTokenizer as GptOssTokenizer,
|
|
354
|
+
)
|
|
343
355
|
from keras_hub.src.models.hgnetv2.hgnetv2_backbone import (
|
|
344
356
|
HGNetV2Backbone as HGNetV2Backbone,
|
|
345
357
|
)
|
|
@@ -1,281 +1,8 @@
|
|
|
1
|
-
import inspect
|
|
2
|
-
|
|
3
1
|
import keras
|
|
4
|
-
from keras import ops
|
|
5
2
|
|
|
6
3
|
from keras_hub.src.api_export import keras_hub_export
|
|
7
4
|
|
|
8
5
|
|
|
9
6
|
@keras_hub_export("keras_hub.layers.ReversibleEmbedding")
|
|
10
|
-
class ReversibleEmbedding(keras.layers.
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
This layer is an extension of `keras.layers.Embedding` for language models.
|
|
14
|
-
This layer can be called "in reverse" with `reverse=True`, in which case the
|
|
15
|
-
layer will linearly project from `output_dim` back to `input_dim`.
|
|
16
|
-
|
|
17
|
-
By default, the reverse projection will use the transpose of the
|
|
18
|
-
`embeddings` weights to project to `input_dim` (weights are "tied"). If
|
|
19
|
-
`tie_weights=False`, the model will use a separate, trainable variable for
|
|
20
|
-
reverse projection.
|
|
21
|
-
|
|
22
|
-
This layer has no bias terms.
|
|
23
|
-
|
|
24
|
-
Args:
|
|
25
|
-
input_dim: Integer. Size of the vocabulary,
|
|
26
|
-
i.e. maximum integer index + 1.
|
|
27
|
-
output_dim: Integer. Dimension of the dense embedding.
|
|
28
|
-
tie_weights: Boolean, whether or not the matrix for embedding and
|
|
29
|
-
the matrix for the `reverse` projection should share the same
|
|
30
|
-
weights.
|
|
31
|
-
embeddings_initializer: Initializer for the `embeddings`
|
|
32
|
-
matrix (see `keras.initializers`).
|
|
33
|
-
embeddings_regularizer: Regularizer function applied to
|
|
34
|
-
the `embeddings` matrix (see `keras.regularizers`).
|
|
35
|
-
embeddings_constraint: Constraint function applied to
|
|
36
|
-
the `embeddings` matrix (see `keras.constraints`).
|
|
37
|
-
mask_zero: Boolean, whether or not the input value 0 is a special
|
|
38
|
-
"padding" value that should be masked out.
|
|
39
|
-
reverse_dtype: The dtype for the reverse projection computation.
|
|
40
|
-
Defaults to the `compute_dtype` of the layer.
|
|
41
|
-
logit_soft_cap: If `logit_soft_cap` is set and `reverse=True`, the
|
|
42
|
-
output logits will be scaled by
|
|
43
|
-
`tanh(logits / logit_soft_cap) * logit_soft_cap`. This narrows the
|
|
44
|
-
range of output logits and can improve training.
|
|
45
|
-
**kwargs: other keyword arguments passed to `keras.layers.Embedding`,
|
|
46
|
-
including `name`, `trainable`, `dtype` etc.
|
|
47
|
-
|
|
48
|
-
Call arguments:
|
|
49
|
-
inputs: The tensor inputs to the layer.
|
|
50
|
-
reverse: Boolean. If `True` the layer will perform a linear projection
|
|
51
|
-
from `output_dim` to `input_dim`, instead of a normal embedding
|
|
52
|
-
call. Default to `False`.
|
|
53
|
-
|
|
54
|
-
Example:
|
|
55
|
-
```python
|
|
56
|
-
batch_size = 16
|
|
57
|
-
vocab_size = 100
|
|
58
|
-
hidden_dim = 32
|
|
59
|
-
seq_length = 50
|
|
60
|
-
|
|
61
|
-
# Generate random inputs.
|
|
62
|
-
token_ids = np.random.randint(vocab_size, size=(batch_size, seq_length))
|
|
63
|
-
|
|
64
|
-
embedding = keras_hub.layers.ReversibleEmbedding(vocab_size, hidden_dim)
|
|
65
|
-
# Embed tokens to shape `(batch_size, seq_length, hidden_dim)`.
|
|
66
|
-
hidden_states = embedding(token_ids)
|
|
67
|
-
# Project hidden states to shape `(batch_size, seq_length, vocab_size)`.
|
|
68
|
-
logits = embedding(hidden_states, reverse=True)
|
|
69
|
-
```
|
|
70
|
-
|
|
71
|
-
References:
|
|
72
|
-
- [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
|
|
73
|
-
- [Press and Wolf, 2016](https://arxiv.org/abs/1608.05859)
|
|
74
|
-
"""
|
|
75
|
-
|
|
76
|
-
def __init__(
|
|
77
|
-
self,
|
|
78
|
-
input_dim,
|
|
79
|
-
output_dim,
|
|
80
|
-
tie_weights=True,
|
|
81
|
-
embeddings_initializer="uniform",
|
|
82
|
-
embeddings_regularizer=None,
|
|
83
|
-
embeddings_constraint=None,
|
|
84
|
-
mask_zero=False,
|
|
85
|
-
reverse_dtype=None,
|
|
86
|
-
logit_soft_cap=None,
|
|
87
|
-
**kwargs,
|
|
88
|
-
):
|
|
89
|
-
super().__init__(
|
|
90
|
-
input_dim,
|
|
91
|
-
output_dim,
|
|
92
|
-
embeddings_initializer=embeddings_initializer,
|
|
93
|
-
embeddings_regularizer=embeddings_regularizer,
|
|
94
|
-
embeddings_constraint=embeddings_constraint,
|
|
95
|
-
mask_zero=mask_zero,
|
|
96
|
-
**kwargs,
|
|
97
|
-
)
|
|
98
|
-
self.tie_weights = tie_weights
|
|
99
|
-
self.reverse_dtype = reverse_dtype
|
|
100
|
-
self.logit_soft_cap = logit_soft_cap
|
|
101
|
-
|
|
102
|
-
def build(self, inputs_shape=None):
|
|
103
|
-
super().build(inputs_shape)
|
|
104
|
-
if (
|
|
105
|
-
not self.tie_weights
|
|
106
|
-
and getattr(self, "quantization_mode", None) != "int8"
|
|
107
|
-
):
|
|
108
|
-
self.reverse_embeddings = self.add_weight(
|
|
109
|
-
name="reverse_embeddings",
|
|
110
|
-
shape=(self.output_dim, self.input_dim),
|
|
111
|
-
initializer=self.embeddings_initializer,
|
|
112
|
-
dtype=self.dtype,
|
|
113
|
-
)
|
|
114
|
-
|
|
115
|
-
def call(self, inputs, reverse=False):
|
|
116
|
-
if reverse:
|
|
117
|
-
if self.tie_weights:
|
|
118
|
-
kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
|
|
119
|
-
else:
|
|
120
|
-
kernel = self.reverse_embeddings
|
|
121
|
-
if self.reverse_dtype is not None:
|
|
122
|
-
inputs = ops.cast(inputs, self.reverse_dtype)
|
|
123
|
-
kernel = ops.cast(kernel, self.reverse_dtype)
|
|
124
|
-
logits = ops.matmul(inputs, kernel)
|
|
125
|
-
# Optionally soft-cap logits.
|
|
126
|
-
if self.logit_soft_cap is not None:
|
|
127
|
-
soft_cap = self.logit_soft_cap
|
|
128
|
-
logits = ops.tanh(logits / soft_cap) * soft_cap
|
|
129
|
-
return logits
|
|
130
|
-
|
|
131
|
-
return super().call(inputs)
|
|
132
|
-
|
|
133
|
-
def get_config(self):
|
|
134
|
-
config = super().get_config()
|
|
135
|
-
config.update(
|
|
136
|
-
{
|
|
137
|
-
"tie_weights": self.tie_weights,
|
|
138
|
-
"reverse_dtype": self.reverse_dtype,
|
|
139
|
-
"logit_soft_cap": self.logit_soft_cap,
|
|
140
|
-
}
|
|
141
|
-
)
|
|
142
|
-
return config
|
|
143
|
-
|
|
144
|
-
def save_own_variables(self, store):
|
|
145
|
-
if not self.built:
|
|
146
|
-
return
|
|
147
|
-
super().save_own_variables(store)
|
|
148
|
-
target_variables = []
|
|
149
|
-
if not self.tie_weights:
|
|
150
|
-
# Store the reverse embedding weights as the last weights.
|
|
151
|
-
target_variables.append(self.reverse_embeddings)
|
|
152
|
-
if getattr(self, "quantization_mode", None) == "int8":
|
|
153
|
-
target_variables.append(self.reverse_embeddings_scale)
|
|
154
|
-
for i, variable in enumerate(target_variables, start=len(store)):
|
|
155
|
-
store[str(i)] = variable
|
|
156
|
-
|
|
157
|
-
def load_own_variables(self, store):
|
|
158
|
-
if not self.built:
|
|
159
|
-
self.build()
|
|
160
|
-
super().load_own_variables(store)
|
|
161
|
-
if not self.tie_weights:
|
|
162
|
-
# Last weights in the stores are the reverse embedding weights.
|
|
163
|
-
target_variables = [self.reverse_embeddings]
|
|
164
|
-
if getattr(self, "quantization_mode", None) == "int8":
|
|
165
|
-
target_variables.append(self.reverse_embeddings_scale)
|
|
166
|
-
for i, variable in enumerate(
|
|
167
|
-
target_variables, start=len(store) - len(target_variables)
|
|
168
|
-
):
|
|
169
|
-
variable.assign(store[str(i)])
|
|
170
|
-
|
|
171
|
-
def compute_output_spec(self, inputs, reverse=False):
|
|
172
|
-
output_shape = list(inputs.shape)
|
|
173
|
-
if reverse:
|
|
174
|
-
output_shape[-1] = self.input_dim
|
|
175
|
-
else:
|
|
176
|
-
output_shape += [self.output_dim]
|
|
177
|
-
return keras.KerasTensor(output_shape, dtype=self.compute_dtype)
|
|
178
|
-
|
|
179
|
-
# Quantization-related (int8) methods
|
|
180
|
-
|
|
181
|
-
def quantized_call(self, inputs, reverse=False):
|
|
182
|
-
# TODO (hongyu): This function could be removed once we add `*args` and
|
|
183
|
-
# `**kwargs` for `Embedding.quantized_call`
|
|
184
|
-
if self.quantization_mode == "int8":
|
|
185
|
-
return self._int8_call(inputs, reverse=reverse)
|
|
186
|
-
else:
|
|
187
|
-
self._quantization_mode_error(self.quantization_mode)
|
|
188
|
-
|
|
189
|
-
def _int8_build(self, embeddings_shape=None):
|
|
190
|
-
if (
|
|
191
|
-
"embeddings_shape"
|
|
192
|
-
in inspect.signature(super()._int8_build).parameters
|
|
193
|
-
):
|
|
194
|
-
if embeddings_shape is None:
|
|
195
|
-
embeddings_shape = (self.input_dim, self.output_dim)
|
|
196
|
-
super()._int8_build(embeddings_shape=embeddings_shape)
|
|
197
|
-
else:
|
|
198
|
-
# Backward compatibility for older versions of Keras.
|
|
199
|
-
super()._int8_build()
|
|
200
|
-
self.inputs_quantizer = keras.quantizers.AbsMaxQuantizer(axis=-1)
|
|
201
|
-
if not self.tie_weights:
|
|
202
|
-
self.reverse_embeddings = self.add_weight(
|
|
203
|
-
name="reverse_embeddings",
|
|
204
|
-
shape=(self.output_dim, self.input_dim),
|
|
205
|
-
initializer="zeros",
|
|
206
|
-
dtype="int8",
|
|
207
|
-
trainable=False,
|
|
208
|
-
)
|
|
209
|
-
self.reverse_embeddings_scale = self.add_weight(
|
|
210
|
-
name="reverse_embeddings_scale",
|
|
211
|
-
shape=(self.input_dim,),
|
|
212
|
-
initializer="ones",
|
|
213
|
-
trainable=False,
|
|
214
|
-
)
|
|
215
|
-
self._is_quantized = True
|
|
216
|
-
|
|
217
|
-
def _int8_call(self, inputs, reverse=False):
|
|
218
|
-
if reverse:
|
|
219
|
-
if self.tie_weights:
|
|
220
|
-
kernel = ops.transpose(self._embeddings)
|
|
221
|
-
scale = ops.transpose(self.embeddings_scale)
|
|
222
|
-
else:
|
|
223
|
-
kernel = self.reverse_embeddings
|
|
224
|
-
scale = self.reverse_embeddings_scale
|
|
225
|
-
inputs, inputs_scale = self.inputs_quantizer(inputs)
|
|
226
|
-
logits = ops.matmul(inputs, kernel)
|
|
227
|
-
# De-scale outputs
|
|
228
|
-
logits = ops.cast(logits, self.compute_dtype)
|
|
229
|
-
logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
|
|
230
|
-
# Optionally soft-cap logits.
|
|
231
|
-
if self.logit_soft_cap is not None:
|
|
232
|
-
soft_cap = self.logit_soft_cap
|
|
233
|
-
logits = ops.tanh(logits / soft_cap) * soft_cap
|
|
234
|
-
return logits
|
|
235
|
-
|
|
236
|
-
return super()._int8_call(inputs)
|
|
237
|
-
|
|
238
|
-
def quantize(self, mode, type_check=True, config=None):
|
|
239
|
-
del config
|
|
240
|
-
if type_check and type(self) is not ReversibleEmbedding:
|
|
241
|
-
raise self._not_implemented_error(self.quantize)
|
|
242
|
-
|
|
243
|
-
def abs_max_quantize(inputs, axis):
|
|
244
|
-
return keras.quantizers.abs_max_quantize(
|
|
245
|
-
inputs, axis=axis, to_numpy=True
|
|
246
|
-
)
|
|
247
|
-
|
|
248
|
-
if mode != "int8":
|
|
249
|
-
raise NotImplementedError(
|
|
250
|
-
"Invalid quantization mode. Expected 'int8'. "
|
|
251
|
-
f"Received: quantization_mode={mode}"
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
embeddings_shape = (self.input_dim, self.output_dim)
|
|
255
|
-
if mode == "int8":
|
|
256
|
-
embeddings, embeddings_scale = abs_max_quantize(
|
|
257
|
-
self._embeddings, axis=-1
|
|
258
|
-
)
|
|
259
|
-
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
|
260
|
-
del self._embeddings
|
|
261
|
-
if not self.tie_weights:
|
|
262
|
-
reverse_embeddings, reverse_embeddings_scale = abs_max_quantize(
|
|
263
|
-
self.reverse_embeddings, axis=0
|
|
264
|
-
)
|
|
265
|
-
reverse_embeddings_scale = ops.squeeze(
|
|
266
|
-
reverse_embeddings_scale, axis=0
|
|
267
|
-
)
|
|
268
|
-
del self.reverse_embeddings
|
|
269
|
-
self.quantized_build(embeddings_shape, mode)
|
|
270
|
-
if mode == "int8":
|
|
271
|
-
self._embeddings.assign(embeddings)
|
|
272
|
-
self.embeddings_scale.assign(embeddings_scale)
|
|
273
|
-
if not self.tie_weights:
|
|
274
|
-
self.reverse_embeddings.assign(reverse_embeddings)
|
|
275
|
-
self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
|
|
276
|
-
|
|
277
|
-
if self.dtype_policy.quantization_mode is None:
|
|
278
|
-
policy = keras.dtype_policies.get(
|
|
279
|
-
f"{mode}_from_{self.dtype_policy.name}"
|
|
280
|
-
)
|
|
281
|
-
self.dtype_policy = policy
|
|
7
|
+
class ReversibleEmbedding(keras.layers.ReversibleEmbedding):
|
|
8
|
+
pass
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
import numpy as np
|
|
2
3
|
from keras import ops
|
|
3
4
|
|
|
4
5
|
from keras_hub.src.api_export import keras_hub_export
|
|
@@ -25,6 +26,17 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
|
25
26
|
curves.
|
|
26
27
|
scaling_factor: float. The scaling factor used to scale positions of
|
|
27
28
|
the tokens.
|
|
29
|
+
rope_type: str. The type of RoPE scaling to apply. Supported types:
|
|
30
|
+
"linear", "dynamic", "yarn". Defaults to "linear".
|
|
31
|
+
beta_fast: float. Beta fast parameter for YaRN scaling. Only used
|
|
32
|
+
when rope_type="yarn". Defaults to 32.0.
|
|
33
|
+
beta_slow: float. Beta slow parameter for YaRN scaling. Only used
|
|
34
|
+
when rope_type="yarn". Defaults to 1.0.
|
|
35
|
+
original_max_position_embeddings: int. Original maximum position
|
|
36
|
+
embeddings for YaRN scaling. Only used when rope_type="yarn".
|
|
37
|
+
Defaults to 4096.
|
|
38
|
+
truncate: bool. Whether to apply truncation for YaRN scaling. Only used
|
|
39
|
+
when rope_type="yarn". Defaults to False.
|
|
28
40
|
sequence_axis: int. Sequence axis in the input tensor.
|
|
29
41
|
feature_axis: int. Feature axis in the input tensor.
|
|
30
42
|
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
|
@@ -69,6 +81,11 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
|
69
81
|
self,
|
|
70
82
|
max_wavelength=10000,
|
|
71
83
|
scaling_factor=1.0,
|
|
84
|
+
rope_type="linear",
|
|
85
|
+
beta_fast=32.0,
|
|
86
|
+
beta_slow=1.0,
|
|
87
|
+
original_max_position_embeddings=4096,
|
|
88
|
+
truncate=False,
|
|
72
89
|
sequence_axis=1,
|
|
73
90
|
feature_axis=-1,
|
|
74
91
|
**kwargs,
|
|
@@ -78,24 +95,70 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
|
78
95
|
self.sequence_axis = sequence_axis
|
|
79
96
|
self.feature_axis = feature_axis
|
|
80
97
|
self.scaling_factor = scaling_factor
|
|
98
|
+
self.rope_type = rope_type
|
|
99
|
+
|
|
100
|
+
# YaRN-specific parameters (only used when rope_type="yarn")
|
|
101
|
+
self.beta_fast = beta_fast
|
|
102
|
+
self.beta_slow = beta_slow
|
|
103
|
+
self.original_max_position_embeddings = original_max_position_embeddings
|
|
104
|
+
self.truncate = truncate
|
|
81
105
|
self.built = True
|
|
82
106
|
|
|
107
|
+
def _normalize_axes(self, input_shape):
|
|
108
|
+
"""Normalize and validate axis indices for the given input shape."""
|
|
109
|
+
rank = len(input_shape)
|
|
110
|
+
|
|
111
|
+
# Normalize negative indices
|
|
112
|
+
sequence_axis = self.sequence_axis
|
|
113
|
+
feature_axis = self.feature_axis
|
|
114
|
+
|
|
115
|
+
if sequence_axis < 0:
|
|
116
|
+
sequence_axis += rank
|
|
117
|
+
if feature_axis < 0:
|
|
118
|
+
feature_axis += rank
|
|
119
|
+
|
|
120
|
+
if sequence_axis < 0 or sequence_axis >= rank:
|
|
121
|
+
raise ValueError(
|
|
122
|
+
f"sequence_axis {self.sequence_axis} "
|
|
123
|
+
f"is out of range for input with rank {rank}"
|
|
124
|
+
)
|
|
125
|
+
if feature_axis < 0 or feature_axis >= rank:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"feature_axis {self.feature_axis} "
|
|
128
|
+
f"is out of range for input with rank {rank}"
|
|
129
|
+
)
|
|
130
|
+
if sequence_axis == feature_axis:
|
|
131
|
+
raise ValueError("sequence_axis and feature_axis must be different")
|
|
132
|
+
|
|
133
|
+
return sequence_axis, feature_axis
|
|
134
|
+
|
|
135
|
+
def _validate_rotary_dimension(self, rotary_dim):
|
|
136
|
+
if rotary_dim % 2 != 0:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"Rotary dimension must be even, got {rotary_dim}."
|
|
139
|
+
"The rotary embedding splits the feature dimension "
|
|
140
|
+
"into two halves. Consider using a different feature "
|
|
141
|
+
"dimension or padding."
|
|
142
|
+
)
|
|
143
|
+
|
|
83
144
|
def call(self, inputs, start_index=0, positions=None):
|
|
145
|
+
input_shape = ops.shape(inputs)
|
|
146
|
+
sequence_axis, feature_axis = self._normalize_axes(input_shape)
|
|
147
|
+
|
|
148
|
+
rotary_dim = input_shape[feature_axis]
|
|
149
|
+
self._validate_rotary_dimension(rotary_dim)
|
|
150
|
+
|
|
84
151
|
# Take care of unbatched `positions`.
|
|
85
152
|
if positions is not None:
|
|
86
153
|
if len(ops.shape(positions)) == 1:
|
|
87
154
|
positions = ops.expand_dims(positions, axis=0)
|
|
88
155
|
|
|
89
|
-
inputs = ops.moveaxis(
|
|
90
|
-
inputs, (self.feature_axis, self.sequence_axis), (-1, 1)
|
|
91
|
-
)
|
|
156
|
+
inputs = ops.moveaxis(inputs, (feature_axis, sequence_axis), (-1, 1))
|
|
92
157
|
cos_emb, sin_emb = self._compute_cos_sin_embedding(
|
|
93
158
|
inputs, start_index, positions
|
|
94
159
|
)
|
|
95
160
|
output = self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb)
|
|
96
|
-
return ops.moveaxis(
|
|
97
|
-
output, (-1, 1), (self.feature_axis, self.sequence_axis)
|
|
98
|
-
)
|
|
161
|
+
return ops.moveaxis(output, (-1, 1), (feature_axis, sequence_axis))
|
|
99
162
|
|
|
100
163
|
def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
|
|
101
164
|
x1, x2 = ops.split(tensor, 2, axis=-1)
|
|
@@ -113,19 +176,35 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
|
113
176
|
return positions + ops.cast(start_index, dtype="float32")
|
|
114
177
|
|
|
115
178
|
def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
|
|
179
|
+
"""Compute cos & sin RoPE embeddings with optional YaRN scaling.
|
|
180
|
+
Uses tensor ops only to remain JIT/backends friendly.
|
|
181
|
+
"""
|
|
116
182
|
batch_axis = 0
|
|
117
|
-
feature_axis = len(inputs.shape) - 1
|
|
118
183
|
sequence_axis = 1
|
|
184
|
+
feature_axis = len(inputs.shape) - 1
|
|
119
185
|
|
|
120
186
|
rotary_dim = ops.shape(inputs)[feature_axis]
|
|
121
187
|
inverse_freq = self._get_inverse_freq(rotary_dim)
|
|
122
188
|
|
|
123
189
|
if positions is None:
|
|
124
190
|
positions = self._compute_positions(inputs, start_index)
|
|
125
|
-
positions = ops.expand_dims(
|
|
191
|
+
positions = ops.expand_dims(
|
|
192
|
+
positions, axis=batch_axis
|
|
193
|
+
) # shape (1, seq_len)
|
|
126
194
|
else:
|
|
127
195
|
positions = ops.cast(positions, "float32")
|
|
128
|
-
|
|
196
|
+
if len(ops.shape(positions)) == 1:
|
|
197
|
+
positions = ops.expand_dims(positions, axis=batch_axis)
|
|
198
|
+
|
|
199
|
+
if (
|
|
200
|
+
self.rope_type == "yarn"
|
|
201
|
+
and self.truncate
|
|
202
|
+
and self.original_max_position_embeddings is not None
|
|
203
|
+
):
|
|
204
|
+
positions = ops.minimum(
|
|
205
|
+
positions,
|
|
206
|
+
ops.cast(self.original_max_position_embeddings, "float32"),
|
|
207
|
+
)
|
|
129
208
|
|
|
130
209
|
freq = ops.einsum("bi,j->bij", positions, inverse_freq)
|
|
131
210
|
|
|
@@ -140,15 +219,103 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
|
140
219
|
|
|
141
220
|
cos_emb = ops.cast(ops.cos(embedding), self.compute_dtype)
|
|
142
221
|
sin_emb = ops.cast(ops.sin(embedding), self.compute_dtype)
|
|
222
|
+
|
|
223
|
+
if self.rope_type == "yarn":
|
|
224
|
+
# YaRN temperature scaling
|
|
225
|
+
factor = ops.add(
|
|
226
|
+
ops.multiply(
|
|
227
|
+
ops.cast(0.1, self.compute_dtype),
|
|
228
|
+
ops.log(ops.cast(self.scaling_factor, self.compute_dtype)),
|
|
229
|
+
),
|
|
230
|
+
ops.cast(1.0, self.compute_dtype),
|
|
231
|
+
)
|
|
232
|
+
cos_emb = cos_emb * factor
|
|
233
|
+
sin_emb = sin_emb * factor
|
|
143
234
|
return cos_emb, sin_emb
|
|
144
235
|
|
|
145
236
|
def _get_inverse_freq(self, rotary_dim):
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
237
|
+
"""Return inverse frequencies."""
|
|
238
|
+
idx = ops.arange(0, rotary_dim, 2, dtype="float32")
|
|
239
|
+
denom = ops.cast(rotary_dim, "float32")
|
|
240
|
+
freq_range = idx / denom
|
|
241
|
+
inv = ops.power(ops.cast(self.max_wavelength, "float32"), -freq_range)
|
|
242
|
+
|
|
243
|
+
if self.rope_type == "linear":
|
|
244
|
+
return inv / ops.cast(self.scaling_factor, "float32")
|
|
245
|
+
elif self.rope_type == "dynamic":
|
|
246
|
+
exponent = ops.cast(rotary_dim, "float32") / ops.cast(
|
|
247
|
+
max(1, rotary_dim - 2), "float32"
|
|
248
|
+
)
|
|
249
|
+
return inv / ops.power(
|
|
250
|
+
ops.cast(self.scaling_factor, "float32"), exponent
|
|
251
|
+
)
|
|
252
|
+
elif self.rope_type == "yarn":
|
|
253
|
+
return self._get_yarn_inverse_freq(rotary_dim)
|
|
254
|
+
else:
|
|
255
|
+
return inv
|
|
256
|
+
|
|
257
|
+
def _get_yarn_inverse_freq(self, rotary_dim):
|
|
258
|
+
# Get the base (rope_theta equivalent) from max_wavelength
|
|
259
|
+
base = ops.cast(self.max_wavelength, "float32")
|
|
260
|
+
|
|
261
|
+
# Compute base frequencies: base ** (idx / dim)
|
|
262
|
+
idx = ops.arange(0, rotary_dim, 2, dtype="float32")
|
|
263
|
+
pos_freqs = ops.power(base, idx / ops.cast(rotary_dim, "float32"))
|
|
264
|
+
|
|
265
|
+
# Compute interpolation and extrapolation frequencies
|
|
266
|
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
|
267
|
+
inv_freq_interpolation = 1.0 / (
|
|
268
|
+
ops.cast(self.scaling_factor, "float32") * pos_freqs
|
|
149
269
|
)
|
|
150
|
-
|
|
151
|
-
|
|
270
|
+
|
|
271
|
+
# Find correction range
|
|
272
|
+
beta_fast = ops.cast(self.beta_fast, "float32")
|
|
273
|
+
beta_slow = ops.cast(self.beta_slow, "float32")
|
|
274
|
+
|
|
275
|
+
# Find correction dimensions for beta_fast and beta_slow
|
|
276
|
+
def find_correction_dim_tensor(num_rotations, dim):
|
|
277
|
+
max_pos = ops.cast(self.original_max_position_embeddings, "float32")
|
|
278
|
+
return (dim * ops.log(max_pos / (num_rotations * 2 * np.pi))) / (
|
|
279
|
+
2 * ops.log(base)
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
low = find_correction_dim_tensor(
|
|
283
|
+
beta_fast, ops.cast(rotary_dim, "float32")
|
|
284
|
+
)
|
|
285
|
+
high = find_correction_dim_tensor(
|
|
286
|
+
beta_slow, ops.cast(rotary_dim, "float32")
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Apply truncation if specified
|
|
290
|
+
if self.truncate:
|
|
291
|
+
low = ops.floor(low)
|
|
292
|
+
high = ops.ceil(high)
|
|
293
|
+
|
|
294
|
+
# Clamp to valid range
|
|
295
|
+
low = ops.maximum(low, ops.cast(0, "float32"))
|
|
296
|
+
high = ops.minimum(high, ops.cast(rotary_dim // 2 - 1, "float32"))
|
|
297
|
+
|
|
298
|
+
# Linear ramp function
|
|
299
|
+
dim_half = rotary_dim // 2
|
|
300
|
+
idx_half = ops.arange(0, dim_half, dtype="float32")
|
|
301
|
+
|
|
302
|
+
# Prevent singularity
|
|
303
|
+
diff = high - low
|
|
304
|
+
diff = ops.maximum(diff, ops.cast(0.001, "float32"))
|
|
305
|
+
|
|
306
|
+
linear_func = (idx_half - low) / diff
|
|
307
|
+
ramp_func = ops.clip(linear_func, 0, 1)
|
|
308
|
+
|
|
309
|
+
# Apply the ramp to get extrapolation factor
|
|
310
|
+
inv_freq_extrapolation_factor = 1 - ramp_func
|
|
311
|
+
|
|
312
|
+
# Combine interpolation and extrapolation
|
|
313
|
+
scaled_inverse_freq = (
|
|
314
|
+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
|
315
|
+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
return scaled_inverse_freq
|
|
152
319
|
|
|
153
320
|
def get_config(self):
|
|
154
321
|
config = super().get_config()
|
|
@@ -156,6 +323,13 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
|
156
323
|
{
|
|
157
324
|
"max_wavelength": self.max_wavelength,
|
|
158
325
|
"scaling_factor": self.scaling_factor,
|
|
326
|
+
"rope_type": self.rope_type,
|
|
327
|
+
"beta_fast": self.beta_fast,
|
|
328
|
+
"beta_slow": self.beta_slow,
|
|
329
|
+
"original_max_position_embeddings": (
|
|
330
|
+
self.original_max_position_embeddings
|
|
331
|
+
),
|
|
332
|
+
"truncate": self.truncate,
|
|
159
333
|
"sequence_axis": self.sequence_axis,
|
|
160
334
|
"feature_axis": self.feature_axis,
|
|
161
335
|
}
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
5
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
|
9
7
|
|
|
10
8
|
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
5
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
|
|
9
7
|
from keras_hub.src.models.backbone import Backbone
|
|
10
8
|
from keras_hub.src.utils.keras_utils import gelu_approximate
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
5
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder
|
|
9
7
|
from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
|
|
10
8
|
from keras_hub.src.models.backbone import Backbone
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
5
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
|
|
9
7
|
from keras_hub.src.models.backbone import Backbone
|
|
10
8
|
from keras_hub.src.utils.keras_utils import gelu_approximate
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
5
|
-
ReversibleEmbedding,
|
|
6
|
-
)
|
|
7
5
|
from keras_hub.src.models.backbone import Backbone
|
|
8
6
|
from keras_hub.src.models.bloom.bloom_decoder import BloomDecoder
|
|
9
7
|
|
|
@@ -196,7 +196,7 @@ class CausalLM(Task):
|
|
|
196
196
|
|
|
197
197
|
# Create an explicit tuple of all variable state.
|
|
198
198
|
state = (
|
|
199
|
-
self.sampler.variables,
|
|
199
|
+
[v.value for v in self.sampler.variables],
|
|
200
200
|
# Use the explicit variable.value to preserve the
|
|
201
201
|
# sharding spec of distribution.
|
|
202
202
|
[v.value for v in self.trainable_variables],
|
|
@@ -429,3 +429,25 @@ class CausalLM(Task):
|
|
|
429
429
|
super()._post_quantize(mode, **kwargs)
|
|
430
430
|
# Reset the compiled generate function.
|
|
431
431
|
self.generate_function = None
|
|
432
|
+
|
|
433
|
+
def get_quantization_layer_structure(self, mode):
|
|
434
|
+
if mode != "gptq":
|
|
435
|
+
return None
|
|
436
|
+
|
|
437
|
+
backbone = self.backbone
|
|
438
|
+
# Check for standard backbone structure.
|
|
439
|
+
if not hasattr(backbone, "transformer_layers"):
|
|
440
|
+
return None
|
|
441
|
+
|
|
442
|
+
# Check for embedding.
|
|
443
|
+
embedding = getattr(backbone, "token_embedding", None)
|
|
444
|
+
if embedding is None:
|
|
445
|
+
embedding = getattr(backbone, "embedding", None)
|
|
446
|
+
|
|
447
|
+
if embedding is None:
|
|
448
|
+
return None
|
|
449
|
+
|
|
450
|
+
return {
|
|
451
|
+
"pre_block_layers": [embedding],
|
|
452
|
+
"sequential_blocks": backbone.transformer_layers,
|
|
453
|
+
}
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
5
|
-
ReversibleEmbedding,
|
|
6
|
-
)
|
|
7
5
|
from keras_hub.src.models.backbone import Backbone
|
|
8
6
|
from keras_hub.src.models.deberta_v3.disentangled_attention_encoder import (
|
|
9
7
|
DisentangledAttentionEncoder,
|