keras-hub-nightly 0.20.0.dev202503310403__py3-none-any.whl → 0.20.0.dev202504020401__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. keras_hub/api/models/__init__.py +21 -0
  2. keras_hub/api/tokenizers/__init__.py +3 -0
  3. keras_hub/src/models/gemma/gemma_attention.py +9 -7
  4. keras_hub/src/models/roformer_v2/__init__.py +0 -0
  5. keras_hub/src/models/roformer_v2/roformer_v2_attention.py +212 -0
  6. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +198 -0
  7. keras_hub/src/models/roformer_v2/roformer_v2_encoder.py +128 -0
  8. keras_hub/src/models/roformer_v2/roformer_v2_masked_lm.py +173 -0
  9. keras_hub/src/models/roformer_v2/roformer_v2_masked_lm_preprocessor.py +125 -0
  10. keras_hub/src/models/roformer_v2/roformer_v2_presets.py +0 -0
  11. keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +121 -0
  12. keras_hub/src/models/roformer_v2/roformer_v2_text_classifier_preprocessor.py +128 -0
  13. keras_hub/src/models/roformer_v2/roformer_v2_tokenizer.py +62 -0
  14. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +4 -2
  15. keras_hub/src/models/text_to_image_preprocessor.py +35 -0
  16. keras_hub/src/utils/preset_utils.py +2 -1
  17. keras_hub/src/version_utils.py +1 -1
  18. {keras_hub_nightly-0.20.0.dev202503310403.dist-info → keras_hub_nightly-0.20.0.dev202504020401.dist-info}/METADATA +1 -1
  19. {keras_hub_nightly-0.20.0.dev202503310403.dist-info → keras_hub_nightly-0.20.0.dev202504020401.dist-info}/RECORD +21 -10
  20. {keras_hub_nightly-0.20.0.dev202503310403.dist-info → keras_hub_nightly-0.20.0.dev202504020401.dist-info}/WHEEL +0 -0
  21. {keras_hub_nightly-0.20.0.dev202503310403.dist-info → keras_hub_nightly-0.20.0.dev202504020401.dist-info}/top_level.txt +0 -0
@@ -323,6 +323,24 @@ from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import (
323
323
  RobertaTextClassifierPreprocessor as RobertaPreprocessor,
324
324
  )
325
325
  from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
326
+ from keras_hub.src.models.roformer_v2.roformer_v2_backbone import (
327
+ RoformerV2Backbone as RorformerV2Backbone,
328
+ )
329
+ from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm import (
330
+ RoformerV2MaskedLM,
331
+ )
332
+ from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm_preprocessor import (
333
+ RoformerV2MaskedLMPreprocessor,
334
+ )
335
+ from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier import (
336
+ RorformerV2TextClassifier,
337
+ )
338
+ from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor import (
339
+ RoformerV2TextClassifierPreprocessor,
340
+ )
341
+ from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
342
+ RoformerV2Tokenizer,
343
+ )
326
344
  from keras_hub.src.models.sam.sam_backbone import SAMBackbone
327
345
  from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter
328
346
  from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (
@@ -369,6 +387,9 @@ from keras_hub.src.models.text_classifier_preprocessor import (
369
387
  TextClassifierPreprocessor,
370
388
  )
371
389
  from keras_hub.src.models.text_to_image import TextToImage
390
+ from keras_hub.src.models.text_to_image_preprocessor import (
391
+ TextToImagePreprocessor,
392
+ )
372
393
  from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
373
394
  from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier
374
395
  from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import (
@@ -35,6 +35,9 @@ from keras_hub.src.models.qwen.qwen_tokenizer import (
35
35
  QwenTokenizer as Qwen2Tokenizer,
36
36
  )
37
37
  from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
38
+ from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
39
+ RoformerV2Tokenizer,
40
+ )
38
41
  from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer
39
42
  from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer
40
43
  from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
@@ -133,6 +133,13 @@ class CachedGemmaAttention(keras.layers.Layer):
133
133
  query_normalization = 1 / np.sqrt(
134
134
  self.hidden_dim // self.num_query_heads
135
135
  )
136
+
137
+ if self.use_sliding_window_attention and attention_mask is not None:
138
+ attention_mask = self._mask_sliding_window(
139
+ attention_mask,
140
+ cache_update_index=cache_update_index,
141
+ )
142
+
136
143
  if self._can_use_flash_attention():
137
144
  if attention_mask is not None:
138
145
  attention_mask = ops.expand_dims(attention_mask, axis=1)
@@ -172,13 +179,8 @@ class CachedGemmaAttention(keras.layers.Layer):
172
179
  ops.tanh(attention_logits), self.logit_soft_cap
173
180
  )
174
181
 
175
- if self.use_sliding_window_attention:
176
- attention_mask = self._mask_sliding_window(
177
- attention_mask,
178
- cache_update_index=cache_update_index,
179
- )
180
-
181
- attention_mask = attention_mask[:, None, None, :, :]
182
+ if attention_mask is not None:
183
+ attention_mask = attention_mask[:, None, None, :, :]
182
184
  orig_dtype = attention_logits.dtype
183
185
  attention_softmax = self.softmax(attention_logits, mask=attention_mask)
184
186
  attention_softmax = ops.cast(attention_softmax, orig_dtype)
File without changes
@@ -0,0 +1,212 @@
1
+ import keras
2
+ from keras import initializers
3
+ from keras import ops
4
+
5
+
6
+ class RoformerNorm(keras.layers.Layer):
7
+ """A normalization layer for Roformer that implements RMS normalization."""
8
+
9
+ def __init__(self, epsilon=1e-6, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.epsilon = epsilon
12
+
13
+ def build(self, input_shape):
14
+ dim = input_shape[-1]
15
+ self.scale = self.add_weight(
16
+ name="scale",
17
+ trainable=True,
18
+ shape=(dim,),
19
+ initializer="ones",
20
+ dtype=self.variable_dtype,
21
+ )
22
+ self.built = True
23
+
24
+ def call(self, x):
25
+ x = ops.cast(x, "float32")
26
+ var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
27
+ x = x * ops.rsqrt(var + self.epsilon)
28
+ return ops.cast(x * self.scale, self.compute_dtype)
29
+
30
+ def get_config(self):
31
+ config = super().get_config()
32
+ config.update({"epsilon": self.epsilon})
33
+ return config
34
+
35
+
36
+ class RoformrPositionalEmbedding(keras.layers.Layer):
37
+ """Native rotary implement by jianlin su
38
+ from native implement
39
+ https://github.com/bojone/bert4keras
40
+
41
+ """
42
+
43
+ def __init__(self, output_dim, max_wavelength=10000, **kwargs):
44
+ super().__init__(**kwargs)
45
+ self.max_wavelength = max_wavelength
46
+ self.output_dim = output_dim
47
+
48
+ def call(self, tensors):
49
+ input_shape = ops.shape(tensors[0])
50
+ seq_len = input_shape[1]
51
+ position_ids = ops.arange(0, seq_len, dtype=tensors[0].dtype)[None]
52
+ embeddings = self.sinusoidal_embeddings(
53
+ position_ids, self.output_dim, self.max_wavelength
54
+ )
55
+ embeddings = ops.cast(embeddings, self.compute_dtype)
56
+
57
+ ndim = ops.ndim(tensors[0])
58
+ sinusoidal = self.align(embeddings, [0, 1, -1], ndim)
59
+ cos_pos = ops.repeat(sinusoidal[..., 1::2], 2, -1)
60
+ sin_pos = ops.repeat(sinusoidal[..., ::2], 2, -1)
61
+ outputs = []
62
+ for tensor in tensors:
63
+ tensor2 = ops.stack([-tensor[..., 1::2], tensor[..., ::2]], ndim)
64
+ tensor2 = ops.reshape(tensor2, ops.shape(tensor))
65
+ outputs.append(tensor * cos_pos + tensor2 * sin_pos)
66
+ return outputs[0] if len(outputs) == 1 else outputs
67
+
68
+ def align(self, tensor, axes, ndim=None):
69
+ ndim = ndim or max(axes) + 1
70
+ indices = [None] * ndim
71
+ for i in axes:
72
+ indices[i] = slice(None)
73
+ if keras.config.backend() == "jax":
74
+ return tensor[tuple(indices)]
75
+ return tensor[indices]
76
+
77
+ def sinusoidal_embeddings(self, pos, dim, base=10000):
78
+ if dim % 2 != 0:
79
+ raise ("Dimension must be even")
80
+
81
+ indices = ops.arange(0, dim // 2, dtype="float32")
82
+ indices = ops.power(ops.cast(base, dtype="float32"), -2 * indices / dim)
83
+ embeddings = ops.einsum("...,d->...d", pos, indices)
84
+ embeddings = ops.stack(
85
+ [ops.sin(embeddings), ops.cos(embeddings)], axis=-1
86
+ )
87
+ shape = list(ops.shape(embeddings))
88
+ embeddings = ops.reshape(embeddings, shape[:-2] + [-1])
89
+ return embeddings
90
+
91
+ def get_config(self):
92
+ config = super().get_config()
93
+ config.update(
94
+ {
95
+ "out_dim": self.out_dim,
96
+ "max_wavelength": self.max_wavelength,
97
+ }
98
+ )
99
+ return config
100
+
101
+
102
+ @keras.saving.register_keras_serializable(package="keras_hub")
103
+ class RoformerAttention(keras.layers.Layer):
104
+ """MultiHeadAttention by roformerV2
105
+
106
+ modifity from native implement
107
+ https://github.com/bojone/bert4keras
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ heads,
113
+ head_size,
114
+ out_dim=None,
115
+ use_bias=False,
116
+ max_wavelength=10000,
117
+ kernel_initializer="glorot_uniform",
118
+ **kwargs,
119
+ ):
120
+ super().__init__(**kwargs)
121
+ self.heads = heads
122
+ self.head_size = head_size
123
+ self.out_dim = out_dim or heads * head_size
124
+ self.use_bias = use_bias
125
+ self.kernel_initializer = initializers.get(kernel_initializer)
126
+ self.max_wavelength = max_wavelength
127
+
128
+ def build(self, input_shape):
129
+ super().build(input_shape)
130
+ self.q_dense = keras.layers.Dense(
131
+ units=self.head_size * self.heads,
132
+ use_bias=self.use_bias,
133
+ kernel_initializer=self.kernel_initializer,
134
+ name="q_dense_layer",
135
+ dtype=self.dtype_policy,
136
+ )
137
+ self.q_dense.build(input_shape)
138
+
139
+ self.k_dense = keras.layers.Dense(
140
+ units=self.head_size * self.heads,
141
+ use_bias=self.use_bias,
142
+ kernel_initializer=self.kernel_initializer,
143
+ name="k_dense_layer",
144
+ dtype=self.dtype_policy,
145
+ )
146
+ self.k_dense.build(input_shape)
147
+
148
+ self.v_dense = keras.layers.Dense(
149
+ units=self.head_size * self.heads,
150
+ use_bias=self.use_bias,
151
+ kernel_initializer=self.kernel_initializer,
152
+ name="v_dense_layer",
153
+ dtype=self.dtype_policy,
154
+ )
155
+ self.v_dense.build(input_shape)
156
+
157
+ self.o_dense = keras.layers.Dense(
158
+ units=self.out_dim,
159
+ use_bias=self.use_bias,
160
+ kernel_initializer=self.kernel_initializer,
161
+ name="o_dense_layer",
162
+ dtype=self.dtype_policy,
163
+ )
164
+ self.o_dense.build([None, None, self.head_size * self.heads])
165
+
166
+ self.rotary_embedding_layer = RoformrPositionalEmbedding(
167
+ self.head_size, self.max_wavelength, dtype=self.dtype_policy
168
+ )
169
+ self.rotary_embedding_layer.build([])
170
+
171
+ def call(self, x, attention_mask=None):
172
+ qw = self.q_dense(x)
173
+ kw = self.k_dense(x)
174
+ vw = self.v_dense(x)
175
+
176
+ b, s = ops.shape(qw)[:2]
177
+ qw = ops.reshape(qw, (b, s, self.heads, self.head_size))
178
+ kw = ops.reshape(kw, (b, s, self.heads, self.head_size))
179
+ vw = ops.reshape(vw, (b, s, self.heads, self.head_size))
180
+
181
+ qw, kw = self.rotary_embedding_layer([qw, kw])
182
+ if keras.__version__ < "3.6":
183
+ raise ("Please make sure your Keras version is >=3.6.")
184
+ flash_attention = keras.config.is_flash_attention_enabled()
185
+ attention_mask = ops.reshape(attention_mask, [b, 1, s, 1])
186
+ if keras.config.backend() == "torch":
187
+ attention_mask = ops.repeat(attention_mask, s, -1)
188
+ attention_mask = ops.transpose(attention_mask, [0, 1, 3, 2])
189
+ o = ops.dot_product_attention(
190
+ qw, kw, vw, mask=attention_mask, flash_attention=flash_attention
191
+ )
192
+
193
+ return self.o_dense(ops.reshape(o, [b, s, -1]))
194
+
195
+ def compute_output_shape(self, input_shape):
196
+ return input_shape
197
+
198
+ def get_config(self):
199
+ config = super().get_config()
200
+ config.update(
201
+ {
202
+ "heads": self.heads,
203
+ "head_size": self.head_size,
204
+ "out_dim": self.out_dim,
205
+ "use_bias": self.use_bias,
206
+ "max_wavelength": self.max_wavelength,
207
+ "kernel_initializer": initializers.serialize(
208
+ self.kernel_initializer
209
+ ),
210
+ }
211
+ )
212
+ return config
@@ -0,0 +1,198 @@
1
+ import keras
2
+ from keras import activations
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.layers.modeling.reversible_embedding import (
6
+ ReversibleEmbedding,
7
+ )
8
+ from keras_hub.src.models.backbone import Backbone
9
+ from keras_hub.src.models.roformer_v2.roformer_v2_attention import RoformerNorm
10
+ from keras_hub.src.models.roformer_v2.roformer_v2_encoder import (
11
+ RoformerV2Encoder,
12
+ )
13
+
14
+
15
+ def roformer_kernel_initializer(stddev=0.02):
16
+ return keras.initializers.TruncatedNormal(stddev=stddev)
17
+
18
+
19
+ @keras_hub_export("keras_hub.models.RorformerV2Backbone")
20
+ class RoformerV2Backbone(Backbone):
21
+ """A RoformerV2 encoder network.
22
+
23
+ This class implements a bi-directional Transformer-based encoder as
24
+ described in ["Roformer"](https://github.com/ZhuiyiTechnology/roformer).
25
+ It includes the
26
+ embedding lookups and transformer layers, but not the masked language model
27
+ or next sentence prediction heads.
28
+
29
+ The default constructor gives a fully customizable, randomly initialized
30
+ RoformerV2 encoder with any number of layers, heads, and embed dim.To
31
+ load preset architectures and weights, use the `from_preset()` constructor.
32
+
33
+ Disclaimer: Pre-trained models are provided on an "as is" basis, without
34
+ warranties or conditions of any kind.
35
+
36
+ Args:
37
+ vocabulary_size: int. The size of the token vocabulary.
38
+ num_layers: int. The number of transformer layers.
39
+ num_heads: int. The number of attention heads for each transformer.
40
+ The hidden size must be divisible by the number of attention heads.
41
+ hidden_dim: int. The size of the transformer encoding and pooler layers.
42
+ intermediate_dim: int. The output dimension of the first Dense layer in
43
+ a two-layer feedforward network for each transformer.
44
+ dropout: float. Dropout probability for the Transformer encoder.
45
+ num_segments: int. The number of types that the 'segment_ids' input can
46
+ take.
47
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
48
+ for model computations and weights. Note that some computations,
49
+ such as softmax and layer normalization, will always be done at
50
+ float32 precision regardless of dtype.
51
+
52
+ Examples:
53
+ ```python
54
+ input_data = {
55
+ "token_ids": np.ones(shape=(1, 12), dtype="int32"),
56
+ "segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]]),
57
+ "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
58
+ }
59
+
60
+ # Pretrained RoformerV2 encoder.
61
+ model = keras_hub.models.RoformerV2Backbone.from_preset("roformer_v2_base")
62
+ model(input_data)
63
+
64
+ # Randomly initialized RoformerV2 encoder with a custom config.
65
+ model = keras_hub.models.RoformerV2Backbone(
66
+ vocabulary_size=30552,
67
+ num_layers=4,
68
+ num_heads=4,
69
+ hidden_dim=256,
70
+ intermediate_dim=512,
71
+ head_size = 64,
72
+ )
73
+ model(input_data)
74
+ ```
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ vocabulary_size,
80
+ num_layers,
81
+ num_heads,
82
+ hidden_dim,
83
+ intermediate_dim,
84
+ head_size,
85
+ use_bias=False,
86
+ activation="relu",
87
+ dropout=0.1,
88
+ num_segments=2,
89
+ dtype=None,
90
+ max_wavelength=10000,
91
+ **kwargs,
92
+ ):
93
+ # === Layers ===
94
+ self.token_embedding = ReversibleEmbedding(
95
+ input_dim=vocabulary_size,
96
+ output_dim=hidden_dim,
97
+ embeddings_initializer=roformer_kernel_initializer(),
98
+ dtype=dtype,
99
+ name="token_embedding",
100
+ )
101
+ self.segment_embedding = keras.layers.Embedding(
102
+ input_dim=num_segments,
103
+ output_dim=hidden_dim,
104
+ embeddings_initializer=roformer_kernel_initializer(),
105
+ dtype=dtype,
106
+ name="segment_embedding",
107
+ )
108
+ self.embeddings_add = keras.layers.Add(
109
+ dtype=dtype,
110
+ name="embeddings_add",
111
+ )
112
+ self.embeddings_layer_norm = RoformerNorm(
113
+ epsilon=keras.backend.epsilon(),
114
+ dtype=dtype,
115
+ name="embeddings_layer_norm",
116
+ )
117
+ self.embeddings_dropout = keras.layers.Dropout(
118
+ dropout,
119
+ dtype=dtype,
120
+ name="embeddings_dropout",
121
+ )
122
+ self.transformer_layers = []
123
+ for i in range(num_layers):
124
+ layer = RoformerV2Encoder(
125
+ heads=num_heads,
126
+ head_size=head_size,
127
+ intermediate_size=intermediate_dim,
128
+ use_bias=use_bias,
129
+ max_wavelength=max_wavelength,
130
+ dropout=dropout,
131
+ activation=activation,
132
+ kernel_initializer=roformer_kernel_initializer(),
133
+ dtype=dtype,
134
+ name=f"transformer_layer_{i}",
135
+ )
136
+ self.transformer_layers.append(layer)
137
+
138
+ # === Functional Model ===
139
+ token_id_input = keras.Input(
140
+ shape=(None,), dtype="int32", name="token_ids"
141
+ )
142
+ segment_id_input = keras.Input(
143
+ shape=(None,), dtype="int32", name="segment_ids"
144
+ )
145
+ attention_mask = keras.ops.not_equal(token_id_input, 0)
146
+ # Embed tokens, positions, and segment ids.
147
+ tokens = self.token_embedding(token_id_input)
148
+ segments = self.segment_embedding(segment_id_input)
149
+ # Sum, normalize and apply dropout to embeddings.
150
+ x = self.embeddings_add((tokens, segments))
151
+ x = self.embeddings_layer_norm(x)
152
+ x = self.embeddings_dropout(x)
153
+ for transformer_layer in self.transformer_layers:
154
+ x = transformer_layer(x, attention_mask=attention_mask)
155
+
156
+ super().__init__(
157
+ inputs={
158
+ "token_ids": token_id_input,
159
+ "segment_ids": segment_id_input,
160
+ },
161
+ outputs=x,
162
+ dtype=dtype,
163
+ **kwargs,
164
+ )
165
+
166
+ # === Config ===
167
+ self.vocabulary_size = vocabulary_size
168
+ self.num_layers = num_layers
169
+ self.num_heads = num_heads
170
+ self.hidden_dim = hidden_dim
171
+ self.intermediate_dim = intermediate_dim
172
+ self.dropout = dropout
173
+ self.num_segments = num_segments
174
+ self.max_wavelength = max_wavelength
175
+ self.head_size = head_size
176
+ self.dropout = dropout
177
+ self.activation = activations.get(activation)
178
+ self.use_bias = use_bias
179
+ self.start_token_index = 0
180
+
181
+ def get_config(self):
182
+ config = super().get_config()
183
+ config.update(
184
+ {
185
+ "vocabulary_size": self.vocabulary_size,
186
+ "num_layers": self.num_layers,
187
+ "num_heads": self.num_heads,
188
+ "hidden_dim": self.hidden_dim,
189
+ "intermediate_dim": self.intermediate_dim,
190
+ "dropout": self.dropout,
191
+ "num_segments": self.num_segments,
192
+ "max_wavelength": self.max_wavelength,
193
+ "head_size": self.head_size,
194
+ "use_bias": self.use_bias,
195
+ "activation": activations.serialize(self.activation),
196
+ }
197
+ )
198
+ return config
@@ -0,0 +1,128 @@
1
+ import keras
2
+ from keras import activations
3
+ from keras import initializers
4
+
5
+ from keras_hub.src.models.roformer_v2.roformer_v2_attention import (
6
+ RoformerAttention,
7
+ )
8
+ from keras_hub.src.models.roformer_v2.roformer_v2_attention import RoformerNorm
9
+
10
+
11
+ @keras.saving.register_keras_serializable(package="keras_hub")
12
+ class RoformerV2Encoder(keras.layers.Layer):
13
+ """A Transformer Encoder layer for the Roformer backbone."""
14
+
15
+ def __init__(
16
+ self,
17
+ heads,
18
+ head_size,
19
+ intermediate_size=None,
20
+ max_wavelength=10000,
21
+ dropout=0,
22
+ activation="relu",
23
+ use_bias=False,
24
+ kernel_initializer="glorot_uniform",
25
+ **kwargs,
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.heads = heads
29
+ self.head_size = head_size
30
+ self.intermediate_size = intermediate_size
31
+ self.use_bias = use_bias
32
+ self.kernel_initializer = initializers.get(kernel_initializer)
33
+ self.max_wavelength = max_wavelength
34
+ self.dropout = dropout
35
+ self.activation = activations.get(activation)
36
+
37
+ def build(self, input_shape):
38
+ super().build(input_shape)
39
+ self.attention_layer = RoformerAttention(
40
+ heads=self.heads,
41
+ head_size=self.head_size,
42
+ use_bias=self.use_bias,
43
+ max_wavelength=self.max_wavelength,
44
+ kernel_initializer=self.kernel_initializer,
45
+ dtype=self.dtype_policy,
46
+ name="attention_layer",
47
+ )
48
+ self.attention_layer.build(input_shape)
49
+
50
+ self.dropout_layer = keras.layers.Dropout(
51
+ rate=self.dropout,
52
+ dtype=self.dtype_policy,
53
+ name="self_attention_dropout",
54
+ )
55
+ self.dropout_layer.build([])
56
+
57
+ # Feedforward layers.
58
+ self.feedforward_intermediate_dense = keras.layers.Dense(
59
+ self.intermediate_size,
60
+ kernel_initializer=self.kernel_initializer,
61
+ use_bias=self.use_bias,
62
+ dtype=self.dtype_policy,
63
+ activation=self.activation,
64
+ name="feedforward_intermediate_dense",
65
+ )
66
+ self.feedforward_intermediate_dense.build(input_shape)
67
+
68
+ self.feedforward_output_dense = keras.layers.Dense(
69
+ input_shape[-1],
70
+ kernel_initializer=self.kernel_initializer,
71
+ use_bias=self.use_bias,
72
+ dtype=self.dtype_policy,
73
+ name="feedforward_output_dense",
74
+ )
75
+
76
+ self.feedforward_output_dense.build(
77
+ [None, None, self.intermediate_size]
78
+ )
79
+
80
+ self.attention_norm = RoformerNorm(
81
+ epsilon=keras.backend.epsilon(),
82
+ name="attention_norm",
83
+ dtype=self.dtype_policy,
84
+ )
85
+ self.attention_norm.build(input_shape)
86
+
87
+ self.feedforward_norm = RoformerNorm(
88
+ epsilon=keras.backend.epsilon(),
89
+ name="feedforward_norm",
90
+ dtype=self.dtype_policy,
91
+ )
92
+ self.feedforward_norm.build(input_shape)
93
+
94
+ def call(self, x, attention_mask=None):
95
+ attention_output = self.attention_layer(
96
+ x,
97
+ attention_mask=attention_mask,
98
+ )
99
+
100
+ residual = x + self.dropout_layer(attention_output)
101
+ x = self.attention_norm(residual)
102
+
103
+ intermediate_output = self.feedforward_intermediate_dense(x)
104
+ feedroward_output = self.feedforward_output_dense(intermediate_output)
105
+
106
+ residual = x + self.dropout_layer(feedroward_output)
107
+ return self.feedforward_norm(residual)
108
+
109
+ def compute_output_shape(self, input_shape):
110
+ return input_shape
111
+
112
+ def get_config(self):
113
+ config = super().get_config()
114
+ config.update(
115
+ {
116
+ "heads": self.heads,
117
+ "head_size": self.head_size,
118
+ "intermediate_size": self.intermediate_size,
119
+ "max_wavelength": self.max_wavelength,
120
+ "use_bias": self.use_bias,
121
+ "activation": activations.serialize(self.activation),
122
+ "dropout": self.dropout,
123
+ "kernel_initializer": initializers.serialize(
124
+ self.kernel_initializer
125
+ ),
126
+ }
127
+ )
128
+ return config