sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -117
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +3 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +22 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +8 -5
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +106 -15
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +55 -13
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +40 -15
- sglang/srt/layers/communicator.py +35 -8
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +9 -8
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +87 -107
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +59 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +8 -7
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -4
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +10 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +61 -32
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +21 -4
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +30 -8
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +170 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +59 -22
- sglang/srt/managers/tokenizer_manager.py +137 -67
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +48 -17
- sglang/srt/model_executor/model_runner.py +24 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +95 -50
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +102 -27
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/qwen3_moe.py +39 -14
- sglang/srt/models/step3_vl.py +10 -1
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +218 -23
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +163 -9
- sglang/srt/utils.py +41 -26
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +4 -4
- sglang/test/test_utils.py +4 -4
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py
CHANGED
@@ -43,6 +43,7 @@ I'm going to the park
|
|
43
43
|
"""
|
44
44
|
|
45
45
|
import argparse
|
46
|
+
import copy
|
46
47
|
import dataclasses
|
47
48
|
import itertools
|
48
49
|
import json
|
@@ -84,12 +85,14 @@ class BenchArgs:
|
|
84
85
|
batch_size: Tuple[int] = (1,)
|
85
86
|
input_len: Tuple[int] = (1024,)
|
86
87
|
output_len: Tuple[int] = (16,)
|
88
|
+
prompt_filename: str = ""
|
87
89
|
result_filename: str = "result.jsonl"
|
88
90
|
correctness_test: bool = False
|
89
91
|
# This is only used for correctness test
|
90
92
|
cut_len: int = 4
|
91
93
|
log_decode_step: int = 0
|
92
94
|
profile: bool = False
|
95
|
+
profile_record_shapes: bool = False
|
93
96
|
profile_filename_prefix: str = "profile"
|
94
97
|
|
95
98
|
@staticmethod
|
@@ -104,6 +107,9 @@ class BenchArgs:
|
|
104
107
|
parser.add_argument(
|
105
108
|
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
106
109
|
)
|
110
|
+
parser.add_argument(
|
111
|
+
"--prompt-filename", type=str, default=BenchArgs.prompt_filename
|
112
|
+
)
|
107
113
|
parser.add_argument(
|
108
114
|
"--result-filename", type=str, default=BenchArgs.result_filename
|
109
115
|
)
|
@@ -118,6 +124,11 @@ class BenchArgs:
|
|
118
124
|
parser.add_argument(
|
119
125
|
"--profile", action="store_true", help="Use Torch Profiler."
|
120
126
|
)
|
127
|
+
parser.add_argument(
|
128
|
+
"--profile-record-shapes",
|
129
|
+
action="store_true",
|
130
|
+
help="Record tensor shapes in profiling results.",
|
131
|
+
)
|
121
132
|
parser.add_argument(
|
122
133
|
"--profile-filename-prefix",
|
123
134
|
type=str,
|
@@ -165,12 +176,16 @@ def load_model(server_args, port_args, tp_rank):
|
|
165
176
|
return model_runner, tokenizer
|
166
177
|
|
167
178
|
|
168
|
-
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
169
|
-
prompts =
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
179
|
+
def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
|
180
|
+
prompts = (
|
181
|
+
custom_prompts
|
182
|
+
if custom_prompts
|
183
|
+
else [
|
184
|
+
"The capital of France is",
|
185
|
+
"The capital of the United Kindom is",
|
186
|
+
"Today is a sunny day and I like",
|
187
|
+
]
|
188
|
+
)
|
174
189
|
input_ids = [tokenizer.encode(p) for p in prompts]
|
175
190
|
sampling_params = SamplingParams(
|
176
191
|
temperature=0,
|
@@ -211,8 +226,14 @@ def prepare_extend_inputs_for_correctness_test(
|
|
211
226
|
return reqs
|
212
227
|
|
213
228
|
|
214
|
-
def prepare_synthetic_inputs_for_latency_test(
|
215
|
-
|
229
|
+
def prepare_synthetic_inputs_for_latency_test(
|
230
|
+
batch_size, input_len, custom_inputs=None
|
231
|
+
):
|
232
|
+
input_ids = (
|
233
|
+
custom_inputs
|
234
|
+
if custom_inputs
|
235
|
+
else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
|
236
|
+
)
|
216
237
|
sampling_params = SamplingParams(
|
217
238
|
temperature=0,
|
218
239
|
max_new_tokens=BenchArgs.output_len,
|
@@ -284,6 +305,30 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
|
284
305
|
)
|
285
306
|
|
286
307
|
|
308
|
+
def _read_prompts_from_file(prompt_file, rank_print):
|
309
|
+
"""Read custom prompts from the file specified by `--prompt-filename`."""
|
310
|
+
if not prompt_file:
|
311
|
+
return []
|
312
|
+
if not os.path.exists(prompt_file):
|
313
|
+
rank_print(
|
314
|
+
f"Custom prompt file {prompt_file} not found. Using default inputs..."
|
315
|
+
)
|
316
|
+
return []
|
317
|
+
with open(prompt_file, "r") as pf:
|
318
|
+
return pf.readlines()
|
319
|
+
|
320
|
+
|
321
|
+
def _save_profile_trace_results(profiler, filename):
|
322
|
+
parent_dir = os.path.dirname(os.path.abspath(filename))
|
323
|
+
os.makedirs(parent_dir, exist_ok=True)
|
324
|
+
profiler.export_chrome_trace(filename)
|
325
|
+
print(
|
326
|
+
profiler.key_averages(group_by_input_shape=True).table(
|
327
|
+
sort_by="self_cpu_time_total"
|
328
|
+
)
|
329
|
+
)
|
330
|
+
|
331
|
+
|
287
332
|
def correctness_test(
|
288
333
|
server_args,
|
289
334
|
port_args,
|
@@ -298,7 +343,10 @@ def correctness_test(
|
|
298
343
|
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
299
344
|
|
300
345
|
# Prepare inputs
|
301
|
-
|
346
|
+
custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
347
|
+
input_ids, reqs = prepare_inputs_for_correctness_test(
|
348
|
+
bench_args, tokenizer, custom_prompts
|
349
|
+
)
|
302
350
|
rank_print(f"\n{input_ids=}\n")
|
303
351
|
|
304
352
|
if bench_args.cut_len > 0:
|
@@ -344,6 +392,7 @@ def latency_test_run_once(
|
|
344
392
|
device,
|
345
393
|
log_decode_step,
|
346
394
|
profile,
|
395
|
+
profile_record_shapes,
|
347
396
|
profile_filename_prefix,
|
348
397
|
):
|
349
398
|
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
@@ -374,6 +423,7 @@ def latency_test_run_once(
|
|
374
423
|
torch.profiler.ProfilerActivity.CUDA,
|
375
424
|
],
|
376
425
|
with_stack=True,
|
426
|
+
record_shapes=profile_record_shapes,
|
377
427
|
)
|
378
428
|
profiler.start()
|
379
429
|
|
@@ -391,10 +441,30 @@ def latency_test_run_once(
|
|
391
441
|
measurement_results["prefill_latency"] = prefill_latency
|
392
442
|
measurement_results["prefill_throughput"] = throughput
|
393
443
|
|
444
|
+
if profile:
|
445
|
+
profiler.stop()
|
446
|
+
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
|
447
|
+
_save_profile_trace_results(profiler, profile_filename)
|
448
|
+
rank_print(
|
449
|
+
f"torch profiler chrome trace for prefill saved to {profile_filename}"
|
450
|
+
)
|
451
|
+
|
394
452
|
# Decode
|
395
453
|
decode_latencies = []
|
396
454
|
for i in range(output_len - 1):
|
397
455
|
synchronize(device)
|
456
|
+
if profile and i == output_len / 2:
|
457
|
+
profiler = None
|
458
|
+
profiler = torch.profiler.profile(
|
459
|
+
activities=[
|
460
|
+
torch.profiler.ProfilerActivity.CPU,
|
461
|
+
torch.profiler.ProfilerActivity.CUDA,
|
462
|
+
],
|
463
|
+
with_stack=True,
|
464
|
+
record_shapes=profile_record_shapes,
|
465
|
+
)
|
466
|
+
profiler.start()
|
467
|
+
|
398
468
|
tic = time.perf_counter()
|
399
469
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
400
470
|
synchronize(device)
|
@@ -407,13 +477,13 @@ def latency_test_run_once(
|
|
407
477
|
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
408
478
|
)
|
409
479
|
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
480
|
+
if profile and i == output_len / 2:
|
481
|
+
profiler.stop()
|
482
|
+
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
|
483
|
+
_save_profile_trace_results(profiler, profile_filename)
|
484
|
+
rank_print(
|
485
|
+
f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
|
486
|
+
)
|
417
487
|
|
418
488
|
# Record decode timing from 2nd output
|
419
489
|
if output_len > 1:
|
@@ -469,17 +539,42 @@ def latency_test(
|
|
469
539
|
server_args.device,
|
470
540
|
log_decode_step=0,
|
471
541
|
profile=False,
|
542
|
+
profile_record_shapes=False,
|
472
543
|
profile_filename_prefix="", # not used
|
473
544
|
)
|
474
545
|
|
475
546
|
rank_print("Benchmark ...")
|
476
547
|
|
548
|
+
custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
549
|
+
custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
|
550
|
+
custom_input_len = len(custom_inputs)
|
551
|
+
|
477
552
|
# Run the sweep
|
478
553
|
result_list = []
|
479
554
|
for bs, il, ol in itertools.product(
|
480
555
|
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
481
556
|
):
|
482
|
-
|
557
|
+
bs_aligned_inputs = []
|
558
|
+
if custom_inputs:
|
559
|
+
if custom_input_len == bs:
|
560
|
+
bs_aligned_inputs = custom_inputs
|
561
|
+
elif custom_input_len > bs:
|
562
|
+
rank_print(
|
563
|
+
f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
|
564
|
+
f"Using the first {bs} prompts."
|
565
|
+
)
|
566
|
+
bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
|
567
|
+
else:
|
568
|
+
rank_print(
|
569
|
+
f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
|
570
|
+
f"Pad to the desired batch_size with the last prompt."
|
571
|
+
)
|
572
|
+
bs_aligned_inputs = copy.deepcopy(custom_inputs)
|
573
|
+
bs_aligned_inputs.extend(
|
574
|
+
[bs_aligned_inputs[-1]] * (bs - custom_input_len)
|
575
|
+
)
|
576
|
+
|
577
|
+
reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
|
483
578
|
ret = latency_test_run_once(
|
484
579
|
bench_args.run_name,
|
485
580
|
model_runner,
|
@@ -491,6 +586,7 @@ def latency_test(
|
|
491
586
|
server_args.device,
|
492
587
|
bench_args.log_decode_step,
|
493
588
|
bench_args.profile if tp_rank == 0 else None,
|
589
|
+
bench_args.profile_record_shapes if tp_rank == 0 else None,
|
494
590
|
bench_args.profile_filename_prefix,
|
495
591
|
)
|
496
592
|
if ret is not None:
|
sglang/compile_deep_gemm.py
CHANGED
@@ -17,6 +17,7 @@ import time
|
|
17
17
|
|
18
18
|
import requests
|
19
19
|
|
20
|
+
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
|
20
21
|
from sglang.srt.entrypoints.http_server import launch_server
|
21
22
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
22
23
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
@@ -52,7 +53,9 @@ class CompileArgs:
|
|
52
53
|
|
53
54
|
|
54
55
|
@warmup("compile-deep-gemm")
|
55
|
-
async def warm_up_compile(
|
56
|
+
async def warm_up_compile(
|
57
|
+
disaggregation_mode: str, tokenizer_manager: TokenizerManager
|
58
|
+
):
|
56
59
|
print("\nGenerate warm up request for compiling DeepGEMM...\n")
|
57
60
|
generate_req_input = GenerateReqInput(
|
58
61
|
input_ids=[0, 1, 2, 3],
|
@@ -62,6 +65,10 @@ async def warm_up_compile(tokenizer_manager: TokenizerManager):
|
|
62
65
|
"ignore_eos": True,
|
63
66
|
},
|
64
67
|
)
|
68
|
+
if disaggregation_mode != "null":
|
69
|
+
generate_req_input.bootstrap_room = 0
|
70
|
+
generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
71
|
+
|
65
72
|
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|
66
73
|
|
67
74
|
|
sglang/global_config.py
CHANGED
@@ -30,7 +30,11 @@ class GlobalConfig:
|
|
30
30
|
self.default_new_token_ratio_decay_steps = float(
|
31
31
|
os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
|
32
32
|
)
|
33
|
-
|
33
|
+
self.torch_empty_cache_interval = float(
|
34
|
+
os.environ.get(
|
35
|
+
"SGLANG_EMPTY_CACHE_INTERVAL", -1
|
36
|
+
) # in seconds. Set if you observe high memory accumulation over a long serving period.
|
37
|
+
)
|
34
38
|
# Runtime constants: others
|
35
39
|
self.retract_decode_steps = 20
|
36
40
|
self.flashinfer_workspace_size = os.environ.get(
|
@@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import (
|
|
27
27
|
get_context_length,
|
28
28
|
get_generation_config,
|
29
29
|
get_hf_text_config,
|
30
|
+
get_sparse_attention_config,
|
30
31
|
)
|
31
32
|
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
32
33
|
from sglang.srt.server_args import ServerArgs
|
@@ -133,6 +134,11 @@ class ModelConfig:
|
|
133
134
|
|
134
135
|
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
|
135
136
|
self.hf_config.architectures[0] = "MiMoMTP"
|
137
|
+
if (
|
138
|
+
is_draft_model
|
139
|
+
and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
|
140
|
+
):
|
141
|
+
self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
|
136
142
|
# Check model type
|
137
143
|
self.is_generation = is_generation_model(
|
138
144
|
self.hf_config.architectures, is_embedding
|
@@ -270,6 +276,9 @@ class ModelConfig:
|
|
270
276
|
# Verify quantization
|
271
277
|
self._verify_quantization()
|
272
278
|
|
279
|
+
# Verify dual-chunk attention config
|
280
|
+
self._verify_dual_chunk_attention_config()
|
281
|
+
|
273
282
|
# Cache attributes
|
274
283
|
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
275
284
|
|
@@ -297,6 +306,13 @@ class ModelConfig:
|
|
297
306
|
**kwargs,
|
298
307
|
)
|
299
308
|
|
309
|
+
def get_total_num_attention_heads(self) -> int:
|
310
|
+
return self.num_attention_heads
|
311
|
+
|
312
|
+
def get_num_attention_heads(self, tensor_parallel_size) -> int:
|
313
|
+
total_num_attention_heads = self.num_attention_heads
|
314
|
+
return max(1, total_num_attention_heads // tensor_parallel_size)
|
315
|
+
|
300
316
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
301
317
|
def get_total_num_kv_heads(self) -> int:
|
302
318
|
"""Returns the total number of KV heads."""
|
@@ -401,6 +417,8 @@ class ModelConfig:
|
|
401
417
|
"fbgemm_fp8",
|
402
418
|
"w8a8_fp8",
|
403
419
|
"petit_nvfp4",
|
420
|
+
"quark",
|
421
|
+
"mxfp4",
|
404
422
|
]
|
405
423
|
optimized_quantization_methods = [
|
406
424
|
"fp8",
|
@@ -482,6 +500,23 @@ class ModelConfig:
|
|
482
500
|
self.quantization,
|
483
501
|
)
|
484
502
|
|
503
|
+
def _verify_dual_chunk_attention_config(self) -> None:
|
504
|
+
if hasattr(self.hf_config, "dual_chunk_attention_config"):
|
505
|
+
# Try loading the sparse attention config
|
506
|
+
sparse_attn_config = get_sparse_attention_config(self.model_path)
|
507
|
+
if not sparse_attn_config:
|
508
|
+
return
|
509
|
+
self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
|
510
|
+
sparse_attn_config
|
511
|
+
)
|
512
|
+
if (
|
513
|
+
"sparse_attention_enabled"
|
514
|
+
not in self.hf_config.dual_chunk_attention_config
|
515
|
+
):
|
516
|
+
self.hf_config.dual_chunk_attention_config[
|
517
|
+
"sparse_attention_enabled"
|
518
|
+
] = True
|
519
|
+
|
485
520
|
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
486
521
|
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
487
522
|
if eos_ids is not None:
|
sglang/srt/conversation.py
CHANGED
@@ -30,8 +30,10 @@ import re
|
|
30
30
|
from enum import IntEnum, auto
|
31
31
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
32
32
|
|
33
|
+
from typing_extensions import Literal
|
34
|
+
|
33
35
|
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
34
|
-
from sglang.srt.utils import read_system_prompt_from_file
|
36
|
+
from sglang.srt.utils import ImageData, read_system_prompt_from_file
|
35
37
|
|
36
38
|
|
37
39
|
class SeparatorStyle(IntEnum):
|
@@ -91,7 +93,7 @@ class Conversation:
|
|
91
93
|
video_token: str = "<video>"
|
92
94
|
audio_token: str = "<audio>"
|
93
95
|
|
94
|
-
image_data: Optional[List[
|
96
|
+
image_data: Optional[List[ImageData]] = None
|
95
97
|
video_data: Optional[List[str]] = None
|
96
98
|
modalities: Optional[List[str]] = None
|
97
99
|
stop_token_ids: Optional[int] = None
|
@@ -381,9 +383,9 @@ class Conversation:
|
|
381
383
|
"""Append a new message."""
|
382
384
|
self.messages.append([role, message])
|
383
385
|
|
384
|
-
def append_image(self, image: str):
|
386
|
+
def append_image(self, image: str, detail: Literal["auto", "low", "high"]):
|
385
387
|
"""Append a new image."""
|
386
|
-
self.image_data.append(image)
|
388
|
+
self.image_data.append(ImageData(url=image, detail=detail))
|
387
389
|
|
388
390
|
def append_video(self, video: str):
|
389
391
|
"""Append a new video."""
|
@@ -627,7 +629,9 @@ def generate_chat_conv(
|
|
627
629
|
real_content = image_token + real_content
|
628
630
|
else:
|
629
631
|
real_content += image_token
|
630
|
-
conv.append_image(
|
632
|
+
conv.append_image(
|
633
|
+
content.image_url.url, content.image_url.detail
|
634
|
+
)
|
631
635
|
elif content.type == "video_url":
|
632
636
|
real_content += video_token
|
633
637
|
conv.append_video(content.video_url.url)
|
@@ -954,20 +958,6 @@ register_conv_template(
|
|
954
958
|
)
|
955
959
|
)
|
956
960
|
|
957
|
-
register_conv_template(
|
958
|
-
Conversation(
|
959
|
-
name="mimo-vl",
|
960
|
-
system_message="You are MiMo, an AI assistant developed by Xiaomi.",
|
961
|
-
system_template="<|im_start|>system\n{system_message}",
|
962
|
-
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
963
|
-
sep="<|im_end|>\n",
|
964
|
-
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
965
|
-
stop_str=["<|im_end|>"],
|
966
|
-
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
967
|
-
)
|
968
|
-
)
|
969
|
-
|
970
|
-
|
971
961
|
register_conv_template(
|
972
962
|
Conversation(
|
973
963
|
name="qwen2-audio",
|
@@ -981,51 +971,11 @@ register_conv_template(
|
|
981
971
|
)
|
982
972
|
)
|
983
973
|
|
984
|
-
register_conv_template(
|
985
|
-
Conversation(
|
986
|
-
name="llama_4_vision",
|
987
|
-
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
|
988
|
-
system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
|
989
|
-
roles=("user", "assistant"),
|
990
|
-
sep_style=SeparatorStyle.LLAMA4,
|
991
|
-
sep="",
|
992
|
-
stop_str="<|eot|>",
|
993
|
-
image_token="<|image|>",
|
994
|
-
)
|
995
|
-
)
|
996
|
-
|
997
|
-
register_conv_template(
|
998
|
-
Conversation(
|
999
|
-
name="step3-vl",
|
1000
|
-
system_message="<|begin▁of▁sentence|>You are a helpful assistant",
|
1001
|
-
system_template="{system_message}\n",
|
1002
|
-
roles=(
|
1003
|
-
"<|BOT|>user\n",
|
1004
|
-
"<|BOT|>assistant\n<think>\n",
|
1005
|
-
),
|
1006
|
-
sep="<|EOT|>",
|
1007
|
-
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
1008
|
-
stop_str="<|EOT|>",
|
1009
|
-
image_token="<im_patch>",
|
1010
|
-
# add_bos=True,
|
1011
|
-
)
|
1012
|
-
)
|
1013
|
-
|
1014
974
|
|
1015
975
|
@register_conv_template_matching_function
|
1016
976
|
def match_internvl(model_path: str):
|
1017
977
|
if re.search(r"internvl", model_path, re.IGNORECASE):
|
1018
978
|
return "internvl-2-5"
|
1019
|
-
if re.search(r"intern.*s1", model_path, re.IGNORECASE):
|
1020
|
-
return "interns1"
|
1021
|
-
|
1022
|
-
|
1023
|
-
@register_conv_template_matching_function
|
1024
|
-
def match_llama_vision(model_path: str):
|
1025
|
-
if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
|
1026
|
-
return "llama_3_vision"
|
1027
|
-
if re.search(r"llama.*4.*", model_path, re.IGNORECASE):
|
1028
|
-
return "llama_4_vision"
|
1029
979
|
|
1030
980
|
|
1031
981
|
@register_conv_template_matching_function
|
@@ -1040,22 +990,6 @@ def match_vicuna(model_path: str):
|
|
1040
990
|
return "vicuna_v1.1"
|
1041
991
|
|
1042
992
|
|
1043
|
-
@register_conv_template_matching_function
|
1044
|
-
def match_llama2_chat(model_path: str):
|
1045
|
-
if re.search(
|
1046
|
-
r"llama-2.*chat|codellama.*instruct",
|
1047
|
-
model_path,
|
1048
|
-
re.IGNORECASE,
|
1049
|
-
):
|
1050
|
-
return "llama-2"
|
1051
|
-
|
1052
|
-
|
1053
|
-
@register_conv_template_matching_function
|
1054
|
-
def match_mistral(model_path: str):
|
1055
|
-
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
|
1056
|
-
return "mistral"
|
1057
|
-
|
1058
|
-
|
1059
993
|
@register_conv_template_matching_function
|
1060
994
|
def match_deepseek_vl(model_path: str):
|
1061
995
|
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
|
@@ -1064,12 +998,6 @@ def match_deepseek_vl(model_path: str):
|
|
1064
998
|
|
1065
999
|
@register_conv_template_matching_function
|
1066
1000
|
def match_qwen_chat_ml(model_path: str):
|
1067
|
-
if re.search(r"gme.*qwen.*vl", model_path, re.IGNORECASE):
|
1068
|
-
return "gme-qwen2-vl"
|
1069
|
-
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
|
1070
|
-
return "qwen2-vl"
|
1071
|
-
if re.search(r"qwen.*audio", model_path, re.IGNORECASE):
|
1072
|
-
return "qwen2-audio"
|
1073
1001
|
if re.search(
|
1074
1002
|
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
|
1075
1003
|
model_path,
|
@@ -1078,12 +1006,6 @@ def match_qwen_chat_ml(model_path: str):
|
|
1078
1006
|
return "chatml-llava"
|
1079
1007
|
|
1080
1008
|
|
1081
|
-
@register_conv_template_matching_function
|
1082
|
-
def match_gemma3_instruct(model_path: str):
|
1083
|
-
if re.search(r"gemma-3.*it", model_path, re.IGNORECASE):
|
1084
|
-
return "gemma-it"
|
1085
|
-
|
1086
|
-
|
1087
1009
|
@register_conv_template_matching_function
|
1088
1010
|
def match_openbmb_minicpm(model_path: str):
|
1089
1011
|
if re.search(r"minicpm-v", model_path, re.IGNORECASE):
|
@@ -1092,37 +1014,7 @@ def match_openbmb_minicpm(model_path: str):
|
|
1092
1014
|
return "minicpmo"
|
1093
1015
|
|
1094
1016
|
|
1095
|
-
@register_conv_template_matching_function
|
1096
|
-
def match_moonshot_kimivl(model_path: str):
|
1097
|
-
if re.search(r"kimi.*vl", model_path, re.IGNORECASE):
|
1098
|
-
return "kimi-vl"
|
1099
|
-
|
1100
|
-
|
1101
|
-
@register_conv_template_matching_function
|
1102
|
-
def match_devstral(model_path: str):
|
1103
|
-
if re.search(r"devstral", model_path, re.IGNORECASE):
|
1104
|
-
return "devstral"
|
1105
|
-
|
1106
|
-
|
1107
1017
|
@register_conv_template_matching_function
|
1108
1018
|
def match_phi_4_mm(model_path: str):
|
1109
1019
|
if "phi-4-multimodal" in model_path.lower():
|
1110
1020
|
return "phi-4-mm"
|
1111
|
-
|
1112
|
-
|
1113
|
-
@register_conv_template_matching_function
|
1114
|
-
def match_vila(model_path: str):
|
1115
|
-
if re.search(r"vila", model_path, re.IGNORECASE):
|
1116
|
-
return "chatml"
|
1117
|
-
|
1118
|
-
|
1119
|
-
@register_conv_template_matching_function
|
1120
|
-
def match_mimo_vl(model_path: str):
|
1121
|
-
if re.search(r"mimo.*vl", model_path, re.IGNORECASE):
|
1122
|
-
return "mimo-vl"
|
1123
|
-
|
1124
|
-
|
1125
|
-
# @register_conv_template_matching_function
|
1126
|
-
# def match_step3(model_path: str):
|
1127
|
-
# if re.search(r"step3", model_path, re.IGNORECASE):
|
1128
|
-
# return "step3-vl"
|
@@ -25,10 +25,13 @@ class KVArgs:
|
|
25
25
|
gpu_id: int
|
26
26
|
# for different tp
|
27
27
|
decode_tp_size: int
|
28
|
-
# for pp prefill
|
29
|
-
prefill_pp_size: int
|
30
28
|
kv_head_num: int
|
31
29
|
page_size: int
|
30
|
+
# for pp prefill
|
31
|
+
prefill_pp_size: int
|
32
|
+
pp_rank: int
|
33
|
+
# for system dp
|
34
|
+
system_dp_rank: int
|
32
35
|
|
33
36
|
|
34
37
|
class KVPoll:
|
@@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
|
|
44
44
|
poll_and_all_reduce,
|
45
45
|
prepare_abort,
|
46
46
|
)
|
47
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
47
48
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
48
49
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
49
50
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
@@ -184,9 +185,13 @@ class DecodePreallocQueue:
|
|
184
185
|
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
185
186
|
kv_args = kv_args_class()
|
186
187
|
|
187
|
-
attn_tp_size =
|
188
|
+
attn_tp_size = get_attention_tp_size()
|
188
189
|
kv_args.engine_rank = self.tp_rank % (attn_tp_size)
|
190
|
+
|
189
191
|
kv_args.decode_tp_size = attn_tp_size
|
192
|
+
# Note(shangming): pp is not supported on the decode side yet, so its rank is fixed to 0
|
193
|
+
kv_args.pp_rank = 0
|
194
|
+
kv_args.system_dp_rank = self.scheduler.dp_rank
|
190
195
|
kv_args.prefill_pp_size = self.prefill_pp_size
|
191
196
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
192
197
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
@@ -76,6 +76,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
76
76
|
req_pool_indices, dtype=torch.int64, device=self.device
|
77
77
|
)
|
78
78
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
79
|
+
self.orig_seq_lens = torch.tensor(
|
80
|
+
seq_lens, dtype=torch.int32, device=self.device
|
81
|
+
)
|
79
82
|
self.out_cache_loc = out_cache_loc
|
80
83
|
self.seq_lens_sum = sum(seq_lens)
|
81
84
|
|
@@ -88,6 +91,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
88
91
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
89
92
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
90
93
|
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
94
|
+
self.multimodal_inputs = [r.multimodal_inputs for r in reqs]
|
91
95
|
|
92
96
|
# Build sampling info
|
93
97
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|