sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +74 -8
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +213 -118
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +176 -683
- sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +71 -34
- sglang/srt/mem_cache/memory_pool.py +81 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +68 -20
- sglang/srt/model_executor/forward_batch_info.py +23 -10
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +200 -191
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +59 -35
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +24 -16
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -33,9 +33,9 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
33
33
|
ForwardBatch,
|
34
34
|
ForwardMode,
|
35
35
|
)
|
36
|
-
from sglang.srt.utils import is_hip
|
36
|
+
from sglang.srt.utils import get_available_gpu_memory, is_hip
|
37
37
|
|
38
|
-
|
38
|
+
_is_hip = is_hip()
|
39
39
|
|
40
40
|
if TYPE_CHECKING:
|
41
41
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -119,7 +119,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
119
119
|
else:
|
120
120
|
capture_bs = list(range(1, 33))
|
121
121
|
|
122
|
-
if
|
122
|
+
if _is_hip:
|
123
123
|
capture_bs += [i * 8 for i in range(21, 33)]
|
124
124
|
|
125
125
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
@@ -174,6 +174,7 @@ class CudaGraphRunner:
|
|
174
174
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
175
175
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
176
176
|
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
177
|
+
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
177
178
|
self.tp_size = model_runner.server_args.tp_size
|
178
179
|
self.dp_size = model_runner.server_args.dp_size
|
179
180
|
|
@@ -236,7 +237,7 @@ class CudaGraphRunner:
|
|
236
237
|
if self.enable_dp_attention:
|
237
238
|
self.gathered_buffer = torch.zeros(
|
238
239
|
(
|
239
|
-
self.max_bs * self.dp_size,
|
240
|
+
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
240
241
|
self.model_runner.model_config.hidden_size,
|
241
242
|
),
|
242
243
|
dtype=self.model_runner.dtype,
|
@@ -264,21 +265,24 @@ class CudaGraphRunner:
|
|
264
265
|
def model_capture_mode(self):
|
265
266
|
if hasattr(self.model_runner.model, "capture_mode"):
|
266
267
|
self.model_runner.model.capture_mode = True
|
268
|
+
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
269
|
+
self.model_runner.token_to_kv_pool.capture_mode = True
|
267
270
|
|
268
271
|
yield
|
269
272
|
|
270
273
|
if hasattr(self.model_runner.model, "capture_mode"):
|
271
274
|
self.model_runner.model.capture_mode = False
|
275
|
+
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
276
|
+
self.model_runner.token_to_kv_pool.capture_mode = False
|
272
277
|
|
273
278
|
def can_run(self, forward_batch: ForwardBatch):
|
274
279
|
if self.enable_dp_attention:
|
275
|
-
|
276
|
-
|
277
|
-
), max(forward_batch.global_num_tokens_cpu)
|
280
|
+
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
|
281
|
+
|
278
282
|
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
279
|
-
|
283
|
+
total_global_tokens in self.graphs
|
280
284
|
if self.disable_padding
|
281
|
-
else
|
285
|
+
else total_global_tokens <= self.max_bs
|
282
286
|
)
|
283
287
|
else:
|
284
288
|
is_bs_supported = (
|
@@ -300,12 +304,26 @@ class CudaGraphRunner:
|
|
300
304
|
def capture(self):
|
301
305
|
with graph_capture() as graph_capture_context:
|
302
306
|
self.stream = graph_capture_context.stream
|
307
|
+
avail_mem = get_available_gpu_memory(
|
308
|
+
self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
|
309
|
+
)
|
310
|
+
# Reverse the order to enable better memory sharing across cuda graphs.
|
303
311
|
capture_range = (
|
304
|
-
tqdm.tqdm(self.capture_bs)
|
312
|
+
tqdm.tqdm(list(reversed(self.capture_bs)))
|
305
313
|
if get_tensor_model_parallel_rank() == 0
|
306
|
-
else self.capture_bs
|
314
|
+
else reversed(self.capture_bs)
|
307
315
|
)
|
308
316
|
for bs in capture_range:
|
317
|
+
if get_tensor_model_parallel_rank() == 0:
|
318
|
+
avail_mem = get_available_gpu_memory(
|
319
|
+
self.model_runner.device,
|
320
|
+
self.model_runner.gpu_id,
|
321
|
+
empty_cache=False,
|
322
|
+
)
|
323
|
+
capture_range.set_description(
|
324
|
+
f"Capturing batches ({avail_mem=:.2f} GB)"
|
325
|
+
)
|
326
|
+
|
309
327
|
with patch_model(
|
310
328
|
self.model_runner.model,
|
311
329
|
bs in self.compile_bs,
|
@@ -340,8 +358,18 @@ class CudaGraphRunner:
|
|
340
358
|
mrope_positions = self.mrope_positions[:, :bs]
|
341
359
|
|
342
360
|
if self.enable_dp_attention:
|
343
|
-
|
344
|
-
|
361
|
+
self.global_num_tokens_gpu.copy_(
|
362
|
+
torch.tensor(
|
363
|
+
[
|
364
|
+
num_tokens // self.dp_size + (i < bs % self.dp_size)
|
365
|
+
for i in range(self.dp_size)
|
366
|
+
],
|
367
|
+
dtype=torch.int32,
|
368
|
+
device=input_ids.device,
|
369
|
+
)
|
370
|
+
)
|
371
|
+
global_num_tokens = self.global_num_tokens_gpu
|
372
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
345
373
|
else:
|
346
374
|
global_num_tokens = None
|
347
375
|
gathered_buffer = None
|
@@ -366,7 +394,7 @@ class CudaGraphRunner:
|
|
366
394
|
encoder_lens=encoder_lens,
|
367
395
|
return_logprob=False,
|
368
396
|
positions=positions,
|
369
|
-
|
397
|
+
global_num_tokens_gpu=global_num_tokens,
|
370
398
|
gathered_buffer=gathered_buffer,
|
371
399
|
mrope_positions=mrope_positions,
|
372
400
|
spec_algorithm=self.model_runner.spec_algorithm,
|
@@ -387,6 +415,9 @@ class CudaGraphRunner:
|
|
387
415
|
|
388
416
|
# Run and capture
|
389
417
|
def run_once():
|
418
|
+
# Clean intermediate result cache for DP attention
|
419
|
+
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
420
|
+
|
390
421
|
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
391
422
|
return logits_output.next_token_logits, logits_output.hidden_states
|
392
423
|
|
@@ -421,7 +452,7 @@ class CudaGraphRunner:
|
|
421
452
|
self.capture_hidden_mode = hidden_mode_from_spec_info
|
422
453
|
self.capture()
|
423
454
|
|
424
|
-
def
|
455
|
+
def replay_prepare(self, forward_batch: ForwardBatch):
|
425
456
|
self.recapture_if_needed(forward_batch)
|
426
457
|
|
427
458
|
raw_bs = forward_batch.batch_size
|
@@ -430,7 +461,7 @@ class CudaGraphRunner:
|
|
430
461
|
# Pad
|
431
462
|
if self.enable_dp_attention:
|
432
463
|
index = bisect.bisect_left(
|
433
|
-
self.capture_bs,
|
464
|
+
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
|
434
465
|
)
|
435
466
|
else:
|
436
467
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
@@ -454,6 +485,8 @@ class CudaGraphRunner:
|
|
454
485
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
455
486
|
if forward_batch.mrope_positions is not None:
|
456
487
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
488
|
+
if self.enable_dp_attention:
|
489
|
+
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
457
490
|
|
458
491
|
if hasattr(forward_batch.spec_info, "hidden_states"):
|
459
492
|
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
|
@@ -470,14 +503,29 @@ class CudaGraphRunner:
|
|
470
503
|
seq_lens_cpu=self.seq_lens_cpu,
|
471
504
|
)
|
472
505
|
|
506
|
+
# Store fields
|
507
|
+
self.raw_bs = raw_bs
|
508
|
+
self.raw_num_token = raw_num_token
|
509
|
+
self.bs = bs
|
510
|
+
|
511
|
+
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
|
512
|
+
if not skip_attn_backend_init:
|
513
|
+
self.replay_prepare(forward_batch)
|
514
|
+
else:
|
515
|
+
# In speculative decoding, these two fields are still needed.
|
516
|
+
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
517
|
+
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
518
|
+
|
473
519
|
# Replay
|
474
|
-
self.graphs[bs].replay()
|
475
|
-
next_token_logits, hidden_states = self.output_buffers[bs]
|
520
|
+
self.graphs[self.bs].replay()
|
521
|
+
next_token_logits, hidden_states = self.output_buffers[self.bs]
|
476
522
|
|
477
523
|
logits_output = LogitsProcessorOutput(
|
478
|
-
next_token_logits=next_token_logits[:raw_num_token],
|
524
|
+
next_token_logits=next_token_logits[: self.raw_num_token],
|
479
525
|
hidden_states=(
|
480
|
-
hidden_states[:raw_num_token]
|
526
|
+
hidden_states[: self.raw_num_token]
|
527
|
+
if hidden_states is not None
|
528
|
+
else None
|
481
529
|
),
|
482
530
|
)
|
483
531
|
return logits_output
|
@@ -43,7 +43,7 @@ from sglang.srt.utils import get_compiler_backend
|
|
43
43
|
if TYPE_CHECKING:
|
44
44
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
45
45
|
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
46
|
-
from sglang.srt.mem_cache.memory_pool import
|
46
|
+
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
47
47
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
48
48
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
49
49
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
@@ -51,9 +51,8 @@ if TYPE_CHECKING:
|
|
51
51
|
|
52
52
|
|
53
53
|
class ForwardMode(IntEnum):
|
54
|
-
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
55
|
-
PREFILL = auto()
|
56
54
|
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
|
55
|
+
# It is also called "prefill" in common terminology.
|
57
56
|
EXTEND = auto()
|
58
57
|
# Decode one token.
|
59
58
|
DECODE = auto()
|
@@ -153,6 +152,12 @@ class ForwardBatch:
|
|
153
152
|
top_logprobs_nums: Optional[List[int]] = None
|
154
153
|
token_ids_logprobs: Optional[List[List[int]]] = None
|
155
154
|
|
155
|
+
# For logits and logprobs post processing
|
156
|
+
temp_scaled_logprobs: bool = False
|
157
|
+
temperature: torch.Tensor = None
|
158
|
+
top_p_normalized_logprobs: bool = False
|
159
|
+
top_p: torch.Tensor = None
|
160
|
+
|
156
161
|
# Position information
|
157
162
|
positions: torch.Tensor = None
|
158
163
|
|
@@ -189,7 +194,7 @@ class ForwardBatch:
|
|
189
194
|
|
190
195
|
# Attention backend
|
191
196
|
req_to_token_pool: ReqToTokenPool = None
|
192
|
-
token_to_kv_pool:
|
197
|
+
token_to_kv_pool: KVCache = None
|
193
198
|
attn_backend: AttentionBackend = None
|
194
199
|
|
195
200
|
# For DP attention
|
@@ -229,7 +234,6 @@ class ForwardBatch:
|
|
229
234
|
extend_input_logprob_token_ids_gpu = (
|
230
235
|
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
|
231
236
|
)
|
232
|
-
|
233
237
|
ret = cls(
|
234
238
|
forward_mode=batch.forward_mode,
|
235
239
|
batch_size=len(batch.seq_lens),
|
@@ -259,15 +263,24 @@ class ForwardBatch:
|
|
259
263
|
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
260
264
|
)
|
261
265
|
|
266
|
+
# For DP attention
|
262
267
|
if batch.global_num_tokens is not None:
|
263
268
|
ret.global_num_tokens_cpu = batch.global_num_tokens
|
264
|
-
|
269
|
+
ret.global_num_tokens_gpu = torch.tensor(
|
270
|
+
batch.global_num_tokens, dtype=torch.int64
|
271
|
+
).to(device, non_blocking=True)
|
272
|
+
|
273
|
+
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
|
274
|
+
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
|
275
|
+
batch.global_num_tokens_for_logprob, dtype=torch.int64
|
276
|
+
).to(device, non_blocking=True)
|
277
|
+
|
278
|
+
sum_len = sum(batch.global_num_tokens)
|
265
279
|
ret.gathered_buffer = torch.zeros(
|
266
|
-
(
|
280
|
+
(sum_len, model_runner.model_config.hidden_size),
|
267
281
|
dtype=model_runner.dtype,
|
268
282
|
device=device,
|
269
283
|
)
|
270
|
-
|
271
284
|
if ret.forward_mode.is_idle():
|
272
285
|
ret.positions = torch.empty((0,), device=device)
|
273
286
|
return ret
|
@@ -417,8 +430,8 @@ def compute_position_kernel(
|
|
417
430
|
prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
|
418
431
|
seq_len = tl.load(extend_seq_lens + pid)
|
419
432
|
|
420
|
-
#
|
421
|
-
cumsum_start = 0
|
433
|
+
# NOTE: This can be slow for large bs
|
434
|
+
cumsum_start = tl.cast(0, tl.int64)
|
422
435
|
for i in range(pid):
|
423
436
|
cumsum_start += tl.load(extend_seq_lens + i)
|
424
437
|
|
@@ -35,17 +35,13 @@ from sglang.srt.distributed import (
|
|
35
35
|
set_custom_all_reduce,
|
36
36
|
)
|
37
37
|
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
38
|
-
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
39
|
-
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
40
|
-
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
41
|
-
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
42
|
-
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
43
38
|
from sglang.srt.layers.dp_attention import (
|
44
39
|
get_attention_tp_group,
|
45
40
|
get_attention_tp_size,
|
46
41
|
initialize_dp_attention,
|
47
42
|
)
|
48
43
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
44
|
+
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
|
49
45
|
from sglang.srt.layers.sampler import Sampler
|
50
46
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
51
47
|
from sglang.srt.lora.lora_manager import LoRAManager
|
@@ -57,9 +53,16 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
57
53
|
ReqToTokenPool,
|
58
54
|
TokenToKVPoolAllocator,
|
59
55
|
)
|
56
|
+
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
60
57
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
61
58
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
62
59
|
from sglang.srt.model_loader import get_model
|
60
|
+
from sglang.srt.model_loader.loader import (
|
61
|
+
DefaultModelLoader,
|
62
|
+
device_loading_context,
|
63
|
+
get_model_loader,
|
64
|
+
)
|
65
|
+
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
63
66
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
64
67
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
65
68
|
from sglang.srt.server_args import ServerArgs
|
@@ -77,11 +80,9 @@ from sglang.srt.utils import (
|
|
77
80
|
set_cpu_offload_max_bytes,
|
78
81
|
set_cuda_arch,
|
79
82
|
)
|
80
|
-
from sglang.utils import get_exception_traceback
|
81
83
|
|
82
84
|
logger = logging.getLogger(__name__)
|
83
85
|
|
84
|
-
|
85
86
|
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
86
87
|
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
87
88
|
|
@@ -118,6 +119,7 @@ class ModelRunner:
|
|
118
119
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
119
120
|
server_args.speculative_algorithm
|
120
121
|
)
|
122
|
+
self.page_size = server_args.page_size
|
121
123
|
self.req_to_token_pool = req_to_token_pool
|
122
124
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
123
125
|
|
@@ -160,6 +162,11 @@ class ModelRunner:
|
|
160
162
|
# Get memory before model loading
|
161
163
|
min_per_gpu_memory = self.init_torch_distributed()
|
162
164
|
|
165
|
+
# If it is a draft model tp_group can be different.
|
166
|
+
self.initialize(min_per_gpu_memory)
|
167
|
+
|
168
|
+
def initialize(self, min_per_gpu_memory: float):
|
169
|
+
server_args = self.server_args
|
163
170
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
164
171
|
enable=self.server_args.enable_memory_saver
|
165
172
|
)
|
@@ -299,15 +306,16 @@ class ModelRunner:
|
|
299
306
|
min_per_gpu_memory = get_available_gpu_memory(
|
300
307
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
301
308
|
)
|
302
|
-
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
303
309
|
self.tp_group = get_tp_group()
|
304
310
|
self.attention_tp_group = get_attention_tp_group()
|
305
311
|
|
306
312
|
# Check memory for tensor parallelism
|
313
|
+
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
307
314
|
if self.tp_size > 1:
|
308
315
|
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
309
316
|
raise ValueError(
|
310
|
-
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
317
|
+
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
|
318
|
+
f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
|
311
319
|
)
|
312
320
|
|
313
321
|
logger.info(
|
@@ -347,6 +355,8 @@ class ModelRunner:
|
|
347
355
|
# Load the model
|
348
356
|
# Remove monkey_patch when linear.py quant remove dependencies with vllm
|
349
357
|
monkey_patch_vllm_parallel_state()
|
358
|
+
monkey_patch_isinstance_for_vllm_base_layer()
|
359
|
+
|
350
360
|
with self.memory_saver_adapter.region():
|
351
361
|
self.model = get_model(
|
352
362
|
model_config=self.model_config,
|
@@ -354,6 +364,7 @@ class ModelRunner:
|
|
354
364
|
device_config=DeviceConfig(self.device),
|
355
365
|
)
|
356
366
|
monkey_patch_vllm_parallel_state(reverse=True)
|
367
|
+
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
357
368
|
|
358
369
|
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
359
370
|
if self.server_args.quantization_param_path is not None:
|
@@ -411,13 +422,6 @@ class ModelRunner:
|
|
411
422
|
self, model_path: str, load_format: str
|
412
423
|
) -> tuple[bool, str]:
|
413
424
|
"""Update engine weights in-place from the disk."""
|
414
|
-
from sglang.srt.model_loader.loader import (
|
415
|
-
DefaultModelLoader,
|
416
|
-
device_loading_context,
|
417
|
-
get_model_loader,
|
418
|
-
)
|
419
|
-
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
420
|
-
|
421
425
|
logger.info(
|
422
426
|
f"Update engine weights online from disk begin. "
|
423
427
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
@@ -427,7 +431,7 @@ class ModelRunner:
|
|
427
431
|
self.model_config.model_path = model_path
|
428
432
|
load_config = LoadConfig(load_format=load_format)
|
429
433
|
|
430
|
-
# Only support
|
434
|
+
# Only support DefaultModelLoader for now
|
431
435
|
loader = get_model_loader(load_config)
|
432
436
|
if not isinstance(loader, DefaultModelLoader):
|
433
437
|
message = f"Failed to get model loader: {loader}."
|
@@ -701,6 +705,12 @@ class ModelRunner:
|
|
701
705
|
)
|
702
706
|
self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
|
703
707
|
|
708
|
+
self.max_total_num_tokens = (
|
709
|
+
self.max_total_num_tokens
|
710
|
+
// self.server_args.page_size
|
711
|
+
* self.server_args.page_size
|
712
|
+
)
|
713
|
+
|
704
714
|
if self.max_total_num_tokens <= 0:
|
705
715
|
raise RuntimeError(
|
706
716
|
"Not enough memory. Please try to increase --mem-fraction-static."
|
@@ -723,6 +733,7 @@ class ModelRunner:
|
|
723
733
|
):
|
724
734
|
self.token_to_kv_pool = MLATokenToKVPool(
|
725
735
|
self.max_total_num_tokens,
|
736
|
+
page_size=self.page_size,
|
726
737
|
dtype=self.kv_cache_dtype,
|
727
738
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
728
739
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
@@ -733,6 +744,7 @@ class ModelRunner:
|
|
733
744
|
elif self.server_args.enable_double_sparsity:
|
734
745
|
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
735
746
|
self.max_total_num_tokens,
|
747
|
+
page_size=self.page_size,
|
736
748
|
dtype=self.kv_cache_dtype,
|
737
749
|
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
738
750
|
head_dim=self.model_config.head_dim,
|
@@ -744,6 +756,7 @@ class ModelRunner:
|
|
744
756
|
else:
|
745
757
|
self.token_to_kv_pool = MHATokenToKVPool(
|
746
758
|
self.max_total_num_tokens,
|
759
|
+
page_size=self.page_size,
|
747
760
|
dtype=self.kv_cache_dtype,
|
748
761
|
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
749
762
|
head_dim=self.model_config.head_dim,
|
@@ -753,12 +766,21 @@ class ModelRunner:
|
|
753
766
|
)
|
754
767
|
|
755
768
|
if self.token_to_kv_pool_allocator is None:
|
756
|
-
self.
|
757
|
-
self.
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
769
|
+
if self.page_size == 1:
|
770
|
+
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
771
|
+
self.max_total_num_tokens,
|
772
|
+
dtype=self.kv_cache_dtype,
|
773
|
+
device=self.device,
|
774
|
+
kvcache=self.token_to_kv_pool,
|
775
|
+
)
|
776
|
+
else:
|
777
|
+
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
778
|
+
self.max_total_num_tokens,
|
779
|
+
page_size=self.page_size,
|
780
|
+
dtype=self.kv_cache_dtype,
|
781
|
+
device=self.device,
|
782
|
+
kvcache=self.token_to_kv_pool,
|
783
|
+
)
|
762
784
|
else:
|
763
785
|
assert self.is_draft_worker
|
764
786
|
|
@@ -779,10 +801,13 @@ class ModelRunner:
|
|
779
801
|
def init_attention_backend(self):
|
780
802
|
"""Init attention kernel backend."""
|
781
803
|
if self.server_args.attention_backend == "flashinfer":
|
804
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
805
|
+
FlashInferAttnBackend,
|
806
|
+
)
|
807
|
+
|
782
808
|
# Init streams
|
783
809
|
if self.server_args.speculative_algorithm == "EAGLE":
|
784
810
|
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
785
|
-
|
786
811
|
self.attn_backend = FlashInferAttnBackend(self)
|
787
812
|
elif self.server_args.attention_backend == "triton":
|
788
813
|
assert self.sliding_window_size is None, (
|
@@ -794,12 +819,26 @@ class ModelRunner:
|
|
794
819
|
"Please use `--attention-backend flashinfer`."
|
795
820
|
)
|
796
821
|
if self.server_args.enable_double_sparsity:
|
822
|
+
from sglang.srt.layers.attention.double_sparsity_backend import (
|
823
|
+
DoubleSparseAttnBackend,
|
824
|
+
)
|
825
|
+
|
797
826
|
self.attn_backend = DoubleSparseAttnBackend(self)
|
798
827
|
else:
|
828
|
+
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
829
|
+
|
799
830
|
self.attn_backend = TritonAttnBackend(self)
|
800
831
|
elif self.server_args.attention_backend == "torch_native":
|
832
|
+
from sglang.srt.layers.attention.torch_native_backend import (
|
833
|
+
TorchNativeAttnBackend,
|
834
|
+
)
|
835
|
+
|
801
836
|
self.attn_backend = TorchNativeAttnBackend(self)
|
802
837
|
elif self.server_args.attention_backend == "flashinfer_mla":
|
838
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
839
|
+
FlashInferMLAAttnBackend,
|
840
|
+
)
|
841
|
+
|
803
842
|
self.attn_backend = FlashInferMLAAttnBackend(self)
|
804
843
|
else:
|
805
844
|
raise ValueError(
|
@@ -928,45 +967,6 @@ class ModelRunner:
|
|
928
967
|
sampling_info.update_regex_vocab_mask()
|
929
968
|
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
930
969
|
|
931
|
-
def update_output_logprobs(
|
932
|
-
self,
|
933
|
-
logits_output: LogitsProcessorOutput,
|
934
|
-
sampling_info: SamplingBatchInfo,
|
935
|
-
top_logprobs_nums: List[int],
|
936
|
-
token_ids_logprobs: List[int],
|
937
|
-
next_token_ids: torch.Tensor,
|
938
|
-
*,
|
939
|
-
num_tokens_per_req: List[int],
|
940
|
-
):
|
941
|
-
"""Update the logits_output's output logprob based on next_token_ids
|
942
|
-
|
943
|
-
Args:
|
944
|
-
logits_output: The logits output from the model forward
|
945
|
-
sampling_info: Sampling info for logprob calculation
|
946
|
-
top_logprobs_nums: Number of logprobs per request.
|
947
|
-
next_token_ids: Next token ids.
|
948
|
-
num_tokens_per_req: The number of tokens per request.
|
949
|
-
|
950
|
-
Returns:
|
951
|
-
A list of next_token_ids
|
952
|
-
"""
|
953
|
-
self._preprocess_logits(logits_output, sampling_info)
|
954
|
-
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
955
|
-
top_logprobs_nums_repeat_interleaved = []
|
956
|
-
token_ids_logprobs_repeat_interleaved = []
|
957
|
-
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
958
|
-
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
959
|
-
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
960
|
-
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
961
|
-
self.sampler(
|
962
|
-
logits_output,
|
963
|
-
sampling_info,
|
964
|
-
True,
|
965
|
-
top_logprobs_nums_repeat_interleaved,
|
966
|
-
token_ids_logprobs_repeat_interleaved,
|
967
|
-
batch_next_token_ids=next_token_ids,
|
968
|
-
)
|
969
|
-
|
970
970
|
def sample(
|
971
971
|
self,
|
972
972
|
logits_output: LogitsProcessorOutput,
|
@@ -48,6 +48,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
48
48
|
safetensors_weights_iterator,
|
49
49
|
)
|
50
50
|
from sglang.srt.utils import (
|
51
|
+
get_bool_env_var,
|
51
52
|
get_device_capability,
|
52
53
|
is_pin_memory_available,
|
53
54
|
set_weight_attrs,
|
@@ -197,7 +198,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|
197
198
|
|
198
199
|
Returns the path to the downloaded model, or None if the model is not
|
199
200
|
downloaded from ModelScope."""
|
200
|
-
if
|
201
|
+
if get_bool_env_var("SGLANG_USE_MODELSCOPE"):
|
201
202
|
# download model from ModelScope hub,
|
202
203
|
# lazy import so that modelscope is not required for normal use.
|
203
204
|
# pylint: disable=C.
|
@@ -455,7 +455,7 @@ def pt_weights_iterator(
|
|
455
455
|
disable=not enable_tqdm,
|
456
456
|
bar_format=_BAR_FORMAT,
|
457
457
|
):
|
458
|
-
state = torch.load(bin_file, map_location="cpu")
|
458
|
+
state = torch.load(bin_file, map_location="cpu", weights_only=True)
|
459
459
|
yield from state.items()
|
460
460
|
del state
|
461
461
|
torch.cuda.empty_cache()
|