keras-hub-nightly 0.23.0.dev202510160419__py3-none-any.whl → 0.23.0.dev202510180414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-hub-nightly might be problematic. Click here for more details.

@@ -0,0 +1,757 @@
1
+ import math
2
+
3
+ from keras import activations
4
+ from keras import initializers
5
+ from keras import layers
6
+ from keras import ops
7
+
8
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
9
+ compute_causal_mask,
10
+ )
11
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
12
+ merge_padding_and_attention_mask,
13
+ )
14
+ from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb
15
+ from keras_hub.src.models.smollm3.smollm3_utils import rope_init
16
+
17
+
18
+ class SmolLM3Attention(layers.Layer):
19
+ """Multi-head attention layer for SmolLM3 model.
20
+
21
+ Args:
22
+ hidden_size: int. The hidden size of the attention layer.
23
+ num_attention_heads: int. The number of attention heads.
24
+ num_key_value_heads: int. The number of key-value heads.
25
+ attention_bias: bool. Whether to use bias in attention projections.
26
+ attention_dropout: float. Dropout rate for attention weights.
27
+ rope_layer_enabled_list: list of bool. List indicating if RoPE is
28
+ enabled for each layer.
29
+ layer_types: list of str. List of layer types.
30
+ layer_idx: int. Index of the current layer.
31
+ max_position_embeddings: int. Maximum sequence length for position
32
+ embeddings. Defaults to 2048.
33
+ rope_theta: float. The theta value for RoPE. Defaults to 10000.0.
34
+ partial_rotary_factor: float. The factor for partial rotary embedding.
35
+ Defaults to 1.0.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ hidden_size,
41
+ num_attention_heads,
42
+ num_key_value_heads,
43
+ attention_bias,
44
+ attention_dropout,
45
+ rope_layer_enabled_list,
46
+ layer_types,
47
+ layer_idx,
48
+ max_position_embeddings=2048,
49
+ rope_theta=10000.0,
50
+ partial_rotary_factor=1.0,
51
+ **kwargs,
52
+ ):
53
+ super().__init__(**kwargs)
54
+
55
+ self.hidden_size = hidden_size
56
+ self.num_attention_heads = num_attention_heads
57
+ self.num_key_value_heads = num_key_value_heads
58
+ self.attention_bias = attention_bias
59
+ self.attention_dropout = attention_dropout
60
+ self.rope_layer_enabled_list = rope_layer_enabled_list
61
+ self.layer_types = layer_types
62
+ self.max_position_embeddings = max_position_embeddings
63
+ self.rope_theta = rope_theta
64
+ self.partial_rotary_factor = partial_rotary_factor
65
+
66
+ self._dot_product_equation = "bquh,bkuh->buqk"
67
+ self._combine_equation = "buqk,bkuh->bquh"
68
+
69
+ self.head_dim = hidden_size // self.num_attention_heads
70
+ self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
71
+
72
+ self.layer_idx = layer_idx
73
+ self.num_key_value_groups = (
74
+ self.num_attention_heads // self.num_key_value_heads
75
+ )
76
+ self.scaling = self.head_dim**-0.5
77
+ self.is_causal = True
78
+
79
+ self.q_proj = layers.Dense(
80
+ self.num_attention_heads * self.head_dim,
81
+ use_bias=self.attention_bias,
82
+ name="q_proj",
83
+ )
84
+ self.k_proj = layers.Dense(
85
+ self.num_key_value_heads * self.head_dim,
86
+ use_bias=self.attention_bias,
87
+ name="k_proj",
88
+ )
89
+ self.v_proj = layers.Dense(
90
+ self.num_key_value_heads * self.head_dim,
91
+ use_bias=self.attention_bias,
92
+ name="v_proj",
93
+ )
94
+ self.o_proj = layers.EinsumDense(
95
+ equation="bquh,uhm->bqm",
96
+ output_shape=(None, self.hidden_size),
97
+ name="o_proj",
98
+ )
99
+ self.o_proj.build((None, None, self.num_attention_heads, self.head_dim))
100
+
101
+ self.use_rope = (
102
+ self.rope_layer_enabled_list[self.layer_idx]
103
+ if self.layer_idx < len(self.rope_layer_enabled_list)
104
+ else True
105
+ ) # Default to True if index out of bounds
106
+
107
+ self.rotary_embedding = SmolLM3RotaryEmbedding(
108
+ hidden_size=self.hidden_size,
109
+ num_attention_heads=self.num_attention_heads,
110
+ max_position_embeddings=self.max_position_embeddings,
111
+ rope_theta=self.rope_theta,
112
+ partial_rotary_factor=self.partial_rotary_factor,
113
+ name="rotary_emb",
114
+ )
115
+
116
+ self._softmax = layers.Softmax(
117
+ axis=-1,
118
+ dtype="float32",
119
+ name="attention_softmax",
120
+ )
121
+
122
+ def build(self, input_shape):
123
+ """Builds the internal Dense layers.
124
+
125
+ Args:
126
+ input_shape: A list/tuple of shapes for the inputs:
127
+ [hidden_states_shape, position_embeddings_shape_tuple,
128
+ attention_mask_shape]
129
+ - hidden_states_shape: (batch_size, seq_len,
130
+ hidden_size)
131
+ """
132
+ # The input shape to the Dense layers (q_proj, k_proj, v_proj, o_proj)
133
+ # is the same as the hidden_states input to SmolLM3Attention.
134
+ hidden_states_shape = input_shape[0]
135
+ self.q_proj.build(hidden_states_shape)
136
+ self.k_proj.build(hidden_states_shape)
137
+ self.v_proj.build(hidden_states_shape)
138
+ super().build(input_shape)
139
+
140
+ def call(
141
+ self,
142
+ hidden_states,
143
+ training=False,
144
+ attention_mask=None,
145
+ **kwargs,
146
+ ):
147
+ """Forward pass for SmolLM3Attention.
148
+
149
+ Args:
150
+ hidden_states: Input tensor of shape (batch_size, seq_len,
151
+ hidden_size).
152
+ position_embeddings: Tuple of (cos, sin) tensors for RoPE.
153
+ attention_mask: Attention mask tensor.
154
+ training: Whether the layer is in training mode.
155
+ """
156
+ self.training = training
157
+ self_attention_cache = kwargs.get("self_attention_cache", None)
158
+ self_attention_cache_update_index = kwargs.get(
159
+ "self_attention_cache_update_index", None
160
+ )
161
+ start_index = (
162
+ self_attention_cache_update_index
163
+ if self_attention_cache_update_index is not None
164
+ else 0
165
+ )
166
+
167
+ input_shape = ops.shape(hidden_states)[:-1]
168
+ hidden_shape = (*input_shape, self.num_attention_heads, self.head_dim)
169
+
170
+ query = ops.reshape(self.q_proj(hidden_states), hidden_shape)
171
+
172
+ def _compute_kv_values(x_input):
173
+ kv_hidden_shape = (
174
+ *input_shape,
175
+ self.num_key_value_heads,
176
+ self.head_dim,
177
+ )
178
+
179
+ key = ops.reshape(self.k_proj(x_input), kv_hidden_shape)
180
+ value = ops.reshape(self.v_proj(x_input), kv_hidden_shape)
181
+
182
+ return key, value
183
+
184
+ if self_attention_cache is not None:
185
+ key_cache = self_attention_cache[:, 0, ...]
186
+ value_cache = self_attention_cache[:, 1, ...]
187
+
188
+ if self_attention_cache_update_index is None:
189
+ key = key_cache
190
+ value = value_cache
191
+ else:
192
+ key_update, value_update = _compute_kv_values(hidden_states)
193
+
194
+ # Apply RoPE to key_update BEFORE caching
195
+ if self.use_rope:
196
+ cos, sin = self.rotary_embedding(
197
+ query, start_index=start_index
198
+ )
199
+ query_rope, key_update = apply_rotary_pos_emb(
200
+ query, key_update, cos, sin, expansion_axis=2
201
+ )
202
+ query = query_rope
203
+
204
+ start = (0, self_attention_cache_update_index, 0, 0)
205
+
206
+ key = ops.slice_update(key_cache, start, key_update)
207
+ value = ops.slice_update(value_cache, start, value_update)
208
+ self_attention_cache = ops.stack((key, value), axis=1)
209
+ else:
210
+ if self_attention_cache_update_index is not None:
211
+ raise ValueError(
212
+ "`self_attention_cache_update_index` should not be set "
213
+ "if `self_attention_cache` is `None`. Received: "
214
+ f"self_attention_cache={self_attention_cache}, "
215
+ "self_attention_cache_update_index="
216
+ f"{self_attention_cache_update_index}"
217
+ )
218
+ key, value = _compute_kv_values(hidden_states)
219
+
220
+ # Apply RoPE when not using cache
221
+ if self.use_rope:
222
+ cos, sin = self.rotary_embedding(query, start_index=start_index)
223
+ query, key = apply_rotary_pos_emb(
224
+ query, key, cos, sin, expansion_axis=2
225
+ )
226
+
227
+ key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
228
+ value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
229
+
230
+ attn_output = self._compute_attention(
231
+ query,
232
+ key,
233
+ value,
234
+ attention_mask,
235
+ cache_update_index=self_attention_cache_update_index,
236
+ )
237
+
238
+ attn_output = self.o_proj(attn_output)
239
+
240
+ if self_attention_cache is not None:
241
+ return attn_output, self_attention_cache
242
+
243
+ return attn_output
244
+
245
+ def compute_output_shape(self, input_shape):
246
+ """
247
+ Computes the output shape of the layer.
248
+
249
+ Args:
250
+ input_shape: A list/tuple of shapes for the inputs:
251
+ [hidden_states_shape, position_embeddings_shape_tuple,
252
+ attention_mask_shape]
253
+ - hidden_states_shape: (batch_size, seq_len,
254
+ hidden_size)
255
+ - position_embeddings_shape_tuple: (cos_shape,
256
+ sin_shape) where cos_shape/sin_shape is
257
+ (batch_size, seq_len, head_dim)
258
+ - attention_mask_shape: (batch_size, 1, seq_len,
259
+ seq_len)
260
+
261
+ Returns:
262
+ A list of output shapes: [output_attn_output_shape,
263
+ output_attn_weights_shape]
264
+ """
265
+ hidden_states_shape = input_shape[0]
266
+
267
+ batch_size = hidden_states_shape[0]
268
+ seq_len = hidden_states_shape[1]
269
+
270
+ output_attn_output_shape = (batch_size, seq_len, self.hidden_size)
271
+
272
+ output_attn_weights_shape = (
273
+ batch_size,
274
+ self.num_attention_heads,
275
+ seq_len,
276
+ seq_len,
277
+ )
278
+
279
+ return [output_attn_output_shape, output_attn_weights_shape]
280
+
281
+ def _masked_softmax(self, attention_scores, attention_mask=None):
282
+ """Applies softmax with optional masking.
283
+
284
+ Args:
285
+ attention_scores: Attention score tensor.
286
+ attention_mask: Optional mask tensor.
287
+
288
+ Returns:
289
+ Masked softmax attention weights.
290
+ """
291
+ if attention_mask is not None:
292
+ return self._softmax(
293
+ attention_scores, attention_mask[:, None, :, :]
294
+ )
295
+ return self._softmax(attention_scores)
296
+
297
+ def _compute_attention(
298
+ self, query, key, value, attention_mask=None, cache_update_index=None
299
+ ):
300
+ """Computes attention using query, key, and value tensors.
301
+
302
+ Uses Flash Attention when available for better performance.
303
+
304
+ Args:
305
+ query: Query tensor.
306
+ key: Key tensor.
307
+ value: Value tensor.
308
+ attention_mask: Optional mask tensor.
309
+ cache_update_index: Index for sliding window computation.
310
+
311
+ Returns:
312
+ attention_output: Output tensor after applying attention.
313
+ """
314
+ attention_scores = ops.einsum(self._dot_product_equation, query, key)
315
+
316
+ attention_scores = ops.multiply(
317
+ attention_scores,
318
+ ops.cast(self._inv_norm_factor, self.compute_dtype),
319
+ )
320
+ attention_scores = self._masked_softmax(
321
+ attention_scores, attention_mask
322
+ )
323
+ attention_scores = ops.cast(attention_scores, self.compute_dtype)
324
+ attention_output = ops.einsum(
325
+ self._combine_equation, attention_scores, value
326
+ )
327
+
328
+ return attention_output
329
+
330
+
331
+ class SmolLM3MLP(layers.Layer):
332
+ """Multi-layer perceptron (MLP) block for SmolLM3 model.
333
+
334
+ Args:
335
+ hidden_size: int. The hidden size of the MLP.
336
+ intermediate_size: int. The intermediate size of the MLP.
337
+ mlp_bias: bool. Whether to use bias in MLP dense layers.
338
+ """
339
+
340
+ def __init__(self, hidden_size, intermediate_size, mlp_bias, **kwargs):
341
+ super().__init__(**kwargs)
342
+ self.hidden_size = hidden_size
343
+ self.intermediate_size = intermediate_size
344
+ self.mlp_bias = mlp_bias
345
+
346
+ self.gate_proj = layers.Dense(
347
+ self.intermediate_size, use_bias=self.mlp_bias, name="gate_proj"
348
+ )
349
+ self.up_proj = layers.Dense(
350
+ self.intermediate_size, use_bias=self.mlp_bias, name="up_proj"
351
+ )
352
+ self.down_proj = layers.Dense(
353
+ self.hidden_size, use_bias=self.mlp_bias, name="down_proj"
354
+ )
355
+
356
+ def build(self, input_shape):
357
+ """
358
+ Builds the internal Dense layers.
359
+ Args:
360
+ input_shape: The shape of the input to this layer
361
+ (batch_size, seq_len, hidden_size).
362
+ """
363
+ self.gate_proj.build(input_shape)
364
+ self.up_proj.build(input_shape)
365
+ # The down_proj takes intermediate_output, which has shape
366
+ # (batch_size, seq_len, intermediate_size)
367
+ down_proj_input_shape = (
368
+ input_shape[0],
369
+ input_shape[1],
370
+ self.intermediate_size,
371
+ )
372
+ self.down_proj.build(down_proj_input_shape)
373
+ super().build(input_shape)
374
+
375
+ def call(self, x):
376
+ """
377
+ Forward pass for SmolLM3MLP.
378
+
379
+ Args:
380
+ x: Input tensor of shape (batch_size, seq_len, hidden_size).
381
+ """
382
+ gate_output = activations.silu(self.gate_proj(x))
383
+ up_output = self.up_proj(x)
384
+ intermediate_output = gate_output * up_output
385
+ down_proj_output = self.down_proj(intermediate_output)
386
+ return down_proj_output
387
+
388
+ def compute_output_shape(self, input_shape):
389
+ """
390
+ Computes the output shape of the layer.
391
+
392
+ Args:
393
+ input_shape: The input shape (batch_size, seq_len, hidden_size).
394
+
395
+ Returns:
396
+ The output shape, which is the same as the input shape:
397
+ (batch_size, seq_len, hidden_size).
398
+ """
399
+ return input_shape
400
+
401
+
402
+ class SmolLM3DecoderLayer(layers.Layer):
403
+ """Decoder layer for SmolLM3 model, combining self-attention and MLP.
404
+
405
+ Args:
406
+ hidden_size: int. The hidden size of the layer.
407
+ num_attention_heads: int. The number of attention heads.
408
+ num_key_value_heads: int. The number of key-value heads.
409
+ attention_bias: bool. Whether to use bias in attention projections.
410
+ attention_dropout: float. Dropout rate for attention weights.
411
+ rope_layer_enabled_list: list of bool. List indicating if RoPE is
412
+ enabled for each layer.
413
+ layer_types: list of str. List of layer types.
414
+ layer_idx: int. Index of the current layer.
415
+ intermediate_size: int. The intermediate size of the MLP.
416
+ mlp_bias: bool. Whether to use bias in MLP dense layers.
417
+ layer_norm_epsilon: float. Epsilon for RMSNormalization.
418
+ max_position_embeddings: int. Maximum sequence length for position
419
+ embeddings. Defaults to 2048.
420
+ rope_theta: float. The theta value for RoPE. Defaults to 10000.0.
421
+ partial_rotary_factor: float. The factor for partial rotary embedding.
422
+ Defaults to 1.0.
423
+ """
424
+
425
+ def __init__(
426
+ self,
427
+ hidden_size,
428
+ num_attention_heads,
429
+ num_key_value_heads,
430
+ attention_bias,
431
+ attention_dropout,
432
+ rope_layer_enabled_list,
433
+ layer_types,
434
+ layer_idx,
435
+ intermediate_size,
436
+ mlp_bias,
437
+ layer_norm_epsilon,
438
+ max_position_embeddings=2048,
439
+ rope_theta=10000.0,
440
+ partial_rotary_factor=1.0,
441
+ **kwargs,
442
+ ):
443
+ super().__init__(**kwargs)
444
+ self.hidden_size = hidden_size
445
+ self.layer_idx = layer_idx
446
+
447
+ self.self_attn = SmolLM3Attention(
448
+ hidden_size=hidden_size,
449
+ num_attention_heads=num_attention_heads,
450
+ num_key_value_heads=num_key_value_heads,
451
+ attention_bias=attention_bias,
452
+ attention_dropout=attention_dropout,
453
+ rope_layer_enabled_list=rope_layer_enabled_list,
454
+ layer_types=layer_types,
455
+ layer_idx=layer_idx,
456
+ max_position_embeddings=max_position_embeddings,
457
+ rope_theta=rope_theta,
458
+ partial_rotary_factor=partial_rotary_factor,
459
+ name="self_attn",
460
+ )
461
+
462
+ self.mlp = SmolLM3MLP(
463
+ hidden_size=hidden_size,
464
+ intermediate_size=intermediate_size,
465
+ mlp_bias=mlp_bias,
466
+ name="mlp",
467
+ )
468
+
469
+ self.input_layernorm = layers.RMSNormalization(
470
+ epsilon=layer_norm_epsilon, axis=-1, name="input_layernorm"
471
+ )
472
+ self.post_attention_layernorm = layers.RMSNormalization(
473
+ epsilon=layer_norm_epsilon, axis=-1, name="post_attention_layernorm"
474
+ )
475
+
476
+ self.attention_type = layer_types[layer_idx]
477
+
478
+ def _compute_self_attention_mask(
479
+ self,
480
+ decoder_sequence,
481
+ decoder_padding_mask,
482
+ decoder_attention_mask,
483
+ self_attention_cache,
484
+ self_attention_cache_update_index,
485
+ ):
486
+ decoder_mask = merge_padding_and_attention_mask(
487
+ decoder_sequence, decoder_padding_mask, decoder_attention_mask
488
+ )
489
+ batch_size = ops.shape(decoder_sequence)[0]
490
+ input_length = output_length = ops.shape(decoder_sequence)[1]
491
+ # We need to handle a rectangular causal mask when doing cached
492
+ # decoding. For generative inference, `decoder_sequence` will
493
+ # generally be length 1, and `cache` will be the full generation length.
494
+ if self_attention_cache is not None:
495
+ input_length = ops.shape(self_attention_cache)[2]
496
+
497
+ cache_update_index = (
498
+ 0
499
+ if self_attention_cache_update_index is None
500
+ else self_attention_cache_update_index
501
+ )
502
+
503
+ causal_mask = compute_causal_mask(
504
+ batch_size, input_length, output_length, cache_update_index
505
+ )
506
+
507
+ return (
508
+ ops.minimum(decoder_mask, causal_mask)
509
+ if decoder_mask is not None
510
+ else causal_mask
511
+ )
512
+
513
+ def build(self, input_shape):
514
+ """
515
+ Builds the sub-layers based on the input shape.
516
+
517
+ Args:
518
+ input_shape: The input shape to the decoder layer
519
+ (batch_size, seq_len, hidden_size).
520
+ """
521
+ # input_shape for SmolLM3DecoderLayer: (batch_size, seq_len,
522
+ # hidden_size)
523
+ batch_size = input_shape[0]
524
+ seq_len = input_shape[1]
525
+
526
+ head_dim = self.self_attn.head_dim
527
+ pos_emb_shape = (batch_size, seq_len, head_dim)
528
+
529
+ attn_mask_shape = (batch_size, 1, seq_len, seq_len)
530
+
531
+ # Pass the correct input shape to self_attn's build method
532
+ # The input_shape for self_attn.build is a list:
533
+ # [hidden_states_shape, (pos_emb_shape, pos_emb_shape), attn_mask_shape]
534
+ self.self_attn.build(
535
+ [input_shape, (pos_emb_shape, pos_emb_shape), attn_mask_shape]
536
+ )
537
+
538
+ self.mlp.build(input_shape)
539
+ self.input_layernorm.build(input_shape)
540
+ self.post_attention_layernorm.build(input_shape)
541
+
542
+ super().build(input_shape)
543
+
544
+ def call(
545
+ self,
546
+ hidden_states,
547
+ training=False,
548
+ decoder_padding_mask=None,
549
+ decoder_attention_mask=None,
550
+ **kwargs,
551
+ ):
552
+ """
553
+ Forward pass for SmolLM3DecoderLayer.
554
+
555
+ Args:
556
+ hidden_states: Input tensor of shape (batch_size,
557
+ seq_len, hidden_size).
558
+ position_embeddings: Optional tuple of (cos, sin)
559
+ tensors for RoPE.
560
+ training: Whether the layer is in training mode.
561
+ """
562
+ self_attention_cache = kwargs.get("self_attention_cache", None)
563
+ self_attention_cache_update_index = kwargs.get(
564
+ "self_attention_cache_update_index", None
565
+ )
566
+
567
+ self_attention_mask = self._compute_self_attention_mask(
568
+ decoder_sequence=hidden_states,
569
+ decoder_padding_mask=decoder_padding_mask,
570
+ decoder_attention_mask=decoder_attention_mask,
571
+ self_attention_cache=self_attention_cache,
572
+ self_attention_cache_update_index=self_attention_cache_update_index,
573
+ )
574
+
575
+ residual = hidden_states
576
+ hidden_states = self.input_layernorm(hidden_states)
577
+
578
+ # Self Attention
579
+ x = self.self_attn(
580
+ hidden_states=hidden_states,
581
+ training=training,
582
+ attention_mask=self_attention_mask,
583
+ **kwargs,
584
+ )
585
+
586
+ if isinstance(x, tuple):
587
+ attn_output, self_attention_cache = x
588
+ else:
589
+ attn_output = x
590
+
591
+ hidden_states = ops.add(residual, attn_output)
592
+
593
+ residual = hidden_states
594
+ hidden_states = self.post_attention_layernorm(hidden_states)
595
+ hidden_states = self.mlp(hidden_states)
596
+ hidden_states = ops.add(residual, hidden_states)
597
+
598
+ if self_attention_cache is not None:
599
+ return hidden_states, self_attention_cache
600
+ else:
601
+ return hidden_states
602
+
603
+ def compute_output_shape(self, input_shape):
604
+ """
605
+ Computes the output shape of the layer.
606
+
607
+ Args:
608
+ input_shape: The input shape (batch_size, seq_len, hidden_size).
609
+
610
+ Returns:
611
+ The output shape, which is the same as the input shape:
612
+ (batch_size, seq_len, hidden_size).
613
+ """
614
+ return input_shape
615
+
616
+
617
+ class SmolLM3RotaryEmbedding(layers.Layer):
618
+ """Rotary Position Embedding (RoPE) layer for SmolLM3 model.
619
+
620
+ Args:
621
+ hidden_size: int. The hidden size of the model.
622
+ num_attention_heads: int. The number of attention heads.
623
+ max_position_embeddings: int. The maximum sequence length for position
624
+ embeddings.
625
+ rope_theta: float. The theta value for RoPE.
626
+ partial_rotary_factor: float. The factor for partial rotary embedding.
627
+ """
628
+
629
+ def __init__(
630
+ self,
631
+ hidden_size,
632
+ num_attention_heads,
633
+ max_position_embeddings,
634
+ rope_theta,
635
+ partial_rotary_factor,
636
+ **kwargs,
637
+ ):
638
+ super().__init__(**kwargs)
639
+ self.hidden_size = hidden_size
640
+ self.num_attention_heads = num_attention_heads
641
+ self.max_position_embeddings = max_position_embeddings
642
+ self.rope_theta = rope_theta
643
+ self.partial_rotary_factor = partial_rotary_factor
644
+
645
+ self.head_dim = self.hidden_size // self.num_attention_heads
646
+
647
+ inv_freq_tensor, self.attention_scaling = rope_init(
648
+ self.rope_theta, self.partial_rotary_factor, self.head_dim
649
+ )
650
+
651
+ self.inv_freq = self.add_weight(
652
+ name="inv_freq",
653
+ shape=ops.shape(inv_freq_tensor),
654
+ dtype=inv_freq_tensor.dtype,
655
+ initializer=initializers.Constant(
656
+ ops.convert_to_numpy(inv_freq_tensor)
657
+ ),
658
+ trainable=False, # This weight is not trained
659
+ )
660
+ self.original_inv_freq = self.inv_freq
661
+
662
+ def build(self, input_shape):
663
+ """
664
+ Builds the layer. For SmolLM3RotaryEmbedding, this mainly
665
+ ensures that the parent layer's build is called.
666
+
667
+ Args:
668
+ input_shape: A list/tuple of shapes for the inputs:
669
+ [x_shape, position_ids_shape]
670
+ - x_shape: (batch_size, ..., head_dim)
671
+ - position_ids_shape: (batch_size, seq_len)
672
+ """
673
+ # No internal layers to explicitly build here, as inv_freq is
674
+ # added in __init__
675
+ super().build(input_shape)
676
+
677
+ def call(
678
+ self,
679
+ x,
680
+ start_index=0,
681
+ ):
682
+ """
683
+ Forward pass for SmolLM3RotaryEmbedding.
684
+
685
+ Args:
686
+ x: Input tensor, typically query or key states.
687
+ Shape can vary, but the last dimension is head_dim.
688
+ position_ids: Tensor of position IDs of shape (batch_size, seq_len).
689
+ """
690
+ batch_size = ops.shape(x)[0]
691
+ seq_len = ops.shape(x)[1]
692
+ positions = ops.arange(seq_len, dtype="float32")
693
+ positions = positions + ops.cast(start_index, dtype="float32")
694
+
695
+ # inv_freq: (inv_freq_dim,) -> (1, inv_freq_dim, 1)
696
+ # -> (batch, inv_freq_dim, 1)
697
+ inv_freq_expanded = ops.expand_dims(
698
+ ops.expand_dims(self.inv_freq, axis=0), axis=-1
699
+ )
700
+ inv_freq_expanded = ops.broadcast_to(
701
+ inv_freq_expanded, (batch_size, ops.shape(self.inv_freq)[0], 1)
702
+ )
703
+
704
+ # positions: (seq_len,) -> (1, 1, seq_len)
705
+ # -> (batch, 1, seq_len)
706
+ position_ids_expanded = ops.expand_dims(
707
+ ops.expand_dims(positions, axis=0), axis=0
708
+ )
709
+ position_ids_expanded = ops.broadcast_to(
710
+ position_ids_expanded, (batch_size, 1, seq_len)
711
+ )
712
+
713
+ # matmul: (batch, inv_freq_dim, 1) @ (batch, 1, seq_len)
714
+ # -> (batch, inv_freq_dim, seq_len)
715
+ freqs = ops.matmul(
716
+ ops.cast(inv_freq_expanded, "float32"),
717
+ ops.cast(position_ids_expanded, "float32"),
718
+ )
719
+
720
+ # transpose: (batch, inv_freq_dim, seq_len) ->
721
+ # (batch, seq_len, inv_freq_dim)
722
+ freqs = ops.transpose(freqs, axes=(0, 2, 1))
723
+
724
+ emb = ops.concatenate((freqs, freqs), axis=-1)
725
+
726
+ cos = ops.cos(emb) * self.attention_scaling
727
+ sin = ops.sin(emb) * self.attention_scaling
728
+
729
+ return ops.cast(cos, x.dtype), ops.cast(sin, x.dtype)
730
+
731
+ def compute_output_shape(self, input_shape):
732
+ """
733
+ Computes the output shape of the layer.
734
+
735
+ Args:
736
+ input_shape: A list/tuple of shapes for the inputs:
737
+ [x_shape, position_ids_shape]
738
+ - x_shape: (batch_size, ..., head_dim)
739
+ - position_ids_shape: (batch_size, seq_len)
740
+
741
+ Returns:
742
+ A list of output shapes for (cos, sin):
743
+ [(batch_size, seq_len, head_dim), (batch_size, seq_len, head_dim)]
744
+ """
745
+ if input_shape[1] is not None and len(input_shape[1]) >= 2:
746
+ batch_size = input_shape[1][0]
747
+ seq_len = input_shape[1][1]
748
+ else:
749
+ # Fallback if position_ids_shape is None or malformed.
750
+ # In this case, the batch_size and seq_len are unknown.
751
+ batch_size = None
752
+ seq_len = None
753
+
754
+ # The output cos and sin have shape (batch_size, seq_len, head_dim)
755
+ output_shape = (batch_size, seq_len, self.head_dim)
756
+
757
+ return [output_shape, output_shape]