sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +3 -2
- sglang/srt/disaggregation/utils.py +12 -11
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/openai/protocol.py +47 -4
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/attention/flashattention_backend.py +24 -14
- sglang/srt/layers/layernorm.py +15 -0
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +12 -3
- sglang/srt/layers/moe/ep_moe/layer.py +79 -12
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
- sglang/srt/layers/moe/topk.py +26 -0
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/rotary_embedding.py +103 -11
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +10 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +9 -1
- sglang/srt/managers/scheduler.py +42 -6
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -2
- sglang/srt/model_loader/loader.py +45 -10
- sglang/srt/model_loader/weight_utils.py +89 -0
- sglang/srt/models/deepseek_nextn.py +7 -4
- sglang/srt/models/deepseek_v2.py +147 -4
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/server_args.py +16 -2
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +71 -0
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,949 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Optional, Sequence, Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
import torch.nn.functional as F
|
7
|
+
from transformers import Gemma3nAudioConfig, PreTrainedModel
|
8
|
+
|
9
|
+
from sglang.srt.layers.linear import (
|
10
|
+
ColumnParallelLinear,
|
11
|
+
QKVParallelLinear,
|
12
|
+
RowParallelLinear,
|
13
|
+
)
|
14
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
15
|
+
from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm
|
16
|
+
from sglang.srt.utils import add_prefix, make_layers
|
17
|
+
|
18
|
+
|
19
|
+
class Gemma3nCumulativeGroupNorm(nn.Module):
|
20
|
+
"""Applies Group Normalization cumulatively over the time dimension.
|
21
|
+
|
22
|
+
This layer normalizes the input by calculating the mean and variance
|
23
|
+
cumulatively over the time dimension (dim 1). The statistics are computed
|
24
|
+
over all feature dimensions (specified by `feature_dims` and `num_channels`)
|
25
|
+
for elements marked as valid by the optional `mask`.
|
26
|
+
|
27
|
+
If a `mask` is provided (True for valid, False for invalid/padded),
|
28
|
+
invalid time steps do not contribute to the statistics calculation, and
|
29
|
+
their corresponding output values are zeroed out.
|
30
|
+
|
31
|
+
Scale and bias, if enabled, are applied per-channel (last dimension).
|
32
|
+
This behavior is similar to JAX's `GroupNormalization` with `num_groups=1`
|
33
|
+
and `cumulative=True`.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
num_channels: int, # Number of channels (size of the last dimension)
|
39
|
+
feature_dims: Sequence[
|
40
|
+
int
|
41
|
+
], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C]
|
42
|
+
eps: float = 1e-3,
|
43
|
+
):
|
44
|
+
super().__init__()
|
45
|
+
self.num_channels = num_channels
|
46
|
+
self.feature_dims = tuple(feature_dims)
|
47
|
+
self.eps = eps
|
48
|
+
|
49
|
+
# Scale parameter depends only on the channel dimension
|
50
|
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
51
|
+
|
52
|
+
# Axes for normalization: all dimensions except Batch (0) and Time (1).
|
53
|
+
# For input [B, T, *feature_dims, C], these are dims from 2 onwards.
|
54
|
+
self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
|
55
|
+
|
56
|
+
def forward(
|
57
|
+
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
|
58
|
+
) -> torch.Tensor:
|
59
|
+
"""Applies cumulative group norm, optionally using a mask.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
x: Input tensor, shape [B, T, *feature_dims, C].
|
63
|
+
mask: Optional boolean mask, shape [B, T]. True indicates a valid
|
64
|
+
(non-padded) time step. If None, all time steps are considered valid.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
Normalized tensor with the same shape as x.
|
68
|
+
"""
|
69
|
+
expected_input_suffix = self.feature_dims + (self.num_channels,)
|
70
|
+
if x.shape[2:] != expected_input_suffix:
|
71
|
+
raise ValueError(
|
72
|
+
f"Input tensor shape suffix {x.shape[2:]} does not match expected"
|
73
|
+
f" suffix (feature_dims + num_channels) {expected_input_suffix}"
|
74
|
+
)
|
75
|
+
|
76
|
+
input_dtype = x.dtype
|
77
|
+
# Calculations are performed in float32 for numerical stability.
|
78
|
+
calc_dtype = torch.float32
|
79
|
+
x_calc = x.to(calc_dtype)
|
80
|
+
|
81
|
+
# Prepare a broadcastable mask (`mask_calc`).
|
82
|
+
# If no mask is provided, treat all elements as valid
|
83
|
+
# (mask_calc is all ones).
|
84
|
+
# Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
|
85
|
+
mask_calc = torch.ones_like(x_calc, dtype=calc_dtype)
|
86
|
+
|
87
|
+
# Cumulative Statistics Calculation
|
88
|
+
# 1. Sum of values over reduction axes at each time step.
|
89
|
+
sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True)
|
90
|
+
# 2. Cumulative sum of values over time.
|
91
|
+
cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
|
92
|
+
|
93
|
+
# 3. Count of valid elements in the normalization group at each time step.
|
94
|
+
# (A "group" here consists of all features at a given Batch, Time).
|
95
|
+
elements_in_group_at_t = torch.sum(
|
96
|
+
mask_calc, dim=self.reduction_axes, keepdim=True
|
97
|
+
)
|
98
|
+
# 4. Cumulative count of valid elements over time.
|
99
|
+
cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
|
100
|
+
# Avoid division by zero if all preceding elements were masked.
|
101
|
+
safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0)
|
102
|
+
|
103
|
+
# 5. Cumulative mean.
|
104
|
+
cum_mean = cum_sum_values / safe_cum_count_elements
|
105
|
+
|
106
|
+
# 6. Sum of squared differences from the cumulative mean.
|
107
|
+
# Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc.
|
108
|
+
# Using x_calc here for the difference, as cum_mean already accounts for masking.
|
109
|
+
squared_diff_from_mean = (x_calc - cum_mean).pow(2)
|
110
|
+
sum_sq_diff_at_t = torch.sum(
|
111
|
+
squared_diff_from_mean, dim=self.reduction_axes, keepdim=True
|
112
|
+
)
|
113
|
+
|
114
|
+
# 7. Cumulative sum of squared differences over time.
|
115
|
+
cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1)
|
116
|
+
|
117
|
+
# 8. Cumulative variance.
|
118
|
+
cum_variance = cum_sum_sq_diff / safe_cum_count_elements
|
119
|
+
|
120
|
+
# Normalize the input using the calculated cumulative statistics:
|
121
|
+
# (x - E[x]) / sqrt(Var[x] + eps)
|
122
|
+
normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps)
|
123
|
+
|
124
|
+
# Apply affine transformation (scale and bias) if enabled.
|
125
|
+
# Scale and bias are applied per-channel (last dimension).
|
126
|
+
scale = self.weight.to(calc_dtype)
|
127
|
+
# Reshape for broadcasting: [C] -> [1, ..., 1, C]
|
128
|
+
scale_view_shape = [1] * (x.dim() - 1) + [self.num_channels]
|
129
|
+
normalized_x = normalized_x * scale.view(scale_view_shape)
|
130
|
+
|
131
|
+
# Zero out outputs for time steps that were originally masked (where mask_calc is 0).
|
132
|
+
# This ensures padded/invalid positions in the input result in zero output.
|
133
|
+
final_output = normalized_x * mask_calc
|
134
|
+
|
135
|
+
return final_output.to(input_dtype)
|
136
|
+
|
137
|
+
|
138
|
+
class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
config: Gemma3nAudioConfig,
|
142
|
+
quant_config: Optional[QuantizationConfig] = None,
|
143
|
+
prefix: str = "",
|
144
|
+
):
|
145
|
+
super().__init__()
|
146
|
+
self.config = config
|
147
|
+
|
148
|
+
self.num_heads = self.config.conf_num_attention_heads
|
149
|
+
self.channels = self.config.hidden_size
|
150
|
+
self.head_dim = self.channels // self.num_heads
|
151
|
+
self.max_backward = max(0, self.config.conf_attention_context_left - 1)
|
152
|
+
self.max_forward = self.config.conf_attention_context_right
|
153
|
+
|
154
|
+
self.pos_proj = ColumnParallelLinear(
|
155
|
+
self.channels,
|
156
|
+
self.num_heads * self.head_dim,
|
157
|
+
bias=False,
|
158
|
+
quant_config=quant_config,
|
159
|
+
prefix=add_prefix("pos_proj", prefix),
|
160
|
+
)
|
161
|
+
|
162
|
+
min_timescale = 1.0
|
163
|
+
max_timescale = 1.0e4
|
164
|
+
num_timescales = self.channels // 2
|
165
|
+
log_timescale_increment = math.log(
|
166
|
+
float(max_timescale) / float(min_timescale)
|
167
|
+
) / max(num_timescales - 1, 1)
|
168
|
+
inv_timescales = min_timescale * torch.exp(
|
169
|
+
torch.arange(num_timescales) * -log_timescale_increment
|
170
|
+
)
|
171
|
+
self.register_buffer(
|
172
|
+
"inv_timescales",
|
173
|
+
inv_timescales.float().unsqueeze(0).unsqueeze(0),
|
174
|
+
persistent=False,
|
175
|
+
)
|
176
|
+
|
177
|
+
def _get_timing_signal_1d_pos(
|
178
|
+
self, position: torch.Tensor, dtype: torch.dtype
|
179
|
+
) -> torch.Tensor:
|
180
|
+
assert position.ndim == 2
|
181
|
+
position = position.float().unsqueeze(-1)
|
182
|
+
scaled_time = position * self.inv_timescales.to(
|
183
|
+
device=position.device, dtype=torch.float32
|
184
|
+
)
|
185
|
+
timing_signal = torch.cat(
|
186
|
+
[torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1
|
187
|
+
)
|
188
|
+
return timing_signal.type(dtype)
|
189
|
+
|
190
|
+
def _relative_shift(
|
191
|
+
self,
|
192
|
+
term_bd_before_shift: torch.Tensor,
|
193
|
+
batch_size: int,
|
194
|
+
num_heads: int,
|
195
|
+
num_query_blocks: int,
|
196
|
+
query_block_size: int,
|
197
|
+
key_context_size: int,
|
198
|
+
max_span_plus_1: int,
|
199
|
+
) -> torch.Tensor:
|
200
|
+
"""Performs the relative shift."""
|
201
|
+
pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
|
202
|
+
padding_tuple = (0, pad_amount_last_dim)
|
203
|
+
|
204
|
+
term_bd_padded = F.pad(term_bd_before_shift, padding_tuple)
|
205
|
+
term_bd_reshaped = term_bd_padded.reshape(
|
206
|
+
(
|
207
|
+
batch_size,
|
208
|
+
num_heads,
|
209
|
+
num_query_blocks,
|
210
|
+
query_block_size * (key_context_size + 1),
|
211
|
+
)
|
212
|
+
)
|
213
|
+
term_bd_sliced = term_bd_reshaped[
|
214
|
+
:, :, :, : query_block_size * key_context_size
|
215
|
+
]
|
216
|
+
term_bd_shifted = term_bd_sliced.reshape(
|
217
|
+
(
|
218
|
+
batch_size,
|
219
|
+
num_heads,
|
220
|
+
num_query_blocks,
|
221
|
+
query_block_size,
|
222
|
+
key_context_size,
|
223
|
+
)
|
224
|
+
)
|
225
|
+
return term_bd_shifted
|
226
|
+
|
227
|
+
def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
|
228
|
+
batch_size, num_query_blocks, query_block_size, num_heads, head_dim = (
|
229
|
+
queries.shape
|
230
|
+
)
|
231
|
+
_, _, key_context_size, _, _ = keys.shape
|
232
|
+
|
233
|
+
pos_indices = torch.arange(
|
234
|
+
self.max_backward, -self.max_forward - 1, -1, device=queries.device
|
235
|
+
).unsqueeze(0)
|
236
|
+
max_span_plus_1 = pos_indices.shape[1]
|
237
|
+
|
238
|
+
sin_emb_timing_signal = self._get_timing_signal_1d_pos(
|
239
|
+
pos_indices, dtype=queries.dtype
|
240
|
+
)
|
241
|
+
projected_sin_emb, _ = self.pos_proj(sin_emb_timing_signal)
|
242
|
+
sin_emb = projected_sin_emb.reshape(
|
243
|
+
1, max_span_plus_1, self.num_heads, self.head_dim
|
244
|
+
).squeeze(0)
|
245
|
+
|
246
|
+
queries_p = queries.permute(0, 3, 1, 2, 4)
|
247
|
+
keys_p_t = keys.permute(0, 3, 1, 4, 2)
|
248
|
+
term_ac = torch.matmul(queries_p, keys_p_t)
|
249
|
+
|
250
|
+
q_permuted = queries.permute(0, 3, 1, 2, 4)
|
251
|
+
s_permuted = sin_emb.permute(1, 2, 0)
|
252
|
+
q_reshaped = q_permuted.reshape(
|
253
|
+
batch_size, num_heads, num_query_blocks * query_block_size, head_dim
|
254
|
+
)
|
255
|
+
term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
|
256
|
+
term_bd_unshifed = term_bd_unshifed_matmul.reshape(
|
257
|
+
batch_size,
|
258
|
+
num_heads,
|
259
|
+
num_query_blocks,
|
260
|
+
query_block_size,
|
261
|
+
max_span_plus_1,
|
262
|
+
)
|
263
|
+
|
264
|
+
term_bd_shifted = self._relative_shift(
|
265
|
+
term_bd_unshifed,
|
266
|
+
batch_size,
|
267
|
+
num_heads,
|
268
|
+
num_query_blocks,
|
269
|
+
query_block_size,
|
270
|
+
key_context_size,
|
271
|
+
max_span_plus_1,
|
272
|
+
)
|
273
|
+
|
274
|
+
return term_ac + term_bd_shifted
|
275
|
+
|
276
|
+
|
277
|
+
class Gemma3nAudioAttention(nn.Module):
|
278
|
+
"""Local dot product self-attention for audio."""
|
279
|
+
|
280
|
+
def __init__(
|
281
|
+
self,
|
282
|
+
config: Gemma3nAudioConfig,
|
283
|
+
quant_config: Optional[QuantizationConfig] = None,
|
284
|
+
prefix: str = "",
|
285
|
+
):
|
286
|
+
super().__init__()
|
287
|
+
self.config = config
|
288
|
+
|
289
|
+
self.num_heads = self.config.conf_num_attention_heads
|
290
|
+
self.hidden_size = self.config.hidden_size
|
291
|
+
self.head_dim = self.hidden_size // self.num_heads
|
292
|
+
|
293
|
+
self.chunk_size = self.config.conf_attention_chunk_size
|
294
|
+
self.max_future_horizon = self.config.conf_attention_context_right
|
295
|
+
self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1)
|
296
|
+
self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
|
297
|
+
self.context_size = (
|
298
|
+
self.chunk_size + self.max_past_horizon + self.max_future_horizon
|
299
|
+
)
|
300
|
+
|
301
|
+
self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(
|
302
|
+
config,
|
303
|
+
quant_config,
|
304
|
+
prefix=add_prefix("relative_position_embedding", prefix),
|
305
|
+
)
|
306
|
+
self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
|
307
|
+
|
308
|
+
self.qkv_proj = QKVParallelLinear(
|
309
|
+
self.hidden_size,
|
310
|
+
self.head_dim,
|
311
|
+
self.num_heads,
|
312
|
+
self.num_heads,
|
313
|
+
bias=False,
|
314
|
+
quant_config=quant_config,
|
315
|
+
prefix=add_prefix("qkv_proj", prefix),
|
316
|
+
)
|
317
|
+
|
318
|
+
q_scale = self.head_dim**-0.5
|
319
|
+
r_softplus_0 = 1.0 / F.softplus(torch.tensor(0.0))
|
320
|
+
self.register_buffer(
|
321
|
+
"q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False
|
322
|
+
)
|
323
|
+
|
324
|
+
# Create local causal mask
|
325
|
+
lower_causal_mask = torch.tril(
|
326
|
+
torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
|
327
|
+
diagonal=0,
|
328
|
+
).T
|
329
|
+
upper_causal_mask = torch.tril(
|
330
|
+
torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
|
331
|
+
diagonal=self.max_past_horizon + self.max_future_horizon,
|
332
|
+
)
|
333
|
+
local_causal_valid_mask = torch.ones(
|
334
|
+
(self.chunk_size, self.context_size), dtype=torch.bool
|
335
|
+
)
|
336
|
+
local_causal_valid_mask = (
|
337
|
+
local_causal_valid_mask * lower_causal_mask * upper_causal_mask
|
338
|
+
)
|
339
|
+
self.register_buffer(
|
340
|
+
"local_causal_valid_mask", local_causal_valid_mask, persistent=False
|
341
|
+
)
|
342
|
+
|
343
|
+
self.register_buffer(
|
344
|
+
"softcap",
|
345
|
+
torch.tensor(self.attention_logits_soft_cap).float(),
|
346
|
+
persistent=False,
|
347
|
+
)
|
348
|
+
|
349
|
+
def _pad_dim1(
|
350
|
+
self, x: torch.Tensor, dim10_val: int, dim11_val: int
|
351
|
+
) -> torch.Tensor:
|
352
|
+
padding_tuple = [0] * x.ndim * 2
|
353
|
+
dim_idx_from_end = x.ndim - 2
|
354
|
+
start_idx_for_dim = 2 * dim_idx_from_end
|
355
|
+
padding_tuple[start_idx_for_dim] = dim10_val
|
356
|
+
padding_tuple[start_idx_for_dim + 1] = dim11_val
|
357
|
+
return F.pad(x, tuple(padding_tuple))
|
358
|
+
|
359
|
+
def _convert_to_block(self, x: torch.Tensor) -> torch.Tensor:
|
360
|
+
"""Turns a sequence to non overlapping blocks."""
|
361
|
+
shape = x.shape
|
362
|
+
b, t = shape[:2]
|
363
|
+
num_blocks = (t + self.chunk_size - 1) // self.chunk_size
|
364
|
+
|
365
|
+
if (padding_len := num_blocks * self.chunk_size - t) > 0:
|
366
|
+
x = self._pad_dim1(x, 0, padding_len)
|
367
|
+
|
368
|
+
permute_dims = (b, num_blocks, self.chunk_size) + shape[2:]
|
369
|
+
x = x.reshape(permute_dims).contiguous()
|
370
|
+
return x
|
371
|
+
|
372
|
+
def _extract_block_context(self, x: torch.Tensor) -> torch.Tensor:
|
373
|
+
"""Extracts temporal context for every block."""
|
374
|
+
pad_left = self.max_past_horizon
|
375
|
+
pad_right = self.max_future_horizon + self.chunk_size - 1
|
376
|
+
x = self._pad_dim1(x, pad_left, pad_right)
|
377
|
+
|
378
|
+
frame_len = self.context_size
|
379
|
+
frame_step = self.chunk_size
|
380
|
+
|
381
|
+
x_unfolded = x.unfold(dimension=1, size=frame_len, step=frame_step)
|
382
|
+
|
383
|
+
if x.ndim > 2 and x_unfolded.ndim > 3:
|
384
|
+
x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2)
|
385
|
+
|
386
|
+
return x_unfolded.contiguous()
|
387
|
+
|
388
|
+
def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
|
389
|
+
# Project to Q, K, V
|
390
|
+
qkv, _ = self.qkv_proj(x)
|
391
|
+
query_states, key_states, value_states = qkv.chunk(chunks=3, dim=-1)
|
392
|
+
|
393
|
+
# Reshape
|
394
|
+
query_states = query_states.reshape(
|
395
|
+
*x.shape[:-1], self.num_heads, self.head_dim
|
396
|
+
).contiguous()
|
397
|
+
key_states = key_states.reshape(
|
398
|
+
*x.shape[:-1], self.num_heads, self.head_dim
|
399
|
+
).contiguous()
|
400
|
+
value_states = value_states.reshape(
|
401
|
+
*x.shape[:-1], self.num_heads, self.head_dim
|
402
|
+
).contiguous()
|
403
|
+
|
404
|
+
# Apply per-dim scale
|
405
|
+
per_dim_scale_sp = F.softplus(self.per_dim_scale)
|
406
|
+
broadcast_shape = (1, 1, 1, self.head_dim)
|
407
|
+
per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape)
|
408
|
+
query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast
|
409
|
+
|
410
|
+
batch_size, q_time = query_states.shape[:2]
|
411
|
+
|
412
|
+
# Convert to blocks
|
413
|
+
query_blocks = self._convert_to_block(query_states)
|
414
|
+
key_blocks = self._extract_block_context(key_states)
|
415
|
+
value_blocks = self._extract_block_context(value_states)
|
416
|
+
num_query_blocks = query_blocks.shape[1]
|
417
|
+
|
418
|
+
# Create mask for valid positions
|
419
|
+
original_valid_mask = ~mask
|
420
|
+
extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask)
|
421
|
+
|
422
|
+
if (
|
423
|
+
extracted_valid_mask_blocks.ndim == 4
|
424
|
+
and extracted_valid_mask_blocks.shape[0] == batch_size
|
425
|
+
and extracted_valid_mask_blocks.shape[1] == num_query_blocks
|
426
|
+
and extracted_valid_mask_blocks.shape[2]
|
427
|
+
* extracted_valid_mask_blocks.shape[3]
|
428
|
+
== self.context_size
|
429
|
+
):
|
430
|
+
extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
|
431
|
+
batch_size, num_query_blocks, self.context_size
|
432
|
+
)
|
433
|
+
|
434
|
+
condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(
|
435
|
+
1
|
436
|
+
).unsqueeze(-2)
|
437
|
+
condition_from_causality = (
|
438
|
+
self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0)
|
439
|
+
)
|
440
|
+
|
441
|
+
final_condition_for_where = torch.logical_and(
|
442
|
+
condition_from_input_validity,
|
443
|
+
condition_from_causality.to(condition_from_input_validity.device),
|
444
|
+
)
|
445
|
+
|
446
|
+
# Compute attention scores
|
447
|
+
logits = self.relative_position_embedding(query_blocks, key_blocks)
|
448
|
+
|
449
|
+
# Apply attention logit softcap
|
450
|
+
softcap_val = self.softcap.to(logits.device)
|
451
|
+
logits = logits / softcap_val
|
452
|
+
logits = torch.tanh(logits)
|
453
|
+
logits = logits * softcap_val
|
454
|
+
|
455
|
+
# Apply the combined mask.
|
456
|
+
# final_condition_for_where will broadcast with logits [B,N,U,W,C]
|
457
|
+
logits = torch.where(
|
458
|
+
final_condition_for_where, logits, torch.finfo(logits.dtype).min
|
459
|
+
)
|
460
|
+
|
461
|
+
probabilities = F.softmax(logits, dim=-1, dtype=torch.float32).to(
|
462
|
+
dtype=value_blocks.dtype
|
463
|
+
)
|
464
|
+
|
465
|
+
# context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
|
466
|
+
b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
|
467
|
+
h_dim = value_blocks.shape[-1]
|
468
|
+
prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
|
469
|
+
v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim)
|
470
|
+
result_bmm = torch.bmm(prob_bun, v_bun)
|
471
|
+
context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(
|
472
|
+
0, 1, 3, 2, 4
|
473
|
+
)
|
474
|
+
context_vectors = context_vectors.reshape(
|
475
|
+
(
|
476
|
+
batch_size,
|
477
|
+
num_query_blocks * self.chunk_size,
|
478
|
+
self.num_heads,
|
479
|
+
self.head_dim,
|
480
|
+
)
|
481
|
+
)
|
482
|
+
context_vectors = context_vectors[:, :q_time]
|
483
|
+
|
484
|
+
return context_vectors
|
485
|
+
|
486
|
+
|
487
|
+
class Gemma3nAudioSSCPConvBlock(nn.Module):
|
488
|
+
"""A single convolution block for the SubSampleConvProjection."""
|
489
|
+
|
490
|
+
def __init__(
|
491
|
+
self,
|
492
|
+
config: Gemma3nAudioConfig,
|
493
|
+
idx: int,
|
494
|
+
input_freq_dim: int,
|
495
|
+
manual_padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
|
496
|
+
quant_config: Optional[QuantizationConfig] = None,
|
497
|
+
prefix: str = "",
|
498
|
+
):
|
499
|
+
super().__init__()
|
500
|
+
self.config = config
|
501
|
+
self.manual_padding = manual_padding
|
502
|
+
|
503
|
+
in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
|
504
|
+
out_channels = self.config.sscp_conv_channel_size[idx]
|
505
|
+
kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
|
506
|
+
stride_h, stride_w = self.config.sscp_conv_stride_size[idx]
|
507
|
+
|
508
|
+
self.conv = nn.Conv2d(
|
509
|
+
in_channels=in_channels,
|
510
|
+
out_channels=out_channels,
|
511
|
+
kernel_size=(kernel_h, kernel_w),
|
512
|
+
stride=(stride_h, stride_w),
|
513
|
+
padding=(0, 0), # Manual padding is used
|
514
|
+
bias=False,
|
515
|
+
)
|
516
|
+
|
517
|
+
f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1]
|
518
|
+
f_out_conv = (f_in_padded - kernel_w) // stride_w + 1
|
519
|
+
|
520
|
+
self.norm = Gemma3nCumulativeGroupNorm(
|
521
|
+
num_channels=out_channels,
|
522
|
+
feature_dims=(f_out_conv,),
|
523
|
+
eps=self.config.sscp_conv_group_norm_eps,
|
524
|
+
)
|
525
|
+
|
526
|
+
self.activation = nn.ReLU()
|
527
|
+
|
528
|
+
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
|
529
|
+
audio_encodings_padded = F.pad(
|
530
|
+
audio_encodings, self.manual_padding, mode="constant", value=0.0
|
531
|
+
)
|
532
|
+
audio_encodings_conv = self.conv(audio_encodings_padded)
|
533
|
+
x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous()
|
534
|
+
x_normed = self.norm(x_for_norm)
|
535
|
+
audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous()
|
536
|
+
return self.activation(audio_encodings_normed)
|
537
|
+
|
538
|
+
|
539
|
+
class Gemma3nAudioSubSampleConvProjection(nn.Module):
|
540
|
+
def __init__(
|
541
|
+
self,
|
542
|
+
config: Gemma3nAudioConfig,
|
543
|
+
quant_config: Optional[QuantizationConfig] = None,
|
544
|
+
prefix: str = "",
|
545
|
+
):
|
546
|
+
super().__init__()
|
547
|
+
self.config = config
|
548
|
+
|
549
|
+
current_f_for_block_input = config.input_feat_size
|
550
|
+
calculated_block_padding = []
|
551
|
+
calculated_f_out_dims = []
|
552
|
+
|
553
|
+
for i in range(2): # Assuming 2 conv layers
|
554
|
+
kernel_h, kernel_w = config.sscp_conv_kernel_size[i]
|
555
|
+
stride_h, stride_w = config.sscp_conv_stride_size[i]
|
556
|
+
|
557
|
+
# Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
|
558
|
+
pad_t_top = 0
|
559
|
+
pad_t_bottom = kernel_h - 1
|
560
|
+
|
561
|
+
# Frequency Padding (Width for Conv2d)
|
562
|
+
pad_f_left = 1
|
563
|
+
pad_f_right = 1
|
564
|
+
|
565
|
+
manual_padding_tuple = (pad_f_left, pad_f_right, pad_t_top, pad_t_bottom)
|
566
|
+
calculated_block_padding.append(manual_padding_tuple)
|
567
|
+
|
568
|
+
f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right
|
569
|
+
f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1
|
570
|
+
calculated_f_out_dims.append(f_out_after_conv)
|
571
|
+
current_f_for_block_input = f_out_after_conv
|
572
|
+
|
573
|
+
self.conv_0 = Gemma3nAudioSSCPConvBlock(
|
574
|
+
idx=0,
|
575
|
+
input_freq_dim=config.input_feat_size,
|
576
|
+
config=config,
|
577
|
+
manual_padding=calculated_block_padding[0],
|
578
|
+
quant_config=quant_config,
|
579
|
+
prefix=add_prefix("conv_0", prefix),
|
580
|
+
)
|
581
|
+
self.conv_1 = Gemma3nAudioSSCPConvBlock(
|
582
|
+
idx=1,
|
583
|
+
input_freq_dim=calculated_f_out_dims[0],
|
584
|
+
config=config,
|
585
|
+
manual_padding=calculated_block_padding[1],
|
586
|
+
quant_config=quant_config,
|
587
|
+
prefix=add_prefix("conv_1", prefix),
|
588
|
+
)
|
589
|
+
|
590
|
+
final_c_out = config.sscp_conv_channel_size[-1]
|
591
|
+
final_f_out = calculated_f_out_dims[-1]
|
592
|
+
self.input_proj_in_features = final_c_out * final_f_out
|
593
|
+
|
594
|
+
self.input_proj_linear = RowParallelLinear(
|
595
|
+
self.input_proj_in_features,
|
596
|
+
self.config.hidden_size,
|
597
|
+
bias=False,
|
598
|
+
quant_config=quant_config,
|
599
|
+
prefix=add_prefix("input_proj_linear", prefix),
|
600
|
+
)
|
601
|
+
|
602
|
+
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
|
603
|
+
audio_encodings_reshaped = audio_encodings.unsqueeze(1)
|
604
|
+
x = self.conv_0(audio_encodings_reshaped)
|
605
|
+
x = self.conv_1(x)
|
606
|
+
b, c_out, t_out, f_out = x.shape
|
607
|
+
x_permuted = x.permute(0, 2, 3, 1).contiguous()
|
608
|
+
output_flattened = x_permuted.view(b, t_out, f_out * c_out)
|
609
|
+
output, _ = self.input_proj_linear(output_flattened)
|
610
|
+
return output
|
611
|
+
|
612
|
+
|
613
|
+
class Gemma3nAudioConformerAttention(nn.Module):
|
614
|
+
def __init__(
|
615
|
+
self,
|
616
|
+
config: Gemma3nAudioConfig,
|
617
|
+
quant_config: Optional[QuantizationConfig] = None,
|
618
|
+
prefix: str = "",
|
619
|
+
):
|
620
|
+
super().__init__()
|
621
|
+
self.config = config
|
622
|
+
|
623
|
+
head_dim = self.config.hidden_size // self.config.conf_num_attention_heads
|
624
|
+
self.post_in_shape = (self.config.conf_num_attention_heads, head_dim)
|
625
|
+
self.post_in_features = self.config.hidden_size
|
626
|
+
|
627
|
+
self.register_buffer(
|
628
|
+
"gradient_clipping",
|
629
|
+
torch.tensor(self.config.gradient_clipping),
|
630
|
+
persistent=False,
|
631
|
+
)
|
632
|
+
|
633
|
+
self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
634
|
+
self.attn = Gemma3nAudioAttention(
|
635
|
+
config, quant_config, prefix=add_prefix("attn", prefix)
|
636
|
+
)
|
637
|
+
self.post = RowParallelLinear(
|
638
|
+
self.post_in_features,
|
639
|
+
self.config.hidden_size,
|
640
|
+
bias=False,
|
641
|
+
quant_config=quant_config,
|
642
|
+
prefix=add_prefix("post", prefix),
|
643
|
+
)
|
644
|
+
self.post_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
645
|
+
|
646
|
+
def forward(
|
647
|
+
self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor
|
648
|
+
) -> torch.Tensor:
|
649
|
+
audio_encodings_input_to_attn = audio_encodings
|
650
|
+
audio_encodings = torch.clamp(
|
651
|
+
audio_encodings, -self.gradient_clipping, self.gradient_clipping
|
652
|
+
)
|
653
|
+
audio_encodings_norm = self.pre_attn_norm(audio_encodings)
|
654
|
+
audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask)
|
655
|
+
|
656
|
+
b, t, num_heads, head_dim = audio_encodings_attn_out.shape
|
657
|
+
audio_encodings_reshaped = audio_encodings_attn_out.reshape(
|
658
|
+
b, t, num_heads * head_dim
|
659
|
+
)
|
660
|
+
|
661
|
+
audio_encodings, _ = self.post(audio_encodings_reshaped)
|
662
|
+
audio_encodings = torch.clamp(
|
663
|
+
audio_encodings, -self.gradient_clipping, self.gradient_clipping
|
664
|
+
)
|
665
|
+
return audio_encodings_input_to_attn + self.post_norm(audio_encodings)
|
666
|
+
|
667
|
+
|
668
|
+
class Gemma3nAudioConformerFeedForward(nn.Module):
|
669
|
+
def __init__(
|
670
|
+
self,
|
671
|
+
config: Gemma3nAudioConfig,
|
672
|
+
quant_config: Optional[QuantizationConfig] = None,
|
673
|
+
prefix: str = "",
|
674
|
+
):
|
675
|
+
super().__init__()
|
676
|
+
self.config = config
|
677
|
+
|
678
|
+
self.register_buffer(
|
679
|
+
"gradient_clipping",
|
680
|
+
torch.tensor(self.config.gradient_clipping),
|
681
|
+
persistent=False,
|
682
|
+
)
|
683
|
+
|
684
|
+
self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
685
|
+
self.ffw_layer_1 = ColumnParallelLinear(
|
686
|
+
self.config.hidden_size,
|
687
|
+
self.config.hidden_size * 4,
|
688
|
+
bias=False,
|
689
|
+
quant_config=quant_config,
|
690
|
+
prefix=add_prefix("ffw_layer_1", prefix),
|
691
|
+
)
|
692
|
+
self.ffw_layer_2 = RowParallelLinear(
|
693
|
+
self.config.hidden_size * 4,
|
694
|
+
self.config.hidden_size,
|
695
|
+
bias=False,
|
696
|
+
quant_config=quant_config,
|
697
|
+
prefix=add_prefix("ffw_layer_2", prefix),
|
698
|
+
)
|
699
|
+
self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
700
|
+
self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
|
701
|
+
|
702
|
+
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
|
703
|
+
residual = audio_encodings
|
704
|
+
audio_encodings = torch.clamp(
|
705
|
+
audio_encodings, -self.gradient_clipping, self.gradient_clipping
|
706
|
+
)
|
707
|
+
audio_encodings = self.pre_layer_norm(audio_encodings)
|
708
|
+
audio_encodings, _ = self.ffw_layer_1(audio_encodings)
|
709
|
+
audio_encodings = F.silu(audio_encodings)
|
710
|
+
audio_encodings, _ = self.ffw_layer_2(audio_encodings)
|
711
|
+
audio_encodings = torch.clamp(
|
712
|
+
audio_encodings, -self.gradient_clipping, self.gradient_clipping
|
713
|
+
)
|
714
|
+
audio_encodings = self.post_layer_norm(audio_encodings)
|
715
|
+
return residual + (audio_encodings * self.post_layer_scale)
|
716
|
+
|
717
|
+
|
718
|
+
class Gemma3nAudioConformerLightConv1d(nn.Module):
|
719
|
+
def __init__(
|
720
|
+
self,
|
721
|
+
config: Gemma3nAudioConfig,
|
722
|
+
quant_config: Optional[QuantizationConfig] = None,
|
723
|
+
prefix: str = "",
|
724
|
+
):
|
725
|
+
super().__init__()
|
726
|
+
self.config = config
|
727
|
+
|
728
|
+
self.pre_layer_norm = Gemma3nRMSNorm(
|
729
|
+
self.config.hidden_size, eps=self.config.rms_norm_eps
|
730
|
+
)
|
731
|
+
self.linear_start = ColumnParallelLinear(
|
732
|
+
self.config.hidden_size,
|
733
|
+
self.config.hidden_size * 2,
|
734
|
+
bias=False,
|
735
|
+
quant_config=quant_config,
|
736
|
+
prefix=add_prefix("linear_start", prefix),
|
737
|
+
)
|
738
|
+
|
739
|
+
self.depthwise_conv1d = nn.Conv1d(
|
740
|
+
in_channels=self.config.hidden_size,
|
741
|
+
out_channels=self.config.hidden_size,
|
742
|
+
kernel_size=self.config.conf_conv_kernel_size,
|
743
|
+
stride=1,
|
744
|
+
padding=0, # Manual causal padding
|
745
|
+
groups=self.config.hidden_size, # Depthwise
|
746
|
+
bias=False,
|
747
|
+
)
|
748
|
+
self.register_buffer(
|
749
|
+
"gradient_clipping",
|
750
|
+
torch.tensor(self.config.gradient_clipping),
|
751
|
+
persistent=False,
|
752
|
+
)
|
753
|
+
self.conv_norm = Gemma3nRMSNorm(
|
754
|
+
self.config.hidden_size, eps=self.config.rms_norm_eps
|
755
|
+
)
|
756
|
+
self.linear_end = RowParallelLinear(
|
757
|
+
self.config.hidden_size,
|
758
|
+
self.config.hidden_size,
|
759
|
+
bias=False,
|
760
|
+
quant_config=quant_config,
|
761
|
+
prefix=add_prefix("linear_end", prefix),
|
762
|
+
)
|
763
|
+
|
764
|
+
self.causal_padding = self.config.conf_conv_kernel_size - 1
|
765
|
+
|
766
|
+
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
|
767
|
+
audio_encodings_residual = audio_encodings # Save for residual connection
|
768
|
+
|
769
|
+
audio_encodings = self.pre_layer_norm(audio_encodings)
|
770
|
+
audio_encodings, _ = self.linear_start(audio_encodings)
|
771
|
+
audio_encodings = F.glu(audio_encodings, dim=-1)
|
772
|
+
|
773
|
+
# Permute for Conv1d: [B, T, D] -> [B, D, T]
|
774
|
+
audio_encodings_permuted = audio_encodings.permute(0, 2, 1)
|
775
|
+
# Apply manual causal padding
|
776
|
+
audio_encodings_permuted_padded = F.pad(
|
777
|
+
audio_encodings_permuted, (self.causal_padding, 0)
|
778
|
+
)
|
779
|
+
audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded)
|
780
|
+
# Permute back: [B, D, T_out] -> [B, T_out, D]
|
781
|
+
audio_encodings = audio_encodings.permute(0, 2, 1)
|
782
|
+
audio_encodings = torch.clamp(
|
783
|
+
audio_encodings, -self.gradient_clipping, self.gradient_clipping
|
784
|
+
)
|
785
|
+
audio_encodings = self.conv_norm(audio_encodings)
|
786
|
+
audio_encodings = F.silu(audio_encodings)
|
787
|
+
audio_encodings, _ = self.linear_end(audio_encodings)
|
788
|
+
output = audio_encodings + audio_encodings_residual
|
789
|
+
return output
|
790
|
+
|
791
|
+
|
792
|
+
class Gemma3nAudioConformerBlock(nn.Module):
|
793
|
+
def __init__(
|
794
|
+
self,
|
795
|
+
config: Gemma3nAudioConfig,
|
796
|
+
quant_config: Optional[QuantizationConfig] = None,
|
797
|
+
prefix: str = "",
|
798
|
+
):
|
799
|
+
super().__init__()
|
800
|
+
self.config = config
|
801
|
+
|
802
|
+
self.ffw_layer_start = Gemma3nAudioConformerFeedForward(
|
803
|
+
config, quant_config, prefix=add_prefix("ffw_layer_start", prefix)
|
804
|
+
)
|
805
|
+
self.attention = Gemma3nAudioConformerAttention(
|
806
|
+
config, quant_config, prefix=add_prefix("attention", prefix)
|
807
|
+
)
|
808
|
+
self.lconv1d = Gemma3nAudioConformerLightConv1d(
|
809
|
+
config, quant_config, prefix=add_prefix("lconv1d", prefix)
|
810
|
+
)
|
811
|
+
self.ffw_layer_end = Gemma3nAudioConformerFeedForward(
|
812
|
+
config, quant_config, prefix=add_prefix("ffw_layer_end", prefix)
|
813
|
+
)
|
814
|
+
self.register_buffer(
|
815
|
+
"gradient_clipping",
|
816
|
+
torch.tensor(self.config.gradient_clipping),
|
817
|
+
persistent=False,
|
818
|
+
)
|
819
|
+
self.norm = Gemma3nRMSNorm(self.config.hidden_size)
|
820
|
+
|
821
|
+
def forward(
|
822
|
+
self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor
|
823
|
+
) -> torch.Tensor:
|
824
|
+
audio_encodings = self.ffw_layer_start(audio_encodings)
|
825
|
+
audio_encodings = self.attention(audio_encodings, audio_mel_mask)
|
826
|
+
validity_mask_for_lconv = ~audio_mel_mask # True for valid
|
827
|
+
audio_encodings_for_lconv_input = (
|
828
|
+
audio_encodings
|
829
|
+
* validity_mask_for_lconv.unsqueeze(-1).to(audio_encodings.dtype)
|
830
|
+
)
|
831
|
+
audio_encodings = self.lconv1d(audio_encodings_for_lconv_input)
|
832
|
+
|
833
|
+
audio_encodings = self.ffw_layer_end(audio_encodings)
|
834
|
+
audio_encodings = torch.clamp(
|
835
|
+
audio_encodings, -self.gradient_clipping, self.gradient_clipping
|
836
|
+
)
|
837
|
+
output = self.norm(audio_encodings)
|
838
|
+
return output
|
839
|
+
|
840
|
+
|
841
|
+
class Gemma3nAudioEncoder(PreTrainedModel):
|
842
|
+
"""A Universal Speech Encoder -- https://arxiv.org/abs/2303.01037"""
|
843
|
+
|
844
|
+
config_class = Gemma3nAudioConfig
|
845
|
+
|
846
|
+
def __init__(
|
847
|
+
self,
|
848
|
+
config: Gemma3nAudioConfig,
|
849
|
+
quant_config: Optional[QuantizationConfig] = None,
|
850
|
+
prefix: str = "",
|
851
|
+
):
|
852
|
+
super().__init__(config)
|
853
|
+
self.config = config
|
854
|
+
|
855
|
+
self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(
|
856
|
+
config, quant_config, prefix=add_prefix("subsample_conv_projection", prefix)
|
857
|
+
)
|
858
|
+
self.conformer = make_layers(
|
859
|
+
config.conf_num_hidden_layers,
|
860
|
+
lambda idx, prefix: Gemma3nAudioConformerBlock(
|
861
|
+
config=config,
|
862
|
+
quant_config=quant_config,
|
863
|
+
prefix=prefix,
|
864
|
+
),
|
865
|
+
prefix=add_prefix("conformer", prefix),
|
866
|
+
)
|
867
|
+
|
868
|
+
def forward(
|
869
|
+
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
|
870
|
+
) -> Tuple[torch.Tensor, torch.BoolTensor]:
|
871
|
+
"""Encodes a batch of MELs.
|
872
|
+
|
873
|
+
Args:
|
874
|
+
audio_mel: a torch.Tensor of shape [batch, num_frames, mel_bins].
|
875
|
+
audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
|
876
|
+
|
877
|
+
Returns:
|
878
|
+
audio_encodings: a torch.Tensor of shape
|
879
|
+
`[batch_size, reduced_time_frames, hidden_size]`
|
880
|
+
audio_mel_mask: a torch.BoolTensor of shape [batch, reduced_time_frames].
|
881
|
+
"""
|
882
|
+
audio_encodings = self.subsample_conv_projection(
|
883
|
+
audio_mel
|
884
|
+
) # audio_encodings: [B, T_sub, D]
|
885
|
+
|
886
|
+
# Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
|
887
|
+
t_sub = audio_encodings.shape[1]
|
888
|
+
|
889
|
+
time_stride_product = 1
|
890
|
+
for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
|
891
|
+
time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
|
892
|
+
|
893
|
+
# Create indices for gathering from the original mask.
|
894
|
+
# These indices map to original time steps corresponding to the start of each
|
895
|
+
# receptive field in the subsampled output.
|
896
|
+
indices = (
|
897
|
+
torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product
|
898
|
+
)
|
899
|
+
indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1)
|
900
|
+
|
901
|
+
# Expand indices for batch compatibility if B > 1 and indices is 1D.
|
902
|
+
if audio_mel_mask.ndim > 1 and indices.ndim == 1:
|
903
|
+
indices = indices.unsqueeze(0).expand(
|
904
|
+
audio_mel_mask.shape[0], -1
|
905
|
+
) # [B, T_sub]
|
906
|
+
elif (
|
907
|
+
audio_mel_mask.ndim == indices.ndim
|
908
|
+
and audio_mel_mask.shape[0] == 1
|
909
|
+
and indices.shape[0] != 1
|
910
|
+
and t_sub == indices.shape[0]
|
911
|
+
):
|
912
|
+
# Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
|
913
|
+
indices = indices.unsqueeze(0)
|
914
|
+
|
915
|
+
current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub]
|
916
|
+
|
917
|
+
# Fallback: Ensure mask length matches feature length after gather.
|
918
|
+
if current_mask.shape[1] != t_sub:
|
919
|
+
if current_mask.shape[1] > t_sub:
|
920
|
+
current_mask = current_mask[:, :t_sub]
|
921
|
+
else: # current_mask.shape[1] < t_sub
|
922
|
+
padding_needed = t_sub - current_mask.shape[1]
|
923
|
+
current_mask = F.pad(
|
924
|
+
current_mask, (0, padding_needed), value=True
|
925
|
+
) # Pad with True (masked)
|
926
|
+
|
927
|
+
for i, block in enumerate(self.conformer):
|
928
|
+
audio_encodings = block(
|
929
|
+
audio_encodings, current_mask
|
930
|
+
) # Pass the processed mask
|
931
|
+
|
932
|
+
if self.config.conf_reduction_factor > 1:
|
933
|
+
audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
|
934
|
+
# Reduce the mask as well
|
935
|
+
current_mask = current_mask[:, :: self.config.conf_reduction_factor]
|
936
|
+
|
937
|
+
# Final masking of audio_encodings based on the final current_mask
|
938
|
+
# Ensure current_mask length matches the finally reduced audio_encodings length
|
939
|
+
if current_mask.shape[1] != audio_encodings.shape[1]:
|
940
|
+
target_len = audio_encodings.shape[1]
|
941
|
+
mask_current_len = current_mask.shape[1]
|
942
|
+
if target_len > mask_current_len:
|
943
|
+
padding_needed = target_len - mask_current_len
|
944
|
+
current_mask = F.pad(current_mask, (0, padding_needed), value=True)
|
945
|
+
elif mask_current_len > target_len: # mask is longer
|
946
|
+
current_mask = current_mask[:, :target_len]
|
947
|
+
|
948
|
+
audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
|
949
|
+
return audio_encodings, current_mask
|