sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -46,6 +46,10 @@ from sglang.srt.utils import (
46
46
  get_available_gpu_memory,
47
47
  get_device_memory_capacity,
48
48
  rank0_log,
49
+ require_attn_tp_gather,
50
+ require_gathered_buffer,
51
+ require_mlp_sync,
52
+ require_mlp_tp_gather,
49
53
  )
50
54
 
51
55
  logger = logging.getLogger(__name__)
@@ -207,8 +211,10 @@ class CudaGraphRunner:
207
211
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
208
212
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
209
213
  self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
210
- self.enable_dp_attention = model_runner.server_args.enable_dp_attention
211
- self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
214
+ self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
215
+ self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
216
+ self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
217
+ self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
212
218
  self.enable_two_batch_overlap = (
213
219
  model_runner.server_args.enable_two_batch_overlap
214
220
  )
@@ -242,13 +248,13 @@ class CudaGraphRunner:
242
248
  # Attention backend
243
249
  self.max_bs = max(self.capture_bs)
244
250
  self.max_num_token = self.max_bs * self.num_tokens_per_bs
245
- if global_server_args_dict["attention_backend"] == "flashmla":
246
- self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
247
- else:
248
- self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
251
+ self.model_runner.attn_backend.init_cuda_graph_state(
252
+ self.max_bs, self.max_num_token
253
+ )
249
254
  self.seq_len_fill_value = (
250
255
  self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
251
256
  )
257
+
252
258
  # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
253
259
  self.encoder_len_fill_value = 0
254
260
  self.seq_lens_cpu = torch.full(
@@ -299,18 +305,30 @@ class CudaGraphRunner:
299
305
  else:
300
306
  self.encoder_lens = None
301
307
 
302
- if self.enable_dp_attention or self.enable_sp_layernorm:
303
- # TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
308
+ if self.require_gathered_buffer:
304
309
  self.gathered_buffer = torch.zeros(
305
310
  (
306
- self.max_bs * self.dp_size * self.num_tokens_per_bs,
311
+ self.max_num_token,
307
312
  self.model_runner.model_config.hidden_size,
308
313
  ),
309
314
  dtype=self.model_runner.dtype,
310
315
  )
311
- self.global_num_tokens_gpu = torch.zeros(
312
- (self.dp_size,), dtype=torch.int32
313
- )
316
+ if self.require_mlp_tp_gather:
317
+ self.global_num_tokens_gpu = torch.zeros(
318
+ (self.dp_size,), dtype=torch.int32
319
+ )
320
+ else:
321
+ assert self.require_attn_tp_gather
322
+ self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
323
+
324
+ self.custom_mask = torch.ones(
325
+ (
326
+ (self.seq_lens.sum().item() + self.max_num_token)
327
+ * self.num_tokens_per_bs
328
+ ),
329
+ dtype=torch.bool,
330
+ device="cuda",
331
+ )
314
332
 
315
333
  # Capture
316
334
  try:
@@ -322,20 +340,23 @@ class CudaGraphRunner:
322
340
  )
323
341
 
324
342
  def can_run(self, forward_batch: ForwardBatch):
325
- if self.enable_dp_attention or self.enable_sp_layernorm:
326
- total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
327
-
328
- is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
329
- total_global_tokens in self.graphs
330
- if self.disable_padding
331
- else total_global_tokens <= self.max_bs
343
+ if self.require_mlp_tp_gather:
344
+ cuda_graph_bs = (
345
+ sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
346
+ if self.model_runner.spec_algorithm.is_eagle()
347
+ else sum(forward_batch.global_num_tokens_cpu)
332
348
  )
333
349
  else:
334
- is_bs_supported = (
335
- forward_batch.batch_size in self.graphs
336
- if self.disable_padding
337
- else forward_batch.batch_size <= self.max_bs
338
- )
350
+ cuda_graph_bs = forward_batch.batch_size
351
+
352
+ is_bs_supported = (
353
+ cuda_graph_bs in self.graphs
354
+ if self.disable_padding
355
+ else cuda_graph_bs <= self.max_bs
356
+ )
357
+
358
+ if self.require_mlp_sync:
359
+ is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
339
360
 
340
361
  # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
341
362
  # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
@@ -400,7 +421,7 @@ class CudaGraphRunner:
400
421
  empty_cache=False,
401
422
  )
402
423
  capture_range.set_description(
403
- f"Capturing batches ({avail_mem=:.2f} GB)"
424
+ f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
404
425
  )
405
426
 
406
427
  with patch_model(
@@ -456,11 +477,11 @@ class CudaGraphRunner:
456
477
  {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
457
478
  )
458
479
 
459
- if self.enable_dp_attention or self.enable_sp_layernorm:
480
+ if self.require_mlp_tp_gather:
460
481
  self.global_num_tokens_gpu.copy_(
461
482
  torch.tensor(
462
483
  [
463
- num_tokens // self.dp_size + (i < bs % self.dp_size)
484
+ num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
464
485
  for i in range(self.dp_size)
465
486
  ],
466
487
  dtype=torch.int32,
@@ -469,6 +490,16 @@ class CudaGraphRunner:
469
490
  )
470
491
  global_num_tokens = self.global_num_tokens_gpu
471
492
  gathered_buffer = self.gathered_buffer[:num_tokens]
493
+ elif self.require_attn_tp_gather:
494
+ self.global_num_tokens_gpu.copy_(
495
+ torch.tensor(
496
+ [num_tokens],
497
+ dtype=torch.int32,
498
+ device=input_ids.device,
499
+ )
500
+ )
501
+ global_num_tokens = self.global_num_tokens_gpu
502
+ gathered_buffer = self.gathered_buffer[:num_tokens]
472
503
  else:
473
504
  global_num_tokens = None
474
505
  gathered_buffer = None
@@ -604,15 +635,18 @@ class CudaGraphRunner:
604
635
  raw_num_token = raw_bs * self.num_tokens_per_bs
605
636
 
606
637
  # Pad
607
- if self.enable_dp_attention or self.enable_sp_layernorm:
608
- index = bisect.bisect_left(
609
- self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
638
+ if self.require_mlp_tp_gather:
639
+ total_batch_size = (
640
+ sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
641
+ if self.model_runner.spec_algorithm.is_eagle()
642
+ else sum(forward_batch.global_num_tokens_cpu)
610
643
  )
644
+ index = bisect.bisect_left(self.capture_bs, total_batch_size)
611
645
  else:
612
646
  index = bisect.bisect_left(self.capture_bs, raw_bs)
613
647
  bs = self.capture_bs[index]
614
648
  if bs != raw_bs:
615
- self.seq_lens.fill_(1)
649
+ self.seq_lens.fill_(self.seq_len_fill_value)
616
650
  self.out_cache_loc.zero_()
617
651
 
618
652
  # Common inputs
@@ -624,7 +658,7 @@ class CudaGraphRunner:
624
658
 
625
659
  if forward_batch.seq_lens_cpu is not None:
626
660
  if bs != raw_bs:
627
- self.seq_lens_cpu.fill_(1)
661
+ self.seq_lens_cpu.fill_(self.seq_len_fill_value)
628
662
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
629
663
 
630
664
  if pp_proxy_tensors:
@@ -636,27 +670,28 @@ class CudaGraphRunner:
636
670
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
637
671
  if forward_batch.mrope_positions is not None:
638
672
  self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
639
- if self.enable_dp_attention or self.enable_sp_layernorm:
673
+ if self.require_gathered_buffer:
640
674
  self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
641
675
  if enable_num_token_non_padded(self.model_runner.server_args):
642
676
  self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
643
677
  if self.enable_two_batch_overlap:
644
678
  self.tbo_plugin.replay_prepare(
645
- forward_mode=forward_batch.forward_mode,
679
+ forward_mode=self.capture_forward_mode,
646
680
  bs=bs,
647
681
  num_token_non_padded=len(forward_batch.input_ids),
648
682
  )
649
-
683
+ if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
684
+ forward_batch.spec_info.custom_mask = self.custom_mask
650
685
  # Attention backend
651
686
  self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
652
687
  bs,
653
- self.req_pool_indices,
654
- self.seq_lens,
655
- forward_batch.seq_lens_sum + (bs - raw_bs),
656
- self.encoder_lens,
657
- forward_batch.forward_mode,
688
+ self.req_pool_indices[:bs],
689
+ self.seq_lens[:bs],
690
+ forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
691
+ self.encoder_lens[:bs] if self.is_encoder_decoder else None,
692
+ self.capture_forward_mode,
658
693
  forward_batch.spec_info,
659
- seq_lens_cpu=self.seq_lens_cpu,
694
+ seq_lens_cpu=self.seq_lens_cpu[:bs],
660
695
  )
661
696
 
662
697
  # Store fields
@@ -704,11 +739,7 @@ class CudaGraphRunner:
704
739
  else:
705
740
  spec_info = EagleVerifyInput(
706
741
  draft_token=None,
707
- custom_mask=torch.ones(
708
- (num_tokens * self.model_runner.model_config.context_len),
709
- dtype=torch.bool,
710
- device="cuda",
711
- ),
742
+ custom_mask=self.custom_mask,
712
743
  positions=None,
713
744
  retrive_index=None,
714
745
  retrive_next_token=None,
@@ -320,17 +320,30 @@ class ForwardBatch:
320
320
 
321
321
  # For DP attention
322
322
  if batch.global_num_tokens is not None:
323
- ret.global_num_tokens_cpu = batch.global_num_tokens
323
+
324
+ spec_num_draft_tokens = (
325
+ batch.spec_num_draft_tokens
326
+ if batch.spec_num_draft_tokens is not None
327
+ else 1
328
+ )
329
+ global_num_tokens = [
330
+ x * spec_num_draft_tokens for x in batch.global_num_tokens
331
+ ]
332
+ global_num_tokens_for_logprob = [
333
+ x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
334
+ ]
335
+
336
+ ret.global_num_tokens_cpu = global_num_tokens
324
337
  ret.global_num_tokens_gpu = torch.tensor(
325
- batch.global_num_tokens, dtype=torch.int64
338
+ global_num_tokens, dtype=torch.int64
326
339
  ).to(device, non_blocking=True)
327
340
 
328
- ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
341
+ ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
329
342
  ret.global_num_tokens_for_logprob_gpu = torch.tensor(
330
- batch.global_num_tokens_for_logprob, dtype=torch.int64
343
+ global_num_tokens_for_logprob, dtype=torch.int64
331
344
  ).to(device, non_blocking=True)
332
345
 
333
- sum_len = sum(batch.global_num_tokens)
346
+ sum_len = sum(global_num_tokens)
334
347
  ret.gathered_buffer = torch.zeros(
335
348
  (sum_len, model_runner.model_config.hidden_size),
336
349
  dtype=model_runner.dtype,
@@ -30,6 +30,7 @@ from sglang.srt import debug_utils
30
30
  from sglang.srt.configs.device_config import DeviceConfig
31
31
  from sglang.srt.configs.load_config import LoadConfig
32
32
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
33
+ from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
33
34
  from sglang.srt.distributed import (
34
35
  get_tp_group,
35
36
  get_world_group,
@@ -70,14 +71,17 @@ from sglang.srt.managers.schedule_batch import (
70
71
  GLOBAL_SERVER_ARGS_KEYS,
71
72
  global_server_args_dict,
72
73
  )
74
+ from sglang.srt.mem_cache.allocator import (
75
+ BaseTokenToKVPoolAllocator,
76
+ PagedTokenToKVPoolAllocator,
77
+ TokenToKVPoolAllocator,
78
+ )
73
79
  from sglang.srt.mem_cache.memory_pool import (
74
80
  DoubleSparseTokenToKVPool,
75
81
  MHATokenToKVPool,
76
82
  MLATokenToKVPool,
77
83
  ReqToTokenPool,
78
- TokenToKVPoolAllocator,
79
84
  )
80
- from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
81
85
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
82
86
  from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
83
87
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
@@ -93,6 +97,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
93
97
  from sglang.srt.utils import (
94
98
  MultiprocessingSerializer,
95
99
  cpu_has_amx_support,
100
+ dynamic_import,
96
101
  enable_show_time_cost,
97
102
  get_available_gpu_memory,
98
103
  get_bool_env_var,
@@ -110,6 +115,7 @@ from sglang.srt.utils import (
110
115
  )
111
116
 
112
117
  _is_hip = is_hip()
118
+ _is_cpu_amx_available = cpu_has_amx_support()
113
119
 
114
120
  # Use a small KV cache pool size for tests in CI
115
121
  SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
@@ -149,7 +155,7 @@ class ModelRunner:
149
155
  server_args: ServerArgs,
150
156
  is_draft_worker: bool = False,
151
157
  req_to_token_pool: Optional[ReqToTokenPool] = None,
152
- token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
158
+ token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
153
159
  ):
154
160
  # Parse args
155
161
  self.model_config = model_config
@@ -162,6 +168,7 @@ class ModelRunner:
162
168
  logger.addFilter(RankZeroFilter(tp_rank == 0))
163
169
  self.tp_rank = tp_rank
164
170
  self.tp_size = tp_size
171
+ self.dp_size = server_args.dp_size
165
172
  self.pp_rank = pp_rank
166
173
  self.pp_size = pp_size
167
174
  self.dist_port = nccl_port
@@ -195,6 +202,7 @@ class ModelRunner:
195
202
  | {
196
203
  # TODO it is indeed not a "server args"
197
204
  "use_mla_backend": self.use_mla_backend,
205
+ "speculative_algorithm": self.spec_algorithm,
198
206
  }
199
207
  )
200
208
 
@@ -218,6 +226,7 @@ class ModelRunner:
218
226
 
219
227
  def initialize(self, min_per_gpu_memory: float):
220
228
  server_args = self.server_args
229
+
221
230
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
222
231
  enable=self.server_args.enable_memory_saver
223
232
  )
@@ -230,7 +239,7 @@ class ModelRunner:
230
239
  "SGLANG_LOG_EXPERT_LOCATION_METADATA"
231
240
  ):
232
241
  logger.info(
233
- f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}"
242
+ f"Initial expert_location_metadata: {get_global_expert_location_metadata()}"
234
243
  )
235
244
 
236
245
  set_global_expert_distribution_recorder(
@@ -272,6 +281,10 @@ class ModelRunner:
272
281
  self.apply_torch_tp()
273
282
 
274
283
  # Init lora
284
+ # TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add
285
+ # a new server arg `enable_lora` to control whether to init LoRA manager to be more
286
+ # explicit, as it is perfectly valid to start a server with an empty lora_paths and
287
+ # load LoRA adapters dynamically later.
275
288
  if server_args.lora_paths is not None:
276
289
  self.init_lora_manager()
277
290
 
@@ -299,7 +312,7 @@ class ModelRunner:
299
312
  if (
300
313
  server_args.attention_backend == "intel_amx"
301
314
  and server_args.device == "cpu"
302
- and not cpu_has_amx_support()
315
+ and not _is_cpu_amx_available
303
316
  ):
304
317
  logger.info(
305
318
  "The current platform does not support Intel AMX, will fallback to torch_native backend."
@@ -534,6 +547,7 @@ class ModelRunner:
534
547
  self.load_config = LoadConfig(
535
548
  load_format=self.server_args.load_format,
536
549
  download_dir=self.server_args.download_dir,
550
+ model_loader_extra_config=self.server_args.model_loader_extra_config,
537
551
  )
538
552
  if self.server_args.load_format == "gguf":
539
553
  monkey_patch_vllm_gguf_config()
@@ -543,7 +557,7 @@ class ModelRunner:
543
557
  monkey_patch_vllm_parallel_state()
544
558
  monkey_patch_isinstance_for_vllm_base_layer()
545
559
 
546
- with self.memory_saver_adapter.region():
560
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
547
561
  self.model = get_model(
548
562
  model_config=self.model_config,
549
563
  load_config=self.load_config,
@@ -761,6 +775,9 @@ class ModelRunner:
761
775
  ]
762
776
  if load_format == "direct":
763
777
  _model_load_weights_direct(self.model, named_tensors)
778
+ elif load_format in self.server_args.custom_weight_loader:
779
+ custom_loader = dynamic_import(load_format)
780
+ custom_loader(self.model, named_tensors)
764
781
  elif load_format is None:
765
782
  self.model.load_weights(named_tensors)
766
783
  else:
@@ -787,7 +804,6 @@ class ModelRunner:
787
804
  def init_lora_manager(self):
788
805
  self.lora_manager = LoRAManager(
789
806
  base_model=self.model,
790
- lora_paths=self.server_args.lora_paths,
791
807
  base_hf_config=self.model_config.hf_config,
792
808
  max_loras_per_batch=self.server_args.max_loras_per_batch,
793
809
  load_config=self.load_config,
@@ -796,6 +812,7 @@ class ModelRunner:
796
812
  tp_size=self.tp_size,
797
813
  tp_rank=self.tp_rank,
798
814
  )
815
+ self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
799
816
  logger.info("LoRA manager ready.")
800
817
 
801
818
  def profile_max_num_token(self, total_gpu_memory: int):
@@ -849,7 +866,9 @@ class ModelRunner:
849
866
  else:
850
867
  self.kv_cache_dtype = torch.float8_e5m2
851
868
  elif self.server_args.kv_cache_dtype == "fp8_e4m3":
852
- if is_cuda():
869
+ if _is_hip: # Using natively supported format
870
+ self.kv_cache_dtype = torch.float8_e4m3fnuz
871
+ else:
853
872
  self.kv_cache_dtype = torch.float8_e4m3fn
854
873
  else:
855
874
  raise ValueError(
@@ -2,6 +2,7 @@
2
2
 
3
3
  # ruff: noqa: SIM117
4
4
  import collections
5
+ import concurrent
5
6
  import dataclasses
6
7
  import fnmatch
7
8
  import glob
@@ -11,14 +12,17 @@ import math
11
12
  import os
12
13
  import time
13
14
  from abc import ABC, abstractmethod
15
+ from concurrent.futures import ThreadPoolExecutor
14
16
  from contextlib import contextmanager
15
17
  from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
16
18
 
17
19
  import huggingface_hub
18
20
  import numpy as np
21
+ import safetensors.torch
19
22
  import torch
20
23
  from huggingface_hub import HfApi, hf_hub_download
21
24
  from torch import nn
25
+ from tqdm.auto import tqdm
22
26
  from transformers import AutoModelForCausalLM
23
27
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
24
28
 
@@ -41,6 +45,7 @@ from sglang.srt.model_loader.utils import (
41
45
  set_default_torch_dtype,
42
46
  )
43
47
  from sglang.srt.model_loader.weight_utils import (
48
+ _BAR_FORMAT,
44
49
  download_safetensors_index_file_from_hf,
45
50
  download_weights_from_hf,
46
51
  filter_duplicate_safetensors_files,
@@ -49,6 +54,8 @@ from sglang.srt.model_loader.weight_utils import (
49
54
  get_quant_config,
50
55
  gguf_quant_weights_iterator,
51
56
  initialize_dummy_weights,
57
+ multi_thread_pt_weights_iterator,
58
+ multi_thread_safetensors_weights_iterator,
52
59
  np_cache_weights_iterator,
53
60
  pt_weights_iterator,
54
61
  safetensors_weights_iterator,
@@ -181,6 +188,9 @@ class BaseModelLoader(ABC):
181
188
  class DefaultModelLoader(BaseModelLoader):
182
189
  """Model loader that can load different file types from disk."""
183
190
 
191
+ # default number of thread when enable multithread weight loading
192
+ DEFAULT_NUM_THREADS = 8
193
+
184
194
  @dataclasses.dataclass
185
195
  class Source:
186
196
  """A source for weights."""
@@ -208,10 +218,15 @@ class DefaultModelLoader(BaseModelLoader):
208
218
 
209
219
  def __init__(self, load_config: LoadConfig):
210
220
  super().__init__(load_config)
211
- if load_config.model_loader_extra_config:
221
+ extra_config = load_config.model_loader_extra_config
222
+ allowed_keys = {"enable_multithread_load", "num_threads"}
223
+ unexpected_keys = set(extra_config.keys()) - allowed_keys
224
+
225
+ if unexpected_keys:
212
226
  raise ValueError(
213
- f"Model loader extra config is not supported for "
214
- f"load format {load_config.load_format}"
227
+ f"Unexpected extra config keys for load format "
228
+ f"{load_config.load_format}: "
229
+ f"{unexpected_keys}"
215
230
  )
216
231
 
217
232
  def _maybe_download_from_modelscope(
@@ -324,6 +339,7 @@ class DefaultModelLoader(BaseModelLoader):
324
339
  self, source: "Source"
325
340
  ) -> Generator[Tuple[str, torch.Tensor], None, None]:
326
341
  """Get an iterator for the model weights based on the load format."""
342
+ extra_config = self.load_config.model_loader_extra_config
327
343
  hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
328
344
  source.model_or_path, source.revision, source.fall_back_to_pt
329
345
  )
@@ -337,9 +353,35 @@ class DefaultModelLoader(BaseModelLoader):
337
353
  hf_weights_files,
338
354
  )
339
355
  elif use_safetensors:
340
- weights_iterator = safetensors_weights_iterator(hf_weights_files)
356
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
357
+
358
+ weight_loader_disable_mmap = global_server_args_dict.get(
359
+ "weight_loader_disable_mmap"
360
+ )
361
+
362
+ if extra_config.get("enable_multithread_load"):
363
+ weights_iterator = multi_thread_safetensors_weights_iterator(
364
+ hf_weights_files,
365
+ max_workers=extra_config.get(
366
+ "num_threads", self.DEFAULT_NUM_THREADS
367
+ ),
368
+ disable_mmap=weight_loader_disable_mmap,
369
+ )
370
+ else:
371
+ weights_iterator = safetensors_weights_iterator(
372
+ hf_weights_files, disable_mmap=weight_loader_disable_mmap
373
+ )
374
+
341
375
  else:
342
- weights_iterator = pt_weights_iterator(hf_weights_files)
376
+ if extra_config.get("enable_multithread_load"):
377
+ weights_iterator = multi_thread_pt_weights_iterator(
378
+ hf_weights_files,
379
+ max_workers=extra_config.get(
380
+ "num_threads", self.DEFAULT_NUM_THREADS
381
+ ),
382
+ )
383
+ else:
384
+ weights_iterator = pt_weights_iterator(hf_weights_files)
343
385
 
344
386
  # Apply the prefix.
345
387
  return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
@@ -378,9 +420,9 @@ class DefaultModelLoader(BaseModelLoader):
378
420
  self.load_config,
379
421
  )
380
422
 
381
- self.load_weights_and_postprocess(
382
- model, self._get_all_weights(model_config, model), target_device
383
- )
423
+ self.load_weights_and_postprocess(
424
+ model, self._get_all_weights(model_config, model), target_device
425
+ )
384
426
 
385
427
  return model.eval()
386
428