sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +0 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  12. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  13. sglang/srt/constrained/xgrammar_backend.py +26 -4
  14. sglang/srt/custom_op.py +0 -62
  15. sglang/srt/disaggregation/decode.py +62 -6
  16. sglang/srt/disaggregation/mini_lb.py +5 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +32 -62
  18. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  19. sglang/srt/disaggregation/prefill.py +40 -4
  20. sglang/srt/disaggregation/utils.py +15 -0
  21. sglang/srt/entrypoints/verl_engine.py +7 -5
  22. sglang/srt/layers/activation.py +6 -8
  23. sglang/srt/layers/attention/flashattention_backend.py +114 -71
  24. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  25. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  26. sglang/srt/layers/attention/triton_backend.py +6 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  28. sglang/srt/layers/layernorm.py +1 -1
  29. sglang/srt/layers/linear.py +17 -3
  30. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  31. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  34. sglang/srt/layers/moe/topk.py +27 -30
  35. sglang/srt/layers/parameter.py +0 -2
  36. sglang/srt/layers/quantization/__init__.py +1 -0
  37. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  38. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
  39. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  40. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  41. sglang/srt/layers/quantization/fp8.py +115 -132
  42. sglang/srt/layers/quantization/fp8_kernel.py +213 -57
  43. sglang/srt/layers/quantization/fp8_utils.py +187 -262
  44. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  45. sglang/srt/layers/quantization/utils.py +5 -11
  46. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  47. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  48. sglang/srt/layers/radix_attention.py +15 -0
  49. sglang/srt/layers/rotary_embedding.py +3 -2
  50. sglang/srt/layers/sampler.py +5 -10
  51. sglang/srt/lora/backend/base_backend.py +18 -2
  52. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  53. sglang/srt/lora/backend/triton_backend.py +1 -1
  54. sglang/srt/lora/layers.py +1 -1
  55. sglang/srt/lora/lora.py +1 -1
  56. sglang/srt/lora/lora_manager.py +1 -1
  57. sglang/srt/managers/detokenizer_manager.py +0 -1
  58. sglang/srt/managers/io_struct.py +1 -0
  59. sglang/srt/managers/mm_utils.py +4 -3
  60. sglang/srt/managers/multimodal_processor.py +0 -2
  61. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  62. sglang/srt/managers/schedule_batch.py +2 -4
  63. sglang/srt/managers/scheduler.py +12 -71
  64. sglang/srt/managers/tokenizer_manager.py +1 -0
  65. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  66. sglang/srt/mem_cache/memory_pool.py +7 -2
  67. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  68. sglang/srt/model_executor/model_runner.py +20 -27
  69. sglang/srt/models/bert.py +398 -0
  70. sglang/srt/models/deepseek.py +1 -1
  71. sglang/srt/models/deepseek_nextn.py +74 -70
  72. sglang/srt/models/deepseek_v2.py +289 -348
  73. sglang/srt/models/llama.py +5 -5
  74. sglang/srt/models/minicpm3.py +29 -201
  75. sglang/srt/models/qwen2.py +4 -1
  76. sglang/srt/models/qwen2_moe.py +14 -13
  77. sglang/srt/models/qwen3.py +335 -0
  78. sglang/srt/models/qwen3_moe.py +423 -0
  79. sglang/srt/reasoning_parser.py +0 -1
  80. sglang/srt/sampling/sampling_batch_info.py +2 -3
  81. sglang/srt/server_args.py +34 -32
  82. sglang/srt/speculative/eagle_worker.py +4 -7
  83. sglang/srt/utils.py +16 -1
  84. sglang/test/runners.py +5 -1
  85. sglang/test/test_block_fp8.py +167 -0
  86. sglang/test/test_custom_ops.py +1 -1
  87. sglang/version.py +1 -1
  88. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
  89. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
  90. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  91. sglang/lang/__init__.py +0 -0
  92. sglang/srt/lora/backend/__init__.py +0 -25
  93. sglang/srt/server.py +0 -18
  94. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  95. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,398 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from typing import Any, Dict, Iterable, Optional, Set, Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
8
+ from sglang.srt.layers.activation import get_act_fn
9
+ from sglang.srt.layers.linear import (
10
+ ColumnParallelLinear,
11
+ QKVParallelLinear,
12
+ RowParallelLinear,
13
+ )
14
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
15
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
16
+ from sglang.srt.layers.radix_attention import AttentionType, RadixAttention
17
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
18
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
19
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
20
+
21
+ BertConfig = None
22
+
23
+
24
+ class BertEmbedding(nn.Module):
25
+
26
+ def __init__(self, config: BertConfig):
27
+
28
+ super().__init__()
29
+ self.size = config.hidden_size
30
+ self.word_embeddings = VocabParallelEmbedding(
31
+ config.vocab_size, config.hidden_size
32
+ )
33
+ self.position_embeddings = VocabParallelEmbedding(
34
+ config.max_position_embeddings, config.hidden_size
35
+ )
36
+ self.token_type_embeddings = VocabParallelEmbedding(
37
+ config.type_vocab_size, config.hidden_size
38
+ )
39
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
40
+ self.position_ids = nn.Parameter(
41
+ torch.empty((1, config.max_position_embeddings)),
42
+ )
43
+
44
+ self.position_embedding_type = config.position_embedding_type
45
+ if self.position_embedding_type != "absolute":
46
+ raise ValueError(
47
+ "Only 'absolute' position_embedding_type" + " is supported"
48
+ )
49
+
50
+ def forward(
51
+ self,
52
+ input_ids: torch.Tensor,
53
+ position_ids: torch.Tensor,
54
+ ) -> torch.Tensor:
55
+ input_shape = input_ids.size()
56
+
57
+ # Input embeddings.
58
+ inputs_embeds = self.word_embeddings(input_ids)
59
+
60
+ # Position embeddings.
61
+ position_embeddings = self.position_embeddings(position_ids)
62
+
63
+ token_type_ids = torch.zeros(
64
+ input_shape, dtype=torch.long, device=inputs_embeds.device
65
+ )
66
+
67
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
68
+
69
+ embeddings = inputs_embeds + token_type_embeddings + position_embeddings
70
+ embeddings = self.LayerNorm(embeddings)
71
+ return embeddings
72
+
73
+
74
+ class BertEncoder(nn.Module):
75
+
76
+ def __init__(
77
+ self,
78
+ config: BertConfig,
79
+ quant_config: Optional[QuantizationConfig] = None,
80
+ prefix: str = "",
81
+ ):
82
+ super().__init__()
83
+ self.config = config
84
+ self.quant_config = quant_config
85
+ self.layer = nn.ModuleList(
86
+ [
87
+ BertLayer(
88
+ config=config,
89
+ layer_id=layer_idx,
90
+ quant_config=quant_config,
91
+ prefix=f"{prefix}.layer.{layer_idx}",
92
+ )
93
+ for layer_idx in range(config.num_hidden_layers)
94
+ ]
95
+ )
96
+
97
+ def forward(
98
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
99
+ ) -> torch.Tensor:
100
+ for layer in self.layer:
101
+ hidden_states = layer(hidden_states, forward_batch)
102
+ return hidden_states
103
+
104
+
105
+ class BertLayer(nn.Module):
106
+
107
+ def __init__(
108
+ self,
109
+ config: BertConfig,
110
+ layer_id: int = 0,
111
+ quant_config: Optional[QuantizationConfig] = None,
112
+ prefix: str = "",
113
+ ):
114
+ super().__init__()
115
+
116
+ self.attention = BertAttention(
117
+ hidden_size=config.hidden_size,
118
+ num_attention_heads=config.num_attention_heads,
119
+ layer_id=layer_id,
120
+ layer_norm_eps=config.layer_norm_eps,
121
+ quant_config=quant_config,
122
+ prefix=f"{prefix}.attention",
123
+ )
124
+
125
+ self.intermediate = BertIntermediate(
126
+ hidden_size=config.hidden_size,
127
+ intermediate_size=config.intermediate_size,
128
+ hidden_act=config.hidden_act,
129
+ quant_config=quant_config,
130
+ prefix=f"{prefix}.intermediate",
131
+ )
132
+
133
+ self.output = BertOutput(
134
+ hidden_size=config.hidden_size,
135
+ intermediate_size=config.intermediate_size,
136
+ layer_norm_eps=config.layer_norm_eps,
137
+ quant_config=quant_config,
138
+ prefix=f"{prefix}.output",
139
+ )
140
+
141
+ def forward(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
142
+ attn_output = self.attention(hidden_states, forward_batch)
143
+ intermediate_output = self.intermediate(attn_output)
144
+ output = self.output(intermediate_output, attn_output)
145
+ return output
146
+
147
+
148
+ class BertAttention(nn.Module):
149
+
150
+ def __init__(
151
+ self,
152
+ hidden_size: int,
153
+ num_attention_heads: int,
154
+ layer_norm_eps: float,
155
+ layer_id: int = 0,
156
+ quant_config: Optional[QuantizationConfig] = None,
157
+ prefix: str = "",
158
+ ):
159
+ super().__init__()
160
+
161
+ self.self_attn = BertSelfAttention(
162
+ hidden_size=hidden_size,
163
+ num_attention_heads=num_attention_heads,
164
+ layer_id=layer_id,
165
+ quant_config=quant_config,
166
+ prefix=f"{prefix}.output",
167
+ )
168
+
169
+ self.output = BertSelfOutput(
170
+ hidden_size=hidden_size,
171
+ layer_norm_eps=layer_norm_eps,
172
+ quant_config=quant_config,
173
+ prefix=f"{prefix}.output",
174
+ )
175
+
176
+ def forward(
177
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
178
+ ) -> torch.Tensor:
179
+ self_output = self.self_attn(hidden_states, forward_batch)
180
+ return self.output(self_output, hidden_states)
181
+
182
+
183
+ class BertSelfAttention(nn.Module):
184
+
185
+ def __init__(
186
+ self,
187
+ hidden_size: int,
188
+ num_attention_heads: int,
189
+ layer_id: int = 0,
190
+ quant_config: Optional[QuantizationConfig] = None,
191
+ prefix: str = "",
192
+ ):
193
+ super().__init__()
194
+ self.hidden_size = hidden_size
195
+ tp_size = get_tensor_model_parallel_world_size()
196
+
197
+ self.total_num_heads = num_attention_heads
198
+ assert self.total_num_heads % tp_size == 0
199
+
200
+ self.num_heads = self.total_num_heads // tp_size
201
+ self.total_num_kv_heads = self.total_num_heads
202
+ self.head_dim = self.hidden_size // self.total_num_heads
203
+ assert self.head_dim * self.total_num_heads == self.hidden_size
204
+
205
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
206
+
207
+ self.q_size = self.num_heads * self.head_dim
208
+ self.kv_size = self.num_kv_heads * self.head_dim
209
+ self.scaling = self.head_dim**-0.5
210
+ self.qkv_proj = QKVParallelLinear(
211
+ hidden_size=self.hidden_size,
212
+ head_size=self.head_dim,
213
+ total_num_heads=self.total_num_heads,
214
+ total_num_kv_heads=self.total_num_kv_heads,
215
+ bias=True,
216
+ quant_config=quant_config,
217
+ prefix=f"{prefix}.qkv_proj",
218
+ )
219
+
220
+ self.attn = RadixAttention(
221
+ num_heads=self.num_heads,
222
+ head_dim=self.head_dim,
223
+ scaling=self.scaling,
224
+ num_kv_heads=self.num_kv_heads,
225
+ layer_id=layer_id,
226
+ prefix=f"{prefix}.attn",
227
+ attn_type=AttentionType.ENCODER_ONLY,
228
+ )
229
+
230
+ def forward(
231
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
232
+ ) -> torch.Tensor:
233
+ qkv, _ = self.qkv_proj(hidden_states)
234
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
235
+ output = self.attn(q, k, v, forward_batch)
236
+ return output
237
+
238
+
239
+ class BertSelfOutput(nn.Module):
240
+
241
+ def __init__(
242
+ self,
243
+ hidden_size: int,
244
+ layer_norm_eps: float,
245
+ quant_config: Optional[QuantizationConfig] = None,
246
+ prefix: str = "",
247
+ ):
248
+ super().__init__()
249
+ self.dense = RowParallelLinear(
250
+ input_size=hidden_size,
251
+ output_size=hidden_size,
252
+ bias=True,
253
+ quant_config=quant_config,
254
+ prefix=f"{prefix}.dense",
255
+ )
256
+ self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
257
+
258
+ def forward(
259
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
260
+ ) -> torch.Tensor:
261
+ hidden_states, _ = self.dense(hidden_states)
262
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
263
+ return hidden_states
264
+
265
+
266
+ class BertIntermediate(nn.Module):
267
+
268
+ def __init__(
269
+ self,
270
+ hidden_size: int,
271
+ intermediate_size: int,
272
+ hidden_act: str,
273
+ quant_config: Optional[QuantizationConfig] = None,
274
+ prefix: str = "",
275
+ ):
276
+ super().__init__()
277
+ self.dense = ColumnParallelLinear(
278
+ input_size=hidden_size,
279
+ output_size=intermediate_size,
280
+ bias=True,
281
+ quant_config=quant_config,
282
+ prefix=f"{prefix}.dense",
283
+ )
284
+ self.intermediate_act_fn = get_act_fn(hidden_act)
285
+
286
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
287
+ hidden_states, _ = self.dense(hidden_states)
288
+ hidden_states = self.intermediate_act_fn(hidden_states)
289
+ return hidden_states
290
+
291
+
292
+ class BertOutput(nn.Module):
293
+
294
+ def __init__(
295
+ self,
296
+ hidden_size: int,
297
+ intermediate_size: int,
298
+ layer_norm_eps: float,
299
+ quant_config: Optional[QuantizationConfig] = None,
300
+ prefix: str = "",
301
+ ):
302
+ super().__init__()
303
+
304
+ self.dense = RowParallelLinear(
305
+ input_size=intermediate_size,
306
+ output_size=hidden_size,
307
+ bias=True,
308
+ quant_config=quant_config,
309
+ prefix=f"{prefix}.dense",
310
+ )
311
+
312
+ self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
313
+
314
+ def forward(
315
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
316
+ ) -> torch.Tensor:
317
+ hidden_states, _ = self.dense(hidden_states)
318
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
319
+ return hidden_states
320
+
321
+
322
+ class BertModel(nn.Module):
323
+
324
+ def __init__(
325
+ self,
326
+ *,
327
+ config: BertConfig,
328
+ quant_config: Optional[QuantizationConfig] = None,
329
+ prefix: str = "",
330
+ ):
331
+ super().__init__()
332
+ self.config = config
333
+ self.embeddings = BertEmbedding(config)
334
+ self.encoder = BertEncoder(
335
+ config=config, quant_config=quant_config, prefix=f"encoder"
336
+ )
337
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
338
+ # self.pooler = BertPooler(config)
339
+
340
+ @torch.no_grad()
341
+ def forward(
342
+ self,
343
+ input_ids: torch.Tensor,
344
+ positions: torch.Tensor,
345
+ forward_batch: ForwardBatch,
346
+ input_embeds: torch.Tensor = None,
347
+ get_embedding: bool = False,
348
+ ) -> torch.Tensor:
349
+ assert get_embedding == True
350
+ # Your tokenized IDs
351
+
352
+ hidden_states = self.embeddings(
353
+ input_ids=input_ids,
354
+ position_ids=positions,
355
+ )
356
+
357
+ hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
358
+ return self.pooler(hidden_states, forward_batch)
359
+
360
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
361
+ stacked_params_mapping = [
362
+ # (param_name, shard_name, shard_id)
363
+ ("qkv_proj", "query", "q"),
364
+ ("qkv_proj", "key", "k"),
365
+ ("qkv_proj", "value", "v"),
366
+ ]
367
+
368
+ params_dict = dict(self.named_parameters())
369
+ for name, loaded_weight in weights:
370
+ name = name.replace("self", "self_attn")
371
+ if "pooler" in name:
372
+ continue
373
+ for param_name, weight_name, shard_id in stacked_params_mapping:
374
+
375
+ if weight_name not in name:
376
+ continue
377
+ name = name.replace(weight_name, param_name)
378
+ # Skip loading extra bias for GPTQ models.
379
+ if name.endswith(".bias") and name not in params_dict:
380
+ continue
381
+ param = params_dict[name]
382
+ weight_loader = param.weight_loader
383
+ weight_loader(param, loaded_weight, shard_id)
384
+ break
385
+ else:
386
+ # Skip loading extra bias for GPTQ models.
387
+ if name.endswith(".bias") and name not in params_dict:
388
+ continue
389
+ param = params_dict[name]
390
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
391
+ weight_loader(param, loaded_weight)
392
+
393
+
394
+ class Contriever(BertModel):
395
+ pass
396
+
397
+
398
+ EntryClass = [BertModel, Contriever]
@@ -170,7 +170,7 @@ class DeepseekMoE(nn.Module):
170
170
  shared_output = self.shared_experts(hidden_states)
171
171
  # router_logits: (num_tokens, n_experts)
172
172
  router_logits, _ = self.gate(hidden_states)
173
- final_hidden_states = fused_moe(
173
+ final_hidden_states = fused_moe.fused_moe(
174
174
  hidden_states,
175
175
  self.w1,
176
176
  self.w2,
@@ -40,7 +40,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
40
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
41
  from sglang.srt.model_loader.weight_utils import default_weight_loader
42
42
  from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
43
- from sglang.srt.utils import add_prefix, is_cuda, is_hip
43
+ from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip
44
44
 
45
45
  _is_hip = is_hip()
46
46
  _is_cuda = is_cuda()
@@ -48,7 +48,7 @@ _is_cuda = is_cuda()
48
48
  if _is_cuda:
49
49
  from sgl_kernel import awq_dequantize
50
50
  else:
51
- from vllm import _custom_ops as ops
51
+ from vllm._custom_ops import awq_dequantize
52
52
 
53
53
 
54
54
  class DeepseekModelNextN(nn.Module):
@@ -91,6 +91,14 @@ class DeepseekModelNextN(nn.Module):
91
91
  forward_batch: ForwardBatch,
92
92
  input_embeds: torch.Tensor = None,
93
93
  ) -> torch.Tensor:
94
+ zero_allocator = BumpAllocator(
95
+ buffer_size=2,
96
+ dtype=torch.float32,
97
+ device=(
98
+ input_embeds.device if input_embeds is not None else input_ids.device
99
+ ),
100
+ )
101
+
94
102
  if input_embeds is None:
95
103
  hidden_states = self.embed_tokens(input_ids)
96
104
  else:
@@ -108,7 +116,7 @@ class DeepseekModelNextN(nn.Module):
108
116
 
109
117
  residual = None
110
118
  hidden_states, residual = self.decoder(
111
- positions, hidden_states, forward_batch, residual
119
+ positions, hidden_states, forward_batch, residual, zero_allocator
112
120
  )
113
121
 
114
122
  if not forward_batch.forward_mode.is_idle():
@@ -262,79 +270,75 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
262
270
  )
263
271
  weight_loader(param, loaded_weight)
264
272
 
265
- if not global_server_args_dict["disable_mla"]:
266
- self_attn = self.model.decoder.self_attn
267
- if hasattr(self_attn.kv_b_proj, "qweight"):
268
- # AWQ compatible
269
- if _is_cuda:
270
- w = awq_dequantize(
271
- self_attn.kv_b_proj.qweight,
272
- self_attn.kv_b_proj.scales,
273
- self_attn.kv_b_proj.qzeros,
274
- ).T
275
- else:
276
- w = ops.awq_dequantize(
277
- self_attn.kv_b_proj.qweight,
278
- self_attn.kv_b_proj.scales,
279
- self_attn.kv_b_proj.qzeros,
280
- 0,
281
- 0,
282
- 0,
283
- ).T
273
+ self_attn = self.model.decoder.self_attn
274
+ if hasattr(self_attn.kv_b_proj, "qweight"):
275
+ # AWQ compatible
276
+ if _is_cuda:
277
+ w = awq_dequantize(
278
+ self_attn.kv_b_proj.qweight,
279
+ self_attn.kv_b_proj.scales,
280
+ self_attn.kv_b_proj.qzeros,
281
+ ).T
284
282
  else:
285
- w = self_attn.kv_b_proj.weight
286
- # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
287
- # This may affect the accuracy of fp8 model.
288
- if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
289
- torch.float8_e4m3fn,
290
- torch.float8_e4m3fnuz,
291
- ):
283
+ w = awq_dequantize(
284
+ self_attn.kv_b_proj.qweight,
285
+ self_attn.kv_b_proj.scales,
286
+ self_attn.kv_b_proj.qzeros,
287
+ 0,
288
+ 0,
289
+ 0,
290
+ ).T
291
+ else:
292
+ w = self_attn.kv_b_proj.weight
293
+ # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
294
+ # This may affect the accuracy of fp8 model.
295
+ if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
296
+ torch.float8_e4m3fn,
297
+ torch.float8_e4m3fnuz,
298
+ ):
299
+ weight_block_size = self.quant_config.weight_block_size
300
+ if weight_block_size is not None:
301
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
302
+ if _is_hip:
303
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
304
+ weight=w,
305
+ weight_scale=self_attn.kv_b_proj.weight_scale_inv,
306
+ input_scale=None,
307
+ )
308
+ else:
309
+ weight = w
310
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
311
+
312
+ w, scale = block_quant_to_tensor_quant(
313
+ weight, weight_scale, weight_block_size
314
+ )
315
+ self_attn.w_scale = scale
316
+ if w.dtype == torch.int8:
317
+ if hasattr(self.quant_config, "weight_block_size"):
318
+ # block-wise int8 need it
292
319
  weight_block_size = self.quant_config.weight_block_size
293
320
  if weight_block_size is not None:
294
321
  assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
295
- if _is_hip:
296
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
297
- weight=w,
298
- weight_scale=self_attn.kv_b_proj.weight_scale_inv,
299
- input_scale=None,
300
- )
301
- else:
302
- weight = w
303
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
304
-
305
- w, scale = block_quant_to_tensor_quant(
306
- weight, weight_scale, weight_block_size
307
- )
308
- self_attn.w_scale = scale
309
- if w.dtype == torch.int8:
310
- if hasattr(self.quant_config, "weight_block_size"):
311
- # block-wise int8 need it
312
- weight_block_size = self.quant_config.weight_block_size
313
- if weight_block_size is not None:
314
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
315
- weight = w
316
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
317
- w = int8_block_dequant(
318
- weight, weight_scale, weight_block_size
319
- ).to(torch.bfloat16)
320
- else:
321
- # channel-wise int8 need it
322
- assert hasattr(self_attn.kv_b_proj, "weight_scale")
323
- w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
322
+ weight = w
323
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
324
+ w = int8_block_dequant(weight, weight_scale, weight_block_size).to(
324
325
  torch.bfloat16
325
326
  )
326
- w_kc, w_vc = w.unflatten(
327
- 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
328
- ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
329
- self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
330
- self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
331
- if (
332
- hasattr(self_attn.kv_b_proj, "weight_scale")
333
- and self_attn.w_scale is None
334
- ):
335
- self_attn.w_scale = self_attn.kv_b_proj.weight_scale
336
- if _is_hip:
337
- self_attn.w_scale *= 2.0
327
+ else:
328
+ # channel-wise int8 need it
329
+ assert hasattr(self_attn.kv_b_proj, "weight_scale")
330
+ w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
331
+ torch.bfloat16
332
+ )
333
+ w_kc, w_vc = w.unflatten(
334
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
335
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
336
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
337
+ self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
338
+ if hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None:
339
+ self_attn.w_scale = self_attn.kv_b_proj.weight_scale
340
+ if _is_hip:
341
+ self_attn.w_scale *= 2.0
338
342
 
339
343
 
340
344
  EntryClass = [DeepseekV3ForCausalLMNextN]