sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -132,6 +132,9 @@ class ForwardMode(IntEnum):
|
|
132
132
|
or self == ForwardMode.IDLE
|
133
133
|
)
|
134
134
|
|
135
|
+
def is_cpu_graph(self):
|
136
|
+
return self == ForwardMode.DECODE
|
137
|
+
|
135
138
|
def is_dummy_first(self):
|
136
139
|
return self == ForwardMode.DUMMY_FIRST
|
137
140
|
|
@@ -441,7 +444,13 @@ class ForwardBatch:
|
|
441
444
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
442
445
|
|
443
446
|
if model_runner.model_is_mrope:
|
444
|
-
|
447
|
+
if (
|
448
|
+
ret.spec_info is not None
|
449
|
+
and getattr(ret.spec_info, "positions", None) is not None
|
450
|
+
):
|
451
|
+
ret._compute_spec_mrope_positions(model_runner, batch)
|
452
|
+
else:
|
453
|
+
ret._compute_mrope_positions(model_runner, batch)
|
445
454
|
|
446
455
|
# Init lora information
|
447
456
|
if model_runner.server_args.enable_lora:
|
@@ -507,6 +516,52 @@ class ForwardBatch:
|
|
507
516
|
or self.contains_image_inputs()
|
508
517
|
)
|
509
518
|
|
519
|
+
def _compute_spec_mrope_positions(
|
520
|
+
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
521
|
+
):
|
522
|
+
# TODO support batched deltas
|
523
|
+
batch_size = self.seq_lens.shape[0]
|
524
|
+
device = model_runner.device
|
525
|
+
mm_inputs = batch.multimodal_inputs
|
526
|
+
|
527
|
+
if batch.forward_mode.is_draft_extend(): # draft_extend_after_decode
|
528
|
+
mrope_deltas = []
|
529
|
+
extend_lens = []
|
530
|
+
for batch_idx in range(batch_size):
|
531
|
+
extend_seq_len = batch.extend_seq_lens[batch_idx]
|
532
|
+
extend_lens.append(extend_seq_len)
|
533
|
+
mrope_delta = (
|
534
|
+
torch.zeros(1, dtype=torch.int64)
|
535
|
+
if mm_inputs[batch_idx] is None
|
536
|
+
else mm_inputs[batch_idx].mrope_position_delta.squeeze(0)
|
537
|
+
)
|
538
|
+
mrope_deltas.append(mrope_delta.to(device=device))
|
539
|
+
position_chunks = torch.split(batch.spec_info.positions, extend_lens)
|
540
|
+
mrope_positions_list = [
|
541
|
+
pos_chunk + delta
|
542
|
+
for pos_chunk, delta in zip(position_chunks, mrope_deltas)
|
543
|
+
]
|
544
|
+
next_input_positions = (
|
545
|
+
torch.cat(mrope_positions_list, dim=0).unsqueeze(0).repeat(3, 1)
|
546
|
+
)
|
547
|
+
|
548
|
+
else: # target_verify or draft_decode
|
549
|
+
seq_positions = batch.spec_info.positions.view(batch_size, -1)
|
550
|
+
mrope_deltas = [
|
551
|
+
(
|
552
|
+
torch.tensor([0], dtype=torch.int64)
|
553
|
+
if mm_inputs[i] is None
|
554
|
+
else mm_inputs[i].mrope_position_delta.squeeze(0)
|
555
|
+
)
|
556
|
+
for i in range(batch_size)
|
557
|
+
]
|
558
|
+
mrope_delta_tensor = torch.stack(mrope_deltas, dim=0).to(device=device)
|
559
|
+
next_input_positions = (
|
560
|
+
(seq_positions + mrope_delta_tensor).flatten().unsqueeze(0).repeat(3, 1)
|
561
|
+
)
|
562
|
+
|
563
|
+
self.mrope_positions = next_input_positions
|
564
|
+
|
510
565
|
def _compute_mrope_positions(
|
511
566
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
512
567
|
):
|
@@ -516,24 +571,23 @@ class ForwardBatch:
|
|
516
571
|
for batch_idx in range(batch_size):
|
517
572
|
mm_input = batch.multimodal_inputs[batch_idx]
|
518
573
|
if self.forward_mode.is_decode():
|
519
|
-
mrope_position_deltas = (
|
520
|
-
[0]
|
521
|
-
if mm_input is None
|
522
|
-
else flatten_nested_list(mm_input.mrope_position_delta.tolist())
|
523
|
-
)
|
524
|
-
next_input_positions = []
|
525
|
-
for mrope_position_delta in mrope_position_deltas:
|
526
|
-
# batched deltas needs to be processed separately
|
527
|
-
# Convert list of lists to tensor with shape [3, seq_len]
|
528
|
-
next_input_positions += [
|
529
|
-
MRotaryEmbedding.get_next_input_positions(
|
530
|
-
mrope_position_delta,
|
531
|
-
int(self.seq_lens[batch_idx]) - 1,
|
532
|
-
int(self.seq_lens[batch_idx]),
|
533
|
-
)
|
534
|
-
]
|
535
574
|
# 3 * N
|
536
|
-
|
575
|
+
if mm_input is None:
|
576
|
+
mrope_positions_list[batch_idx] = torch.full(
|
577
|
+
(3, 1),
|
578
|
+
self.seq_lens[batch_idx] - 1,
|
579
|
+
dtype=torch.int64,
|
580
|
+
device=model_runner.device,
|
581
|
+
)
|
582
|
+
else:
|
583
|
+
mrope_position_deltas = mm_input.mrope_position_delta.flatten().to(
|
584
|
+
model_runner.device, non_blocking=True
|
585
|
+
)
|
586
|
+
mrope_positions_list[batch_idx] = (
|
587
|
+
(mrope_position_deltas + self.seq_lens[batch_idx] - 1)
|
588
|
+
.unsqueeze(0)
|
589
|
+
.repeat(3, 1)
|
590
|
+
)
|
537
591
|
elif self.forward_mode.is_extend():
|
538
592
|
extend_seq_len, extend_prefix_len = (
|
539
593
|
batch.extend_seq_lens[batch_idx],
|
@@ -20,6 +20,7 @@ import json
|
|
20
20
|
import logging
|
21
21
|
import os
|
22
22
|
import time
|
23
|
+
from collections import defaultdict
|
23
24
|
from dataclasses import dataclass
|
24
25
|
from typing import List, Optional, Tuple, Union
|
25
26
|
|
@@ -32,6 +33,7 @@ from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
|
32
33
|
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
33
34
|
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
34
35
|
from sglang.srt.distributed import (
|
36
|
+
get_pp_group,
|
35
37
|
get_tp_group,
|
36
38
|
get_world_group,
|
37
39
|
init_distributed_environment,
|
@@ -83,11 +85,14 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
83
85
|
AscendMLAPagedTokenToKVPool,
|
84
86
|
AscendTokenToKVPool,
|
85
87
|
DoubleSparseTokenToKVPool,
|
88
|
+
HybridLinearKVPool,
|
89
|
+
HybridReqToTokenPool,
|
86
90
|
MHATokenToKVPool,
|
87
91
|
MLATokenToKVPool,
|
88
92
|
ReqToTokenPool,
|
89
93
|
SWAKVPool,
|
90
94
|
)
|
95
|
+
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
|
91
96
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
92
97
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
93
98
|
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
@@ -300,6 +305,26 @@ class ModelRunner:
|
|
300
305
|
if architectures and not any("Llama4" in arch for arch in architectures):
|
301
306
|
self.is_hybrid = self.model_config.is_hybrid = True
|
302
307
|
|
308
|
+
if self.is_hybrid_gdn:
|
309
|
+
logger.warning("Hybrid GDN model detected, disable radix cache")
|
310
|
+
self.server_args.disable_radix_cache = True
|
311
|
+
self.server_args.attention_backend = "hybrid_linear_attn"
|
312
|
+
if self.server_args.max_mamba_cache_size is None:
|
313
|
+
if self.server_args.max_running_requests is not None:
|
314
|
+
self.server_args.max_mamba_cache_size = (
|
315
|
+
self.server_args.max_running_requests
|
316
|
+
)
|
317
|
+
else:
|
318
|
+
self.server_args.max_mamba_cache_size = 512
|
319
|
+
self.server_args.max_mamba_cache_size = (
|
320
|
+
self.server_args.max_mamba_cache_size
|
321
|
+
// (
|
322
|
+
self.server_args.dp_size
|
323
|
+
if self.server_args.enable_dp_attention
|
324
|
+
else 1
|
325
|
+
)
|
326
|
+
)
|
327
|
+
|
303
328
|
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
|
304
329
|
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
|
305
330
|
# determine the number of layers.
|
@@ -307,7 +332,10 @@ class ModelRunner:
|
|
307
332
|
model_num_layers = (
|
308
333
|
self.model_config.num_nextn_predict_layers
|
309
334
|
if self.is_draft_worker and model_has_mtp_layers
|
310
|
-
else
|
335
|
+
else max(
|
336
|
+
self.model_config.num_hidden_layers,
|
337
|
+
self.model_config.num_attention_layers,
|
338
|
+
)
|
311
339
|
)
|
312
340
|
self.start_layer = getattr(self.model, "start_layer", 0)
|
313
341
|
self.end_layer = getattr(self.model, "end_layer", model_num_layers)
|
@@ -338,6 +366,14 @@ class ModelRunner:
|
|
338
366
|
if server_args.enable_lora:
|
339
367
|
self.init_lora_manager()
|
340
368
|
|
369
|
+
# Init Double Sparsity
|
370
|
+
if server_args.enable_double_sparsity:
|
371
|
+
if server_args.ds_heavy_channel_type is None:
|
372
|
+
raise ValueError(
|
373
|
+
"Please specify the heavy channel type for double sparsity optimization."
|
374
|
+
)
|
375
|
+
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
376
|
+
|
341
377
|
# Init memory pool and attention backends
|
342
378
|
self.init_memory_pool(
|
343
379
|
min_per_gpu_memory,
|
@@ -348,12 +384,12 @@ class ModelRunner:
|
|
348
384
|
self.init_cublas()
|
349
385
|
self.init_attention_backend()
|
350
386
|
self.init_device_graphs()
|
351
|
-
elif self.device
|
387
|
+
elif self.device in ["npu", "cpu"]:
|
352
388
|
self.init_attention_backend()
|
353
389
|
self.init_device_graphs()
|
354
390
|
else:
|
355
391
|
self.graph_runner = None
|
356
|
-
self.
|
392
|
+
self.graph_mem_usage = 0
|
357
393
|
self.init_attention_backend()
|
358
394
|
|
359
395
|
# auxiliary hidden capture mode. TODO: expose this to server args?
|
@@ -503,11 +539,6 @@ class ModelRunner:
|
|
503
539
|
)
|
504
540
|
server_args.attention_backend = "triton"
|
505
541
|
server_args.disable_cuda_graph = True
|
506
|
-
if server_args.ds_heavy_channel_type is None:
|
507
|
-
raise ValueError(
|
508
|
-
"Please specify the heavy channel type for double sparsity optimization."
|
509
|
-
)
|
510
|
-
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
511
542
|
|
512
543
|
if self.is_multimodal:
|
513
544
|
if not self.is_multimodal_chunked_prefill_supported:
|
@@ -519,6 +550,17 @@ class ModelRunner:
|
|
519
550
|
|
520
551
|
if not self.use_mla_backend:
|
521
552
|
server_args.disable_chunked_prefix_cache = True
|
553
|
+
# TODO(kaixih@nvidia): remove this once we have a better solution for DP attention.
|
554
|
+
# For more details, see: https://github.com/sgl-project/sglang/issues/8616
|
555
|
+
elif (
|
556
|
+
self.dp_size > 1
|
557
|
+
and is_sm100_supported()
|
558
|
+
and server_args.attention_backend != "triton"
|
559
|
+
):
|
560
|
+
logger.info(
|
561
|
+
"Disable chunked prefix cache when dp size > 1 and attention backend is not triton."
|
562
|
+
)
|
563
|
+
server_args.disable_chunked_prefix_cache = True
|
522
564
|
|
523
565
|
if not server_args.disable_chunked_prefix_cache:
|
524
566
|
logger.info("Chunked prefix cache is turned on.")
|
@@ -590,6 +632,11 @@ class ModelRunner:
|
|
590
632
|
# Set local size to hint SGLang to use shared memory based AllReduce
|
591
633
|
os.environ["LOCAL_SIZE"] = str(self.tp_size)
|
592
634
|
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
|
635
|
+
|
636
|
+
@torch.library.register_fake("sgl_kernel::shm_allgather")
|
637
|
+
def _(data, dim):
|
638
|
+
return torch.cat([data] * self.tp_size, dim=dim)
|
639
|
+
|
593
640
|
else:
|
594
641
|
logger.warning(
|
595
642
|
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
|
@@ -622,6 +669,7 @@ class ModelRunner:
|
|
622
669
|
cpu_group=get_world_group().cpu_group,
|
623
670
|
)
|
624
671
|
self.tp_group = get_tp_group()
|
672
|
+
self.pp_group = get_pp_group()
|
625
673
|
self.attention_tp_group = get_attention_tp_group()
|
626
674
|
|
627
675
|
# Check memory for tensor parallelism
|
@@ -1054,6 +1102,8 @@ class ModelRunner:
|
|
1054
1102
|
"num_nextn_predict_layers",
|
1055
1103
|
self.num_effective_layers,
|
1056
1104
|
)
|
1105
|
+
elif self.is_hybrid_gdn:
|
1106
|
+
num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
|
1057
1107
|
else:
|
1058
1108
|
num_layers = self.num_effective_layers
|
1059
1109
|
if self.use_mla_backend:
|
@@ -1073,9 +1123,22 @@ class ModelRunner:
|
|
1073
1123
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
1074
1124
|
1 - self.mem_fraction_static
|
1075
1125
|
)
|
1126
|
+
if self.is_hybrid_gdn:
|
1127
|
+
rest_memory -= (
|
1128
|
+
self.server_args.max_mamba_cache_size
|
1129
|
+
* self.model_config.hf_config.mamba_cache_per_req
|
1130
|
+
/ (1 << 30)
|
1131
|
+
)
|
1076
1132
|
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
1077
1133
|
return max_num_token
|
1078
1134
|
|
1135
|
+
@property
|
1136
|
+
def is_hybrid_gdn(self):
|
1137
|
+
return self.model_config.hf_config.architectures[0] in [
|
1138
|
+
"Qwen3NextForCausalLM",
|
1139
|
+
"Qwen3NextForCausalLMMTP",
|
1140
|
+
]
|
1141
|
+
|
1079
1142
|
def set_num_token_hybrid(self):
|
1080
1143
|
if (
|
1081
1144
|
"Llama4ForConditionalGeneration"
|
@@ -1196,6 +1259,8 @@ class ModelRunner:
|
|
1196
1259
|
),
|
1197
1260
|
4096,
|
1198
1261
|
)
|
1262
|
+
if self.is_hybrid_gdn:
|
1263
|
+
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
|
1199
1264
|
|
1200
1265
|
if not self.spec_algorithm.is_none():
|
1201
1266
|
if self.is_draft_worker:
|
@@ -1234,6 +1299,16 @@ class ModelRunner:
|
|
1234
1299
|
// self.server_args.page_size
|
1235
1300
|
* self.server_args.page_size
|
1236
1301
|
)
|
1302
|
+
# different pp rank may have different num of layers, so we need to reduce the max_total_num_tokens
|
1303
|
+
if self.pp_size > 1:
|
1304
|
+
tensor = torch.tensor(self.max_total_num_tokens, dtype=torch.int64)
|
1305
|
+
torch.distributed.all_reduce(
|
1306
|
+
tensor,
|
1307
|
+
op=torch.distributed.ReduceOp.MIN,
|
1308
|
+
group=get_world_group().cpu_group,
|
1309
|
+
)
|
1310
|
+
self.max_total_num_tokens = tensor.item()
|
1311
|
+
|
1237
1312
|
# create token size for hybrid cache
|
1238
1313
|
if self.is_hybrid:
|
1239
1314
|
self.set_num_token_hybrid()
|
@@ -1264,6 +1339,28 @@ class ModelRunner:
|
|
1264
1339
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1265
1340
|
pre_alloc_size=pre_alloc_size,
|
1266
1341
|
)
|
1342
|
+
elif self.is_hybrid_gdn:
|
1343
|
+
config = self.model_config.hf_config
|
1344
|
+
(
|
1345
|
+
conv_state_shape,
|
1346
|
+
temporal_state_shape,
|
1347
|
+
conv_dtype,
|
1348
|
+
ssm_dtype,
|
1349
|
+
mamba_layers,
|
1350
|
+
) = config.hybrid_gdn_params
|
1351
|
+
self.req_to_token_pool = HybridReqToTokenPool(
|
1352
|
+
size=max_num_reqs,
|
1353
|
+
max_context_len=self.model_config.context_len
|
1354
|
+
+ extra_max_context_len,
|
1355
|
+
device=self.device,
|
1356
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
1357
|
+
conv_state_shape=conv_state_shape,
|
1358
|
+
temporal_state_shape=temporal_state_shape,
|
1359
|
+
conv_dtype=conv_dtype,
|
1360
|
+
ssm_dtype=ssm_dtype,
|
1361
|
+
mamba_layers=mamba_layers,
|
1362
|
+
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
1363
|
+
)
|
1267
1364
|
else:
|
1268
1365
|
self.req_to_token_pool = ReqToTokenPool(
|
1269
1366
|
size=max_num_reqs,
|
@@ -1346,6 +1443,23 @@ class ModelRunner:
|
|
1346
1443
|
enable_kvcache_transpose=False,
|
1347
1444
|
device=self.device,
|
1348
1445
|
)
|
1446
|
+
elif self.is_hybrid_gdn:
|
1447
|
+
self.token_to_kv_pool = HybridLinearKVPool(
|
1448
|
+
size=self.max_total_num_tokens,
|
1449
|
+
dtype=self.kv_cache_dtype,
|
1450
|
+
head_num=self.model_config.get_num_kv_heads(
|
1451
|
+
get_attention_tp_size()
|
1452
|
+
),
|
1453
|
+
head_dim=self.model_config.head_dim,
|
1454
|
+
# if draft worker, we only need 1 attention layer's kv pool
|
1455
|
+
full_attention_layer_ids=(
|
1456
|
+
[0]
|
1457
|
+
if self.is_draft_worker
|
1458
|
+
else self.model_config.hf_config.full_attention_layer_ids
|
1459
|
+
),
|
1460
|
+
enable_kvcache_transpose=False,
|
1461
|
+
device=self.device,
|
1462
|
+
)
|
1349
1463
|
else:
|
1350
1464
|
self.token_to_kv_pool = MHATokenToKVPool(
|
1351
1465
|
self.max_total_num_tokens,
|
@@ -1440,14 +1554,12 @@ class ModelRunner:
|
|
1440
1554
|
else self.server_args.attention_backend
|
1441
1555
|
)
|
1442
1556
|
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
1443
|
-
assert (
|
1444
|
-
self.server_args.speculative_algorithm is None
|
1445
|
-
), "Currently HybridAttentionBackend does not support speculative decoding."
|
1446
1557
|
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
1447
1558
|
HybridAttnBackend,
|
1448
1559
|
)
|
1449
1560
|
|
1450
1561
|
attn_backend = HybridAttnBackend(
|
1562
|
+
self,
|
1451
1563
|
decode_backend=self._get_attention_backend_from_str(
|
1452
1564
|
self.decode_attention_backend_str
|
1453
1565
|
),
|
@@ -1581,6 +1693,24 @@ class ModelRunner:
|
|
1581
1693
|
)
|
1582
1694
|
|
1583
1695
|
return DualChunkFlashAttentionBackend(self)
|
1696
|
+
elif backend_str == "hybrid_linear_attn":
|
1697
|
+
assert (
|
1698
|
+
self.is_hybrid_gdn
|
1699
|
+
), "hybrid_linear_attn backend can only be used with hybrid GDN models."
|
1700
|
+
from sglang.srt.layers.attention.flashattention_backend import (
|
1701
|
+
FlashAttentionBackend,
|
1702
|
+
)
|
1703
|
+
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
1704
|
+
HybridLinearAttnBackend,
|
1705
|
+
MambaAttnBackend,
|
1706
|
+
)
|
1707
|
+
|
1708
|
+
full_attn_backend = FlashAttentionBackend(self)
|
1709
|
+
linear_attn_backend = MambaAttnBackend(self)
|
1710
|
+
full_attn_layers = self.model_config.hf_config.full_attention_layer_ids
|
1711
|
+
return HybridLinearAttnBackend(
|
1712
|
+
full_attn_backend, linear_attn_backend, full_attn_layers
|
1713
|
+
)
|
1584
1714
|
else:
|
1585
1715
|
raise ValueError(f"Invalid attention backend: {backend_str}")
|
1586
1716
|
|
@@ -1602,38 +1732,46 @@ class ModelRunner:
|
|
1602
1732
|
)
|
1603
1733
|
|
1604
1734
|
def init_device_graphs(self):
|
1605
|
-
"""Capture
|
1735
|
+
"""Capture device graphs."""
|
1606
1736
|
self.graph_runner = None
|
1607
|
-
self.
|
1737
|
+
self.graph_mem_usage = 0
|
1608
1738
|
|
1609
1739
|
if not self.is_generation:
|
1610
1740
|
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
1611
1741
|
return
|
1612
1742
|
|
1613
|
-
if self.server_args.disable_cuda_graph:
|
1743
|
+
if self.device != "cpu" and self.server_args.disable_cuda_graph:
|
1744
|
+
return
|
1745
|
+
|
1746
|
+
if self.device == "cpu" and not self.server_args.enable_torch_compile:
|
1614
1747
|
return
|
1615
1748
|
|
1616
1749
|
tic = time.perf_counter()
|
1617
1750
|
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
1618
1751
|
logger.info(
|
1619
|
-
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
1752
|
+
f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
1620
1753
|
)
|
1621
|
-
|
1622
|
-
CudaGraphRunner
|
1754
|
+
graph_runners = defaultdict(
|
1755
|
+
lambda: CudaGraphRunner,
|
1756
|
+
{
|
1757
|
+
"cpu": CPUGraphRunner,
|
1758
|
+
"npu": NPUGraphRunner,
|
1759
|
+
},
|
1623
1760
|
)
|
1761
|
+
self.graph_runner = graph_runners[self.device](self)
|
1762
|
+
|
1624
1763
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
1625
|
-
self.
|
1764
|
+
self.graph_mem_usage = before_mem - after_mem
|
1626
1765
|
logger.info(
|
1627
|
-
f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
|
1628
|
-
f"mem usage={self.
|
1766
|
+
f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
|
1767
|
+
f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
|
1629
1768
|
)
|
1630
1769
|
|
1631
1770
|
def init_threads_binding(self):
|
1632
1771
|
omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
|
1772
|
+
cpu_ids_by_node = get_cpu_ids_by_node()
|
1773
|
+
n_numa_node = len(cpu_ids_by_node)
|
1633
1774
|
if omp_cpuids == "all":
|
1634
|
-
cpu_ids_by_node = get_cpu_ids_by_node()
|
1635
|
-
n_numa_node = len(cpu_ids_by_node)
|
1636
|
-
|
1637
1775
|
assert self.tp_size <= n_numa_node, (
|
1638
1776
|
f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
|
1639
1777
|
f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
|
@@ -1650,11 +1788,22 @@ class ModelRunner:
|
|
1650
1788
|
)
|
1651
1789
|
self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
|
1652
1790
|
else:
|
1653
|
-
|
1791
|
+
threads_bind_list = omp_cpuids.split("|")
|
1792
|
+
assert self.tp_size == len(threads_bind_list), (
|
1793
|
+
f"SGLANG_CPU_OMP_THREADS_BIND setting must be aligned with TP size parameter ({self.tp_size}). "
|
1794
|
+
f"Please double check your settings."
|
1795
|
+
)
|
1796
|
+
self.local_omp_cpuid = threads_bind_list[self.tp_rank]
|
1797
|
+
if self.tp_size > n_numa_node:
|
1798
|
+
logger.warning(
|
1799
|
+
f"TP size ({self.tp_size})is larger than numa node number ({n_numa_node}), "
|
1800
|
+
f"in this case the available memory amount of each rank cannot be determined in prior. "
|
1801
|
+
f"Please set proper `--max-total-tokens` to avoid the out-of-memory error."
|
1802
|
+
)
|
1654
1803
|
|
1655
1804
|
def apply_torch_tp(self):
|
1656
1805
|
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
1657
|
-
from sglang.srt.model_parallel import tensor_parallel
|
1806
|
+
from sglang.srt.layers.model_parallel import tensor_parallel
|
1658
1807
|
|
1659
1808
|
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
1660
1809
|
tensor_parallel(self.model, device_mesh)
|
@@ -1770,18 +1919,24 @@ class ModelRunner:
|
|
1770
1919
|
reinit_attn_backend: bool = False,
|
1771
1920
|
split_forward_count: int = 1,
|
1772
1921
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
1773
|
-
|
1774
|
-
forward_batch.forward_mode.
|
1922
|
+
mode_check = (
|
1923
|
+
forward_batch.forward_mode.is_cpu_graph
|
1924
|
+
if self.device == "cpu"
|
1925
|
+
else forward_batch.forward_mode.is_cuda_graph
|
1926
|
+
)
|
1927
|
+
can_run_graph = bool(
|
1928
|
+
mode_check()
|
1775
1929
|
and self.graph_runner
|
1776
1930
|
and self.graph_runner.can_run(forward_batch)
|
1777
1931
|
)
|
1778
|
-
|
1932
|
+
|
1933
|
+
if can_run_graph:
|
1779
1934
|
ret = self.graph_runner.replay(
|
1780
1935
|
forward_batch,
|
1781
1936
|
skip_attn_backend_init=skip_attn_backend_init,
|
1782
1937
|
pp_proxy_tensors=pp_proxy_tensors,
|
1783
1938
|
)
|
1784
|
-
return ret,
|
1939
|
+
return ret, can_run_graph
|
1785
1940
|
|
1786
1941
|
# For MLP sync
|
1787
1942
|
if forward_batch.global_num_tokens_cpu is not None:
|
@@ -1810,10 +1965,13 @@ class ModelRunner:
|
|
1810
1965
|
else:
|
1811
1966
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
1812
1967
|
|
1813
|
-
if
|
1968
|
+
if (
|
1969
|
+
forward_batch.global_num_tokens_cpu is not None
|
1970
|
+
and self.pp_group.is_last_rank
|
1971
|
+
):
|
1814
1972
|
forward_batch.post_forward_mlp_sync_batch(ret)
|
1815
1973
|
|
1816
|
-
return ret,
|
1974
|
+
return ret, can_run_graph
|
1817
1975
|
|
1818
1976
|
def _preprocess_logits(
|
1819
1977
|
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
@@ -1,16 +1,22 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
3
7
|
from torch import nn
|
4
8
|
|
5
|
-
from sglang.srt.configs.device_config import DeviceConfig
|
6
|
-
from sglang.srt.configs.load_config import LoadConfig
|
7
|
-
from sglang.srt.configs.model_config import ModelConfig
|
8
9
|
from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
|
9
10
|
from sglang.srt.model_loader.utils import (
|
10
11
|
get_architecture_class_name,
|
11
12
|
get_model_architecture,
|
12
13
|
)
|
13
14
|
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
17
|
+
from sglang.srt.configs.load_config import LoadConfig
|
18
|
+
from sglang.srt.configs.model_config import ModelConfig
|
19
|
+
|
14
20
|
|
15
21
|
def get_model(
|
16
22
|
*,
|