sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +36 -2
  3. sglang/srt/conversation.py +56 -3
  4. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  5. sglang/srt/disaggregation/ascend/conn.py +44 -0
  6. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +50 -18
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  9. sglang/srt/disaggregation/utils.py +25 -3
  10. sglang/srt/entrypoints/engine.py +1 -1
  11. sglang/srt/entrypoints/http_server.py +1 -0
  12. sglang/srt/entrypoints/http_server_engine.py +1 -1
  13. sglang/srt/entrypoints/openai/protocol.py +11 -0
  14. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  15. sglang/srt/function_call/function_call_parser.py +2 -0
  16. sglang/srt/function_call/kimik2_detector.py +220 -0
  17. sglang/srt/hf_transformers_utils.py +18 -0
  18. sglang/srt/jinja_template_utils.py +8 -0
  19. sglang/srt/layers/communicator.py +20 -5
  20. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  21. sglang/srt/layers/layernorm.py +2 -2
  22. sglang/srt/layers/linear.py +12 -2
  23. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  24. sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
  25. sglang/srt/layers/moe/ep_moe/layer.py +141 -2
  26. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  29. sglang/srt/layers/moe/topk.py +8 -2
  30. sglang/srt/layers/parameter.py +19 -3
  31. sglang/srt/layers/quantization/__init__.py +2 -0
  32. sglang/srt/layers/quantization/fp8.py +28 -7
  33. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  35. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  36. sglang/srt/layers/quantization/w4afp8.py +264 -0
  37. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  38. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  39. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  40. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  41. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  42. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  43. sglang/srt/managers/cache_controller.py +41 -195
  44. sglang/srt/managers/io_struct.py +35 -3
  45. sglang/srt/managers/mm_utils.py +59 -96
  46. sglang/srt/managers/schedule_batch.py +17 -6
  47. sglang/srt/managers/scheduler.py +38 -6
  48. sglang/srt/managers/tokenizer_manager.py +16 -0
  49. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  50. sglang/srt/mem_cache/memory_pool.py +176 -101
  51. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  52. sglang/srt/mem_cache/radix_cache.py +8 -4
  53. sglang/srt/model_executor/forward_batch_info.py +13 -1
  54. sglang/srt/model_loader/loader.py +23 -12
  55. sglang/srt/models/deepseek_janus_pro.py +1 -1
  56. sglang/srt/models/deepseek_v2.py +78 -19
  57. sglang/srt/models/deepseek_vl2.py +1 -1
  58. sglang/srt/models/gemma3_mm.py +1 -1
  59. sglang/srt/models/gemma3n_mm.py +6 -3
  60. sglang/srt/models/internvl.py +8 -2
  61. sglang/srt/models/kimi_vl.py +8 -2
  62. sglang/srt/models/llama.py +2 -0
  63. sglang/srt/models/llava.py +3 -1
  64. sglang/srt/models/llavavid.py +1 -1
  65. sglang/srt/models/minicpmo.py +1 -2
  66. sglang/srt/models/minicpmv.py +1 -1
  67. sglang/srt/models/mixtral_quant.py +4 -0
  68. sglang/srt/models/mllama4.py +372 -82
  69. sglang/srt/models/phi4mm.py +8 -2
  70. sglang/srt/models/phimoe.py +553 -0
  71. sglang/srt/models/qwen2.py +2 -0
  72. sglang/srt/models/qwen2_5_vl.py +10 -7
  73. sglang/srt/models/qwen2_vl.py +12 -1
  74. sglang/srt/models/vila.py +8 -2
  75. sglang/srt/multimodal/mm_utils.py +2 -2
  76. sglang/srt/multimodal/processors/base_processor.py +197 -137
  77. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  78. sglang/srt/multimodal/processors/gemma3.py +4 -2
  79. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  80. sglang/srt/multimodal/processors/internvl.py +1 -1
  81. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  82. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  83. sglang/srt/multimodal/processors/minicpm.py +4 -3
  84. sglang/srt/multimodal/processors/mllama4.py +63 -61
  85. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  86. sglang/srt/multimodal/processors/pixtral.py +1 -1
  87. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  88. sglang/srt/multimodal/processors/vila.py +1 -1
  89. sglang/srt/server_args.py +26 -4
  90. sglang/srt/two_batch_overlap.py +3 -0
  91. sglang/srt/utils.py +191 -48
  92. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  93. sglang/utils.py +5 -5
  94. sglang/version.py +1 -1
  95. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
  96. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
  97. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -82,6 +82,7 @@ class OpenAIServingChat(OpenAIServingBase):
82
82
  adapted_request = GenerateReqInput(
83
83
  **prompt_kwargs,
84
84
  image_data=processed_messages.image_data,
85
+ video_data=processed_messages.video_data,
85
86
  audio_data=processed_messages.audio_data,
86
87
  sampling_params=sampling_params,
87
88
  return_logprob=request.logprobs,
@@ -143,6 +144,7 @@ class OpenAIServingChat(OpenAIServingBase):
143
144
  prompt_ids = []
144
145
  openai_compatible_messages = []
145
146
  image_data = []
147
+ video_data = []
146
148
  audio_data = []
147
149
  modalities = []
148
150
 
@@ -158,6 +160,7 @@ class OpenAIServingChat(OpenAIServingBase):
158
160
  msg_dict,
159
161
  template_content_format,
160
162
  image_data,
163
+ video_data,
161
164
  audio_data,
162
165
  modalities,
163
166
  )
@@ -214,11 +217,13 @@ class OpenAIServingChat(OpenAIServingBase):
214
217
  stop = request.stop
215
218
  image_data = image_data if image_data else None
216
219
  audio_data = audio_data if audio_data else None
220
+ video_data = video_data if video_data else None
217
221
  modalities = modalities if modalities else []
218
222
  return MessageProcessingResult(
219
223
  prompt=prompt,
220
224
  prompt_ids=prompt_ids,
221
225
  image_data=image_data,
226
+ video_data=video_data,
222
227
  audio_data=audio_data,
223
228
  modalities=modalities,
224
229
  stop=stop,
@@ -260,6 +265,7 @@ class OpenAIServingChat(OpenAIServingBase):
260
265
  prompt = conv.get_prompt()
261
266
 
262
267
  image_data = conv.image_data if conv.image_data else None
268
+ video_data = conv.video_data if conv.video_data else None
263
269
  audio_data = conv.audio_data if conv.audio_data else None
264
270
  modalities = conv.modalities if conv.modalities else []
265
271
  stop = copy.copy(conv.stop_str or [] if not request.ignore_eos else [])
@@ -277,6 +283,7 @@ class OpenAIServingChat(OpenAIServingBase):
277
283
  prompt=prompt,
278
284
  prompt_ids=prompt_ids,
279
285
  image_data=image_data,
286
+ video_data=video_data,
280
287
  audio_data=audio_data,
281
288
  modalities=modalities,
282
289
  stop=stop,
@@ -10,6 +10,7 @@ from sglang.srt.entrypoints.openai.protocol import (
10
10
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
11
11
  from sglang.srt.function_call.core_types import ToolCallItem
12
12
  from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
13
+ from sglang.srt.function_call.kimik2_detector import KimiK2Detector
13
14
  from sglang.srt.function_call.llama32_detector import Llama32Detector
14
15
  from sglang.srt.function_call.mistral_detector import MistralDetector
15
16
  from sglang.srt.function_call.pythonic_detector import PythonicDetector
@@ -33,6 +34,7 @@ class FunctionCallParser:
33
34
  "mistral": MistralDetector,
34
35
  "deepseekv3": DeepSeekV3Detector,
35
36
  "pythonic": PythonicDetector,
37
+ "kimi_k2": KimiK2Detector,
36
38
  }
37
39
 
38
40
  def __init__(self, tools: List[Tool], tool_call_parser: str):
@@ -0,0 +1,220 @@
1
+ import json
2
+ import logging
3
+ import re
4
+ from typing import List
5
+
6
+ from sglang.srt.entrypoints.openai.protocol import Tool
7
+ from sglang.srt.function_call.base_format_detector import BaseFormatDetector
8
+ from sglang.srt.function_call.core_types import (
9
+ StreamingParseResult,
10
+ StructureInfo,
11
+ ToolCallItem,
12
+ _GetInfoFunc,
13
+ )
14
+ from sglang.srt.function_call.ebnf_composer import EBNFComposer
15
+ from sglang.srt.function_call.utils import _is_complete_json
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class KimiK2Detector(BaseFormatDetector):
21
+
22
+ def __init__(self):
23
+ super().__init__()
24
+ self._buffer = ""
25
+ self.current_tool_name_sent: bool = False
26
+ self.prev_tool_call_arr: list[dict] = []
27
+ self.current_tool_id: int = -1
28
+ self.streamed_args_for_tool: list[str] = (
29
+ []
30
+ ) # map what has been streamed for each tool so far to a list
31
+
32
+ self.bot_token: str = "<|tool_calls_section_begin|>"
33
+ self.eot_token: str = "<|tool_calls_section_end|>"
34
+
35
+ self.tool_call_start_token: str = "<|tool_call_begin|>"
36
+ self.tool_call_end_token: str = "<|tool_call_end|>"
37
+
38
+ self.tool_call_regex = re.compile(
39
+ r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>"
40
+ )
41
+
42
+ self.stream_tool_call_portion_regex = re.compile(
43
+ r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)"
44
+ )
45
+
46
+ self._last_arguments = ""
47
+
48
+ def has_tool_call(self, text: str) -> bool:
49
+ """Check if the text contains a KimiK2 format tool call."""
50
+ return self.bot_token in text
51
+
52
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
53
+ """
54
+ One-time parsing: Detects and parses tool calls in the provided text.
55
+
56
+ :param text: The complete text to parse.
57
+ :param tools: List of available tools.
58
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
59
+ """
60
+ if self.bot_token not in text:
61
+ return StreamingParseResult(normal_text=text, calls=[])
62
+ try:
63
+ # there are two possible captures - between tags, or between a
64
+ # tag and end-of-string so the result of
65
+ # findall is an array of tuples where one is a function call and
66
+ # the other is None
67
+ function_call_tuples = self.tool_call_regex.findall(text)
68
+
69
+ logger.debug("function_call_tuples: %s", function_call_tuples)
70
+
71
+ tool_calls = []
72
+ for match in function_call_tuples:
73
+ function_id, function_args = match
74
+ function_name = function_id.split(".")[1].split(":")[0]
75
+ function_idx = int(function_id.split(".")[1].split(":")[1])
76
+
77
+ logger.info(f"function_name {function_name}")
78
+
79
+ tool_calls.append(
80
+ ToolCallItem(
81
+ tool_index=function_idx, # Use the call index in the response, not tool position
82
+ name=function_name,
83
+ parameters=function_args,
84
+ )
85
+ )
86
+
87
+ content = text[: text.find(self.bot_token)]
88
+ return StreamingParseResult(normal_text=content, calls=tool_calls)
89
+
90
+ except Exception as e:
91
+ logger.error(f"Error in detect_and_parse: {e}")
92
+ # return the normal text if parsing fails
93
+ return StreamingParseResult(normal_text=text)
94
+
95
+ def parse_streaming_increment(
96
+ self, new_text: str, tools: List[Tool]
97
+ ) -> StreamingParseResult:
98
+ """
99
+ Streaming incremental parsing tool calls for KimiK2 format.
100
+ """
101
+ self._buffer += new_text
102
+ current_text = self._buffer
103
+
104
+ # Check if we have a tool call (either the start token or individual tool call)
105
+ has_tool_call = (
106
+ self.bot_token in current_text or self.tool_call_start_token in current_text
107
+ )
108
+
109
+ if not has_tool_call:
110
+ self._buffer = ""
111
+ for e_token in [self.eot_token, self.tool_call_end_token]:
112
+ if e_token in new_text:
113
+ new_text = new_text.replace(e_token, "")
114
+ return StreamingParseResult(normal_text=new_text)
115
+
116
+ if not hasattr(self, "_tool_indices"):
117
+ self._tool_indices = {
118
+ tool.function.name: i
119
+ for i, tool in enumerate(tools)
120
+ if tool.function and tool.function.name
121
+ }
122
+
123
+ calls: list[ToolCallItem] = []
124
+ try:
125
+ match = self.stream_tool_call_portion_regex.search(current_text)
126
+ if match:
127
+ function_id = match.group("tool_call_id")
128
+ function_args = match.group("function_arguments")
129
+
130
+ function_name = function_id.split(".")[1].split(":")[0]
131
+
132
+ # Initialize state if this is the first tool call
133
+ if self.current_tool_id == -1:
134
+ self.current_tool_id = 0
135
+ self.prev_tool_call_arr = []
136
+ self.streamed_args_for_tool = [""]
137
+
138
+ # Ensure we have enough entries in our tracking arrays
139
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
140
+ self.prev_tool_call_arr.append({})
141
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
142
+ self.streamed_args_for_tool.append("")
143
+
144
+ if not self.current_tool_name_sent:
145
+ calls.append(
146
+ ToolCallItem(
147
+ tool_index=self.current_tool_id,
148
+ name=function_name,
149
+ parameters="",
150
+ )
151
+ )
152
+ self.current_tool_name_sent = True
153
+ # Store the tool call info for adapter.py
154
+ self.prev_tool_call_arr[self.current_tool_id] = {
155
+ "name": function_name,
156
+ "arguments": {},
157
+ }
158
+ else:
159
+ argument_diff = (
160
+ function_args[len(self._last_arguments) :]
161
+ if function_args.startswith(self._last_arguments)
162
+ else function_args
163
+ )
164
+
165
+ parsed_args_diff = argument_diff.split("<|tool_call_end|>", 1)[0]
166
+
167
+ if parsed_args_diff:
168
+
169
+ calls.append(
170
+ ToolCallItem(
171
+ tool_index=self.current_tool_id,
172
+ name=None,
173
+ parameters=parsed_args_diff,
174
+ )
175
+ )
176
+ self._last_arguments += argument_diff
177
+ self.streamed_args_for_tool[
178
+ self.current_tool_id
179
+ ] += parsed_args_diff
180
+
181
+ parsed_args = function_args.split("<|tool_call_end|>", 1)[0]
182
+ if _is_complete_json(parsed_args):
183
+ try:
184
+ parsed_args = json.loads(parsed_args)
185
+ self.prev_tool_call_arr[self.current_tool_id][
186
+ "arguments"
187
+ ] = parsed_args
188
+ except json.JSONDecodeError:
189
+ pass
190
+
191
+ # Find the end of the current tool call and remove only that part from buffer
192
+ tool_call_end_pattern = (
193
+ r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>"
194
+ )
195
+ match = re.search(
196
+ tool_call_end_pattern, current_text, re.DOTALL
197
+ )
198
+ if match:
199
+ # Remove the completed tool call from buffer, keep any remaining content
200
+ self._buffer = current_text[match.end() :]
201
+ else:
202
+ self._buffer = ""
203
+
204
+ result = StreamingParseResult(normal_text="", calls=calls)
205
+ self.current_tool_id += 1
206
+ self._last_arguments = ""
207
+ self.current_tool_name_sent = False
208
+ return result
209
+
210
+ return StreamingParseResult(normal_text="", calls=calls)
211
+
212
+ except Exception as e:
213
+ logger.error(f"Error in parse_streaming_increment: {e}")
214
+ return StreamingParseResult(normal_text=current_text)
215
+
216
+ def structure_info(self) -> _GetInfoFunc:
217
+ raise NotImplementedError()
218
+
219
+ def build_ebnf(self, tools: List[Tool]):
220
+ raise NotImplementedError()
@@ -14,6 +14,7 @@
14
14
  """Utilities for Huggingface Transformers."""
15
15
 
16
16
  import contextlib
17
+ import logging
17
18
  import os
18
19
  import warnings
19
20
  from pathlib import Path
@@ -25,6 +26,7 @@ from transformers import (
25
26
  AutoConfig,
26
27
  AutoProcessor,
27
28
  AutoTokenizer,
29
+ GenerationConfig,
28
30
  PretrainedConfig,
29
31
  PreTrainedTokenizer,
30
32
  PreTrainedTokenizerBase,
@@ -153,6 +155,22 @@ def get_config(
153
155
  return config
154
156
 
155
157
 
158
+ @lru_cache_frozenset(maxsize=32)
159
+ def get_generation_config(
160
+ model: str,
161
+ trust_remote_code: bool,
162
+ revision: Optional[str] = None,
163
+ **kwargs,
164
+ ):
165
+ try:
166
+ return GenerationConfig.from_pretrained(
167
+ model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
168
+ )
169
+ except OSError as e:
170
+ logging.info("model doesn't have generation_config.json")
171
+ return None
172
+
173
+
156
174
  # Models don't use the same configuration key for determining the maximum
157
175
  # context length. Store them here so we can sanely check them.
158
176
  # NOTE: The ordering here is important. Some models have two of these and we
@@ -110,6 +110,7 @@ def process_content_for_template_format(
110
110
  msg_dict: dict,
111
111
  content_format: str,
112
112
  image_data: list,
113
+ video_data: list,
113
114
  audio_data: list,
114
115
  modalities: list,
115
116
  ) -> dict:
@@ -120,6 +121,7 @@ def process_content_for_template_format(
120
121
  msg_dict: Message dictionary with content
121
122
  content_format: 'string' or 'openai' (detected via AST analysis)
122
123
  image_data: List to append extracted image URLs
124
+ video_data: List to append extracted video URLs
123
125
  audio_data: List to append extracted audio URLs
124
126
  modalities: List to append modalities
125
127
 
@@ -143,6 +145,12 @@ def process_content_for_template_format(
143
145
  modalities.append(chunk.get("modalities"))
144
146
  # Normalize to simple 'image' type for template compatibility
145
147
  processed_content_parts.append({"type": "image"})
148
+ elif chunk_type == "video_url":
149
+ video_data.append(chunk["video_url"]["url"])
150
+ if chunk.get("modalities"):
151
+ modalities.append(chunk.get("modalities"))
152
+ # Normalize to simple 'video' type for template compatibility
153
+ processed_content_parts.append({"type": "video"})
146
154
  elif chunk_type == "audio_url":
147
155
  audio_data.append(chunk["audio_url"]["url"])
148
156
  # Normalize to simple 'audio' type
@@ -187,11 +187,24 @@ class LayerCommunicator:
187
187
  if hidden_states.shape[0] == 0:
188
188
  residual = hidden_states
189
189
  else:
190
- if residual is None:
191
- residual = hidden_states
192
- hidden_states = self.input_layernorm(hidden_states)
190
+ if (
191
+ residual is not None
192
+ and hasattr(hidden_states, "_sglang_needs_allreduce_fusion")
193
+ and hidden_states._sglang_needs_allreduce_fusion
194
+ ):
195
+ hidden_states, residual = (
196
+ self.input_layernorm.forward_with_allreduce_fusion(
197
+ hidden_states, residual
198
+ )
199
+ )
193
200
  else:
194
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
201
+ if residual is None:
202
+ residual = hidden_states
203
+ hidden_states = self.input_layernorm(hidden_states)
204
+ else:
205
+ hidden_states, residual = self.input_layernorm(
206
+ hidden_states, residual
207
+ )
195
208
 
196
209
  hidden_states = self._communicate_simple_fn(
197
210
  hidden_states=hidden_states,
@@ -402,12 +415,14 @@ class CommunicateWithAllReduceAndLayerNormFn:
402
415
  if hidden_states.shape[0] != 0:
403
416
  hidden_states = layernorm(hidden_states)
404
417
  else:
418
+ # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
419
+ # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
405
420
  if (
406
421
  _is_sm100_supported
407
422
  and _is_flashinfer_available
408
423
  and hasattr(layernorm, "forward_with_allreduce_fusion")
409
424
  and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
410
- and hidden_states.shape[0] <= 1024
425
+ and hidden_states.shape[0] <= 128
411
426
  ):
412
427
  hidden_states, residual = layernorm.forward_with_allreduce_fusion(
413
428
  hidden_states, residual
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
92
92
 
93
93
 
94
94
  def ensure_workspace_initialized(
95
- max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False
95
+ max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False
96
96
  ):
97
97
  """Ensure workspace is initialized"""
98
98
  if not is_flashinfer_available() or _flashinfer_comm is None:
@@ -119,12 +119,12 @@ def ensure_workspace_initialized(
119
119
  return _workspace_manager.initialized
120
120
 
121
121
 
122
- def flashinfer_allreduce_add_rmsnorm(
122
+ def flashinfer_allreduce_residual_rmsnorm(
123
123
  input_tensor: torch.Tensor,
124
124
  residual: torch.Tensor,
125
125
  weight: torch.Tensor,
126
126
  eps: float = 1e-6,
127
- max_token_num: int = 1024,
127
+ max_token_num: int = 128,
128
128
  use_oneshot: bool = True,
129
129
  trigger_completion_at_end: bool = False,
130
130
  fp32_acc: bool = False,
@@ -174,11 +174,11 @@ class RMSNorm(CustomOp):
174
174
  if residual is not None:
175
175
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
176
176
  from sglang.srt.layers.flashinfer_comm_fusion import (
177
- flashinfer_allreduce_add_rmsnorm,
177
+ flashinfer_allreduce_residual_rmsnorm,
178
178
  )
179
179
 
180
180
  if get_tensor_model_parallel_world_size() > 1:
181
- fused_result = flashinfer_allreduce_add_rmsnorm(
181
+ fused_result = flashinfer_allreduce_residual_rmsnorm(
182
182
  input_tensor=x,
183
183
  residual=residual,
184
184
  weight=self.weight,
@@ -34,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import (
34
34
  from sglang.srt.utils import (
35
35
  cpu_has_amx_support,
36
36
  is_cpu,
37
+ is_npu,
37
38
  set_weight_attrs,
38
39
  use_intel_amx_backend,
39
40
  )
@@ -60,6 +61,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
60
61
 
61
62
  _is_cpu_amx_available = cpu_has_amx_support()
62
63
  _is_cpu = is_cpu()
64
+ _is_npu = is_npu()
63
65
 
64
66
 
65
67
  def adjust_marlin_shard(param, shard_size, shard_offset):
@@ -297,6 +299,14 @@ class ReplicatedLinear(LinearBase):
297
299
  if len(loaded_weight.shape) == 0:
298
300
  loaded_weight = loaded_weight.reshape(1)
299
301
 
302
+ # The per-tensor quant-scale must be 1 dimension
303
+ if _is_npu:
304
+ if param.size() != loaded_weight.size() and param.size(0) == 1:
305
+ if torch.allclose(loaded_weight, loaded_weight[0]):
306
+ loaded_weight = loaded_weight[:1]
307
+ else:
308
+ raise ValueError(f"{loaded_weight} are not all equal")
309
+
300
310
  assert param.size() == loaded_weight.size()
301
311
  param.data.copy_(loaded_weight)
302
312
 
@@ -1357,7 +1367,7 @@ class RowParallelLinear(LinearBase):
1357
1367
  # It does not support additional parameters.
1358
1368
  param.load_row_parallel_weight(loaded_weight)
1359
1369
 
1360
- def forward(self, input_):
1370
+ def forward(self, input_, can_fuse_mlp_allreduce=False):
1361
1371
  if self.input_is_parallel:
1362
1372
  input_parallel = input_
1363
1373
  else:
@@ -1372,7 +1382,7 @@ class RowParallelLinear(LinearBase):
1372
1382
  # bias will not get added more than once in TP>1 case)
1373
1383
  bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1374
1384
  output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
1375
- if self.reduce_results and self.tp_size > 1:
1385
+ if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
1376
1386
  output = tensor_model_parallel_all_reduce(output_parallel)
1377
1387
  else:
1378
1388
  output = output_parallel