sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ from sglang.srt.entrypoints.openai.protocol import (
22
22
  ErrorResponse,
23
23
  FunctionResponse,
24
24
  LogProbs,
25
+ MessageProcessingResult,
25
26
  ToolCall,
26
27
  TopLogprob,
27
28
  )
@@ -62,120 +63,81 @@ class OpenAIServingChat(OpenAIServingBase):
62
63
  is_multimodal = self.tokenizer_manager.model_config.is_multimodal
63
64
 
64
65
  # Process messages and apply chat template
65
- (
66
- prompt,
67
- prompt_ids,
68
- image_data,
69
- audio_data,
70
- modalities,
71
- stop,
72
- tool_call_constraint,
73
- ) = self._process_messages(request, is_multimodal)
66
+ processed_messages = self._process_messages(request, is_multimodal)
74
67
 
75
68
  # Build sampling parameters
76
69
  sampling_params = self._build_sampling_params(
77
- request, stop, tool_call_constraint
70
+ request, processed_messages.stop, processed_messages.tool_call_constraint
78
71
  )
79
72
 
80
73
  # Handle single vs multiple requests
81
74
  if is_multimodal:
82
- prompt_kwargs = {"text": prompt}
75
+ prompt_kwargs = {"text": processed_messages.prompt}
83
76
  else:
84
- if isinstance(prompt_ids, str):
85
- prompt_kwargs = {"text": prompt_ids}
77
+ if isinstance(processed_messages.prompt_ids, str):
78
+ prompt_kwargs = {"text": processed_messages.prompt_ids}
86
79
  else:
87
- prompt_kwargs = {"input_ids": prompt_ids}
80
+ prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
88
81
 
89
82
  adapted_request = GenerateReqInput(
90
83
  **prompt_kwargs,
91
- image_data=image_data,
92
- audio_data=audio_data,
84
+ image_data=processed_messages.image_data,
85
+ audio_data=processed_messages.audio_data,
93
86
  sampling_params=sampling_params,
94
87
  return_logprob=request.logprobs,
95
88
  logprob_start_len=-1,
96
89
  top_logprobs_num=request.top_logprobs or 0,
97
90
  stream=request.stream,
98
91
  return_text_in_logprobs=True,
99
- modalities=modalities,
92
+ modalities=processed_messages.modalities,
100
93
  lora_path=request.lora_path,
101
94
  bootstrap_host=request.bootstrap_host,
102
95
  bootstrap_port=request.bootstrap_port,
103
96
  bootstrap_room=request.bootstrap_room,
104
97
  return_hidden_states=request.return_hidden_states,
98
+ rid=request.rid,
105
99
  )
106
100
 
107
101
  return adapted_request, request
108
102
 
109
103
  def _process_messages(
110
104
  self, request: ChatCompletionRequest, is_multimodal: bool
111
- ) -> tuple[
112
- str,
113
- Union[str, List[int]],
114
- Optional[Any],
115
- Optional[Any],
116
- List[str],
117
- List[str],
118
- Optional[Any],
119
- ]:
105
+ ) -> MessageProcessingResult:
120
106
  """Process chat messages and apply chat template"""
121
107
  tool_call_constraint = None
122
- prompt = ""
123
- prompt_ids = []
124
108
 
125
- if not isinstance(request.messages, str):
126
- # Apply chat template and its stop strings
127
- tools = None
128
- if request.tools and request.tool_choice != "none":
129
- request.skip_special_tokens = False
130
- if not isinstance(request.tool_choice, str):
131
- tools = [
132
- item.function.model_dump()
133
- for item in request.tools
134
- if item.function.name == request.tool_choice.function.name
135
- ]
136
- else:
137
- tools = [item.function.model_dump() for item in request.tools]
109
+ # Apply chat template and its stop strings
110
+ tools = None
111
+ if request.tools and request.tool_choice != "none":
112
+ request.skip_special_tokens = False
113
+ if not isinstance(request.tool_choice, str):
114
+ tools = [
115
+ item.function.model_dump()
116
+ for item in request.tools
117
+ if item.function.name == request.tool_choice.function.name
118
+ ]
119
+ else:
120
+ tools = [item.function.model_dump() for item in request.tools]
138
121
 
139
- tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
140
- parser = FunctionCallParser(request.tools, tool_call_parser)
141
- tool_call_constraint = parser.get_structure_constraint(
142
- request.tool_choice
143
- )
122
+ tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
123
+ parser = FunctionCallParser(request.tools, tool_call_parser)
124
+ tool_call_constraint = parser.get_structure_constraint(request.tool_choice)
144
125
 
145
- # Use chat template
146
- if self.template_manager.chat_template_name is None:
147
- prompt, prompt_ids, image_data, audio_data, modalities, stop = (
148
- self._apply_jinja_template(request, tools, is_multimodal)
149
- )
150
- else:
151
- prompt, prompt_ids, image_data, audio_data, modalities, stop = (
152
- self._apply_conversation_template(request, is_multimodal)
153
- )
126
+ # Use chat template
127
+ if self.template_manager.chat_template_name is None:
128
+ result = self._apply_jinja_template(request, tools, is_multimodal)
154
129
  else:
155
- # Use raw prompt
156
- prompt_ids = request.messages
157
- stop = request.stop or []
158
- image_data = None
159
- audio_data = None
160
- modalities = []
161
- prompt = request.messages
162
-
163
- return (
164
- prompt,
165
- prompt_ids,
166
- image_data,
167
- audio_data,
168
- modalities,
169
- stop,
170
- tool_call_constraint,
171
- )
130
+ result = self._apply_conversation_template(request, is_multimodal)
131
+
132
+ result.tool_call_constraint = tool_call_constraint
133
+ return result
172
134
 
173
135
  def _apply_jinja_template(
174
136
  self,
175
137
  request: ChatCompletionRequest,
176
138
  tools: Optional[List[Dict]],
177
139
  is_multimodal: bool,
178
- ) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
140
+ ) -> MessageProcessingResult:
179
141
  """Apply Jinja chat template"""
180
142
  prompt = ""
181
143
  prompt_ids = []
@@ -253,13 +215,20 @@ class OpenAIServingChat(OpenAIServingBase):
253
215
  image_data = image_data if image_data else None
254
216
  audio_data = audio_data if audio_data else None
255
217
  modalities = modalities if modalities else []
256
- return prompt, prompt_ids, image_data, audio_data, modalities, stop
218
+ return MessageProcessingResult(
219
+ prompt=prompt,
220
+ prompt_ids=prompt_ids,
221
+ image_data=image_data,
222
+ audio_data=audio_data,
223
+ modalities=modalities,
224
+ stop=stop,
225
+ )
257
226
 
258
227
  def _apply_conversation_template(
259
228
  self,
260
229
  request: ChatCompletionRequest,
261
230
  is_multimodal: bool,
262
- ) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]:
231
+ ) -> MessageProcessingResult:
263
232
  """Apply conversation template"""
264
233
  prompt = ""
265
234
  prompt_ids = []
@@ -304,7 +273,14 @@ class OpenAIServingChat(OpenAIServingBase):
304
273
  if not is_multimodal:
305
274
  prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
306
275
 
307
- return prompt, prompt_ids, image_data, audio_data, modalities, stop
276
+ return MessageProcessingResult(
277
+ prompt=prompt,
278
+ prompt_ids=prompt_ids,
279
+ image_data=image_data,
280
+ audio_data=audio_data,
281
+ modalities=modalities,
282
+ stop=stop,
283
+ )
308
284
 
309
285
  def _build_sampling_params(
310
286
  self,
@@ -87,6 +87,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
87
87
  bootstrap_port=request.bootstrap_port,
88
88
  bootstrap_room=request.bootstrap_room,
89
89
  return_hidden_states=request.return_hidden_states,
90
+ rid=request.rid,
90
91
  )
91
92
 
92
93
  return adapted_request, request
@@ -119,6 +119,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
119
119
 
120
120
  adapted_request = EmbeddingReqInput(
121
121
  **prompt_kwargs,
122
+ rid=request.rid,
122
123
  )
123
124
 
124
125
  return adapted_request, request
File without changes
@@ -3,7 +3,7 @@ from typing import Optional
3
3
 
4
4
  import torch
5
5
 
6
- from sglang.srt.managers.eplb_algorithms import deepseek, deepseek_vec
6
+ from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec
7
7
 
8
8
 
9
9
  class EplbAlgorithm(Enum):
@@ -4,10 +4,8 @@ from typing import TYPE_CHECKING, List
4
4
 
5
5
  import torch.cuda
6
6
 
7
- from sglang.srt.managers.expert_distribution import (
8
- get_global_expert_distribution_recorder,
9
- )
10
- from sglang.srt.managers.expert_location import ExpertLocationMetadata
7
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
8
+ from sglang.srt.eplb.expert_location import ExpertLocationMetadata
11
9
 
12
10
  if TYPE_CHECKING:
13
11
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -4,7 +4,7 @@ from pathlib import Path
4
4
  import torch
5
5
  from tqdm import tqdm
6
6
 
7
- from sglang.srt.managers.expert_distribution import (
7
+ from sglang.srt.eplb.expert_distribution import (
8
8
  _convert_global_physical_count_to_logical_count,
9
9
  )
10
10
 
@@ -24,7 +24,7 @@ import einops
24
24
  import torch
25
25
  import torch.distributed
26
26
 
27
- from sglang.srt.managers.expert_location import ExpertLocationMetadata
27
+ from sglang.srt.eplb.expert_location import ExpertLocationMetadata
28
28
  from sglang.srt.managers.schedule_batch import global_server_args_dict
29
29
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
30
  from sglang.srt.server_args import ServerArgs
@@ -61,6 +61,10 @@ class ExpertDistributionRecorder(ABC):
61
61
  def with_debug_name(self, debug_name):
62
62
  yield
63
63
 
64
+ @contextmanager
65
+ def disable_this_region(self):
66
+ yield
67
+
64
68
  @contextmanager
65
69
  def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
66
70
  yield
@@ -116,6 +120,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
116
120
  self._expert_location_metadata = expert_location_metadata
117
121
 
118
122
  self._recording = False
123
+ self._disable_all = False
119
124
  self._current_forward_pass_id = Withable()
120
125
  self._current_layer_idx = Withable()
121
126
  self._current_debug_name = Withable()
@@ -148,6 +153,16 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
148
153
  finally:
149
154
  self._on_forward_pass_end(forward_pass_id)
150
155
 
156
+ @contextmanager
157
+ def disable_this_region(self):
158
+ """Context manager to temporarily disable recording."""
159
+ previous_disable_all = self._disable_all
160
+ self._disable_all = True
161
+ try:
162
+ yield
163
+ finally:
164
+ self._disable_all = previous_disable_all
165
+
151
166
  def _on_forward_pass_start(self, forward_batch: ForwardBatch):
152
167
  if not self._recording:
153
168
  return
@@ -189,6 +204,8 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
189
204
  )
190
205
 
191
206
  def _on_hook(self, hook_name: str, **kwargs):
207
+ if self._disable_all:
208
+ return
192
209
  if not (self._recording or torch.cuda.is_current_stream_capturing()):
193
210
  return
194
211
  gatherer = self._single_pass_gatherers[
@@ -23,7 +23,7 @@ import torch.distributed
23
23
  import torch.nn.functional as F
24
24
 
25
25
  from sglang.srt.configs.model_config import ModelConfig
26
- from sglang.srt.managers import eplb_algorithms
26
+ from sglang.srt.eplb import eplb_algorithms
27
27
  from sglang.srt.model_loader import get_model_architecture
28
28
  from sglang.srt.server_args import ServerArgs
29
29
 
@@ -17,7 +17,7 @@ from typing import Literal, Optional
17
17
 
18
18
  import torch
19
19
 
20
- from sglang.srt.managers.expert_location import get_global_expert_location_metadata
20
+ from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
21
21
  from sglang.srt.managers.schedule_batch import global_server_args_dict
22
22
 
23
23
 
@@ -20,7 +20,7 @@ import torch
20
20
  import torch.distributed
21
21
  from torch.distributed import P2POp
22
22
 
23
- from sglang.srt.managers.expert_location import (
23
+ from sglang.srt.eplb.expert_location import (
24
24
  ExpertLocationMetadata,
25
25
  get_global_expert_location_metadata,
26
26
  )
@@ -30,6 +30,9 @@ from sglang.srt.utils import get_bool_env_var
30
30
  logger = logging.getLogger(__name__)
31
31
 
32
32
 
33
+ _LOG_INPUT = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_INPUT")
34
+
35
+
33
36
  class ExpertLocationUpdater:
34
37
  def __init__(self):
35
38
  self._first_execution = True
@@ -175,6 +178,19 @@ def update_expert_weights_single_layer(
175
178
  assert isinstance(old_physical_to_logical_map, list)
176
179
  assert isinstance(new_physical_to_logical_map, list)
177
180
 
181
+ if _LOG_INPUT:
182
+ logger.info(
183
+ "update_expert_weights_single_layer "
184
+ f"{[x.shape for x in routed_experts_weights]=} "
185
+ f"{[x.shape for x in temp_buffers]=} "
186
+ f"{old_physical_to_logical_map=} "
187
+ f"{new_physical_to_logical_map=} "
188
+ f"{num_local_physical_experts=} "
189
+ f"{num_gpu_per_node=} "
190
+ f"{rank=} "
191
+ f"{world_size=} "
192
+ )
193
+
178
194
  output_logs = [] if debug else None
179
195
 
180
196
  num_physical_experts = len(old_physical_to_logical_map)
@@ -42,7 +42,7 @@ from sglang.srt.configs import (
42
42
  )
43
43
  from sglang.srt.configs.internvl import InternVLChatConfig
44
44
  from sglang.srt.connector import create_remote_connector
45
- from sglang.srt.utils import is_remote_url
45
+ from sglang.srt.utils import is_remote_url, lru_cache_frozenset
46
46
 
47
47
  _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
48
48
  ChatGLMConfig.model_type: ChatGLMConfig,
@@ -103,6 +103,7 @@ def get_hf_text_config(config: PretrainedConfig):
103
103
  return config
104
104
 
105
105
 
106
+ @lru_cache_frozenset(maxsize=32)
106
107
  def get_config(
107
108
  model: str,
108
109
  trust_remote_code: bool,
@@ -46,6 +46,9 @@ _is_cpu = is_cpu()
46
46
  if _is_cuda:
47
47
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
48
48
 
49
+ if is_npu():
50
+ import torch_npu
51
+
49
52
  logger = logging.getLogger(__name__)
50
53
 
51
54
 
@@ -70,6 +73,10 @@ class SiluAndMul(CustomOp):
70
73
  else:
71
74
  return self.forward_native(x)
72
75
 
76
+ def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
77
+ out = torch_npu.npu_swiglu(x)
78
+ return out
79
+
73
80
 
74
81
  class GeluAndMul(CustomOp):
75
82
  def __init__(self, approximate="tanh"):
@@ -0,0 +1,86 @@
1
+ import logging
2
+
3
+ import torch
4
+
5
+ from sglang.srt.utils import cpu_has_amx_support
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def amx_process_weight_after_loading(weight):
11
+ if weight.device != torch.device("cpu"):
12
+ return weight
13
+ if not cpu_has_amx_support():
14
+ return weight
15
+
16
+ return torch.ops.sgl_kernel.convert_weight_packed(weight)
17
+
18
+
19
+ # TODO: currently gemm kernel has the below requirements:
20
+ # OC % TILE_N == 0, where TILE_N = 16
21
+ # IC % TILE_K == 0, where TILE_K = 32
22
+ def dim_is_supported(weight):
23
+ TILE_N = 16
24
+ TILE_K = 32
25
+ ndim = weight.ndim
26
+ OC = weight.size(1) if ndim == 3 else weight.size(0)
27
+ IC = weight.size(2) if ndim == 3 else weight.size(1)
28
+ return OC % TILE_N == 0 and IC % TILE_K == 0
29
+
30
+
31
+ def _amx_process_weight_after_loading(
32
+ module, weight_names, transpose_dims=None
33
+ ) -> None:
34
+ # Pack weight for get better performance on CPU
35
+ devices = {getattr(module, weight_name).device for weight_name in weight_names}
36
+ assert len(devices) == 1, f"Expects all weights to be on the same device"
37
+ device = devices.pop()
38
+
39
+ if transpose_dims:
40
+ assert len(weight_names) == len(
41
+ transpose_dims
42
+ ), "len(weight_names) should be equal to len(transpose_dims)"
43
+
44
+ for i, weight_name in enumerate(weight_names):
45
+ weight_tensor = getattr(module, weight_name)
46
+
47
+ if transpose_dims and transpose_dims[i]:
48
+ weight_tensor = weight_tensor.transpose(*transpose_dims[i])
49
+
50
+ # We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
51
+ if not dim_is_supported(weight_tensor):
52
+ logger.warning(
53
+ f"Unsupported dimension for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} in {module}. "
54
+ f"The derived (OC, IC) dimensions must be divisible by (16, 32). "
55
+ )
56
+ module.use_intel_amx_backend = False
57
+ return
58
+
59
+ packed_weight = torch.nn.Parameter(
60
+ amx_process_weight_after_loading(weight_tensor),
61
+ requires_grad=False,
62
+ )
63
+ packed_weight.__dict__ = weight_tensor.__dict__
64
+ setattr(module, weight_name, packed_weight)
65
+
66
+ module.use_intel_amx_backend = (
67
+ device == torch.device("cpu") and cpu_has_amx_support()
68
+ )
69
+
70
+ if (
71
+ module.use_intel_amx_backend
72
+ and hasattr(module, "bias")
73
+ and module.bias is not None
74
+ ):
75
+ module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)
76
+
77
+
78
+ class PackWeightMethod:
79
+ def __init__(self, weight_names, transpose_dims=None):
80
+ self.weight_names = weight_names
81
+ self.transpose_dims = transpose_dims
82
+
83
+ def process_weights_after_loading(self, module) -> None:
84
+ _amx_process_weight_after_loading(
85
+ module, self.weight_names, self.transpose_dims
86
+ )