sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +2 -0
  3. sglang/srt/configs/longcat_flash.py +104 -0
  4. sglang/srt/configs/model_config.py +14 -1
  5. sglang/srt/connector/__init__.py +1 -1
  6. sglang/srt/connector/base_connector.py +1 -2
  7. sglang/srt/connector/redis.py +2 -2
  8. sglang/srt/connector/serde/__init__.py +1 -1
  9. sglang/srt/connector/serde/safe_serde.py +4 -3
  10. sglang/srt/disaggregation/ascend/conn.py +75 -0
  11. sglang/srt/disaggregation/launch_lb.py +0 -13
  12. sglang/srt/disaggregation/mini_lb.py +33 -8
  13. sglang/srt/disaggregation/prefill.py +1 -1
  14. sglang/srt/distributed/parallel_state.py +27 -15
  15. sglang/srt/entrypoints/engine.py +19 -12
  16. sglang/srt/entrypoints/http_server.py +174 -34
  17. sglang/srt/entrypoints/openai/protocol.py +60 -0
  18. sglang/srt/eplb/eplb_manager.py +26 -2
  19. sglang/srt/eplb/expert_distribution.py +29 -2
  20. sglang/srt/hf_transformers_utils.py +10 -0
  21. sglang/srt/layers/activation.py +12 -0
  22. sglang/srt/layers/attention/ascend_backend.py +240 -109
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  24. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  25. sglang/srt/layers/layernorm.py +28 -3
  26. sglang/srt/layers/linear.py +3 -2
  27. sglang/srt/layers/logits_processor.py +1 -1
  28. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  29. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  30. sglang/srt/layers/moe/ep_moe/layer.py +14 -13
  31. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  32. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  34. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  37. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  38. sglang/srt/layers/moe/topk.py +35 -12
  39. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  40. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  41. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  42. sglang/srt/layers/quantization/mxfp4.py +9 -4
  43. sglang/srt/layers/quantization/utils.py +13 -0
  44. sglang/srt/layers/quantization/w4afp8.py +30 -25
  45. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  46. sglang/srt/layers/rotary_embedding.py +28 -1
  47. sglang/srt/layers/sampler.py +29 -5
  48. sglang/srt/managers/cache_controller.py +62 -96
  49. sglang/srt/managers/detokenizer_manager.py +9 -2
  50. sglang/srt/managers/io_struct.py +27 -0
  51. sglang/srt/managers/mm_utils.py +5 -1
  52. sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
  53. sglang/srt/managers/scheduler.py +39 -2
  54. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  55. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  56. sglang/srt/managers/tokenizer_manager.py +86 -39
  57. sglang/srt/mem_cache/chunk_cache.py +1 -1
  58. sglang/srt/mem_cache/hicache_storage.py +20 -3
  59. sglang/srt/mem_cache/hiradix_cache.py +94 -71
  60. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  61. sglang/srt/mem_cache/memory_pool.py +4 -0
  62. sglang/srt/mem_cache/memory_pool_host.py +4 -4
  63. sglang/srt/mem_cache/radix_cache.py +5 -4
  64. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  65. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  66. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
  67. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
  68. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  69. sglang/srt/model_executor/model_runner.py +5 -4
  70. sglang/srt/model_loader/loader.py +15 -24
  71. sglang/srt/model_loader/utils.py +12 -0
  72. sglang/srt/models/deepseek_v2.py +31 -10
  73. sglang/srt/models/gpt_oss.py +5 -18
  74. sglang/srt/models/llama_eagle3.py +4 -0
  75. sglang/srt/models/longcat_flash.py +1026 -0
  76. sglang/srt/models/longcat_flash_nextn.py +699 -0
  77. sglang/srt/models/qwen2.py +26 -3
  78. sglang/srt/models/qwen2_5_vl.py +65 -41
  79. sglang/srt/models/qwen2_moe.py +22 -2
  80. sglang/srt/models/transformers.py +1 -1
  81. sglang/srt/multimodal/processors/base_processor.py +4 -2
  82. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  83. sglang/srt/server_args.py +112 -55
  84. sglang/srt/speculative/eagle_worker.py +28 -8
  85. sglang/srt/utils.py +4 -0
  86. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  87. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  88. sglang/version.py +1 -1
  89. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
  90. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
  91. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
  92. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
  93. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -175,6 +175,8 @@ class FusedMoE(torch.nn.Module):
175
175
  self.moe_tp_rank = get_moe_tensor_parallel_rank()
176
176
  assert num_experts % self.moe_ep_size == 0
177
177
  self.num_local_experts = num_experts // self.moe_ep_size
178
+ self.start_expert_id = self.moe_ep_rank * self.num_local_experts
179
+ self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
178
180
  if self.moe_ep_size > 1:
179
181
  # TODO(ch-wan): support shared experts fusion
180
182
  # Create a tensor of size num_experts filled with -1
@@ -593,8 +595,9 @@ class FusedMoE(torch.nn.Module):
593
595
 
594
596
  if (
595
597
  "compressed" in self.quant_method.__class__.__name__.lower()
596
- and param.data[expert_id] != 1
597
- and (param.data[expert_id] - loaded_weight).abs() > 1e-5
598
+ or "w4afp8" in self.quant_config.get_name()
599
+ and (param.data[expert_id] != 1).any()
600
+ and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
598
601
  ):
599
602
  raise ValueError(
600
603
  "input_scales of w1 and w3 of a layer "
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import triton
7
+
8
+ from sglang.srt.utils import is_cuda, is_hip
9
+
10
+ _is_cuda = is_cuda()
11
+ _is_hip = is_hip()
12
+
13
+ if _is_cuda or _is_hip:
14
+ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
15
+
16
+
17
+ def moe_align_block_size(
18
+ topk_ids: torch.Tensor, block_size: int, num_experts: int
19
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
20
+ """
21
+ Aligns the token distribution across experts to be compatible with block
22
+ size for matrix multiplication.
23
+
24
+ Parameters:
25
+ - topk_ids: A tensor of shape [total_tokens, top_k] representing the
26
+ top-k expert indices for each token.
27
+ - block_size: The block size used in block matrix multiplication.
28
+ - num_experts: The total number of experts.
29
+
30
+ Returns:
31
+ - sorted_token_ids: A tensor containing the sorted token indices according
32
+ to their allocated expert.
33
+ - expert_ids: A tensor indicating the assigned expert index for each block.
34
+ - num_tokens_post_padded: The total number of tokens after padding,
35
+ ensuring divisibility by block_size.
36
+
37
+ This function pads the number of tokens that each expert needs to process
38
+ so that it is divisible by block_size.
39
+ Padding ensures that during block matrix multiplication, the dimensions
40
+ align correctly.
41
+
42
+ Example:
43
+ Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
44
+ block_size = 4, and num_experts = 4:
45
+ - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
46
+ with each expert needing to process 3 tokens.
47
+ - As block_size is 4, we pad 1 token for each expert.
48
+ - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
49
+ - Then append padding tokens [12, 12, 12, 12] for each block.
50
+ - After sorting by expert index, we obtain token_ids
51
+ [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
52
+ Tokens 12 are non-existent (padding) and are ignored in
53
+ the subsequent matrix multiplication.
54
+ - The padding ensures that the total number of tokens is now divisible
55
+ by block_size for proper block matrix operations.
56
+ """
57
+ max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
58
+ sorted_ids = torch.empty(
59
+ (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
60
+ )
61
+ max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
62
+ expert_ids = torch.empty(
63
+ (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
64
+ )
65
+ num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
66
+
67
+ # In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
68
+ cumsum_buffer = torch.empty(
69
+ (num_experts + 2,), dtype=torch.int32, device=topk_ids.device
70
+ )
71
+
72
+ # Threshold based on benchmark results
73
+ fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
74
+ if not fuse_sorted_ids_padding:
75
+ sorted_ids.fill_(topk_ids.numel())
76
+
77
+ sgl_moe_align_block_size(
78
+ topk_ids,
79
+ num_experts + 1,
80
+ block_size,
81
+ sorted_ids,
82
+ expert_ids,
83
+ num_tokens_post_pad,
84
+ cumsum_buffer,
85
+ fuse_sorted_ids_padding,
86
+ )
87
+ return sorted_ids, expert_ids, num_tokens_post_pad
@@ -304,12 +304,12 @@ class TopK(CustomOp):
304
304
  global_num_experts = router_logits.shape[-1]
305
305
 
306
306
  # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
307
- if global_num_experts == 256 and self.topk_config.renormalize is True:
307
+ if global_num_experts == 256:
308
308
 
309
309
  routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
310
310
  router_logits = router_logits.to(torch.float32)
311
311
 
312
- return torch_npu.npu_moe_gating_top_k(
312
+ topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
313
313
  router_logits,
314
314
  k=self.topk_config.top_k,
315
315
  bias=self.topk_config.correction_bias.to(torch.float32),
@@ -321,6 +321,16 @@ class TopK(CustomOp):
321
321
  routed_scaling_factor=routed_scaling_factor,
322
322
  eps=float(1e-20),
323
323
  )
324
+
325
+ if self.topk_config.renormalize:
326
+ topk_weights_sum = (
327
+ topk_weights.sum(dim=-1, keepdim=True)
328
+ if self.topk_config.num_fused_shared_experts == 0
329
+ else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
330
+ )
331
+ topk_weights = topk_weights / topk_weights_sum
332
+
333
+ return StandardTopKOutput(topk_weights, topk_ids, _)
324
334
  else:
325
335
  self.topk_config.torch_native = True
326
336
  return select_experts(
@@ -347,17 +357,28 @@ def fused_topk_torch_native(
347
357
  gating_output: torch.Tensor,
348
358
  topk: int,
349
359
  renormalize: bool,
360
+ correction_bias: torch.Tensor = None,
350
361
  ):
351
- assert (
352
- hidden_states.shape[0] == gating_output.shape[0]
353
- ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
354
- M, _ = hidden_states.shape
355
- topk_weights = torch.empty(
356
- M, topk, dtype=torch.float32, device=hidden_states.device
357
- )
358
- topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
359
- topk_weights = F.softmax(gating_output.float(), dim=-1)
360
- topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
362
+ if correction_bias is not None:
363
+ n_routed_experts = gating_output.shape[-1]
364
+ scores = gating_output.softmax(dim=-1)
365
+ scores_for_choice = scores.view(
366
+ -1, n_routed_experts
367
+ ) + correction_bias.unsqueeze(0)
368
+ topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
369
+ topk_weights = scores.gather(1, topk_ids)
370
+ else:
371
+ assert (
372
+ hidden_states.shape[0] == gating_output.shape[0]
373
+ ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
374
+ M, _ = hidden_states.shape
375
+ topk_weights = torch.empty(
376
+ M, topk, dtype=torch.float32, device=hidden_states.device
377
+ )
378
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
379
+ topk_weights = F.softmax(gating_output.float(), dim=-1)
380
+ topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
381
+
361
382
  if renormalize:
362
383
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
363
384
  return topk_weights, topk_ids
@@ -370,6 +391,7 @@ def fused_topk_cpu(
370
391
  renormalize: bool,
371
392
  num_token_non_padded: Optional[torch.Tensor] = None,
372
393
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
394
+ correction_bias: torch.Tensor = None,
373
395
  ):
374
396
  topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
375
397
  hidden_states=hidden_states,
@@ -815,6 +837,7 @@ def select_experts(
815
837
  gating_output=router_logits,
816
838
  topk=top_k,
817
839
  renormalize=renormalize,
840
+ correction_bias=correction_bias,
818
841
  )
819
842
  elif custom_routing_function is None:
820
843
  assert not apply_routed_scaling_factor_on_output, "Not implemented"
@@ -93,7 +93,7 @@ def _maybe_compile_deep_gemm_one_type_all(
93
93
  if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
94
94
  logger.warning(
95
95
  "Entering DeepGEMM JIT Pre-Compile session. "
96
- "It may takes a long time (typically 10-20 mins) "
96
+ "It may take a long time (typically 10-20 mins) "
97
97
  "if you have not run `sglang.compile_deep_gemm`. "
98
98
  "It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
99
99
  " for pre-compilation to reduce the overhead if you have not run it before. "
@@ -132,9 +132,17 @@ def _compile_deep_gemm_one_type_all(
132
132
  kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
133
133
  )
134
134
 
135
+ old_compile_mode = deep_gemm.get_compile_mode()
136
+ deep_gemm.set_compile_mode(1)
135
137
  # TODO can use multi thread
136
138
  for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
137
139
  executor.execute(m=m)
140
+ deep_gemm.set_compile_mode(old_compile_mode)
141
+
142
+ # clean up input buffers
143
+ torch.cuda.current_stream().synchronize()
144
+ del executor
145
+ torch.cuda.empty_cache()
138
146
 
139
147
 
140
148
  class _BaseWarmupExecutor:
@@ -11,9 +11,6 @@ def _compute_enable_deep_gemm():
11
11
  sm_version = get_device_sm()
12
12
  if sm_version < 90:
13
13
  return False
14
- # TODO fix deepgemm cu129 fp8 issue
15
- if torch.version.cuda == "12.9":
16
- return False
17
14
 
18
15
  try:
19
16
  import deep_gemm
@@ -599,6 +599,13 @@ class ModelOptFp4Config(QuantizationConfig):
599
599
  regex_str = pattern.replace(".", r"\.").replace("*", r".*")
600
600
  if re.fullmatch(regex_str, prefix):
601
601
  return True
602
+
603
+ # Check if the last part of the excluded pattern is contained in the last part of the prefix
604
+ # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
605
+ pattern_last_part = pattern.split(".")[-1]
606
+ prefix_last_part = prefix.split(".")[-1]
607
+ if pattern_last_part in prefix_last_part:
608
+ return True
602
609
  return False
603
610
 
604
611
  def get_quant_method(
@@ -66,10 +66,15 @@ _is_hip = is_hip()
66
66
 
67
67
  if _is_hip:
68
68
  # import aiter
69
- from aiter import ActivationType, QuantType, dtypes
70
- from aiter.fused_moe import fused_moe
71
- from aiter.ops.triton.quant import dynamic_mxfp4_quant
72
- from aiter.utility.fp4_utils import e8m0_shuffle
69
+ try:
70
+ from aiter import ActivationType, QuantType, dtypes
71
+ from aiter.fused_moe import fused_moe
72
+ from aiter.ops.triton.quant import dynamic_mxfp4_quant
73
+ from aiter.utility.fp4_utils import e8m0_shuffle
74
+ except ImportError as err:
75
+ ActivationType = QuantType = dtypes = fused_moe = dynamic_mxfp4_quant = (
76
+ e8m0_shuffle
77
+ ) = err
73
78
 
74
79
 
75
80
  def _swizzle_mxfp4(quant_tensor, scale, num_warps):
@@ -77,6 +77,19 @@ def is_layer_skipped(
77
77
  )
78
78
  else:
79
79
  is_skipped = prefix in ignored_layers
80
+ if "gate_up_proj" in prefix:
81
+ prefix_gate = prefix.replace("gate_up_proj", "gate_proj")
82
+ prefix_up = prefix.replace("gate_up_proj", "up_proj")
83
+ if prefix_gate in ignored_layers and prefix_up in ignored_layers:
84
+ is_skipped = True
85
+ elif "experts" in prefix:
86
+ is_skipped = any(
87
+ [
88
+ prefix in layer_name
89
+ for layer_name in ignored_layers
90
+ if "experts" in layer_name
91
+ ]
92
+ )
80
93
 
81
94
  assert is_skipped is not None
82
95
  return is_skipped
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, Any, Dict, List, Optional
4
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
5
5
 
6
6
  import torch
7
7
  from torch.nn import Module
8
8
  from torch.nn.parameter import Parameter
9
9
 
10
+ from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
11
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
10
12
  from sglang.srt.layers.quantization.base_config import (
11
13
  FusedMoEMethodBase,
12
14
  QuantizationConfig,
@@ -91,12 +93,13 @@ class W4AFp8Config(QuantizationConfig):
91
93
  from sglang.srt.layers.linear import LinearBase
92
94
  from sglang.srt.layers.moe.ep_moe.layer import EPMoE
93
95
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
96
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
94
97
 
95
98
  if isinstance(layer, LinearBase):
96
99
  if is_layer_skipped(prefix, self.ignored_layers):
97
100
  return UnquantizedLinearMethod()
98
101
  return Fp8LinearMethod(self)
99
- elif isinstance(layer, EPMoE):
102
+ elif isinstance(layer, FusedMoE):
100
103
  return W4AFp8MoEMethod(self)
101
104
  return None
102
105
 
@@ -104,8 +107,24 @@ class W4AFp8Config(QuantizationConfig):
104
107
  return []
105
108
 
106
109
 
107
- class W4AFp8MoEMethod(FusedMoEMethodBase):
110
+ def interleave_scales(scales: torch.Tensor) -> torch.Tensor:
111
+ """Interleave scales in groups of 4 similar to TRT-LLM implementation."""
112
+ s_shape = scales.shape
113
+ # Reshape to separate groups of 4
114
+ alignment = 4 if s_shape[2] % 4 == 0 else 1
115
+ scales_interleaved = scales.reshape(
116
+ s_shape[0], s_shape[1], (s_shape[2] // alignment), alignment
117
+ )
118
+ # Permute dimensions to interleave
119
+ scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
120
+ # Reshape back to original dimensions but with interleaved values
121
+ scales_interleaved = scales_interleaved.reshape(
122
+ s_shape[0], s_shape[2] // alignment, s_shape[1] * alignment
123
+ )
124
+ return scales_interleaved.contiguous()
125
+
108
126
 
127
+ class W4AFp8MoEMethod(FusedMoEMethodBase):
109
128
  def __init__(self, quant_config: W4AFp8Config):
110
129
  self.quant_config = quant_config
111
130
 
@@ -234,33 +253,18 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
234
253
 
235
254
  return
236
255
 
237
- def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
238
- """Interleave scales in groups of 4 similar to TRT-LLM implementation."""
239
- s_shape = scales.shape
240
- # Reshape to separate groups of 4
241
- scales_interleaved = scales.reshape(
242
- s_shape[0], s_shape[1], (s_shape[2] // 4), 4
243
- )
244
- # Permute dimensions to interleave
245
- scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
246
- # Reshape back to original dimensions but with interleaved values
247
- scales_interleaved = scales_interleaved.reshape(
248
- s_shape[0], s_shape[2] // 4, s_shape[1] * 4
249
- )
250
- return scales_interleaved.contiguous()
251
-
252
256
  def process_weights_after_loading(self, layer: Module) -> None:
253
257
  dtype = torch.bfloat16
254
258
  device = layer.w2_weight.device
255
259
 
256
260
  # Interleave w13_weight_scale (gate_up_proj)
257
261
  w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
258
- w13_weight_scale = self._interleave_scales(w13_weight_scale)
262
+ w13_weight_scale = interleave_scales(w13_weight_scale)
259
263
  layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
260
264
 
261
265
  # Interleave w2_weight_scale (down_proj)
262
266
  w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
263
- w2_weight_scale = self._interleave_scales(w2_weight_scale)
267
+ w2_weight_scale = interleave_scales(w2_weight_scale)
264
268
  layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
265
269
 
266
270
  # Process input scales
@@ -291,11 +295,12 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
291
295
 
292
296
  topk_weights, topk_ids, _ = topk_output
293
297
  local_topk_ids = topk_ids
294
- local_topk_ids = torch.where(
295
- topk_ids == -1,
296
- layer.num_experts,
297
- topk_ids,
298
- )
298
+ if get_moe_expert_parallel_world_size() > 1:
299
+ local_topk_ids = torch.where(
300
+ topk_ids == -1,
301
+ layer.num_experts,
302
+ topk_ids,
303
+ )
299
304
 
300
305
  output = cutlass_w4a8_moe(
301
306
  layer.start_expert_id,
@@ -551,7 +551,7 @@ class NPU_W8A8LinearMethodImpl:
551
551
  def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
552
552
  params_dict = {}
553
553
  params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
554
- params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
554
+ params_dict["input_offset"] = torch.empty(1, dtype=params_dtype)
555
555
  return params_dict
556
556
 
557
557
  @staticmethod
@@ -582,11 +582,11 @@ class NPU_W8A8LinearMethodImpl:
582
582
  if original_dtype != torch.int8:
583
583
  x = torch_npu.npu_quantize(
584
584
  x,
585
- layer.aclnn_input_scale,
585
+ layer.aclnn_input_scale_reciprocal,
586
586
  layer.aclnn_input_offset,
587
587
  torch.qint8,
588
588
  -1,
589
- True,
589
+ False,
590
590
  )
591
591
  # Only fuse bias add into GEMM for rank 0 (this ensures that
592
592
  # bias will not get added more than once in Attention TP>1 case)
@@ -608,6 +608,10 @@ class NPU_W8A8LinearMethodImpl:
608
608
  layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
609
609
  requires_grad=False,
610
610
  )
611
+ layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
612
+ layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
613
+ requires_grad=False,
614
+ )
611
615
  layer.aclnn_input_offset = torch.nn.Parameter(
612
616
  layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
613
617
  requires_grad=False,
@@ -1876,7 +1876,7 @@ def rotate_half(x):
1876
1876
  return torch.cat((-x2, x1), dim=-1)
1877
1877
 
1878
1878
 
1879
- def apply_rotary_pos_emb(
1879
+ def apply_rotary_pos_emb_native(
1880
1880
  q: torch.Tensor,
1881
1881
  k: torch.Tensor,
1882
1882
  cos: torch.Tensor,
@@ -1899,6 +1899,33 @@ def apply_rotary_pos_emb(
1899
1899
  return q_embed, k_embed
1900
1900
 
1901
1901
 
1902
+ def apply_rotary_pos_emb_npu(
1903
+ q: torch.Tensor,
1904
+ k: torch.Tensor,
1905
+ cos: torch.Tensor,
1906
+ sin: torch.Tensor,
1907
+ unsqueeze_dim=1,
1908
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1909
+ if q.shape[1] != 128:
1910
+ return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
1911
+ cos = cos.unsqueeze(unsqueeze_dim)
1912
+ cos = torch.transpose(cos, 1, 2)
1913
+ sin = sin.unsqueeze(unsqueeze_dim)
1914
+ sin = torch.transpose(sin, 1, 2)
1915
+ q = torch.transpose(q, 1, 2)
1916
+ k = torch.transpose(k, 1, 2)
1917
+ q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
1918
+ q_embed = torch.transpose(q_embed, 1, 2)
1919
+ k_embed = torch.transpose(k_embed, 1, 2)
1920
+ return q_embed, k_embed
1921
+
1922
+
1923
+ if _is_npu:
1924
+ apply_rotary_pos_emb = apply_rotary_pos_emb_npu
1925
+ else:
1926
+ apply_rotary_pos_emb = apply_rotary_pos_emb_native
1927
+
1928
+
1902
1929
  def get_rope_cpu(
1903
1930
  head_size: int,
1904
1931
  rotary_dim: int,
@@ -27,6 +27,7 @@ if is_cuda():
27
27
  logger = logging.getLogger(__name__)
28
28
 
29
29
  SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
30
+ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
30
31
 
31
32
 
32
33
  class Sampler(nn.Module):
@@ -77,7 +78,12 @@ class Sampler(nn.Module):
77
78
  batch_next_token_ids = torch.argmax(logits, -1)
78
79
  if return_logprob:
79
80
  logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
81
+
80
82
  else:
83
+ # Post process original logits. if temperatures are all 1.0, no need to rescale
84
+ if return_logprob and RETURN_ORIGINAL_LOGPROB:
85
+ logprobs = torch.softmax(logits, dim=-1)
86
+
81
87
  # Post process logits
82
88
  logits.div_(sampling_info.temperatures)
83
89
  logits[:] = torch.softmax(logits, dim=-1)
@@ -116,7 +122,12 @@ class Sampler(nn.Module):
116
122
 
117
123
  if return_logprob:
118
124
  # clamp to avoid -inf
119
- logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
125
+ if RETURN_ORIGINAL_LOGPROB:
126
+ logprobs = torch.log(logprobs).clamp(
127
+ min=torch.finfo(logprobs.dtype).min
128
+ )
129
+ else:
130
+ logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
120
131
 
121
132
  # Attach logprobs to logits_output (in-place modification)
122
133
  if return_logprob:
@@ -201,7 +212,10 @@ def top_p_normalize_probs_torch(
201
212
  return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
202
213
 
203
214
 
204
- def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
215
+ def get_top_logprobs(
216
+ logprobs: torch.Tensor,
217
+ top_logprobs_nums: List[int],
218
+ ):
205
219
  max_k = max(top_logprobs_nums)
206
220
  ret = logprobs.topk(max_k, dim=1)
207
221
  values = ret.values.tolist()
@@ -212,10 +226,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
212
226
  for i, k in enumerate(top_logprobs_nums):
213
227
  output_top_logprobs_val.append(values[i][:k])
214
228
  output_top_logprobs_idx.append(indices[i][:k])
215
- return output_top_logprobs_val, output_top_logprobs_idx
229
+
230
+ return (
231
+ output_top_logprobs_val,
232
+ output_top_logprobs_idx,
233
+ )
216
234
 
217
235
 
218
- def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
236
+ def get_token_ids_logprobs(
237
+ logprobs: torch.Tensor,
238
+ token_ids_logprobs: List[List[int]],
239
+ ):
219
240
  output_token_ids_logprobs_val = []
220
241
  output_token_ids_logprobs_idx = []
221
242
  for i, token_ids in enumerate(token_ids_logprobs):
@@ -226,7 +247,10 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
226
247
  output_token_ids_logprobs_val.append([])
227
248
  output_token_ids_logprobs_idx.append([])
228
249
 
229
- return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
250
+ return (
251
+ output_token_ids_logprobs_val,
252
+ output_token_ids_logprobs_idx,
253
+ )
230
254
 
231
255
 
232
256
  def apply_custom_logit_processor(