sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@
13
13
  # ==============================================================================
14
14
  """ModelRunner runs the forward passes of the models."""
15
15
 
16
- import collections
17
16
  import datetime
18
17
  import gc
19
18
  import inspect
@@ -32,6 +31,7 @@ from sglang.srt.configs.load_config import LoadConfig
32
31
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
33
32
  from sglang.srt.distributed import (
34
33
  get_tp_group,
34
+ get_world_group,
35
35
  init_distributed_environment,
36
36
  initialize_model_parallel,
37
37
  set_custom_all_reduce,
@@ -51,6 +51,18 @@ from sglang.srt.layers.quantization.deep_gemm import (
51
51
  from sglang.srt.layers.sampler import Sampler
52
52
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
53
53
  from sglang.srt.lora.lora_manager import LoRAManager
54
+ from sglang.srt.managers.eplb_manager import EPLBManager
55
+ from sglang.srt.managers.expert_distribution import (
56
+ ExpertDistributionRecorder,
57
+ get_global_expert_distribution_recorder,
58
+ set_global_expert_distribution_recorder,
59
+ )
60
+ from sglang.srt.managers.expert_location import (
61
+ ExpertLocationMetadata,
62
+ compute_initial_expert_location_metadata,
63
+ get_global_expert_location_metadata,
64
+ set_global_expert_location_metadata,
65
+ )
54
66
  from sglang.srt.managers.schedule_batch import global_server_args_dict
55
67
  from sglang.srt.mem_cache.memory_pool import (
56
68
  DoubleSparseTokenToKVPool,
@@ -60,6 +72,7 @@ from sglang.srt.mem_cache.memory_pool import (
60
72
  TokenToKVPoolAllocator,
61
73
  )
62
74
  from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
75
+ from sglang.srt.model_executor import expert_location_updater
63
76
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
64
77
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
65
78
  from sglang.srt.model_loader import get_model
@@ -93,6 +106,8 @@ from sglang.srt.utils import (
93
106
  set_cuda_arch,
94
107
  )
95
108
 
109
+ _is_hip = is_hip()
110
+
96
111
  # Use a small KV cache pool size for tests in CI
97
112
  SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
98
113
 
@@ -102,6 +117,19 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
102
117
  logger = logging.getLogger(__name__)
103
118
 
104
119
 
120
+ class RankZeroFilter(logging.Filter):
121
+ """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
122
+
123
+ def __init__(self, is_rank_zero):
124
+ super().__init__()
125
+ self.is_rank_zero = is_rank_zero
126
+
127
+ def filter(self, record):
128
+ if record.levelno == logging.INFO:
129
+ return self.is_rank_zero
130
+ return True
131
+
132
+
105
133
  class ModelRunner:
106
134
  """ModelRunner runs the forward passes of the models."""
107
135
 
@@ -125,6 +153,10 @@ class ModelRunner:
125
153
  self.mem_fraction_static = mem_fraction_static
126
154
  self.device = server_args.device
127
155
  self.gpu_id = gpu_id
156
+
157
+ # Apply the rank zero filter to logger
158
+ if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
159
+ logger.addFilter(RankZeroFilter(tp_rank == 0))
128
160
  self.tp_rank = tp_rank
129
161
  self.tp_size = tp_size
130
162
  self.pp_rank = pp_rank
@@ -134,7 +166,9 @@ class ModelRunner:
134
166
  self.is_draft_worker = is_draft_worker
135
167
  self.is_generation = model_config.is_generation
136
168
  self.is_multimodal = model_config.is_multimodal
137
- self.should_log = tp_rank == 0
169
+ self.is_multimodal_chunked_prefill_supported = (
170
+ model_config.is_multimodal_chunked_prefill_supported
171
+ )
138
172
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
139
173
  server_args.speculative_algorithm
140
174
  )
@@ -144,6 +178,8 @@ class ModelRunner:
144
178
  self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
145
179
  self.attention_chunk_size = model_config.attention_chunk_size
146
180
 
181
+ self.forward_pass_id = 0
182
+
147
183
  # Model-specific adjustment
148
184
  self.model_specific_adjustment()
149
185
 
@@ -162,10 +198,13 @@ class ModelRunner:
162
198
  "disable_radix_cache": server_args.disable_radix_cache,
163
199
  "enable_nan_detection": server_args.enable_nan_detection,
164
200
  "enable_dp_attention": server_args.enable_dp_attention,
201
+ "enable_dp_lm_head": server_args.enable_dp_lm_head,
165
202
  "enable_ep_moe": server_args.enable_ep_moe,
166
203
  "enable_deepep_moe": server_args.enable_deepep_moe,
204
+ "deepep_config": server_args.deepep_config,
167
205
  "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
168
206
  "moe_dense_tp_size": server_args.moe_dense_tp_size,
207
+ "ep_dispatch_algorithm": server_args.ep_dispatch_algorithm,
169
208
  "n_share_experts_fusion": server_args.n_share_experts_fusion,
170
209
  "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
171
210
  "torchao_config": server_args.torchao_config,
@@ -174,6 +213,7 @@ class ModelRunner:
174
213
  "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
175
214
  "use_mla_backend": self.use_mla_backend,
176
215
  "mm_attention_backend": server_args.mm_attention_backend,
216
+ "ep_num_redundant_experts": server_args.ep_num_redundant_experts,
177
217
  }
178
218
  )
179
219
 
@@ -201,6 +241,31 @@ class ModelRunner:
201
241
  enable=self.server_args.enable_memory_saver
202
242
  )
203
243
 
244
+ if not self.is_draft_worker:
245
+ set_global_expert_location_metadata(
246
+ compute_initial_expert_location_metadata(server_args, self.model_config)
247
+ )
248
+ if self.tp_rank == 0 and get_bool_env_var(
249
+ "SGLANG_LOG_EXPERT_LOCATION_METADATA"
250
+ ):
251
+ logger.info(
252
+ f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}"
253
+ )
254
+
255
+ set_global_expert_distribution_recorder(
256
+ ExpertDistributionRecorder.init_new(
257
+ server_args,
258
+ get_global_expert_location_metadata(),
259
+ rank=self.tp_rank,
260
+ )
261
+ )
262
+
263
+ self.eplb_manager = (
264
+ EPLBManager(self)
265
+ if self.server_args.enable_eplb and (not self.is_draft_worker)
266
+ else None
267
+ )
268
+
204
269
  # Load the model
205
270
  self.sampler = Sampler()
206
271
  self.load_model()
@@ -269,6 +334,8 @@ class ModelRunner:
269
334
  and is_fa3_default_architecture(self.model_config.hf_config)
270
335
  ):
271
336
  server_args.attention_backend = "fa3"
337
+ elif _is_hip:
338
+ server_args.attention_backend = "aiter"
272
339
  else:
273
340
  server_args.attention_backend = (
274
341
  "flashinfer" if is_flashinfer_available() else "triton"
@@ -279,10 +346,9 @@ class ModelRunner:
279
346
  server_args.attention_backend = "fa3"
280
347
  else:
281
348
  server_args.attention_backend = "triton"
282
- if self.should_log:
283
- logger.info(
284
- f"Attention backend not set. Use {server_args.attention_backend} backend by default."
285
- )
349
+ logger.info(
350
+ f"Attention backend not set. Use {server_args.attention_backend} backend by default."
351
+ )
286
352
  elif self.use_mla_backend:
287
353
  if server_args.device != "cpu":
288
354
  if server_args.attention_backend in [
@@ -292,10 +358,9 @@ class ModelRunner:
292
358
  "flashmla",
293
359
  "cutlass_mla",
294
360
  ]:
295
- if self.should_log:
296
- logger.info(
297
- f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
298
- )
361
+ logger.info(
362
+ f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
363
+ )
299
364
  else:
300
365
  raise ValueError(
301
366
  f"Invalid attention backend for MLA: {server_args.attention_backend}"
@@ -314,10 +379,9 @@ class ModelRunner:
314
379
  server_args.attention_backend = "triton"
315
380
 
316
381
  if server_args.enable_double_sparsity:
317
- if self.should_log:
318
- logger.info(
319
- "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
320
- )
382
+ logger.info(
383
+ "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
384
+ )
321
385
  server_args.attention_backend = "triton"
322
386
  server_args.disable_cuda_graph = True
323
387
  if server_args.ds_heavy_channel_type is None:
@@ -328,26 +392,25 @@ class ModelRunner:
328
392
 
329
393
  if self.is_multimodal:
330
394
  self.mem_fraction_static *= 0.90
331
- if self.should_log:
332
- logger.info(
333
- f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
334
- f"because this is a multimodal model."
335
- )
395
+ logger.info(
396
+ f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
397
+ f"because this is a multimodal model."
398
+ )
399
+ if not self.is_multimodal_chunked_prefill_supported:
400
+ server_args.chunked_prefill_size = -1
336
401
  logger.info(
337
- "Automatically turn off --chunked-prefill-size for multimodal model."
402
+ f"Automatically turn of --chunked-prefill-size as it is not supported for "
403
+ f"{self.model_config.hf_config.model_type}"
338
404
  )
339
- server_args.chunked_prefill_size = -1
340
405
 
341
406
  if not self.use_mla_backend:
342
407
  server_args.disable_chunked_prefix_cache = True
343
408
  elif self.page_size > 1:
344
- if self.should_log:
345
- logger.info("Disable chunked prefix cache when page size > 1.")
409
+ logger.info("Disable chunked prefix cache when page size > 1.")
346
410
  server_args.disable_chunked_prefix_cache = True
347
411
 
348
412
  if not server_args.disable_chunked_prefix_cache:
349
- if self.should_log:
350
- logger.info("Chunked prefix cache is turned on.")
413
+ logger.info("Chunked prefix cache is turned on.")
351
414
 
352
415
  def init_torch_distributed(self):
353
416
  logger.info("Init torch distributed begin.")
@@ -400,11 +463,15 @@ class ModelRunner:
400
463
  tp_rank=self.tp_rank,
401
464
  tp_size=self.tp_size,
402
465
  dp_size=self.server_args.dp_size,
466
+ moe_dense_tp_size=self.server_args.moe_dense_tp_size,
403
467
  pp_size=self.server_args.pp_size,
404
468
  )
405
469
 
406
470
  min_per_gpu_memory = get_available_gpu_memory(
407
- self.device, self.gpu_id, distributed=self.tp_size > 1
471
+ self.device,
472
+ self.gpu_id,
473
+ distributed=get_world_group().world_size > 1,
474
+ cpu_group=get_world_group().cpu_group,
408
475
  )
409
476
  self.tp_group = get_tp_group()
410
477
  self.attention_tp_group = get_attention_tp_group()
@@ -440,10 +507,9 @@ class ModelRunner:
440
507
  torch.set_num_threads(1)
441
508
  if self.device == "cuda":
442
509
  if torch.cuda.get_device_capability()[0] < 8:
443
- if self.should_log:
444
- logger.info(
445
- "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
446
- )
510
+ logger.info(
511
+ "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
512
+ )
447
513
  self.server_args.dtype = "float16"
448
514
  self.model_config.dtype = torch.float16
449
515
  if torch.cuda.get_device_capability()[1] < 5:
@@ -479,11 +545,10 @@ class ModelRunner:
479
545
  self.model.load_kv_cache_scales(
480
546
  self.server_args.quantization_param_path
481
547
  )
482
- if self.should_log:
483
- logger.info(
484
- "Loaded KV cache scaling factors from %s",
485
- self.server_args.quantization_param_path,
486
- )
548
+ logger.info(
549
+ "Loaded KV cache scaling factors from %s",
550
+ self.server_args.quantization_param_path,
551
+ )
487
552
  else:
488
553
  raise RuntimeError(
489
554
  "Using FP8 KV cache and scaling factors provided but "
@@ -526,6 +591,16 @@ class ModelRunner:
526
591
  f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
527
592
  ) from None
528
593
 
594
+ def update_expert_location(
595
+ self, new_expert_location_metadata: ExpertLocationMetadata
596
+ ):
597
+ expert_location_updater.update_expert_location(
598
+ self.model.routed_experts_weights_of_layer,
599
+ new_expert_location_metadata,
600
+ nnodes=self.server_args.nnodes,
601
+ rank=self.tp_rank,
602
+ )
603
+
529
604
  def update_weights_from_disk(
530
605
  self, model_path: str, load_format: str
531
606
  ) -> tuple[bool, str]:
@@ -547,13 +622,7 @@ class ModelRunner:
547
622
 
548
623
  def get_weight_iter(config):
549
624
  iter = loader._get_weights_iterator(
550
- DefaultModelLoader.Source(
551
- config.model_path,
552
- revision=config.revision,
553
- fall_back_to_pt=getattr(
554
- self.model, "fall_back_to_pt_during_load", True
555
- ),
556
- )
625
+ DefaultModelLoader.Source.init_new(config, self.model)
557
626
  )
558
627
  return iter
559
628
 
@@ -626,7 +695,6 @@ class ModelRunner:
626
695
  rank=rank,
627
696
  group_name=group_name,
628
697
  )
629
- dist.barrier(group=self._model_update_group, device_ids=[rank])
630
698
  return True, "Succeeded to initialize custom process group."
631
699
  except Exception as e:
632
700
  message = f"Failed to initialize custom process group: {e}."
@@ -716,14 +784,20 @@ class ModelRunner:
716
784
 
717
785
  def profile_max_num_token(self, total_gpu_memory: int):
718
786
  available_gpu_memory = get_available_gpu_memory(
719
- self.device, self.gpu_id, distributed=self.tp_size > 1
787
+ self.device,
788
+ self.gpu_id,
789
+ distributed=get_world_group().world_size > 1,
790
+ cpu_group=get_world_group().cpu_group,
720
791
  )
721
- if self.use_mla_backend:
722
- num_layers = (
723
- self.model_config.num_hidden_layers
724
- if not self.is_draft_worker
725
- else self.model_config.hf_config.num_nextn_predict_layers
792
+ if self.is_draft_worker:
793
+ num_layers = getattr(
794
+ self.model_config.hf_config,
795
+ "num_nextn_predict_layers",
796
+ self.num_effective_layers,
726
797
  )
798
+ else:
799
+ num_layers = self.num_effective_layers
800
+ if self.use_mla_backend:
727
801
  # FIXME: pipeline parallelism is not compatible with mla backend
728
802
  assert self.pp_size == 1
729
803
  cell_size = (
@@ -735,7 +809,7 @@ class ModelRunner:
735
809
  cell_size = (
736
810
  self.model_config.get_num_kv_heads(get_attention_tp_size())
737
811
  * self.model_config.head_dim
738
- * self.num_effective_layers
812
+ * num_layers
739
813
  * 2
740
814
  * torch._utils._element_size(self.kv_cache_dtype)
741
815
  )
@@ -754,7 +828,7 @@ class ModelRunner:
754
828
  if self.server_args.kv_cache_dtype == "auto":
755
829
  self.kv_cache_dtype = self.dtype
756
830
  elif self.server_args.kv_cache_dtype == "fp8_e5m2":
757
- if is_hip(): # Using natively supported format
831
+ if _is_hip: # Using natively supported format
758
832
  self.kv_cache_dtype = torch.float8_e5m2fnuz
759
833
  else:
760
834
  self.kv_cache_dtype = torch.float8_e5m2
@@ -932,6 +1006,10 @@ class ModelRunner:
932
1006
  )
933
1007
 
934
1008
  self.attn_backend = FlashInferMLAAttnBackend(self)
1009
+ elif self.server_args.attention_backend == "aiter":
1010
+ from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
1011
+
1012
+ self.attn_backend = AiterAttnBackend(self)
935
1013
  elif self.server_args.attention_backend == "triton":
936
1014
  assert self.sliding_window_size is None, (
937
1015
  "Window attention is not supported in the triton attention backend. "
@@ -1012,7 +1090,7 @@ class ModelRunner:
1012
1090
  if self.server_args.disable_cuda_graph:
1013
1091
  return
1014
1092
 
1015
- tic = time.time()
1093
+ tic = time.perf_counter()
1016
1094
  before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1017
1095
  logger.info(
1018
1096
  f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
@@ -1020,13 +1098,12 @@ class ModelRunner:
1020
1098
  self.cuda_graph_runner = CudaGraphRunner(self)
1021
1099
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1022
1100
  logger.info(
1023
- f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
1101
+ f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1024
1102
  f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
1025
1103
  )
1026
1104
 
1027
1105
  def apply_torch_tp(self):
1028
- if self.should_log:
1029
- logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1106
+ logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1030
1107
  from sglang.srt.model_parallel import tensor_parallel
1031
1108
 
1032
1109
  device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
@@ -1085,32 +1162,54 @@ class ModelRunner:
1085
1162
  forward_batch: ForwardBatch,
1086
1163
  skip_attn_backend_init: bool = False,
1087
1164
  pp_proxy_tensors: Optional[PPProxyTensors] = None,
1088
- ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
1165
+ ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1166
+ self.forward_pass_id += 1
1167
+
1168
+ with get_global_expert_distribution_recorder().with_forward_pass(
1169
+ self.forward_pass_id,
1170
+ forward_batch,
1171
+ ):
1172
+ output = self._forward_raw(
1173
+ forward_batch, skip_attn_backend_init, pp_proxy_tensors
1174
+ )
1175
+
1176
+ if self.eplb_manager is not None:
1177
+ self.eplb_manager.on_forward_pass_end(self.forward_pass_id)
1178
+
1179
+ return output
1180
+
1181
+ def _forward_raw(
1182
+ self,
1183
+ forward_batch: ForwardBatch,
1184
+ skip_attn_backend_init: bool,
1185
+ pp_proxy_tensors: Optional[PPProxyTensors],
1186
+ ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1089
1187
  can_run_cuda_graph = bool(
1090
1188
  forward_batch.forward_mode.is_cuda_graph()
1091
1189
  and self.cuda_graph_runner
1092
1190
  and self.cuda_graph_runner.can_run(forward_batch)
1093
1191
  )
1094
1192
  if can_run_cuda_graph:
1095
- return self.cuda_graph_runner.replay(
1193
+ ret = self.cuda_graph_runner.replay(
1096
1194
  forward_batch,
1097
1195
  skip_attn_backend_init=skip_attn_backend_init,
1098
1196
  pp_proxy_tensors=pp_proxy_tensors,
1099
1197
  )
1100
-
1101
- if forward_batch.forward_mode.is_decode():
1102
- return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1198
+ elif forward_batch.forward_mode.is_decode():
1199
+ ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1103
1200
  elif forward_batch.forward_mode.is_extend():
1104
- return self.forward_extend(
1201
+ ret = self.forward_extend(
1105
1202
  forward_batch,
1106
1203
  skip_attn_backend_init=skip_attn_backend_init,
1107
1204
  pp_proxy_tensors=pp_proxy_tensors,
1108
1205
  )
1109
1206
  elif forward_batch.forward_mode.is_idle():
1110
- return self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1207
+ ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1111
1208
  else:
1112
1209
  raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1113
1210
 
1211
+ return ret, can_run_cuda_graph
1212
+
1114
1213
  def _preprocess_logits(
1115
1214
  self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
1116
1215
  ):
@@ -1145,9 +1244,7 @@ class ModelRunner:
1145
1244
  [self.sample(values, forward_batch) for values in logits_output],
1146
1245
  axis=-1,
1147
1246
  )
1148
- sampling_info = forward_batch.sampling_info
1149
- if sampling_info.thinking_budgets is not None:
1150
- sampling_info.apply_thinking_budgets(logits_output.next_token_logits)
1247
+
1151
1248
  self._preprocess_logits(logits_output, forward_batch.sampling_info)
1152
1249
 
1153
1250
  # Sample the next tokens
@@ -1158,15 +1255,13 @@ class ModelRunner:
1158
1255
  forward_batch.top_logprobs_nums,
1159
1256
  forward_batch.token_ids_logprobs,
1160
1257
  )
1161
- if sampling_info.thinking_budgets is not None:
1162
- sampling_info.update_thinking_budgets(next_token_ids)
1163
1258
  return next_token_ids
1164
1259
 
1165
1260
  @property
1166
1261
  def model_is_mrope(self) -> bool:
1167
1262
  """Detect if the model has "mrope" rope_scaling type.
1168
1263
  mrope requires keep "rope_deltas" between prompt and decoding phases."""
1169
- rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
1264
+ rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
1170
1265
  if rope_scaling is None:
1171
1266
  return False
1172
1267
  is_mrope_enabled = "mrope_section" in rope_scaling
@@ -197,6 +197,15 @@ class DefaultModelLoader(BaseModelLoader):
197
197
  fall_back_to_pt: bool = True
198
198
  """Whether .pt weights can be used."""
199
199
 
200
+ @classmethod
201
+ def init_new(cls, model_config: ModelConfig, model):
202
+ return cls(
203
+ model_config.model_path,
204
+ model_config.revision,
205
+ prefix="",
206
+ fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
207
+ )
208
+
200
209
  def __init__(self, load_config: LoadConfig):
201
210
  super().__init__(load_config)
202
211
  if load_config.model_loader_extra_config:
@@ -341,12 +350,7 @@ class DefaultModelLoader(BaseModelLoader):
341
350
  model: nn.Module,
342
351
  ) -> Generator[Tuple[str, torch.Tensor], None, None]:
343
352
 
344
- primary_weights = DefaultModelLoader.Source(
345
- model_config.model_path,
346
- model_config.revision,
347
- prefix="",
348
- fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
349
- )
353
+ primary_weights = DefaultModelLoader.Source.init_new(model_config, model)
350
354
  yield from self._get_weights_iterator(primary_weights)
351
355
 
352
356
  secondary_weights = cast(
sglang/srt/models/clip.py CHANGED
@@ -168,7 +168,7 @@ class CLIPEncoderLayer(nn.Module):
168
168
  softmax_in_single_precision=softmax_in_single_precision,
169
169
  flatten_batch=True,
170
170
  quant_config=quant_config,
171
- prefix=add_prefix("attn", prefix),
171
+ prefix=add_prefix("self_attn", prefix),
172
172
  )
173
173
  self.mlp = CLIPMLP(
174
174
  config,
@@ -395,6 +395,10 @@ class CLIPVisionModel(nn.Module):
395
395
  config, quant_config, prefix=add_prefix("vision_model", prefix)
396
396
  )
397
397
 
398
+ @property
399
+ def device(self) -> torch.device:
400
+ return self.vision_model.device
401
+
398
402
  def forward(self, pixel_values: torch.Tensor):
399
403
  return self.vision_model(pixel_values)
400
404
 
@@ -188,7 +188,7 @@ def trunc_normal_tf_(
188
188
  best when :math:`a \\leq \text{mean} \\leq b`.
189
189
  NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
190
190
  bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
191
- and the result is subsquently scaled and shifted by the mean and std args.
191
+ and the result is subsequently scaled and shifted by the mean and std args.
192
192
  Args:
193
193
  tensor: an n-dimensional `torch.Tensor`
194
194
  mean: the mean of the normal distribution
@@ -735,7 +735,7 @@ class VisionTransformer(nn.Module):
735
735
  img_size: Input image size.
736
736
  patch_size: Patch size.
737
737
  in_chans: Number of image input channels.
738
- num_classes: Mumber of classes for classification head.
738
+ num_classes: Number of classes for classification head.
739
739
  global_pool: Type of global pooling for final sequence (default: 'token').
740
740
  embed_dim: Transformer embedding dimension.
741
741
  depth: Depth of transformer.