keras-hub-nightly 0.24.0.dev202511220420__py3-none-any.whl → 0.26.0.dev202601010440__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/models/__init__.py +12 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/rotary_embedding.py +188 -14
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -3
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +1 -3
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +23 -1
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/dinov3/dinov3_presets.py +90 -1
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/esm/esm_attention.py +11 -4
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +16 -0
- keras_hub/src/models/gemma3/gemma3_backbone.py +1 -3
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +8 -3
- keras_hub/src/models/gemma3/gemma3_presets.py +12 -0
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +20 -8
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +17 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/__init__.py +5 -0
- keras_hub/src/models/gpt_oss/gpt_oss_attention.py +330 -0
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +219 -0
- keras_hub/src/models/gpt_oss/gpt_oss_causal_lm.py +284 -0
- keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_preprocessor.py +79 -0
- keras_hub/src/models/gpt_oss/gpt_oss_decoder.py +444 -0
- keras_hub/src/models/gpt_oss/gpt_oss_layer_norm.py +34 -0
- keras_hub/src/models/gpt_oss/gpt_oss_presets.py +51 -0
- keras_hub/src/models/gpt_oss/gpt_oss_tokenizer.py +39 -0
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/llama3/llama3_presets.py +1 -1
- keras_hub/src/models/masked_lm.py +22 -0
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_decoder.py +21 -9
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3/qwen3_presets.py +36 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/__init__.py +5 -0
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/smollm3/smollm3_presets.py +16 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +1 -1
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/tests/test_case.py +1 -3
- keras_hub/src/utils/transformers/convert_gemma3.py +353 -0
- keras_hub/src/utils/transformers/convert_gpt_oss.py +302 -0
- keras_hub/src/utils/transformers/preset_loader.py +12 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/METADATA +4 -5
- {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/RECORD +66 -53
- {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,444 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras import ops
|
|
3
|
+
|
|
4
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
|
5
|
+
compute_causal_mask,
|
|
6
|
+
)
|
|
7
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
|
8
|
+
merge_padding_and_attention_mask,
|
|
9
|
+
)
|
|
10
|
+
from keras_hub.src.models.gpt_oss.gpt_oss_attention import GptOssAttention
|
|
11
|
+
from keras_hub.src.models.gpt_oss.gpt_oss_layer_norm import (
|
|
12
|
+
GptOssLayerNormalization,
|
|
13
|
+
)
|
|
14
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GptOssExperts(keras.layers.Layer):
|
|
18
|
+
"""A layer containing the feed-forward expert networks for GPT-OSS.
|
|
19
|
+
|
|
20
|
+
This layer implements the expert networks as described in the GPT-OSS
|
|
21
|
+
paper. It uses a custom GLU activation.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
num_experts: int. The total number of experts.
|
|
25
|
+
hidden_dim: int. The hidden size of the model.
|
|
26
|
+
intermediate_dim: int. The intermediate size of the feed-forward
|
|
27
|
+
network.
|
|
28
|
+
kernel_initializer: string. The initializer for the kernel
|
|
29
|
+
weights. Defaults to "glorot_uniform".
|
|
30
|
+
alpha: float. The alpha parameter for the custom GLU
|
|
31
|
+
activation. Defaults to `1.702`.
|
|
32
|
+
limit: float. The clamping limit for gate and up
|
|
33
|
+
projections. Defaults to `7.0`.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
num_experts,
|
|
39
|
+
hidden_dim,
|
|
40
|
+
intermediate_dim,
|
|
41
|
+
kernel_initializer="glorot_uniform",
|
|
42
|
+
alpha=1.702,
|
|
43
|
+
limit=7.0,
|
|
44
|
+
**kwargs,
|
|
45
|
+
):
|
|
46
|
+
super().__init__(**kwargs)
|
|
47
|
+
self.num_experts = num_experts
|
|
48
|
+
self.hidden_dim = hidden_dim
|
|
49
|
+
self.intermediate_dim = intermediate_dim
|
|
50
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
|
51
|
+
self.alpha = alpha
|
|
52
|
+
self.limit = limit
|
|
53
|
+
|
|
54
|
+
def build(self, _):
|
|
55
|
+
self.gate_up_proj = self.add_weight(
|
|
56
|
+
shape=(
|
|
57
|
+
self.num_experts,
|
|
58
|
+
self.hidden_dim,
|
|
59
|
+
2 * self.intermediate_dim,
|
|
60
|
+
),
|
|
61
|
+
initializer=self.kernel_initializer,
|
|
62
|
+
name="gate_up_proj",
|
|
63
|
+
)
|
|
64
|
+
self.gate_up_proj_bias = self.add_weight(
|
|
65
|
+
shape=(self.num_experts, 2 * self.intermediate_dim),
|
|
66
|
+
initializer="zeros",
|
|
67
|
+
name="gate_up_proj_bias",
|
|
68
|
+
)
|
|
69
|
+
self.down_proj = self.add_weight(
|
|
70
|
+
shape=(self.num_experts, self.intermediate_dim, self.hidden_dim),
|
|
71
|
+
initializer=self.kernel_initializer,
|
|
72
|
+
name="down_proj",
|
|
73
|
+
)
|
|
74
|
+
self.down_proj_bias = self.add_weight(
|
|
75
|
+
shape=(self.num_experts, self.hidden_dim),
|
|
76
|
+
initializer="zeros",
|
|
77
|
+
name="down_proj_bias",
|
|
78
|
+
)
|
|
79
|
+
self.built = True
|
|
80
|
+
|
|
81
|
+
def call(self, hidden_states):
|
|
82
|
+
# hidden_states shape: (num_tokens, hidden_dim)
|
|
83
|
+
# Einsum for batched matrix multiplication across experts.
|
|
84
|
+
# [num_experts, num_tokens, 2 * intermediate_dim]
|
|
85
|
+
gate_up = ops.einsum("th,ehm->etm", hidden_states, self.gate_up_proj)
|
|
86
|
+
gate_up = gate_up + self.gate_up_proj_bias[:, None, :]
|
|
87
|
+
|
|
88
|
+
gate = gate_up[..., ::2]
|
|
89
|
+
up = gate_up[..., 1::2]
|
|
90
|
+
|
|
91
|
+
gate = ops.clip(gate, -1e9, self.limit)
|
|
92
|
+
up = ops.clip(up, -self.limit, self.limit)
|
|
93
|
+
|
|
94
|
+
glu = gate * ops.sigmoid(gate * self.alpha)
|
|
95
|
+
gated_output = (up + 1) * glu
|
|
96
|
+
|
|
97
|
+
# [num_experts, num_tokens, hidden_dim]
|
|
98
|
+
out = ops.einsum("etm,emh->eth", gated_output, self.down_proj)
|
|
99
|
+
out = out + self.down_proj_bias[:, None, :]
|
|
100
|
+
return out
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class GptOssTopKRouter(keras.layers.Layer):
|
|
104
|
+
"""A layer for routing tokens to the top-k experts.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
num_experts: int. The total number of experts.
|
|
108
|
+
top_k: int. The number of experts to route each token to.
|
|
109
|
+
kernel_initializer: string. The initializer for the kernel
|
|
110
|
+
weights. Defaults to "glorot_uniform".
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
num_experts,
|
|
116
|
+
top_k,
|
|
117
|
+
kernel_initializer="glorot_uniform",
|
|
118
|
+
**kwargs,
|
|
119
|
+
):
|
|
120
|
+
super().__init__(**kwargs)
|
|
121
|
+
self.num_experts = num_experts
|
|
122
|
+
self.top_k = top_k
|
|
123
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
|
124
|
+
|
|
125
|
+
def build(self, hidden_states_shape):
|
|
126
|
+
self.router_dense = keras.layers.Dense(
|
|
127
|
+
self.num_experts,
|
|
128
|
+
kernel_initializer=self.kernel_initializer,
|
|
129
|
+
dtype=self.dtype_policy,
|
|
130
|
+
name="router_dense",
|
|
131
|
+
)
|
|
132
|
+
self.router_dense.build(hidden_states_shape)
|
|
133
|
+
self.built = True
|
|
134
|
+
|
|
135
|
+
def call(self, hidden_states):
|
|
136
|
+
# hidden_states shape: (num_tokens, hidden_dim)
|
|
137
|
+
router_logits = self.router_dense(hidden_states)
|
|
138
|
+
|
|
139
|
+
routing_weights, selected_experts = ops.top_k(
|
|
140
|
+
router_logits, k=self.top_k
|
|
141
|
+
)
|
|
142
|
+
routing_weights = ops.softmax(routing_weights, axis=-1)
|
|
143
|
+
|
|
144
|
+
expert_mask = ops.one_hot(selected_experts, self.num_experts)
|
|
145
|
+
expert_mask = ops.cast(expert_mask, dtype=routing_weights.dtype)
|
|
146
|
+
|
|
147
|
+
# Shape: (num_tokens, top_k, num_experts)
|
|
148
|
+
weighted_mask = expert_mask * ops.expand_dims(routing_weights, axis=-1)
|
|
149
|
+
|
|
150
|
+
# Shape: (num_tokens, num_experts)
|
|
151
|
+
router_scores = ops.sum(weighted_mask, axis=1)
|
|
152
|
+
|
|
153
|
+
return router_scores
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class GptOssSparseMoeBlock(keras.layers.Layer):
|
|
157
|
+
"""GPT-OSS sparse Mixture of Experts (MoE) block.
|
|
158
|
+
|
|
159
|
+
This block combines a router and a set of expert networks to implement
|
|
160
|
+
the MoE layer.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
hidden_dim: int. The hidden size of the model.
|
|
164
|
+
intermediate_dim: int. The intermediate size of the feed-forward
|
|
165
|
+
network.
|
|
166
|
+
num_experts: int. The total number of experts.
|
|
167
|
+
top_k: int. The number of experts to route each token to.
|
|
168
|
+
Defaults to 2.
|
|
169
|
+
kernel_initializer: string. The initializer for the kernel
|
|
170
|
+
weights. Defaults to "glorot_uniform".
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
def __init__(
|
|
174
|
+
self,
|
|
175
|
+
hidden_dim,
|
|
176
|
+
intermediate_dim,
|
|
177
|
+
num_experts,
|
|
178
|
+
top_k=2,
|
|
179
|
+
kernel_initializer="glorot_uniform",
|
|
180
|
+
**kwargs,
|
|
181
|
+
):
|
|
182
|
+
super().__init__(**kwargs)
|
|
183
|
+
self.hidden_dim = hidden_dim
|
|
184
|
+
self.intermediate_dim = intermediate_dim
|
|
185
|
+
self.num_experts = num_experts
|
|
186
|
+
self.top_k = top_k
|
|
187
|
+
self.kernel_initializer = kernel_initializer
|
|
188
|
+
|
|
189
|
+
def build(self, decoder_sequence_shape):
|
|
190
|
+
self.router = GptOssTopKRouter(
|
|
191
|
+
num_experts=self.num_experts,
|
|
192
|
+
top_k=self.top_k,
|
|
193
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
|
194
|
+
dtype=self.dtype_policy,
|
|
195
|
+
name="router",
|
|
196
|
+
)
|
|
197
|
+
self.router.build(decoder_sequence_shape)
|
|
198
|
+
|
|
199
|
+
self.experts = GptOssExperts(
|
|
200
|
+
num_experts=self.num_experts,
|
|
201
|
+
hidden_dim=self.hidden_dim,
|
|
202
|
+
intermediate_dim=self.intermediate_dim,
|
|
203
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
|
204
|
+
dtype=self.dtype_policy,
|
|
205
|
+
name="experts",
|
|
206
|
+
)
|
|
207
|
+
self.experts.build(decoder_sequence_shape)
|
|
208
|
+
self.built = True
|
|
209
|
+
|
|
210
|
+
def call(self, hidden_states):
|
|
211
|
+
batch_size, seq_len, _ = ops.shape(hidden_states)
|
|
212
|
+
hidden_states_flattened = ops.reshape(
|
|
213
|
+
hidden_states, (-1, self.hidden_dim)
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
router_scores = self.router(hidden_states_flattened)
|
|
217
|
+
|
|
218
|
+
expert_outputs = self.experts(hidden_states_flattened)
|
|
219
|
+
|
|
220
|
+
# Weight expert outputs by router scores and sum
|
|
221
|
+
# router_scores shape: (num_tokens, num_experts)
|
|
222
|
+
# expert_outputs shape: (num_experts, num_tokens, hidden_dim)
|
|
223
|
+
# Transpose scores for broadcasting: (num_experts, num_tokens)
|
|
224
|
+
router_scores_t = ops.transpose(router_scores)
|
|
225
|
+
# Expand for broadcasting: (num_experts, num_tokens, 1)
|
|
226
|
+
router_scores_expanded = ops.expand_dims(router_scores_t, axis=-1)
|
|
227
|
+
|
|
228
|
+
weighted_outputs = expert_outputs * router_scores_expanded
|
|
229
|
+
final_output = ops.sum(weighted_outputs, axis=0)
|
|
230
|
+
|
|
231
|
+
final_output = ops.reshape(
|
|
232
|
+
final_output, (batch_size, seq_len, self.hidden_dim)
|
|
233
|
+
)
|
|
234
|
+
return final_output, router_scores
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class GptOssTransformerDecoder(keras.layers.Layer):
|
|
238
|
+
"""A GPT-OSS transformer decoder layer.
|
|
239
|
+
|
|
240
|
+
This layer implements the transformer decoder block from the GPT-OSS
|
|
241
|
+
model, which includes self-attention and a sparse MoE block.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
intermediate_dim: int. The intermediate size of the feed-forward
|
|
245
|
+
network.
|
|
246
|
+
num_query_heads: int. The number of query attention heads.
|
|
247
|
+
num_key_value_heads: int. The number of key and value attention
|
|
248
|
+
heads.
|
|
249
|
+
num_experts: int. The total number of experts in the MoE layer.
|
|
250
|
+
top_k: int. The number of experts to route each token to.
|
|
251
|
+
Defaults to 2.
|
|
252
|
+
output_router_logits: bool. If True, the router logits will
|
|
253
|
+
be returned by the layer. Defaults to False.
|
|
254
|
+
rope_max_wavelength: int. The maximum wavelength for the
|
|
255
|
+
rotary position embedding. Defaults to 10000.
|
|
256
|
+
rope_scaling_factor: float. The scaling factor for the
|
|
257
|
+
rotary position embedding. Defaults to 1.0.
|
|
258
|
+
layer_norm_epsilon: float. The epsilon for layer
|
|
259
|
+
normalization. Defaults to 1e-6.
|
|
260
|
+
kernel_initializer: string. The initializer for the kernel
|
|
261
|
+
weights. Defaults to "glorot_uniform".
|
|
262
|
+
sliding_window: int. The size of the sliding window for
|
|
263
|
+
attention. Defaults to 4096.
|
|
264
|
+
dropout: float. The dropout rate. Defaults to 0.
|
|
265
|
+
"""
|
|
266
|
+
|
|
267
|
+
def __init__(
|
|
268
|
+
self,
|
|
269
|
+
intermediate_dim,
|
|
270
|
+
num_query_heads,
|
|
271
|
+
num_key_value_heads,
|
|
272
|
+
num_experts,
|
|
273
|
+
top_k=2,
|
|
274
|
+
output_router_logits=False,
|
|
275
|
+
rope_max_wavelength=10000,
|
|
276
|
+
rope_scaling_factor=1.0,
|
|
277
|
+
layer_norm_epsilon=1e-6,
|
|
278
|
+
kernel_initializer="glorot_uniform",
|
|
279
|
+
sliding_window=4096,
|
|
280
|
+
dropout=0,
|
|
281
|
+
head_dim=None,
|
|
282
|
+
**kwargs,
|
|
283
|
+
):
|
|
284
|
+
super().__init__(**kwargs)
|
|
285
|
+
self.intermediate_dim = intermediate_dim
|
|
286
|
+
self.num_query_heads = num_query_heads
|
|
287
|
+
self.num_key_value_heads = num_key_value_heads
|
|
288
|
+
self.num_experts = num_experts
|
|
289
|
+
self.top_k = top_k
|
|
290
|
+
self.output_router_logits = output_router_logits
|
|
291
|
+
self.rope_max_wavelength = rope_max_wavelength
|
|
292
|
+
self.rope_scaling_factor = rope_scaling_factor
|
|
293
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
|
294
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
|
295
|
+
self.sliding_window = sliding_window
|
|
296
|
+
self.dropout = dropout
|
|
297
|
+
self.head_dim = head_dim
|
|
298
|
+
self.supports_masking = True
|
|
299
|
+
|
|
300
|
+
def build(self, decoder_sequence_shape):
|
|
301
|
+
self.hidden_dim = decoder_sequence_shape[-1]
|
|
302
|
+
|
|
303
|
+
self.self_attention_layer = GptOssAttention(
|
|
304
|
+
num_query_heads=self.num_query_heads,
|
|
305
|
+
num_key_value_heads=self.num_key_value_heads,
|
|
306
|
+
rope_max_wavelength=self.rope_max_wavelength,
|
|
307
|
+
rope_scaling_factor=self.rope_scaling_factor,
|
|
308
|
+
sliding_window=self.sliding_window,
|
|
309
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
|
310
|
+
dropout=self.dropout,
|
|
311
|
+
head_dim=self.head_dim,
|
|
312
|
+
dtype=self.dtype_policy,
|
|
313
|
+
name="self_attention",
|
|
314
|
+
)
|
|
315
|
+
self.self_attention_layer.build(decoder_sequence_shape)
|
|
316
|
+
|
|
317
|
+
self.input_layernorm = GptOssLayerNormalization(
|
|
318
|
+
epsilon=self.layer_norm_epsilon,
|
|
319
|
+
dtype=self.dtype_policy,
|
|
320
|
+
name="input_layernorm",
|
|
321
|
+
)
|
|
322
|
+
self.input_layernorm.build(decoder_sequence_shape)
|
|
323
|
+
|
|
324
|
+
self.post_attention_layernorm = GptOssLayerNormalization(
|
|
325
|
+
epsilon=self.layer_norm_epsilon,
|
|
326
|
+
dtype=self.dtype_policy,
|
|
327
|
+
name="post_attention_layernorm",
|
|
328
|
+
)
|
|
329
|
+
self.post_attention_layernorm.build(decoder_sequence_shape)
|
|
330
|
+
|
|
331
|
+
self.sparse_moe_block = GptOssSparseMoeBlock(
|
|
332
|
+
hidden_dim=self.hidden_dim,
|
|
333
|
+
intermediate_dim=self.intermediate_dim,
|
|
334
|
+
num_experts=self.num_experts,
|
|
335
|
+
top_k=self.top_k,
|
|
336
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
|
337
|
+
dtype=self.dtype_policy,
|
|
338
|
+
name="sparse_moe_block",
|
|
339
|
+
)
|
|
340
|
+
self.sparse_moe_block.build(decoder_sequence_shape)
|
|
341
|
+
|
|
342
|
+
self.built = True
|
|
343
|
+
|
|
344
|
+
def call(
|
|
345
|
+
self,
|
|
346
|
+
decoder_sequence,
|
|
347
|
+
decoder_padding_mask=None,
|
|
348
|
+
decoder_attention_mask=None,
|
|
349
|
+
self_attention_cache=None,
|
|
350
|
+
self_attention_cache_update_index=None,
|
|
351
|
+
training=None,
|
|
352
|
+
):
|
|
353
|
+
self_attention_mask = self._compute_self_attention_mask(
|
|
354
|
+
decoder_sequence=decoder_sequence,
|
|
355
|
+
decoder_padding_mask=decoder_padding_mask,
|
|
356
|
+
decoder_attention_mask=decoder_attention_mask,
|
|
357
|
+
self_attention_cache=self_attention_cache,
|
|
358
|
+
self_attention_cache_update_index=self_attention_cache_update_index,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
residual = decoder_sequence
|
|
362
|
+
x = self.input_layernorm(decoder_sequence)
|
|
363
|
+
|
|
364
|
+
x = self.self_attention_layer(
|
|
365
|
+
hidden_states=x,
|
|
366
|
+
attention_mask=self_attention_mask,
|
|
367
|
+
cache=self_attention_cache,
|
|
368
|
+
cache_update_index=self_attention_cache_update_index,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
if self_attention_cache is not None:
|
|
372
|
+
x, self_attention_cache = x
|
|
373
|
+
|
|
374
|
+
x = x + residual
|
|
375
|
+
residual = x
|
|
376
|
+
|
|
377
|
+
x = self.post_attention_layernorm(x)
|
|
378
|
+
x, router_logits = self.sparse_moe_block(x)
|
|
379
|
+
|
|
380
|
+
decoder_output = x + residual
|
|
381
|
+
|
|
382
|
+
output = (decoder_output,)
|
|
383
|
+
if self_attention_cache is not None:
|
|
384
|
+
output += (self_attention_cache,)
|
|
385
|
+
if self.output_router_logits:
|
|
386
|
+
output += (router_logits,)
|
|
387
|
+
|
|
388
|
+
return output[0] if len(output) == 1 else output
|
|
389
|
+
|
|
390
|
+
def _compute_self_attention_mask(
|
|
391
|
+
self,
|
|
392
|
+
decoder_sequence,
|
|
393
|
+
decoder_padding_mask,
|
|
394
|
+
decoder_attention_mask,
|
|
395
|
+
self_attention_cache,
|
|
396
|
+
self_attention_cache_update_index,
|
|
397
|
+
):
|
|
398
|
+
decoder_mask = merge_padding_and_attention_mask(
|
|
399
|
+
decoder_sequence, decoder_padding_mask, decoder_attention_mask
|
|
400
|
+
)
|
|
401
|
+
batch_size = ops.shape(decoder_sequence)[0]
|
|
402
|
+
input_length = output_length = ops.shape(decoder_sequence)[1]
|
|
403
|
+
|
|
404
|
+
if self_attention_cache is not None:
|
|
405
|
+
input_length = ops.shape(self_attention_cache)[2]
|
|
406
|
+
|
|
407
|
+
cache_update_index = (
|
|
408
|
+
0
|
|
409
|
+
if self_attention_cache_update_index is None
|
|
410
|
+
else self_attention_cache_update_index
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
causal_mask = compute_causal_mask(
|
|
414
|
+
batch_size, input_length, output_length, cache_update_index
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
return (
|
|
418
|
+
ops.minimum(decoder_mask, causal_mask)
|
|
419
|
+
if decoder_mask is not None
|
|
420
|
+
else causal_mask
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
def get_config(self):
|
|
424
|
+
config = super().get_config()
|
|
425
|
+
config.update(
|
|
426
|
+
{
|
|
427
|
+
"intermediate_dim": self.intermediate_dim,
|
|
428
|
+
"num_query_heads": self.num_query_heads,
|
|
429
|
+
"num_key_value_heads": self.num_key_value_heads,
|
|
430
|
+
"num_experts": self.num_experts,
|
|
431
|
+
"top_k": self.top_k,
|
|
432
|
+
"output_router_logits": self.output_router_logits,
|
|
433
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
|
434
|
+
"rope_scaling_factor": self.rope_scaling_factor,
|
|
435
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
436
|
+
"kernel_initializer": keras.initializers.serialize(
|
|
437
|
+
self.kernel_initializer
|
|
438
|
+
),
|
|
439
|
+
"sliding_window": self.sliding_window,
|
|
440
|
+
"dropout": self.dropout,
|
|
441
|
+
"head_dim": self.head_dim,
|
|
442
|
+
}
|
|
443
|
+
)
|
|
444
|
+
return config
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras import ops
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
# NOTE: `keras.layers.LayerNormalization(rms_scaling=True)`
|
|
6
|
+
# does not produce the same results.
|
|
7
|
+
class GptOssLayerNormalization(keras.layers.Layer):
|
|
8
|
+
"""A normalization layer for Gpt-Oss that implements RMS normalization."""
|
|
9
|
+
|
|
10
|
+
def __init__(self, epsilon=1e-6, **kwargs):
|
|
11
|
+
super().__init__(**kwargs)
|
|
12
|
+
self.epsilon = epsilon
|
|
13
|
+
|
|
14
|
+
def build(self, input_shape):
|
|
15
|
+
dim = input_shape[-1]
|
|
16
|
+
self.scale = self.add_weight(
|
|
17
|
+
name="scale",
|
|
18
|
+
trainable=True,
|
|
19
|
+
shape=(dim,),
|
|
20
|
+
initializer="ones",
|
|
21
|
+
dtype=self.variable_dtype,
|
|
22
|
+
)
|
|
23
|
+
self.built = True
|
|
24
|
+
|
|
25
|
+
def call(self, x):
|
|
26
|
+
x = ops.cast(x, "float32")
|
|
27
|
+
var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
|
|
28
|
+
x = x * ops.rsqrt(var + self.epsilon)
|
|
29
|
+
return ops.cast(x * self.scale, self.compute_dtype)
|
|
30
|
+
|
|
31
|
+
def get_config(self):
|
|
32
|
+
config = super().get_config()
|
|
33
|
+
config.update({"epsilon": self.epsilon})
|
|
34
|
+
return config
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
backbone_presets = {
|
|
2
|
+
"gpt_oss_20b_en": {
|
|
3
|
+
"metadata": {
|
|
4
|
+
"description": (
|
|
5
|
+
"This preset has 21 billion total parameters, "
|
|
6
|
+
"with 3.6 billion active parameters, a 128k context "
|
|
7
|
+
"length, and is de-quantized from MXFP4."
|
|
8
|
+
),
|
|
9
|
+
"params": 20_914_757_184,
|
|
10
|
+
"path": "gpt_oss",
|
|
11
|
+
},
|
|
12
|
+
"kaggle_handle": "kaggle://keras/gpt-oss/keras/gpt_oss_20b_en/1",
|
|
13
|
+
},
|
|
14
|
+
"gpt_oss_120b_en": {
|
|
15
|
+
"metadata": {
|
|
16
|
+
"description": (
|
|
17
|
+
"This preset has 117 billion total parameters, "
|
|
18
|
+
"with 5.1 billion active parameters, a 128k context "
|
|
19
|
+
"length, and is de-quantized from MXFP4."
|
|
20
|
+
),
|
|
21
|
+
"params": 116_829_156_672,
|
|
22
|
+
"path": "gpt_oss",
|
|
23
|
+
},
|
|
24
|
+
"kaggle_handle": "kaggle://keras/gpt-oss/keras/gpt_oss_120b_en/1",
|
|
25
|
+
},
|
|
26
|
+
"gpt_oss_safeguard_20b_en": {
|
|
27
|
+
"metadata": {
|
|
28
|
+
"description": (
|
|
29
|
+
"Open-weight safety reasoning model with 21 billion "
|
|
30
|
+
"total parameters,with 3.6 billion active "
|
|
31
|
+
"parameters, a context length of over 128k, "
|
|
32
|
+
"and is de-quantized from MXFP4."
|
|
33
|
+
),
|
|
34
|
+
"params": 20_914_757_184,
|
|
35
|
+
"path": "gpt_oss",
|
|
36
|
+
},
|
|
37
|
+
"kaggle_handle": "kaggle://keras/gpt-oss-safeguard/keras/gpt_oss_safeguard_20b_en/1",
|
|
38
|
+
},
|
|
39
|
+
"gpt_oss_safeguard_120b_en": {
|
|
40
|
+
"metadata": {
|
|
41
|
+
"description": (
|
|
42
|
+
"Open-weight safety reasoning model with 117 billion "
|
|
43
|
+
"total parameters,with 5.1 billion active parameters, "
|
|
44
|
+
"a 128k context length, and is de-quantized from MXFP4."
|
|
45
|
+
),
|
|
46
|
+
"params": 116_829_156_672,
|
|
47
|
+
"path": "gpt_oss",
|
|
48
|
+
},
|
|
49
|
+
"kaggle_handle": "kaggle://keras/gpt-oss-safeguard/keras/gpt_oss_safeguard_120b_en/1",
|
|
50
|
+
},
|
|
51
|
+
}
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
2
|
+
from keras_hub.src.models.gpt_oss.gpt_oss_backbone import GptOssBackbone
|
|
3
|
+
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@keras_hub_export(
|
|
7
|
+
[
|
|
8
|
+
"keras_hub.tokenizers.GptOssTokenizer",
|
|
9
|
+
"keras_hub.models.GptOssTokenizer",
|
|
10
|
+
]
|
|
11
|
+
)
|
|
12
|
+
class GptOssTokenizer(BytePairTokenizer):
|
|
13
|
+
"""A GptOss tokenizer using BytePair encoding.
|
|
14
|
+
|
|
15
|
+
Tokenizer is a subclass of `keras_hub.tokenizers.BytePairTokenizer`.
|
|
16
|
+
It uses a BytePair encoding model to tokenize strings. It also adds special
|
|
17
|
+
tokens for the start and end of a sequence.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
vocabulary: string or dict, maps token to integer ids. If it is a
|
|
21
|
+
string, it should be the file path to a json file.
|
|
22
|
+
merges: string or list, contains the merge rule. If it is a string,
|
|
23
|
+
it should be the file path to merge rules.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
backbone_cls = GptOssBackbone
|
|
27
|
+
|
|
28
|
+
def __init__(self, vocabulary=None, merges=None, **kwargs):
|
|
29
|
+
"""Initializes the GptOssTokenizer.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
vocabulary: string or dict, maps token to integer ids.
|
|
33
|
+
merges: string or list, contains the merge rule.
|
|
34
|
+
**kwargs: Additional keyword arguments.
|
|
35
|
+
"""
|
|
36
|
+
self.start_token_id = None
|
|
37
|
+
self._add_special_token("<|endoftext|>", "end_token")
|
|
38
|
+
self.pad_token_id = 0
|
|
39
|
+
super().__init__(vocabulary=vocabulary, merges=merges, **kwargs)
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
2
|
from keras import ops
|
|
3
|
+
from keras.layers import ReversibleEmbedding
|
|
3
4
|
|
|
4
5
|
from keras_hub.src.api_export import keras_hub_export
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.models.backbone import Backbone
|
|
9
7
|
from keras_hub.src.models.llama.llama_decoder import LlamaTransformerDecoder
|
|
10
8
|
from keras_hub.src.models.llama.llama_layernorm import LlamaLayerNorm
|
|
@@ -63,7 +63,7 @@ backbone_presets = {
|
|
|
63
63
|
"params": 8030261248,
|
|
64
64
|
"path": "llama3",
|
|
65
65
|
},
|
|
66
|
-
"kaggle_handle": ("kaggle://keras/llama3/keras/
|
|
66
|
+
"kaggle_handle": ("kaggle://keras/llama3/keras/llama3.1_instruct_8b/2"),
|
|
67
67
|
},
|
|
68
68
|
"llama3.1_guard_8b": {
|
|
69
69
|
"metadata": {
|
|
@@ -84,3 +84,25 @@ class MaskedLM(Task):
|
|
|
84
84
|
weighted_metrics=weighted_metrics,
|
|
85
85
|
**kwargs,
|
|
86
86
|
)
|
|
87
|
+
|
|
88
|
+
def get_quantization_layer_structure(self, mode):
|
|
89
|
+
if mode != "gptq":
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
backbone = self.backbone
|
|
93
|
+
# Check for standard backbone structure.
|
|
94
|
+
if not hasattr(backbone, "transformer_layers"):
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
# Check for embedding.
|
|
98
|
+
embedding = getattr(backbone, "token_embedding", None)
|
|
99
|
+
if embedding is None:
|
|
100
|
+
embedding = getattr(backbone, "embedding", None)
|
|
101
|
+
|
|
102
|
+
if embedding is None:
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
return {
|
|
106
|
+
"pre_block_layers": [embedding],
|
|
107
|
+
"sequential_blocks": backbone.transformer_layers,
|
|
108
|
+
}
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
2
|
from keras import ops
|
|
3
|
+
from keras.layers import ReversibleEmbedding
|
|
3
4
|
|
|
4
5
|
from keras_hub.src.api_export import keras_hub_export
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.models.backbone import Backbone
|
|
9
7
|
from keras_hub.src.models.mistral.mistral_layer_norm import (
|
|
10
8
|
MistralLayerNormalization,
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
2
|
from keras import ops
|
|
3
|
+
from keras.layers import ReversibleEmbedding
|
|
3
4
|
|
|
4
5
|
from keras_hub.src.api_export import keras_hub_export
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.models.backbone import Backbone
|
|
9
7
|
from keras_hub.src.models.mixtral.mixtral_decoder import (
|
|
10
8
|
MixtralTransformerDecoder,
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
5
|
-
ReversibleEmbedding,
|
|
6
|
-
)
|
|
7
5
|
from keras_hub.src.models.backbone import Backbone
|
|
8
6
|
from keras_hub.src.models.moonshine.moonshine_decoder import (
|
|
9
7
|
MoonshineDecoderBlock,
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
2
|
from keras import ops
|
|
3
|
+
from keras.layers import ReversibleEmbedding
|
|
3
4
|
|
|
4
5
|
from keras_hub.src.api_export import keras_hub_export
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.models.backbone import Backbone
|
|
9
7
|
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
|
|
10
8
|
from keras_hub.src.models.pali_gemma.pali_gemma_decoder_block import (
|