keras-hub 0.20.0.dev1__py3-none-any.whl → 0.21.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. keras_hub/__init__.py +15 -33
  2. keras_hub/layers/__init__.py +134 -0
  3. keras_hub/metrics/__init__.py +11 -0
  4. keras_hub/models/__init__.py +642 -0
  5. keras_hub/samplers/__init__.py +18 -0
  6. keras_hub/src/layers/modeling/reversible_embedding.py +25 -35
  7. keras_hub/src/layers/preprocessing/image_converter.py +1 -0
  8. keras_hub/src/layers/preprocessing/random_deletion.py +1 -1
  9. keras_hub/src/layers/preprocessing/random_swap.py +1 -1
  10. keras_hub/src/models/audio_to_text.py +66 -0
  11. keras_hub/src/models/audio_to_text_preprocessor.py +80 -0
  12. keras_hub/src/models/backbone.py +5 -2
  13. keras_hub/src/models/cspnet/cspnet_backbone.py +51 -26
  14. keras_hub/src/models/cspnet/cspnet_presets.py +38 -3
  15. keras_hub/src/models/falcon/falcon_backbone.py +1 -1
  16. keras_hub/src/models/gemma/gemma_presets.py +10 -10
  17. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +3 -2
  18. keras_hub/src/models/gemma3/gemma3_presets.py +8 -8
  19. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  20. keras_hub/src/models/llama/llama_attention.py +24 -6
  21. keras_hub/src/models/llama/llama_backbone.py +50 -16
  22. keras_hub/src/models/llama/llama_decoder.py +20 -3
  23. keras_hub/src/models/llama/llama_presets.py +3 -3
  24. keras_hub/src/models/llama/llama_rotary_embedding.py +180 -0
  25. keras_hub/src/models/llama3/llama3_backbone.py +10 -2
  26. keras_hub/src/models/llama3/llama3_presets.py +84 -2
  27. keras_hub/src/models/mistral/mistral_presets.py +3 -3
  28. keras_hub/src/models/mixtral/__init__.py +5 -0
  29. keras_hub/src/models/mixtral/mixtral_attention.py +252 -0
  30. keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
  31. keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
  32. keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
  33. keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
  34. keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
  35. keras_hub/src/models/mixtral/mixtral_presets.py +26 -0
  36. keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
  37. keras_hub/src/models/moonshine/__init__.py +5 -0
  38. keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
  39. keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
  40. keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +272 -0
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
  42. keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
  43. keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
  44. keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
  45. keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
  46. keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
  47. keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
  48. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +11 -11
  49. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +1 -1
  50. keras_hub/src/models/qwen/__init__.py +4 -0
  51. keras_hub/src/models/qwen/qwen_attention.py +3 -1
  52. keras_hub/src/models/qwen/qwen_backbone.py +8 -1
  53. keras_hub/src/models/qwen/qwen_causal_lm.py +7 -0
  54. keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +7 -0
  55. keras_hub/src/models/qwen/qwen_presets.py +61 -0
  56. keras_hub/src/models/qwen/qwen_tokenizer.py +9 -0
  57. keras_hub/src/models/qwen_moe/__init__.py +5 -0
  58. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +375 -0
  59. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
  60. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
  61. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
  62. keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
  63. keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
  64. keras_hub/src/models/qwen_moe/qwen_moe_presets.py +15 -0
  65. keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
  66. keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
  67. keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
  68. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +0 -18
  69. keras_hub/src/models/segformer/segformer_presets.py +12 -12
  70. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +6 -0
  71. keras_hub/src/models/task.py +5 -2
  72. keras_hub/src/models/xception/__init__.py +5 -0
  73. keras_hub/src/models/xception/xception_backbone.py +188 -0
  74. keras_hub/src/models/xception/xception_image_classifier.py +12 -0
  75. keras_hub/src/models/xception/xception_image_classifier_preprocessor.py +14 -0
  76. keras_hub/src/models/xception/xception_image_converter.py +8 -0
  77. keras_hub/src/models/xception/xception_presets.py +14 -0
  78. keras_hub/src/tests/mocks/mock_gemma3_tokenizer.py +155 -0
  79. keras_hub/src/utils/coco/__init__.py +0 -0
  80. keras_hub/src/utils/coco/coco_utils.py +133 -0
  81. keras_hub/src/utils/imagenet/imagenet_utils.py +36 -0
  82. keras_hub/src/utils/keras_utils.py +11 -0
  83. keras_hub/src/utils/preset_utils.py +70 -10
  84. keras_hub/src/utils/tensor_utils.py +27 -1
  85. keras_hub/src/utils/timm/convert_cspnet.py +94 -23
  86. keras_hub/src/utils/timm/preset_loader.py +6 -6
  87. keras_hub/src/utils/transformers/convert_llama3.py +21 -1
  88. keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
  89. keras_hub/src/utils/transformers/convert_qwen.py +1 -0
  90. keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
  91. keras_hub/src/utils/transformers/preset_loader.py +6 -0
  92. keras_hub/src/{version_utils.py → version.py} +1 -1
  93. keras_hub/tokenizers/__init__.py +117 -0
  94. keras_hub/utils/__init__.py +21 -0
  95. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/METADATA +6 -20
  96. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/RECORD +98 -55
  97. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/WHEEL +1 -1
  98. keras_hub/api/__init__.py +0 -15
  99. keras_hub/api/layers/__init__.py +0 -86
  100. keras_hub/api/metrics/__init__.py +0 -11
  101. keras_hub/api/models/__init__.py +0 -416
  102. keras_hub/api/samplers/__init__.py +0 -16
  103. keras_hub/api/tokenizers/__init__.py +0 -58
  104. keras_hub/api/utils/__init__.py +0 -9
  105. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,252 @@
1
+ import inspect
2
+ import math
3
+
4
+ import keras
5
+ from keras import ops
6
+
7
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
8
+ from keras_hub.src.utils.keras_utils import clone_initializer
9
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
10
+ from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
11
+ from keras_hub.src.utils.keras_utils import running_on_gpu
12
+ from keras_hub.src.utils.keras_utils import running_on_tpu
13
+
14
+
15
+ class CachedMixtralAttention(keras.layers.Layer):
16
+ """A cached grounded query attention layer with sliding window."""
17
+
18
+ def __init__(
19
+ self,
20
+ num_query_heads,
21
+ num_key_value_heads,
22
+ rope_max_wavelength=10000,
23
+ rope_scaling_factor=1.0,
24
+ kernel_initializer="glorot_uniform",
25
+ sliding_window=512,
26
+ dropout=0,
27
+ **kwargs,
28
+ ):
29
+ super().__init__(**kwargs)
30
+ self.num_query_heads = num_query_heads
31
+ self.num_key_value_heads = num_key_value_heads
32
+ self.sliding_window = sliding_window
33
+ self.dropout = dropout
34
+
35
+ self.num_key_value_groups = num_query_heads // num_key_value_heads
36
+ self.rope_max_wavelength = rope_max_wavelength
37
+
38
+ self._kernel_initializer = keras.initializers.get(
39
+ clone_initializer(kernel_initializer)
40
+ )
41
+
42
+ self.rope_scaling_factor = rope_scaling_factor
43
+
44
+ def build(self, inputs_shape):
45
+ # Einsum variables:
46
+ # b = batch size
47
+ # q = query length
48
+ # k = key/value length
49
+ # m = model dim
50
+ # u = num query heads
51
+ # v = num key/value heads
52
+ # h = head dim
53
+ self._hidden_dim = inputs_shape[-1]
54
+ self._head_dim = self._hidden_dim // self.num_query_heads
55
+ self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)
56
+
57
+ self.query_dense = keras.layers.EinsumDense(
58
+ equation="bqm,muh->bquh",
59
+ output_shape=(None, self.num_query_heads, self._head_dim),
60
+ kernel_initializer=self._kernel_initializer,
61
+ dtype=self.dtype_policy,
62
+ name="query",
63
+ )
64
+ self.query_dense.build(inputs_shape)
65
+
66
+ self.key_dense = keras.layers.EinsumDense(
67
+ equation="bkm,mvh->bkvh",
68
+ output_shape=(
69
+ None,
70
+ self.num_key_value_heads,
71
+ self._head_dim,
72
+ ),
73
+ kernel_initializer=self._kernel_initializer,
74
+ dtype=self.dtype_policy,
75
+ name="key",
76
+ )
77
+ self.key_dense.build(inputs_shape)
78
+
79
+ self.value_dense = keras.layers.EinsumDense(
80
+ equation="bkm,mvh->bkvh",
81
+ output_shape=(
82
+ None,
83
+ self.num_key_value_heads,
84
+ self._head_dim,
85
+ ),
86
+ kernel_initializer=self._kernel_initializer,
87
+ dtype=self.dtype_policy,
88
+ name="value",
89
+ )
90
+ self.value_dense.build(inputs_shape)
91
+
92
+ self.softmax = keras.layers.Softmax(
93
+ axis=-1,
94
+ dtype="float32",
95
+ name="attention_softmax",
96
+ )
97
+
98
+ self.dropout_layer = keras.layers.Dropout(
99
+ rate=self.dropout,
100
+ dtype=self.dtype_policy,
101
+ )
102
+
103
+ self.output_dense = keras.layers.EinsumDense(
104
+ equation="bquh,uhm->bqm",
105
+ output_shape=(None, self._hidden_dim),
106
+ kernel_initializer=self._kernel_initializer,
107
+ dtype=self.dtype_policy,
108
+ name="attention_output",
109
+ )
110
+ self.output_dense.build(
111
+ (None, None, self.num_query_heads, self._head_dim)
112
+ )
113
+
114
+ self.rotary_embedding_layer = RotaryEmbedding(
115
+ max_wavelength=self.rope_max_wavelength,
116
+ scaling_factor=self.rope_scaling_factor,
117
+ dtype=self.dtype_policy,
118
+ )
119
+
120
+ self._dot_product_equation = "bquh,bkuh->buqk"
121
+ self._combine_equation = "buqk,bkuh->bquh"
122
+
123
+ self.built = True
124
+
125
+ def call(
126
+ self,
127
+ hidden_states,
128
+ attention_mask=None,
129
+ cache=None,
130
+ cache_update_index=None,
131
+ training=None,
132
+ ):
133
+ start_index = (
134
+ cache_update_index if cache_update_index is not None else 0
135
+ )
136
+
137
+ query = self.query_dense(hidden_states)
138
+
139
+ # Compute RoPE for queries
140
+ query = self.rotary_embedding_layer(query, start_index=start_index)
141
+
142
+ def _compute_key_value(x):
143
+ key, value = self.key_dense(x), self.value_dense(x)
144
+ # Compute RoPE for keys
145
+ key = self.rotary_embedding_layer(key, start_index=start_index)
146
+ return key, value
147
+
148
+ if cache is not None:
149
+ key_cache = cache[:, 0, ...]
150
+ value_cache = cache[:, 1, ...]
151
+ if cache_update_index is None:
152
+ key = key_cache
153
+ value = value_cache
154
+ else:
155
+ key_update, value_update = _compute_key_value(hidden_states)
156
+ start = [0, cache_update_index, 0, 0]
157
+ key = ops.slice_update(key_cache, start, key_update)
158
+ value = ops.slice_update(value_cache, start, value_update)
159
+ cache = ops.stack((key, value), axis=1)
160
+ else:
161
+ if cache_update_index is not None:
162
+ raise ValueError(
163
+ "`cache_update_index` should not be set if `cache` is "
164
+ f"`None`. Received: cache={cache}, "
165
+ f"cache_update_index={cache_update_index}"
166
+ )
167
+ key, value = _compute_key_value(hidden_states)
168
+
169
+ # [batch_shape, seq_len, num_key_value_heads, head_dim]
170
+ # -> [batch_shape, seq_len, num_heads, head_dim]
171
+ key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
172
+ value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
173
+
174
+ attention_output = self._compute_attention(
175
+ query, key, value, attention_mask
176
+ )
177
+
178
+ attention_output = self.dropout_layer(
179
+ attention_output, training=training
180
+ )
181
+
182
+ attention_output = self.output_dense(attention_output)
183
+
184
+ if cache is not None:
185
+ return attention_output, cache
186
+ return attention_output
187
+
188
+ def _masked_softmax(self, attention_scores, attention_mask=None):
189
+ if attention_mask is not None:
190
+ return self.softmax(attention_scores, attention_mask[:, None, :, :])
191
+ return self.softmax(attention_scores)
192
+
193
+ def _use_fused_attention_op(self):
194
+ if not fused_attention_op_available():
195
+ return False
196
+ if self.dropout > 0.0:
197
+ return False
198
+ if running_on_gpu():
199
+ return gpu_supports_fused_attention_op()
200
+ elif running_on_tpu():
201
+ # TPU supports softcap with on keras >= 3.10.
202
+ sig = inspect.signature(ops.dot_product_attention)
203
+ return "attn_logits_soft_cap" in sig.parameters
204
+ else:
205
+ return False
206
+
207
+ def _compute_attention(self, query, key, value, attention_mask=None):
208
+ if self._use_fused_attention_op():
209
+ if attention_mask is not None:
210
+ attention_mask = ops.expand_dims(attention_mask, axis=1)
211
+ attention_mask = ops.cast(attention_mask, dtype="bool")
212
+
213
+ attention_output = ops.dot_product_attention(
214
+ query,
215
+ key,
216
+ value,
217
+ mask=attention_mask,
218
+ scale=self._inv_norm_factor,
219
+ )
220
+ return attention_output
221
+
222
+ attention_scores = ops.einsum(self._dot_product_equation, query, key)
223
+ attention_scores = ops.multiply(
224
+ attention_scores,
225
+ ops.cast(self._inv_norm_factor, self.compute_dtype),
226
+ )
227
+ attention_scores = self._masked_softmax(
228
+ attention_scores, attention_mask
229
+ )
230
+ attention_scores = ops.cast(attention_scores, self.compute_dtype)
231
+ attention_output = ops.einsum(
232
+ self._combine_equation, attention_scores, value
233
+ )
234
+
235
+ return attention_output
236
+
237
+ def get_config(self):
238
+ config = super().get_config()
239
+ config.update(
240
+ {
241
+ "num_query_heads": self.num_query_heads,
242
+ "num_key_value_heads": self.num_key_value_heads,
243
+ "rope_max_wavelength": self.rope_max_wavelength,
244
+ "rope_scaling_factor": self.rope_scaling_factor,
245
+ "kernel_initializer": keras.initializers.serialize(
246
+ self._kernel_initializer
247
+ ),
248
+ "sliding_window": self.sliding_window,
249
+ "dropout": self.dropout,
250
+ }
251
+ )
252
+ return config
@@ -0,0 +1,207 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.layers.modeling.reversible_embedding import (
6
+ ReversibleEmbedding,
7
+ )
8
+ from keras_hub.src.models.backbone import Backbone
9
+ from keras_hub.src.models.mixtral.mixtral_decoder import (
10
+ MixtralTransformerDecoder,
11
+ )
12
+ from keras_hub.src.models.mixtral.mixtral_layer_norm import (
13
+ MixtralLayerNormalization,
14
+ )
15
+
16
+
17
+ def _mixtral_kernel_initializer(stddev=0.02):
18
+ return keras.initializers.RandomNormal(stddev=stddev)
19
+
20
+
21
+ @keras_hub_export("keras_hub.models.MixtralBackbone")
22
+ class MixtralBackbone(Backbone):
23
+ """The Mixtral Transformer core architecture with hyperparameters.
24
+
25
+ This network implements a mixture of Experts based decoder network,
26
+ Mixtral, as described in
27
+ ["Mixtral of Experts"](https://arxiv.org/pdf/2401.04088).
28
+ It includes the embedding lookups and transformer layers.
29
+
30
+ The default constructor gives a fully customizable, randomly initialized
31
+ Mixtral model with any number of layers, heads, and embedding
32
+ dimensions. To load preset architectures and weights, use the `from_preset`
33
+ constructor.
34
+
35
+ Args:
36
+ vocabulary_size (int): The size of the token vocabulary.
37
+ num_layers (int): The number of transformer layers.
38
+ num_query_heads (int): The number of query attention heads for
39
+ each transformer.
40
+ hidden_dim (int): The size of the transformer encoding and pooling
41
+ layers.
42
+ intermediate_dim (int): The output dimension of the first Dense layer
43
+ in a three-layer feedforward network for each transformer.
44
+ num_key_value_heads (int): The number of key and value attention heads
45
+ for each transformer.
46
+ rope_max_wavelength (int, optional): The maximum angular wavelength of
47
+ the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
48
+ rope_scaling_factor (float, optional): The scaling factor for
49
+ calculation of roatary embedding. Defaults to `1.0`.
50
+ layer_norm_epsilon (float, optional): Epsilon for the layer
51
+ normalization layers in the transformer decoder. Defaults to `1e-6`.
52
+ sliding_window (int, optional): The sliding window for the mixtral
53
+ attention layers. This controls the maximum cache size for the
54
+ attention layers in each transformer decoder. Only `sliding_window`
55
+ number of tokens are saved in the cache and used to generate the
56
+ next token. Defaults to `512`.
57
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
58
+ for model computations and weights. Note that some computations,
59
+ such as softmax and layer normalization, will always be done at
60
+ float32 precision regardless of dtype.
61
+
62
+ Examples:
63
+
64
+ ```python
65
+ input_data = {
66
+ "token_ids": np.ones(shape=(1, 12), dtype="int32"),
67
+ "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
68
+ }
69
+
70
+ # Pretrained Mixtral decoder.
71
+ model = keras_hub.models.MixtralBackbone.from_preset("mixtral7b_base_en")
72
+ model(input_data)
73
+
74
+ # Randomly initialized Mixtral decoder with custom config.
75
+ model = keras_hub.models.MixtralBackbone(
76
+ vocabulary_size=10,
77
+ hidden_dim=512,
78
+ num_layers=2,
79
+ num_query_heads=32,
80
+ num_key_value_heads=8,
81
+ intermediate_dim=1024,
82
+ sliding_window=512,
83
+ layer_norm_epsilon=1e-6,
84
+ dtype="float32"
85
+ )
86
+ model(input_data)
87
+ ```
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ vocabulary_size,
93
+ num_layers,
94
+ num_query_heads,
95
+ hidden_dim,
96
+ intermediate_dim,
97
+ num_key_value_heads,
98
+ num_experts,
99
+ top_k=2,
100
+ router_jitter_noise=0.0,
101
+ rope_max_wavelength=10000,
102
+ rope_scaling_factor=1.0,
103
+ layer_norm_epsilon=1e-6,
104
+ router_aux_loss_coef=0.02,
105
+ sliding_window=512,
106
+ dropout=0,
107
+ dtype=None,
108
+ output_router_logits=False,
109
+ **kwargs,
110
+ ):
111
+ # === Layers ===
112
+ self.token_embedding = ReversibleEmbedding(
113
+ input_dim=vocabulary_size,
114
+ output_dim=hidden_dim,
115
+ tie_weights=False,
116
+ embeddings_initializer=_mixtral_kernel_initializer(stddev=0.01),
117
+ dtype=dtype,
118
+ name="token_embedding",
119
+ )
120
+ self.transformer_layers = []
121
+ for i in range(num_layers):
122
+ layer = MixtralTransformerDecoder(
123
+ intermediate_dim=intermediate_dim,
124
+ num_query_heads=num_query_heads,
125
+ num_key_value_heads=num_key_value_heads,
126
+ num_experts=num_experts,
127
+ top_k=top_k,
128
+ router_jitter_noise=router_jitter_noise,
129
+ output_router_logits=output_router_logits,
130
+ rope_max_wavelength=rope_max_wavelength,
131
+ rope_scaling_factor=rope_scaling_factor,
132
+ layer_norm_epsilon=layer_norm_epsilon,
133
+ activation=ops.silu,
134
+ router_aux_loss_coef=router_aux_loss_coef,
135
+ kernel_initializer=_mixtral_kernel_initializer(stddev=0.02),
136
+ sliding_window=sliding_window,
137
+ dropout=dropout,
138
+ dtype=dtype,
139
+ name=f"transformer_layer_{i}",
140
+ )
141
+ self.transformer_layers.append(layer)
142
+ self.layer_norm = MixtralLayerNormalization(
143
+ epsilon=layer_norm_epsilon,
144
+ dtype=dtype,
145
+ name="sequence_output_layernorm",
146
+ )
147
+
148
+ # === Functional Model ===
149
+ token_id_input = keras.Input(
150
+ shape=(None,), dtype="int32", name="token_ids"
151
+ )
152
+ padding_mask_input = keras.Input(
153
+ shape=(None,), dtype="int32", name="padding_mask"
154
+ )
155
+ x = self.token_embedding(token_id_input)
156
+ for transformer_layer in self.transformer_layers:
157
+ x = transformer_layer(x, decoder_padding_mask=padding_mask_input)
158
+ sequence_output = self.layer_norm(x)
159
+ super().__init__(
160
+ inputs={
161
+ "token_ids": token_id_input,
162
+ "padding_mask": padding_mask_input,
163
+ },
164
+ outputs=sequence_output,
165
+ dtype=dtype,
166
+ **kwargs,
167
+ )
168
+
169
+ # === Config ===
170
+ self.vocabulary_size = vocabulary_size
171
+ self.num_layers = num_layers
172
+ self.num_query_heads = num_query_heads
173
+ self.hidden_dim = hidden_dim
174
+ self.intermediate_dim = intermediate_dim
175
+ self.num_key_value_heads = num_key_value_heads
176
+ self.num_experts = num_experts
177
+ self.top_k = top_k
178
+ self.router_jitter_noise = router_jitter_noise
179
+ self.rope_max_wavelength = rope_max_wavelength
180
+ self.router_aux_loss_coef = router_aux_loss_coef
181
+ self.rope_scaling_factor = rope_scaling_factor
182
+ self.sliding_window = sliding_window
183
+ self.layer_norm_epsilon = layer_norm_epsilon
184
+ self.dropout = dropout
185
+
186
+ def get_config(self):
187
+ config = super().get_config()
188
+ config.update(
189
+ {
190
+ "vocabulary_size": self.vocabulary_size,
191
+ "num_layers": self.num_layers,
192
+ "num_query_heads": self.num_query_heads,
193
+ "hidden_dim": self.hidden_dim,
194
+ "intermediate_dim": self.intermediate_dim,
195
+ "num_experts": self.num_experts,
196
+ "top_k": self.top_k,
197
+ "router_jitter_noise": self.router_jitter_noise,
198
+ "rope_max_wavelength": self.rope_max_wavelength,
199
+ "rope_scaling_factor": self.rope_scaling_factor,
200
+ "num_key_value_heads": self.num_key_value_heads,
201
+ "router_aux_loss_coef": self.router_aux_loss_coef,
202
+ "sliding_window": self.sliding_window,
203
+ "layer_norm_epsilon": self.layer_norm_epsilon,
204
+ "dropout": self.dropout,
205
+ }
206
+ )
207
+ return config