sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +12 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +24 -14
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +60 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/hf_transformers_utils.py +10 -0
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +240 -109
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +12 -6
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +9 -4
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/managers/cache_controller.py +62 -96
- sglang/srt/managers/detokenizer_manager.py +43 -2
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
- sglang/srt/managers/scheduler.py +36 -2
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +86 -39
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +20 -3
- sglang/srt/mem_cache/hiradix_cache.py +75 -68
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +4 -0
- sglang/srt/mem_cache/memory_pool_host.py +2 -4
- sglang/srt/mem_cache/radix_cache.py +5 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +33 -7
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -4
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +26 -10
- sglang/srt/models/gpt_oss.py +0 -14
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1015 -0
- sglang/srt/models/longcat_flash_nextn.py +691 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +65 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +112 -55
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/utils.py +14 -0
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +5 -5
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +83 -78
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
sglang/srt/layers/linear.py
CHANGED
@@ -235,8 +235,9 @@ class ReplicatedLinear(LinearBase):
|
|
235
235
|
loaded_weight = loaded_weight[:1]
|
236
236
|
else:
|
237
237
|
raise ValueError(f"{loaded_weight} are not all equal")
|
238
|
-
|
239
|
-
|
238
|
+
assert (
|
239
|
+
param.size() == loaded_weight.size()
|
240
|
+
), f"Loading weight error: param: {param.size()}, loaded_weight: {loaded_weight.size()}"
|
240
241
|
param.data.copy_(loaded_weight)
|
241
242
|
|
242
243
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
@@ -61,7 +61,7 @@ class LogitsProcessorOutput:
|
|
61
61
|
hidden_states: Optional[torch.Tensor] = None
|
62
62
|
|
63
63
|
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
|
64
|
-
#
|
64
|
+
# he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
|
65
65
|
next_token_logprobs: Optional[torch.Tensor] = None
|
66
66
|
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
67
67
|
next_token_top_logprobs_val: Optional[List] = None
|
@@ -1362,3 +1362,77 @@ def moe_ep_deepgemm_preprocess(
|
|
1362
1362
|
gateup_input,
|
1363
1363
|
gateup_input_scale,
|
1364
1364
|
)
|
1365
|
+
|
1366
|
+
|
1367
|
+
@triton.jit
|
1368
|
+
def compute_identity_kernel(
|
1369
|
+
top_k,
|
1370
|
+
hidden_states_ptr,
|
1371
|
+
expert_scales_ptr,
|
1372
|
+
num_tokens,
|
1373
|
+
output_ptr,
|
1374
|
+
hidden_dim,
|
1375
|
+
scales_stride,
|
1376
|
+
BLOCK_SIZE: tl.constexpr,
|
1377
|
+
):
|
1378
|
+
pid = tl.program_id(0)
|
1379
|
+
|
1380
|
+
batch_id = pid // (hidden_dim // BLOCK_SIZE)
|
1381
|
+
dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
|
1382
|
+
|
1383
|
+
if batch_id >= num_tokens or dim_offset >= hidden_dim:
|
1384
|
+
return
|
1385
|
+
|
1386
|
+
h = tl.load(
|
1387
|
+
hidden_states_ptr
|
1388
|
+
+ batch_id * hidden_dim
|
1389
|
+
+ dim_offset
|
1390
|
+
+ tl.arange(0, BLOCK_SIZE),
|
1391
|
+
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
|
1392
|
+
)
|
1393
|
+
|
1394
|
+
result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
1395
|
+
for i in range(top_k):
|
1396
|
+
scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
|
1397
|
+
result += h * scale
|
1398
|
+
|
1399
|
+
tl.store(
|
1400
|
+
output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE),
|
1401
|
+
result,
|
1402
|
+
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
|
1403
|
+
)
|
1404
|
+
|
1405
|
+
|
1406
|
+
def zero_experts_compute_triton(
|
1407
|
+
expert_indices, expert_scales, num_experts, zero_expert_type, hidden_states
|
1408
|
+
):
|
1409
|
+
N = expert_indices.numel()
|
1410
|
+
top_k = expert_indices.size(-1)
|
1411
|
+
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
|
1412
|
+
|
1413
|
+
if zero_expert_type == "identity":
|
1414
|
+
zero_expert_mask = expert_indices < num_experts
|
1415
|
+
zero_expert_scales = expert_scales.clone()
|
1416
|
+
zero_expert_scales[zero_expert_mask] = 0.0
|
1417
|
+
|
1418
|
+
normal_expert_mask = expert_indices >= num_experts
|
1419
|
+
expert_indices[normal_expert_mask] = 0
|
1420
|
+
expert_scales[normal_expert_mask] = 0.0
|
1421
|
+
|
1422
|
+
output = torch.zeros_like(hidden_states).to(hidden_states.device)
|
1423
|
+
hidden_dim = hidden_states.size(-1)
|
1424
|
+
num_tokens = hidden_states.size(0)
|
1425
|
+
|
1426
|
+
grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),)
|
1427
|
+
compute_identity_kernel[grid](
|
1428
|
+
top_k,
|
1429
|
+
hidden_states,
|
1430
|
+
zero_expert_scales,
|
1431
|
+
num_tokens,
|
1432
|
+
output,
|
1433
|
+
hidden_dim,
|
1434
|
+
zero_expert_scales.stride(0),
|
1435
|
+
BLOCK_SIZE=256,
|
1436
|
+
)
|
1437
|
+
|
1438
|
+
return output
|
@@ -746,19 +746,25 @@ class DeepEPMoE(EPMoE):
|
|
746
746
|
hidden_states = torch_npu.npu_grouped_matmul(
|
747
747
|
x=[hidden_states],
|
748
748
|
weight=[self.w13_weight],
|
749
|
-
scale=[self.w13_weight_scale.to(output_dtype)],
|
750
|
-
per_token_scale=[pertoken_scale],
|
751
749
|
split_item=2,
|
752
750
|
group_list_type=group_list_type,
|
753
751
|
group_type=0,
|
754
752
|
group_list=seg_indptr,
|
755
|
-
output_dtype=
|
753
|
+
output_dtype=torch.int32,
|
756
754
|
)[0]
|
757
755
|
|
758
756
|
# act_fn: swiglu
|
759
|
-
hidden_states = torch_npu.
|
760
|
-
|
761
|
-
|
757
|
+
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
758
|
+
x=hidden_states,
|
759
|
+
weight_scale=self.w13_weight_scale.to(torch.float32),
|
760
|
+
activation_scale=pertoken_scale,
|
761
|
+
bias=None,
|
762
|
+
quant_scale=None,
|
763
|
+
quant_offset=None,
|
764
|
+
group_index=seg_indptr,
|
765
|
+
activate_left=True,
|
766
|
+
quant_mode=1,
|
767
|
+
)
|
762
768
|
|
763
769
|
# gmm2: down_proj
|
764
770
|
hidden_states = torch_npu.npu_grouped_matmul(
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 5
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 5
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 256,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 256,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 256,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 4
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 256,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 256,
|
93
|
+
"BLOCK_SIZE_K": 64,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 8,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 256,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 8,
|
112
|
+
"num_stages": 5
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 256,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 256,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 256,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 5
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 256,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 5
|
145
|
+
}
|
146
|
+
}
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -304,12 +304,12 @@ class TopK(CustomOp):
|
|
304
304
|
global_num_experts = router_logits.shape[-1]
|
305
305
|
|
306
306
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
307
|
-
if global_num_experts == 256
|
307
|
+
if global_num_experts == 256:
|
308
308
|
|
309
309
|
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
310
310
|
router_logits = router_logits.to(torch.float32)
|
311
311
|
|
312
|
-
|
312
|
+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
313
313
|
router_logits,
|
314
314
|
k=self.topk_config.top_k,
|
315
315
|
bias=self.topk_config.correction_bias.to(torch.float32),
|
@@ -321,6 +321,16 @@ class TopK(CustomOp):
|
|
321
321
|
routed_scaling_factor=routed_scaling_factor,
|
322
322
|
eps=float(1e-20),
|
323
323
|
)
|
324
|
+
|
325
|
+
if self.topk_config.renormalize:
|
326
|
+
topk_weights_sum = (
|
327
|
+
topk_weights.sum(dim=-1, keepdim=True)
|
328
|
+
if self.topk_config.num_fused_shared_experts == 0
|
329
|
+
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
330
|
+
)
|
331
|
+
topk_weights = topk_weights / topk_weights_sum
|
332
|
+
|
333
|
+
return StandardTopKOutput(topk_weights, topk_ids, _)
|
324
334
|
else:
|
325
335
|
self.topk_config.torch_native = True
|
326
336
|
return select_experts(
|
@@ -347,17 +357,28 @@ def fused_topk_torch_native(
|
|
347
357
|
gating_output: torch.Tensor,
|
348
358
|
topk: int,
|
349
359
|
renormalize: bool,
|
360
|
+
correction_bias: torch.Tensor = None,
|
350
361
|
):
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
362
|
+
if correction_bias is not None:
|
363
|
+
n_routed_experts = gating_output.shape[-1]
|
364
|
+
scores = gating_output.softmax(dim=-1)
|
365
|
+
scores_for_choice = scores.view(
|
366
|
+
-1, n_routed_experts
|
367
|
+
) + correction_bias.unsqueeze(0)
|
368
|
+
topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
|
369
|
+
topk_weights = scores.gather(1, topk_ids)
|
370
|
+
else:
|
371
|
+
assert (
|
372
|
+
hidden_states.shape[0] == gating_output.shape[0]
|
373
|
+
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
|
374
|
+
M, _ = hidden_states.shape
|
375
|
+
topk_weights = torch.empty(
|
376
|
+
M, topk, dtype=torch.float32, device=hidden_states.device
|
377
|
+
)
|
378
|
+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
379
|
+
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
380
|
+
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
381
|
+
|
361
382
|
if renormalize:
|
362
383
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
363
384
|
return topk_weights, topk_ids
|
@@ -370,6 +391,7 @@ def fused_topk_cpu(
|
|
370
391
|
renormalize: bool,
|
371
392
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
372
393
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
394
|
+
correction_bias: torch.Tensor = None,
|
373
395
|
):
|
374
396
|
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
|
375
397
|
hidden_states=hidden_states,
|
@@ -815,6 +837,7 @@ def select_experts(
|
|
815
837
|
gating_output=router_logits,
|
816
838
|
topk=top_k,
|
817
839
|
renormalize=renormalize,
|
840
|
+
correction_bias=correction_bias,
|
818
841
|
)
|
819
842
|
elif custom_routing_function is None:
|
820
843
|
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
@@ -93,7 +93,7 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
93
93
|
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
94
94
|
logger.warning(
|
95
95
|
"Entering DeepGEMM JIT Pre-Compile session. "
|
96
|
-
"It may
|
96
|
+
"It may take a long time (typically 10-20 mins) "
|
97
97
|
"if you have not run `sglang.compile_deep_gemm`. "
|
98
98
|
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
99
99
|
" for pre-compilation to reduce the overhead if you have not run it before. "
|
@@ -599,6 +599,13 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
599
599
|
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
600
600
|
if re.fullmatch(regex_str, prefix):
|
601
601
|
return True
|
602
|
+
|
603
|
+
# Check if the last part of the excluded pattern is contained in the last part of the prefix
|
604
|
+
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
|
605
|
+
pattern_last_part = pattern.split(".")[-1]
|
606
|
+
prefix_last_part = prefix.split(".")[-1]
|
607
|
+
if pattern_last_part in prefix_last_part:
|
608
|
+
return True
|
602
609
|
return False
|
603
610
|
|
604
611
|
def get_quant_method(
|
@@ -66,10 +66,15 @@ _is_hip = is_hip()
|
|
66
66
|
|
67
67
|
if _is_hip:
|
68
68
|
# import aiter
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
69
|
+
try:
|
70
|
+
from aiter import ActivationType, QuantType, dtypes
|
71
|
+
from aiter.fused_moe import fused_moe
|
72
|
+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
73
|
+
from aiter.utility.fp4_utils import e8m0_shuffle
|
74
|
+
except ImportError as err:
|
75
|
+
ActivationType = QuantType = dtypes = fused_moe = dynamic_mxfp4_quant = (
|
76
|
+
e8m0_shuffle
|
77
|
+
) = err
|
73
78
|
|
74
79
|
|
75
80
|
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
@@ -77,6 +77,19 @@ def is_layer_skipped(
|
|
77
77
|
)
|
78
78
|
else:
|
79
79
|
is_skipped = prefix in ignored_layers
|
80
|
+
if "gate_up_proj" in prefix:
|
81
|
+
prefix_gate = prefix.replace("gate_up_proj", "gate_proj")
|
82
|
+
prefix_up = prefix.replace("gate_up_proj", "up_proj")
|
83
|
+
if prefix_gate in ignored_layers and prefix_up in ignored_layers:
|
84
|
+
is_skipped = True
|
85
|
+
elif "experts" in prefix:
|
86
|
+
is_skipped = any(
|
87
|
+
[
|
88
|
+
prefix in layer_name
|
89
|
+
for layer_name in ignored_layers
|
90
|
+
if "experts" in layer_name
|
91
|
+
]
|
92
|
+
)
|
80
93
|
|
81
94
|
assert is_skipped is not None
|
82
95
|
return is_skipped
|
@@ -551,7 +551,7 @@ class NPU_W8A8LinearMethodImpl:
|
|
551
551
|
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
552
552
|
params_dict = {}
|
553
553
|
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
554
|
-
params_dict["input_offset"] = torch.empty(1, dtype=
|
554
|
+
params_dict["input_offset"] = torch.empty(1, dtype=params_dtype)
|
555
555
|
return params_dict
|
556
556
|
|
557
557
|
@staticmethod
|
@@ -582,11 +582,11 @@ class NPU_W8A8LinearMethodImpl:
|
|
582
582
|
if original_dtype != torch.int8:
|
583
583
|
x = torch_npu.npu_quantize(
|
584
584
|
x,
|
585
|
-
layer.
|
585
|
+
layer.aclnn_input_scale_reciprocal,
|
586
586
|
layer.aclnn_input_offset,
|
587
587
|
torch.qint8,
|
588
588
|
-1,
|
589
|
-
|
589
|
+
False,
|
590
590
|
)
|
591
591
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
592
592
|
# bias will not get added more than once in Attention TP>1 case)
|
@@ -608,6 +608,10 @@ class NPU_W8A8LinearMethodImpl:
|
|
608
608
|
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
|
609
609
|
requires_grad=False,
|
610
610
|
)
|
611
|
+
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
|
612
|
+
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
|
613
|
+
requires_grad=False,
|
614
|
+
)
|
611
615
|
layer.aclnn_input_offset = torch.nn.Parameter(
|
612
616
|
layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
|
613
617
|
requires_grad=False,
|
@@ -1876,7 +1876,7 @@ def rotate_half(x):
|
|
1876
1876
|
return torch.cat((-x2, x1), dim=-1)
|
1877
1877
|
|
1878
1878
|
|
1879
|
-
def
|
1879
|
+
def apply_rotary_pos_emb_native(
|
1880
1880
|
q: torch.Tensor,
|
1881
1881
|
k: torch.Tensor,
|
1882
1882
|
cos: torch.Tensor,
|
@@ -1899,6 +1899,33 @@ def apply_rotary_pos_emb(
|
|
1899
1899
|
return q_embed, k_embed
|
1900
1900
|
|
1901
1901
|
|
1902
|
+
def apply_rotary_pos_emb_npu(
|
1903
|
+
q: torch.Tensor,
|
1904
|
+
k: torch.Tensor,
|
1905
|
+
cos: torch.Tensor,
|
1906
|
+
sin: torch.Tensor,
|
1907
|
+
unsqueeze_dim=1,
|
1908
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1909
|
+
if q.shape[1] != 128:
|
1910
|
+
return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
|
1911
|
+
cos = cos.unsqueeze(unsqueeze_dim)
|
1912
|
+
cos = torch.transpose(cos, 1, 2)
|
1913
|
+
sin = sin.unsqueeze(unsqueeze_dim)
|
1914
|
+
sin = torch.transpose(sin, 1, 2)
|
1915
|
+
q = torch.transpose(q, 1, 2)
|
1916
|
+
k = torch.transpose(k, 1, 2)
|
1917
|
+
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
|
1918
|
+
q_embed = torch.transpose(q_embed, 1, 2)
|
1919
|
+
k_embed = torch.transpose(k_embed, 1, 2)
|
1920
|
+
return q_embed, k_embed
|
1921
|
+
|
1922
|
+
|
1923
|
+
if _is_npu:
|
1924
|
+
apply_rotary_pos_emb = apply_rotary_pos_emb_npu
|
1925
|
+
else:
|
1926
|
+
apply_rotary_pos_emb = apply_rotary_pos_emb_native
|
1927
|
+
|
1928
|
+
|
1902
1929
|
def get_rope_cpu(
|
1903
1930
|
head_size: int,
|
1904
1931
|
rotary_dim: int,
|
sglang/srt/layers/sampler.py
CHANGED
@@ -27,6 +27,7 @@ if is_cuda():
|
|
27
27
|
logger = logging.getLogger(__name__)
|
28
28
|
|
29
29
|
SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
|
30
|
+
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
|
30
31
|
|
31
32
|
|
32
33
|
class Sampler(nn.Module):
|
@@ -77,7 +78,12 @@ class Sampler(nn.Module):
|
|
77
78
|
batch_next_token_ids = torch.argmax(logits, -1)
|
78
79
|
if return_logprob:
|
79
80
|
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
81
|
+
|
80
82
|
else:
|
83
|
+
# Post process original logits. if temperatures are all 1.0, no need to rescale
|
84
|
+
if return_logprob and RETURN_ORIGINAL_LOGPROB:
|
85
|
+
logprobs = torch.softmax(logits, dim=-1)
|
86
|
+
|
81
87
|
# Post process logits
|
82
88
|
logits.div_(sampling_info.temperatures)
|
83
89
|
logits[:] = torch.softmax(logits, dim=-1)
|
@@ -116,7 +122,12 @@ class Sampler(nn.Module):
|
|
116
122
|
|
117
123
|
if return_logprob:
|
118
124
|
# clamp to avoid -inf
|
119
|
-
|
125
|
+
if RETURN_ORIGINAL_LOGPROB:
|
126
|
+
logprobs = torch.log(logprobs).clamp(
|
127
|
+
min=torch.finfo(logprobs.dtype).min
|
128
|
+
)
|
129
|
+
else:
|
130
|
+
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
120
131
|
|
121
132
|
# Attach logprobs to logits_output (in-place modification)
|
122
133
|
if return_logprob:
|
@@ -201,7 +212,10 @@ def top_p_normalize_probs_torch(
|
|
201
212
|
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
|
202
213
|
|
203
214
|
|
204
|
-
def get_top_logprobs(
|
215
|
+
def get_top_logprobs(
|
216
|
+
logprobs: torch.Tensor,
|
217
|
+
top_logprobs_nums: List[int],
|
218
|
+
):
|
205
219
|
max_k = max(top_logprobs_nums)
|
206
220
|
ret = logprobs.topk(max_k, dim=1)
|
207
221
|
values = ret.values.tolist()
|
@@ -212,10 +226,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
|
212
226
|
for i, k in enumerate(top_logprobs_nums):
|
213
227
|
output_top_logprobs_val.append(values[i][:k])
|
214
228
|
output_top_logprobs_idx.append(indices[i][:k])
|
215
|
-
|
229
|
+
|
230
|
+
return (
|
231
|
+
output_top_logprobs_val,
|
232
|
+
output_top_logprobs_idx,
|
233
|
+
)
|
216
234
|
|
217
235
|
|
218
|
-
def get_token_ids_logprobs(
|
236
|
+
def get_token_ids_logprobs(
|
237
|
+
logprobs: torch.Tensor,
|
238
|
+
token_ids_logprobs: List[List[int]],
|
239
|
+
):
|
219
240
|
output_token_ids_logprobs_val = []
|
220
241
|
output_token_ids_logprobs_idx = []
|
221
242
|
for i, token_ids in enumerate(token_ids_logprobs):
|
@@ -226,7 +247,10 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
|
|
226
247
|
output_token_ids_logprobs_val.append([])
|
227
248
|
output_token_ids_logprobs_idx.append([])
|
228
249
|
|
229
|
-
return
|
250
|
+
return (
|
251
|
+
output_token_ids_logprobs_val,
|
252
|
+
output_token_ids_logprobs_idx,
|
253
|
+
)
|
230
254
|
|
231
255
|
|
232
256
|
def apply_custom_logit_processor(
|