sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py
CHANGED
@@ -43,6 +43,7 @@ I'm going to the park
|
|
43
43
|
"""
|
44
44
|
|
45
45
|
import argparse
|
46
|
+
import copy
|
46
47
|
import dataclasses
|
47
48
|
import itertools
|
48
49
|
import json
|
@@ -84,12 +85,14 @@ class BenchArgs:
|
|
84
85
|
batch_size: Tuple[int] = (1,)
|
85
86
|
input_len: Tuple[int] = (1024,)
|
86
87
|
output_len: Tuple[int] = (16,)
|
88
|
+
prompt_filename: str = ""
|
87
89
|
result_filename: str = "result.jsonl"
|
88
90
|
correctness_test: bool = False
|
89
91
|
# This is only used for correctness test
|
90
92
|
cut_len: int = 4
|
91
93
|
log_decode_step: int = 0
|
92
94
|
profile: bool = False
|
95
|
+
profile_record_shapes: bool = False
|
93
96
|
profile_filename_prefix: str = "profile"
|
94
97
|
|
95
98
|
@staticmethod
|
@@ -104,6 +107,9 @@ class BenchArgs:
|
|
104
107
|
parser.add_argument(
|
105
108
|
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
106
109
|
)
|
110
|
+
parser.add_argument(
|
111
|
+
"--prompt-filename", type=str, default=BenchArgs.prompt_filename
|
112
|
+
)
|
107
113
|
parser.add_argument(
|
108
114
|
"--result-filename", type=str, default=BenchArgs.result_filename
|
109
115
|
)
|
@@ -118,6 +124,11 @@ class BenchArgs:
|
|
118
124
|
parser.add_argument(
|
119
125
|
"--profile", action="store_true", help="Use Torch Profiler."
|
120
126
|
)
|
127
|
+
parser.add_argument(
|
128
|
+
"--profile-record-shapes",
|
129
|
+
action="store_true",
|
130
|
+
help="Record tensor shapes in profiling results.",
|
131
|
+
)
|
121
132
|
parser.add_argument(
|
122
133
|
"--profile-filename-prefix",
|
123
134
|
type=str,
|
@@ -165,12 +176,16 @@ def load_model(server_args, port_args, tp_rank):
|
|
165
176
|
return model_runner, tokenizer
|
166
177
|
|
167
178
|
|
168
|
-
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
169
|
-
prompts =
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
179
|
+
def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
|
180
|
+
prompts = (
|
181
|
+
custom_prompts
|
182
|
+
if custom_prompts
|
183
|
+
else [
|
184
|
+
"The capital of France is",
|
185
|
+
"The capital of the United Kindom is",
|
186
|
+
"Today is a sunny day and I like",
|
187
|
+
]
|
188
|
+
)
|
174
189
|
input_ids = [tokenizer.encode(p) for p in prompts]
|
175
190
|
sampling_params = SamplingParams(
|
176
191
|
temperature=0,
|
@@ -211,8 +226,14 @@ def prepare_extend_inputs_for_correctness_test(
|
|
211
226
|
return reqs
|
212
227
|
|
213
228
|
|
214
|
-
def prepare_synthetic_inputs_for_latency_test(
|
215
|
-
|
229
|
+
def prepare_synthetic_inputs_for_latency_test(
|
230
|
+
batch_size, input_len, custom_inputs=None
|
231
|
+
):
|
232
|
+
input_ids = (
|
233
|
+
custom_inputs
|
234
|
+
if custom_inputs
|
235
|
+
else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
|
236
|
+
)
|
216
237
|
sampling_params = SamplingParams(
|
217
238
|
temperature=0,
|
218
239
|
max_new_tokens=BenchArgs.output_len,
|
@@ -284,6 +305,30 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
|
284
305
|
)
|
285
306
|
|
286
307
|
|
308
|
+
def _read_prompts_from_file(prompt_file, rank_print):
|
309
|
+
"""Read custom prompts from the file specified by `--prompt-filename`."""
|
310
|
+
if not prompt_file:
|
311
|
+
return []
|
312
|
+
if not os.path.exists(prompt_file):
|
313
|
+
rank_print(
|
314
|
+
f"Custom prompt file {prompt_file} not found. Using default inputs..."
|
315
|
+
)
|
316
|
+
return []
|
317
|
+
with open(prompt_file, "r") as pf:
|
318
|
+
return pf.readlines()
|
319
|
+
|
320
|
+
|
321
|
+
def _save_profile_trace_results(profiler, filename):
|
322
|
+
parent_dir = os.path.dirname(os.path.abspath(filename))
|
323
|
+
os.makedirs(parent_dir, exist_ok=True)
|
324
|
+
profiler.export_chrome_trace(filename)
|
325
|
+
print(
|
326
|
+
profiler.key_averages(group_by_input_shape=True).table(
|
327
|
+
sort_by="self_cpu_time_total"
|
328
|
+
)
|
329
|
+
)
|
330
|
+
|
331
|
+
|
287
332
|
def correctness_test(
|
288
333
|
server_args,
|
289
334
|
port_args,
|
@@ -298,7 +343,10 @@ def correctness_test(
|
|
298
343
|
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
299
344
|
|
300
345
|
# Prepare inputs
|
301
|
-
|
346
|
+
custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
347
|
+
input_ids, reqs = prepare_inputs_for_correctness_test(
|
348
|
+
bench_args, tokenizer, custom_prompts
|
349
|
+
)
|
302
350
|
rank_print(f"\n{input_ids=}\n")
|
303
351
|
|
304
352
|
if bench_args.cut_len > 0:
|
@@ -344,6 +392,7 @@ def latency_test_run_once(
|
|
344
392
|
device,
|
345
393
|
log_decode_step,
|
346
394
|
profile,
|
395
|
+
profile_record_shapes,
|
347
396
|
profile_filename_prefix,
|
348
397
|
):
|
349
398
|
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
@@ -374,6 +423,7 @@ def latency_test_run_once(
|
|
374
423
|
torch.profiler.ProfilerActivity.CUDA,
|
375
424
|
],
|
376
425
|
with_stack=True,
|
426
|
+
record_shapes=profile_record_shapes,
|
377
427
|
)
|
378
428
|
profiler.start()
|
379
429
|
|
@@ -391,10 +441,30 @@ def latency_test_run_once(
|
|
391
441
|
measurement_results["prefill_latency"] = prefill_latency
|
392
442
|
measurement_results["prefill_throughput"] = throughput
|
393
443
|
|
444
|
+
if profile:
|
445
|
+
profiler.stop()
|
446
|
+
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
|
447
|
+
_save_profile_trace_results(profiler, profile_filename)
|
448
|
+
rank_print(
|
449
|
+
f"torch profiler chrome trace for prefill saved to {profile_filename}"
|
450
|
+
)
|
451
|
+
|
394
452
|
# Decode
|
395
453
|
decode_latencies = []
|
396
454
|
for i in range(output_len - 1):
|
397
455
|
synchronize(device)
|
456
|
+
if profile and i == output_len / 2:
|
457
|
+
profiler = None
|
458
|
+
profiler = torch.profiler.profile(
|
459
|
+
activities=[
|
460
|
+
torch.profiler.ProfilerActivity.CPU,
|
461
|
+
torch.profiler.ProfilerActivity.CUDA,
|
462
|
+
],
|
463
|
+
with_stack=True,
|
464
|
+
record_shapes=profile_record_shapes,
|
465
|
+
)
|
466
|
+
profiler.start()
|
467
|
+
|
398
468
|
tic = time.perf_counter()
|
399
469
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
400
470
|
synchronize(device)
|
@@ -407,13 +477,13 @@ def latency_test_run_once(
|
|
407
477
|
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
408
478
|
)
|
409
479
|
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
480
|
+
if profile and i == output_len / 2:
|
481
|
+
profiler.stop()
|
482
|
+
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
|
483
|
+
_save_profile_trace_results(profiler, profile_filename)
|
484
|
+
rank_print(
|
485
|
+
f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
|
486
|
+
)
|
417
487
|
|
418
488
|
# Record decode timing from 2nd output
|
419
489
|
if output_len > 1:
|
@@ -469,17 +539,42 @@ def latency_test(
|
|
469
539
|
server_args.device,
|
470
540
|
log_decode_step=0,
|
471
541
|
profile=False,
|
542
|
+
profile_record_shapes=False,
|
472
543
|
profile_filename_prefix="", # not used
|
473
544
|
)
|
474
545
|
|
475
546
|
rank_print("Benchmark ...")
|
476
547
|
|
548
|
+
custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
549
|
+
custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
|
550
|
+
custom_input_len = len(custom_inputs)
|
551
|
+
|
477
552
|
# Run the sweep
|
478
553
|
result_list = []
|
479
554
|
for bs, il, ol in itertools.product(
|
480
555
|
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
481
556
|
):
|
482
|
-
|
557
|
+
bs_aligned_inputs = []
|
558
|
+
if custom_inputs:
|
559
|
+
if custom_input_len == bs:
|
560
|
+
bs_aligned_inputs = custom_inputs
|
561
|
+
elif custom_input_len > bs:
|
562
|
+
rank_print(
|
563
|
+
f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
|
564
|
+
f"Using the first {bs} prompts."
|
565
|
+
)
|
566
|
+
bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
|
567
|
+
else:
|
568
|
+
rank_print(
|
569
|
+
f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
|
570
|
+
f"Pad to the desired batch_size with the last prompt."
|
571
|
+
)
|
572
|
+
bs_aligned_inputs = copy.deepcopy(custom_inputs)
|
573
|
+
bs_aligned_inputs.extend(
|
574
|
+
[bs_aligned_inputs[-1]] * (bs - custom_input_len)
|
575
|
+
)
|
576
|
+
|
577
|
+
reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
|
483
578
|
ret = latency_test_run_once(
|
484
579
|
bench_args.run_name,
|
485
580
|
model_runner,
|
@@ -491,6 +586,7 @@ def latency_test(
|
|
491
586
|
server_args.device,
|
492
587
|
bench_args.log_decode_step,
|
493
588
|
bench_args.profile if tp_rank == 0 else None,
|
589
|
+
bench_args.profile_record_shapes if tp_rank == 0 else None,
|
494
590
|
bench_args.profile_filename_prefix,
|
495
591
|
)
|
496
592
|
if ret is not None:
|
@@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import (
|
|
27
27
|
get_context_length,
|
28
28
|
get_generation_config,
|
29
29
|
get_hf_text_config,
|
30
|
+
get_sparse_attention_config,
|
30
31
|
)
|
31
32
|
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
32
33
|
from sglang.srt.server_args import ServerArgs
|
@@ -133,6 +134,11 @@ class ModelConfig:
|
|
133
134
|
|
134
135
|
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
|
135
136
|
self.hf_config.architectures[0] = "MiMoMTP"
|
137
|
+
if (
|
138
|
+
is_draft_model
|
139
|
+
and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
|
140
|
+
):
|
141
|
+
self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
|
136
142
|
# Check model type
|
137
143
|
self.is_generation = is_generation_model(
|
138
144
|
self.hf_config.architectures, is_embedding
|
@@ -270,6 +276,9 @@ class ModelConfig:
|
|
270
276
|
# Verify quantization
|
271
277
|
self._verify_quantization()
|
272
278
|
|
279
|
+
# Verify dual-chunk attention config
|
280
|
+
self._verify_dual_chunk_attention_config()
|
281
|
+
|
273
282
|
# Cache attributes
|
274
283
|
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
275
284
|
|
@@ -297,6 +306,13 @@ class ModelConfig:
|
|
297
306
|
**kwargs,
|
298
307
|
)
|
299
308
|
|
309
|
+
def get_total_num_attention_heads(self) -> int:
|
310
|
+
return self.num_attention_heads
|
311
|
+
|
312
|
+
def get_num_attention_heads(self, tensor_parallel_size) -> int:
|
313
|
+
total_num_attention_heads = self.num_attention_heads
|
314
|
+
return max(1, total_num_attention_heads // tensor_parallel_size)
|
315
|
+
|
300
316
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
301
317
|
def get_total_num_kv_heads(self) -> int:
|
302
318
|
"""Returns the total number of KV heads."""
|
@@ -401,6 +417,8 @@ class ModelConfig:
|
|
401
417
|
"fbgemm_fp8",
|
402
418
|
"w8a8_fp8",
|
403
419
|
"petit_nvfp4",
|
420
|
+
"quark",
|
421
|
+
"mxfp4",
|
404
422
|
]
|
405
423
|
optimized_quantization_methods = [
|
406
424
|
"fp8",
|
@@ -482,6 +500,23 @@ class ModelConfig:
|
|
482
500
|
self.quantization,
|
483
501
|
)
|
484
502
|
|
503
|
+
def _verify_dual_chunk_attention_config(self) -> None:
|
504
|
+
if hasattr(self.hf_config, "dual_chunk_attention_config"):
|
505
|
+
# Try loading the sparse attention config
|
506
|
+
sparse_attn_config = get_sparse_attention_config(self.model_path)
|
507
|
+
if not sparse_attn_config:
|
508
|
+
return
|
509
|
+
self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
|
510
|
+
sparse_attn_config
|
511
|
+
)
|
512
|
+
if (
|
513
|
+
"sparse_attention_enabled"
|
514
|
+
not in self.hf_config.dual_chunk_attention_config
|
515
|
+
):
|
516
|
+
self.hf_config.dual_chunk_attention_config[
|
517
|
+
"sparse_attention_enabled"
|
518
|
+
] = True
|
519
|
+
|
485
520
|
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
486
521
|
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
487
522
|
if eos_ids is not None:
|
sglang/srt/conversation.py
CHANGED
@@ -30,8 +30,10 @@ import re
|
|
30
30
|
from enum import IntEnum, auto
|
31
31
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
32
32
|
|
33
|
+
from typing_extensions import Literal
|
34
|
+
|
33
35
|
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
34
|
-
from sglang.srt.utils import read_system_prompt_from_file
|
36
|
+
from sglang.srt.utils import ImageData, read_system_prompt_from_file
|
35
37
|
|
36
38
|
|
37
39
|
class SeparatorStyle(IntEnum):
|
@@ -91,7 +93,7 @@ class Conversation:
|
|
91
93
|
video_token: str = "<video>"
|
92
94
|
audio_token: str = "<audio>"
|
93
95
|
|
94
|
-
image_data: Optional[List[
|
96
|
+
image_data: Optional[List[ImageData]] = None
|
95
97
|
video_data: Optional[List[str]] = None
|
96
98
|
modalities: Optional[List[str]] = None
|
97
99
|
stop_token_ids: Optional[int] = None
|
@@ -381,9 +383,9 @@ class Conversation:
|
|
381
383
|
"""Append a new message."""
|
382
384
|
self.messages.append([role, message])
|
383
385
|
|
384
|
-
def append_image(self, image: str):
|
386
|
+
def append_image(self, image: str, detail: Literal["auto", "low", "high"]):
|
385
387
|
"""Append a new image."""
|
386
|
-
self.image_data.append(image)
|
388
|
+
self.image_data.append(ImageData(url=image, detail=detail))
|
387
389
|
|
388
390
|
def append_video(self, video: str):
|
389
391
|
"""Append a new video."""
|
@@ -627,7 +629,9 @@ def generate_chat_conv(
|
|
627
629
|
real_content = image_token + real_content
|
628
630
|
else:
|
629
631
|
real_content += image_token
|
630
|
-
conv.append_image(
|
632
|
+
conv.append_image(
|
633
|
+
content.image_url.url, content.image_url.detail
|
634
|
+
)
|
631
635
|
elif content.type == "video_url":
|
632
636
|
real_content += video_token
|
633
637
|
conv.append_video(content.video_url.url)
|
@@ -25,10 +25,13 @@ class KVArgs:
|
|
25
25
|
gpu_id: int
|
26
26
|
# for different tp
|
27
27
|
decode_tp_size: int
|
28
|
-
# for pp prefill
|
29
|
-
prefill_pp_size: int
|
30
28
|
kv_head_num: int
|
31
29
|
page_size: int
|
30
|
+
# for pp prefill
|
31
|
+
prefill_pp_size: int
|
32
|
+
pp_rank: int
|
33
|
+
# for system dp
|
34
|
+
system_dp_rank: int
|
32
35
|
|
33
36
|
|
34
37
|
class KVPoll:
|
@@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
|
|
44
44
|
poll_and_all_reduce,
|
45
45
|
prepare_abort,
|
46
46
|
)
|
47
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
47
48
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
48
49
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
49
50
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
@@ -184,9 +185,13 @@ class DecodePreallocQueue:
|
|
184
185
|
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
185
186
|
kv_args = kv_args_class()
|
186
187
|
|
187
|
-
attn_tp_size =
|
188
|
+
attn_tp_size = get_attention_tp_size()
|
188
189
|
kv_args.engine_rank = self.tp_rank % (attn_tp_size)
|
190
|
+
|
189
191
|
kv_args.decode_tp_size = attn_tp_size
|
192
|
+
# Note(shangming): pp is not supported on the decode side yet, so its rank is fixed to 0
|
193
|
+
kv_args.pp_rank = 0
|
194
|
+
kv_args.system_dp_rank = self.scheduler.dp_rank
|
190
195
|
kv_args.prefill_pp_size = self.prefill_pp_size
|
191
196
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
192
197
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
@@ -76,6 +76,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
76
76
|
req_pool_indices, dtype=torch.int64, device=self.device
|
77
77
|
)
|
78
78
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
79
|
+
self.orig_seq_lens = torch.tensor(
|
80
|
+
seq_lens, dtype=torch.int32, device=self.device
|
81
|
+
)
|
79
82
|
self.out_cache_loc = out_cache_loc
|
80
83
|
self.seq_lens_sum = sum(seq_lens)
|
81
84
|
|