sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/srt/configs/model_config.py +1 -0
  2. sglang/srt/conversation.py +1 -0
  3. sglang/srt/custom_op.py +7 -1
  4. sglang/srt/disaggregation/base/conn.py +2 -0
  5. sglang/srt/disaggregation/decode.py +1 -1
  6. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  8. sglang/srt/disaggregation/nixl/conn.py +94 -46
  9. sglang/srt/disaggregation/prefill.py +3 -2
  10. sglang/srt/disaggregation/utils.py +12 -11
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/openai/protocol.py +47 -4
  13. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  14. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  15. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  16. sglang/srt/layers/activation.py +7 -0
  17. sglang/srt/layers/attention/flashattention_backend.py +24 -14
  18. sglang/srt/layers/layernorm.py +15 -0
  19. sglang/srt/layers/linear.py +18 -1
  20. sglang/srt/layers/logits_processor.py +12 -3
  21. sglang/srt/layers/moe/ep_moe/layer.py +79 -12
  22. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  23. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
  25. sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
  26. sglang/srt/layers/moe/topk.py +26 -0
  27. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  28. sglang/srt/layers/rotary_embedding.py +103 -11
  29. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  30. sglang/srt/managers/expert_distribution.py +21 -0
  31. sglang/srt/managers/io_struct.py +10 -2
  32. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  33. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  34. sglang/srt/managers/schedule_batch.py +9 -1
  35. sglang/srt/managers/scheduler.py +42 -6
  36. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  37. sglang/srt/model_executor/model_runner.py +5 -2
  38. sglang/srt/model_loader/loader.py +45 -10
  39. sglang/srt/model_loader/weight_utils.py +89 -0
  40. sglang/srt/models/deepseek_nextn.py +7 -4
  41. sglang/srt/models/deepseek_v2.py +147 -4
  42. sglang/srt/models/gemma3n_audio.py +949 -0
  43. sglang/srt/models/gemma3n_causal.py +1009 -0
  44. sglang/srt/models/gemma3n_mm.py +511 -0
  45. sglang/srt/models/hunyuan.py +771 -0
  46. sglang/srt/server_args.py +16 -2
  47. sglang/srt/two_batch_overlap.py +4 -1
  48. sglang/srt/utils.py +71 -0
  49. sglang/version.py +1 -1
  50. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
  51. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
  52. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  53. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  54. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,949 @@
1
+ import math
2
+ from typing import Optional, Sequence, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers import Gemma3nAudioConfig, PreTrainedModel
8
+
9
+ from sglang.srt.layers.linear import (
10
+ ColumnParallelLinear,
11
+ QKVParallelLinear,
12
+ RowParallelLinear,
13
+ )
14
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
15
+ from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm
16
+ from sglang.srt.utils import add_prefix, make_layers
17
+
18
+
19
+ class Gemma3nCumulativeGroupNorm(nn.Module):
20
+ """Applies Group Normalization cumulatively over the time dimension.
21
+
22
+ This layer normalizes the input by calculating the mean and variance
23
+ cumulatively over the time dimension (dim 1). The statistics are computed
24
+ over all feature dimensions (specified by `feature_dims` and `num_channels`)
25
+ for elements marked as valid by the optional `mask`.
26
+
27
+ If a `mask` is provided (True for valid, False for invalid/padded),
28
+ invalid time steps do not contribute to the statistics calculation, and
29
+ their corresponding output values are zeroed out.
30
+
31
+ Scale and bias, if enabled, are applied per-channel (last dimension).
32
+ This behavior is similar to JAX's `GroupNormalization` with `num_groups=1`
33
+ and `cumulative=True`.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ num_channels: int, # Number of channels (size of the last dimension)
39
+ feature_dims: Sequence[
40
+ int
41
+ ], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C]
42
+ eps: float = 1e-3,
43
+ ):
44
+ super().__init__()
45
+ self.num_channels = num_channels
46
+ self.feature_dims = tuple(feature_dims)
47
+ self.eps = eps
48
+
49
+ # Scale parameter depends only on the channel dimension
50
+ self.weight = nn.Parameter(torch.ones(num_channels))
51
+
52
+ # Axes for normalization: all dimensions except Batch (0) and Time (1).
53
+ # For input [B, T, *feature_dims, C], these are dims from 2 onwards.
54
+ self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
55
+
56
+ def forward(
57
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
58
+ ) -> torch.Tensor:
59
+ """Applies cumulative group norm, optionally using a mask.
60
+
61
+ Args:
62
+ x: Input tensor, shape [B, T, *feature_dims, C].
63
+ mask: Optional boolean mask, shape [B, T]. True indicates a valid
64
+ (non-padded) time step. If None, all time steps are considered valid.
65
+
66
+ Returns:
67
+ Normalized tensor with the same shape as x.
68
+ """
69
+ expected_input_suffix = self.feature_dims + (self.num_channels,)
70
+ if x.shape[2:] != expected_input_suffix:
71
+ raise ValueError(
72
+ f"Input tensor shape suffix {x.shape[2:]} does not match expected"
73
+ f" suffix (feature_dims + num_channels) {expected_input_suffix}"
74
+ )
75
+
76
+ input_dtype = x.dtype
77
+ # Calculations are performed in float32 for numerical stability.
78
+ calc_dtype = torch.float32
79
+ x_calc = x.to(calc_dtype)
80
+
81
+ # Prepare a broadcastable mask (`mask_calc`).
82
+ # If no mask is provided, treat all elements as valid
83
+ # (mask_calc is all ones).
84
+ # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
85
+ mask_calc = torch.ones_like(x_calc, dtype=calc_dtype)
86
+
87
+ # Cumulative Statistics Calculation
88
+ # 1. Sum of values over reduction axes at each time step.
89
+ sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True)
90
+ # 2. Cumulative sum of values over time.
91
+ cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
92
+
93
+ # 3. Count of valid elements in the normalization group at each time step.
94
+ # (A "group" here consists of all features at a given Batch, Time).
95
+ elements_in_group_at_t = torch.sum(
96
+ mask_calc, dim=self.reduction_axes, keepdim=True
97
+ )
98
+ # 4. Cumulative count of valid elements over time.
99
+ cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
100
+ # Avoid division by zero if all preceding elements were masked.
101
+ safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0)
102
+
103
+ # 5. Cumulative mean.
104
+ cum_mean = cum_sum_values / safe_cum_count_elements
105
+
106
+ # 6. Sum of squared differences from the cumulative mean.
107
+ # Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc.
108
+ # Using x_calc here for the difference, as cum_mean already accounts for masking.
109
+ squared_diff_from_mean = (x_calc - cum_mean).pow(2)
110
+ sum_sq_diff_at_t = torch.sum(
111
+ squared_diff_from_mean, dim=self.reduction_axes, keepdim=True
112
+ )
113
+
114
+ # 7. Cumulative sum of squared differences over time.
115
+ cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1)
116
+
117
+ # 8. Cumulative variance.
118
+ cum_variance = cum_sum_sq_diff / safe_cum_count_elements
119
+
120
+ # Normalize the input using the calculated cumulative statistics:
121
+ # (x - E[x]) / sqrt(Var[x] + eps)
122
+ normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps)
123
+
124
+ # Apply affine transformation (scale and bias) if enabled.
125
+ # Scale and bias are applied per-channel (last dimension).
126
+ scale = self.weight.to(calc_dtype)
127
+ # Reshape for broadcasting: [C] -> [1, ..., 1, C]
128
+ scale_view_shape = [1] * (x.dim() - 1) + [self.num_channels]
129
+ normalized_x = normalized_x * scale.view(scale_view_shape)
130
+
131
+ # Zero out outputs for time steps that were originally masked (where mask_calc is 0).
132
+ # This ensures padded/invalid positions in the input result in zero output.
133
+ final_output = normalized_x * mask_calc
134
+
135
+ return final_output.to(input_dtype)
136
+
137
+
138
+ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
139
+ def __init__(
140
+ self,
141
+ config: Gemma3nAudioConfig,
142
+ quant_config: Optional[QuantizationConfig] = None,
143
+ prefix: str = "",
144
+ ):
145
+ super().__init__()
146
+ self.config = config
147
+
148
+ self.num_heads = self.config.conf_num_attention_heads
149
+ self.channels = self.config.hidden_size
150
+ self.head_dim = self.channels // self.num_heads
151
+ self.max_backward = max(0, self.config.conf_attention_context_left - 1)
152
+ self.max_forward = self.config.conf_attention_context_right
153
+
154
+ self.pos_proj = ColumnParallelLinear(
155
+ self.channels,
156
+ self.num_heads * self.head_dim,
157
+ bias=False,
158
+ quant_config=quant_config,
159
+ prefix=add_prefix("pos_proj", prefix),
160
+ )
161
+
162
+ min_timescale = 1.0
163
+ max_timescale = 1.0e4
164
+ num_timescales = self.channels // 2
165
+ log_timescale_increment = math.log(
166
+ float(max_timescale) / float(min_timescale)
167
+ ) / max(num_timescales - 1, 1)
168
+ inv_timescales = min_timescale * torch.exp(
169
+ torch.arange(num_timescales) * -log_timescale_increment
170
+ )
171
+ self.register_buffer(
172
+ "inv_timescales",
173
+ inv_timescales.float().unsqueeze(0).unsqueeze(0),
174
+ persistent=False,
175
+ )
176
+
177
+ def _get_timing_signal_1d_pos(
178
+ self, position: torch.Tensor, dtype: torch.dtype
179
+ ) -> torch.Tensor:
180
+ assert position.ndim == 2
181
+ position = position.float().unsqueeze(-1)
182
+ scaled_time = position * self.inv_timescales.to(
183
+ device=position.device, dtype=torch.float32
184
+ )
185
+ timing_signal = torch.cat(
186
+ [torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1
187
+ )
188
+ return timing_signal.type(dtype)
189
+
190
+ def _relative_shift(
191
+ self,
192
+ term_bd_before_shift: torch.Tensor,
193
+ batch_size: int,
194
+ num_heads: int,
195
+ num_query_blocks: int,
196
+ query_block_size: int,
197
+ key_context_size: int,
198
+ max_span_plus_1: int,
199
+ ) -> torch.Tensor:
200
+ """Performs the relative shift."""
201
+ pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
202
+ padding_tuple = (0, pad_amount_last_dim)
203
+
204
+ term_bd_padded = F.pad(term_bd_before_shift, padding_tuple)
205
+ term_bd_reshaped = term_bd_padded.reshape(
206
+ (
207
+ batch_size,
208
+ num_heads,
209
+ num_query_blocks,
210
+ query_block_size * (key_context_size + 1),
211
+ )
212
+ )
213
+ term_bd_sliced = term_bd_reshaped[
214
+ :, :, :, : query_block_size * key_context_size
215
+ ]
216
+ term_bd_shifted = term_bd_sliced.reshape(
217
+ (
218
+ batch_size,
219
+ num_heads,
220
+ num_query_blocks,
221
+ query_block_size,
222
+ key_context_size,
223
+ )
224
+ )
225
+ return term_bd_shifted
226
+
227
+ def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
228
+ batch_size, num_query_blocks, query_block_size, num_heads, head_dim = (
229
+ queries.shape
230
+ )
231
+ _, _, key_context_size, _, _ = keys.shape
232
+
233
+ pos_indices = torch.arange(
234
+ self.max_backward, -self.max_forward - 1, -1, device=queries.device
235
+ ).unsqueeze(0)
236
+ max_span_plus_1 = pos_indices.shape[1]
237
+
238
+ sin_emb_timing_signal = self._get_timing_signal_1d_pos(
239
+ pos_indices, dtype=queries.dtype
240
+ )
241
+ projected_sin_emb, _ = self.pos_proj(sin_emb_timing_signal)
242
+ sin_emb = projected_sin_emb.reshape(
243
+ 1, max_span_plus_1, self.num_heads, self.head_dim
244
+ ).squeeze(0)
245
+
246
+ queries_p = queries.permute(0, 3, 1, 2, 4)
247
+ keys_p_t = keys.permute(0, 3, 1, 4, 2)
248
+ term_ac = torch.matmul(queries_p, keys_p_t)
249
+
250
+ q_permuted = queries.permute(0, 3, 1, 2, 4)
251
+ s_permuted = sin_emb.permute(1, 2, 0)
252
+ q_reshaped = q_permuted.reshape(
253
+ batch_size, num_heads, num_query_blocks * query_block_size, head_dim
254
+ )
255
+ term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
256
+ term_bd_unshifed = term_bd_unshifed_matmul.reshape(
257
+ batch_size,
258
+ num_heads,
259
+ num_query_blocks,
260
+ query_block_size,
261
+ max_span_plus_1,
262
+ )
263
+
264
+ term_bd_shifted = self._relative_shift(
265
+ term_bd_unshifed,
266
+ batch_size,
267
+ num_heads,
268
+ num_query_blocks,
269
+ query_block_size,
270
+ key_context_size,
271
+ max_span_plus_1,
272
+ )
273
+
274
+ return term_ac + term_bd_shifted
275
+
276
+
277
+ class Gemma3nAudioAttention(nn.Module):
278
+ """Local dot product self-attention for audio."""
279
+
280
+ def __init__(
281
+ self,
282
+ config: Gemma3nAudioConfig,
283
+ quant_config: Optional[QuantizationConfig] = None,
284
+ prefix: str = "",
285
+ ):
286
+ super().__init__()
287
+ self.config = config
288
+
289
+ self.num_heads = self.config.conf_num_attention_heads
290
+ self.hidden_size = self.config.hidden_size
291
+ self.head_dim = self.hidden_size // self.num_heads
292
+
293
+ self.chunk_size = self.config.conf_attention_chunk_size
294
+ self.max_future_horizon = self.config.conf_attention_context_right
295
+ self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1)
296
+ self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
297
+ self.context_size = (
298
+ self.chunk_size + self.max_past_horizon + self.max_future_horizon
299
+ )
300
+
301
+ self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(
302
+ config,
303
+ quant_config,
304
+ prefix=add_prefix("relative_position_embedding", prefix),
305
+ )
306
+ self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
307
+
308
+ self.qkv_proj = QKVParallelLinear(
309
+ self.hidden_size,
310
+ self.head_dim,
311
+ self.num_heads,
312
+ self.num_heads,
313
+ bias=False,
314
+ quant_config=quant_config,
315
+ prefix=add_prefix("qkv_proj", prefix),
316
+ )
317
+
318
+ q_scale = self.head_dim**-0.5
319
+ r_softplus_0 = 1.0 / F.softplus(torch.tensor(0.0))
320
+ self.register_buffer(
321
+ "q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False
322
+ )
323
+
324
+ # Create local causal mask
325
+ lower_causal_mask = torch.tril(
326
+ torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
327
+ diagonal=0,
328
+ ).T
329
+ upper_causal_mask = torch.tril(
330
+ torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
331
+ diagonal=self.max_past_horizon + self.max_future_horizon,
332
+ )
333
+ local_causal_valid_mask = torch.ones(
334
+ (self.chunk_size, self.context_size), dtype=torch.bool
335
+ )
336
+ local_causal_valid_mask = (
337
+ local_causal_valid_mask * lower_causal_mask * upper_causal_mask
338
+ )
339
+ self.register_buffer(
340
+ "local_causal_valid_mask", local_causal_valid_mask, persistent=False
341
+ )
342
+
343
+ self.register_buffer(
344
+ "softcap",
345
+ torch.tensor(self.attention_logits_soft_cap).float(),
346
+ persistent=False,
347
+ )
348
+
349
+ def _pad_dim1(
350
+ self, x: torch.Tensor, dim10_val: int, dim11_val: int
351
+ ) -> torch.Tensor:
352
+ padding_tuple = [0] * x.ndim * 2
353
+ dim_idx_from_end = x.ndim - 2
354
+ start_idx_for_dim = 2 * dim_idx_from_end
355
+ padding_tuple[start_idx_for_dim] = dim10_val
356
+ padding_tuple[start_idx_for_dim + 1] = dim11_val
357
+ return F.pad(x, tuple(padding_tuple))
358
+
359
+ def _convert_to_block(self, x: torch.Tensor) -> torch.Tensor:
360
+ """Turns a sequence to non overlapping blocks."""
361
+ shape = x.shape
362
+ b, t = shape[:2]
363
+ num_blocks = (t + self.chunk_size - 1) // self.chunk_size
364
+
365
+ if (padding_len := num_blocks * self.chunk_size - t) > 0:
366
+ x = self._pad_dim1(x, 0, padding_len)
367
+
368
+ permute_dims = (b, num_blocks, self.chunk_size) + shape[2:]
369
+ x = x.reshape(permute_dims).contiguous()
370
+ return x
371
+
372
+ def _extract_block_context(self, x: torch.Tensor) -> torch.Tensor:
373
+ """Extracts temporal context for every block."""
374
+ pad_left = self.max_past_horizon
375
+ pad_right = self.max_future_horizon + self.chunk_size - 1
376
+ x = self._pad_dim1(x, pad_left, pad_right)
377
+
378
+ frame_len = self.context_size
379
+ frame_step = self.chunk_size
380
+
381
+ x_unfolded = x.unfold(dimension=1, size=frame_len, step=frame_step)
382
+
383
+ if x.ndim > 2 and x_unfolded.ndim > 3:
384
+ x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2)
385
+
386
+ return x_unfolded.contiguous()
387
+
388
+ def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
389
+ # Project to Q, K, V
390
+ qkv, _ = self.qkv_proj(x)
391
+ query_states, key_states, value_states = qkv.chunk(chunks=3, dim=-1)
392
+
393
+ # Reshape
394
+ query_states = query_states.reshape(
395
+ *x.shape[:-1], self.num_heads, self.head_dim
396
+ ).contiguous()
397
+ key_states = key_states.reshape(
398
+ *x.shape[:-1], self.num_heads, self.head_dim
399
+ ).contiguous()
400
+ value_states = value_states.reshape(
401
+ *x.shape[:-1], self.num_heads, self.head_dim
402
+ ).contiguous()
403
+
404
+ # Apply per-dim scale
405
+ per_dim_scale_sp = F.softplus(self.per_dim_scale)
406
+ broadcast_shape = (1, 1, 1, self.head_dim)
407
+ per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape)
408
+ query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast
409
+
410
+ batch_size, q_time = query_states.shape[:2]
411
+
412
+ # Convert to blocks
413
+ query_blocks = self._convert_to_block(query_states)
414
+ key_blocks = self._extract_block_context(key_states)
415
+ value_blocks = self._extract_block_context(value_states)
416
+ num_query_blocks = query_blocks.shape[1]
417
+
418
+ # Create mask for valid positions
419
+ original_valid_mask = ~mask
420
+ extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask)
421
+
422
+ if (
423
+ extracted_valid_mask_blocks.ndim == 4
424
+ and extracted_valid_mask_blocks.shape[0] == batch_size
425
+ and extracted_valid_mask_blocks.shape[1] == num_query_blocks
426
+ and extracted_valid_mask_blocks.shape[2]
427
+ * extracted_valid_mask_blocks.shape[3]
428
+ == self.context_size
429
+ ):
430
+ extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
431
+ batch_size, num_query_blocks, self.context_size
432
+ )
433
+
434
+ condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(
435
+ 1
436
+ ).unsqueeze(-2)
437
+ condition_from_causality = (
438
+ self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0)
439
+ )
440
+
441
+ final_condition_for_where = torch.logical_and(
442
+ condition_from_input_validity,
443
+ condition_from_causality.to(condition_from_input_validity.device),
444
+ )
445
+
446
+ # Compute attention scores
447
+ logits = self.relative_position_embedding(query_blocks, key_blocks)
448
+
449
+ # Apply attention logit softcap
450
+ softcap_val = self.softcap.to(logits.device)
451
+ logits = logits / softcap_val
452
+ logits = torch.tanh(logits)
453
+ logits = logits * softcap_val
454
+
455
+ # Apply the combined mask.
456
+ # final_condition_for_where will broadcast with logits [B,N,U,W,C]
457
+ logits = torch.where(
458
+ final_condition_for_where, logits, torch.finfo(logits.dtype).min
459
+ )
460
+
461
+ probabilities = F.softmax(logits, dim=-1, dtype=torch.float32).to(
462
+ dtype=value_blocks.dtype
463
+ )
464
+
465
+ # context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
466
+ b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
467
+ h_dim = value_blocks.shape[-1]
468
+ prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
469
+ v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim)
470
+ result_bmm = torch.bmm(prob_bun, v_bun)
471
+ context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(
472
+ 0, 1, 3, 2, 4
473
+ )
474
+ context_vectors = context_vectors.reshape(
475
+ (
476
+ batch_size,
477
+ num_query_blocks * self.chunk_size,
478
+ self.num_heads,
479
+ self.head_dim,
480
+ )
481
+ )
482
+ context_vectors = context_vectors[:, :q_time]
483
+
484
+ return context_vectors
485
+
486
+
487
+ class Gemma3nAudioSSCPConvBlock(nn.Module):
488
+ """A single convolution block for the SubSampleConvProjection."""
489
+
490
+ def __init__(
491
+ self,
492
+ config: Gemma3nAudioConfig,
493
+ idx: int,
494
+ input_freq_dim: int,
495
+ manual_padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
496
+ quant_config: Optional[QuantizationConfig] = None,
497
+ prefix: str = "",
498
+ ):
499
+ super().__init__()
500
+ self.config = config
501
+ self.manual_padding = manual_padding
502
+
503
+ in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
504
+ out_channels = self.config.sscp_conv_channel_size[idx]
505
+ kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
506
+ stride_h, stride_w = self.config.sscp_conv_stride_size[idx]
507
+
508
+ self.conv = nn.Conv2d(
509
+ in_channels=in_channels,
510
+ out_channels=out_channels,
511
+ kernel_size=(kernel_h, kernel_w),
512
+ stride=(stride_h, stride_w),
513
+ padding=(0, 0), # Manual padding is used
514
+ bias=False,
515
+ )
516
+
517
+ f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1]
518
+ f_out_conv = (f_in_padded - kernel_w) // stride_w + 1
519
+
520
+ self.norm = Gemma3nCumulativeGroupNorm(
521
+ num_channels=out_channels,
522
+ feature_dims=(f_out_conv,),
523
+ eps=self.config.sscp_conv_group_norm_eps,
524
+ )
525
+
526
+ self.activation = nn.ReLU()
527
+
528
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
529
+ audio_encodings_padded = F.pad(
530
+ audio_encodings, self.manual_padding, mode="constant", value=0.0
531
+ )
532
+ audio_encodings_conv = self.conv(audio_encodings_padded)
533
+ x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous()
534
+ x_normed = self.norm(x_for_norm)
535
+ audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous()
536
+ return self.activation(audio_encodings_normed)
537
+
538
+
539
+ class Gemma3nAudioSubSampleConvProjection(nn.Module):
540
+ def __init__(
541
+ self,
542
+ config: Gemma3nAudioConfig,
543
+ quant_config: Optional[QuantizationConfig] = None,
544
+ prefix: str = "",
545
+ ):
546
+ super().__init__()
547
+ self.config = config
548
+
549
+ current_f_for_block_input = config.input_feat_size
550
+ calculated_block_padding = []
551
+ calculated_f_out_dims = []
552
+
553
+ for i in range(2): # Assuming 2 conv layers
554
+ kernel_h, kernel_w = config.sscp_conv_kernel_size[i]
555
+ stride_h, stride_w = config.sscp_conv_stride_size[i]
556
+
557
+ # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
558
+ pad_t_top = 0
559
+ pad_t_bottom = kernel_h - 1
560
+
561
+ # Frequency Padding (Width for Conv2d)
562
+ pad_f_left = 1
563
+ pad_f_right = 1
564
+
565
+ manual_padding_tuple = (pad_f_left, pad_f_right, pad_t_top, pad_t_bottom)
566
+ calculated_block_padding.append(manual_padding_tuple)
567
+
568
+ f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right
569
+ f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1
570
+ calculated_f_out_dims.append(f_out_after_conv)
571
+ current_f_for_block_input = f_out_after_conv
572
+
573
+ self.conv_0 = Gemma3nAudioSSCPConvBlock(
574
+ idx=0,
575
+ input_freq_dim=config.input_feat_size,
576
+ config=config,
577
+ manual_padding=calculated_block_padding[0],
578
+ quant_config=quant_config,
579
+ prefix=add_prefix("conv_0", prefix),
580
+ )
581
+ self.conv_1 = Gemma3nAudioSSCPConvBlock(
582
+ idx=1,
583
+ input_freq_dim=calculated_f_out_dims[0],
584
+ config=config,
585
+ manual_padding=calculated_block_padding[1],
586
+ quant_config=quant_config,
587
+ prefix=add_prefix("conv_1", prefix),
588
+ )
589
+
590
+ final_c_out = config.sscp_conv_channel_size[-1]
591
+ final_f_out = calculated_f_out_dims[-1]
592
+ self.input_proj_in_features = final_c_out * final_f_out
593
+
594
+ self.input_proj_linear = RowParallelLinear(
595
+ self.input_proj_in_features,
596
+ self.config.hidden_size,
597
+ bias=False,
598
+ quant_config=quant_config,
599
+ prefix=add_prefix("input_proj_linear", prefix),
600
+ )
601
+
602
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
603
+ audio_encodings_reshaped = audio_encodings.unsqueeze(1)
604
+ x = self.conv_0(audio_encodings_reshaped)
605
+ x = self.conv_1(x)
606
+ b, c_out, t_out, f_out = x.shape
607
+ x_permuted = x.permute(0, 2, 3, 1).contiguous()
608
+ output_flattened = x_permuted.view(b, t_out, f_out * c_out)
609
+ output, _ = self.input_proj_linear(output_flattened)
610
+ return output
611
+
612
+
613
+ class Gemma3nAudioConformerAttention(nn.Module):
614
+ def __init__(
615
+ self,
616
+ config: Gemma3nAudioConfig,
617
+ quant_config: Optional[QuantizationConfig] = None,
618
+ prefix: str = "",
619
+ ):
620
+ super().__init__()
621
+ self.config = config
622
+
623
+ head_dim = self.config.hidden_size // self.config.conf_num_attention_heads
624
+ self.post_in_shape = (self.config.conf_num_attention_heads, head_dim)
625
+ self.post_in_features = self.config.hidden_size
626
+
627
+ self.register_buffer(
628
+ "gradient_clipping",
629
+ torch.tensor(self.config.gradient_clipping),
630
+ persistent=False,
631
+ )
632
+
633
+ self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
634
+ self.attn = Gemma3nAudioAttention(
635
+ config, quant_config, prefix=add_prefix("attn", prefix)
636
+ )
637
+ self.post = RowParallelLinear(
638
+ self.post_in_features,
639
+ self.config.hidden_size,
640
+ bias=False,
641
+ quant_config=quant_config,
642
+ prefix=add_prefix("post", prefix),
643
+ )
644
+ self.post_norm = Gemma3nRMSNorm(self.config.hidden_size)
645
+
646
+ def forward(
647
+ self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor
648
+ ) -> torch.Tensor:
649
+ audio_encodings_input_to_attn = audio_encodings
650
+ audio_encodings = torch.clamp(
651
+ audio_encodings, -self.gradient_clipping, self.gradient_clipping
652
+ )
653
+ audio_encodings_norm = self.pre_attn_norm(audio_encodings)
654
+ audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask)
655
+
656
+ b, t, num_heads, head_dim = audio_encodings_attn_out.shape
657
+ audio_encodings_reshaped = audio_encodings_attn_out.reshape(
658
+ b, t, num_heads * head_dim
659
+ )
660
+
661
+ audio_encodings, _ = self.post(audio_encodings_reshaped)
662
+ audio_encodings = torch.clamp(
663
+ audio_encodings, -self.gradient_clipping, self.gradient_clipping
664
+ )
665
+ return audio_encodings_input_to_attn + self.post_norm(audio_encodings)
666
+
667
+
668
+ class Gemma3nAudioConformerFeedForward(nn.Module):
669
+ def __init__(
670
+ self,
671
+ config: Gemma3nAudioConfig,
672
+ quant_config: Optional[QuantizationConfig] = None,
673
+ prefix: str = "",
674
+ ):
675
+ super().__init__()
676
+ self.config = config
677
+
678
+ self.register_buffer(
679
+ "gradient_clipping",
680
+ torch.tensor(self.config.gradient_clipping),
681
+ persistent=False,
682
+ )
683
+
684
+ self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
685
+ self.ffw_layer_1 = ColumnParallelLinear(
686
+ self.config.hidden_size,
687
+ self.config.hidden_size * 4,
688
+ bias=False,
689
+ quant_config=quant_config,
690
+ prefix=add_prefix("ffw_layer_1", prefix),
691
+ )
692
+ self.ffw_layer_2 = RowParallelLinear(
693
+ self.config.hidden_size * 4,
694
+ self.config.hidden_size,
695
+ bias=False,
696
+ quant_config=quant_config,
697
+ prefix=add_prefix("ffw_layer_2", prefix),
698
+ )
699
+ self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
700
+ self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
701
+
702
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
703
+ residual = audio_encodings
704
+ audio_encodings = torch.clamp(
705
+ audio_encodings, -self.gradient_clipping, self.gradient_clipping
706
+ )
707
+ audio_encodings = self.pre_layer_norm(audio_encodings)
708
+ audio_encodings, _ = self.ffw_layer_1(audio_encodings)
709
+ audio_encodings = F.silu(audio_encodings)
710
+ audio_encodings, _ = self.ffw_layer_2(audio_encodings)
711
+ audio_encodings = torch.clamp(
712
+ audio_encodings, -self.gradient_clipping, self.gradient_clipping
713
+ )
714
+ audio_encodings = self.post_layer_norm(audio_encodings)
715
+ return residual + (audio_encodings * self.post_layer_scale)
716
+
717
+
718
+ class Gemma3nAudioConformerLightConv1d(nn.Module):
719
+ def __init__(
720
+ self,
721
+ config: Gemma3nAudioConfig,
722
+ quant_config: Optional[QuantizationConfig] = None,
723
+ prefix: str = "",
724
+ ):
725
+ super().__init__()
726
+ self.config = config
727
+
728
+ self.pre_layer_norm = Gemma3nRMSNorm(
729
+ self.config.hidden_size, eps=self.config.rms_norm_eps
730
+ )
731
+ self.linear_start = ColumnParallelLinear(
732
+ self.config.hidden_size,
733
+ self.config.hidden_size * 2,
734
+ bias=False,
735
+ quant_config=quant_config,
736
+ prefix=add_prefix("linear_start", prefix),
737
+ )
738
+
739
+ self.depthwise_conv1d = nn.Conv1d(
740
+ in_channels=self.config.hidden_size,
741
+ out_channels=self.config.hidden_size,
742
+ kernel_size=self.config.conf_conv_kernel_size,
743
+ stride=1,
744
+ padding=0, # Manual causal padding
745
+ groups=self.config.hidden_size, # Depthwise
746
+ bias=False,
747
+ )
748
+ self.register_buffer(
749
+ "gradient_clipping",
750
+ torch.tensor(self.config.gradient_clipping),
751
+ persistent=False,
752
+ )
753
+ self.conv_norm = Gemma3nRMSNorm(
754
+ self.config.hidden_size, eps=self.config.rms_norm_eps
755
+ )
756
+ self.linear_end = RowParallelLinear(
757
+ self.config.hidden_size,
758
+ self.config.hidden_size,
759
+ bias=False,
760
+ quant_config=quant_config,
761
+ prefix=add_prefix("linear_end", prefix),
762
+ )
763
+
764
+ self.causal_padding = self.config.conf_conv_kernel_size - 1
765
+
766
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
767
+ audio_encodings_residual = audio_encodings # Save for residual connection
768
+
769
+ audio_encodings = self.pre_layer_norm(audio_encodings)
770
+ audio_encodings, _ = self.linear_start(audio_encodings)
771
+ audio_encodings = F.glu(audio_encodings, dim=-1)
772
+
773
+ # Permute for Conv1d: [B, T, D] -> [B, D, T]
774
+ audio_encodings_permuted = audio_encodings.permute(0, 2, 1)
775
+ # Apply manual causal padding
776
+ audio_encodings_permuted_padded = F.pad(
777
+ audio_encodings_permuted, (self.causal_padding, 0)
778
+ )
779
+ audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded)
780
+ # Permute back: [B, D, T_out] -> [B, T_out, D]
781
+ audio_encodings = audio_encodings.permute(0, 2, 1)
782
+ audio_encodings = torch.clamp(
783
+ audio_encodings, -self.gradient_clipping, self.gradient_clipping
784
+ )
785
+ audio_encodings = self.conv_norm(audio_encodings)
786
+ audio_encodings = F.silu(audio_encodings)
787
+ audio_encodings, _ = self.linear_end(audio_encodings)
788
+ output = audio_encodings + audio_encodings_residual
789
+ return output
790
+
791
+
792
+ class Gemma3nAudioConformerBlock(nn.Module):
793
+ def __init__(
794
+ self,
795
+ config: Gemma3nAudioConfig,
796
+ quant_config: Optional[QuantizationConfig] = None,
797
+ prefix: str = "",
798
+ ):
799
+ super().__init__()
800
+ self.config = config
801
+
802
+ self.ffw_layer_start = Gemma3nAudioConformerFeedForward(
803
+ config, quant_config, prefix=add_prefix("ffw_layer_start", prefix)
804
+ )
805
+ self.attention = Gemma3nAudioConformerAttention(
806
+ config, quant_config, prefix=add_prefix("attention", prefix)
807
+ )
808
+ self.lconv1d = Gemma3nAudioConformerLightConv1d(
809
+ config, quant_config, prefix=add_prefix("lconv1d", prefix)
810
+ )
811
+ self.ffw_layer_end = Gemma3nAudioConformerFeedForward(
812
+ config, quant_config, prefix=add_prefix("ffw_layer_end", prefix)
813
+ )
814
+ self.register_buffer(
815
+ "gradient_clipping",
816
+ torch.tensor(self.config.gradient_clipping),
817
+ persistent=False,
818
+ )
819
+ self.norm = Gemma3nRMSNorm(self.config.hidden_size)
820
+
821
+ def forward(
822
+ self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor
823
+ ) -> torch.Tensor:
824
+ audio_encodings = self.ffw_layer_start(audio_encodings)
825
+ audio_encodings = self.attention(audio_encodings, audio_mel_mask)
826
+ validity_mask_for_lconv = ~audio_mel_mask # True for valid
827
+ audio_encodings_for_lconv_input = (
828
+ audio_encodings
829
+ * validity_mask_for_lconv.unsqueeze(-1).to(audio_encodings.dtype)
830
+ )
831
+ audio_encodings = self.lconv1d(audio_encodings_for_lconv_input)
832
+
833
+ audio_encodings = self.ffw_layer_end(audio_encodings)
834
+ audio_encodings = torch.clamp(
835
+ audio_encodings, -self.gradient_clipping, self.gradient_clipping
836
+ )
837
+ output = self.norm(audio_encodings)
838
+ return output
839
+
840
+
841
+ class Gemma3nAudioEncoder(PreTrainedModel):
842
+ """A Universal Speech Encoder -- https://arxiv.org/abs/2303.01037"""
843
+
844
+ config_class = Gemma3nAudioConfig
845
+
846
+ def __init__(
847
+ self,
848
+ config: Gemma3nAudioConfig,
849
+ quant_config: Optional[QuantizationConfig] = None,
850
+ prefix: str = "",
851
+ ):
852
+ super().__init__(config)
853
+ self.config = config
854
+
855
+ self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(
856
+ config, quant_config, prefix=add_prefix("subsample_conv_projection", prefix)
857
+ )
858
+ self.conformer = make_layers(
859
+ config.conf_num_hidden_layers,
860
+ lambda idx, prefix: Gemma3nAudioConformerBlock(
861
+ config=config,
862
+ quant_config=quant_config,
863
+ prefix=prefix,
864
+ ),
865
+ prefix=add_prefix("conformer", prefix),
866
+ )
867
+
868
+ def forward(
869
+ self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
870
+ ) -> Tuple[torch.Tensor, torch.BoolTensor]:
871
+ """Encodes a batch of MELs.
872
+
873
+ Args:
874
+ audio_mel: a torch.Tensor of shape [batch, num_frames, mel_bins].
875
+ audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
876
+
877
+ Returns:
878
+ audio_encodings: a torch.Tensor of shape
879
+ `[batch_size, reduced_time_frames, hidden_size]`
880
+ audio_mel_mask: a torch.BoolTensor of shape [batch, reduced_time_frames].
881
+ """
882
+ audio_encodings = self.subsample_conv_projection(
883
+ audio_mel
884
+ ) # audio_encodings: [B, T_sub, D]
885
+
886
+ # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
887
+ t_sub = audio_encodings.shape[1]
888
+
889
+ time_stride_product = 1
890
+ for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
891
+ time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
892
+
893
+ # Create indices for gathering from the original mask.
894
+ # These indices map to original time steps corresponding to the start of each
895
+ # receptive field in the subsampled output.
896
+ indices = (
897
+ torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product
898
+ )
899
+ indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1)
900
+
901
+ # Expand indices for batch compatibility if B > 1 and indices is 1D.
902
+ if audio_mel_mask.ndim > 1 and indices.ndim == 1:
903
+ indices = indices.unsqueeze(0).expand(
904
+ audio_mel_mask.shape[0], -1
905
+ ) # [B, T_sub]
906
+ elif (
907
+ audio_mel_mask.ndim == indices.ndim
908
+ and audio_mel_mask.shape[0] == 1
909
+ and indices.shape[0] != 1
910
+ and t_sub == indices.shape[0]
911
+ ):
912
+ # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
913
+ indices = indices.unsqueeze(0)
914
+
915
+ current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub]
916
+
917
+ # Fallback: Ensure mask length matches feature length after gather.
918
+ if current_mask.shape[1] != t_sub:
919
+ if current_mask.shape[1] > t_sub:
920
+ current_mask = current_mask[:, :t_sub]
921
+ else: # current_mask.shape[1] < t_sub
922
+ padding_needed = t_sub - current_mask.shape[1]
923
+ current_mask = F.pad(
924
+ current_mask, (0, padding_needed), value=True
925
+ ) # Pad with True (masked)
926
+
927
+ for i, block in enumerate(self.conformer):
928
+ audio_encodings = block(
929
+ audio_encodings, current_mask
930
+ ) # Pass the processed mask
931
+
932
+ if self.config.conf_reduction_factor > 1:
933
+ audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
934
+ # Reduce the mask as well
935
+ current_mask = current_mask[:, :: self.config.conf_reduction_factor]
936
+
937
+ # Final masking of audio_encodings based on the final current_mask
938
+ # Ensure current_mask length matches the finally reduced audio_encodings length
939
+ if current_mask.shape[1] != audio_encodings.shape[1]:
940
+ target_len = audio_encodings.shape[1]
941
+ mask_current_len = current_mask.shape[1]
942
+ if target_len > mask_current_len:
943
+ padding_needed = target_len - mask_current_len
944
+ current_mask = F.pad(current_mask, (0, padding_needed), value=True)
945
+ elif mask_current_len > target_len: # mask is longer
946
+ current_mask = current_mask[:, :target_len]
947
+
948
+ audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
949
+ return audio_encodings, current_mask