keras-hub-nightly 0.20.0.dev202504010407__py3-none-any.whl → 0.20.0.dev202504030357__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. keras_hub/api/models/__init__.py +18 -0
  2. keras_hub/api/tokenizers/__init__.py +3 -0
  3. keras_hub/src/models/gemma/gemma_attention.py +26 -17
  4. keras_hub/src/models/gemma3/gemma3_attention.py +2 -2
  5. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -2
  6. keras_hub/src/models/llama/llama_attention.py +2 -2
  7. keras_hub/src/models/mistral/mistral_attention.py +2 -2
  8. keras_hub/src/models/phi3/phi3_attention.py +2 -2
  9. keras_hub/src/models/qwen/qwen_attention.py +2 -2
  10. keras_hub/src/models/roformer_v2/__init__.py +0 -0
  11. keras_hub/src/models/roformer_v2/roformer_v2_attention.py +212 -0
  12. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +198 -0
  13. keras_hub/src/models/roformer_v2/roformer_v2_encoder.py +128 -0
  14. keras_hub/src/models/roformer_v2/roformer_v2_masked_lm.py +173 -0
  15. keras_hub/src/models/roformer_v2/roformer_v2_masked_lm_preprocessor.py +125 -0
  16. keras_hub/src/models/roformer_v2/roformer_v2_presets.py +0 -0
  17. keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +121 -0
  18. keras_hub/src/models/roformer_v2/roformer_v2_text_classifier_preprocessor.py +128 -0
  19. keras_hub/src/models/roformer_v2/roformer_v2_tokenizer.py +62 -0
  20. keras_hub/src/models/stable_diffusion_3/mmdit.py +2 -2
  21. keras_hub/src/utils/keras_utils.py +44 -1
  22. keras_hub/src/utils/preset_utils.py +2 -1
  23. keras_hub/src/version_utils.py +1 -1
  24. {keras_hub_nightly-0.20.0.dev202504010407.dist-info → keras_hub_nightly-0.20.0.dev202504030357.dist-info}/METADATA +1 -1
  25. {keras_hub_nightly-0.20.0.dev202504010407.dist-info → keras_hub_nightly-0.20.0.dev202504030357.dist-info}/RECORD +27 -17
  26. {keras_hub_nightly-0.20.0.dev202504010407.dist-info → keras_hub_nightly-0.20.0.dev202504030357.dist-info}/WHEEL +0 -0
  27. {keras_hub_nightly-0.20.0.dev202504010407.dist-info → keras_hub_nightly-0.20.0.dev202504030357.dist-info}/top_level.txt +0 -0
@@ -323,6 +323,24 @@ from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import (
323
323
  RobertaTextClassifierPreprocessor as RobertaPreprocessor,
324
324
  )
325
325
  from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
326
+ from keras_hub.src.models.roformer_v2.roformer_v2_backbone import (
327
+ RoformerV2Backbone as RorformerV2Backbone,
328
+ )
329
+ from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm import (
330
+ RoformerV2MaskedLM,
331
+ )
332
+ from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm_preprocessor import (
333
+ RoformerV2MaskedLMPreprocessor,
334
+ )
335
+ from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier import (
336
+ RorformerV2TextClassifier,
337
+ )
338
+ from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor import (
339
+ RoformerV2TextClassifierPreprocessor,
340
+ )
341
+ from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
342
+ RoformerV2Tokenizer,
343
+ )
326
344
  from keras_hub.src.models.sam.sam_backbone import SAMBackbone
327
345
  from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter
328
346
  from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (
@@ -35,6 +35,9 @@ from keras_hub.src.models.qwen.qwen_tokenizer import (
35
35
  QwenTokenizer as Qwen2Tokenizer,
36
36
  )
37
37
  from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
38
+ from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
39
+ RoformerV2Tokenizer,
40
+ )
38
41
  from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer
39
42
  from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer
40
43
  from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
@@ -6,7 +6,9 @@ from keras import ops
6
6
 
7
7
  from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
8
8
  from keras_hub.src.utils.keras_utils import clone_initializer
9
- from keras_hub.src.utils.keras_utils import has_flash_attention_support
9
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
10
+ from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
11
+ from keras_hub.src.utils.keras_utils import running_on_gpu
10
12
  from keras_hub.src.utils.keras_utils import running_on_tpu
11
13
 
12
14
 
@@ -106,17 +108,22 @@ class CachedGemmaAttention(keras.layers.Layer):
106
108
  )
107
109
  return x
108
110
 
109
- def _can_use_flash_attention(self):
110
- if not has_flash_attention_support():
111
+ def _use_fused_attention_op(self):
112
+ if not fused_attention_op_available():
111
113
  return False
112
114
  if self.dropout > 0.0:
113
115
  return False
114
- if self.logit_soft_cap is None:
115
- return True
116
- sig = inspect.signature(ops.dot_product_attention)
117
- # We can currently only run soft capped attention for keras >= 3.10
118
- # and only on TPU.
119
- return running_on_tpu() and "attn_logits_soft_cap" in sig.parameters
116
+ if running_on_gpu():
117
+ # GPU never supports softcap in the fused op.
118
+ if self.logit_soft_cap is not None:
119
+ return False
120
+ return gpu_supports_fused_attention_op()
121
+ elif running_on_tpu():
122
+ # TPU supports softcap with on keras >= 3.10.
123
+ sig = inspect.signature(ops.dot_product_attention)
124
+ return "attn_logits_soft_cap" in sig.parameters
125
+ else:
126
+ return False
120
127
 
121
128
  def _compute_attention(
122
129
  self,
@@ -133,7 +140,14 @@ class CachedGemmaAttention(keras.layers.Layer):
133
140
  query_normalization = 1 / np.sqrt(
134
141
  self.hidden_dim // self.num_query_heads
135
142
  )
136
- if self._can_use_flash_attention():
143
+
144
+ if self.use_sliding_window_attention and attention_mask is not None:
145
+ attention_mask = self._mask_sliding_window(
146
+ attention_mask,
147
+ cache_update_index=cache_update_index,
148
+ )
149
+
150
+ if self._use_fused_attention_op():
137
151
  if attention_mask is not None:
138
152
  attention_mask = ops.expand_dims(attention_mask, axis=1)
139
153
  attention_mask = ops.cast(attention_mask, dtype="bool")
@@ -172,13 +186,8 @@ class CachedGemmaAttention(keras.layers.Layer):
172
186
  ops.tanh(attention_logits), self.logit_soft_cap
173
187
  )
174
188
 
175
- if self.use_sliding_window_attention:
176
- attention_mask = self._mask_sliding_window(
177
- attention_mask,
178
- cache_update_index=cache_update_index,
179
- )
180
-
181
- attention_mask = attention_mask[:, None, None, :, :]
189
+ if attention_mask is not None:
190
+ attention_mask = attention_mask[:, None, None, :, :]
182
191
  orig_dtype = attention_logits.dtype
183
192
  attention_softmax = self.softmax(attention_logits, mask=attention_mask)
184
193
  attention_softmax = ops.cast(attention_softmax, orig_dtype)
@@ -7,7 +7,7 @@ from keras import ops
7
7
  from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
8
8
  from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
9
9
  from keras_hub.src.utils.keras_utils import clone_initializer
10
- from keras_hub.src.utils.keras_utils import has_flash_attention_support
10
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
11
11
  from keras_hub.src.utils.keras_utils import running_on_tpu
12
12
 
13
13
 
@@ -140,7 +140,7 @@ class CachedGemma3Attention(keras.layers.Layer):
140
140
  return x
141
141
 
142
142
  def _can_use_flash_attention(self):
143
- if not has_flash_attention_support():
143
+ if not fused_attention_op_available():
144
144
  return False
145
145
  if self.dropout > 0.0:
146
146
  return False
@@ -5,7 +5,7 @@ from keras import ops
5
5
 
6
6
  from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
7
7
  from keras_hub.src.utils.keras_utils import clone_initializer
8
- from keras_hub.src.utils.keras_utils import has_flash_attention_support
8
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
9
9
 
10
10
 
11
11
  class GPTNeoXAttention(keras.layers.Layer):
@@ -125,7 +125,7 @@ class GPTNeoXAttention(keras.layers.Layer):
125
125
  def _compute_attention(
126
126
  self, query, key, value, attention_mask=None, training=None
127
127
  ):
128
- if has_flash_attention_support() and self.dropout == 0:
128
+ if fused_attention_op_available() and self.dropout == 0:
129
129
  # Use `dot_product_attention` with Flash Attention support if
130
130
  # available.
131
131
  if attention_mask is not None:
@@ -5,7 +5,7 @@ from keras import ops
5
5
 
6
6
  from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
7
7
  from keras_hub.src.utils.keras_utils import clone_initializer
8
- from keras_hub.src.utils.keras_utils import has_flash_attention_support
8
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
9
9
 
10
10
 
11
11
  class LlamaAttention(keras.layers.Layer):
@@ -185,7 +185,7 @@ class LlamaAttention(keras.layers.Layer):
185
185
  return self._softmax(attention_scores)
186
186
 
187
187
  def _compute_attention(self, query, key, value, attention_mask=None):
188
- if has_flash_attention_support():
188
+ if fused_attention_op_available():
189
189
  # Use `dot_product_attention` with Flash Attention support if
190
190
  # available.
191
191
  if attention_mask is not None:
@@ -5,7 +5,7 @@ from keras import ops
5
5
 
6
6
  from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
7
7
  from keras_hub.src.utils.keras_utils import clone_initializer
8
- from keras_hub.src.utils.keras_utils import has_flash_attention_support
8
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
9
9
 
10
10
 
11
11
  # This is just a self-attention layer in Mistral. But it can be generalized
@@ -196,7 +196,7 @@ class CachedMistralAttention(keras.layers.Layer):
196
196
  return self._softmax(attention_scores)
197
197
 
198
198
  def _compute_attention(self, query, key, value, attention_mask=None):
199
- if has_flash_attention_support():
199
+ if fused_attention_op_available():
200
200
  # Use `dot_product_attention` with Flash Attention support if
201
201
  # available.
202
202
  if attention_mask is not None:
@@ -8,7 +8,7 @@ from keras_hub.src.models.phi3.phi3_rotary_embedding import (
8
8
  Phi3SuScaledRotaryEmbedding,
9
9
  )
10
10
  from keras_hub.src.utils.keras_utils import clone_initializer
11
- from keras_hub.src.utils.keras_utils import has_flash_attention_support
11
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
12
12
 
13
13
 
14
14
  class Phi3Attention(keras.layers.Layer):
@@ -217,7 +217,7 @@ class Phi3Attention(keras.layers.Layer):
217
217
  return self.softmax(attention_scores)
218
218
 
219
219
  def _compute_attention(self, query, key, value, attention_mask=None):
220
- if has_flash_attention_support():
220
+ if fused_attention_op_available():
221
221
  # Use `dot_product_attention` with Flash Attention support if
222
222
  # available.
223
223
  if attention_mask is not None:
@@ -5,7 +5,7 @@ from keras import ops
5
5
 
6
6
  from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
7
7
  from keras_hub.src.utils.keras_utils import clone_initializer
8
- from keras_hub.src.utils.keras_utils import has_flash_attention_support
8
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
9
9
 
10
10
 
11
11
  class QwenAttention(keras.layers.Layer):
@@ -263,7 +263,7 @@ class QwenAttention(keras.layers.Layer):
263
263
  Returns:
264
264
  attention_output: Output tensor after applying attention.
265
265
  """
266
- if has_flash_attention_support():
266
+ if fused_attention_op_available():
267
267
  # Use `dot_product_attention` with Flash Attention support if
268
268
  # available.
269
269
  if attention_mask is not None:
File without changes
@@ -0,0 +1,212 @@
1
+ import keras
2
+ from keras import initializers
3
+ from keras import ops
4
+
5
+
6
+ class RoformerNorm(keras.layers.Layer):
7
+ """A normalization layer for Roformer that implements RMS normalization."""
8
+
9
+ def __init__(self, epsilon=1e-6, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.epsilon = epsilon
12
+
13
+ def build(self, input_shape):
14
+ dim = input_shape[-1]
15
+ self.scale = self.add_weight(
16
+ name="scale",
17
+ trainable=True,
18
+ shape=(dim,),
19
+ initializer="ones",
20
+ dtype=self.variable_dtype,
21
+ )
22
+ self.built = True
23
+
24
+ def call(self, x):
25
+ x = ops.cast(x, "float32")
26
+ var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
27
+ x = x * ops.rsqrt(var + self.epsilon)
28
+ return ops.cast(x * self.scale, self.compute_dtype)
29
+
30
+ def get_config(self):
31
+ config = super().get_config()
32
+ config.update({"epsilon": self.epsilon})
33
+ return config
34
+
35
+
36
+ class RoformrPositionalEmbedding(keras.layers.Layer):
37
+ """Native rotary implement by jianlin su
38
+ from native implement
39
+ https://github.com/bojone/bert4keras
40
+
41
+ """
42
+
43
+ def __init__(self, output_dim, max_wavelength=10000, **kwargs):
44
+ super().__init__(**kwargs)
45
+ self.max_wavelength = max_wavelength
46
+ self.output_dim = output_dim
47
+
48
+ def call(self, tensors):
49
+ input_shape = ops.shape(tensors[0])
50
+ seq_len = input_shape[1]
51
+ position_ids = ops.arange(0, seq_len, dtype=tensors[0].dtype)[None]
52
+ embeddings = self.sinusoidal_embeddings(
53
+ position_ids, self.output_dim, self.max_wavelength
54
+ )
55
+ embeddings = ops.cast(embeddings, self.compute_dtype)
56
+
57
+ ndim = ops.ndim(tensors[0])
58
+ sinusoidal = self.align(embeddings, [0, 1, -1], ndim)
59
+ cos_pos = ops.repeat(sinusoidal[..., 1::2], 2, -1)
60
+ sin_pos = ops.repeat(sinusoidal[..., ::2], 2, -1)
61
+ outputs = []
62
+ for tensor in tensors:
63
+ tensor2 = ops.stack([-tensor[..., 1::2], tensor[..., ::2]], ndim)
64
+ tensor2 = ops.reshape(tensor2, ops.shape(tensor))
65
+ outputs.append(tensor * cos_pos + tensor2 * sin_pos)
66
+ return outputs[0] if len(outputs) == 1 else outputs
67
+
68
+ def align(self, tensor, axes, ndim=None):
69
+ ndim = ndim or max(axes) + 1
70
+ indices = [None] * ndim
71
+ for i in axes:
72
+ indices[i] = slice(None)
73
+ if keras.config.backend() == "jax":
74
+ return tensor[tuple(indices)]
75
+ return tensor[indices]
76
+
77
+ def sinusoidal_embeddings(self, pos, dim, base=10000):
78
+ if dim % 2 != 0:
79
+ raise ("Dimension must be even")
80
+
81
+ indices = ops.arange(0, dim // 2, dtype="float32")
82
+ indices = ops.power(ops.cast(base, dtype="float32"), -2 * indices / dim)
83
+ embeddings = ops.einsum("...,d->...d", pos, indices)
84
+ embeddings = ops.stack(
85
+ [ops.sin(embeddings), ops.cos(embeddings)], axis=-1
86
+ )
87
+ shape = list(ops.shape(embeddings))
88
+ embeddings = ops.reshape(embeddings, shape[:-2] + [-1])
89
+ return embeddings
90
+
91
+ def get_config(self):
92
+ config = super().get_config()
93
+ config.update(
94
+ {
95
+ "out_dim": self.out_dim,
96
+ "max_wavelength": self.max_wavelength,
97
+ }
98
+ )
99
+ return config
100
+
101
+
102
+ @keras.saving.register_keras_serializable(package="keras_hub")
103
+ class RoformerAttention(keras.layers.Layer):
104
+ """MultiHeadAttention by roformerV2
105
+
106
+ modifity from native implement
107
+ https://github.com/bojone/bert4keras
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ heads,
113
+ head_size,
114
+ out_dim=None,
115
+ use_bias=False,
116
+ max_wavelength=10000,
117
+ kernel_initializer="glorot_uniform",
118
+ **kwargs,
119
+ ):
120
+ super().__init__(**kwargs)
121
+ self.heads = heads
122
+ self.head_size = head_size
123
+ self.out_dim = out_dim or heads * head_size
124
+ self.use_bias = use_bias
125
+ self.kernel_initializer = initializers.get(kernel_initializer)
126
+ self.max_wavelength = max_wavelength
127
+
128
+ def build(self, input_shape):
129
+ super().build(input_shape)
130
+ self.q_dense = keras.layers.Dense(
131
+ units=self.head_size * self.heads,
132
+ use_bias=self.use_bias,
133
+ kernel_initializer=self.kernel_initializer,
134
+ name="q_dense_layer",
135
+ dtype=self.dtype_policy,
136
+ )
137
+ self.q_dense.build(input_shape)
138
+
139
+ self.k_dense = keras.layers.Dense(
140
+ units=self.head_size * self.heads,
141
+ use_bias=self.use_bias,
142
+ kernel_initializer=self.kernel_initializer,
143
+ name="k_dense_layer",
144
+ dtype=self.dtype_policy,
145
+ )
146
+ self.k_dense.build(input_shape)
147
+
148
+ self.v_dense = keras.layers.Dense(
149
+ units=self.head_size * self.heads,
150
+ use_bias=self.use_bias,
151
+ kernel_initializer=self.kernel_initializer,
152
+ name="v_dense_layer",
153
+ dtype=self.dtype_policy,
154
+ )
155
+ self.v_dense.build(input_shape)
156
+
157
+ self.o_dense = keras.layers.Dense(
158
+ units=self.out_dim,
159
+ use_bias=self.use_bias,
160
+ kernel_initializer=self.kernel_initializer,
161
+ name="o_dense_layer",
162
+ dtype=self.dtype_policy,
163
+ )
164
+ self.o_dense.build([None, None, self.head_size * self.heads])
165
+
166
+ self.rotary_embedding_layer = RoformrPositionalEmbedding(
167
+ self.head_size, self.max_wavelength, dtype=self.dtype_policy
168
+ )
169
+ self.rotary_embedding_layer.build([])
170
+
171
+ def call(self, x, attention_mask=None):
172
+ qw = self.q_dense(x)
173
+ kw = self.k_dense(x)
174
+ vw = self.v_dense(x)
175
+
176
+ b, s = ops.shape(qw)[:2]
177
+ qw = ops.reshape(qw, (b, s, self.heads, self.head_size))
178
+ kw = ops.reshape(kw, (b, s, self.heads, self.head_size))
179
+ vw = ops.reshape(vw, (b, s, self.heads, self.head_size))
180
+
181
+ qw, kw = self.rotary_embedding_layer([qw, kw])
182
+ if keras.__version__ < "3.6":
183
+ raise ("Please make sure your Keras version is >=3.6.")
184
+ flash_attention = keras.config.is_flash_attention_enabled()
185
+ attention_mask = ops.reshape(attention_mask, [b, 1, s, 1])
186
+ if keras.config.backend() == "torch":
187
+ attention_mask = ops.repeat(attention_mask, s, -1)
188
+ attention_mask = ops.transpose(attention_mask, [0, 1, 3, 2])
189
+ o = ops.dot_product_attention(
190
+ qw, kw, vw, mask=attention_mask, flash_attention=flash_attention
191
+ )
192
+
193
+ return self.o_dense(ops.reshape(o, [b, s, -1]))
194
+
195
+ def compute_output_shape(self, input_shape):
196
+ return input_shape
197
+
198
+ def get_config(self):
199
+ config = super().get_config()
200
+ config.update(
201
+ {
202
+ "heads": self.heads,
203
+ "head_size": self.head_size,
204
+ "out_dim": self.out_dim,
205
+ "use_bias": self.use_bias,
206
+ "max_wavelength": self.max_wavelength,
207
+ "kernel_initializer": initializers.serialize(
208
+ self.kernel_initializer
209
+ ),
210
+ }
211
+ )
212
+ return config
@@ -0,0 +1,198 @@
1
+ import keras
2
+ from keras import activations
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.layers.modeling.reversible_embedding import (
6
+ ReversibleEmbedding,
7
+ )
8
+ from keras_hub.src.models.backbone import Backbone
9
+ from keras_hub.src.models.roformer_v2.roformer_v2_attention import RoformerNorm
10
+ from keras_hub.src.models.roformer_v2.roformer_v2_encoder import (
11
+ RoformerV2Encoder,
12
+ )
13
+
14
+
15
+ def roformer_kernel_initializer(stddev=0.02):
16
+ return keras.initializers.TruncatedNormal(stddev=stddev)
17
+
18
+
19
+ @keras_hub_export("keras_hub.models.RorformerV2Backbone")
20
+ class RoformerV2Backbone(Backbone):
21
+ """A RoformerV2 encoder network.
22
+
23
+ This class implements a bi-directional Transformer-based encoder as
24
+ described in ["Roformer"](https://github.com/ZhuiyiTechnology/roformer).
25
+ It includes the
26
+ embedding lookups and transformer layers, but not the masked language model
27
+ or next sentence prediction heads.
28
+
29
+ The default constructor gives a fully customizable, randomly initialized
30
+ RoformerV2 encoder with any number of layers, heads, and embed dim.To
31
+ load preset architectures and weights, use the `from_preset()` constructor.
32
+
33
+ Disclaimer: Pre-trained models are provided on an "as is" basis, without
34
+ warranties or conditions of any kind.
35
+
36
+ Args:
37
+ vocabulary_size: int. The size of the token vocabulary.
38
+ num_layers: int. The number of transformer layers.
39
+ num_heads: int. The number of attention heads for each transformer.
40
+ The hidden size must be divisible by the number of attention heads.
41
+ hidden_dim: int. The size of the transformer encoding and pooler layers.
42
+ intermediate_dim: int. The output dimension of the first Dense layer in
43
+ a two-layer feedforward network for each transformer.
44
+ dropout: float. Dropout probability for the Transformer encoder.
45
+ num_segments: int. The number of types that the 'segment_ids' input can
46
+ take.
47
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
48
+ for model computations and weights. Note that some computations,
49
+ such as softmax and layer normalization, will always be done at
50
+ float32 precision regardless of dtype.
51
+
52
+ Examples:
53
+ ```python
54
+ input_data = {
55
+ "token_ids": np.ones(shape=(1, 12), dtype="int32"),
56
+ "segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]]),
57
+ "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
58
+ }
59
+
60
+ # Pretrained RoformerV2 encoder.
61
+ model = keras_hub.models.RoformerV2Backbone.from_preset("roformer_v2_base")
62
+ model(input_data)
63
+
64
+ # Randomly initialized RoformerV2 encoder with a custom config.
65
+ model = keras_hub.models.RoformerV2Backbone(
66
+ vocabulary_size=30552,
67
+ num_layers=4,
68
+ num_heads=4,
69
+ hidden_dim=256,
70
+ intermediate_dim=512,
71
+ head_size = 64,
72
+ )
73
+ model(input_data)
74
+ ```
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ vocabulary_size,
80
+ num_layers,
81
+ num_heads,
82
+ hidden_dim,
83
+ intermediate_dim,
84
+ head_size,
85
+ use_bias=False,
86
+ activation="relu",
87
+ dropout=0.1,
88
+ num_segments=2,
89
+ dtype=None,
90
+ max_wavelength=10000,
91
+ **kwargs,
92
+ ):
93
+ # === Layers ===
94
+ self.token_embedding = ReversibleEmbedding(
95
+ input_dim=vocabulary_size,
96
+ output_dim=hidden_dim,
97
+ embeddings_initializer=roformer_kernel_initializer(),
98
+ dtype=dtype,
99
+ name="token_embedding",
100
+ )
101
+ self.segment_embedding = keras.layers.Embedding(
102
+ input_dim=num_segments,
103
+ output_dim=hidden_dim,
104
+ embeddings_initializer=roformer_kernel_initializer(),
105
+ dtype=dtype,
106
+ name="segment_embedding",
107
+ )
108
+ self.embeddings_add = keras.layers.Add(
109
+ dtype=dtype,
110
+ name="embeddings_add",
111
+ )
112
+ self.embeddings_layer_norm = RoformerNorm(
113
+ epsilon=keras.backend.epsilon(),
114
+ dtype=dtype,
115
+ name="embeddings_layer_norm",
116
+ )
117
+ self.embeddings_dropout = keras.layers.Dropout(
118
+ dropout,
119
+ dtype=dtype,
120
+ name="embeddings_dropout",
121
+ )
122
+ self.transformer_layers = []
123
+ for i in range(num_layers):
124
+ layer = RoformerV2Encoder(
125
+ heads=num_heads,
126
+ head_size=head_size,
127
+ intermediate_size=intermediate_dim,
128
+ use_bias=use_bias,
129
+ max_wavelength=max_wavelength,
130
+ dropout=dropout,
131
+ activation=activation,
132
+ kernel_initializer=roformer_kernel_initializer(),
133
+ dtype=dtype,
134
+ name=f"transformer_layer_{i}",
135
+ )
136
+ self.transformer_layers.append(layer)
137
+
138
+ # === Functional Model ===
139
+ token_id_input = keras.Input(
140
+ shape=(None,), dtype="int32", name="token_ids"
141
+ )
142
+ segment_id_input = keras.Input(
143
+ shape=(None,), dtype="int32", name="segment_ids"
144
+ )
145
+ attention_mask = keras.ops.not_equal(token_id_input, 0)
146
+ # Embed tokens, positions, and segment ids.
147
+ tokens = self.token_embedding(token_id_input)
148
+ segments = self.segment_embedding(segment_id_input)
149
+ # Sum, normalize and apply dropout to embeddings.
150
+ x = self.embeddings_add((tokens, segments))
151
+ x = self.embeddings_layer_norm(x)
152
+ x = self.embeddings_dropout(x)
153
+ for transformer_layer in self.transformer_layers:
154
+ x = transformer_layer(x, attention_mask=attention_mask)
155
+
156
+ super().__init__(
157
+ inputs={
158
+ "token_ids": token_id_input,
159
+ "segment_ids": segment_id_input,
160
+ },
161
+ outputs=x,
162
+ dtype=dtype,
163
+ **kwargs,
164
+ )
165
+
166
+ # === Config ===
167
+ self.vocabulary_size = vocabulary_size
168
+ self.num_layers = num_layers
169
+ self.num_heads = num_heads
170
+ self.hidden_dim = hidden_dim
171
+ self.intermediate_dim = intermediate_dim
172
+ self.dropout = dropout
173
+ self.num_segments = num_segments
174
+ self.max_wavelength = max_wavelength
175
+ self.head_size = head_size
176
+ self.dropout = dropout
177
+ self.activation = activations.get(activation)
178
+ self.use_bias = use_bias
179
+ self.start_token_index = 0
180
+
181
+ def get_config(self):
182
+ config = super().get_config()
183
+ config.update(
184
+ {
185
+ "vocabulary_size": self.vocabulary_size,
186
+ "num_layers": self.num_layers,
187
+ "num_heads": self.num_heads,
188
+ "hidden_dim": self.hidden_dim,
189
+ "intermediate_dim": self.intermediate_dim,
190
+ "dropout": self.dropout,
191
+ "num_segments": self.num_segments,
192
+ "max_wavelength": self.max_wavelength,
193
+ "head_size": self.head_size,
194
+ "use_bias": self.use_bias,
195
+ "activation": activations.serialize(self.activation),
196
+ }
197
+ )
198
+ return config