sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +14 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +301 -64
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +20 -15
- sglang/srt/disaggregation/utils.py +47 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +27 -31
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +897 -0
- sglang/srt/entrypoints/openai/serving_completions.py +425 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +28 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +43 -23
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +44 -2
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +14 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +286 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
- sglang/srt/layers/moe/topk.py +117 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +144 -12
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +19 -14
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +49 -32
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +189 -68
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +77 -46
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +27 -8
- sglang/srt/model_loader/loader.py +50 -8
- sglang/srt/model_loader/weight_utils.py +100 -2
- sglang/srt/models/deepseek_nextn.py +35 -30
- sglang/srt/models/deepseek_v2.py +255 -30
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +51 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +248 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -46,6 +46,10 @@ from sglang.srt.utils import (
|
|
46
46
|
get_available_gpu_memory,
|
47
47
|
get_device_memory_capacity,
|
48
48
|
rank0_log,
|
49
|
+
require_attn_tp_gather,
|
50
|
+
require_gathered_buffer,
|
51
|
+
require_mlp_sync,
|
52
|
+
require_mlp_tp_gather,
|
49
53
|
)
|
50
54
|
|
51
55
|
logger = logging.getLogger(__name__)
|
@@ -207,8 +211,10 @@ class CudaGraphRunner:
|
|
207
211
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
208
212
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
209
213
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
210
|
-
self.
|
211
|
-
self.
|
214
|
+
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
215
|
+
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
216
|
+
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
217
|
+
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
212
218
|
self.enable_two_batch_overlap = (
|
213
219
|
model_runner.server_args.enable_two_batch_overlap
|
214
220
|
)
|
@@ -242,13 +248,13 @@ class CudaGraphRunner:
|
|
242
248
|
# Attention backend
|
243
249
|
self.max_bs = max(self.capture_bs)
|
244
250
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
245
|
-
|
246
|
-
self.
|
247
|
-
|
248
|
-
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
|
251
|
+
self.model_runner.attn_backend.init_cuda_graph_state(
|
252
|
+
self.max_bs, self.max_num_token
|
253
|
+
)
|
249
254
|
self.seq_len_fill_value = (
|
250
255
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
251
256
|
)
|
257
|
+
|
252
258
|
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
253
259
|
self.encoder_len_fill_value = 0
|
254
260
|
self.seq_lens_cpu = torch.full(
|
@@ -299,18 +305,30 @@ class CudaGraphRunner:
|
|
299
305
|
else:
|
300
306
|
self.encoder_lens = None
|
301
307
|
|
302
|
-
if self.
|
303
|
-
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
308
|
+
if self.require_gathered_buffer:
|
304
309
|
self.gathered_buffer = torch.zeros(
|
305
310
|
(
|
306
|
-
self.
|
311
|
+
self.max_num_token,
|
307
312
|
self.model_runner.model_config.hidden_size,
|
308
313
|
),
|
309
314
|
dtype=self.model_runner.dtype,
|
310
315
|
)
|
311
|
-
self.
|
312
|
-
|
313
|
-
|
316
|
+
if self.require_mlp_tp_gather:
|
317
|
+
self.global_num_tokens_gpu = torch.zeros(
|
318
|
+
(self.dp_size,), dtype=torch.int32
|
319
|
+
)
|
320
|
+
else:
|
321
|
+
assert self.require_attn_tp_gather
|
322
|
+
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
323
|
+
|
324
|
+
self.custom_mask = torch.ones(
|
325
|
+
(
|
326
|
+
(self.seq_lens.sum().item() + self.max_num_token)
|
327
|
+
* self.num_tokens_per_bs
|
328
|
+
),
|
329
|
+
dtype=torch.bool,
|
330
|
+
device="cuda",
|
331
|
+
)
|
314
332
|
|
315
333
|
# Capture
|
316
334
|
try:
|
@@ -322,20 +340,23 @@ class CudaGraphRunner:
|
|
322
340
|
)
|
323
341
|
|
324
342
|
def can_run(self, forward_batch: ForwardBatch):
|
325
|
-
if self.
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
if self.disable_padding
|
331
|
-
else total_global_tokens <= self.max_bs
|
343
|
+
if self.require_mlp_tp_gather:
|
344
|
+
cuda_graph_bs = (
|
345
|
+
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
346
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
347
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
332
348
|
)
|
333
349
|
else:
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
350
|
+
cuda_graph_bs = forward_batch.batch_size
|
351
|
+
|
352
|
+
is_bs_supported = (
|
353
|
+
cuda_graph_bs in self.graphs
|
354
|
+
if self.disable_padding
|
355
|
+
else cuda_graph_bs <= self.max_bs
|
356
|
+
)
|
357
|
+
|
358
|
+
if self.require_mlp_sync:
|
359
|
+
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
339
360
|
|
340
361
|
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
341
362
|
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
@@ -400,7 +421,7 @@ class CudaGraphRunner:
|
|
400
421
|
empty_cache=False,
|
401
422
|
)
|
402
423
|
capture_range.set_description(
|
403
|
-
f"Capturing batches ({avail_mem=:.2f} GB)"
|
424
|
+
f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
|
404
425
|
)
|
405
426
|
|
406
427
|
with patch_model(
|
@@ -456,11 +477,11 @@ class CudaGraphRunner:
|
|
456
477
|
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
|
457
478
|
)
|
458
479
|
|
459
|
-
if self.
|
480
|
+
if self.require_mlp_tp_gather:
|
460
481
|
self.global_num_tokens_gpu.copy_(
|
461
482
|
torch.tensor(
|
462
483
|
[
|
463
|
-
num_tokens // self.dp_size + (i <
|
484
|
+
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
464
485
|
for i in range(self.dp_size)
|
465
486
|
],
|
466
487
|
dtype=torch.int32,
|
@@ -469,6 +490,16 @@ class CudaGraphRunner:
|
|
469
490
|
)
|
470
491
|
global_num_tokens = self.global_num_tokens_gpu
|
471
492
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
493
|
+
elif self.require_attn_tp_gather:
|
494
|
+
self.global_num_tokens_gpu.copy_(
|
495
|
+
torch.tensor(
|
496
|
+
[num_tokens],
|
497
|
+
dtype=torch.int32,
|
498
|
+
device=input_ids.device,
|
499
|
+
)
|
500
|
+
)
|
501
|
+
global_num_tokens = self.global_num_tokens_gpu
|
502
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
472
503
|
else:
|
473
504
|
global_num_tokens = None
|
474
505
|
gathered_buffer = None
|
@@ -604,15 +635,18 @@ class CudaGraphRunner:
|
|
604
635
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
605
636
|
|
606
637
|
# Pad
|
607
|
-
if self.
|
608
|
-
|
609
|
-
|
638
|
+
if self.require_mlp_tp_gather:
|
639
|
+
total_batch_size = (
|
640
|
+
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
|
641
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
642
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
610
643
|
)
|
644
|
+
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
611
645
|
else:
|
612
646
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
613
647
|
bs = self.capture_bs[index]
|
614
648
|
if bs != raw_bs:
|
615
|
-
self.seq_lens.fill_(
|
649
|
+
self.seq_lens.fill_(self.seq_len_fill_value)
|
616
650
|
self.out_cache_loc.zero_()
|
617
651
|
|
618
652
|
# Common inputs
|
@@ -624,7 +658,7 @@ class CudaGraphRunner:
|
|
624
658
|
|
625
659
|
if forward_batch.seq_lens_cpu is not None:
|
626
660
|
if bs != raw_bs:
|
627
|
-
self.seq_lens_cpu.fill_(
|
661
|
+
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
628
662
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
629
663
|
|
630
664
|
if pp_proxy_tensors:
|
@@ -636,27 +670,28 @@ class CudaGraphRunner:
|
|
636
670
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
637
671
|
if forward_batch.mrope_positions is not None:
|
638
672
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
639
|
-
if self.
|
673
|
+
if self.require_gathered_buffer:
|
640
674
|
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
641
675
|
if enable_num_token_non_padded(self.model_runner.server_args):
|
642
676
|
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
|
643
677
|
if self.enable_two_batch_overlap:
|
644
678
|
self.tbo_plugin.replay_prepare(
|
645
|
-
forward_mode=
|
679
|
+
forward_mode=self.capture_forward_mode,
|
646
680
|
bs=bs,
|
647
681
|
num_token_non_padded=len(forward_batch.input_ids),
|
648
682
|
)
|
649
|
-
|
683
|
+
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
|
684
|
+
forward_batch.spec_info.custom_mask = self.custom_mask
|
650
685
|
# Attention backend
|
651
686
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
652
687
|
bs,
|
653
|
-
self.req_pool_indices,
|
654
|
-
self.seq_lens,
|
655
|
-
forward_batch.seq_lens_sum + (bs - raw_bs),
|
656
|
-
self.encoder_lens,
|
657
|
-
|
688
|
+
self.req_pool_indices[:bs],
|
689
|
+
self.seq_lens[:bs],
|
690
|
+
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
|
691
|
+
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
|
692
|
+
self.capture_forward_mode,
|
658
693
|
forward_batch.spec_info,
|
659
|
-
seq_lens_cpu=self.seq_lens_cpu,
|
694
|
+
seq_lens_cpu=self.seq_lens_cpu[:bs],
|
660
695
|
)
|
661
696
|
|
662
697
|
# Store fields
|
@@ -704,11 +739,7 @@ class CudaGraphRunner:
|
|
704
739
|
else:
|
705
740
|
spec_info = EagleVerifyInput(
|
706
741
|
draft_token=None,
|
707
|
-
custom_mask=
|
708
|
-
(num_tokens * self.model_runner.model_config.context_len),
|
709
|
-
dtype=torch.bool,
|
710
|
-
device="cuda",
|
711
|
-
),
|
742
|
+
custom_mask=self.custom_mask,
|
712
743
|
positions=None,
|
713
744
|
retrive_index=None,
|
714
745
|
retrive_next_token=None,
|
@@ -320,17 +320,30 @@ class ForwardBatch:
|
|
320
320
|
|
321
321
|
# For DP attention
|
322
322
|
if batch.global_num_tokens is not None:
|
323
|
-
|
323
|
+
|
324
|
+
spec_num_draft_tokens = (
|
325
|
+
batch.spec_num_draft_tokens
|
326
|
+
if batch.spec_num_draft_tokens is not None
|
327
|
+
else 1
|
328
|
+
)
|
329
|
+
global_num_tokens = [
|
330
|
+
x * spec_num_draft_tokens for x in batch.global_num_tokens
|
331
|
+
]
|
332
|
+
global_num_tokens_for_logprob = [
|
333
|
+
x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
|
334
|
+
]
|
335
|
+
|
336
|
+
ret.global_num_tokens_cpu = global_num_tokens
|
324
337
|
ret.global_num_tokens_gpu = torch.tensor(
|
325
|
-
|
338
|
+
global_num_tokens, dtype=torch.int64
|
326
339
|
).to(device, non_blocking=True)
|
327
340
|
|
328
|
-
ret.global_num_tokens_for_logprob_cpu =
|
341
|
+
ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
|
329
342
|
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
|
330
|
-
|
343
|
+
global_num_tokens_for_logprob, dtype=torch.int64
|
331
344
|
).to(device, non_blocking=True)
|
332
345
|
|
333
|
-
sum_len = sum(
|
346
|
+
sum_len = sum(global_num_tokens)
|
334
347
|
ret.gathered_buffer = torch.zeros(
|
335
348
|
(sum_len, model_runner.model_config.hidden_size),
|
336
349
|
dtype=model_runner.dtype,
|
@@ -30,6 +30,7 @@ from sglang.srt import debug_utils
|
|
30
30
|
from sglang.srt.configs.device_config import DeviceConfig
|
31
31
|
from sglang.srt.configs.load_config import LoadConfig
|
32
32
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
33
|
+
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
33
34
|
from sglang.srt.distributed import (
|
34
35
|
get_tp_group,
|
35
36
|
get_world_group,
|
@@ -70,14 +71,17 @@ from sglang.srt.managers.schedule_batch import (
|
|
70
71
|
GLOBAL_SERVER_ARGS_KEYS,
|
71
72
|
global_server_args_dict,
|
72
73
|
)
|
74
|
+
from sglang.srt.mem_cache.allocator import (
|
75
|
+
BaseTokenToKVPoolAllocator,
|
76
|
+
PagedTokenToKVPoolAllocator,
|
77
|
+
TokenToKVPoolAllocator,
|
78
|
+
)
|
73
79
|
from sglang.srt.mem_cache.memory_pool import (
|
74
80
|
DoubleSparseTokenToKVPool,
|
75
81
|
MHATokenToKVPool,
|
76
82
|
MLATokenToKVPool,
|
77
83
|
ReqToTokenPool,
|
78
|
-
TokenToKVPoolAllocator,
|
79
84
|
)
|
80
|
-
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
81
85
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
82
86
|
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
|
83
87
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
@@ -93,6 +97,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
93
97
|
from sglang.srt.utils import (
|
94
98
|
MultiprocessingSerializer,
|
95
99
|
cpu_has_amx_support,
|
100
|
+
dynamic_import,
|
96
101
|
enable_show_time_cost,
|
97
102
|
get_available_gpu_memory,
|
98
103
|
get_bool_env_var,
|
@@ -110,6 +115,7 @@ from sglang.srt.utils import (
|
|
110
115
|
)
|
111
116
|
|
112
117
|
_is_hip = is_hip()
|
118
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
113
119
|
|
114
120
|
# Use a small KV cache pool size for tests in CI
|
115
121
|
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
@@ -149,7 +155,7 @@ class ModelRunner:
|
|
149
155
|
server_args: ServerArgs,
|
150
156
|
is_draft_worker: bool = False,
|
151
157
|
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
152
|
-
token_to_kv_pool_allocator: Optional[
|
158
|
+
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
153
159
|
):
|
154
160
|
# Parse args
|
155
161
|
self.model_config = model_config
|
@@ -162,6 +168,7 @@ class ModelRunner:
|
|
162
168
|
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
163
169
|
self.tp_rank = tp_rank
|
164
170
|
self.tp_size = tp_size
|
171
|
+
self.dp_size = server_args.dp_size
|
165
172
|
self.pp_rank = pp_rank
|
166
173
|
self.pp_size = pp_size
|
167
174
|
self.dist_port = nccl_port
|
@@ -195,6 +202,7 @@ class ModelRunner:
|
|
195
202
|
| {
|
196
203
|
# TODO it is indeed not a "server args"
|
197
204
|
"use_mla_backend": self.use_mla_backend,
|
205
|
+
"speculative_algorithm": self.spec_algorithm,
|
198
206
|
}
|
199
207
|
)
|
200
208
|
|
@@ -218,6 +226,7 @@ class ModelRunner:
|
|
218
226
|
|
219
227
|
def initialize(self, min_per_gpu_memory: float):
|
220
228
|
server_args = self.server_args
|
229
|
+
|
221
230
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
222
231
|
enable=self.server_args.enable_memory_saver
|
223
232
|
)
|
@@ -230,7 +239,7 @@ class ModelRunner:
|
|
230
239
|
"SGLANG_LOG_EXPERT_LOCATION_METADATA"
|
231
240
|
):
|
232
241
|
logger.info(
|
233
|
-
f"Initial expert_location_metadata: {get_global_expert_location_metadata()
|
242
|
+
f"Initial expert_location_metadata: {get_global_expert_location_metadata()}"
|
234
243
|
)
|
235
244
|
|
236
245
|
set_global_expert_distribution_recorder(
|
@@ -272,6 +281,10 @@ class ModelRunner:
|
|
272
281
|
self.apply_torch_tp()
|
273
282
|
|
274
283
|
# Init lora
|
284
|
+
# TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add
|
285
|
+
# a new server arg `enable_lora` to control whether to init LoRA manager to be more
|
286
|
+
# explicit, as it is perfectly valid to start a server with an empty lora_paths and
|
287
|
+
# load LoRA adapters dynamically later.
|
275
288
|
if server_args.lora_paths is not None:
|
276
289
|
self.init_lora_manager()
|
277
290
|
|
@@ -299,7 +312,7 @@ class ModelRunner:
|
|
299
312
|
if (
|
300
313
|
server_args.attention_backend == "intel_amx"
|
301
314
|
and server_args.device == "cpu"
|
302
|
-
and not
|
315
|
+
and not _is_cpu_amx_available
|
303
316
|
):
|
304
317
|
logger.info(
|
305
318
|
"The current platform does not support Intel AMX, will fallback to torch_native backend."
|
@@ -534,6 +547,7 @@ class ModelRunner:
|
|
534
547
|
self.load_config = LoadConfig(
|
535
548
|
load_format=self.server_args.load_format,
|
536
549
|
download_dir=self.server_args.download_dir,
|
550
|
+
model_loader_extra_config=self.server_args.model_loader_extra_config,
|
537
551
|
)
|
538
552
|
if self.server_args.load_format == "gguf":
|
539
553
|
monkey_patch_vllm_gguf_config()
|
@@ -543,7 +557,7 @@ class ModelRunner:
|
|
543
557
|
monkey_patch_vllm_parallel_state()
|
544
558
|
monkey_patch_isinstance_for_vllm_base_layer()
|
545
559
|
|
546
|
-
with self.memory_saver_adapter.region():
|
560
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
|
547
561
|
self.model = get_model(
|
548
562
|
model_config=self.model_config,
|
549
563
|
load_config=self.load_config,
|
@@ -761,6 +775,9 @@ class ModelRunner:
|
|
761
775
|
]
|
762
776
|
if load_format == "direct":
|
763
777
|
_model_load_weights_direct(self.model, named_tensors)
|
778
|
+
elif load_format in self.server_args.custom_weight_loader:
|
779
|
+
custom_loader = dynamic_import(load_format)
|
780
|
+
custom_loader(self.model, named_tensors)
|
764
781
|
elif load_format is None:
|
765
782
|
self.model.load_weights(named_tensors)
|
766
783
|
else:
|
@@ -787,7 +804,6 @@ class ModelRunner:
|
|
787
804
|
def init_lora_manager(self):
|
788
805
|
self.lora_manager = LoRAManager(
|
789
806
|
base_model=self.model,
|
790
|
-
lora_paths=self.server_args.lora_paths,
|
791
807
|
base_hf_config=self.model_config.hf_config,
|
792
808
|
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
793
809
|
load_config=self.load_config,
|
@@ -796,6 +812,7 @@ class ModelRunner:
|
|
796
812
|
tp_size=self.tp_size,
|
797
813
|
tp_rank=self.tp_rank,
|
798
814
|
)
|
815
|
+
self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
799
816
|
logger.info("LoRA manager ready.")
|
800
817
|
|
801
818
|
def profile_max_num_token(self, total_gpu_memory: int):
|
@@ -849,7 +866,9 @@ class ModelRunner:
|
|
849
866
|
else:
|
850
867
|
self.kv_cache_dtype = torch.float8_e5m2
|
851
868
|
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
|
852
|
-
if
|
869
|
+
if _is_hip: # Using natively supported format
|
870
|
+
self.kv_cache_dtype = torch.float8_e4m3fnuz
|
871
|
+
else:
|
853
872
|
self.kv_cache_dtype = torch.float8_e4m3fn
|
854
873
|
else:
|
855
874
|
raise ValueError(
|
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
# ruff: noqa: SIM117
|
4
4
|
import collections
|
5
|
+
import concurrent
|
5
6
|
import dataclasses
|
6
7
|
import fnmatch
|
7
8
|
import glob
|
@@ -11,14 +12,17 @@ import math
|
|
11
12
|
import os
|
12
13
|
import time
|
13
14
|
from abc import ABC, abstractmethod
|
15
|
+
from concurrent.futures import ThreadPoolExecutor
|
14
16
|
from contextlib import contextmanager
|
15
17
|
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
16
18
|
|
17
19
|
import huggingface_hub
|
18
20
|
import numpy as np
|
21
|
+
import safetensors.torch
|
19
22
|
import torch
|
20
23
|
from huggingface_hub import HfApi, hf_hub_download
|
21
24
|
from torch import nn
|
25
|
+
from tqdm.auto import tqdm
|
22
26
|
from transformers import AutoModelForCausalLM
|
23
27
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
24
28
|
|
@@ -41,6 +45,7 @@ from sglang.srt.model_loader.utils import (
|
|
41
45
|
set_default_torch_dtype,
|
42
46
|
)
|
43
47
|
from sglang.srt.model_loader.weight_utils import (
|
48
|
+
_BAR_FORMAT,
|
44
49
|
download_safetensors_index_file_from_hf,
|
45
50
|
download_weights_from_hf,
|
46
51
|
filter_duplicate_safetensors_files,
|
@@ -49,6 +54,8 @@ from sglang.srt.model_loader.weight_utils import (
|
|
49
54
|
get_quant_config,
|
50
55
|
gguf_quant_weights_iterator,
|
51
56
|
initialize_dummy_weights,
|
57
|
+
multi_thread_pt_weights_iterator,
|
58
|
+
multi_thread_safetensors_weights_iterator,
|
52
59
|
np_cache_weights_iterator,
|
53
60
|
pt_weights_iterator,
|
54
61
|
safetensors_weights_iterator,
|
@@ -181,6 +188,9 @@ class BaseModelLoader(ABC):
|
|
181
188
|
class DefaultModelLoader(BaseModelLoader):
|
182
189
|
"""Model loader that can load different file types from disk."""
|
183
190
|
|
191
|
+
# default number of thread when enable multithread weight loading
|
192
|
+
DEFAULT_NUM_THREADS = 8
|
193
|
+
|
184
194
|
@dataclasses.dataclass
|
185
195
|
class Source:
|
186
196
|
"""A source for weights."""
|
@@ -208,10 +218,15 @@ class DefaultModelLoader(BaseModelLoader):
|
|
208
218
|
|
209
219
|
def __init__(self, load_config: LoadConfig):
|
210
220
|
super().__init__(load_config)
|
211
|
-
|
221
|
+
extra_config = load_config.model_loader_extra_config
|
222
|
+
allowed_keys = {"enable_multithread_load", "num_threads"}
|
223
|
+
unexpected_keys = set(extra_config.keys()) - allowed_keys
|
224
|
+
|
225
|
+
if unexpected_keys:
|
212
226
|
raise ValueError(
|
213
|
-
f"
|
214
|
-
f"
|
227
|
+
f"Unexpected extra config keys for load format "
|
228
|
+
f"{load_config.load_format}: "
|
229
|
+
f"{unexpected_keys}"
|
215
230
|
)
|
216
231
|
|
217
232
|
def _maybe_download_from_modelscope(
|
@@ -324,6 +339,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|
324
339
|
self, source: "Source"
|
325
340
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
326
341
|
"""Get an iterator for the model weights based on the load format."""
|
342
|
+
extra_config = self.load_config.model_loader_extra_config
|
327
343
|
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
328
344
|
source.model_or_path, source.revision, source.fall_back_to_pt
|
329
345
|
)
|
@@ -337,9 +353,35 @@ class DefaultModelLoader(BaseModelLoader):
|
|
337
353
|
hf_weights_files,
|
338
354
|
)
|
339
355
|
elif use_safetensors:
|
340
|
-
|
356
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
357
|
+
|
358
|
+
weight_loader_disable_mmap = global_server_args_dict.get(
|
359
|
+
"weight_loader_disable_mmap"
|
360
|
+
)
|
361
|
+
|
362
|
+
if extra_config.get("enable_multithread_load"):
|
363
|
+
weights_iterator = multi_thread_safetensors_weights_iterator(
|
364
|
+
hf_weights_files,
|
365
|
+
max_workers=extra_config.get(
|
366
|
+
"num_threads", self.DEFAULT_NUM_THREADS
|
367
|
+
),
|
368
|
+
disable_mmap=weight_loader_disable_mmap,
|
369
|
+
)
|
370
|
+
else:
|
371
|
+
weights_iterator = safetensors_weights_iterator(
|
372
|
+
hf_weights_files, disable_mmap=weight_loader_disable_mmap
|
373
|
+
)
|
374
|
+
|
341
375
|
else:
|
342
|
-
|
376
|
+
if extra_config.get("enable_multithread_load"):
|
377
|
+
weights_iterator = multi_thread_pt_weights_iterator(
|
378
|
+
hf_weights_files,
|
379
|
+
max_workers=extra_config.get(
|
380
|
+
"num_threads", self.DEFAULT_NUM_THREADS
|
381
|
+
),
|
382
|
+
)
|
383
|
+
else:
|
384
|
+
weights_iterator = pt_weights_iterator(hf_weights_files)
|
343
385
|
|
344
386
|
# Apply the prefix.
|
345
387
|
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
|
@@ -378,9 +420,9 @@ class DefaultModelLoader(BaseModelLoader):
|
|
378
420
|
self.load_config,
|
379
421
|
)
|
380
422
|
|
381
|
-
|
382
|
-
|
383
|
-
|
423
|
+
self.load_weights_and_postprocess(
|
424
|
+
model, self._get_all_weights(model_config, model), target_device
|
425
|
+
)
|
384
426
|
|
385
427
|
return model.eval()
|
386
428
|
|