sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -82,12 +82,12 @@ class EAGLEDraftCudaGraphRunner:
82
82
  self.capture()
83
83
  except RuntimeError as e:
84
84
  raise Exception(
85
- f"Capture cuda graph failed: {e}\n"
85
+ f"Capture CUDA graph failed: {e}\n"
86
86
  "Possible solutions:\n"
87
87
  "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
88
88
  "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
89
89
  "3. disable torch compile by not using --enable-torch-compile\n"
90
- "4. disable cuda graph by --disable-cuda-graph. (Not recommonded. Huge perf loss)\n"
90
+ "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
91
91
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
92
92
  )
93
93
 
@@ -149,7 +149,7 @@ class EAGLEDraftCudaGraphRunner:
149
149
 
150
150
  # Run and capture
151
151
  def run_once():
152
- # Backup two fileds, which will be modified in-place in `draft_forward`.
152
+ # Backup two fields, which will be modified in-place in `draft_forward`.
153
153
  output_cache_loc_backup = forward_batch.out_cache_loc
154
154
  hidden_states_backup = forward_batch.spec_info.hidden_states
155
155
 
@@ -167,12 +167,12 @@ class EagleVerifyOutput:
167
167
  draft_input: EagleDraftInput
168
168
  # Logit outputs from target worker
169
169
  logits_output: LogitsProcessorOutput
170
- # Accepeted token ids including the bonus token
170
+ # Accepted token ids including the bonus token
171
171
  verified_id: torch.Tensor
172
- # Accepeted token length per sequence in a batch in CPU.
172
+ # Accepted token length per sequence in a batch in CPU.
173
173
  accept_length_per_req_cpu: List[int]
174
- # Accepeted indices from logits_output.next_token_logits
175
- accepeted_indices: torch.Tensor
174
+ # Accepted indices from logits_output.next_token_logits
175
+ accepted_indices: torch.Tensor
176
176
 
177
177
 
178
178
  @dataclass
@@ -316,7 +316,7 @@ class EagleVerifyInput:
316
316
 
317
317
  This API updates values inside logits_output based on the accepted
318
318
  tokens. I.e., logits_output.next_token_logits only contains
319
- accepeted token logits.
319
+ accepted token logits.
320
320
  """
321
321
  bs = self.retrive_index.shape[0]
322
322
  candidates = self.draft_token.reshape(bs, self.draft_token_num)
@@ -493,7 +493,7 @@ class EagleVerifyInput:
493
493
  logits_output=logits_output,
494
494
  verified_id=verified_id,
495
495
  accept_length_per_req_cpu=accept_length_cpu,
496
- accepeted_indices=accept_index,
496
+ accepted_indices=accept_index,
497
497
  )
498
498
  else:
499
499
  assign_req_to_token_pool[(bs,)](
@@ -539,7 +539,7 @@ class EagleVerifyInput:
539
539
  logits_output=logits_output,
540
540
  verified_id=verified_id,
541
541
  accept_length_per_req_cpu=accept_length_cpu,
542
- accepeted_indices=accept_index,
542
+ accepted_indices=accept_index,
543
543
  )
544
544
 
545
545
 
@@ -201,7 +201,7 @@ class EAGLEWorker(TpModelWorker):
201
201
  self.has_prefill_wrapper_verify = False
202
202
  else:
203
203
  raise ValueError(
204
- f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
204
+ f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
205
205
  )
206
206
 
207
207
  self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
@@ -245,14 +245,14 @@ class EAGLEWorker(TpModelWorker):
245
245
  Args:
246
246
  batch: The batch to run forward. The state of the batch is modified as it runs.
247
247
  Returns:
248
- A tuple of the final logit output of the target model, next tokens accepeted,
249
- the batch id (used for overlap schedule), and number of accepeted tokens.
248
+ A tuple of the final logit output of the target model, next tokens accepted,
249
+ the batch id (used for overlap schedule), and number of accepted tokens.
250
250
  """
251
251
  if batch.forward_mode.is_decode():
252
252
  with self.draft_tp_context(self.draft_model_runner.tp_group):
253
253
  spec_info = self.draft(batch)
254
- logits_output, verify_output, model_worker_batch = self.verify(
255
- batch, spec_info
254
+ logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
255
+ self.verify(batch, spec_info)
256
256
  )
257
257
 
258
258
  # If it is None, it means all requests are finished
@@ -264,21 +264,22 @@ class EAGLEWorker(TpModelWorker):
264
264
  verify_output.verified_id,
265
265
  model_worker_batch.bid,
266
266
  sum(verify_output.accept_length_per_req_cpu),
267
+ can_run_cuda_graph,
267
268
  )
268
269
  elif batch.forward_mode.is_idle():
269
270
  model_worker_batch = batch.get_model_worker_batch()
270
- logits_output, next_token_ids = self.target_worker.forward_batch_generation(
271
- model_worker_batch
271
+ logits_output, next_token_ids, _ = (
272
+ self.target_worker.forward_batch_generation(model_worker_batch)
272
273
  )
273
274
 
274
- return logits_output, next_token_ids, model_worker_batch.bid, 0
275
+ return logits_output, next_token_ids, model_worker_batch.bid, 0, False
275
276
  else:
276
277
  logits_output, next_token_ids, bid = self.forward_target_extend(batch)
277
278
  with self.draft_tp_context(self.draft_model_runner.tp_group):
278
279
  self.forward_draft_extend(
279
280
  batch, logits_output.hidden_states, next_token_ids
280
281
  )
281
- return logits_output, next_token_ids, bid, 0
282
+ return logits_output, next_token_ids, bid, 0, False
282
283
 
283
284
  def forward_target_extend(
284
285
  self, batch: ScheduleBatch
@@ -297,7 +298,7 @@ class EAGLEWorker(TpModelWorker):
297
298
  # We need the full hidden states to prefill the KV cache of the draft model.
298
299
  model_worker_batch = batch.get_model_worker_batch()
299
300
  model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
300
- logits_output, next_token_ids = self.target_worker.forward_batch_generation(
301
+ logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
301
302
  model_worker_batch
302
303
  )
303
304
  return logits_output, next_token_ids, model_worker_batch.bid
@@ -478,8 +479,10 @@ class EAGLEWorker(TpModelWorker):
478
479
  batch.forward_mode = ForwardMode.TARGET_VERIFY
479
480
  batch.spec_info = spec_info
480
481
  model_worker_batch = batch.get_model_worker_batch()
481
- logits_output, _ = self.target_worker.forward_batch_generation(
482
- model_worker_batch, skip_sample=True
482
+ logits_output, _, can_run_cuda_graph = (
483
+ self.target_worker.forward_batch_generation(
484
+ model_worker_batch, skip_sample=True
485
+ )
483
486
  )
484
487
  self._detect_nan_if_needed(logits_output)
485
488
  spec_info.hidden_states = logits_output.hidden_states
@@ -491,11 +494,11 @@ class EAGLEWorker(TpModelWorker):
491
494
  )
492
495
 
493
496
  # Post process based on verified outputs.
494
- # Pick indices that we care (accepeted)
497
+ # Pick indices that we care (accepted)
495
498
  logits_output.next_token_logits = logits_output.next_token_logits[
496
- res.accepeted_indices
499
+ res.accepted_indices
497
500
  ]
498
- logits_output.hidden_states = logits_output.hidden_states[res.accepeted_indices]
501
+ logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
499
502
 
500
503
  # Prepare the batch for the next draft forwards.
501
504
  batch.forward_mode = ForwardMode.DECODE
@@ -504,7 +507,7 @@ class EAGLEWorker(TpModelWorker):
504
507
  if batch.return_logprob:
505
508
  self.add_logprob_values(batch, res, logits_output)
506
509
 
507
- return logits_output, res, model_worker_batch
510
+ return logits_output, res, model_worker_batch, can_run_cuda_graph
508
511
 
509
512
  def add_logprob_values(
510
513
  self,
@@ -590,14 +593,14 @@ class EAGLEWorker(TpModelWorker):
590
593
  model_worker_batch, self.draft_model_runner
591
594
  )
592
595
  forward_batch.return_logprob = False
593
- logits_output = self.draft_model_runner.forward(forward_batch)
596
+ logits_output, _ = self.draft_model_runner.forward(forward_batch)
594
597
  self._detect_nan_if_needed(logits_output)
595
598
  assert isinstance(forward_batch.spec_info, EagleDraftInput)
596
599
  assert forward_batch.spec_info is batch.spec_info
597
600
  self.capture_for_decode(logits_output, forward_batch.spec_info)
598
601
 
599
602
  def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
600
- # Backup fileds that will be modified in-place
603
+ # Backup fields that will be modified in-place
601
604
  seq_lens_backup = batch.seq_lens.clone()
602
605
  req_pool_indices_backup = batch.req_pool_indices
603
606
  accept_length_backup = batch.spec_info.accept_length
@@ -617,7 +620,7 @@ class EAGLEWorker(TpModelWorker):
617
620
  )
618
621
 
619
622
  # Run
620
- logits_output = self.draft_model_runner.forward(forward_batch)
623
+ logits_output, _ = self.draft_model_runner.forward(forward_batch)
621
624
 
622
625
  self._detect_nan_if_needed(logits_output)
623
626
  self.capture_for_decode(logits_output, forward_batch.spec_info)
sglang/srt/utils.py CHANGED
@@ -145,6 +145,10 @@ def is_xpu() -> bool:
145
145
  return hasattr(torch, "xpu") and torch.xpu.is_available()
146
146
 
147
147
 
148
+ def is_npu() -> bool:
149
+ return hasattr(torch, "npu") and torch.npu.is_available()
150
+
151
+
148
152
  def is_flashinfer_available():
149
153
  """
150
154
  Check whether flashinfer is available.
@@ -278,7 +282,9 @@ def calculate_time(show=False, min_cost_ms=0.0):
278
282
  return wrapper
279
283
 
280
284
 
281
- def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True):
285
+ def get_available_gpu_memory(
286
+ device, gpu_id, distributed=False, empty_cache=True, cpu_group=None
287
+ ):
282
288
  """
283
289
  Get available memory for cuda:gpu_id device.
284
290
  When distributed is True, the available memory is the minimum available memory of all GPUs.
@@ -328,12 +334,22 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
328
334
  elif device == "cpu":
329
335
  # TODO: rename the variables in the current function to be not GPU specific
330
336
  free_gpu_memory = psutil.virtual_memory().available
337
+ elif device == "npu":
338
+ num_gpus = torch.npu.device_count()
339
+ assert gpu_id < num_gpus
340
+
341
+ if torch.npu.current_device() != gpu_id:
342
+ print(
343
+ f"WARNING: current device is not {gpu_id}, but {torch.npu.current_device()}, ",
344
+ "which may cause useless memory allocation for torch NPU context.",
345
+ )
346
+ free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
331
347
 
332
348
  if distributed:
333
- tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
334
- torch.device(device, gpu_id)
349
+ tensor = torch.tensor(free_gpu_memory, dtype=torch.float32)
350
+ torch.distributed.all_reduce(
351
+ tensor, op=torch.distributed.ReduceOp.MIN, group=cpu_group
335
352
  )
336
- torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
337
353
  free_gpu_memory = tensor.item()
338
354
 
339
355
  return free_gpu_memory / (1 << 30)
@@ -897,7 +913,10 @@ def broadcast_pyobj(
897
913
  src: int = 0,
898
914
  force_cpu_device: bool = True,
899
915
  ):
900
- """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
916
+ """Broadcast inputs from src rank to all other ranks with torch.dist backend.
917
+ The `rank` here refer to the source rank on global process group (regardless
918
+ of dist_group argument).
919
+ """
901
920
  device = torch.device(
902
921
  "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
903
922
  )
@@ -1345,6 +1364,9 @@ def get_device_name(device_id: int = 0) -> str:
1345
1364
  if hasattr(torch, "hpu") and torch.hpu.is_available():
1346
1365
  return torch.hpu.get_device_name(device_id)
1347
1366
 
1367
+ if hasattr(torch, "npu") and torch.npu.is_available():
1368
+ return torch.npu.get_device_name(device_id)
1369
+
1348
1370
 
1349
1371
  @lru_cache(maxsize=1)
1350
1372
  def is_habana_available() -> bool:
@@ -1441,6 +1463,13 @@ def get_compiler_backend() -> str:
1441
1463
  if hasattr(torch, "hpu") and torch.hpu.is_available():
1442
1464
  return "hpu_backend"
1443
1465
 
1466
+ if hasattr(torch, "npu") and torch.npu.is_available():
1467
+ import torchair
1468
+
1469
+ config = torchair.CompilerConfig()
1470
+ npu_backend = torchair.get_npu_backend(compiler_config=config)
1471
+ return npu_backend
1472
+
1444
1473
  return "inductor"
1445
1474
 
1446
1475
 
@@ -2049,7 +2078,6 @@ def is_fa3_default_architecture(hf_config):
2049
2078
  "Llama4ForConditionalGeneration",
2050
2079
  "LlamaForCausalLM",
2051
2080
  "MistralForCausalLM",
2052
- "MixtralForCausalLM",
2053
2081
  "Gemma2ForCausalLM",
2054
2082
  "Gemma3ForConditionalGeneration",
2055
2083
  "Qwen3ForCausalLM",
@@ -2069,3 +2097,10 @@ class BumpAllocator:
2069
2097
  output = self._buffer[self._pointer : self._pointer + size]
2070
2098
  self._pointer += size
2071
2099
  return output
2100
+
2101
+
2102
+ def log_info_on_rank0(logger, msg):
2103
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
2104
+
2105
+ if get_tensor_model_parallel_rank() == 0:
2106
+ logger.info(msg)
@@ -90,7 +90,7 @@ def run_eval(args):
90
90
  #####################################
91
91
 
92
92
  # Run requests
93
- tic = time.time()
93
+ tic = time.perf_counter()
94
94
  states = few_shot_gsm8k.run_batch(
95
95
  arguments,
96
96
  temperature=args.temperature if hasattr(args, "temperature") else 0,
@@ -99,7 +99,7 @@ def run_eval(args):
99
99
  return_logprob=getattr(args, "return_logprob", None),
100
100
  logprob_start_len=getattr(args, "logprob_start_len", None),
101
101
  )
102
- latency = time.time() - tic
102
+ latency = time.perf_counter() - tic
103
103
 
104
104
  preds = []
105
105
  for i in range(len(states)):
@@ -89,7 +89,7 @@ def run_eval(args):
89
89
  }
90
90
 
91
91
  # Run requests
92
- tic = time.time()
92
+ tic = time.perf_counter()
93
93
 
94
94
  loop = asyncio.get_event_loop()
95
95
 
@@ -98,7 +98,7 @@ def run_eval(args):
98
98
  )
99
99
 
100
100
  # End requests
101
- latency = time.time() - tic
101
+ latency = time.perf_counter() - tic
102
102
 
103
103
  # Shutdown the engine
104
104
  engine.shutdown()
sglang/test/run_eval.py CHANGED
@@ -71,9 +71,9 @@ def run_eval(args):
71
71
  )
72
72
 
73
73
  # Run eval
74
- tic = time.time()
74
+ tic = time.perf_counter()
75
75
  result = eval_obj(sampler)
76
- latency = time.time() - tic
76
+ latency = time.perf_counter() - tic
77
77
 
78
78
  # Dump reports
79
79
  metrics = result.metrics | {"score": result.score}
sglang/test/runners.py CHANGED
@@ -19,7 +19,9 @@ from typing import List, Optional, Tuple, Union
19
19
 
20
20
  import torch
21
21
  import torch.nn.functional as F
22
+ import transformers
22
23
  from transformers import (
24
+ AutoConfig,
23
25
  AutoModel,
24
26
  AutoModelForCausalLM,
25
27
  AutoModelForVision2Seq,
@@ -211,7 +213,12 @@ class HFRunner:
211
213
 
212
214
  # Load the model and tokenizer
213
215
  if self.model_type == "generation":
214
- self.base_model = AutoModelForCausalLM.from_pretrained(
216
+ config = AutoConfig.from_pretrained(model_path)
217
+ if model_archs := getattr(config, "architectures"):
218
+ model_cls = getattr(transformers, model_archs[0])
219
+ else:
220
+ model_cls = AutoModelForCausalLM
221
+ self.base_model = model_cls.from_pretrained(
215
222
  model_path,
216
223
  torch_dtype=torch_dtype,
217
224
  trust_remote_code=self.trust_remote_code,
sglang/test/send_one.py CHANGED
@@ -27,6 +27,7 @@ class BenchArgs:
27
27
  "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
28
28
  )
29
29
  image: bool = False
30
+ many_images: bool = False
30
31
  stream: bool = False
31
32
 
32
33
  @staticmethod
@@ -48,6 +49,7 @@ class BenchArgs:
48
49
  parser.add_argument("--return-logprob", action="store_true")
49
50
  parser.add_argument("--prompt", type=str, default=BenchArgs.prompt)
50
51
  parser.add_argument("--image", action="store_true")
52
+ parser.add_argument("--many-images", action="store_true")
51
53
  parser.add_argument("--stream", action="store_true")
52
54
 
53
55
  @classmethod
@@ -62,6 +64,17 @@ def send_one_prompt(args):
62
64
  "Human: Describe this image in a very short sentence.\n\nAssistant:"
63
65
  )
64
66
  image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
67
+ elif args.many_images:
68
+ args.prompt = (
69
+ "Human: I have one reference image and many images."
70
+ "Describe their relationship in a very short sentence.\n\nAssistant:"
71
+ )
72
+ image_data = [
73
+ "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
74
+ "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
75
+ "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
76
+ "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
77
+ ]
65
78
  else:
66
79
  image_data = None
67
80
 
@@ -74,9 +87,6 @@ def send_one_prompt(args):
74
87
  "Write in a format of json.\nAssistant:"
75
88
  )
76
89
  json_schema = "$$ANY$$"
77
- json_schema = (
78
- '{"type": "object", "properties": {"population": {"type": "integer"}}}'
79
- )
80
90
  else:
81
91
  json_schema = None
82
92
 
@@ -140,7 +140,7 @@ class ChatCompletionSampler(SamplerBase):
140
140
  max_tokens=self.max_tokens,
141
141
  )
142
142
  return response.choices[0].message.content
143
- # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
143
+ # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
144
144
  except openai.BadRequestError as e:
145
145
  print("Bad Request Error", e)
146
146
  return ""
@@ -121,7 +121,7 @@ class HumanEval(Eval):
121
121
  convo=convo,
122
122
  metrics={
123
123
  f"pass@{k}": estimate_pass_at_k([total], [correct], k)
124
- # this will be aggrated so no need of .mean()
124
+ # this will be aggregated so no need of .mean()
125
125
  for k in self._ks_passes
126
126
  if total >= k
127
127
  },
@@ -7,9 +7,9 @@ import torch
7
7
  from sglang.srt.layers.activation import SiluAndMul
8
8
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
9
9
  from sglang.srt.layers.quantization.fp8_kernel import (
10
- per_tensor_quant_mla_deep_gemm_masked_fp8,
11
10
  per_tensor_quant_mla_fp8,
12
11
  per_token_group_quant_fp8,
12
+ per_token_group_quant_mla_deep_gemm_masked_fp8,
13
13
  static_quant_fp8,
14
14
  w8a8_block_fp8_matmul,
15
15
  )
@@ -236,7 +236,7 @@ class TestPerTokenGroupQuantMlaDeepGemmMaskedFP8(CustomTestCase):
236
236
 
237
237
  with torch.inference_mode():
238
238
  ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size, 1e-12)
239
- out, scale, _, _, _ = per_tensor_quant_mla_deep_gemm_masked_fp8(
239
+ out, scale, _, _, _ = per_token_group_quant_mla_deep_gemm_masked_fp8(
240
240
  x, group_size
241
241
  )
242
242
  out = out[:, :num_tokens, :]
@@ -0,0 +1,219 @@
1
+ # Copy from deepseek-ai/DeepEP/tests/test_utils.py
2
+
3
+ import os
4
+ import sys
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.distributed as dist
10
+
11
+
12
+ def init_dist(local_rank: int, num_local_ranks: int):
13
+ # NOTES: you may rewrite this function with your own cluster settings
14
+ ip = os.getenv("MASTER_ADDR", "127.0.0.1")
15
+ port = int(os.getenv("MASTER_PORT", "8361"))
16
+ num_nodes = int(os.getenv("WORLD_SIZE", 1))
17
+ node_rank = int(os.getenv("RANK", 0))
18
+ assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
19
+
20
+ dist.init_process_group(
21
+ backend="nccl",
22
+ init_method=f"tcp://{ip}:{port}",
23
+ world_size=num_nodes * num_local_ranks,
24
+ rank=node_rank * num_local_ranks + local_rank,
25
+ )
26
+ torch.set_default_dtype(torch.bfloat16)
27
+ torch.set_default_device("cuda")
28
+ torch.cuda.set_device(local_rank)
29
+
30
+ return (
31
+ dist.get_rank(),
32
+ dist.get_world_size(),
33
+ dist.new_group(list(range(num_local_ranks * num_nodes))),
34
+ )
35
+
36
+
37
+ def calc_diff(x: torch.Tensor, y: torch.Tensor):
38
+ x, y = x.double() + 1, y.double() + 1
39
+ denominator = (x * x + y * y).sum()
40
+ sim = 2 * (x * y).sum() / denominator
41
+ return (1 - sim).item()
42
+
43
+
44
+ def per_token_cast_to_fp8(x: torch.Tensor):
45
+ assert x.dim() == 2 and x.size(1) % 128 == 0
46
+ m, n = x.shape
47
+ x_view = x.view(m, -1, 128)
48
+ x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
49
+ return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
50
+ m, n
51
+ ), (x_amax / 448.0).view(m, -1)
52
+
53
+
54
+ def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
55
+ x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
56
+ x_scales = x_scales.view(x_fp8.size(0), -1, 1)
57
+ return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
58
+
59
+
60
+ def inplace_unique(x: torch.Tensor, num_slots: int):
61
+ assert x.dim() == 2
62
+ mask = x < 0
63
+ x_padded = x.masked_fill(mask, num_slots)
64
+ bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
65
+ bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
66
+ bin_count = bin_count[:, :num_slots]
67
+ sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
68
+ sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
69
+ sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
70
+ x[:, :].fill_(-1)
71
+ valid_len = min(num_slots, x.size(1))
72
+ x[:, :valid_len] = sorted_bin_idx[:, :valid_len]
73
+
74
+
75
+ def create_grouped_scores(
76
+ scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int
77
+ ):
78
+ num_tokens, num_experts = scores.shape
79
+ scores = scores.view(num_tokens, num_groups, -1)
80
+ mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
81
+ mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
82
+ return (scores * mask).view(num_tokens, num_experts)
83
+
84
+
85
+ def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None):
86
+ # Flush L2 cache with 256 MB data
87
+ torch.cuda.synchronize()
88
+ cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
89
+
90
+ # Warmup
91
+ for _ in range(num_warmups):
92
+ fn()
93
+
94
+ # Flush L2
95
+ cache.zero_()
96
+
97
+ # Testing
98
+ start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
99
+ end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
100
+ for i in range(num_tests):
101
+ # Record
102
+ start_events[i].record()
103
+ fn()
104
+ end_events[i].record()
105
+ if post_fn is not None:
106
+ post_fn()
107
+ torch.cuda.synchronize()
108
+
109
+ times = np.array(
110
+ [s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)]
111
+ )[1:]
112
+ return np.average(times), np.min(times), np.max(times)
113
+
114
+
115
+ class empty_suppress:
116
+ def __enter__(self):
117
+ return self
118
+
119
+ def __exit__(self, *_):
120
+ pass
121
+
122
+
123
+ class suppress_stdout_stderr:
124
+ def __enter__(self):
125
+ self.outnull_file = open(os.devnull, "w")
126
+ self.errnull_file = open(os.devnull, "w")
127
+
128
+ self.old_stdout_fileno_undup = sys.stdout.fileno()
129
+ self.old_stderr_fileno_undup = sys.stderr.fileno()
130
+
131
+ self.old_stdout_fileno = os.dup(sys.stdout.fileno())
132
+ self.old_stderr_fileno = os.dup(sys.stderr.fileno())
133
+
134
+ self.old_stdout = sys.stdout
135
+ self.old_stderr = sys.stderr
136
+
137
+ os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
138
+ os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
139
+
140
+ sys.stdout = self.outnull_file
141
+ sys.stderr = self.errnull_file
142
+ return self
143
+
144
+ def __exit__(self, *_):
145
+ sys.stdout = self.old_stdout
146
+ sys.stderr = self.old_stderr
147
+
148
+ os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
149
+ os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
150
+
151
+ os.close(self.old_stdout_fileno)
152
+ os.close(self.old_stderr_fileno)
153
+
154
+ self.outnull_file.close()
155
+ self.errnull_file.close()
156
+
157
+
158
+ def bench_kineto(
159
+ fn,
160
+ kernel_names,
161
+ num_tests: int = 30,
162
+ suppress_kineto_output: bool = False,
163
+ trace_path: Optional[str] = None,
164
+ barrier_comm_profiling: bool = False,
165
+ ):
166
+ # Profile
167
+ suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
168
+ with suppress():
169
+ schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
170
+ with torch.profiler.profile(
171
+ activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
172
+ ) as prof:
173
+ for i in range(2):
174
+ # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
175
+ if barrier_comm_profiling:
176
+ lhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
177
+ rhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
178
+ lhs @ rhs
179
+ dist.all_reduce(torch.ones(1, dtype=torch.float, device="cuda"))
180
+ for _ in range(num_tests):
181
+ fn()
182
+ prof.step()
183
+
184
+ # Parse the profiling table
185
+ assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
186
+ is_tupled = isinstance(kernel_names, tuple)
187
+ prof_lines = (
188
+ prof.key_averages()
189
+ .table(sort_by="cuda_time_total", max_name_column_width=100)
190
+ .split("\n")
191
+ )
192
+ kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
193
+ assert all([isinstance(name, str) for name in kernel_names])
194
+ for name in kernel_names:
195
+ assert (
196
+ sum([name in line for line in prof_lines]) == 1
197
+ ), f"Errors of the kernel {name} in the profiling table"
198
+
199
+ # Save chrome traces
200
+ if trace_path is not None:
201
+ prof.export_chrome_trace(trace_path)
202
+
203
+ # Return average kernel times
204
+ units = {"ms": 1e3, "us": 1e6}
205
+ kernel_times = []
206
+ for name in kernel_names:
207
+ for line in prof_lines:
208
+ if name in line:
209
+ time_str = line.split()[-2]
210
+ for unit, scale in units.items():
211
+ if unit in time_str:
212
+ kernel_times.append(float(time_str.replace(unit, "")) / scale)
213
+ break
214
+ break
215
+ return tuple(kernel_times) if is_tupled else kernel_times[0]
216
+
217
+
218
+ def hash_tensor(t: torch.Tensor):
219
+ return t.view(torch.int64).sum().item()