keras-hub-nightly 0.23.0.dev202510080414__py3-none-any.whl → 0.24.0.dev202511080419__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +6 -0
- keras_hub/models/__init__.py +36 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +6 -0
- keras_hub/src/models/causal_lm.py +5 -0
- keras_hub/src/models/depth_anything/depth_anything_presets.py +38 -1
- keras_hub/src/models/dinov2/dinov2_layers.py +3 -1
- keras_hub/src/models/dinov3/__init__.py +5 -0
- keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
- keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
- keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
- keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
- keras_hub/src/models/gemma/gemma_presets.py +22 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
- keras_hub/src/models/image_to_image.py +5 -0
- keras_hub/src/models/inpaint.py +5 -0
- keras_hub/src/models/mobilenetv5/__init__.py +9 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
- keras_hub/src/models/parseq/__init__.py +5 -0
- keras_hub/src/models/parseq/parseq_presets.py +15 -0
- keras_hub/src/models/qwen3_moe/__init__.py +5 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
- keras_hub/src/models/siglip/siglip_presets.py +15 -0
- keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
- keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
- keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
- keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
- keras_hub/src/models/text_to_image.py +5 -0
- keras_hub/src/utils/preset_utils.py +9 -2
- keras_hub/src/utils/tensor_utils.py +3 -1
- keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
- keras_hub/src/utils/timm/preset_loader.py +8 -4
- keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
- keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
- keras_hub/src/utils/transformers/preset_loader.py +6 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/RECORD +52 -24
- {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,757 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from keras import activations
|
|
4
|
+
from keras import initializers
|
|
5
|
+
from keras import layers
|
|
6
|
+
from keras import ops
|
|
7
|
+
|
|
8
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
|
9
|
+
compute_causal_mask,
|
|
10
|
+
)
|
|
11
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
|
12
|
+
merge_padding_and_attention_mask,
|
|
13
|
+
)
|
|
14
|
+
from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb
|
|
15
|
+
from keras_hub.src.models.smollm3.smollm3_utils import rope_init
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SmolLM3Attention(layers.Layer):
|
|
19
|
+
"""Multi-head attention layer for SmolLM3 model.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
hidden_size: int. The hidden size of the attention layer.
|
|
23
|
+
num_attention_heads: int. The number of attention heads.
|
|
24
|
+
num_key_value_heads: int. The number of key-value heads.
|
|
25
|
+
attention_bias: bool. Whether to use bias in attention projections.
|
|
26
|
+
attention_dropout: float. Dropout rate for attention weights.
|
|
27
|
+
rope_layer_enabled_list: list of bool. List indicating if RoPE is
|
|
28
|
+
enabled for each layer.
|
|
29
|
+
layer_types: list of str. List of layer types.
|
|
30
|
+
layer_idx: int. Index of the current layer.
|
|
31
|
+
max_position_embeddings: int. Maximum sequence length for position
|
|
32
|
+
embeddings. Defaults to 2048.
|
|
33
|
+
rope_theta: float. The theta value for RoPE. Defaults to 10000.0.
|
|
34
|
+
partial_rotary_factor: float. The factor for partial rotary embedding.
|
|
35
|
+
Defaults to 1.0.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
hidden_size,
|
|
41
|
+
num_attention_heads,
|
|
42
|
+
num_key_value_heads,
|
|
43
|
+
attention_bias,
|
|
44
|
+
attention_dropout,
|
|
45
|
+
rope_layer_enabled_list,
|
|
46
|
+
layer_types,
|
|
47
|
+
layer_idx,
|
|
48
|
+
max_position_embeddings=2048,
|
|
49
|
+
rope_theta=10000.0,
|
|
50
|
+
partial_rotary_factor=1.0,
|
|
51
|
+
**kwargs,
|
|
52
|
+
):
|
|
53
|
+
super().__init__(**kwargs)
|
|
54
|
+
|
|
55
|
+
self.hidden_size = hidden_size
|
|
56
|
+
self.num_attention_heads = num_attention_heads
|
|
57
|
+
self.num_key_value_heads = num_key_value_heads
|
|
58
|
+
self.attention_bias = attention_bias
|
|
59
|
+
self.attention_dropout = attention_dropout
|
|
60
|
+
self.rope_layer_enabled_list = rope_layer_enabled_list
|
|
61
|
+
self.layer_types = layer_types
|
|
62
|
+
self.max_position_embeddings = max_position_embeddings
|
|
63
|
+
self.rope_theta = rope_theta
|
|
64
|
+
self.partial_rotary_factor = partial_rotary_factor
|
|
65
|
+
|
|
66
|
+
self._dot_product_equation = "bquh,bkuh->buqk"
|
|
67
|
+
self._combine_equation = "buqk,bkuh->bquh"
|
|
68
|
+
|
|
69
|
+
self.head_dim = hidden_size // self.num_attention_heads
|
|
70
|
+
self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
|
71
|
+
|
|
72
|
+
self.layer_idx = layer_idx
|
|
73
|
+
self.num_key_value_groups = (
|
|
74
|
+
self.num_attention_heads // self.num_key_value_heads
|
|
75
|
+
)
|
|
76
|
+
self.scaling = self.head_dim**-0.5
|
|
77
|
+
self.is_causal = True
|
|
78
|
+
|
|
79
|
+
self.q_proj = layers.Dense(
|
|
80
|
+
self.num_attention_heads * self.head_dim,
|
|
81
|
+
use_bias=self.attention_bias,
|
|
82
|
+
name="q_proj",
|
|
83
|
+
)
|
|
84
|
+
self.k_proj = layers.Dense(
|
|
85
|
+
self.num_key_value_heads * self.head_dim,
|
|
86
|
+
use_bias=self.attention_bias,
|
|
87
|
+
name="k_proj",
|
|
88
|
+
)
|
|
89
|
+
self.v_proj = layers.Dense(
|
|
90
|
+
self.num_key_value_heads * self.head_dim,
|
|
91
|
+
use_bias=self.attention_bias,
|
|
92
|
+
name="v_proj",
|
|
93
|
+
)
|
|
94
|
+
self.o_proj = layers.EinsumDense(
|
|
95
|
+
equation="bquh,uhm->bqm",
|
|
96
|
+
output_shape=(None, self.hidden_size),
|
|
97
|
+
name="o_proj",
|
|
98
|
+
)
|
|
99
|
+
self.o_proj.build((None, None, self.num_attention_heads, self.head_dim))
|
|
100
|
+
|
|
101
|
+
self.use_rope = (
|
|
102
|
+
self.rope_layer_enabled_list[self.layer_idx]
|
|
103
|
+
if self.layer_idx < len(self.rope_layer_enabled_list)
|
|
104
|
+
else True
|
|
105
|
+
) # Default to True if index out of bounds
|
|
106
|
+
|
|
107
|
+
self.rotary_embedding = SmolLM3RotaryEmbedding(
|
|
108
|
+
hidden_size=self.hidden_size,
|
|
109
|
+
num_attention_heads=self.num_attention_heads,
|
|
110
|
+
max_position_embeddings=self.max_position_embeddings,
|
|
111
|
+
rope_theta=self.rope_theta,
|
|
112
|
+
partial_rotary_factor=self.partial_rotary_factor,
|
|
113
|
+
name="rotary_emb",
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
self._softmax = layers.Softmax(
|
|
117
|
+
axis=-1,
|
|
118
|
+
dtype="float32",
|
|
119
|
+
name="attention_softmax",
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def build(self, input_shape):
|
|
123
|
+
"""Builds the internal Dense layers.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
input_shape: A list/tuple of shapes for the inputs:
|
|
127
|
+
[hidden_states_shape, position_embeddings_shape_tuple,
|
|
128
|
+
attention_mask_shape]
|
|
129
|
+
- hidden_states_shape: (batch_size, seq_len,
|
|
130
|
+
hidden_size)
|
|
131
|
+
"""
|
|
132
|
+
# The input shape to the Dense layers (q_proj, k_proj, v_proj, o_proj)
|
|
133
|
+
# is the same as the hidden_states input to SmolLM3Attention.
|
|
134
|
+
hidden_states_shape = input_shape[0]
|
|
135
|
+
self.q_proj.build(hidden_states_shape)
|
|
136
|
+
self.k_proj.build(hidden_states_shape)
|
|
137
|
+
self.v_proj.build(hidden_states_shape)
|
|
138
|
+
super().build(input_shape)
|
|
139
|
+
|
|
140
|
+
def call(
|
|
141
|
+
self,
|
|
142
|
+
hidden_states,
|
|
143
|
+
training=False,
|
|
144
|
+
attention_mask=None,
|
|
145
|
+
**kwargs,
|
|
146
|
+
):
|
|
147
|
+
"""Forward pass for SmolLM3Attention.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
hidden_states: Input tensor of shape (batch_size, seq_len,
|
|
151
|
+
hidden_size).
|
|
152
|
+
position_embeddings: Tuple of (cos, sin) tensors for RoPE.
|
|
153
|
+
attention_mask: Attention mask tensor.
|
|
154
|
+
training: Whether the layer is in training mode.
|
|
155
|
+
"""
|
|
156
|
+
self.training = training
|
|
157
|
+
self_attention_cache = kwargs.get("self_attention_cache", None)
|
|
158
|
+
self_attention_cache_update_index = kwargs.get(
|
|
159
|
+
"self_attention_cache_update_index", None
|
|
160
|
+
)
|
|
161
|
+
start_index = (
|
|
162
|
+
self_attention_cache_update_index
|
|
163
|
+
if self_attention_cache_update_index is not None
|
|
164
|
+
else 0
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
input_shape = ops.shape(hidden_states)[:-1]
|
|
168
|
+
hidden_shape = (*input_shape, self.num_attention_heads, self.head_dim)
|
|
169
|
+
|
|
170
|
+
query = ops.reshape(self.q_proj(hidden_states), hidden_shape)
|
|
171
|
+
|
|
172
|
+
def _compute_kv_values(x_input):
|
|
173
|
+
kv_hidden_shape = (
|
|
174
|
+
*input_shape,
|
|
175
|
+
self.num_key_value_heads,
|
|
176
|
+
self.head_dim,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
key = ops.reshape(self.k_proj(x_input), kv_hidden_shape)
|
|
180
|
+
value = ops.reshape(self.v_proj(x_input), kv_hidden_shape)
|
|
181
|
+
|
|
182
|
+
return key, value
|
|
183
|
+
|
|
184
|
+
if self_attention_cache is not None:
|
|
185
|
+
key_cache = self_attention_cache[:, 0, ...]
|
|
186
|
+
value_cache = self_attention_cache[:, 1, ...]
|
|
187
|
+
|
|
188
|
+
if self_attention_cache_update_index is None:
|
|
189
|
+
key = key_cache
|
|
190
|
+
value = value_cache
|
|
191
|
+
else:
|
|
192
|
+
key_update, value_update = _compute_kv_values(hidden_states)
|
|
193
|
+
|
|
194
|
+
# Apply RoPE to key_update BEFORE caching
|
|
195
|
+
if self.use_rope:
|
|
196
|
+
cos, sin = self.rotary_embedding(
|
|
197
|
+
query, start_index=start_index
|
|
198
|
+
)
|
|
199
|
+
query_rope, key_update = apply_rotary_pos_emb(
|
|
200
|
+
query, key_update, cos, sin, expansion_axis=2
|
|
201
|
+
)
|
|
202
|
+
query = query_rope
|
|
203
|
+
|
|
204
|
+
start = (0, self_attention_cache_update_index, 0, 0)
|
|
205
|
+
|
|
206
|
+
key = ops.slice_update(key_cache, start, key_update)
|
|
207
|
+
value = ops.slice_update(value_cache, start, value_update)
|
|
208
|
+
self_attention_cache = ops.stack((key, value), axis=1)
|
|
209
|
+
else:
|
|
210
|
+
if self_attention_cache_update_index is not None:
|
|
211
|
+
raise ValueError(
|
|
212
|
+
"`self_attention_cache_update_index` should not be set "
|
|
213
|
+
"if `self_attention_cache` is `None`. Received: "
|
|
214
|
+
f"self_attention_cache={self_attention_cache}, "
|
|
215
|
+
"self_attention_cache_update_index="
|
|
216
|
+
f"{self_attention_cache_update_index}"
|
|
217
|
+
)
|
|
218
|
+
key, value = _compute_kv_values(hidden_states)
|
|
219
|
+
|
|
220
|
+
# Apply RoPE when not using cache
|
|
221
|
+
if self.use_rope:
|
|
222
|
+
cos, sin = self.rotary_embedding(query, start_index=start_index)
|
|
223
|
+
query, key = apply_rotary_pos_emb(
|
|
224
|
+
query, key, cos, sin, expansion_axis=2
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
|
|
228
|
+
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
|
|
229
|
+
|
|
230
|
+
attn_output = self._compute_attention(
|
|
231
|
+
query,
|
|
232
|
+
key,
|
|
233
|
+
value,
|
|
234
|
+
attention_mask,
|
|
235
|
+
cache_update_index=self_attention_cache_update_index,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
attn_output = self.o_proj(attn_output)
|
|
239
|
+
|
|
240
|
+
if self_attention_cache is not None:
|
|
241
|
+
return attn_output, self_attention_cache
|
|
242
|
+
|
|
243
|
+
return attn_output
|
|
244
|
+
|
|
245
|
+
def compute_output_shape(self, input_shape):
|
|
246
|
+
"""
|
|
247
|
+
Computes the output shape of the layer.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
input_shape: A list/tuple of shapes for the inputs:
|
|
251
|
+
[hidden_states_shape, position_embeddings_shape_tuple,
|
|
252
|
+
attention_mask_shape]
|
|
253
|
+
- hidden_states_shape: (batch_size, seq_len,
|
|
254
|
+
hidden_size)
|
|
255
|
+
- position_embeddings_shape_tuple: (cos_shape,
|
|
256
|
+
sin_shape) where cos_shape/sin_shape is
|
|
257
|
+
(batch_size, seq_len, head_dim)
|
|
258
|
+
- attention_mask_shape: (batch_size, 1, seq_len,
|
|
259
|
+
seq_len)
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
A list of output shapes: [output_attn_output_shape,
|
|
263
|
+
output_attn_weights_shape]
|
|
264
|
+
"""
|
|
265
|
+
hidden_states_shape = input_shape[0]
|
|
266
|
+
|
|
267
|
+
batch_size = hidden_states_shape[0]
|
|
268
|
+
seq_len = hidden_states_shape[1]
|
|
269
|
+
|
|
270
|
+
output_attn_output_shape = (batch_size, seq_len, self.hidden_size)
|
|
271
|
+
|
|
272
|
+
output_attn_weights_shape = (
|
|
273
|
+
batch_size,
|
|
274
|
+
self.num_attention_heads,
|
|
275
|
+
seq_len,
|
|
276
|
+
seq_len,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
return [output_attn_output_shape, output_attn_weights_shape]
|
|
280
|
+
|
|
281
|
+
def _masked_softmax(self, attention_scores, attention_mask=None):
|
|
282
|
+
"""Applies softmax with optional masking.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
attention_scores: Attention score tensor.
|
|
286
|
+
attention_mask: Optional mask tensor.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Masked softmax attention weights.
|
|
290
|
+
"""
|
|
291
|
+
if attention_mask is not None:
|
|
292
|
+
return self._softmax(
|
|
293
|
+
attention_scores, attention_mask[:, None, :, :]
|
|
294
|
+
)
|
|
295
|
+
return self._softmax(attention_scores)
|
|
296
|
+
|
|
297
|
+
def _compute_attention(
|
|
298
|
+
self, query, key, value, attention_mask=None, cache_update_index=None
|
|
299
|
+
):
|
|
300
|
+
"""Computes attention using query, key, and value tensors.
|
|
301
|
+
|
|
302
|
+
Uses Flash Attention when available for better performance.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
query: Query tensor.
|
|
306
|
+
key: Key tensor.
|
|
307
|
+
value: Value tensor.
|
|
308
|
+
attention_mask: Optional mask tensor.
|
|
309
|
+
cache_update_index: Index for sliding window computation.
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
attention_output: Output tensor after applying attention.
|
|
313
|
+
"""
|
|
314
|
+
attention_scores = ops.einsum(self._dot_product_equation, query, key)
|
|
315
|
+
|
|
316
|
+
attention_scores = ops.multiply(
|
|
317
|
+
attention_scores,
|
|
318
|
+
ops.cast(self._inv_norm_factor, self.compute_dtype),
|
|
319
|
+
)
|
|
320
|
+
attention_scores = self._masked_softmax(
|
|
321
|
+
attention_scores, attention_mask
|
|
322
|
+
)
|
|
323
|
+
attention_scores = ops.cast(attention_scores, self.compute_dtype)
|
|
324
|
+
attention_output = ops.einsum(
|
|
325
|
+
self._combine_equation, attention_scores, value
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
return attention_output
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
class SmolLM3MLP(layers.Layer):
|
|
332
|
+
"""Multi-layer perceptron (MLP) block for SmolLM3 model.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
hidden_size: int. The hidden size of the MLP.
|
|
336
|
+
intermediate_size: int. The intermediate size of the MLP.
|
|
337
|
+
mlp_bias: bool. Whether to use bias in MLP dense layers.
|
|
338
|
+
"""
|
|
339
|
+
|
|
340
|
+
def __init__(self, hidden_size, intermediate_size, mlp_bias, **kwargs):
|
|
341
|
+
super().__init__(**kwargs)
|
|
342
|
+
self.hidden_size = hidden_size
|
|
343
|
+
self.intermediate_size = intermediate_size
|
|
344
|
+
self.mlp_bias = mlp_bias
|
|
345
|
+
|
|
346
|
+
self.gate_proj = layers.Dense(
|
|
347
|
+
self.intermediate_size, use_bias=self.mlp_bias, name="gate_proj"
|
|
348
|
+
)
|
|
349
|
+
self.up_proj = layers.Dense(
|
|
350
|
+
self.intermediate_size, use_bias=self.mlp_bias, name="up_proj"
|
|
351
|
+
)
|
|
352
|
+
self.down_proj = layers.Dense(
|
|
353
|
+
self.hidden_size, use_bias=self.mlp_bias, name="down_proj"
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
def build(self, input_shape):
|
|
357
|
+
"""
|
|
358
|
+
Builds the internal Dense layers.
|
|
359
|
+
Args:
|
|
360
|
+
input_shape: The shape of the input to this layer
|
|
361
|
+
(batch_size, seq_len, hidden_size).
|
|
362
|
+
"""
|
|
363
|
+
self.gate_proj.build(input_shape)
|
|
364
|
+
self.up_proj.build(input_shape)
|
|
365
|
+
# The down_proj takes intermediate_output, which has shape
|
|
366
|
+
# (batch_size, seq_len, intermediate_size)
|
|
367
|
+
down_proj_input_shape = (
|
|
368
|
+
input_shape[0],
|
|
369
|
+
input_shape[1],
|
|
370
|
+
self.intermediate_size,
|
|
371
|
+
)
|
|
372
|
+
self.down_proj.build(down_proj_input_shape)
|
|
373
|
+
super().build(input_shape)
|
|
374
|
+
|
|
375
|
+
def call(self, x):
|
|
376
|
+
"""
|
|
377
|
+
Forward pass for SmolLM3MLP.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
x: Input tensor of shape (batch_size, seq_len, hidden_size).
|
|
381
|
+
"""
|
|
382
|
+
gate_output = activations.silu(self.gate_proj(x))
|
|
383
|
+
up_output = self.up_proj(x)
|
|
384
|
+
intermediate_output = gate_output * up_output
|
|
385
|
+
down_proj_output = self.down_proj(intermediate_output)
|
|
386
|
+
return down_proj_output
|
|
387
|
+
|
|
388
|
+
def compute_output_shape(self, input_shape):
|
|
389
|
+
"""
|
|
390
|
+
Computes the output shape of the layer.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
input_shape: The input shape (batch_size, seq_len, hidden_size).
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
The output shape, which is the same as the input shape:
|
|
397
|
+
(batch_size, seq_len, hidden_size).
|
|
398
|
+
"""
|
|
399
|
+
return input_shape
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
class SmolLM3DecoderLayer(layers.Layer):
|
|
403
|
+
"""Decoder layer for SmolLM3 model, combining self-attention and MLP.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
hidden_size: int. The hidden size of the layer.
|
|
407
|
+
num_attention_heads: int. The number of attention heads.
|
|
408
|
+
num_key_value_heads: int. The number of key-value heads.
|
|
409
|
+
attention_bias: bool. Whether to use bias in attention projections.
|
|
410
|
+
attention_dropout: float. Dropout rate for attention weights.
|
|
411
|
+
rope_layer_enabled_list: list of bool. List indicating if RoPE is
|
|
412
|
+
enabled for each layer.
|
|
413
|
+
layer_types: list of str. List of layer types.
|
|
414
|
+
layer_idx: int. Index of the current layer.
|
|
415
|
+
intermediate_size: int. The intermediate size of the MLP.
|
|
416
|
+
mlp_bias: bool. Whether to use bias in MLP dense layers.
|
|
417
|
+
layer_norm_epsilon: float. Epsilon for RMSNormalization.
|
|
418
|
+
max_position_embeddings: int. Maximum sequence length for position
|
|
419
|
+
embeddings. Defaults to 2048.
|
|
420
|
+
rope_theta: float. The theta value for RoPE. Defaults to 10000.0.
|
|
421
|
+
partial_rotary_factor: float. The factor for partial rotary embedding.
|
|
422
|
+
Defaults to 1.0.
|
|
423
|
+
"""
|
|
424
|
+
|
|
425
|
+
def __init__(
|
|
426
|
+
self,
|
|
427
|
+
hidden_size,
|
|
428
|
+
num_attention_heads,
|
|
429
|
+
num_key_value_heads,
|
|
430
|
+
attention_bias,
|
|
431
|
+
attention_dropout,
|
|
432
|
+
rope_layer_enabled_list,
|
|
433
|
+
layer_types,
|
|
434
|
+
layer_idx,
|
|
435
|
+
intermediate_size,
|
|
436
|
+
mlp_bias,
|
|
437
|
+
layer_norm_epsilon,
|
|
438
|
+
max_position_embeddings=2048,
|
|
439
|
+
rope_theta=10000.0,
|
|
440
|
+
partial_rotary_factor=1.0,
|
|
441
|
+
**kwargs,
|
|
442
|
+
):
|
|
443
|
+
super().__init__(**kwargs)
|
|
444
|
+
self.hidden_size = hidden_size
|
|
445
|
+
self.layer_idx = layer_idx
|
|
446
|
+
|
|
447
|
+
self.self_attn = SmolLM3Attention(
|
|
448
|
+
hidden_size=hidden_size,
|
|
449
|
+
num_attention_heads=num_attention_heads,
|
|
450
|
+
num_key_value_heads=num_key_value_heads,
|
|
451
|
+
attention_bias=attention_bias,
|
|
452
|
+
attention_dropout=attention_dropout,
|
|
453
|
+
rope_layer_enabled_list=rope_layer_enabled_list,
|
|
454
|
+
layer_types=layer_types,
|
|
455
|
+
layer_idx=layer_idx,
|
|
456
|
+
max_position_embeddings=max_position_embeddings,
|
|
457
|
+
rope_theta=rope_theta,
|
|
458
|
+
partial_rotary_factor=partial_rotary_factor,
|
|
459
|
+
name="self_attn",
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
self.mlp = SmolLM3MLP(
|
|
463
|
+
hidden_size=hidden_size,
|
|
464
|
+
intermediate_size=intermediate_size,
|
|
465
|
+
mlp_bias=mlp_bias,
|
|
466
|
+
name="mlp",
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
self.input_layernorm = layers.RMSNormalization(
|
|
470
|
+
epsilon=layer_norm_epsilon, axis=-1, name="input_layernorm"
|
|
471
|
+
)
|
|
472
|
+
self.post_attention_layernorm = layers.RMSNormalization(
|
|
473
|
+
epsilon=layer_norm_epsilon, axis=-1, name="post_attention_layernorm"
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
self.attention_type = layer_types[layer_idx]
|
|
477
|
+
|
|
478
|
+
def _compute_self_attention_mask(
|
|
479
|
+
self,
|
|
480
|
+
decoder_sequence,
|
|
481
|
+
decoder_padding_mask,
|
|
482
|
+
decoder_attention_mask,
|
|
483
|
+
self_attention_cache,
|
|
484
|
+
self_attention_cache_update_index,
|
|
485
|
+
):
|
|
486
|
+
decoder_mask = merge_padding_and_attention_mask(
|
|
487
|
+
decoder_sequence, decoder_padding_mask, decoder_attention_mask
|
|
488
|
+
)
|
|
489
|
+
batch_size = ops.shape(decoder_sequence)[0]
|
|
490
|
+
input_length = output_length = ops.shape(decoder_sequence)[1]
|
|
491
|
+
# We need to handle a rectangular causal mask when doing cached
|
|
492
|
+
# decoding. For generative inference, `decoder_sequence` will
|
|
493
|
+
# generally be length 1, and `cache` will be the full generation length.
|
|
494
|
+
if self_attention_cache is not None:
|
|
495
|
+
input_length = ops.shape(self_attention_cache)[2]
|
|
496
|
+
|
|
497
|
+
cache_update_index = (
|
|
498
|
+
0
|
|
499
|
+
if self_attention_cache_update_index is None
|
|
500
|
+
else self_attention_cache_update_index
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
causal_mask = compute_causal_mask(
|
|
504
|
+
batch_size, input_length, output_length, cache_update_index
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
return (
|
|
508
|
+
ops.minimum(decoder_mask, causal_mask)
|
|
509
|
+
if decoder_mask is not None
|
|
510
|
+
else causal_mask
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
def build(self, input_shape):
|
|
514
|
+
"""
|
|
515
|
+
Builds the sub-layers based on the input shape.
|
|
516
|
+
|
|
517
|
+
Args:
|
|
518
|
+
input_shape: The input shape to the decoder layer
|
|
519
|
+
(batch_size, seq_len, hidden_size).
|
|
520
|
+
"""
|
|
521
|
+
# input_shape for SmolLM3DecoderLayer: (batch_size, seq_len,
|
|
522
|
+
# hidden_size)
|
|
523
|
+
batch_size = input_shape[0]
|
|
524
|
+
seq_len = input_shape[1]
|
|
525
|
+
|
|
526
|
+
head_dim = self.self_attn.head_dim
|
|
527
|
+
pos_emb_shape = (batch_size, seq_len, head_dim)
|
|
528
|
+
|
|
529
|
+
attn_mask_shape = (batch_size, 1, seq_len, seq_len)
|
|
530
|
+
|
|
531
|
+
# Pass the correct input shape to self_attn's build method
|
|
532
|
+
# The input_shape for self_attn.build is a list:
|
|
533
|
+
# [hidden_states_shape, (pos_emb_shape, pos_emb_shape), attn_mask_shape]
|
|
534
|
+
self.self_attn.build(
|
|
535
|
+
[input_shape, (pos_emb_shape, pos_emb_shape), attn_mask_shape]
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
self.mlp.build(input_shape)
|
|
539
|
+
self.input_layernorm.build(input_shape)
|
|
540
|
+
self.post_attention_layernorm.build(input_shape)
|
|
541
|
+
|
|
542
|
+
super().build(input_shape)
|
|
543
|
+
|
|
544
|
+
def call(
|
|
545
|
+
self,
|
|
546
|
+
hidden_states,
|
|
547
|
+
training=False,
|
|
548
|
+
decoder_padding_mask=None,
|
|
549
|
+
decoder_attention_mask=None,
|
|
550
|
+
**kwargs,
|
|
551
|
+
):
|
|
552
|
+
"""
|
|
553
|
+
Forward pass for SmolLM3DecoderLayer.
|
|
554
|
+
|
|
555
|
+
Args:
|
|
556
|
+
hidden_states: Input tensor of shape (batch_size,
|
|
557
|
+
seq_len, hidden_size).
|
|
558
|
+
position_embeddings: Optional tuple of (cos, sin)
|
|
559
|
+
tensors for RoPE.
|
|
560
|
+
training: Whether the layer is in training mode.
|
|
561
|
+
"""
|
|
562
|
+
self_attention_cache = kwargs.get("self_attention_cache", None)
|
|
563
|
+
self_attention_cache_update_index = kwargs.get(
|
|
564
|
+
"self_attention_cache_update_index", None
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
self_attention_mask = self._compute_self_attention_mask(
|
|
568
|
+
decoder_sequence=hidden_states,
|
|
569
|
+
decoder_padding_mask=decoder_padding_mask,
|
|
570
|
+
decoder_attention_mask=decoder_attention_mask,
|
|
571
|
+
self_attention_cache=self_attention_cache,
|
|
572
|
+
self_attention_cache_update_index=self_attention_cache_update_index,
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
residual = hidden_states
|
|
576
|
+
hidden_states = self.input_layernorm(hidden_states)
|
|
577
|
+
|
|
578
|
+
# Self Attention
|
|
579
|
+
x = self.self_attn(
|
|
580
|
+
hidden_states=hidden_states,
|
|
581
|
+
training=training,
|
|
582
|
+
attention_mask=self_attention_mask,
|
|
583
|
+
**kwargs,
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
if isinstance(x, tuple):
|
|
587
|
+
attn_output, self_attention_cache = x
|
|
588
|
+
else:
|
|
589
|
+
attn_output = x
|
|
590
|
+
|
|
591
|
+
hidden_states = ops.add(residual, attn_output)
|
|
592
|
+
|
|
593
|
+
residual = hidden_states
|
|
594
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
595
|
+
hidden_states = self.mlp(hidden_states)
|
|
596
|
+
hidden_states = ops.add(residual, hidden_states)
|
|
597
|
+
|
|
598
|
+
if self_attention_cache is not None:
|
|
599
|
+
return hidden_states, self_attention_cache
|
|
600
|
+
else:
|
|
601
|
+
return hidden_states
|
|
602
|
+
|
|
603
|
+
def compute_output_shape(self, input_shape):
|
|
604
|
+
"""
|
|
605
|
+
Computes the output shape of the layer.
|
|
606
|
+
|
|
607
|
+
Args:
|
|
608
|
+
input_shape: The input shape (batch_size, seq_len, hidden_size).
|
|
609
|
+
|
|
610
|
+
Returns:
|
|
611
|
+
The output shape, which is the same as the input shape:
|
|
612
|
+
(batch_size, seq_len, hidden_size).
|
|
613
|
+
"""
|
|
614
|
+
return input_shape
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
class SmolLM3RotaryEmbedding(layers.Layer):
|
|
618
|
+
"""Rotary Position Embedding (RoPE) layer for SmolLM3 model.
|
|
619
|
+
|
|
620
|
+
Args:
|
|
621
|
+
hidden_size: int. The hidden size of the model.
|
|
622
|
+
num_attention_heads: int. The number of attention heads.
|
|
623
|
+
max_position_embeddings: int. The maximum sequence length for position
|
|
624
|
+
embeddings.
|
|
625
|
+
rope_theta: float. The theta value for RoPE.
|
|
626
|
+
partial_rotary_factor: float. The factor for partial rotary embedding.
|
|
627
|
+
"""
|
|
628
|
+
|
|
629
|
+
def __init__(
|
|
630
|
+
self,
|
|
631
|
+
hidden_size,
|
|
632
|
+
num_attention_heads,
|
|
633
|
+
max_position_embeddings,
|
|
634
|
+
rope_theta,
|
|
635
|
+
partial_rotary_factor,
|
|
636
|
+
**kwargs,
|
|
637
|
+
):
|
|
638
|
+
super().__init__(**kwargs)
|
|
639
|
+
self.hidden_size = hidden_size
|
|
640
|
+
self.num_attention_heads = num_attention_heads
|
|
641
|
+
self.max_position_embeddings = max_position_embeddings
|
|
642
|
+
self.rope_theta = rope_theta
|
|
643
|
+
self.partial_rotary_factor = partial_rotary_factor
|
|
644
|
+
|
|
645
|
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
|
646
|
+
|
|
647
|
+
inv_freq_tensor, self.attention_scaling = rope_init(
|
|
648
|
+
self.rope_theta, self.partial_rotary_factor, self.head_dim
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
self.inv_freq = self.add_weight(
|
|
652
|
+
name="inv_freq",
|
|
653
|
+
shape=ops.shape(inv_freq_tensor),
|
|
654
|
+
dtype=inv_freq_tensor.dtype,
|
|
655
|
+
initializer=initializers.Constant(
|
|
656
|
+
ops.convert_to_numpy(inv_freq_tensor)
|
|
657
|
+
),
|
|
658
|
+
trainable=False, # This weight is not trained
|
|
659
|
+
)
|
|
660
|
+
self.original_inv_freq = self.inv_freq
|
|
661
|
+
|
|
662
|
+
def build(self, input_shape):
|
|
663
|
+
"""
|
|
664
|
+
Builds the layer. For SmolLM3RotaryEmbedding, this mainly
|
|
665
|
+
ensures that the parent layer's build is called.
|
|
666
|
+
|
|
667
|
+
Args:
|
|
668
|
+
input_shape: A list/tuple of shapes for the inputs:
|
|
669
|
+
[x_shape, position_ids_shape]
|
|
670
|
+
- x_shape: (batch_size, ..., head_dim)
|
|
671
|
+
- position_ids_shape: (batch_size, seq_len)
|
|
672
|
+
"""
|
|
673
|
+
# No internal layers to explicitly build here, as inv_freq is
|
|
674
|
+
# added in __init__
|
|
675
|
+
super().build(input_shape)
|
|
676
|
+
|
|
677
|
+
def call(
|
|
678
|
+
self,
|
|
679
|
+
x,
|
|
680
|
+
start_index=0,
|
|
681
|
+
):
|
|
682
|
+
"""
|
|
683
|
+
Forward pass for SmolLM3RotaryEmbedding.
|
|
684
|
+
|
|
685
|
+
Args:
|
|
686
|
+
x: Input tensor, typically query or key states.
|
|
687
|
+
Shape can vary, but the last dimension is head_dim.
|
|
688
|
+
position_ids: Tensor of position IDs of shape (batch_size, seq_len).
|
|
689
|
+
"""
|
|
690
|
+
batch_size = ops.shape(x)[0]
|
|
691
|
+
seq_len = ops.shape(x)[1]
|
|
692
|
+
positions = ops.arange(seq_len, dtype="float32")
|
|
693
|
+
positions = positions + ops.cast(start_index, dtype="float32")
|
|
694
|
+
|
|
695
|
+
# inv_freq: (inv_freq_dim,) -> (1, inv_freq_dim, 1)
|
|
696
|
+
# -> (batch, inv_freq_dim, 1)
|
|
697
|
+
inv_freq_expanded = ops.expand_dims(
|
|
698
|
+
ops.expand_dims(self.inv_freq, axis=0), axis=-1
|
|
699
|
+
)
|
|
700
|
+
inv_freq_expanded = ops.broadcast_to(
|
|
701
|
+
inv_freq_expanded, (batch_size, ops.shape(self.inv_freq)[0], 1)
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
# positions: (seq_len,) -> (1, 1, seq_len)
|
|
705
|
+
# -> (batch, 1, seq_len)
|
|
706
|
+
position_ids_expanded = ops.expand_dims(
|
|
707
|
+
ops.expand_dims(positions, axis=0), axis=0
|
|
708
|
+
)
|
|
709
|
+
position_ids_expanded = ops.broadcast_to(
|
|
710
|
+
position_ids_expanded, (batch_size, 1, seq_len)
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
# matmul: (batch, inv_freq_dim, 1) @ (batch, 1, seq_len)
|
|
714
|
+
# -> (batch, inv_freq_dim, seq_len)
|
|
715
|
+
freqs = ops.matmul(
|
|
716
|
+
ops.cast(inv_freq_expanded, "float32"),
|
|
717
|
+
ops.cast(position_ids_expanded, "float32"),
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
# transpose: (batch, inv_freq_dim, seq_len) ->
|
|
721
|
+
# (batch, seq_len, inv_freq_dim)
|
|
722
|
+
freqs = ops.transpose(freqs, axes=(0, 2, 1))
|
|
723
|
+
|
|
724
|
+
emb = ops.concatenate((freqs, freqs), axis=-1)
|
|
725
|
+
|
|
726
|
+
cos = ops.cos(emb) * self.attention_scaling
|
|
727
|
+
sin = ops.sin(emb) * self.attention_scaling
|
|
728
|
+
|
|
729
|
+
return ops.cast(cos, x.dtype), ops.cast(sin, x.dtype)
|
|
730
|
+
|
|
731
|
+
def compute_output_shape(self, input_shape):
|
|
732
|
+
"""
|
|
733
|
+
Computes the output shape of the layer.
|
|
734
|
+
|
|
735
|
+
Args:
|
|
736
|
+
input_shape: A list/tuple of shapes for the inputs:
|
|
737
|
+
[x_shape, position_ids_shape]
|
|
738
|
+
- x_shape: (batch_size, ..., head_dim)
|
|
739
|
+
- position_ids_shape: (batch_size, seq_len)
|
|
740
|
+
|
|
741
|
+
Returns:
|
|
742
|
+
A list of output shapes for (cos, sin):
|
|
743
|
+
[(batch_size, seq_len, head_dim), (batch_size, seq_len, head_dim)]
|
|
744
|
+
"""
|
|
745
|
+
if input_shape[1] is not None and len(input_shape[1]) >= 2:
|
|
746
|
+
batch_size = input_shape[1][0]
|
|
747
|
+
seq_len = input_shape[1][1]
|
|
748
|
+
else:
|
|
749
|
+
# Fallback if position_ids_shape is None or malformed.
|
|
750
|
+
# In this case, the batch_size and seq_len are unknown.
|
|
751
|
+
batch_size = None
|
|
752
|
+
seq_len = None
|
|
753
|
+
|
|
754
|
+
# The output cos and sin have shape (batch_size, seq_len, head_dim)
|
|
755
|
+
output_shape = (batch_size, seq_len, self.head_dim)
|
|
756
|
+
|
|
757
|
+
return [output_shape, output_shape]
|