sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +2 -0
  3. sglang/srt/configs/longcat_flash.py +104 -0
  4. sglang/srt/configs/model_config.py +12 -0
  5. sglang/srt/connector/__init__.py +1 -1
  6. sglang/srt/connector/base_connector.py +1 -2
  7. sglang/srt/connector/redis.py +2 -2
  8. sglang/srt/connector/serde/__init__.py +1 -1
  9. sglang/srt/connector/serde/safe_serde.py +4 -3
  10. sglang/srt/disaggregation/ascend/conn.py +75 -0
  11. sglang/srt/disaggregation/launch_lb.py +0 -13
  12. sglang/srt/disaggregation/mini_lb.py +33 -8
  13. sglang/srt/disaggregation/prefill.py +1 -1
  14. sglang/srt/distributed/parallel_state.py +24 -14
  15. sglang/srt/entrypoints/engine.py +19 -12
  16. sglang/srt/entrypoints/http_server.py +174 -34
  17. sglang/srt/entrypoints/openai/protocol.py +60 -0
  18. sglang/srt/eplb/eplb_manager.py +26 -2
  19. sglang/srt/eplb/expert_distribution.py +29 -2
  20. sglang/srt/hf_transformers_utils.py +10 -0
  21. sglang/srt/layers/activation.py +12 -0
  22. sglang/srt/layers/attention/ascend_backend.py +240 -109
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  24. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  25. sglang/srt/layers/layernorm.py +28 -3
  26. sglang/srt/layers/linear.py +3 -2
  27. sglang/srt/layers/logits_processor.py +1 -1
  28. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  29. sglang/srt/layers/moe/ep_moe/layer.py +12 -6
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/topk.py +35 -12
  32. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  33. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  34. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  35. sglang/srt/layers/quantization/mxfp4.py +9 -4
  36. sglang/srt/layers/quantization/utils.py +13 -0
  37. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  38. sglang/srt/layers/rotary_embedding.py +28 -1
  39. sglang/srt/layers/sampler.py +29 -5
  40. sglang/srt/managers/cache_controller.py +62 -96
  41. sglang/srt/managers/detokenizer_manager.py +43 -2
  42. sglang/srt/managers/io_struct.py +27 -0
  43. sglang/srt/managers/mm_utils.py +5 -1
  44. sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
  45. sglang/srt/managers/scheduler.py +36 -2
  46. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  47. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  48. sglang/srt/managers/tokenizer_manager.py +86 -39
  49. sglang/srt/mem_cache/chunk_cache.py +1 -1
  50. sglang/srt/mem_cache/hicache_storage.py +20 -3
  51. sglang/srt/mem_cache/hiradix_cache.py +75 -68
  52. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  53. sglang/srt/mem_cache/memory_pool.py +4 -0
  54. sglang/srt/mem_cache/memory_pool_host.py +2 -4
  55. sglang/srt/mem_cache/radix_cache.py +5 -4
  56. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  57. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +33 -7
  58. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
  59. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  60. sglang/srt/model_executor/model_runner.py +5 -4
  61. sglang/srt/model_loader/loader.py +15 -24
  62. sglang/srt/model_loader/utils.py +12 -0
  63. sglang/srt/models/deepseek_v2.py +26 -10
  64. sglang/srt/models/gpt_oss.py +0 -14
  65. sglang/srt/models/llama_eagle3.py +4 -0
  66. sglang/srt/models/longcat_flash.py +1015 -0
  67. sglang/srt/models/longcat_flash_nextn.py +691 -0
  68. sglang/srt/models/qwen2.py +26 -3
  69. sglang/srt/models/qwen2_5_vl.py +65 -41
  70. sglang/srt/models/qwen2_moe.py +22 -2
  71. sglang/srt/models/transformers.py +1 -1
  72. sglang/srt/multimodal/processors/base_processor.py +4 -2
  73. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  74. sglang/srt/server_args.py +112 -55
  75. sglang/srt/speculative/eagle_worker.py +28 -8
  76. sglang/srt/utils.py +14 -0
  77. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  78. sglang/version.py +1 -1
  79. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +5 -5
  80. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +83 -78
  81. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
  82. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -235,8 +235,9 @@ class ReplicatedLinear(LinearBase):
235
235
  loaded_weight = loaded_weight[:1]
236
236
  else:
237
237
  raise ValueError(f"{loaded_weight} are not all equal")
238
-
239
- assert param.size() == loaded_weight.size()
238
+ assert (
239
+ param.size() == loaded_weight.size()
240
+ ), f"Loading weight error: param: {param.size()}, loaded_weight: {loaded_weight.size()}"
240
241
  param.data.copy_(loaded_weight)
241
242
 
242
243
  def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -61,7 +61,7 @@ class LogitsProcessorOutput:
61
61
  hidden_states: Optional[torch.Tensor] = None
62
62
 
63
63
  ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
64
- # The logprobs of the next tokens. shape: [#seq]
64
+ # he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
65
65
  next_token_logprobs: Optional[torch.Tensor] = None
66
66
  # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
67
67
  next_token_top_logprobs_val: Optional[List] = None
@@ -1362,3 +1362,77 @@ def moe_ep_deepgemm_preprocess(
1362
1362
  gateup_input,
1363
1363
  gateup_input_scale,
1364
1364
  )
1365
+
1366
+
1367
+ @triton.jit
1368
+ def compute_identity_kernel(
1369
+ top_k,
1370
+ hidden_states_ptr,
1371
+ expert_scales_ptr,
1372
+ num_tokens,
1373
+ output_ptr,
1374
+ hidden_dim,
1375
+ scales_stride,
1376
+ BLOCK_SIZE: tl.constexpr,
1377
+ ):
1378
+ pid = tl.program_id(0)
1379
+
1380
+ batch_id = pid // (hidden_dim // BLOCK_SIZE)
1381
+ dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
1382
+
1383
+ if batch_id >= num_tokens or dim_offset >= hidden_dim:
1384
+ return
1385
+
1386
+ h = tl.load(
1387
+ hidden_states_ptr
1388
+ + batch_id * hidden_dim
1389
+ + dim_offset
1390
+ + tl.arange(0, BLOCK_SIZE),
1391
+ mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
1392
+ )
1393
+
1394
+ result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
1395
+ for i in range(top_k):
1396
+ scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
1397
+ result += h * scale
1398
+
1399
+ tl.store(
1400
+ output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE),
1401
+ result,
1402
+ mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
1403
+ )
1404
+
1405
+
1406
+ def zero_experts_compute_triton(
1407
+ expert_indices, expert_scales, num_experts, zero_expert_type, hidden_states
1408
+ ):
1409
+ N = expert_indices.numel()
1410
+ top_k = expert_indices.size(-1)
1411
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
1412
+
1413
+ if zero_expert_type == "identity":
1414
+ zero_expert_mask = expert_indices < num_experts
1415
+ zero_expert_scales = expert_scales.clone()
1416
+ zero_expert_scales[zero_expert_mask] = 0.0
1417
+
1418
+ normal_expert_mask = expert_indices >= num_experts
1419
+ expert_indices[normal_expert_mask] = 0
1420
+ expert_scales[normal_expert_mask] = 0.0
1421
+
1422
+ output = torch.zeros_like(hidden_states).to(hidden_states.device)
1423
+ hidden_dim = hidden_states.size(-1)
1424
+ num_tokens = hidden_states.size(0)
1425
+
1426
+ grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),)
1427
+ compute_identity_kernel[grid](
1428
+ top_k,
1429
+ hidden_states,
1430
+ zero_expert_scales,
1431
+ num_tokens,
1432
+ output,
1433
+ hidden_dim,
1434
+ zero_expert_scales.stride(0),
1435
+ BLOCK_SIZE=256,
1436
+ )
1437
+
1438
+ return output
@@ -746,19 +746,25 @@ class DeepEPMoE(EPMoE):
746
746
  hidden_states = torch_npu.npu_grouped_matmul(
747
747
  x=[hidden_states],
748
748
  weight=[self.w13_weight],
749
- scale=[self.w13_weight_scale.to(output_dtype)],
750
- per_token_scale=[pertoken_scale],
751
749
  split_item=2,
752
750
  group_list_type=group_list_type,
753
751
  group_type=0,
754
752
  group_list=seg_indptr,
755
- output_dtype=output_dtype,
753
+ output_dtype=torch.int32,
756
754
  )[0]
757
755
 
758
756
  # act_fn: swiglu
759
- hidden_states = torch_npu.npu_swiglu(hidden_states)
760
-
761
- hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
757
+ hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
758
+ x=hidden_states,
759
+ weight_scale=self.w13_weight_scale.to(torch.float32),
760
+ activation_scale=pertoken_scale,
761
+ bias=None,
762
+ quant_scale=None,
763
+ quant_offset=None,
764
+ group_index=seg_indptr,
765
+ activate_left=True,
766
+ quant_mode=1,
767
+ )
762
768
 
763
769
  # gmm2: down_proj
764
770
  hidden_states = torch_npu.npu_grouped_matmul(
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 256,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 256,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 256,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 256,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 256,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 8,
112
+ "num_stages": 5
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 256,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 8,
136
+ "num_stages": 5
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 256,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 5
145
+ }
146
+ }
@@ -304,12 +304,12 @@ class TopK(CustomOp):
304
304
  global_num_experts = router_logits.shape[-1]
305
305
 
306
306
  # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
307
- if global_num_experts == 256 and self.topk_config.renormalize is True:
307
+ if global_num_experts == 256:
308
308
 
309
309
  routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
310
310
  router_logits = router_logits.to(torch.float32)
311
311
 
312
- return torch_npu.npu_moe_gating_top_k(
312
+ topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
313
313
  router_logits,
314
314
  k=self.topk_config.top_k,
315
315
  bias=self.topk_config.correction_bias.to(torch.float32),
@@ -321,6 +321,16 @@ class TopK(CustomOp):
321
321
  routed_scaling_factor=routed_scaling_factor,
322
322
  eps=float(1e-20),
323
323
  )
324
+
325
+ if self.topk_config.renormalize:
326
+ topk_weights_sum = (
327
+ topk_weights.sum(dim=-1, keepdim=True)
328
+ if self.topk_config.num_fused_shared_experts == 0
329
+ else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
330
+ )
331
+ topk_weights = topk_weights / topk_weights_sum
332
+
333
+ return StandardTopKOutput(topk_weights, topk_ids, _)
324
334
  else:
325
335
  self.topk_config.torch_native = True
326
336
  return select_experts(
@@ -347,17 +357,28 @@ def fused_topk_torch_native(
347
357
  gating_output: torch.Tensor,
348
358
  topk: int,
349
359
  renormalize: bool,
360
+ correction_bias: torch.Tensor = None,
350
361
  ):
351
- assert (
352
- hidden_states.shape[0] == gating_output.shape[0]
353
- ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
354
- M, _ = hidden_states.shape
355
- topk_weights = torch.empty(
356
- M, topk, dtype=torch.float32, device=hidden_states.device
357
- )
358
- topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
359
- topk_weights = F.softmax(gating_output.float(), dim=-1)
360
- topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
362
+ if correction_bias is not None:
363
+ n_routed_experts = gating_output.shape[-1]
364
+ scores = gating_output.softmax(dim=-1)
365
+ scores_for_choice = scores.view(
366
+ -1, n_routed_experts
367
+ ) + correction_bias.unsqueeze(0)
368
+ topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
369
+ topk_weights = scores.gather(1, topk_ids)
370
+ else:
371
+ assert (
372
+ hidden_states.shape[0] == gating_output.shape[0]
373
+ ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
374
+ M, _ = hidden_states.shape
375
+ topk_weights = torch.empty(
376
+ M, topk, dtype=torch.float32, device=hidden_states.device
377
+ )
378
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
379
+ topk_weights = F.softmax(gating_output.float(), dim=-1)
380
+ topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
381
+
361
382
  if renormalize:
362
383
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
363
384
  return topk_weights, topk_ids
@@ -370,6 +391,7 @@ def fused_topk_cpu(
370
391
  renormalize: bool,
371
392
  num_token_non_padded: Optional[torch.Tensor] = None,
372
393
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
394
+ correction_bias: torch.Tensor = None,
373
395
  ):
374
396
  topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
375
397
  hidden_states=hidden_states,
@@ -815,6 +837,7 @@ def select_experts(
815
837
  gating_output=router_logits,
816
838
  topk=top_k,
817
839
  renormalize=renormalize,
840
+ correction_bias=correction_bias,
818
841
  )
819
842
  elif custom_routing_function is None:
820
843
  assert not apply_routed_scaling_factor_on_output, "Not implemented"
@@ -93,7 +93,7 @@ def _maybe_compile_deep_gemm_one_type_all(
93
93
  if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
94
94
  logger.warning(
95
95
  "Entering DeepGEMM JIT Pre-Compile session. "
96
- "It may takes a long time (typically 10-20 mins) "
96
+ "It may take a long time (typically 10-20 mins) "
97
97
  "if you have not run `sglang.compile_deep_gemm`. "
98
98
  "It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
99
99
  " for pre-compilation to reduce the overhead if you have not run it before. "
@@ -11,9 +11,6 @@ def _compute_enable_deep_gemm():
11
11
  sm_version = get_device_sm()
12
12
  if sm_version < 90:
13
13
  return False
14
- # TODO fix deepgemm cu129 fp8 issue
15
- if torch.version.cuda == "12.9":
16
- return False
17
14
 
18
15
  try:
19
16
  import deep_gemm
@@ -599,6 +599,13 @@ class ModelOptFp4Config(QuantizationConfig):
599
599
  regex_str = pattern.replace(".", r"\.").replace("*", r".*")
600
600
  if re.fullmatch(regex_str, prefix):
601
601
  return True
602
+
603
+ # Check if the last part of the excluded pattern is contained in the last part of the prefix
604
+ # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
605
+ pattern_last_part = pattern.split(".")[-1]
606
+ prefix_last_part = prefix.split(".")[-1]
607
+ if pattern_last_part in prefix_last_part:
608
+ return True
602
609
  return False
603
610
 
604
611
  def get_quant_method(
@@ -66,10 +66,15 @@ _is_hip = is_hip()
66
66
 
67
67
  if _is_hip:
68
68
  # import aiter
69
- from aiter import ActivationType, QuantType, dtypes
70
- from aiter.fused_moe import fused_moe
71
- from aiter.ops.triton.quant import dynamic_mxfp4_quant
72
- from aiter.utility.fp4_utils import e8m0_shuffle
69
+ try:
70
+ from aiter import ActivationType, QuantType, dtypes
71
+ from aiter.fused_moe import fused_moe
72
+ from aiter.ops.triton.quant import dynamic_mxfp4_quant
73
+ from aiter.utility.fp4_utils import e8m0_shuffle
74
+ except ImportError as err:
75
+ ActivationType = QuantType = dtypes = fused_moe = dynamic_mxfp4_quant = (
76
+ e8m0_shuffle
77
+ ) = err
73
78
 
74
79
 
75
80
  def _swizzle_mxfp4(quant_tensor, scale, num_warps):
@@ -77,6 +77,19 @@ def is_layer_skipped(
77
77
  )
78
78
  else:
79
79
  is_skipped = prefix in ignored_layers
80
+ if "gate_up_proj" in prefix:
81
+ prefix_gate = prefix.replace("gate_up_proj", "gate_proj")
82
+ prefix_up = prefix.replace("gate_up_proj", "up_proj")
83
+ if prefix_gate in ignored_layers and prefix_up in ignored_layers:
84
+ is_skipped = True
85
+ elif "experts" in prefix:
86
+ is_skipped = any(
87
+ [
88
+ prefix in layer_name
89
+ for layer_name in ignored_layers
90
+ if "experts" in layer_name
91
+ ]
92
+ )
80
93
 
81
94
  assert is_skipped is not None
82
95
  return is_skipped
@@ -551,7 +551,7 @@ class NPU_W8A8LinearMethodImpl:
551
551
  def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
552
552
  params_dict = {}
553
553
  params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
554
- params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
554
+ params_dict["input_offset"] = torch.empty(1, dtype=params_dtype)
555
555
  return params_dict
556
556
 
557
557
  @staticmethod
@@ -582,11 +582,11 @@ class NPU_W8A8LinearMethodImpl:
582
582
  if original_dtype != torch.int8:
583
583
  x = torch_npu.npu_quantize(
584
584
  x,
585
- layer.aclnn_input_scale,
585
+ layer.aclnn_input_scale_reciprocal,
586
586
  layer.aclnn_input_offset,
587
587
  torch.qint8,
588
588
  -1,
589
- True,
589
+ False,
590
590
  )
591
591
  # Only fuse bias add into GEMM for rank 0 (this ensures that
592
592
  # bias will not get added more than once in Attention TP>1 case)
@@ -608,6 +608,10 @@ class NPU_W8A8LinearMethodImpl:
608
608
  layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
609
609
  requires_grad=False,
610
610
  )
611
+ layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
612
+ layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
613
+ requires_grad=False,
614
+ )
611
615
  layer.aclnn_input_offset = torch.nn.Parameter(
612
616
  layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
613
617
  requires_grad=False,
@@ -1876,7 +1876,7 @@ def rotate_half(x):
1876
1876
  return torch.cat((-x2, x1), dim=-1)
1877
1877
 
1878
1878
 
1879
- def apply_rotary_pos_emb(
1879
+ def apply_rotary_pos_emb_native(
1880
1880
  q: torch.Tensor,
1881
1881
  k: torch.Tensor,
1882
1882
  cos: torch.Tensor,
@@ -1899,6 +1899,33 @@ def apply_rotary_pos_emb(
1899
1899
  return q_embed, k_embed
1900
1900
 
1901
1901
 
1902
+ def apply_rotary_pos_emb_npu(
1903
+ q: torch.Tensor,
1904
+ k: torch.Tensor,
1905
+ cos: torch.Tensor,
1906
+ sin: torch.Tensor,
1907
+ unsqueeze_dim=1,
1908
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1909
+ if q.shape[1] != 128:
1910
+ return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
1911
+ cos = cos.unsqueeze(unsqueeze_dim)
1912
+ cos = torch.transpose(cos, 1, 2)
1913
+ sin = sin.unsqueeze(unsqueeze_dim)
1914
+ sin = torch.transpose(sin, 1, 2)
1915
+ q = torch.transpose(q, 1, 2)
1916
+ k = torch.transpose(k, 1, 2)
1917
+ q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
1918
+ q_embed = torch.transpose(q_embed, 1, 2)
1919
+ k_embed = torch.transpose(k_embed, 1, 2)
1920
+ return q_embed, k_embed
1921
+
1922
+
1923
+ if _is_npu:
1924
+ apply_rotary_pos_emb = apply_rotary_pos_emb_npu
1925
+ else:
1926
+ apply_rotary_pos_emb = apply_rotary_pos_emb_native
1927
+
1928
+
1902
1929
  def get_rope_cpu(
1903
1930
  head_size: int,
1904
1931
  rotary_dim: int,
@@ -27,6 +27,7 @@ if is_cuda():
27
27
  logger = logging.getLogger(__name__)
28
28
 
29
29
  SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
30
+ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
30
31
 
31
32
 
32
33
  class Sampler(nn.Module):
@@ -77,7 +78,12 @@ class Sampler(nn.Module):
77
78
  batch_next_token_ids = torch.argmax(logits, -1)
78
79
  if return_logprob:
79
80
  logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
81
+
80
82
  else:
83
+ # Post process original logits. if temperatures are all 1.0, no need to rescale
84
+ if return_logprob and RETURN_ORIGINAL_LOGPROB:
85
+ logprobs = torch.softmax(logits, dim=-1)
86
+
81
87
  # Post process logits
82
88
  logits.div_(sampling_info.temperatures)
83
89
  logits[:] = torch.softmax(logits, dim=-1)
@@ -116,7 +122,12 @@ class Sampler(nn.Module):
116
122
 
117
123
  if return_logprob:
118
124
  # clamp to avoid -inf
119
- logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
125
+ if RETURN_ORIGINAL_LOGPROB:
126
+ logprobs = torch.log(logprobs).clamp(
127
+ min=torch.finfo(logprobs.dtype).min
128
+ )
129
+ else:
130
+ logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
120
131
 
121
132
  # Attach logprobs to logits_output (in-place modification)
122
133
  if return_logprob:
@@ -201,7 +212,10 @@ def top_p_normalize_probs_torch(
201
212
  return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
202
213
 
203
214
 
204
- def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
215
+ def get_top_logprobs(
216
+ logprobs: torch.Tensor,
217
+ top_logprobs_nums: List[int],
218
+ ):
205
219
  max_k = max(top_logprobs_nums)
206
220
  ret = logprobs.topk(max_k, dim=1)
207
221
  values = ret.values.tolist()
@@ -212,10 +226,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
212
226
  for i, k in enumerate(top_logprobs_nums):
213
227
  output_top_logprobs_val.append(values[i][:k])
214
228
  output_top_logprobs_idx.append(indices[i][:k])
215
- return output_top_logprobs_val, output_top_logprobs_idx
229
+
230
+ return (
231
+ output_top_logprobs_val,
232
+ output_top_logprobs_idx,
233
+ )
216
234
 
217
235
 
218
- def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
236
+ def get_token_ids_logprobs(
237
+ logprobs: torch.Tensor,
238
+ token_ids_logprobs: List[List[int]],
239
+ ):
219
240
  output_token_ids_logprobs_val = []
220
241
  output_token_ids_logprobs_idx = []
221
242
  for i, token_ids in enumerate(token_ids_logprobs):
@@ -226,7 +247,10 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
226
247
  output_token_ids_logprobs_val.append([])
227
248
  output_token_ids_logprobs_idx.append([])
228
249
 
229
- return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
250
+ return (
251
+ output_token_ids_logprobs_val,
252
+ output_token_ids_logprobs_idx,
253
+ )
230
254
 
231
255
 
232
256
  def apply_custom_logit_processor(