keras-hub 0.24.0.dev0__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. keras_hub/models/__init__.py +12 -0
  2. keras_hub/src/layers/modeling/rotary_embedding.py +188 -14
  3. keras_hub/src/models/esm/esm_attention.py +11 -4
  4. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +8 -3
  5. keras_hub/src/models/gemma3/gemma3_presets.py +12 -0
  6. keras_hub/src/models/gemma3/gemma3_tokenizer.py +20 -8
  7. keras_hub/src/models/gpt_oss/__init__.py +5 -0
  8. keras_hub/src/models/gpt_oss/gpt_oss_attention.py +330 -0
  9. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +221 -0
  10. keras_hub/src/models/gpt_oss/gpt_oss_causal_lm.py +284 -0
  11. keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_preprocessor.py +79 -0
  12. keras_hub/src/models/gpt_oss/gpt_oss_decoder.py +444 -0
  13. keras_hub/src/models/gpt_oss/gpt_oss_layer_norm.py +34 -0
  14. keras_hub/src/models/gpt_oss/gpt_oss_presets.py +51 -0
  15. keras_hub/src/models/gpt_oss/gpt_oss_tokenizer.py +39 -0
  16. keras_hub/src/models/llama3/llama3_presets.py +1 -1
  17. keras_hub/src/models/parseq/parseq_decoder.py +21 -9
  18. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +1 -1
  19. keras_hub/src/utils/transformers/convert_gemma3.py +353 -0
  20. keras_hub/src/utils/transformers/convert_gpt_oss.py +302 -0
  21. keras_hub/src/utils/transformers/preset_loader.py +12 -0
  22. keras_hub/src/version.py +1 -1
  23. keras_hub/tokenizers/__init__.py +3 -0
  24. {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dist-info}/METADATA +1 -1
  25. {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dist-info}/RECORD +27 -16
  26. {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dist-info}/WHEEL +0 -0
  27. {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,330 @@
1
+ import math
2
+
3
+ import keras
4
+ from keras import ops
5
+
6
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
7
+ from keras_hub.src.utils.keras_utils import clone_initializer
8
+
9
+
10
+ class GptOssAttention(keras.layers.Layer):
11
+ """A cached attention layer with sliding window and sink tokens.
12
+
13
+ This layer implements the attention mechanism described in the GPT-OSS
14
+ paper. It includes grouped-query attention, rotary position embeddings,
15
+ sliding window attention, and sink tokens for improved performance on
16
+ long sequences.
17
+
18
+ Args:
19
+ num_query_heads: int. The number of query attention heads.
20
+ num_key_value_heads: int. The number of key and value attention
21
+ heads.
22
+ rope_max_wavelength: int. The maximum wavelength for the
23
+ rotary position embedding. Defaults to 10000.
24
+ rope_scaling_factor: float. The scaling factor for the
25
+ rotary position embedding. Defaults to 1.0.
26
+ kernel_initializer: str. The initializer for the kernel
27
+ weights. Defaults to "glorot_uniform".
28
+ sliding_window: int. The size of the sliding window.
29
+ Defaults to 4096.
30
+ dropout: float. The dropout rate. Defaults to 0.
31
+ head_dim: int. Head dimension for attention. If None,
32
+ calculated as hidden_dim // num_query_heads. Defaults to None.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ num_query_heads,
38
+ num_key_value_heads,
39
+ rope_max_wavelength=10000,
40
+ rope_scaling_factor=1.0,
41
+ kernel_initializer="glorot_uniform",
42
+ sliding_window=4096,
43
+ dropout=0,
44
+ head_dim=None,
45
+ **kwargs,
46
+ ):
47
+ super().__init__(**kwargs)
48
+ self.num_query_heads = num_query_heads
49
+ self.num_key_value_heads = num_key_value_heads
50
+ self.sliding_window = sliding_window
51
+ self.dropout = dropout
52
+ self.head_dim = head_dim
53
+ self.rope_max_wavelength = rope_max_wavelength
54
+ self.rope_scaling_factor = rope_scaling_factor
55
+ self.num_key_value_groups = num_query_heads // num_key_value_heads
56
+ self._kernel_initializer = keras.initializers.get(
57
+ clone_initializer(kernel_initializer)
58
+ )
59
+
60
+ def build(self, inputs_shape):
61
+ # Einsum variables:
62
+ # b = batch size
63
+ # q = query length
64
+ # k = key/value length
65
+ # m = the model's hidden_dim
66
+ # u = num query heads
67
+ # v = num key/value heads
68
+ # h = head dim
69
+ self._hidden_dim = inputs_shape[-1]
70
+
71
+ if self.head_dim is not None:
72
+ self._head_dim = self.head_dim
73
+ else:
74
+ self._head_dim = self._hidden_dim // self.num_query_heads
75
+ self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)
76
+
77
+ self._rotary_dim = (self._head_dim // 2) * 2
78
+
79
+ self.query_dense = keras.layers.EinsumDense(
80
+ equation="bqm,muh->bquh",
81
+ output_shape=(None, self.num_query_heads, self._head_dim),
82
+ bias_axes="uh",
83
+ kernel_initializer=self._kernel_initializer,
84
+ bias_initializer="zeros",
85
+ dtype=self.dtype_policy,
86
+ name="query",
87
+ )
88
+ self.query_dense.build(inputs_shape)
89
+
90
+ self.key_dense = keras.layers.EinsumDense(
91
+ equation="bkm,mvh->bkvh",
92
+ output_shape=(
93
+ None,
94
+ self.num_key_value_heads,
95
+ self._head_dim,
96
+ ),
97
+ bias_axes="vh",
98
+ kernel_initializer=self._kernel_initializer,
99
+ bias_initializer="zeros",
100
+ dtype=self.dtype_policy,
101
+ name="key",
102
+ )
103
+ self.key_dense.build(inputs_shape)
104
+
105
+ self.value_dense = keras.layers.EinsumDense(
106
+ equation="bkm,mvh->bkvh",
107
+ output_shape=(
108
+ None,
109
+ self.num_key_value_heads,
110
+ self._head_dim,
111
+ ),
112
+ bias_axes="vh",
113
+ kernel_initializer=self._kernel_initializer,
114
+ bias_initializer="zeros",
115
+ dtype=self.dtype_policy,
116
+ name="value",
117
+ )
118
+ self.value_dense.build(inputs_shape)
119
+
120
+ self.dropout_layer = keras.layers.Dropout(
121
+ rate=self.dropout,
122
+ dtype=self.dtype_policy,
123
+ )
124
+
125
+ self.output_dense = keras.layers.EinsumDense(
126
+ equation="bquh,uhm->bqm",
127
+ output_shape=(None, self._hidden_dim),
128
+ bias_axes="m",
129
+ kernel_initializer=self._kernel_initializer,
130
+ bias_initializer="zeros",
131
+ dtype=self.dtype_policy,
132
+ name="attention_output",
133
+ )
134
+ self.output_dense.build(
135
+ (None, None, self.num_query_heads, self._head_dim)
136
+ )
137
+
138
+ self.rotary_embedding_layer = RotaryEmbedding(
139
+ max_wavelength=self.rope_max_wavelength,
140
+ scaling_factor=self.rope_scaling_factor, # YaRN scaling factor
141
+ rope_type="yarn",
142
+ beta_fast=32.0,
143
+ beta_slow=1.0,
144
+ original_max_position_embeddings=4096,
145
+ dtype=self.dtype_policy,
146
+ )
147
+
148
+ self.sinks = self.add_weight(
149
+ shape=(self.num_query_heads,),
150
+ initializer="random_normal",
151
+ dtype=self.dtype,
152
+ name="sinks",
153
+ )
154
+
155
+ self._dot_product_equation = "bquh,bkuh->buqk"
156
+ self._combine_equation = "buqk,bkuh->bquh"
157
+
158
+ self.built = True
159
+
160
+ def call(
161
+ self,
162
+ hidden_states,
163
+ attention_mask=None,
164
+ cache=None,
165
+ cache_update_index=None,
166
+ training=None,
167
+ ):
168
+ start_index = (
169
+ cache_update_index if cache_update_index is not None else 0
170
+ )
171
+
172
+ query = self.query_dense(hidden_states)
173
+
174
+ # Compute RoPE for queries (only
175
+ # to first _rotary_dim dimensions)
176
+ if self._rotary_dim < self._head_dim:
177
+ query_rot = query[..., : self._rotary_dim]
178
+ query_rot = self.rotary_embedding_layer(
179
+ query_rot, start_index=start_index
180
+ )
181
+ query = ops.concatenate(
182
+ [query_rot, query[..., self._rotary_dim :]], axis=-1
183
+ )
184
+ else:
185
+ query = self.rotary_embedding_layer(query, start_index=start_index)
186
+
187
+ def _compute_key_value(x):
188
+ key, value = self.key_dense(x), self.value_dense(x)
189
+ # Compute RoPE for keys (only apply to first _rotary_dim dimensions)
190
+ if self._rotary_dim < self._head_dim:
191
+ key_rot = key[..., : self._rotary_dim]
192
+ key_rot = self.rotary_embedding_layer(
193
+ key_rot, start_index=start_index
194
+ )
195
+ key = ops.concatenate(
196
+ [key_rot, key[..., self._rotary_dim :]], axis=-1
197
+ )
198
+ else:
199
+ key = self.rotary_embedding_layer(key, start_index=start_index)
200
+ return key, value
201
+
202
+ if cache is not None:
203
+ key_cache = cache[:, 0, ...]
204
+ value_cache = cache[:, 1, ...]
205
+ if cache_update_index is None:
206
+ key = key_cache
207
+ value = value_cache
208
+ else:
209
+ key_update, value_update = _compute_key_value(hidden_states)
210
+ start = [0, cache_update_index, 0, 0]
211
+ key = ops.slice_update(key_cache, start, key_update)
212
+ value = ops.slice_update(value_cache, start, value_update)
213
+ cache = ops.stack((key, value), axis=1)
214
+ else:
215
+ if cache_update_index is not None:
216
+ raise ValueError(
217
+ "`cache_update_index` should not be set if `cache` is "
218
+ f"`None`. Received: cache={cache}, "
219
+ f"cache_update_index={cache_update_index}"
220
+ )
221
+ key, value = _compute_key_value(hidden_states)
222
+
223
+ # [batch_shape, seq_len, num_key_value_heads, head_dim]
224
+ # -> [batch_shape, seq_len, num_heads, head_dim]
225
+ key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
226
+ value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
227
+
228
+ attention_output = self._compute_attention(
229
+ query, key, value, attention_mask, start_index
230
+ )
231
+
232
+ attention_output = self.dropout_layer(
233
+ attention_output, training=training
234
+ )
235
+
236
+ attention_output = self.output_dense(attention_output)
237
+
238
+ if cache is not None:
239
+ return attention_output, cache
240
+ return attention_output
241
+
242
+ def _compute_attention(
243
+ self, query, key, value, attention_mask=None, start_index=0
244
+ ):
245
+ attention_scores = ops.einsum(self._dot_product_equation, query, key)
246
+ attention_scores = ops.multiply(
247
+ attention_scores,
248
+ ops.cast(self._inv_norm_factor, self.compute_dtype),
249
+ )
250
+
251
+ # Apply sliding window mask if specified
252
+ if self.sliding_window is not None and self.sliding_window > 0:
253
+ q_len = ops.shape(attention_scores)[-2]
254
+ kv_len = ops.shape(attention_scores)[-1]
255
+
256
+ # Query positions are offset by start_index during generation
257
+ q_positions = ops.arange(q_len) + start_index
258
+ kv_positions = ops.arange(kv_len)
259
+
260
+ # Mask true for positions outside sliding window
261
+ # For causal attention: mask if kv_pos < q_pos - sliding_window
262
+ mask = (
263
+ kv_positions[None, :]
264
+ >= q_positions[:, None] - self.sliding_window
265
+ )
266
+ if self.compute_dtype == "float32":
267
+ sliding_adder = ops.cast(-1e9, self.compute_dtype)
268
+ else:
269
+ sliding_adder = ops.cast(-1e4, self.compute_dtype)
270
+ attention_scores = ops.where(
271
+ mask[None, None, :, :], attention_scores, sliding_adder
272
+ )
273
+
274
+ if attention_mask is not None:
275
+ # The mask is a boolean tensor, True for positions to be masked.
276
+ # We add a large negative number to the masked positions.
277
+ # Use a large negative value for masking
278
+ if self.compute_dtype == "float32":
279
+ adder = ops.cast(-1e9, self.compute_dtype)
280
+ else:
281
+ adder = ops.cast(-1e4, self.compute_dtype)
282
+ attention_scores = ops.where(
283
+ attention_mask[:, None, :, :], attention_scores, adder
284
+ )
285
+
286
+ # Handle sink tokens by concatenating them to the logits.
287
+ b = ops.shape(attention_scores)[0]
288
+ q = ops.shape(attention_scores)[2]
289
+
290
+ sinks = ops.reshape(self.sinks, (1, self.num_query_heads, 1, 1))
291
+ sinks = ops.broadcast_to(sinks, (b, self.num_query_heads, q, 1))
292
+ # attention_scores shape: [b, num_heads, q, k]
293
+ # sinks shape: [b, num_heads, q, 1]
294
+ # We need to concatenate along the last dimension
295
+ combined_logits = ops.concatenate([attention_scores, sinks], axis=-1)
296
+
297
+ # Stabilize logits before softmax for numerical stability.
298
+ max_logits = ops.max(combined_logits, axis=-1, keepdims=True)
299
+ max_logits = ops.stop_gradient(max_logits)
300
+ combined_logits = combined_logits - max_logits
301
+
302
+ probs = ops.softmax(combined_logits, axis=-1)
303
+
304
+ # Remove the sink probabilities before computing the output.
305
+ attention_scores = probs[..., :-1]
306
+ attention_scores = ops.cast(attention_scores, self.compute_dtype)
307
+
308
+ attention_output = ops.einsum(
309
+ self._combine_equation, attention_scores, value
310
+ )
311
+
312
+ return attention_output
313
+
314
+ def get_config(self):
315
+ config = super().get_config()
316
+ config.update(
317
+ {
318
+ "num_query_heads": self.num_query_heads,
319
+ "num_key_value_heads": self.num_key_value_heads,
320
+ "rope_max_wavelength": self.rope_max_wavelength,
321
+ "rope_scaling_factor": self.rope_scaling_factor,
322
+ "kernel_initializer": keras.initializers.serialize(
323
+ self._kernel_initializer
324
+ ),
325
+ "sliding_window": self.sliding_window,
326
+ "dropout": self.dropout,
327
+ "head_dim": self.head_dim,
328
+ }
329
+ )
330
+ return config
@@ -0,0 +1,221 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.layers.modeling.reversible_embedding import (
5
+ ReversibleEmbedding,
6
+ )
7
+ from keras_hub.src.models.backbone import Backbone
8
+ from keras_hub.src.models.gpt_oss.gpt_oss_decoder import (
9
+ GptOssTransformerDecoder,
10
+ )
11
+ from keras_hub.src.models.gpt_oss.gpt_oss_layer_norm import (
12
+ GptOssLayerNormalization,
13
+ )
14
+
15
+
16
+ def _gpt_oss_kernel_initializer(stddev=0.02):
17
+ return keras.initializers.RandomNormal(stddev=stddev)
18
+
19
+
20
+ @keras_hub_export("keras_hub.models.GptOssBackbone")
21
+ class GptOssBackbone(Backbone):
22
+ """A GPT-style Transformer with a Mixture of Experts.
23
+
24
+ This network implements a GPT-style decoder network with Mixture of Expert
25
+ (MoE) layers, similar to the architecture described in
26
+ ["Mixtral of Experts"](https://arxiv.org/pdf/2401.04088) but with
27
+ customizations found in some open-source GPT models. It includes the
28
+ embedding lookups and transformer layers.
29
+
30
+ The default constructor gives a fully customizable, randomly initialized
31
+ GptOss model with any number of layers, heads, and embedding
32
+ dimensions. To load preset architectures and weights, use the `from_preset`
33
+ constructor.
34
+
35
+ Args:
36
+ vocabulary_size: int. The size of the token vocabulary.
37
+ num_layers: int. The number of transformer layers.
38
+ num_query_heads: int. The number of query attention heads for
39
+ each transformer.
40
+ hidden_dim: int. The size of the transformer encoding and pooling
41
+ layers.
42
+ intermediate_dim: int. The output dimension of the first Dense layer
43
+ in a three-layer feedforward network for each transformer.
44
+ num_key_value_heads: int. The number of key and value attention heads
45
+ for each transformer.
46
+ num_experts: int. The number of experts for the MoE layers.
47
+ top_k: int. The number of experts to use for each token.
48
+ Defaults to `2`.
49
+ rope_max_wavelength: int. The maximum angular wavelength of
50
+ the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
51
+ rope_scaling_factor: float. The scaling factor for
52
+ calculation of roatary embedding. Defaults to `1.0`.
53
+ layer_norm_epsilon: float. Epsilon for the layer
54
+ normalization layers in the transformer decoder. Defaults to `1e-6`.
55
+ sliding_window: int. The sliding window for the attention
56
+ layers. This controls the maximum cache size for the attention
57
+ layers in each transformer decoder. Only `sliding_window` number
58
+ of tokens are saved in the cache and used to generate the next
59
+ token. Defaults to `4096`.
60
+ head_dim: int. Head dimension for attention layers. This
61
+ parameter is accepted for HuggingFace compatibility but ignored.
62
+ The head dimension is calculated dynamically as hidden_dim //
63
+ num_query_heads. Defaults to `None`.
64
+ dropout: float. Attention dropout probability.
65
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
66
+ for model computations and weights. Note that some computations,
67
+ such as softmax and layer normalization, will always be done at
68
+ `float32` precision regardless of dtype.
69
+
70
+ Examples:
71
+
72
+ ```python
73
+ import numpy as np
74
+ import keras_hub
75
+
76
+ # Load a pretrained GptOss backbone from a preset.
77
+ model = keras_hub.models.GptOssBackbone.from_preset("gpt_oss_20b_en")
78
+
79
+ input_data = {
80
+ "token_ids": np.ones(shape=(1, 12), dtype="int32"),
81
+ "padding_mask": np.array(
82
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], dtype="int32"
83
+ ),
84
+ }
85
+
86
+ model(input_data)
87
+
88
+ # Randomly initialized GptOss decoder with custom config.
89
+ model = keras_hub.models.GptOssBackbone(
90
+ vocabulary_size=10,
91
+ hidden_dim=512,
92
+ num_layers=2,
93
+ num_query_heads=32,
94
+ num_key_value_heads=8,
95
+ intermediate_dim=1024,
96
+ num_experts=4,
97
+ top_k=2,
98
+ sliding_window=256,
99
+ layer_norm_epsilon=1e-6,
100
+ dtype="float32"
101
+ )
102
+ model(input_data)
103
+ ```
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ vocabulary_size,
109
+ num_layers,
110
+ num_query_heads,
111
+ hidden_dim,
112
+ intermediate_dim,
113
+ num_key_value_heads,
114
+ num_experts,
115
+ top_k=2,
116
+ rope_max_wavelength=10000,
117
+ rope_scaling_factor=1.0,
118
+ layer_norm_epsilon=1e-6,
119
+ sliding_window=4096,
120
+ head_dim=None,
121
+ dropout=0,
122
+ output_router_logits=False,
123
+ dtype=None,
124
+ **kwargs,
125
+ ):
126
+ # === Layers ===
127
+ self.token_embedding = ReversibleEmbedding(
128
+ input_dim=vocabulary_size,
129
+ output_dim=hidden_dim,
130
+ tie_weights=False,
131
+ embeddings_initializer=_gpt_oss_kernel_initializer(stddev=0.01),
132
+ dtype=dtype,
133
+ name="token_embedding",
134
+ )
135
+ self.transformer_layers = []
136
+ for i in range(num_layers):
137
+ layer = GptOssTransformerDecoder(
138
+ intermediate_dim=intermediate_dim,
139
+ num_query_heads=num_query_heads,
140
+ num_key_value_heads=num_key_value_heads,
141
+ num_experts=num_experts,
142
+ top_k=top_k,
143
+ output_router_logits=output_router_logits,
144
+ rope_max_wavelength=rope_max_wavelength,
145
+ rope_scaling_factor=rope_scaling_factor,
146
+ layer_norm_epsilon=layer_norm_epsilon,
147
+ kernel_initializer=_gpt_oss_kernel_initializer(stddev=0.02),
148
+ # GPT-OSS uses SW attention in every other layer
149
+ sliding_window=sliding_window if i % 2 == 1 else None,
150
+ dropout=dropout,
151
+ head_dim=head_dim,
152
+ dtype=dtype,
153
+ name=f"transformer_layer_{i}",
154
+ )
155
+ self.transformer_layers.append(layer)
156
+ self.layer_norm = GptOssLayerNormalization(
157
+ epsilon=layer_norm_epsilon,
158
+ dtype=dtype,
159
+ name="sequence_output_layernorm",
160
+ )
161
+
162
+ # === Functional Model ===
163
+ token_id_input = keras.Input(
164
+ shape=(None,), dtype="int32", name="token_ids"
165
+ )
166
+ padding_mask_input = keras.Input(
167
+ shape=(None,), dtype="int32", name="padding_mask"
168
+ )
169
+ x = self.token_embedding(token_id_input)
170
+ for transformer_layer in self.transformer_layers:
171
+ x = transformer_layer(x, decoder_padding_mask=padding_mask_input)
172
+ sequence_output = self.layer_norm(x)
173
+ super().__init__(
174
+ inputs={
175
+ "token_ids": token_id_input,
176
+ "padding_mask": padding_mask_input,
177
+ },
178
+ outputs=sequence_output,
179
+ dtype=dtype,
180
+ **kwargs,
181
+ )
182
+
183
+ # === Config ===
184
+ self.vocabulary_size = vocabulary_size
185
+ self.num_layers = num_layers
186
+ self.num_query_heads = num_query_heads
187
+ self.hidden_dim = hidden_dim
188
+ self.intermediate_dim = intermediate_dim
189
+ self.num_key_value_heads = num_key_value_heads
190
+ self.num_experts = num_experts
191
+ self.top_k = top_k
192
+ self.rope_max_wavelength = rope_max_wavelength
193
+ self.rope_scaling_factor = rope_scaling_factor
194
+ self.sliding_window = sliding_window
195
+ self.layer_norm_epsilon = layer_norm_epsilon
196
+ self.dropout = dropout
197
+ self.output_router_logits = output_router_logits
198
+ self.head_dim = head_dim
199
+
200
+ def get_config(self):
201
+ config = super().get_config()
202
+ config.update(
203
+ {
204
+ "vocabulary_size": self.vocabulary_size,
205
+ "num_layers": self.num_layers,
206
+ "num_query_heads": self.num_query_heads,
207
+ "hidden_dim": self.hidden_dim,
208
+ "intermediate_dim": self.intermediate_dim,
209
+ "num_experts": self.num_experts,
210
+ "top_k": self.top_k,
211
+ "rope_max_wavelength": self.rope_max_wavelength,
212
+ "rope_scaling_factor": self.rope_scaling_factor,
213
+ "num_key_value_heads": self.num_key_value_heads,
214
+ "sliding_window": self.sliding_window,
215
+ "layer_norm_epsilon": self.layer_norm_epsilon,
216
+ "dropout": self.dropout,
217
+ "output_router_logits": self.output_router_logits,
218
+ "head_dim": self.head_dim,
219
+ }
220
+ )
221
+ return config