sglang 0.4.9.post1__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. sglang/srt/configs/model_config.py +24 -1
  2. sglang/srt/conversation.py +21 -2
  3. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  4. sglang/srt/disaggregation/ascend/conn.py +44 -0
  5. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  6. sglang/srt/disaggregation/mooncake/conn.py +15 -14
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  8. sglang/srt/disaggregation/utils.py +25 -3
  9. sglang/srt/entrypoints/engine.py +1 -1
  10. sglang/srt/entrypoints/http_server.py +1 -0
  11. sglang/srt/entrypoints/openai/protocol.py +11 -0
  12. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  13. sglang/srt/function_call/function_call_parser.py +2 -0
  14. sglang/srt/function_call/kimik2_detector.py +220 -0
  15. sglang/srt/hf_transformers_utils.py +18 -0
  16. sglang/srt/jinja_template_utils.py +8 -0
  17. sglang/srt/layers/communicator.py +17 -4
  18. sglang/srt/layers/linear.py +12 -2
  19. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  20. sglang/srt/layers/moe/ep_moe/layer.py +2 -1
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -2
  22. sglang/srt/layers/moe/topk.py +8 -2
  23. sglang/srt/layers/parameter.py +19 -3
  24. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  25. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  26. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  27. sglang/srt/managers/io_struct.py +27 -2
  28. sglang/srt/managers/mm_utils.py +55 -94
  29. sglang/srt/managers/schedule_batch.py +16 -5
  30. sglang/srt/managers/scheduler.py +21 -1
  31. sglang/srt/managers/tokenizer_manager.py +16 -0
  32. sglang/srt/mem_cache/memory_pool.py +65 -40
  33. sglang/srt/model_executor/forward_batch_info.py +13 -1
  34. sglang/srt/model_loader/loader.py +23 -12
  35. sglang/srt/models/deepseek_janus_pro.py +1 -1
  36. sglang/srt/models/deepseek_v2.py +62 -17
  37. sglang/srt/models/deepseek_vl2.py +1 -1
  38. sglang/srt/models/gemma3_mm.py +1 -1
  39. sglang/srt/models/gemma3n_mm.py +6 -3
  40. sglang/srt/models/internvl.py +8 -2
  41. sglang/srt/models/kimi_vl.py +8 -2
  42. sglang/srt/models/llama.py +2 -0
  43. sglang/srt/models/llava.py +3 -1
  44. sglang/srt/models/llavavid.py +1 -1
  45. sglang/srt/models/minicpmo.py +1 -2
  46. sglang/srt/models/minicpmv.py +1 -1
  47. sglang/srt/models/mixtral_quant.py +4 -0
  48. sglang/srt/models/mllama4.py +13 -4
  49. sglang/srt/models/phi4mm.py +8 -2
  50. sglang/srt/models/phimoe.py +553 -0
  51. sglang/srt/models/qwen2.py +2 -0
  52. sglang/srt/models/qwen2_5_vl.py +10 -7
  53. sglang/srt/models/qwen2_vl.py +12 -1
  54. sglang/srt/models/vila.py +8 -2
  55. sglang/srt/multimodal/processors/base_processor.py +197 -137
  56. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  57. sglang/srt/multimodal/processors/gemma3.py +4 -2
  58. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  59. sglang/srt/multimodal/processors/internvl.py +1 -1
  60. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  61. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  62. sglang/srt/multimodal/processors/minicpm.py +4 -3
  63. sglang/srt/multimodal/processors/mllama4.py +1 -1
  64. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  65. sglang/srt/multimodal/processors/pixtral.py +1 -1
  66. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  67. sglang/srt/multimodal/processors/vila.py +1 -1
  68. sglang/srt/server_args.py +11 -4
  69. sglang/srt/utils.py +154 -31
  70. sglang/version.py +1 -1
  71. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +4 -3
  72. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +75 -70
  73. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  75. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,220 @@
1
+ import json
2
+ import logging
3
+ import re
4
+ from typing import List
5
+
6
+ from sglang.srt.entrypoints.openai.protocol import Tool
7
+ from sglang.srt.function_call.base_format_detector import BaseFormatDetector
8
+ from sglang.srt.function_call.core_types import (
9
+ StreamingParseResult,
10
+ StructureInfo,
11
+ ToolCallItem,
12
+ _GetInfoFunc,
13
+ )
14
+ from sglang.srt.function_call.ebnf_composer import EBNFComposer
15
+ from sglang.srt.function_call.utils import _is_complete_json
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class KimiK2Detector(BaseFormatDetector):
21
+
22
+ def __init__(self):
23
+ super().__init__()
24
+ self._buffer = ""
25
+ self.current_tool_name_sent: bool = False
26
+ self.prev_tool_call_arr: list[dict] = []
27
+ self.current_tool_id: int = -1
28
+ self.streamed_args_for_tool: list[str] = (
29
+ []
30
+ ) # map what has been streamed for each tool so far to a list
31
+
32
+ self.bot_token: str = "<|tool_calls_section_begin|>"
33
+ self.eot_token: str = "<|tool_calls_section_end|>"
34
+
35
+ self.tool_call_start_token: str = "<|tool_call_begin|>"
36
+ self.tool_call_end_token: str = "<|tool_call_end|>"
37
+
38
+ self.tool_call_regex = re.compile(
39
+ r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>"
40
+ )
41
+
42
+ self.stream_tool_call_portion_regex = re.compile(
43
+ r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)"
44
+ )
45
+
46
+ self._last_arguments = ""
47
+
48
+ def has_tool_call(self, text: str) -> bool:
49
+ """Check if the text contains a KimiK2 format tool call."""
50
+ return self.bot_token in text
51
+
52
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
53
+ """
54
+ One-time parsing: Detects and parses tool calls in the provided text.
55
+
56
+ :param text: The complete text to parse.
57
+ :param tools: List of available tools.
58
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
59
+ """
60
+ if self.bot_token not in text:
61
+ return StreamingParseResult(normal_text=text, calls=[])
62
+ try:
63
+ # there are two possible captures - between tags, or between a
64
+ # tag and end-of-string so the result of
65
+ # findall is an array of tuples where one is a function call and
66
+ # the other is None
67
+ function_call_tuples = self.tool_call_regex.findall(text)
68
+
69
+ logger.debug("function_call_tuples: %s", function_call_tuples)
70
+
71
+ tool_calls = []
72
+ for match in function_call_tuples:
73
+ function_id, function_args = match
74
+ function_name = function_id.split(".")[1].split(":")[0]
75
+ function_idx = int(function_id.split(".")[1].split(":")[1])
76
+
77
+ logger.info(f"function_name {function_name}")
78
+
79
+ tool_calls.append(
80
+ ToolCallItem(
81
+ tool_index=function_idx, # Use the call index in the response, not tool position
82
+ name=function_name,
83
+ parameters=function_args,
84
+ )
85
+ )
86
+
87
+ content = text[: text.find(self.bot_token)]
88
+ return StreamingParseResult(normal_text=content, calls=tool_calls)
89
+
90
+ except Exception as e:
91
+ logger.error(f"Error in detect_and_parse: {e}")
92
+ # return the normal text if parsing fails
93
+ return StreamingParseResult(normal_text=text)
94
+
95
+ def parse_streaming_increment(
96
+ self, new_text: str, tools: List[Tool]
97
+ ) -> StreamingParseResult:
98
+ """
99
+ Streaming incremental parsing tool calls for KimiK2 format.
100
+ """
101
+ self._buffer += new_text
102
+ current_text = self._buffer
103
+
104
+ # Check if we have a tool call (either the start token or individual tool call)
105
+ has_tool_call = (
106
+ self.bot_token in current_text or self.tool_call_start_token in current_text
107
+ )
108
+
109
+ if not has_tool_call:
110
+ self._buffer = ""
111
+ for e_token in [self.eot_token, self.tool_call_end_token]:
112
+ if e_token in new_text:
113
+ new_text = new_text.replace(e_token, "")
114
+ return StreamingParseResult(normal_text=new_text)
115
+
116
+ if not hasattr(self, "_tool_indices"):
117
+ self._tool_indices = {
118
+ tool.function.name: i
119
+ for i, tool in enumerate(tools)
120
+ if tool.function and tool.function.name
121
+ }
122
+
123
+ calls: list[ToolCallItem] = []
124
+ try:
125
+ match = self.stream_tool_call_portion_regex.search(current_text)
126
+ if match:
127
+ function_id = match.group("tool_call_id")
128
+ function_args = match.group("function_arguments")
129
+
130
+ function_name = function_id.split(".")[1].split(":")[0]
131
+
132
+ # Initialize state if this is the first tool call
133
+ if self.current_tool_id == -1:
134
+ self.current_tool_id = 0
135
+ self.prev_tool_call_arr = []
136
+ self.streamed_args_for_tool = [""]
137
+
138
+ # Ensure we have enough entries in our tracking arrays
139
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
140
+ self.prev_tool_call_arr.append({})
141
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
142
+ self.streamed_args_for_tool.append("")
143
+
144
+ if not self.current_tool_name_sent:
145
+ calls.append(
146
+ ToolCallItem(
147
+ tool_index=self.current_tool_id,
148
+ name=function_name,
149
+ parameters="",
150
+ )
151
+ )
152
+ self.current_tool_name_sent = True
153
+ # Store the tool call info for adapter.py
154
+ self.prev_tool_call_arr[self.current_tool_id] = {
155
+ "name": function_name,
156
+ "arguments": {},
157
+ }
158
+ else:
159
+ argument_diff = (
160
+ function_args[len(self._last_arguments) :]
161
+ if function_args.startswith(self._last_arguments)
162
+ else function_args
163
+ )
164
+
165
+ parsed_args_diff = argument_diff.split("<|tool_call_end|>", 1)[0]
166
+
167
+ if parsed_args_diff:
168
+
169
+ calls.append(
170
+ ToolCallItem(
171
+ tool_index=self.current_tool_id,
172
+ name=None,
173
+ parameters=parsed_args_diff,
174
+ )
175
+ )
176
+ self._last_arguments += argument_diff
177
+ self.streamed_args_for_tool[
178
+ self.current_tool_id
179
+ ] += parsed_args_diff
180
+
181
+ parsed_args = function_args.split("<|tool_call_end|>", 1)[0]
182
+ if _is_complete_json(parsed_args):
183
+ try:
184
+ parsed_args = json.loads(parsed_args)
185
+ self.prev_tool_call_arr[self.current_tool_id][
186
+ "arguments"
187
+ ] = parsed_args
188
+ except json.JSONDecodeError:
189
+ pass
190
+
191
+ # Find the end of the current tool call and remove only that part from buffer
192
+ tool_call_end_pattern = (
193
+ r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>"
194
+ )
195
+ match = re.search(
196
+ tool_call_end_pattern, current_text, re.DOTALL
197
+ )
198
+ if match:
199
+ # Remove the completed tool call from buffer, keep any remaining content
200
+ self._buffer = current_text[match.end() :]
201
+ else:
202
+ self._buffer = ""
203
+
204
+ result = StreamingParseResult(normal_text="", calls=calls)
205
+ self.current_tool_id += 1
206
+ self._last_arguments = ""
207
+ self.current_tool_name_sent = False
208
+ return result
209
+
210
+ return StreamingParseResult(normal_text="", calls=calls)
211
+
212
+ except Exception as e:
213
+ logger.error(f"Error in parse_streaming_increment: {e}")
214
+ return StreamingParseResult(normal_text=current_text)
215
+
216
+ def structure_info(self) -> _GetInfoFunc:
217
+ raise NotImplementedError()
218
+
219
+ def build_ebnf(self, tools: List[Tool]):
220
+ raise NotImplementedError()
@@ -14,6 +14,7 @@
14
14
  """Utilities for Huggingface Transformers."""
15
15
 
16
16
  import contextlib
17
+ import logging
17
18
  import os
18
19
  import warnings
19
20
  from pathlib import Path
@@ -25,6 +26,7 @@ from transformers import (
25
26
  AutoConfig,
26
27
  AutoProcessor,
27
28
  AutoTokenizer,
29
+ GenerationConfig,
28
30
  PretrainedConfig,
29
31
  PreTrainedTokenizer,
30
32
  PreTrainedTokenizerBase,
@@ -153,6 +155,22 @@ def get_config(
153
155
  return config
154
156
 
155
157
 
158
+ @lru_cache_frozenset(maxsize=32)
159
+ def get_generation_config(
160
+ model: str,
161
+ trust_remote_code: bool,
162
+ revision: Optional[str] = None,
163
+ **kwargs,
164
+ ):
165
+ try:
166
+ return GenerationConfig.from_pretrained(
167
+ model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
168
+ )
169
+ except OSError as e:
170
+ logging.info("model doesn't have generation_config.json")
171
+ return None
172
+
173
+
156
174
  # Models don't use the same configuration key for determining the maximum
157
175
  # context length. Store them here so we can sanely check them.
158
176
  # NOTE: The ordering here is important. Some models have two of these and we
@@ -110,6 +110,7 @@ def process_content_for_template_format(
110
110
  msg_dict: dict,
111
111
  content_format: str,
112
112
  image_data: list,
113
+ video_data: list,
113
114
  audio_data: list,
114
115
  modalities: list,
115
116
  ) -> dict:
@@ -120,6 +121,7 @@ def process_content_for_template_format(
120
121
  msg_dict: Message dictionary with content
121
122
  content_format: 'string' or 'openai' (detected via AST analysis)
122
123
  image_data: List to append extracted image URLs
124
+ video_data: List to append extracted video URLs
123
125
  audio_data: List to append extracted audio URLs
124
126
  modalities: List to append modalities
125
127
 
@@ -143,6 +145,12 @@ def process_content_for_template_format(
143
145
  modalities.append(chunk.get("modalities"))
144
146
  # Normalize to simple 'image' type for template compatibility
145
147
  processed_content_parts.append({"type": "image"})
148
+ elif chunk_type == "video_url":
149
+ video_data.append(chunk["video_url"]["url"])
150
+ if chunk.get("modalities"):
151
+ modalities.append(chunk.get("modalities"))
152
+ # Normalize to simple 'video' type for template compatibility
153
+ processed_content_parts.append({"type": "video"})
146
154
  elif chunk_type == "audio_url":
147
155
  audio_data.append(chunk["audio_url"]["url"])
148
156
  # Normalize to simple 'audio' type
@@ -187,11 +187,24 @@ class LayerCommunicator:
187
187
  if hidden_states.shape[0] == 0:
188
188
  residual = hidden_states
189
189
  else:
190
- if residual is None:
191
- residual = hidden_states
192
- hidden_states = self.input_layernorm(hidden_states)
190
+ if (
191
+ residual is not None
192
+ and hasattr(hidden_states, "_sglang_needs_allreduce_fusion")
193
+ and hidden_states._sglang_needs_allreduce_fusion
194
+ ):
195
+ hidden_states, residual = (
196
+ self.input_layernorm.forward_with_allreduce_fusion(
197
+ hidden_states, residual
198
+ )
199
+ )
193
200
  else:
194
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
201
+ if residual is None:
202
+ residual = hidden_states
203
+ hidden_states = self.input_layernorm(hidden_states)
204
+ else:
205
+ hidden_states, residual = self.input_layernorm(
206
+ hidden_states, residual
207
+ )
195
208
 
196
209
  hidden_states = self._communicate_simple_fn(
197
210
  hidden_states=hidden_states,
@@ -34,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import (
34
34
  from sglang.srt.utils import (
35
35
  cpu_has_amx_support,
36
36
  is_cpu,
37
+ is_npu,
37
38
  set_weight_attrs,
38
39
  use_intel_amx_backend,
39
40
  )
@@ -60,6 +61,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
60
61
 
61
62
  _is_cpu_amx_available = cpu_has_amx_support()
62
63
  _is_cpu = is_cpu()
64
+ _is_npu = is_npu()
63
65
 
64
66
 
65
67
  def adjust_marlin_shard(param, shard_size, shard_offset):
@@ -297,6 +299,14 @@ class ReplicatedLinear(LinearBase):
297
299
  if len(loaded_weight.shape) == 0:
298
300
  loaded_weight = loaded_weight.reshape(1)
299
301
 
302
+ # The per-tensor quant-scale must be 1 dimension
303
+ if _is_npu:
304
+ if param.size() != loaded_weight.size() and param.size(0) == 1:
305
+ if torch.allclose(loaded_weight, loaded_weight[0]):
306
+ loaded_weight = loaded_weight[:1]
307
+ else:
308
+ raise ValueError(f"{loaded_weight} are not all equal")
309
+
300
310
  assert param.size() == loaded_weight.size()
301
311
  param.data.copy_(loaded_weight)
302
312
 
@@ -1357,7 +1367,7 @@ class RowParallelLinear(LinearBase):
1357
1367
  # It does not support additional parameters.
1358
1368
  param.load_row_parallel_weight(loaded_weight)
1359
1369
 
1360
- def forward(self, input_):
1370
+ def forward(self, input_, can_fuse_mlp_allreduce=False):
1361
1371
  if self.input_is_parallel:
1362
1372
  input_parallel = input_
1363
1373
  else:
@@ -1372,7 +1382,7 @@ class RowParallelLinear(LinearBase):
1372
1382
  # bias will not get added more than once in TP>1 case)
1373
1383
  bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1374
1384
  output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
1375
- if self.reduce_results and self.tp_size > 1:
1385
+ if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
1376
1386
  output = tensor_model_parallel_all_reduce(output_parallel)
1377
1387
  else:
1378
1388
  output = output_parallel
@@ -6,6 +6,7 @@ import triton
6
6
 
7
7
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
8
8
  from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda
9
+ from sglang.utils import is_in_ci
9
10
 
10
11
  logger = logging.getLogger(__name__)
11
12
 
@@ -1058,7 +1059,7 @@ def ep_gather(
1058
1059
  input_index: torch.Tensor,
1059
1060
  output_tensor: torch.Tensor,
1060
1061
  ):
1061
- BLOCK_D = 1024 # block size of quantization
1062
+ BLOCK_D = 1024 if not is_in_ci() else 128 # block size of quantization
1062
1063
  num_warps = 2
1063
1064
  num_tokens = output_tensor.shape[0]
1064
1065
  hidden_size = input_tensor.shape[1]
@@ -12,7 +12,6 @@ from sglang.srt.distributed import (
12
12
  )
13
13
  from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
14
14
  from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
15
- from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
16
15
  from sglang.srt.layers.moe.ep_moe.kernels import (
17
16
  ep_gather,
18
17
  ep_scatter,
@@ -65,6 +64,8 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
65
64
  if not _is_npu:
66
65
  from sgl_kernel import silu_and_mul
67
66
 
67
+ from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
68
+
68
69
  if _is_hip:
69
70
  from vllm._custom_ops import scaled_fp8_quant
70
71
 
@@ -518,6 +518,7 @@ class FusedMoE(torch.nn.Module):
518
518
  self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
519
519
  assert self.quant_method is not None
520
520
 
521
+ self.quant_config = quant_config
521
522
  self.quant_method.create_weights(
522
523
  layer=self,
523
524
  num_experts=self.local_num_experts,
@@ -661,7 +662,11 @@ class FusedMoE(torch.nn.Module):
661
662
  ):
662
663
  raise ValueError("expert_data and loaded_weight must be torch.Tensor")
663
664
 
664
- if expert_data.dim() != 2 or loaded_weight.dim() != 2:
665
+ if (
666
+ self.quant_config is not None
667
+ and "modelopt" in self.quant_config.get_name()
668
+ and (expert_data.dim() != 2 or loaded_weight.dim() != 2)
669
+ ):
665
670
  raise ValueError(
666
671
  f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
667
672
  )
@@ -850,7 +855,7 @@ class FusedMoE(torch.nn.Module):
850
855
  return
851
856
 
852
857
  # Case weight scales and zero_points
853
- if "scale" in weight_name or "zero" in weight_name:
858
+ if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name:
854
859
  # load the weight scales and zp based on the quantization scheme
855
860
  # supported weight scales/zp can be found in
856
861
  # FusedMoeWeightScaleSupported
@@ -83,13 +83,18 @@ def fused_topk_cpu(
83
83
  gating_output: torch.Tensor,
84
84
  topk: int,
85
85
  renormalize: bool,
86
+ num_token_non_padded: Optional[torch.Tensor] = None,
87
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
86
88
  ):
87
- return torch.ops.sgl_kernel.topk_softmax_cpu(
89
+ topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
88
90
  hidden_states=hidden_states,
89
91
  gating_output=gating_output,
90
92
  topk=topk,
91
93
  renormalize=renormalize,
92
94
  )
95
+ topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
96
+ _mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
97
+ return topk_weights, topk_ids
93
98
 
94
99
 
95
100
  def fused_topk(
@@ -303,7 +308,7 @@ def biased_grouped_topk_gpu(
303
308
  renormalize: bool,
304
309
  num_expert_group: int = 0,
305
310
  topk_group: int = 0,
306
- compiled: bool = True,
311
+ compiled: bool = not _is_npu,
307
312
  num_fused_shared_experts: int = 0,
308
313
  routed_scaling_factor: Optional[float] = None,
309
314
  num_token_non_padded: Optional[torch.Tensor] = None,
@@ -411,6 +416,7 @@ if _is_cpu and _is_cpu_amx_available:
411
416
  biased_grouped_topk = biased_grouped_topk_cpu
412
417
  grouped_topk = grouped_topk_cpu
413
418
  fused_topk_native = fused_topk_cpu
419
+ fused_topk = fused_topk_cpu
414
420
  else:
415
421
  biased_grouped_topk = biased_grouped_topk_gpu
416
422
  grouped_topk = grouped_topk_gpu
@@ -187,10 +187,26 @@ class _ColumnvLLMParameter(BasevLLMParameter):
187
187
  param_data = self.data
188
188
  shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
189
189
  param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
190
- if not use_presharded_weights:
191
- loaded_weight = loaded_weight.narrow(
192
- self.output_dim, shard_id * shard_size, shard_size
190
+
191
+ if _is_cpu:
192
+ from sglang.srt.model_loader.weight_utils import (
193
+ narrow_padded_param_and_loaded_weight,
194
+ )
195
+
196
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
197
+ param_data,
198
+ loaded_weight,
199
+ 0, # param_data_start
200
+ shard_id * shard_size,
201
+ self.output_dim,
202
+ shard_size,
203
+ not use_presharded_weights,
193
204
  )
205
+ else:
206
+ if not use_presharded_weights:
207
+ loaded_weight = loaded_weight.narrow(
208
+ self.output_dim, shard_id * shard_size, shard_size
209
+ )
194
210
 
195
211
  assert (
196
212
  param_data.shape == loaded_weight.shape
@@ -160,8 +160,8 @@ def _per_token_group_quant_fp8_colmajor(
160
160
  """
161
161
  # Map the program id to the row of X and Y it should compute.
162
162
  g_id = tl.program_id(0)
163
- y_ptr += g_id * group_size
164
- y_q_ptr += g_id * group_size
163
+ y_ptr += g_id.to(tl.int64) * group_size
164
+ y_q_ptr += g_id.to(tl.int64) * group_size
165
165
 
166
166
  # Convert g_id the flattened block coordinate to 2D so we can index
167
167
  # into the output y_scales matrix
@@ -116,8 +116,7 @@ class MoeWNA16Config(QuantizationConfig):
116
116
 
117
117
  @classmethod
118
118
  def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
119
- can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
120
- if can_convert and user_quant == "moe_wna16":
119
+ if user_quant == "moe_wna16" and cls.is_moe_wna16_compatible(hf_quant_cfg):
121
120
  return cls.get_name()
122
121
  return None
123
122