sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/constants.py +3 -0
  5. sglang/srt/conversation.py +13 -3
  6. sglang/srt/custom_op.py +5 -1
  7. sglang/srt/disaggregation/decode.py +22 -28
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  9. sglang/srt/disaggregation/mini_lb.py +34 -4
  10. sglang/srt/disaggregation/mooncake/conn.py +12 -16
  11. sglang/srt/disaggregation/prefill.py +17 -13
  12. sglang/srt/disaggregation/utils.py +46 -18
  13. sglang/srt/distributed/parallel_state.py +12 -4
  14. sglang/srt/entrypoints/engine.py +22 -28
  15. sglang/srt/entrypoints/http_server.py +149 -79
  16. sglang/srt/entrypoints/http_server_engine.py +0 -3
  17. sglang/srt/entrypoints/openai/__init__.py +0 -0
  18. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
  19. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  20. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  21. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  22. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  23. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  24. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  25. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  26. sglang/srt/entrypoints/openai/utils.py +72 -0
  27. sglang/srt/function_call/base_format_detector.py +7 -4
  28. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  29. sglang/srt/function_call/ebnf_composer.py +64 -10
  30. sglang/srt/function_call/function_call_parser.py +6 -6
  31. sglang/srt/function_call/llama32_detector.py +1 -1
  32. sglang/srt/function_call/mistral_detector.py +1 -1
  33. sglang/srt/function_call/pythonic_detector.py +1 -1
  34. sglang/srt/function_call/qwen25_detector.py +1 -1
  35. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  36. sglang/srt/layers/activation.py +21 -3
  37. sglang/srt/layers/attention/aiter_backend.py +5 -2
  38. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  39. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  40. sglang/srt/layers/attention/flashattention_backend.py +19 -9
  41. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  42. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  43. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  44. sglang/srt/layers/attention/tbo_backend.py +3 -3
  45. sglang/srt/layers/attention/triton_backend.py +19 -11
  46. sglang/srt/layers/communicator.py +5 -5
  47. sglang/srt/layers/dp_attention.py +11 -2
  48. sglang/srt/layers/layernorm.py +29 -2
  49. sglang/srt/layers/logits_processor.py +2 -2
  50. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  51. sglang/srt/layers/moe/ep_moe/layer.py +207 -1
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
  54. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  55. sglang/srt/layers/moe/topk.py +91 -4
  56. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  57. sglang/srt/layers/quantization/fp8.py +25 -17
  58. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  59. sglang/srt/layers/quantization/utils.py +5 -2
  60. sglang/srt/layers/rotary_embedding.py +42 -2
  61. sglang/srt/layers/sampler.py +1 -1
  62. sglang/srt/lora/lora_manager.py +173 -74
  63. sglang/srt/lora/mem_pool.py +49 -45
  64. sglang/srt/lora/utils.py +1 -1
  65. sglang/srt/managers/cache_controller.py +33 -15
  66. sglang/srt/managers/io_struct.py +9 -12
  67. sglang/srt/managers/schedule_batch.py +40 -31
  68. sglang/srt/managers/schedule_policy.py +70 -56
  69. sglang/srt/managers/scheduler.py +147 -62
  70. sglang/srt/managers/template_manager.py +226 -0
  71. sglang/srt/managers/tokenizer_manager.py +11 -8
  72. sglang/srt/managers/tp_worker.py +12 -2
  73. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  74. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  75. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  76. sglang/srt/mem_cache/chunk_cache.py +11 -16
  77. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  78. sglang/srt/mem_cache/memory_pool.py +118 -114
  79. sglang/srt/mem_cache/radix_cache.py +20 -16
  80. sglang/srt/model_executor/cuda_graph_runner.py +76 -45
  81. sglang/srt/model_executor/forward_batch_info.py +18 -5
  82. sglang/srt/model_executor/model_runner.py +22 -6
  83. sglang/srt/model_loader/loader.py +8 -1
  84. sglang/srt/model_loader/weight_utils.py +11 -2
  85. sglang/srt/models/deepseek_nextn.py +29 -27
  86. sglang/srt/models/deepseek_v2.py +108 -26
  87. sglang/srt/models/glm4.py +312 -0
  88. sglang/srt/models/mimo_mtp.py +2 -18
  89. sglang/srt/reasoning_parser.py +21 -11
  90. sglang/srt/server_args.py +36 -8
  91. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  92. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  93. sglang/srt/speculative/eagle_utils.py +80 -8
  94. sglang/srt/speculative/eagle_worker.py +124 -41
  95. sglang/srt/torch_memory_saver_adapter.py +19 -15
  96. sglang/srt/utils.py +177 -11
  97. sglang/test/test_block_fp8_ep.py +1 -0
  98. sglang/test/test_utils.py +1 -0
  99. sglang/version.py +1 -1
  100. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
  101. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
  102. sglang/srt/entrypoints/verl_engine.py +0 -179
  103. sglang/srt/openai_api/adapter.py +0 -2148
  104. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  105. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  106. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,12 @@ from typing import List, Optional, Tuple
7
7
  import torch
8
8
  from huggingface_hub import snapshot_download
9
9
 
10
- from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
11
- from sglang.srt.layers.dp_attention import disable_dp_size
10
+ from sglang.srt.distributed import (
11
+ GroupCoordinator,
12
+ get_tensor_model_parallel_world_size,
13
+ get_tp_group,
14
+ patch_tensor_parallel_group,
15
+ )
12
16
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
13
17
  from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
14
18
  from sglang.srt.managers.schedule_batch import (
@@ -57,7 +61,7 @@ logger = logging.getLogger(__name__)
57
61
  def draft_tp_context(tp_group: GroupCoordinator):
58
62
  # Draft model doesn't use dp and has its own tp group.
59
63
  # We disable mscclpp now because it doesn't support 2 comm groups.
60
- with disable_dp_size(), patch_tensor_parallel_group(tp_group):
64
+ with patch_tensor_parallel_group(tp_group):
61
65
  yield
62
66
 
63
67
 
@@ -76,6 +80,7 @@ class EAGLEWorker(TpModelWorker):
76
80
  self.server_args = server_args
77
81
  self.topk = server_args.speculative_eagle_topk
78
82
  self.speculative_num_steps = server_args.speculative_num_steps
83
+ self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
79
84
  self.enable_nan_detection = server_args.enable_nan_detection
80
85
  self.gpu_id = gpu_id
81
86
  self.device = server_args.device
@@ -166,6 +171,10 @@ class EAGLEWorker(TpModelWorker):
166
171
 
167
172
  def init_attention_backend(self):
168
173
  # Create multi-step attn backends and cuda graph runners
174
+
175
+ self.has_prefill_wrapper_verify = False
176
+ self.draft_extend_attn_backend = None
177
+
169
178
  if self.server_args.attention_backend == "flashinfer":
170
179
  if not global_server_args_dict["use_mla_backend"]:
171
180
  from sglang.srt.layers.attention.flashinfer_backend import (
@@ -213,7 +222,6 @@ class EAGLEWorker(TpModelWorker):
213
222
  self.draft_model_runner,
214
223
  skip_prefill=False,
215
224
  )
216
- self.has_prefill_wrapper_verify = False
217
225
  elif self.server_args.attention_backend == "fa3":
218
226
  from sglang.srt.layers.attention.flashattention_backend import (
219
227
  FlashAttentionBackend,
@@ -229,7 +237,6 @@ class EAGLEWorker(TpModelWorker):
229
237
  self.draft_model_runner,
230
238
  skip_prefill=False,
231
239
  )
232
- self.has_prefill_wrapper_verify = False
233
240
  elif self.server_args.attention_backend == "flashmla":
234
241
  from sglang.srt.layers.attention.flashmla_backend import (
235
242
  FlashMLAMultiStepDraftBackend,
@@ -240,8 +247,6 @@ class EAGLEWorker(TpModelWorker):
240
247
  self.topk,
241
248
  self.speculative_num_steps,
242
249
  )
243
- self.draft_extend_attn_backend = None
244
- self.has_prefill_wrapper_verify = False
245
250
  else:
246
251
  raise ValueError(
247
252
  f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
@@ -302,17 +307,27 @@ class EAGLEWorker(TpModelWorker):
302
307
  A tuple of the final logit output of the target model, next tokens accepted,
303
308
  the batch id (used for overlap schedule), and number of accepted tokens.
304
309
  """
305
- if batch.forward_mode.is_decode():
310
+ if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
311
+ logits_output, next_token_ids, bid, seq_lens_cpu = (
312
+ self.forward_target_extend(batch)
313
+ )
314
+ with self.draft_tp_context(self.draft_model_runner.tp_group):
315
+ self.forward_draft_extend(
316
+ batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
317
+ )
318
+ return logits_output, next_token_ids, bid, 0, False
319
+ else:
306
320
  with self.draft_tp_context(self.draft_model_runner.tp_group):
307
321
  spec_info = self.draft(batch)
308
322
  logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
309
323
  self.verify(batch, spec_info)
310
324
  )
311
325
 
312
- # If it is None, it means all requests are finished
313
- if batch.spec_info.verified_id is not None:
326
+ if self.check_forward_draft_extend_after_decode(batch):
314
327
  with self.draft_tp_context(self.draft_model_runner.tp_group):
315
- self.forward_draft_extend_after_decode(batch)
328
+ self.forward_draft_extend_after_decode(
329
+ batch,
330
+ )
316
331
  return (
317
332
  logits_output,
318
333
  verify_output.verified_id,
@@ -320,22 +335,27 @@ class EAGLEWorker(TpModelWorker):
320
335
  sum(verify_output.accept_length_per_req_cpu),
321
336
  can_run_cuda_graph,
322
337
  )
323
- elif batch.forward_mode.is_idle():
324
- model_worker_batch = batch.get_model_worker_batch()
325
- logits_output, next_token_ids, _ = (
326
- self.target_worker.forward_batch_generation(model_worker_batch)
327
- )
328
338
 
329
- return logits_output, next_token_ids, model_worker_batch.bid, 0, False
330
- else:
331
- logits_output, next_token_ids, bid, seq_lens_cpu = (
332
- self.forward_target_extend(batch)
333
- )
334
- with self.draft_tp_context(self.draft_model_runner.tp_group):
335
- self.forward_draft_extend(
336
- batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
337
- )
338
- return logits_output, next_token_ids, bid, 0, False
339
+ def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
340
+ local_need_forward = (
341
+ batch.spec_info.verified_id is not None
342
+ and batch.spec_info.verified_id.shape[0] > 0
343
+ )
344
+ if not self.server_args.enable_dp_attention:
345
+ return local_need_forward
346
+
347
+ global_need_forward = torch.tensor(
348
+ [
349
+ (local_need_forward),
350
+ ],
351
+ dtype=torch.int64,
352
+ )
353
+ torch.distributed.all_reduce(
354
+ global_need_forward, group=get_tp_group().cpu_group
355
+ )
356
+ global_need_forward_cnt = global_need_forward[0].item()
357
+ need_forward = global_need_forward_cnt > 0
358
+ return need_forward
339
359
 
340
360
  def forward_target_extend(
341
361
  self, batch: ScheduleBatch
@@ -354,6 +374,7 @@ class EAGLEWorker(TpModelWorker):
354
374
  # We need the full hidden states to prefill the KV cache of the draft model.
355
375
  model_worker_batch = batch.get_model_worker_batch()
356
376
  model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
377
+ model_worker_batch.spec_num_draft_tokens = 1
357
378
  logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
358
379
  model_worker_batch
359
380
  )
@@ -364,7 +385,7 @@ class EAGLEWorker(TpModelWorker):
364
385
  model_worker_batch.seq_lens_cpu,
365
386
  )
366
387
 
367
- def draft(self, batch: ScheduleBatch):
388
+ def _draft_preprocess_decode(self, batch: ScheduleBatch):
368
389
  # Parse args
369
390
  num_seqs = batch.batch_size()
370
391
  spec_info = batch.spec_info
@@ -466,10 +487,33 @@ class EAGLEWorker(TpModelWorker):
466
487
  batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
467
488
  batch.return_hidden_states = False
468
489
  spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
490
+ self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
491
+
492
+ def _draft_preprocess_idle(self, batch: ScheduleBatch):
493
+ batch.spec_info = EagleDraftInput.create_idle_input(
494
+ device=self.device,
495
+ hidden_size=self.model_config.hidden_size,
496
+ dtype=self.model_config.dtype,
497
+ topk=self.topk,
498
+ capture_hidden_mode=CaptureHiddenMode.LAST,
499
+ )
500
+
501
+ def draft(self, batch: ScheduleBatch):
502
+ # Parse args
503
+ if batch.forward_mode.is_idle():
504
+ self._draft_preprocess_idle(batch)
505
+ else:
506
+ self._draft_preprocess_decode(batch)
507
+
508
+ spec_info = batch.spec_info
509
+
469
510
  spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
511
+ batch.return_hidden_states = False
470
512
 
471
513
  # Get forward batch
472
514
  model_worker_batch = batch.get_model_worker_batch()
515
+ model_worker_batch.spec_num_draft_tokens = self.topk
516
+ assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
473
517
  forward_batch = ForwardBatch.init_new(
474
518
  model_worker_batch, self.draft_model_runner
475
519
  )
@@ -481,12 +525,18 @@ class EAGLEWorker(TpModelWorker):
481
525
  forward_batch
482
526
  )
483
527
  else:
484
- # Initialize attention backend
485
- self.draft_attn_backend.init_forward_metadata(forward_batch)
528
+ if not forward_batch.forward_mode.is_idle():
529
+ # Initialize attention backend
530
+ self.draft_attn_backend.init_forward_metadata(forward_batch)
486
531
  # Run forward steps
487
532
  score_list, token_list, parents_list = self.draft_forward(forward_batch)
488
533
 
489
- self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
534
+ if batch.forward_mode.is_idle():
535
+ return EagleVerifyInput.create_idle_input(
536
+ self.topk,
537
+ self.speculative_num_steps,
538
+ self.speculative_num_draft_tokens,
539
+ )
490
540
 
491
541
  (
492
542
  tree_mask,
@@ -504,7 +554,7 @@ class EAGLEWorker(TpModelWorker):
504
554
  batch.seq_lens_sum,
505
555
  self.topk,
506
556
  self.speculative_num_steps,
507
- self.server_args.speculative_num_draft_tokens,
557
+ self.speculative_num_draft_tokens,
508
558
  )
509
559
 
510
560
  return EagleVerifyInput(
@@ -584,11 +634,16 @@ class EAGLEWorker(TpModelWorker):
584
634
  def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
585
635
  spec_info.prepare_for_verify(batch, self.page_size)
586
636
  batch.return_hidden_states = False
587
- batch.forward_mode = ForwardMode.TARGET_VERIFY
637
+ batch.forward_mode = (
638
+ ForwardMode.TARGET_VERIFY
639
+ if not batch.forward_mode.is_idle()
640
+ else ForwardMode.IDLE
641
+ )
588
642
  batch.spec_info = spec_info
589
643
  model_worker_batch = batch.get_model_worker_batch(
590
644
  seq_lens_cpu_cache=spec_info.seq_lens_cpu
591
645
  )
646
+ model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
592
647
  assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
593
648
 
594
649
  if batch.has_grammar:
@@ -646,7 +701,9 @@ class EAGLEWorker(TpModelWorker):
646
701
  self.add_logprob_values(batch, res, logits_output)
647
702
 
648
703
  # Prepare the batch for the next draft forwards.
649
- batch.forward_mode = ForwardMode.DECODE
704
+ batch.forward_mode = (
705
+ ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
706
+ )
650
707
  batch.spec_info = res.draft_input
651
708
 
652
709
  return logits_output, res, model_worker_batch, can_run_cuda_graph
@@ -743,6 +800,7 @@ class EAGLEWorker(TpModelWorker):
743
800
  model_worker_batch = batch.get_model_worker_batch(
744
801
  seq_lens_cpu_cache=seq_lens_cpu
745
802
  )
803
+ model_worker_batch.spec_num_draft_tokens = 1
746
804
  forward_batch = ForwardBatch.init_new(
747
805
  model_worker_batch, self.draft_model_runner
748
806
  )
@@ -759,13 +817,33 @@ class EAGLEWorker(TpModelWorker):
759
817
  req_pool_indices_backup = batch.req_pool_indices
760
818
  accept_length_backup = batch.spec_info.accept_length
761
819
  return_logprob_backup = batch.return_logprob
762
-
763
- # Prepare metadata
764
- batch.spec_info.prepare_extend_after_decode(
765
- batch,
766
- self.speculative_num_steps,
767
- )
820
+ input_is_idle = batch.forward_mode.is_idle()
821
+ if not input_is_idle:
822
+ # Prepare metadata
823
+ if batch.spec_info.verified_id is not None:
824
+ batch.spec_info.prepare_extend_after_decode(
825
+ batch,
826
+ self.speculative_num_steps,
827
+ )
828
+ else:
829
+ batch = batch.copy()
830
+ batch.prepare_for_idle()
831
+ hidden_size = (
832
+ self.model_config.hidden_size * 3
833
+ if self.speculative_algorithm.is_eagle3()
834
+ else self.model_config.hidden_size
835
+ )
836
+ batch.spec_info = EagleDraftInput.create_idle_input(
837
+ device=self.device,
838
+ hidden_size=hidden_size,
839
+ dtype=self.model_config.dtype,
840
+ topk=self.topk,
841
+ capture_hidden_mode=CaptureHiddenMode.LAST,
842
+ )
843
+ batch.return_hidden_states = False
768
844
  model_worker_batch = batch.get_model_worker_batch()
845
+ model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
846
+ assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
769
847
  forward_batch = ForwardBatch.init_new(
770
848
  model_worker_batch, self.draft_model_runner
771
849
  )
@@ -789,7 +867,10 @@ class EAGLEWorker(TpModelWorker):
789
867
  )
790
868
  forward_batch.spec_info.hidden_states = logits_output.hidden_states
791
869
  else:
792
- self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
870
+ if not forward_batch.forward_mode.is_idle():
871
+ self.draft_model_runner.attn_backend.init_forward_metadata(
872
+ forward_batch
873
+ )
793
874
  logits_output = self.draft_model_runner.model.forward(
794
875
  forward_batch.input_ids, forward_batch.positions, forward_batch
795
876
  )
@@ -799,7 +880,9 @@ class EAGLEWorker(TpModelWorker):
799
880
 
800
881
  # Restore backup.
801
882
  # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
802
- batch.forward_mode = ForwardMode.DECODE
883
+ batch.forward_mode = (
884
+ ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
885
+ )
803
886
  batch.seq_lens = seq_lens_backup
804
887
  batch.req_pool_indices = req_pool_indices_backup
805
888
  batch.spec_info.accept_length = accept_length_backup
@@ -1,11 +1,13 @@
1
1
  import logging
2
+ import threading
3
+ import time
2
4
  from abc import ABC
3
- from contextlib import contextmanager
5
+ from contextlib import contextmanager, nullcontext
4
6
 
5
7
  try:
6
8
  import torch_memory_saver
7
9
 
8
- _primary_memory_saver = torch_memory_saver.TorchMemorySaver()
10
+ _memory_saver = torch_memory_saver.torch_memory_saver
9
11
  import_error = None
10
12
  except ImportError as e:
11
13
  import_error = e
@@ -38,13 +40,13 @@ class TorchMemorySaverAdapter(ABC):
38
40
  def configure_subprocess(self):
39
41
  raise NotImplementedError
40
42
 
41
- def region(self):
43
+ def region(self, tag: str):
42
44
  raise NotImplementedError
43
45
 
44
- def pause(self):
46
+ def pause(self, tag: str):
45
47
  raise NotImplementedError
46
48
 
47
- def resume(self):
49
+ def resume(self, tag: str):
48
50
  raise NotImplementedError
49
51
 
50
52
  @property
@@ -53,21 +55,23 @@ class TorchMemorySaverAdapter(ABC):
53
55
 
54
56
 
55
57
  class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
58
+ """Adapter for TorchMemorySaver with tag-based control"""
59
+
56
60
  def configure_subprocess(self):
57
61
  return torch_memory_saver.configure_subprocess()
58
62
 
59
- def region(self):
60
- return _primary_memory_saver.region()
63
+ def region(self, tag: str):
64
+ return _memory_saver.region(tag=tag)
61
65
 
62
- def pause(self):
63
- return _primary_memory_saver.pause()
66
+ def pause(self, tag: str):
67
+ return _memory_saver.pause(tag=tag)
64
68
 
65
- def resume(self):
66
- return _primary_memory_saver.resume()
69
+ def resume(self, tag: str):
70
+ return _memory_saver.resume(tag=tag)
67
71
 
68
72
  @property
69
73
  def enabled(self):
70
- return _primary_memory_saver.enabled
74
+ return _memory_saver is not None and _memory_saver.enabled
71
75
 
72
76
 
73
77
  class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
@@ -76,13 +80,13 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
76
80
  yield
77
81
 
78
82
  @contextmanager
79
- def region(self):
83
+ def region(self, tag: str):
80
84
  yield
81
85
 
82
- def pause(self):
86
+ def pause(self, tag: str):
83
87
  pass
84
88
 
85
- def resume(self):
89
+ def resume(self, tag: str):
86
90
  pass
87
91
 
88
92
  @property
sglang/srt/utils.py CHANGED
@@ -160,7 +160,7 @@ def is_npu() -> bool:
160
160
  return hasattr(torch, "npu") and torch.npu.is_available()
161
161
 
162
162
 
163
- def is_cpu() -> bool:
163
+ def is_host_cpu_x86() -> bool:
164
164
  machine = platform.machine().lower()
165
165
  return (
166
166
  machine in ("x86_64", "amd64", "i386", "i686")
@@ -169,6 +169,10 @@ def is_cpu() -> bool:
169
169
  )
170
170
 
171
171
 
172
+ def is_cpu() -> bool:
173
+ return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86()
174
+
175
+
172
176
  def is_flashinfer_available():
173
177
  """
174
178
  Check whether flashinfer is available.
@@ -1291,6 +1295,15 @@ def get_hpu_memory_capacity():
1291
1295
  )
1292
1296
 
1293
1297
 
1298
+ def get_npu_memory_capacity():
1299
+ try:
1300
+ import torch_npu
1301
+
1302
+ return torch.npu.mem_get_info()[1] // 1024 // 1024 # unit: MB
1303
+ except ImportError as e:
1304
+ raise ImportError("torch_npu is required when run on npu device.")
1305
+
1306
+
1294
1307
  def get_device_memory_capacity(device: str = None):
1295
1308
  if is_cuda():
1296
1309
  gpu_mem = get_nvgpu_memory_capacity()
@@ -1298,6 +1311,8 @@ def get_device_memory_capacity(device: str = None):
1298
1311
  gpu_mem = get_amdgpu_memory_capacity()
1299
1312
  elif device == "hpu":
1300
1313
  gpu_mem = get_hpu_memory_capacity()
1314
+ elif device == "npu":
1315
+ gpu_mem = get_npu_memory_capacity()
1301
1316
  else:
1302
1317
  # GPU memory is not known yet or no GPU is available.
1303
1318
  gpu_mem = None
@@ -1423,6 +1438,11 @@ def get_device(device_id: Optional[int] = None) -> str:
1423
1438
  return "xpu"
1424
1439
  return "xpu:{}".format(device_id)
1425
1440
 
1441
+ if hasattr(torch, "npu") and torch.npu.is_available():
1442
+ if device_id == None:
1443
+ return "npu"
1444
+ return "npu:{}".format(device_id)
1445
+
1426
1446
  if is_habana_available():
1427
1447
  try:
1428
1448
  import habana_frameworks.torch.hpu
@@ -1436,6 +1456,15 @@ def get_device(device_id: Optional[int] = None) -> str:
1436
1456
  "Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
1437
1457
  )
1438
1458
 
1459
+ if is_cpu():
1460
+ if cpu_has_amx_support():
1461
+ logger.info("Intel AMX is detected, using CPU with Intel AMX support.")
1462
+ else:
1463
+ logger.warning(
1464
+ "CPU device enabled, using torch native backend, low performance expected."
1465
+ )
1466
+ return "cpu"
1467
+
1439
1468
  raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")
1440
1469
 
1441
1470
 
@@ -1497,15 +1526,35 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
1497
1526
  return major, minor
1498
1527
 
1499
1528
 
1529
+ def get_npu_compiler_config():
1530
+ config = {
1531
+ "frozen_parameter": True,
1532
+ "tiling_schedule_optimize": True,
1533
+ "topology_sorting_strategy": "StableRDFS",
1534
+ }
1535
+ return config
1536
+
1537
+
1500
1538
  def get_compiler_backend() -> str:
1501
1539
  if hasattr(torch, "hpu") and torch.hpu.is_available():
1502
1540
  return "hpu_backend"
1503
1541
 
1504
1542
  if hasattr(torch, "npu") and torch.npu.is_available():
1505
- import torchair
1543
+ try:
1544
+ import torchair
1545
+ import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce
1546
+ from torchair.configs.compiler_config import CompilerConfig
1547
+ except ImportError as e:
1548
+ raise ImportError(
1549
+ "NPU detected, but torchair package is not installed. "
1550
+ "Please install torchair for torch.compile support on NPU."
1551
+ )
1552
+ compiler_config = CompilerConfig()
1553
+ predefined_config = get_npu_compiler_config()
1554
+ for k, v in predefined_config.items():
1555
+ setattr(compiler_config.experimental_config, k, v)
1506
1556
 
1507
- config = torchair.CompilerConfig()
1508
- npu_backend = torchair.get_npu_backend(compiler_config=config)
1557
+ npu_backend = torchair.get_npu_backend(compiler_config=compiler_config)
1509
1558
  return npu_backend
1510
1559
 
1511
1560
  return "inductor"
@@ -1868,13 +1917,6 @@ def configure_ipv6(dist_init_addr):
1868
1917
  return port, host
1869
1918
 
1870
1919
 
1871
- def rank0_log(msg: str):
1872
- from sglang.srt.distributed import get_tensor_model_parallel_rank
1873
-
1874
- if get_tensor_model_parallel_rank() == 0:
1875
- logger.info(msg)
1876
-
1877
-
1878
1920
  def rank0_print(msg: str):
1879
1921
  from sglang.srt.distributed import get_tensor_model_parallel_rank
1880
1922
 
@@ -1882,6 +1924,9 @@ def rank0_print(msg: str):
1882
1924
  print(msg, flush=True)
1883
1925
 
1884
1926
 
1927
+ rank0_log = rank0_print
1928
+
1929
+
1885
1930
  def get_cuda_version():
1886
1931
  if torch.version.cuda:
1887
1932
  return tuple(map(int, torch.version.cuda.split(".")))
@@ -2105,6 +2150,44 @@ def get_free_port():
2105
2150
  return s.getsockname()[1]
2106
2151
 
2107
2152
 
2153
+ def get_local_ip_auto() -> str:
2154
+ interface = os.environ.get("SGLANG_LOCAL_IP_NIC", None)
2155
+ return (
2156
+ get_local_ip_by_nic(interface)
2157
+ if interface is not None
2158
+ else get_local_ip_by_remote()
2159
+ )
2160
+
2161
+
2162
+ def get_local_ip_by_nic(interface: str) -> str:
2163
+ try:
2164
+ import netifaces
2165
+ except ImportError as e:
2166
+ raise ImportError(
2167
+ "Environment variable SGLANG_LOCAL_IP_NIC requires package netifaces, please install it through 'pip install netifaces'"
2168
+ ) from e
2169
+
2170
+ try:
2171
+ addresses = netifaces.ifaddresses(interface)
2172
+ if netifaces.AF_INET in addresses:
2173
+ for addr_info in addresses[netifaces.AF_INET]:
2174
+ ip = addr_info.get("addr")
2175
+ if ip and ip != "127.0.0.1" and ip != "0.0.0.0":
2176
+ return ip
2177
+ if netifaces.AF_INET6 in addresses:
2178
+ for addr_info in addresses[netifaces.AF_INET6]:
2179
+ ip = addr_info.get("addr")
2180
+ if ip and not ip.startswith("fe80::") and ip != "::1":
2181
+ return ip.split("%")[0]
2182
+ except (ValueError, OSError) as e:
2183
+ raise ValueError(
2184
+ "Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
2185
+ )
2186
+
2187
+ # Fallback
2188
+ return get_local_ip_by_remote()
2189
+
2190
+
2108
2191
  def get_local_ip_by_remote() -> str:
2109
2192
  # try ipv4
2110
2193
  s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@@ -2216,6 +2299,51 @@ class Withable(Generic[T]):
2216
2299
  self._value = None
2217
2300
 
2218
2301
 
2302
+ def require_mlp_tp_gather(server_args):
2303
+ """
2304
+ Check if the input of MLP is obtained by all-gather rather than all-reduce. This only happens when each MLP TP group contains multiple attention DP groups.
2305
+ """
2306
+ if server_args.enable_dp_attention:
2307
+ assert server_args.dp_size > 1, "dp_size must be greater than 1"
2308
+ if (
2309
+ server_args.moe_dense_tp_size is None
2310
+ ): # TODO(ch-wan): some MoE models do not have dense layers
2311
+ return True
2312
+ elif not server_args.enable_dp_lm_head:
2313
+ return True
2314
+ elif not server_args.enable_deepep_moe:
2315
+ return True
2316
+ else:
2317
+ return (
2318
+ server_args.moe_dense_tp_size
2319
+ > server_args.tp_size // server_args.dp_size
2320
+ )
2321
+ else:
2322
+ return False
2323
+
2324
+
2325
+ def require_attn_tp_gather(server_args):
2326
+ """
2327
+ Check if the input of attention is scattered.
2328
+ """
2329
+ assert server_args.moe_dense_tp_size in [1, None]
2330
+ if server_args.enable_deepep_moe or server_args.moe_dense_tp_size == 1:
2331
+ if server_args.enable_dp_attention:
2332
+ return server_args.dp_size < server_args.tp_size
2333
+ else:
2334
+ return True
2335
+ else:
2336
+ return False
2337
+
2338
+
2339
+ def require_gathered_buffer(server_args):
2340
+ return require_mlp_tp_gather(server_args) or require_attn_tp_gather(server_args)
2341
+
2342
+
2343
+ def require_mlp_sync(server_args):
2344
+ return server_args.enable_dp_attention or require_gathered_buffer(server_args)
2345
+
2346
+
2219
2347
  def merge_bias_tensor(
2220
2348
  lhs: Optional[torch.Tensor],
2221
2349
  rhs: Optional[torch.Tensor],
@@ -2340,3 +2468,41 @@ class LazyValue:
2340
2468
  self._value = self._creator()
2341
2469
  self._creator = None
2342
2470
  return self._value
2471
+
2472
+
2473
+ def dynamic_import(func_path: str):
2474
+ parts = func_path.split(".")
2475
+ if len(parts) < 2:
2476
+ raise ValueError(
2477
+ "func_path should contain both module name and func name (such as 'module.func')"
2478
+ )
2479
+ module_path = ".".join(parts[:-1])
2480
+ func_name = parts[-1]
2481
+ module = importlib.import_module(module_path)
2482
+ func = getattr(module, func_name)
2483
+ return func
2484
+
2485
+
2486
+ def configure_gc_logger():
2487
+ logger.info("Enable GC Logger")
2488
+
2489
+ import gc
2490
+
2491
+ gc_start_time = {}
2492
+
2493
+ def gc_callback(phase, info):
2494
+ gen = info.get("generation", "?")
2495
+ if phase == "start":
2496
+ gc_start_time[gen] = time.time()
2497
+ logger.info(f"GC start: Time {time.time()} | Generation {gen}")
2498
+ elif phase == "stop":
2499
+ duration = time.time() - gc_start_time.get(gen, time.time())
2500
+ collected = info.get("collected", "?")
2501
+ uncollectable = info.get("uncollectable", "?")
2502
+ logger.info(
2503
+ f"GC end: Time {time.time()} | Generation {gen} | "
2504
+ f"Duration: {duration:.4f}s | Collected: {collected} | Uncollectable: {uncollectable} "
2505
+ f'{"(LONG GC)" if duration > 0.1 else ""}'
2506
+ )
2507
+
2508
+ gc.callbacks.append(gc_callback)
@@ -182,6 +182,7 @@ def ep_moe(
182
182
  end_expert_id,
183
183
  top_k,
184
184
  hidden_states.size(1),
185
+ 0,
185
186
  BLOCK_SIZE=512,
186
187
  )
187
188
  return output
sglang/test/test_utils.py CHANGED
@@ -37,6 +37,7 @@ from sglang.utils import get_exception_traceback
37
37
  # General test models
38
38
  DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
39
39
  DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
40
+ DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE = "meta-llama/Llama-3.2-1B"
40
41
  DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
41
42
  DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
42
43
 
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.7.post1"
1
+ __version__ = "0.4.8"