sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +5 -2
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +1 -3
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +6 -25
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +104 -71
- sglang/srt/managers/tokenizer_manager.py +17 -8
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +58 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +117 -131
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +1 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama.py +51 -5
- sglang/srt/models/llama_classification.py +1 -20
- sglang/srt/models/llava.py +30 -5
- sglang/srt/models/llavavid.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +57 -44
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +7 -5
- sglang/test/test_utils.py +85 -1
- sglang/utils.py +32 -37
- sglang/version.py +1 -1
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.3.0.dist-info/RECORD +0 -118
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -25,12 +25,6 @@ from typing import Optional, Tuple, Type
|
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
28
|
-
from flashinfer import (
|
29
|
-
BatchDecodeWithPagedKVCacheWrapper,
|
30
|
-
BatchPrefillWithPagedKVCacheWrapper,
|
31
|
-
BatchPrefillWithRaggedKVCacheWrapper,
|
32
|
-
)
|
33
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
34
28
|
from vllm.config import DeviceConfig, LoadConfig
|
35
29
|
from vllm.config import ModelConfig as VllmModelConfig
|
36
30
|
from vllm.distributed import (
|
@@ -43,17 +37,19 @@ from vllm.distributed.parallel_state import in_the_same_node_as
|
|
43
37
|
from vllm.model_executor.model_loader import get_model
|
44
38
|
from vllm.model_executor.models import ModelRegistry
|
45
39
|
|
46
|
-
from sglang.
|
40
|
+
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
41
|
+
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
47
42
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
48
|
-
from sglang.srt.layers.sampler import SampleOutput
|
43
|
+
from sglang.srt.layers.sampler import SampleOutput, Sampler
|
44
|
+
from sglang.srt.lora.lora_manager import LoRAManager
|
49
45
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
50
46
|
from sglang.srt.mem_cache.memory_pool import (
|
51
47
|
MHATokenToKVPool,
|
52
48
|
MLATokenToKVPool,
|
53
49
|
ReqToTokenPool,
|
54
50
|
)
|
55
|
-
from sglang.srt.
|
56
|
-
from sglang.srt.
|
51
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
52
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
57
53
|
from sglang.srt.server_args import ServerArgs
|
58
54
|
from sglang.srt.utils import (
|
59
55
|
get_available_gpu_memory,
|
@@ -69,6 +65,8 @@ logger = logging.getLogger(__name__)
|
|
69
65
|
|
70
66
|
|
71
67
|
class ModelRunner:
|
68
|
+
"""ModelRunner runs the forward passes of the models."""
|
69
|
+
|
72
70
|
def __init__(
|
73
71
|
self,
|
74
72
|
model_config: ModelConfig,
|
@@ -92,13 +90,15 @@ class ModelRunner:
|
|
92
90
|
)
|
93
91
|
global_server_args_dict.update(
|
94
92
|
{
|
95
|
-
"
|
96
|
-
"
|
93
|
+
"attention_backend": server_args.attention_backend,
|
94
|
+
"sampling_backend": server_args.sampling_backend,
|
97
95
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
98
96
|
"enable_mla": server_args.enable_mla,
|
97
|
+
"torchao_config": server_args.torchao_config,
|
99
98
|
}
|
100
99
|
)
|
101
100
|
|
101
|
+
# Model-specific adjustment
|
102
102
|
if self.is_multimodal_model:
|
103
103
|
logger.info(
|
104
104
|
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
@@ -106,15 +106,19 @@ class ModelRunner:
|
|
106
106
|
server_args.chunked_prefill_size = None
|
107
107
|
server_args.mem_fraction_static *= 0.95
|
108
108
|
|
109
|
+
# Init componnets
|
109
110
|
min_per_gpu_memory = self.init_torch_distributed()
|
111
|
+
self.sampler = Sampler()
|
110
112
|
self.load_model()
|
113
|
+
if server_args.lora_paths is not None:
|
114
|
+
self.init_lora_manager()
|
111
115
|
self.init_memory_pool(
|
112
116
|
min_per_gpu_memory,
|
113
|
-
server_args.
|
117
|
+
server_args.max_running_requests,
|
114
118
|
server_args.max_total_tokens,
|
115
119
|
)
|
116
120
|
self.init_cublas()
|
117
|
-
self.
|
121
|
+
self.init_attention_backend()
|
118
122
|
self.init_cuda_graphs()
|
119
123
|
|
120
124
|
def init_torch_distributed(self):
|
@@ -313,6 +317,17 @@ class ModelRunner:
|
|
313
317
|
logger.info("Update weights end.")
|
314
318
|
return True, "Succeeded to update model weights"
|
315
319
|
|
320
|
+
def init_lora_manager(self):
|
321
|
+
self.lora_manager = LoRAManager(
|
322
|
+
base_model=self.model,
|
323
|
+
lora_paths=self.server_args.lora_paths,
|
324
|
+
base_hf_config=self.model_config.hf_config,
|
325
|
+
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
326
|
+
load_config=self.load_config,
|
327
|
+
dtype=self.dtype,
|
328
|
+
)
|
329
|
+
logger.info("LoRA manager ready.")
|
330
|
+
|
316
331
|
def profile_max_num_token(self, total_gpu_memory: int):
|
317
332
|
available_gpu_memory = get_available_gpu_memory(
|
318
333
|
self.gpu_id, distributed=self.tp_size > 1
|
@@ -343,8 +358,8 @@ class ModelRunner:
|
|
343
358
|
def init_memory_pool(
|
344
359
|
self,
|
345
360
|
total_gpu_memory: int,
|
346
|
-
max_num_reqs: int = None,
|
347
|
-
max_total_tokens: int = None,
|
361
|
+
max_num_reqs: Optional[int] = None,
|
362
|
+
max_total_tokens: Optional[int] = None,
|
348
363
|
):
|
349
364
|
if self.server_args.kv_cache_dtype == "auto":
|
350
365
|
self.kv_cache_dtype = self.dtype
|
@@ -378,7 +393,7 @@ class ModelRunner:
|
|
378
393
|
),
|
379
394
|
2048,
|
380
395
|
),
|
381
|
-
|
396
|
+
4096,
|
382
397
|
)
|
383
398
|
|
384
399
|
self.req_to_token_pool = ReqToTokenPool(
|
@@ -396,9 +411,6 @@ class ModelRunner:
|
|
396
411
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
397
412
|
layer_num=self.model_config.num_hidden_layers,
|
398
413
|
)
|
399
|
-
logger.info("using MLA Triton implementaion, flashinfer is disabled")
|
400
|
-
# FIXME: temporarily only Triton MLA is supported
|
401
|
-
self.server_args.disable_flashinfer = True
|
402
414
|
else:
|
403
415
|
self.token_to_kv_pool = MHATokenToKVPool(
|
404
416
|
self.max_total_num_tokens,
|
@@ -421,118 +433,46 @@ class ModelRunner:
|
|
421
433
|
c = a @ b
|
422
434
|
return c
|
423
435
|
|
424
|
-
def
|
425
|
-
"""Init
|
426
|
-
if self.server_args.
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
self.flashinfer_decode_wrapper = None
|
433
|
-
return
|
434
|
-
|
435
|
-
if not _grouped_size_compiled_for_decode_kernels(
|
436
|
-
self.model_config.num_attention_heads // self.tp_size,
|
437
|
-
self.model_config.get_num_kv_heads(self.tp_size),
|
438
|
-
):
|
439
|
-
use_tensor_cores = True
|
440
|
-
else:
|
441
|
-
use_tensor_cores = False
|
442
|
-
|
443
|
-
if self.sliding_window_size is None:
|
444
|
-
self.flashinfer_workspace_buffer = torch.empty(
|
445
|
-
global_config.flashinfer_workspace_size,
|
446
|
-
dtype=torch.uint8,
|
447
|
-
device="cuda",
|
448
|
-
)
|
449
|
-
self.flashinfer_prefill_wrapper_ragged = (
|
450
|
-
BatchPrefillWithRaggedKVCacheWrapper(
|
451
|
-
self.flashinfer_workspace_buffer, "NHD"
|
452
|
-
)
|
453
|
-
)
|
454
|
-
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
455
|
-
self.flashinfer_workspace_buffer, "NHD"
|
456
|
-
)
|
457
|
-
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
458
|
-
self.flashinfer_workspace_buffer,
|
459
|
-
"NHD",
|
460
|
-
use_tensor_cores=use_tensor_cores,
|
436
|
+
def init_attention_backend(self):
|
437
|
+
"""Init attention kernel backend."""
|
438
|
+
if self.server_args.attention_backend == "flashinfer":
|
439
|
+
self.attn_backend = FlashInferAttnBackend(self)
|
440
|
+
elif self.server_args.attention_backend == "triton":
|
441
|
+
assert self.sliding_window_size is None, (
|
442
|
+
"Window attention is not supported in the triton attention backend. "
|
443
|
+
"Please use `--attention-backend flashinfer`."
|
461
444
|
)
|
445
|
+
self.attn_backend = TritonAttnBackend(self)
|
462
446
|
else:
|
463
|
-
|
464
|
-
|
465
|
-
dtype=torch.uint8,
|
466
|
-
device="cuda",
|
447
|
+
raise ValueError(
|
448
|
+
f"Invalid attention backend: {self.server_args.attention_backend}"
|
467
449
|
)
|
468
|
-
self.flashinfer_prefill_wrapper_ragged = None
|
469
|
-
self.flashinfer_prefill_wrapper_paged = []
|
470
|
-
self.flashinfer_decode_wrapper = []
|
471
|
-
for i in range(2):
|
472
|
-
self.flashinfer_prefill_wrapper_paged.append(
|
473
|
-
BatchPrefillWithPagedKVCacheWrapper(
|
474
|
-
self.flashinfer_workspace_buffer, "NHD"
|
475
|
-
)
|
476
|
-
)
|
477
|
-
self.flashinfer_decode_wrapper.append(
|
478
|
-
BatchDecodeWithPagedKVCacheWrapper(
|
479
|
-
self.flashinfer_workspace_buffer,
|
480
|
-
"NHD",
|
481
|
-
use_tensor_cores=use_tensor_cores,
|
482
|
-
)
|
483
|
-
)
|
484
450
|
|
485
451
|
def init_cuda_graphs(self):
|
486
452
|
"""Capture cuda graphs."""
|
453
|
+
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
454
|
+
|
455
|
+
self.cuda_graph_runner = None
|
456
|
+
|
487
457
|
if not self.is_generation:
|
488
458
|
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
489
459
|
return
|
490
460
|
|
491
|
-
|
492
|
-
|
493
|
-
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
|
494
|
-
self.cuda_graph_runner = None
|
461
|
+
if self.server_args.disable_cuda_graph:
|
495
462
|
return
|
496
463
|
|
497
464
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
498
|
-
|
499
|
-
if self.server_args.disable_cuda_graph_padding:
|
500
|
-
batch_size_list = list(range(1, 32)) + [64, 128]
|
501
|
-
else:
|
502
|
-
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
503
|
-
|
504
|
-
self.cuda_graph_runner = CudaGraphRunner(
|
505
|
-
self,
|
506
|
-
max_batch_size_to_capture=max(batch_size_list),
|
507
|
-
use_torch_compile=self.server_args.enable_torch_compile,
|
508
|
-
disable_padding=self.server_args.disable_cuda_graph_padding,
|
509
|
-
)
|
510
|
-
try:
|
511
|
-
self.cuda_graph_runner.capture(batch_size_list)
|
512
|
-
except RuntimeError as e:
|
513
|
-
raise Exception(
|
514
|
-
f"Capture cuda graph failed: {e}\n"
|
515
|
-
"Possible solutions:\n"
|
516
|
-
"1. disable cuda graph by --disable-cuda-graph\n"
|
517
|
-
"2. set --mem-fraction-static to a smaller value\n"
|
518
|
-
"3. disable torch compile by not using --enable-torch-compile\n"
|
519
|
-
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
520
|
-
)
|
465
|
+
self.cuda_graph_runner = CudaGraphRunner(self)
|
521
466
|
|
522
467
|
@torch.inference_mode()
|
523
468
|
def forward_decode(self, batch: ScheduleBatch):
|
524
|
-
if
|
525
|
-
self.
|
526
|
-
|
527
|
-
|
528
|
-
):
|
469
|
+
if self.server_args.lora_paths is not None:
|
470
|
+
self.lora_manager.prepare_lora_batch(batch)
|
471
|
+
|
472
|
+
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
529
473
|
return self.cuda_graph_runner.replay(batch)
|
530
474
|
|
531
|
-
input_metadata = InputMetadata.from_schedule_batch(
|
532
|
-
self,
|
533
|
-
batch,
|
534
|
-
ForwardMode.DECODE,
|
535
|
-
)
|
475
|
+
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
536
476
|
|
537
477
|
return self.model.forward(
|
538
478
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -540,11 +480,10 @@ class ModelRunner:
|
|
540
480
|
|
541
481
|
@torch.inference_mode()
|
542
482
|
def forward_extend(self, batch: ScheduleBatch):
|
543
|
-
input_metadata = InputMetadata.from_schedule_batch(
|
544
|
-
|
545
|
-
batch,
|
546
|
-
|
547
|
-
)
|
483
|
+
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
484
|
+
if self.server_args.lora_paths is not None:
|
485
|
+
self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)
|
486
|
+
|
548
487
|
if self.is_generation:
|
549
488
|
return self.model.forward(
|
550
489
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -560,11 +499,7 @@ class ModelRunner:
|
|
560
499
|
|
561
500
|
@torch.inference_mode()
|
562
501
|
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
563
|
-
input_metadata = InputMetadata.from_schedule_batch(
|
564
|
-
self,
|
565
|
-
batch,
|
566
|
-
forward_mode=ForwardMode.EXTEND,
|
567
|
-
)
|
502
|
+
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
568
503
|
return self.model.forward(
|
569
504
|
batch.input_ids,
|
570
505
|
input_metadata.positions,
|
@@ -574,17 +509,68 @@ class ModelRunner:
|
|
574
509
|
input_metadata.image_offsets,
|
575
510
|
)
|
576
511
|
|
577
|
-
def forward(
|
578
|
-
|
579
|
-
|
580
|
-
if self.is_multimodal_model and forward_mode
|
512
|
+
def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
|
513
|
+
assert batch.forward_mode is not None
|
514
|
+
|
515
|
+
if self.is_multimodal_model and batch.forward_mode.is_extend():
|
581
516
|
return self.forward_extend_multi_modal(batch)
|
582
|
-
elif forward_mode
|
517
|
+
elif batch.forward_mode.is_decode():
|
583
518
|
return self.forward_decode(batch)
|
584
|
-
elif forward_mode
|
519
|
+
elif batch.forward_mode.is_extend():
|
585
520
|
return self.forward_extend(batch)
|
586
521
|
else:
|
587
|
-
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
522
|
+
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
|
523
|
+
|
524
|
+
def _check_sample_results(self, sample_output: SampleOutput):
|
525
|
+
if not torch.all(sample_output.success):
|
526
|
+
probs = sample_output.probs
|
527
|
+
batch_next_token_ids = sample_output.batch_next_token_ids
|
528
|
+
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
529
|
+
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
530
|
+
argmax_ids = torch.argmax(probs, dim=-1)
|
531
|
+
batch_next_token_ids = torch.where(
|
532
|
+
sample_output.success, batch_next_token_ids, argmax_ids
|
533
|
+
)
|
534
|
+
sample_output.probs = probs
|
535
|
+
sample_output.batch_next_token_ids = batch_next_token_ids
|
536
|
+
|
537
|
+
return sample_output.batch_next_token_ids
|
538
|
+
|
539
|
+
def _apply_logits_bias(
|
540
|
+
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
541
|
+
):
|
542
|
+
# Apply logit_bias
|
543
|
+
if sampling_info.logit_bias is not None:
|
544
|
+
logits.add_(sampling_info.logit_bias)
|
545
|
+
|
546
|
+
# min-token, presence, frequency
|
547
|
+
if sampling_info.linear_penalties is not None:
|
548
|
+
logits += sampling_info.linear_penalties
|
549
|
+
|
550
|
+
# repetition
|
551
|
+
if sampling_info.scaling_penalties is not None:
|
552
|
+
logits = torch.where(
|
553
|
+
logits > 0,
|
554
|
+
logits / sampling_info.scaling_penalties,
|
555
|
+
logits * sampling_info.scaling_penalties,
|
556
|
+
)
|
557
|
+
|
558
|
+
# Apply regex vocab_mask
|
559
|
+
if sampling_info.vocab_mask is not None:
|
560
|
+
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
561
|
+
|
562
|
+
return logits
|
563
|
+
|
564
|
+
def sample(
|
565
|
+
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
|
566
|
+
) -> torch.Tensor:
|
567
|
+
batch.sampling_info.update_regex_vocab_mask(batch)
|
568
|
+
batch.sampling_info.update_penalties()
|
569
|
+
logits = self._apply_logits_bias(
|
570
|
+
logits_output.next_token_logits, batch.sampling_info
|
571
|
+
)
|
572
|
+
sample_output = self.sampler(logits, batch.sampling_info)
|
573
|
+
return self._check_sample_results(sample_output)
|
588
574
|
|
589
575
|
|
590
576
|
@lru_cache()
|