keras-hub-nightly 0.21.0.dev202505040408__py3-none-any.whl → 0.21.0.dev202505060405__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. keras_hub/models/__init__.py +21 -0
  2. keras_hub/src/models/backbone.py +5 -2
  3. keras_hub/src/models/mixtral/mixtral_attention.py +263 -0
  4. keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
  5. keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
  6. keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
  7. keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
  8. keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
  9. keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
  10. keras_hub/src/models/qwen/qwen_attention.py +3 -1
  11. keras_hub/src/models/qwen/qwen_presets.py +61 -0
  12. keras_hub/src/models/qwen_moe/__init__.py +0 -0
  13. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +377 -0
  14. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
  15. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
  16. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
  17. keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
  18. keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
  19. keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
  20. keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
  21. keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
  22. keras_hub/src/models/task.py +5 -2
  23. keras_hub/src/utils/keras_utils.py +11 -0
  24. keras_hub/src/utils/preset_utils.py +69 -9
  25. keras_hub/src/utils/tensor_utils.py +27 -1
  26. keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
  27. keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
  28. keras_hub/src/utils/transformers/preset_loader.py +6 -0
  29. keras_hub/src/version.py +1 -1
  30. keras_hub/tokenizers/__init__.py +6 -0
  31. {keras_hub_nightly-0.21.0.dev202505040408.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/METADATA +1 -1
  32. {keras_hub_nightly-0.21.0.dev202505040408.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/RECORD +34 -16
  33. {keras_hub_nightly-0.21.0.dev202505040408.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/WHEEL +1 -1
  34. {keras_hub_nightly-0.21.0.dev202505040408.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/top_level.txt +0 -0
@@ -348,6 +348,18 @@ from keras_hub.src.models.mit.mit_image_classifier import (
348
348
  from keras_hub.src.models.mit.mit_image_classifier_preprocessor import (
349
349
  MiTImageClassifierPreprocessor as MiTImageClassifierPreprocessor,
350
350
  )
351
+ from keras_hub.src.models.mixtral.mixtral_backbone import (
352
+ MixtralBackbone as MixtralBackbone,
353
+ )
354
+ from keras_hub.src.models.mixtral.mixtral_causal_lm import (
355
+ MixtralCausalLM as MixtralCausalLM,
356
+ )
357
+ from keras_hub.src.models.mixtral.mixtral_causal_lm_preprocessor import (
358
+ MixtralCausalLMPreprocessor as MixtralCausalLMPreprocessor,
359
+ )
360
+ from keras_hub.src.models.mixtral.mixtral_tokenizer import (
361
+ MixtralTokenizer as MixtralTokenizer,
362
+ )
351
363
  from keras_hub.src.models.mobilenet.mobilenet_backbone import (
352
364
  MobileNetBackbone as MobileNetBackbone,
353
365
  )
@@ -420,6 +432,15 @@ from keras_hub.src.models.qwen.qwen_tokenizer import (
420
432
  from keras_hub.src.models.qwen.qwen_tokenizer import (
421
433
  QwenTokenizer as QwenTokenizer,
422
434
  )
435
+ from keras_hub.src.models.qwen_moe.qwen_moe_backbone import (
436
+ QwenMoeBackbone as QwenMoeBackbone,
437
+ )
438
+ from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm import (
439
+ QwenMoeCausalLM as QwenMoeCausalLM,
440
+ )
441
+ from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm_preprocessor import (
442
+ QwenMoeCausalLMPreprocessor as QwenMoeCausalLMPreprocessor,
443
+ )
423
444
  from keras_hub.src.models.resnet.resnet_backbone import (
424
445
  ResNetBackbone as ResNetBackbone,
425
446
  )
@@ -177,14 +177,17 @@ class Backbone(keras.Model):
177
177
  )
178
178
  return loader.load_backbone(backbone_cls, load_weights, **kwargs)
179
179
 
180
- def save_to_preset(self, preset_dir):
180
+ def save_to_preset(self, preset_dir, max_shard_size=10):
181
181
  """Save backbone to a preset directory.
182
182
 
183
183
  Args:
184
184
  preset_dir: The path to the local model preset directory.
185
+ max_shard_size: `int` or `float`. Maximum size in GB for each
186
+ sharded file. If `None`, no sharding will be done. Defaults to
187
+ `10`.
185
188
  """
186
189
  saver = get_preset_saver(preset_dir)
187
- saver.save_backbone(self)
190
+ saver.save_backbone(self, max_shard_size=max_shard_size)
188
191
 
189
192
  def get_lora_target_names(self):
190
193
  """Returns list of layer names which are to be LoRA-fied.
@@ -0,0 +1,263 @@
1
+ import inspect
2
+ import math
3
+
4
+ import keras
5
+ from keras import ops
6
+
7
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
8
+ from keras_hub.src.utils.keras_utils import clone_initializer
9
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
10
+ from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
11
+ from keras_hub.src.utils.keras_utils import running_on_gpu
12
+ from keras_hub.src.utils.keras_utils import running_on_tpu
13
+
14
+
15
+ class CachedMixtralAttention(keras.layers.Layer):
16
+ """A cached grounded query attention layer with sliding window."""
17
+
18
+ def __init__(
19
+ self,
20
+ num_query_heads,
21
+ num_key_value_heads,
22
+ rope_max_wavelength=10000,
23
+ rope_scaling_factor=1.0,
24
+ kernel_initializer="glorot_uniform",
25
+ sliding_window=512,
26
+ dropout=0,
27
+ **kwargs,
28
+ ):
29
+ super().__init__(**kwargs)
30
+ self._num_query_heads = num_query_heads
31
+ self._num_key_value_heads = num_key_value_heads
32
+ self._sliding_window = sliding_window
33
+ self._dropout = dropout
34
+
35
+ self._num_key_value_groups = num_query_heads // num_key_value_heads
36
+ self._rope_max_wavelength = rope_max_wavelength
37
+
38
+ self._kernel_initializer = keras.initializers.get(
39
+ clone_initializer(kernel_initializer)
40
+ )
41
+
42
+ self._rope_scaling_factor = rope_scaling_factor
43
+
44
+ def build(self, inputs_shape):
45
+ # Einsum variables:
46
+ # b = batch size
47
+ # q = query length
48
+ # k = key/value length
49
+ # m = model dim
50
+ # u = num query heads
51
+ # v = num key/value heads
52
+ # h = head dim
53
+ self._hidden_dim = inputs_shape[-1]
54
+ self._head_dim = self._hidden_dim // self._num_query_heads
55
+ self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)
56
+
57
+ self.query_dense = keras.layers.EinsumDense(
58
+ equation="bqm,muh->bquh",
59
+ output_shape=(None, self._num_query_heads, self._head_dim),
60
+ kernel_initializer=self._kernel_initializer,
61
+ dtype=self.dtype_policy,
62
+ name="query",
63
+ )
64
+ self.query_dense.build(inputs_shape)
65
+
66
+ self.key_dense = keras.layers.EinsumDense(
67
+ equation="bkm,mvh->bkvh",
68
+ output_shape=(
69
+ None,
70
+ self._num_key_value_heads,
71
+ self._head_dim,
72
+ ),
73
+ kernel_initializer=self._kernel_initializer,
74
+ dtype=self.dtype_policy,
75
+ name="key",
76
+ )
77
+ self.key_dense.build(inputs_shape)
78
+
79
+ self.value_dense = keras.layers.EinsumDense(
80
+ equation="bkm,mvh->bkvh",
81
+ output_shape=(
82
+ None,
83
+ self._num_key_value_heads,
84
+ self._head_dim,
85
+ ),
86
+ kernel_initializer=self._kernel_initializer,
87
+ dtype=self.dtype_policy,
88
+ name="value",
89
+ )
90
+ self.value_dense.build(inputs_shape)
91
+
92
+ self._softmax = keras.layers.Softmax(
93
+ axis=-1,
94
+ dtype="float32",
95
+ name="attention_softmax",
96
+ )
97
+
98
+ self._dropout_layer = keras.layers.Dropout(
99
+ rate=self._dropout,
100
+ dtype=self.dtype_policy,
101
+ )
102
+
103
+ self._output_dense = keras.layers.EinsumDense(
104
+ equation="bquh,uhm->bqm",
105
+ output_shape=(None, self._hidden_dim),
106
+ kernel_initializer=self._kernel_initializer,
107
+ dtype=self.dtype_policy,
108
+ name="attention_output",
109
+ )
110
+ self._output_dense.build(
111
+ (None, None, self._num_query_heads, self._head_dim)
112
+ )
113
+
114
+ self.rotary_embedding_layer = RotaryEmbedding(
115
+ max_wavelength=self._rope_max_wavelength,
116
+ scaling_factor=self._rope_scaling_factor,
117
+ dtype=self.dtype_policy,
118
+ )
119
+
120
+ self._dot_product_equation = "bquh,bkuh->buqk"
121
+ self._combine_equation = "buqk,bkuh->bquh"
122
+
123
+ self.built = True
124
+
125
+ def call(
126
+ self,
127
+ hidden_states,
128
+ attention_mask=None,
129
+ cache=None,
130
+ cache_update_index=None,
131
+ training=None,
132
+ ):
133
+ start_index = (
134
+ cache_update_index if cache_update_index is not None else 0
135
+ )
136
+
137
+ query = self.query_dense(hidden_states)
138
+
139
+ # Compute RoPE for queries
140
+ query = self.rotary_embedding_layer(query, start_index=start_index)
141
+
142
+ def _compute_key_value(x):
143
+ key, value = self.key_dense(x), self.value_dense(x)
144
+ # Compute RoPE for keys
145
+ key = self.rotary_embedding_layer(key, start_index=start_index)
146
+ return key, value
147
+
148
+ if cache is not None:
149
+ key_cache = cache[:, 0, ...]
150
+ value_cache = cache[:, 1, ...]
151
+ if cache_update_index is None:
152
+ key = key_cache
153
+ value = value_cache
154
+ else:
155
+ key_update, value_update = _compute_key_value(hidden_states)
156
+ start = [0, cache_update_index, 0, 0]
157
+ key = ops.slice_update(key_cache, start, key_update)
158
+ value = ops.slice_update(value_cache, start, value_update)
159
+ cache = ops.stack((key, value), axis=1)
160
+ else:
161
+ if cache_update_index is not None:
162
+ raise ValueError(
163
+ "`cache_update_index` should not be set if `cache` is "
164
+ f"`None`. Received: cache={cache}, "
165
+ f"cache_update_index={cache_update_index}"
166
+ )
167
+ key, value = _compute_key_value(hidden_states)
168
+
169
+ # [batch_shape, seq_len, num_key_value_heads, head_dim]
170
+ # -> [batch_shape, seq_len, num_heads, head_dim]
171
+ key = ops.repeat(key, repeats=self._num_key_value_groups, axis=2)
172
+ value = ops.repeat(value, repeats=self._num_key_value_groups, axis=2)
173
+
174
+ attention_output = self._compute_attention(
175
+ query, key, value, attention_mask
176
+ )
177
+
178
+ attention_output = self._dropout_layer(
179
+ attention_output, training=training
180
+ )
181
+
182
+ attention_output = self._output_dense(attention_output)
183
+
184
+ if cache is not None:
185
+ return attention_output, cache
186
+ return attention_output
187
+
188
+ def _masked_softmax(self, attention_scores, attention_mask=None):
189
+ if attention_mask is not None:
190
+ return self._softmax(
191
+ attention_scores, attention_mask[:, None, :, :]
192
+ )
193
+ return self._softmax(attention_scores)
194
+
195
+ def _use_fused_attention_op(self):
196
+ if not fused_attention_op_available():
197
+ return False
198
+ if self.dropout > 0.0:
199
+ return False
200
+ if running_on_gpu():
201
+ # GPU never supports softcap in the fused op.
202
+ if self.logit_soft_cap is not None:
203
+ return False
204
+ return gpu_supports_fused_attention_op()
205
+ elif running_on_tpu():
206
+ # TPU supports softcap with on keras >= 3.10.
207
+ sig = inspect.signature(ops.dot_product_attention)
208
+ return "attn_logits_soft_cap" in sig.parameters
209
+ else:
210
+ return False
211
+
212
+ def _compute_attention(self, query, key, value, attention_mask=None):
213
+ if self._use_fused_attention_op():
214
+ if attention_mask is not None:
215
+ attention_mask = ops.expand_dims(attention_mask, axis=1)
216
+ attention_mask = ops.cast(attention_mask, dtype="bool")
217
+
218
+ if self.logit_soft_cap:
219
+ kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
220
+ else:
221
+ kwargs = {}
222
+
223
+ attention_output = ops.dot_product_attention(
224
+ query,
225
+ key,
226
+ value,
227
+ mask=attention_mask,
228
+ scale=self._inv_norm_factor,
229
+ **kwargs,
230
+ )
231
+ return attention_output
232
+
233
+ attention_scores = ops.einsum(self._dot_product_equation, query, key)
234
+ attention_scores = ops.multiply(
235
+ attention_scores,
236
+ ops.cast(self._inv_norm_factor, self.compute_dtype),
237
+ )
238
+ attention_scores = self._masked_softmax(
239
+ attention_scores, attention_mask
240
+ )
241
+ attention_scores = ops.cast(attention_scores, self.compute_dtype)
242
+ attention_output = ops.einsum(
243
+ self._combine_equation, attention_scores, value
244
+ )
245
+
246
+ return attention_output
247
+
248
+ def get_config(self):
249
+ config = super().get_config()
250
+ config.update(
251
+ {
252
+ "num_query_heads": self._num_query_heads,
253
+ "num_key_value_heads": self._num_key_value_heads,
254
+ "rope_max_wavelength": self._rope_max_wavelength,
255
+ "rope_scaling_factor": self._rope_scaling_factor,
256
+ "kernel_initializer": keras.initializers.serialize(
257
+ self._kernel_initializer
258
+ ),
259
+ "sliding_window": self._sliding_window,
260
+ "dropout": self._dropout,
261
+ }
262
+ )
263
+ return config
@@ -0,0 +1,207 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.layers.modeling.reversible_embedding import (
6
+ ReversibleEmbedding,
7
+ )
8
+ from keras_hub.src.models.backbone import Backbone
9
+ from keras_hub.src.models.mixtral.mixtral_decoder import (
10
+ MixtralTransformerDecoder,
11
+ )
12
+ from keras_hub.src.models.mixtral.mixtral_layer_norm import (
13
+ MixtralLayerNormalization,
14
+ )
15
+
16
+
17
+ def _mixtral_kernel_initializer(stddev=0.02):
18
+ return keras.initializers.RandomNormal(stddev=stddev)
19
+
20
+
21
+ @keras_hub_export("keras_hub.models.MixtralBackbone")
22
+ class MixtralBackbone(Backbone):
23
+ """The Mixtral Transformer core architecture with hyperparameters.
24
+
25
+ This network implements a mixture of Experts based decoder network,
26
+ Mixtral, as described in
27
+ ["Mixtral of Experts"](https://arxiv.org/pdf/2401.04088).
28
+ It includes the embedding lookups and transformer layers.
29
+
30
+ The default constructor gives a fully customizable, randomly initialized
31
+ Mixtral model with any number of layers, heads, and embedding
32
+ dimensions. To load preset architectures and weights, use the `from_preset`
33
+ constructor.
34
+
35
+ Args:
36
+ vocabulary_size (int): The size of the token vocabulary.
37
+ num_layers (int): The number of transformer layers.
38
+ num_query_heads (int): The number of query attention heads for
39
+ each transformer.
40
+ hidden_dim (int): The size of the transformer encoding and pooling
41
+ layers.
42
+ intermediate_dim (int): The output dimension of the first Dense layer
43
+ in a three-layer feedforward network for each transformer.
44
+ num_key_value_heads (int): The number of key and value attention heads
45
+ for each transformer.
46
+ rope_max_wavelength (int, optional): The maximum angular wavelength of
47
+ the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
48
+ rope_scaling_factor (float, optional): The scaling factor for
49
+ calculation of roatary embedding. Defaults to `1.0`.
50
+ layer_norm_epsilon (float, optional): Epsilon for the layer
51
+ normalization layers in the transformer decoder. Defaults to `1e-6`.
52
+ sliding_window (int, optional): The sliding window for the mixtral
53
+ attention layers. This controls the maximum cache size for the
54
+ attention layers in each transformer decoder. Only `sliding_window`
55
+ number of tokens are saved in the cache and used to generate the
56
+ next token. Defaults to `512`.
57
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
58
+ for model computations and weights. Note that some computations,
59
+ such as softmax and layer normalization, will always be done at
60
+ float32 precision regardless of dtype.
61
+
62
+ Examples:
63
+
64
+ ```python
65
+ input_data = {
66
+ "token_ids": np.ones(shape=(1, 12), dtype="int32"),
67
+ "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
68
+ }
69
+
70
+ # Pretrained Mixtral decoder.
71
+ model = keras_hub.models.MixtralBackbone.from_preset("mixtral7b_base_en")
72
+ model(input_data)
73
+
74
+ # Randomly initialized Mixtral decoder with custom config.
75
+ model = keras_hub.models.MixtralBackbone(
76
+ vocabulary_size=10,
77
+ hidden_dim=512,
78
+ num_layers=2,
79
+ num_query_heads=32,
80
+ num_key_value_heads=8,
81
+ intermediate_dim=1024,
82
+ sliding_window=512,
83
+ layer_norm_epsilon=1e-6,
84
+ dtype="float32"
85
+ )
86
+ model(input_data)
87
+ ```
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ vocabulary_size,
93
+ num_layers,
94
+ num_query_heads,
95
+ hidden_dim,
96
+ intermediate_dim,
97
+ num_key_value_heads,
98
+ num_experts,
99
+ top_k=2,
100
+ router_jitter_noise=0.0,
101
+ rope_max_wavelength=10000,
102
+ rope_scaling_factor=1.0,
103
+ layer_norm_epsilon=1e-6,
104
+ router_aux_loss_coef=0.02,
105
+ sliding_window=512,
106
+ dropout=0,
107
+ dtype=None,
108
+ output_router_logits=False,
109
+ **kwargs,
110
+ ):
111
+ # === Layers ===
112
+ self.token_embedding = ReversibleEmbedding(
113
+ input_dim=vocabulary_size,
114
+ output_dim=hidden_dim,
115
+ tie_weights=False,
116
+ embeddings_initializer=_mixtral_kernel_initializer(stddev=0.01),
117
+ dtype=dtype,
118
+ name="token_embedding",
119
+ )
120
+ self.transformer_layers = []
121
+ for i in range(num_layers):
122
+ layer = MixtralTransformerDecoder(
123
+ intermediate_dim=intermediate_dim,
124
+ num_query_heads=num_query_heads,
125
+ num_key_value_heads=num_key_value_heads,
126
+ num_experts=num_experts,
127
+ top_k=top_k,
128
+ router_jitter_noise=router_jitter_noise,
129
+ output_router_logits=output_router_logits,
130
+ rope_max_wavelength=rope_max_wavelength,
131
+ rope_scaling_factor=rope_scaling_factor,
132
+ layer_norm_epsilon=layer_norm_epsilon,
133
+ activation=ops.silu,
134
+ router_aux_loss_coef=router_aux_loss_coef,
135
+ kernel_initializer=_mixtral_kernel_initializer(stddev=0.02),
136
+ sliding_window=sliding_window,
137
+ dropout=dropout,
138
+ dtype=dtype,
139
+ name=f"transformer_layer_{i}",
140
+ )
141
+ self.transformer_layers.append(layer)
142
+ self.layer_norm = MixtralLayerNormalization(
143
+ epsilon=layer_norm_epsilon,
144
+ dtype=dtype,
145
+ name="sequence_output_layernorm",
146
+ )
147
+
148
+ # === Functional Model ===
149
+ token_id_input = keras.Input(
150
+ shape=(None,), dtype="int32", name="token_ids"
151
+ )
152
+ padding_mask_input = keras.Input(
153
+ shape=(None,), dtype="int32", name="padding_mask"
154
+ )
155
+ x = self.token_embedding(token_id_input)
156
+ for transformer_layer in self.transformer_layers:
157
+ x = transformer_layer(x, decoder_padding_mask=padding_mask_input)
158
+ sequence_output = self.layer_norm(x)
159
+ super().__init__(
160
+ inputs={
161
+ "token_ids": token_id_input,
162
+ "padding_mask": padding_mask_input,
163
+ },
164
+ outputs=sequence_output,
165
+ dtype=dtype,
166
+ **kwargs,
167
+ )
168
+
169
+ # === Config ===
170
+ self.vocabulary_size = vocabulary_size
171
+ self.num_layers = num_layers
172
+ self.num_query_heads = num_query_heads
173
+ self.hidden_dim = hidden_dim
174
+ self.intermediate_dim = intermediate_dim
175
+ self.num_key_value_heads = num_key_value_heads
176
+ self.num_experts = num_experts
177
+ self.top_k = top_k
178
+ self.router_jitter_noise = router_jitter_noise
179
+ self.rope_max_wavelength = rope_max_wavelength
180
+ self.router_aux_loss_coef = router_aux_loss_coef
181
+ self.rope_scaling_factor = rope_scaling_factor
182
+ self.sliding_window = sliding_window
183
+ self.layer_norm_epsilon = layer_norm_epsilon
184
+ self.dropout = dropout
185
+
186
+ def get_config(self):
187
+ config = super().get_config()
188
+ config.update(
189
+ {
190
+ "vocabulary_size": self.vocabulary_size,
191
+ "num_layers": self.num_layers,
192
+ "num_query_heads": self.num_query_heads,
193
+ "hidden_dim": self.hidden_dim,
194
+ "intermediate_dim": self.intermediate_dim,
195
+ "num_experts": self.num_experts,
196
+ "top_k": self.top_k,
197
+ "router_jitter_noise": self.router_jitter_noise,
198
+ "rope_max_wavelength": self.rope_max_wavelength,
199
+ "rope_scaling_factor": self.rope_scaling_factor,
200
+ "num_key_value_heads": self.num_key_value_heads,
201
+ "router_aux_loss_coef": self.router_aux_loss_coef,
202
+ "sliding_window": self.sliding_window,
203
+ "layer_norm_epsilon": self.layer_norm_epsilon,
204
+ "dropout": self.dropout,
205
+ }
206
+ )
207
+ return config