keras-hub-nightly 0.20.0.dev202504010407__py3-none-any.whl → 0.20.0.dev202504030357__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/models/__init__.py +18 -0
- keras_hub/api/tokenizers/__init__.py +3 -0
- keras_hub/src/models/gemma/gemma_attention.py +26 -17
- keras_hub/src/models/gemma3/gemma3_attention.py +2 -2
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -2
- keras_hub/src/models/llama/llama_attention.py +2 -2
- keras_hub/src/models/mistral/mistral_attention.py +2 -2
- keras_hub/src/models/phi3/phi3_attention.py +2 -2
- keras_hub/src/models/qwen/qwen_attention.py +2 -2
- keras_hub/src/models/roformer_v2/__init__.py +0 -0
- keras_hub/src/models/roformer_v2/roformer_v2_attention.py +212 -0
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +198 -0
- keras_hub/src/models/roformer_v2/roformer_v2_encoder.py +128 -0
- keras_hub/src/models/roformer_v2/roformer_v2_masked_lm.py +173 -0
- keras_hub/src/models/roformer_v2/roformer_v2_masked_lm_preprocessor.py +125 -0
- keras_hub/src/models/roformer_v2/roformer_v2_presets.py +0 -0
- keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +121 -0
- keras_hub/src/models/roformer_v2/roformer_v2_text_classifier_preprocessor.py +128 -0
- keras_hub/src/models/roformer_v2/roformer_v2_tokenizer.py +62 -0
- keras_hub/src/models/stable_diffusion_3/mmdit.py +2 -2
- keras_hub/src/utils/keras_utils.py +44 -1
- keras_hub/src/utils/preset_utils.py +2 -1
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.20.0.dev202504010407.dist-info → keras_hub_nightly-0.20.0.dev202504030357.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.20.0.dev202504010407.dist-info → keras_hub_nightly-0.20.0.dev202504030357.dist-info}/RECORD +27 -17
- {keras_hub_nightly-0.20.0.dev202504010407.dist-info → keras_hub_nightly-0.20.0.dev202504030357.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.20.0.dev202504010407.dist-info → keras_hub_nightly-0.20.0.dev202504030357.dist-info}/top_level.txt +0 -0
keras_hub/api/models/__init__.py
CHANGED
@@ -323,6 +323,24 @@ from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import (
|
|
323
323
|
RobertaTextClassifierPreprocessor as RobertaPreprocessor,
|
324
324
|
)
|
325
325
|
from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
|
326
|
+
from keras_hub.src.models.roformer_v2.roformer_v2_backbone import (
|
327
|
+
RoformerV2Backbone as RorformerV2Backbone,
|
328
|
+
)
|
329
|
+
from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm import (
|
330
|
+
RoformerV2MaskedLM,
|
331
|
+
)
|
332
|
+
from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm_preprocessor import (
|
333
|
+
RoformerV2MaskedLMPreprocessor,
|
334
|
+
)
|
335
|
+
from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier import (
|
336
|
+
RorformerV2TextClassifier,
|
337
|
+
)
|
338
|
+
from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor import (
|
339
|
+
RoformerV2TextClassifierPreprocessor,
|
340
|
+
)
|
341
|
+
from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
|
342
|
+
RoformerV2Tokenizer,
|
343
|
+
)
|
326
344
|
from keras_hub.src.models.sam.sam_backbone import SAMBackbone
|
327
345
|
from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter
|
328
346
|
from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (
|
@@ -35,6 +35,9 @@ from keras_hub.src.models.qwen.qwen_tokenizer import (
|
|
35
35
|
QwenTokenizer as Qwen2Tokenizer,
|
36
36
|
)
|
37
37
|
from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
|
38
|
+
from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
|
39
|
+
RoformerV2Tokenizer,
|
40
|
+
)
|
38
41
|
from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer
|
39
42
|
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer
|
40
43
|
from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
|
@@ -6,7 +6,9 @@ from keras import ops
|
|
6
6
|
|
7
7
|
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
8
8
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
9
|
-
from keras_hub.src.utils.keras_utils import
|
9
|
+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
10
|
+
from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
|
11
|
+
from keras_hub.src.utils.keras_utils import running_on_gpu
|
10
12
|
from keras_hub.src.utils.keras_utils import running_on_tpu
|
11
13
|
|
12
14
|
|
@@ -106,17 +108,22 @@ class CachedGemmaAttention(keras.layers.Layer):
|
|
106
108
|
)
|
107
109
|
return x
|
108
110
|
|
109
|
-
def
|
110
|
-
if not
|
111
|
+
def _use_fused_attention_op(self):
|
112
|
+
if not fused_attention_op_available():
|
111
113
|
return False
|
112
114
|
if self.dropout > 0.0:
|
113
115
|
return False
|
114
|
-
if
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
116
|
+
if running_on_gpu():
|
117
|
+
# GPU never supports softcap in the fused op.
|
118
|
+
if self.logit_soft_cap is not None:
|
119
|
+
return False
|
120
|
+
return gpu_supports_fused_attention_op()
|
121
|
+
elif running_on_tpu():
|
122
|
+
# TPU supports softcap with on keras >= 3.10.
|
123
|
+
sig = inspect.signature(ops.dot_product_attention)
|
124
|
+
return "attn_logits_soft_cap" in sig.parameters
|
125
|
+
else:
|
126
|
+
return False
|
120
127
|
|
121
128
|
def _compute_attention(
|
122
129
|
self,
|
@@ -133,7 +140,14 @@ class CachedGemmaAttention(keras.layers.Layer):
|
|
133
140
|
query_normalization = 1 / np.sqrt(
|
134
141
|
self.hidden_dim // self.num_query_heads
|
135
142
|
)
|
136
|
-
|
143
|
+
|
144
|
+
if self.use_sliding_window_attention and attention_mask is not None:
|
145
|
+
attention_mask = self._mask_sliding_window(
|
146
|
+
attention_mask,
|
147
|
+
cache_update_index=cache_update_index,
|
148
|
+
)
|
149
|
+
|
150
|
+
if self._use_fused_attention_op():
|
137
151
|
if attention_mask is not None:
|
138
152
|
attention_mask = ops.expand_dims(attention_mask, axis=1)
|
139
153
|
attention_mask = ops.cast(attention_mask, dtype="bool")
|
@@ -172,13 +186,8 @@ class CachedGemmaAttention(keras.layers.Layer):
|
|
172
186
|
ops.tanh(attention_logits), self.logit_soft_cap
|
173
187
|
)
|
174
188
|
|
175
|
-
if
|
176
|
-
attention_mask =
|
177
|
-
attention_mask,
|
178
|
-
cache_update_index=cache_update_index,
|
179
|
-
)
|
180
|
-
|
181
|
-
attention_mask = attention_mask[:, None, None, :, :]
|
189
|
+
if attention_mask is not None:
|
190
|
+
attention_mask = attention_mask[:, None, None, :, :]
|
182
191
|
orig_dtype = attention_logits.dtype
|
183
192
|
attention_softmax = self.softmax(attention_logits, mask=attention_mask)
|
184
193
|
attention_softmax = ops.cast(attention_softmax, orig_dtype)
|
@@ -7,7 +7,7 @@ from keras import ops
|
|
7
7
|
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
8
8
|
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
|
9
9
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
10
|
-
from keras_hub.src.utils.keras_utils import
|
10
|
+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
11
11
|
from keras_hub.src.utils.keras_utils import running_on_tpu
|
12
12
|
|
13
13
|
|
@@ -140,7 +140,7 @@ class CachedGemma3Attention(keras.layers.Layer):
|
|
140
140
|
return x
|
141
141
|
|
142
142
|
def _can_use_flash_attention(self):
|
143
|
-
if not
|
143
|
+
if not fused_attention_op_available():
|
144
144
|
return False
|
145
145
|
if self.dropout > 0.0:
|
146
146
|
return False
|
@@ -5,7 +5,7 @@ from keras import ops
|
|
5
5
|
|
6
6
|
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
7
7
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
8
|
-
from keras_hub.src.utils.keras_utils import
|
8
|
+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
9
9
|
|
10
10
|
|
11
11
|
class GPTNeoXAttention(keras.layers.Layer):
|
@@ -125,7 +125,7 @@ class GPTNeoXAttention(keras.layers.Layer):
|
|
125
125
|
def _compute_attention(
|
126
126
|
self, query, key, value, attention_mask=None, training=None
|
127
127
|
):
|
128
|
-
if
|
128
|
+
if fused_attention_op_available() and self.dropout == 0:
|
129
129
|
# Use `dot_product_attention` with Flash Attention support if
|
130
130
|
# available.
|
131
131
|
if attention_mask is not None:
|
@@ -5,7 +5,7 @@ from keras import ops
|
|
5
5
|
|
6
6
|
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
7
7
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
8
|
-
from keras_hub.src.utils.keras_utils import
|
8
|
+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
9
9
|
|
10
10
|
|
11
11
|
class LlamaAttention(keras.layers.Layer):
|
@@ -185,7 +185,7 @@ class LlamaAttention(keras.layers.Layer):
|
|
185
185
|
return self._softmax(attention_scores)
|
186
186
|
|
187
187
|
def _compute_attention(self, query, key, value, attention_mask=None):
|
188
|
-
if
|
188
|
+
if fused_attention_op_available():
|
189
189
|
# Use `dot_product_attention` with Flash Attention support if
|
190
190
|
# available.
|
191
191
|
if attention_mask is not None:
|
@@ -5,7 +5,7 @@ from keras import ops
|
|
5
5
|
|
6
6
|
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
7
7
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
8
|
-
from keras_hub.src.utils.keras_utils import
|
8
|
+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
9
9
|
|
10
10
|
|
11
11
|
# This is just a self-attention layer in Mistral. But it can be generalized
|
@@ -196,7 +196,7 @@ class CachedMistralAttention(keras.layers.Layer):
|
|
196
196
|
return self._softmax(attention_scores)
|
197
197
|
|
198
198
|
def _compute_attention(self, query, key, value, attention_mask=None):
|
199
|
-
if
|
199
|
+
if fused_attention_op_available():
|
200
200
|
# Use `dot_product_attention` with Flash Attention support if
|
201
201
|
# available.
|
202
202
|
if attention_mask is not None:
|
@@ -8,7 +8,7 @@ from keras_hub.src.models.phi3.phi3_rotary_embedding import (
|
|
8
8
|
Phi3SuScaledRotaryEmbedding,
|
9
9
|
)
|
10
10
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
11
|
-
from keras_hub.src.utils.keras_utils import
|
11
|
+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
12
12
|
|
13
13
|
|
14
14
|
class Phi3Attention(keras.layers.Layer):
|
@@ -217,7 +217,7 @@ class Phi3Attention(keras.layers.Layer):
|
|
217
217
|
return self.softmax(attention_scores)
|
218
218
|
|
219
219
|
def _compute_attention(self, query, key, value, attention_mask=None):
|
220
|
-
if
|
220
|
+
if fused_attention_op_available():
|
221
221
|
# Use `dot_product_attention` with Flash Attention support if
|
222
222
|
# available.
|
223
223
|
if attention_mask is not None:
|
@@ -5,7 +5,7 @@ from keras import ops
|
|
5
5
|
|
6
6
|
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
7
7
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
8
|
-
from keras_hub.src.utils.keras_utils import
|
8
|
+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
9
9
|
|
10
10
|
|
11
11
|
class QwenAttention(keras.layers.Layer):
|
@@ -263,7 +263,7 @@ class QwenAttention(keras.layers.Layer):
|
|
263
263
|
Returns:
|
264
264
|
attention_output: Output tensor after applying attention.
|
265
265
|
"""
|
266
|
-
if
|
266
|
+
if fused_attention_op_available():
|
267
267
|
# Use `dot_product_attention` with Flash Attention support if
|
268
268
|
# available.
|
269
269
|
if attention_mask is not None:
|
File without changes
|
@@ -0,0 +1,212 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import initializers
|
3
|
+
from keras import ops
|
4
|
+
|
5
|
+
|
6
|
+
class RoformerNorm(keras.layers.Layer):
|
7
|
+
"""A normalization layer for Roformer that implements RMS normalization."""
|
8
|
+
|
9
|
+
def __init__(self, epsilon=1e-6, **kwargs):
|
10
|
+
super().__init__(**kwargs)
|
11
|
+
self.epsilon = epsilon
|
12
|
+
|
13
|
+
def build(self, input_shape):
|
14
|
+
dim = input_shape[-1]
|
15
|
+
self.scale = self.add_weight(
|
16
|
+
name="scale",
|
17
|
+
trainable=True,
|
18
|
+
shape=(dim,),
|
19
|
+
initializer="ones",
|
20
|
+
dtype=self.variable_dtype,
|
21
|
+
)
|
22
|
+
self.built = True
|
23
|
+
|
24
|
+
def call(self, x):
|
25
|
+
x = ops.cast(x, "float32")
|
26
|
+
var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
|
27
|
+
x = x * ops.rsqrt(var + self.epsilon)
|
28
|
+
return ops.cast(x * self.scale, self.compute_dtype)
|
29
|
+
|
30
|
+
def get_config(self):
|
31
|
+
config = super().get_config()
|
32
|
+
config.update({"epsilon": self.epsilon})
|
33
|
+
return config
|
34
|
+
|
35
|
+
|
36
|
+
class RoformrPositionalEmbedding(keras.layers.Layer):
|
37
|
+
"""Native rotary implement by jianlin su
|
38
|
+
from native implement
|
39
|
+
https://github.com/bojone/bert4keras
|
40
|
+
|
41
|
+
"""
|
42
|
+
|
43
|
+
def __init__(self, output_dim, max_wavelength=10000, **kwargs):
|
44
|
+
super().__init__(**kwargs)
|
45
|
+
self.max_wavelength = max_wavelength
|
46
|
+
self.output_dim = output_dim
|
47
|
+
|
48
|
+
def call(self, tensors):
|
49
|
+
input_shape = ops.shape(tensors[0])
|
50
|
+
seq_len = input_shape[1]
|
51
|
+
position_ids = ops.arange(0, seq_len, dtype=tensors[0].dtype)[None]
|
52
|
+
embeddings = self.sinusoidal_embeddings(
|
53
|
+
position_ids, self.output_dim, self.max_wavelength
|
54
|
+
)
|
55
|
+
embeddings = ops.cast(embeddings, self.compute_dtype)
|
56
|
+
|
57
|
+
ndim = ops.ndim(tensors[0])
|
58
|
+
sinusoidal = self.align(embeddings, [0, 1, -1], ndim)
|
59
|
+
cos_pos = ops.repeat(sinusoidal[..., 1::2], 2, -1)
|
60
|
+
sin_pos = ops.repeat(sinusoidal[..., ::2], 2, -1)
|
61
|
+
outputs = []
|
62
|
+
for tensor in tensors:
|
63
|
+
tensor2 = ops.stack([-tensor[..., 1::2], tensor[..., ::2]], ndim)
|
64
|
+
tensor2 = ops.reshape(tensor2, ops.shape(tensor))
|
65
|
+
outputs.append(tensor * cos_pos + tensor2 * sin_pos)
|
66
|
+
return outputs[0] if len(outputs) == 1 else outputs
|
67
|
+
|
68
|
+
def align(self, tensor, axes, ndim=None):
|
69
|
+
ndim = ndim or max(axes) + 1
|
70
|
+
indices = [None] * ndim
|
71
|
+
for i in axes:
|
72
|
+
indices[i] = slice(None)
|
73
|
+
if keras.config.backend() == "jax":
|
74
|
+
return tensor[tuple(indices)]
|
75
|
+
return tensor[indices]
|
76
|
+
|
77
|
+
def sinusoidal_embeddings(self, pos, dim, base=10000):
|
78
|
+
if dim % 2 != 0:
|
79
|
+
raise ("Dimension must be even")
|
80
|
+
|
81
|
+
indices = ops.arange(0, dim // 2, dtype="float32")
|
82
|
+
indices = ops.power(ops.cast(base, dtype="float32"), -2 * indices / dim)
|
83
|
+
embeddings = ops.einsum("...,d->...d", pos, indices)
|
84
|
+
embeddings = ops.stack(
|
85
|
+
[ops.sin(embeddings), ops.cos(embeddings)], axis=-1
|
86
|
+
)
|
87
|
+
shape = list(ops.shape(embeddings))
|
88
|
+
embeddings = ops.reshape(embeddings, shape[:-2] + [-1])
|
89
|
+
return embeddings
|
90
|
+
|
91
|
+
def get_config(self):
|
92
|
+
config = super().get_config()
|
93
|
+
config.update(
|
94
|
+
{
|
95
|
+
"out_dim": self.out_dim,
|
96
|
+
"max_wavelength": self.max_wavelength,
|
97
|
+
}
|
98
|
+
)
|
99
|
+
return config
|
100
|
+
|
101
|
+
|
102
|
+
@keras.saving.register_keras_serializable(package="keras_hub")
|
103
|
+
class RoformerAttention(keras.layers.Layer):
|
104
|
+
"""MultiHeadAttention by roformerV2
|
105
|
+
|
106
|
+
modifity from native implement
|
107
|
+
https://github.com/bojone/bert4keras
|
108
|
+
"""
|
109
|
+
|
110
|
+
def __init__(
|
111
|
+
self,
|
112
|
+
heads,
|
113
|
+
head_size,
|
114
|
+
out_dim=None,
|
115
|
+
use_bias=False,
|
116
|
+
max_wavelength=10000,
|
117
|
+
kernel_initializer="glorot_uniform",
|
118
|
+
**kwargs,
|
119
|
+
):
|
120
|
+
super().__init__(**kwargs)
|
121
|
+
self.heads = heads
|
122
|
+
self.head_size = head_size
|
123
|
+
self.out_dim = out_dim or heads * head_size
|
124
|
+
self.use_bias = use_bias
|
125
|
+
self.kernel_initializer = initializers.get(kernel_initializer)
|
126
|
+
self.max_wavelength = max_wavelength
|
127
|
+
|
128
|
+
def build(self, input_shape):
|
129
|
+
super().build(input_shape)
|
130
|
+
self.q_dense = keras.layers.Dense(
|
131
|
+
units=self.head_size * self.heads,
|
132
|
+
use_bias=self.use_bias,
|
133
|
+
kernel_initializer=self.kernel_initializer,
|
134
|
+
name="q_dense_layer",
|
135
|
+
dtype=self.dtype_policy,
|
136
|
+
)
|
137
|
+
self.q_dense.build(input_shape)
|
138
|
+
|
139
|
+
self.k_dense = keras.layers.Dense(
|
140
|
+
units=self.head_size * self.heads,
|
141
|
+
use_bias=self.use_bias,
|
142
|
+
kernel_initializer=self.kernel_initializer,
|
143
|
+
name="k_dense_layer",
|
144
|
+
dtype=self.dtype_policy,
|
145
|
+
)
|
146
|
+
self.k_dense.build(input_shape)
|
147
|
+
|
148
|
+
self.v_dense = keras.layers.Dense(
|
149
|
+
units=self.head_size * self.heads,
|
150
|
+
use_bias=self.use_bias,
|
151
|
+
kernel_initializer=self.kernel_initializer,
|
152
|
+
name="v_dense_layer",
|
153
|
+
dtype=self.dtype_policy,
|
154
|
+
)
|
155
|
+
self.v_dense.build(input_shape)
|
156
|
+
|
157
|
+
self.o_dense = keras.layers.Dense(
|
158
|
+
units=self.out_dim,
|
159
|
+
use_bias=self.use_bias,
|
160
|
+
kernel_initializer=self.kernel_initializer,
|
161
|
+
name="o_dense_layer",
|
162
|
+
dtype=self.dtype_policy,
|
163
|
+
)
|
164
|
+
self.o_dense.build([None, None, self.head_size * self.heads])
|
165
|
+
|
166
|
+
self.rotary_embedding_layer = RoformrPositionalEmbedding(
|
167
|
+
self.head_size, self.max_wavelength, dtype=self.dtype_policy
|
168
|
+
)
|
169
|
+
self.rotary_embedding_layer.build([])
|
170
|
+
|
171
|
+
def call(self, x, attention_mask=None):
|
172
|
+
qw = self.q_dense(x)
|
173
|
+
kw = self.k_dense(x)
|
174
|
+
vw = self.v_dense(x)
|
175
|
+
|
176
|
+
b, s = ops.shape(qw)[:2]
|
177
|
+
qw = ops.reshape(qw, (b, s, self.heads, self.head_size))
|
178
|
+
kw = ops.reshape(kw, (b, s, self.heads, self.head_size))
|
179
|
+
vw = ops.reshape(vw, (b, s, self.heads, self.head_size))
|
180
|
+
|
181
|
+
qw, kw = self.rotary_embedding_layer([qw, kw])
|
182
|
+
if keras.__version__ < "3.6":
|
183
|
+
raise ("Please make sure your Keras version is >=3.6.")
|
184
|
+
flash_attention = keras.config.is_flash_attention_enabled()
|
185
|
+
attention_mask = ops.reshape(attention_mask, [b, 1, s, 1])
|
186
|
+
if keras.config.backend() == "torch":
|
187
|
+
attention_mask = ops.repeat(attention_mask, s, -1)
|
188
|
+
attention_mask = ops.transpose(attention_mask, [0, 1, 3, 2])
|
189
|
+
o = ops.dot_product_attention(
|
190
|
+
qw, kw, vw, mask=attention_mask, flash_attention=flash_attention
|
191
|
+
)
|
192
|
+
|
193
|
+
return self.o_dense(ops.reshape(o, [b, s, -1]))
|
194
|
+
|
195
|
+
def compute_output_shape(self, input_shape):
|
196
|
+
return input_shape
|
197
|
+
|
198
|
+
def get_config(self):
|
199
|
+
config = super().get_config()
|
200
|
+
config.update(
|
201
|
+
{
|
202
|
+
"heads": self.heads,
|
203
|
+
"head_size": self.head_size,
|
204
|
+
"out_dim": self.out_dim,
|
205
|
+
"use_bias": self.use_bias,
|
206
|
+
"max_wavelength": self.max_wavelength,
|
207
|
+
"kernel_initializer": initializers.serialize(
|
208
|
+
self.kernel_initializer
|
209
|
+
),
|
210
|
+
}
|
211
|
+
)
|
212
|
+
return config
|
@@ -0,0 +1,198 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import activations
|
3
|
+
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
5
|
+
from keras_hub.src.layers.modeling.reversible_embedding import (
|
6
|
+
ReversibleEmbedding,
|
7
|
+
)
|
8
|
+
from keras_hub.src.models.backbone import Backbone
|
9
|
+
from keras_hub.src.models.roformer_v2.roformer_v2_attention import RoformerNorm
|
10
|
+
from keras_hub.src.models.roformer_v2.roformer_v2_encoder import (
|
11
|
+
RoformerV2Encoder,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
def roformer_kernel_initializer(stddev=0.02):
|
16
|
+
return keras.initializers.TruncatedNormal(stddev=stddev)
|
17
|
+
|
18
|
+
|
19
|
+
@keras_hub_export("keras_hub.models.RorformerV2Backbone")
|
20
|
+
class RoformerV2Backbone(Backbone):
|
21
|
+
"""A RoformerV2 encoder network.
|
22
|
+
|
23
|
+
This class implements a bi-directional Transformer-based encoder as
|
24
|
+
described in ["Roformer"](https://github.com/ZhuiyiTechnology/roformer).
|
25
|
+
It includes the
|
26
|
+
embedding lookups and transformer layers, but not the masked language model
|
27
|
+
or next sentence prediction heads.
|
28
|
+
|
29
|
+
The default constructor gives a fully customizable, randomly initialized
|
30
|
+
RoformerV2 encoder with any number of layers, heads, and embed dim.To
|
31
|
+
load preset architectures and weights, use the `from_preset()` constructor.
|
32
|
+
|
33
|
+
Disclaimer: Pre-trained models are provided on an "as is" basis, without
|
34
|
+
warranties or conditions of any kind.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
vocabulary_size: int. The size of the token vocabulary.
|
38
|
+
num_layers: int. The number of transformer layers.
|
39
|
+
num_heads: int. The number of attention heads for each transformer.
|
40
|
+
The hidden size must be divisible by the number of attention heads.
|
41
|
+
hidden_dim: int. The size of the transformer encoding and pooler layers.
|
42
|
+
intermediate_dim: int. The output dimension of the first Dense layer in
|
43
|
+
a two-layer feedforward network for each transformer.
|
44
|
+
dropout: float. Dropout probability for the Transformer encoder.
|
45
|
+
num_segments: int. The number of types that the 'segment_ids' input can
|
46
|
+
take.
|
47
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
48
|
+
for model computations and weights. Note that some computations,
|
49
|
+
such as softmax and layer normalization, will always be done at
|
50
|
+
float32 precision regardless of dtype.
|
51
|
+
|
52
|
+
Examples:
|
53
|
+
```python
|
54
|
+
input_data = {
|
55
|
+
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
|
56
|
+
"segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]]),
|
57
|
+
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
|
58
|
+
}
|
59
|
+
|
60
|
+
# Pretrained RoformerV2 encoder.
|
61
|
+
model = keras_hub.models.RoformerV2Backbone.from_preset("roformer_v2_base")
|
62
|
+
model(input_data)
|
63
|
+
|
64
|
+
# Randomly initialized RoformerV2 encoder with a custom config.
|
65
|
+
model = keras_hub.models.RoformerV2Backbone(
|
66
|
+
vocabulary_size=30552,
|
67
|
+
num_layers=4,
|
68
|
+
num_heads=4,
|
69
|
+
hidden_dim=256,
|
70
|
+
intermediate_dim=512,
|
71
|
+
head_size = 64,
|
72
|
+
)
|
73
|
+
model(input_data)
|
74
|
+
```
|
75
|
+
"""
|
76
|
+
|
77
|
+
def __init__(
|
78
|
+
self,
|
79
|
+
vocabulary_size,
|
80
|
+
num_layers,
|
81
|
+
num_heads,
|
82
|
+
hidden_dim,
|
83
|
+
intermediate_dim,
|
84
|
+
head_size,
|
85
|
+
use_bias=False,
|
86
|
+
activation="relu",
|
87
|
+
dropout=0.1,
|
88
|
+
num_segments=2,
|
89
|
+
dtype=None,
|
90
|
+
max_wavelength=10000,
|
91
|
+
**kwargs,
|
92
|
+
):
|
93
|
+
# === Layers ===
|
94
|
+
self.token_embedding = ReversibleEmbedding(
|
95
|
+
input_dim=vocabulary_size,
|
96
|
+
output_dim=hidden_dim,
|
97
|
+
embeddings_initializer=roformer_kernel_initializer(),
|
98
|
+
dtype=dtype,
|
99
|
+
name="token_embedding",
|
100
|
+
)
|
101
|
+
self.segment_embedding = keras.layers.Embedding(
|
102
|
+
input_dim=num_segments,
|
103
|
+
output_dim=hidden_dim,
|
104
|
+
embeddings_initializer=roformer_kernel_initializer(),
|
105
|
+
dtype=dtype,
|
106
|
+
name="segment_embedding",
|
107
|
+
)
|
108
|
+
self.embeddings_add = keras.layers.Add(
|
109
|
+
dtype=dtype,
|
110
|
+
name="embeddings_add",
|
111
|
+
)
|
112
|
+
self.embeddings_layer_norm = RoformerNorm(
|
113
|
+
epsilon=keras.backend.epsilon(),
|
114
|
+
dtype=dtype,
|
115
|
+
name="embeddings_layer_norm",
|
116
|
+
)
|
117
|
+
self.embeddings_dropout = keras.layers.Dropout(
|
118
|
+
dropout,
|
119
|
+
dtype=dtype,
|
120
|
+
name="embeddings_dropout",
|
121
|
+
)
|
122
|
+
self.transformer_layers = []
|
123
|
+
for i in range(num_layers):
|
124
|
+
layer = RoformerV2Encoder(
|
125
|
+
heads=num_heads,
|
126
|
+
head_size=head_size,
|
127
|
+
intermediate_size=intermediate_dim,
|
128
|
+
use_bias=use_bias,
|
129
|
+
max_wavelength=max_wavelength,
|
130
|
+
dropout=dropout,
|
131
|
+
activation=activation,
|
132
|
+
kernel_initializer=roformer_kernel_initializer(),
|
133
|
+
dtype=dtype,
|
134
|
+
name=f"transformer_layer_{i}",
|
135
|
+
)
|
136
|
+
self.transformer_layers.append(layer)
|
137
|
+
|
138
|
+
# === Functional Model ===
|
139
|
+
token_id_input = keras.Input(
|
140
|
+
shape=(None,), dtype="int32", name="token_ids"
|
141
|
+
)
|
142
|
+
segment_id_input = keras.Input(
|
143
|
+
shape=(None,), dtype="int32", name="segment_ids"
|
144
|
+
)
|
145
|
+
attention_mask = keras.ops.not_equal(token_id_input, 0)
|
146
|
+
# Embed tokens, positions, and segment ids.
|
147
|
+
tokens = self.token_embedding(token_id_input)
|
148
|
+
segments = self.segment_embedding(segment_id_input)
|
149
|
+
# Sum, normalize and apply dropout to embeddings.
|
150
|
+
x = self.embeddings_add((tokens, segments))
|
151
|
+
x = self.embeddings_layer_norm(x)
|
152
|
+
x = self.embeddings_dropout(x)
|
153
|
+
for transformer_layer in self.transformer_layers:
|
154
|
+
x = transformer_layer(x, attention_mask=attention_mask)
|
155
|
+
|
156
|
+
super().__init__(
|
157
|
+
inputs={
|
158
|
+
"token_ids": token_id_input,
|
159
|
+
"segment_ids": segment_id_input,
|
160
|
+
},
|
161
|
+
outputs=x,
|
162
|
+
dtype=dtype,
|
163
|
+
**kwargs,
|
164
|
+
)
|
165
|
+
|
166
|
+
# === Config ===
|
167
|
+
self.vocabulary_size = vocabulary_size
|
168
|
+
self.num_layers = num_layers
|
169
|
+
self.num_heads = num_heads
|
170
|
+
self.hidden_dim = hidden_dim
|
171
|
+
self.intermediate_dim = intermediate_dim
|
172
|
+
self.dropout = dropout
|
173
|
+
self.num_segments = num_segments
|
174
|
+
self.max_wavelength = max_wavelength
|
175
|
+
self.head_size = head_size
|
176
|
+
self.dropout = dropout
|
177
|
+
self.activation = activations.get(activation)
|
178
|
+
self.use_bias = use_bias
|
179
|
+
self.start_token_index = 0
|
180
|
+
|
181
|
+
def get_config(self):
|
182
|
+
config = super().get_config()
|
183
|
+
config.update(
|
184
|
+
{
|
185
|
+
"vocabulary_size": self.vocabulary_size,
|
186
|
+
"num_layers": self.num_layers,
|
187
|
+
"num_heads": self.num_heads,
|
188
|
+
"hidden_dim": self.hidden_dim,
|
189
|
+
"intermediate_dim": self.intermediate_dim,
|
190
|
+
"dropout": self.dropout,
|
191
|
+
"num_segments": self.num_segments,
|
192
|
+
"max_wavelength": self.max_wavelength,
|
193
|
+
"head_size": self.head_size,
|
194
|
+
"use_bias": self.use_bias,
|
195
|
+
"activation": activations.serialize(self.activation),
|
196
|
+
}
|
197
|
+
)
|
198
|
+
return config
|