sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +41 -27
- sglang/bench_one_batch.py +60 -4
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +83 -71
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +46 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +452 -0
- sglang/srt/entrypoints/http_server.py +603 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +8 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +71 -0
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +65 -14
- sglang/srt/layers/logits_processor.py +49 -64
- sglang/srt/layers/moe/ep_moe/layer.py +24 -16
- sglang/srt/layers/moe/fused_moe_native.py +84 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
- sglang/srt/layers/parameter.py +18 -8
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +10 -4
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1184 -31
- sglang/srt/layers/sampler.py +64 -6
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +24 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +57 -3
- sglang/srt/managers/schedule_batch.py +78 -45
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +326 -201
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +210 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +26 -30
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +44 -19
- sglang/srt/model_loader/loader.py +83 -6
- sglang/srt/model_loader/weight_utils.py +145 -6
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +17 -5
- sglang/srt/models/dbrx.py +13 -5
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +11 -11
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +15 -25
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +4 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +9 -9
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +20 -7
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +7 -4
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +143 -18
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +77 -15
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +164 -129
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +2 -1
- sglang/test/test_utils.py +83 -22
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -27,7 +27,7 @@ import logging
|
|
27
27
|
import threading
|
28
28
|
from enum import IntEnum
|
29
29
|
from functools import wraps
|
30
|
-
from typing import List, Tuple, Union
|
30
|
+
from typing import List, Optional, Tuple, Union
|
31
31
|
|
32
32
|
import numpy as np
|
33
33
|
import psutil
|
@@ -49,7 +49,6 @@ class ReqToTokenPool:
|
|
49
49
|
size: int,
|
50
50
|
max_context_len: int,
|
51
51
|
device: str,
|
52
|
-
use_records: bool,
|
53
52
|
enable_memory_saver: bool,
|
54
53
|
):
|
55
54
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
@@ -64,17 +63,9 @@ class ReqToTokenPool:
|
|
64
63
|
(size, max_context_len), dtype=torch.int32, device=device
|
65
64
|
)
|
66
65
|
self.free_slots = list(range(size))
|
67
|
-
self.write_records = []
|
68
|
-
self.use_records = use_records
|
69
|
-
|
70
|
-
if self.use_records:
|
71
|
-
self.write = self.write_with_records
|
72
|
-
else:
|
73
|
-
self.write = self.write_without_records
|
74
66
|
|
75
67
|
def write(self, indices, values):
|
76
|
-
|
77
|
-
raise NotImplementedError()
|
68
|
+
self.req_to_token[indices] = values
|
78
69
|
|
79
70
|
def available_size(self):
|
80
71
|
return len(self.free_slots)
|
@@ -96,23 +87,6 @@ class ReqToTokenPool:
|
|
96
87
|
|
97
88
|
def clear(self):
|
98
89
|
self.free_slots = list(range(self.size))
|
99
|
-
self.write_records = []
|
100
|
-
|
101
|
-
def write_without_records(self, indices, values):
|
102
|
-
self.req_to_token[indices] = values
|
103
|
-
|
104
|
-
def write_with_records(self, indices, values):
|
105
|
-
self.req_to_token[indices] = values
|
106
|
-
self.write_records.append((indices, values))
|
107
|
-
|
108
|
-
def get_write_records(self):
|
109
|
-
ret = self.write_records
|
110
|
-
self.write_records = []
|
111
|
-
return ret
|
112
|
-
|
113
|
-
def apply_write_records(self, write_records: List[Tuple]):
|
114
|
-
for indices, values in write_records:
|
115
|
-
self.req_to_token[indices] = values
|
116
90
|
|
117
91
|
|
118
92
|
class BaseTokenToKVPool:
|
@@ -296,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
296
270
|
loc: torch.Tensor,
|
297
271
|
cache_k: torch.Tensor,
|
298
272
|
cache_v: torch.Tensor,
|
299
|
-
k_scale: float =
|
300
|
-
v_scale: float =
|
273
|
+
k_scale: Optional[float] = None,
|
274
|
+
v_scale: Optional[float] = None,
|
301
275
|
):
|
302
276
|
layer_id = layer.layer_id
|
303
277
|
if cache_k.dtype != self.dtype:
|
304
|
-
|
305
|
-
|
278
|
+
if k_scale is not None:
|
279
|
+
cache_k.div_(k_scale)
|
280
|
+
if v_scale is not None:
|
281
|
+
cache_v.div_(v_scale)
|
282
|
+
cache_k = cache_k.to(self.dtype)
|
283
|
+
cache_v = cache_v.to(self.dtype)
|
306
284
|
if self.store_dtype != self.dtype:
|
307
285
|
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
308
286
|
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
sglang/srt/metrics/collector.py
CHANGED
@@ -25,6 +25,7 @@ class SchedulerStats:
|
|
25
25
|
gen_throughput: float = 0.0
|
26
26
|
num_queue_reqs: int = 0
|
27
27
|
cache_hit_rate: float = 0.0
|
28
|
+
spec_accept_length: float = 0.0
|
28
29
|
|
29
30
|
|
30
31
|
class SchedulerMetricsCollector:
|
@@ -37,42 +38,49 @@ class SchedulerMetricsCollector:
|
|
37
38
|
|
38
39
|
self.num_running_reqs = Gauge(
|
39
40
|
name="sglang:num_running_reqs",
|
40
|
-
documentation="The number of running requests",
|
41
|
+
documentation="The number of running requests.",
|
41
42
|
labelnames=labels.keys(),
|
42
43
|
multiprocess_mode="sum",
|
43
44
|
)
|
44
45
|
|
45
46
|
self.num_used_tokens = Gauge(
|
46
47
|
name="sglang:num_used_tokens",
|
47
|
-
documentation="The number of used tokens",
|
48
|
+
documentation="The number of used tokens.",
|
48
49
|
labelnames=labels.keys(),
|
49
50
|
multiprocess_mode="sum",
|
50
51
|
)
|
51
52
|
|
52
53
|
self.token_usage = Gauge(
|
53
54
|
name="sglang:token_usage",
|
54
|
-
documentation="The token usage",
|
55
|
+
documentation="The token usage.",
|
55
56
|
labelnames=labels.keys(),
|
56
57
|
multiprocess_mode="mostrecent",
|
57
58
|
)
|
58
59
|
|
59
60
|
self.gen_throughput = Gauge(
|
60
61
|
name="sglang:gen_throughput",
|
61
|
-
documentation="The
|
62
|
+
documentation="The generation throughput (token/s).",
|
62
63
|
labelnames=labels.keys(),
|
63
64
|
multiprocess_mode="sum",
|
64
65
|
)
|
65
66
|
|
66
67
|
self.num_queue_reqs = Gauge(
|
67
68
|
name="sglang:num_queue_reqs",
|
68
|
-
documentation="The number of requests in the waiting queue",
|
69
|
+
documentation="The number of requests in the waiting queue.",
|
69
70
|
labelnames=labels.keys(),
|
70
71
|
multiprocess_mode="sum",
|
71
72
|
)
|
72
73
|
|
73
74
|
self.cache_hit_rate = Gauge(
|
74
75
|
name="sglang:cache_hit_rate",
|
75
|
-
documentation="The cache hit rate",
|
76
|
+
documentation="The prefix cache hit rate.",
|
77
|
+
labelnames=labels.keys(),
|
78
|
+
multiprocess_mode="mostrecent",
|
79
|
+
)
|
80
|
+
|
81
|
+
self.spec_accept_length = Gauge(
|
82
|
+
name="sglang:spec_accept_length",
|
83
|
+
documentation="The average acceptance length of speculative decoding.",
|
76
84
|
labelnames=labels.keys(),
|
77
85
|
multiprocess_mode="mostrecent",
|
78
86
|
)
|
@@ -88,6 +96,7 @@ class SchedulerMetricsCollector:
|
|
88
96
|
self._log_gauge(self.gen_throughput, stats.gen_throughput)
|
89
97
|
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
|
90
98
|
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
99
|
+
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
91
100
|
|
92
101
|
|
93
102
|
class TokenizerMetricsCollector:
|
@@ -21,10 +21,10 @@ from typing import TYPE_CHECKING, Callable
|
|
21
21
|
|
22
22
|
import torch
|
23
23
|
import tqdm
|
24
|
-
from vllm.distributed import get_tensor_model_parallel_rank
|
25
|
-
from vllm.distributed.parallel_state import graph_capture
|
26
24
|
from vllm.model_executor.custom_op import CustomOp
|
27
25
|
|
26
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
27
|
+
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
28
28
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
29
29
|
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
30
30
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
@@ -33,13 +33,12 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
33
33
|
ForwardBatch,
|
34
34
|
ForwardMode,
|
35
35
|
)
|
36
|
-
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
37
36
|
|
38
37
|
if TYPE_CHECKING:
|
39
38
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
40
39
|
|
41
40
|
|
42
|
-
def _to_torch(model: torch.nn.Module, reverse: bool,
|
41
|
+
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
43
42
|
for sub in model._modules.values():
|
44
43
|
if isinstance(sub, CustomOp):
|
45
44
|
if reverse:
|
@@ -48,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
|
48
47
|
else:
|
49
48
|
# NOTE: Temporarily workaround MoE
|
50
49
|
if "FusedMoE" in sub.__class__.__name__:
|
51
|
-
if
|
50
|
+
if num_tokens == 1:
|
52
51
|
# The performance of torch.compile on this layer is not always good when bs > 1,
|
53
52
|
# so we decide to only use torch.compile when bs =1
|
54
53
|
sub._forward_method = fused_moe_forward_native
|
@@ -56,23 +55,22 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
|
56
55
|
sub._forward_method = sub.forward_native
|
57
56
|
setattr(sub, "is_torch_compile", True)
|
58
57
|
if isinstance(sub, torch.nn.Module):
|
59
|
-
_to_torch(sub, reverse,
|
58
|
+
_to_torch(sub, reverse, num_tokens)
|
60
59
|
|
61
60
|
|
62
61
|
@contextmanager
|
63
62
|
def patch_model(
|
64
63
|
model: torch.nn.Module,
|
65
64
|
enable_compile: bool,
|
66
|
-
|
67
|
-
tp_group:
|
65
|
+
num_tokens: int,
|
66
|
+
tp_group: GroupCoordinator,
|
68
67
|
):
|
69
68
|
"""Patch the model to make it compatible with with torch.compile"""
|
70
69
|
backup_ca_comm = None
|
71
70
|
|
72
71
|
try:
|
73
72
|
if enable_compile:
|
74
|
-
_to_torch(model, reverse=False,
|
75
|
-
monkey_patch_vllm_all_gather()
|
73
|
+
_to_torch(model, reverse=False, num_tokens=num_tokens)
|
76
74
|
backup_ca_comm = tp_group.ca_comm
|
77
75
|
# Use custom-allreduce here.
|
78
76
|
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
@@ -87,8 +85,7 @@ def patch_model(
|
|
87
85
|
yield model.forward
|
88
86
|
finally:
|
89
87
|
if enable_compile:
|
90
|
-
_to_torch(model, reverse=True,
|
91
|
-
monkey_patch_vllm_all_gather(reverse=True)
|
88
|
+
_to_torch(model, reverse=True, num_tokens=num_tokens)
|
92
89
|
tp_group.ca_comm = backup_ca_comm
|
93
90
|
|
94
91
|
|
@@ -122,6 +119,7 @@ class CudaGraphRunner:
|
|
122
119
|
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
123
120
|
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
|
124
121
|
self.tp_size = self.model_runner.tp_size
|
122
|
+
self.dp_size = self.model_runner.server_args.dp_size
|
125
123
|
|
126
124
|
# Batch sizes to capture
|
127
125
|
self.capture_bs = self.model_runner.server_args.cuda_graph_bs
|
@@ -151,9 +149,18 @@ class CudaGraphRunner:
|
|
151
149
|
and bs <= model_runner.server_args.cuda_graph_max_bs
|
152
150
|
]
|
153
151
|
|
152
|
+
self.compile_bs = (
|
153
|
+
[
|
154
|
+
bs
|
155
|
+
for bs in self.capture_bs
|
156
|
+
if bs <= self.model_runner.server_args.torch_compile_max_bs
|
157
|
+
]
|
158
|
+
if self.use_torch_compile
|
159
|
+
else []
|
160
|
+
)
|
161
|
+
|
154
162
|
self.capture_forward_mode = ForwardMode.DECODE
|
155
163
|
self.num_tokens_per_bs = 1
|
156
|
-
|
157
164
|
if model_runner.spec_algorithm.is_eagle():
|
158
165
|
if self.model_runner.is_draft_worker:
|
159
166
|
self.num_tokens_per_bs = (
|
@@ -165,16 +172,6 @@ class CudaGraphRunner:
|
|
165
172
|
self.model_runner.server_args.speculative_num_draft_tokens
|
166
173
|
)
|
167
174
|
|
168
|
-
self.compile_bs = (
|
169
|
-
[
|
170
|
-
bs
|
171
|
-
for bs in self.capture_bs
|
172
|
-
if bs <= self.model_runner.server_args.torch_compile_max_bs
|
173
|
-
]
|
174
|
-
if self.use_torch_compile
|
175
|
-
else []
|
176
|
-
)
|
177
|
-
|
178
175
|
# Attention backend
|
179
176
|
self.max_bs = max(self.capture_bs)
|
180
177
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
@@ -182,7 +179,6 @@ class CudaGraphRunner:
|
|
182
179
|
self.seq_len_fill_value = (
|
183
180
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
184
181
|
)
|
185
|
-
|
186
182
|
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
187
183
|
self.encoder_len_fill_value = 0
|
188
184
|
|
@@ -191,14 +187,14 @@ class CudaGraphRunner:
|
|
191
187
|
|
192
188
|
# Common inputs
|
193
189
|
with torch.device("cuda"):
|
194
|
-
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.
|
190
|
+
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
195
191
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
196
192
|
self.seq_lens = torch.full(
|
197
193
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
198
194
|
)
|
199
|
-
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.
|
195
|
+
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
200
196
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
201
|
-
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.
|
197
|
+
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
202
198
|
|
203
199
|
# Speculative_inference
|
204
200
|
if model_runner.spec_algorithm.is_eagle():
|
@@ -218,7 +214,7 @@ class CudaGraphRunner:
|
|
218
214
|
if self.enable_dp_attention:
|
219
215
|
self.gathered_buffer = torch.zeros(
|
220
216
|
(
|
221
|
-
self.max_bs * self.
|
217
|
+
self.max_bs * self.dp_size,
|
222
218
|
self.model_runner.model_config.hidden_size,
|
223
219
|
),
|
224
220
|
dtype=self.model_runner.dtype,
|
@@ -287,8 +283,8 @@ class CudaGraphRunner:
|
|
287
283
|
with patch_model(
|
288
284
|
self.model_runner.model,
|
289
285
|
bs in self.compile_bs,
|
290
|
-
bs,
|
291
|
-
self.model_runner.tp_group,
|
286
|
+
num_tokens=bs * self.num_tokens_per_bs,
|
287
|
+
tp_group=self.model_runner.tp_group,
|
292
288
|
) as forward:
|
293
289
|
(
|
294
290
|
graph,
|
@@ -38,7 +38,7 @@ import triton
|
|
38
38
|
import triton.language as tl
|
39
39
|
|
40
40
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
41
|
-
from sglang.srt.utils import
|
41
|
+
from sglang.srt.utils import get_compiler_backend
|
42
42
|
|
43
43
|
if TYPE_CHECKING:
|
44
44
|
from sglang.srt.layers.attention import AttentionBackend
|
@@ -282,6 +282,9 @@ class ForwardBatch:
|
|
282
282
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
283
283
|
lora_paths=batch.lora_paths,
|
284
284
|
sampling_info=batch.sampling_info,
|
285
|
+
req_to_token_pool=model_runner.req_to_token_pool,
|
286
|
+
token_to_kv_pool=model_runner.token_to_kv_pool,
|
287
|
+
attn_backend=model_runner.attn_backend,
|
285
288
|
spec_algorithm=batch.spec_algorithm,
|
286
289
|
spec_info=batch.spec_info,
|
287
290
|
capture_hidden_mode=batch.capture_hidden_mode,
|
@@ -336,11 +339,6 @@ class ForwardBatch:
|
|
336
339
|
if model_runner.model_is_mrope:
|
337
340
|
ret.compute_mrope_positions(model_runner, batch)
|
338
341
|
|
339
|
-
# Init attention information
|
340
|
-
ret.req_to_token_pool = model_runner.req_to_token_pool
|
341
|
-
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
342
|
-
ret.attn_backend = model_runner.attn_backend
|
343
|
-
|
344
342
|
# Init lora information
|
345
343
|
if model_runner.server_args.lora_paths is not None:
|
346
344
|
model_runner.lora_manager.prepare_lora_batch(ret)
|
@@ -417,6 +415,6 @@ def compute_position_torch(
|
|
417
415
|
return positions.to(torch.int64), extend_start_loc
|
418
416
|
|
419
417
|
|
420
|
-
@
|
418
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
421
419
|
def clamp_position(seq_lens):
|
422
420
|
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
@@ -21,20 +21,26 @@ from typing import List, Optional, Tuple
|
|
21
21
|
|
22
22
|
import torch
|
23
23
|
import torch.distributed as dist
|
24
|
-
|
24
|
+
|
25
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
26
|
+
from sglang.srt.configs.load_config import LoadConfig
|
27
|
+
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
28
|
+
from sglang.srt.distributed import (
|
25
29
|
get_tp_group,
|
26
30
|
init_distributed_environment,
|
27
31
|
initialize_model_parallel,
|
28
32
|
set_custom_all_reduce,
|
29
33
|
)
|
30
|
-
|
31
|
-
from sglang.srt.configs.device_config import DeviceConfig
|
32
|
-
from sglang.srt.configs.load_config import LoadConfig
|
33
|
-
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
34
|
+
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
34
35
|
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
35
36
|
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
36
37
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
37
38
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
39
|
+
from sglang.srt.layers.dp_attention import (
|
40
|
+
get_attention_tp_group,
|
41
|
+
get_attention_tp_size,
|
42
|
+
initialize_dp_attention,
|
43
|
+
)
|
38
44
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
39
45
|
from sglang.srt.layers.sampler import Sampler
|
40
46
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
@@ -57,8 +63,8 @@ from sglang.srt.utils import (
|
|
57
63
|
init_custom_process_group,
|
58
64
|
is_cuda,
|
59
65
|
is_hip,
|
66
|
+
monkey_patch_p2p_access_check,
|
60
67
|
monkey_patch_vllm_gguf_config,
|
61
|
-
monkey_patch_vllm_p2p_access_check,
|
62
68
|
set_cpu_offload_max_bytes,
|
63
69
|
)
|
64
70
|
|
@@ -101,8 +107,10 @@ class ModelRunner:
|
|
101
107
|
self.model_config.attention_arch == AttentionArch.MLA
|
102
108
|
and not self.server_args.disable_mla
|
103
109
|
):
|
104
|
-
|
105
|
-
self.server_args.
|
110
|
+
# TODO: add MLA optimization on CPU
|
111
|
+
if self.server_args.device != "cpu":
|
112
|
+
logger.info("MLA optimization is turned on. Use triton backend.")
|
113
|
+
self.server_args.attention_backend = "triton"
|
106
114
|
|
107
115
|
if self.server_args.enable_double_sparsity:
|
108
116
|
logger.info(
|
@@ -159,6 +167,7 @@ class ModelRunner:
|
|
159
167
|
"enable_nan_detection": server_args.enable_nan_detection,
|
160
168
|
"enable_dp_attention": server_args.enable_dp_attention,
|
161
169
|
"enable_ep_moe": server_args.enable_ep_moe,
|
170
|
+
"device": server_args.device,
|
162
171
|
}
|
163
172
|
)
|
164
173
|
|
@@ -176,9 +185,12 @@ class ModelRunner:
|
|
176
185
|
self.load_model()
|
177
186
|
|
178
187
|
# Apply torchao quantization
|
179
|
-
|
180
|
-
|
181
|
-
|
188
|
+
torchao_applied = getattr(self.model, "torchao_applied", False)
|
189
|
+
# In layered loading, torchao may have been applied
|
190
|
+
if not torchao_applied:
|
191
|
+
apply_torchao_config_to_model(
|
192
|
+
self.model, global_server_args_dict["torchao_config"]
|
193
|
+
)
|
182
194
|
|
183
195
|
# Apply torch TP if the model supports it
|
184
196
|
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
@@ -206,7 +218,7 @@ class ModelRunner:
|
|
206
218
|
|
207
219
|
def init_torch_distributed(self):
|
208
220
|
logger.info("Init torch distributed begin.")
|
209
|
-
|
221
|
+
|
210
222
|
torch.get_device_module(self.device).set_device(self.gpu_id)
|
211
223
|
if self.device == "cuda":
|
212
224
|
backend = "nccl"
|
@@ -216,9 +228,12 @@ class ModelRunner:
|
|
216
228
|
backend = "gloo"
|
217
229
|
elif self.device == "hpu":
|
218
230
|
backend = "hccl"
|
231
|
+
elif self.device == "cpu":
|
232
|
+
backend = "gloo"
|
219
233
|
|
220
234
|
if not self.server_args.enable_p2p_check:
|
221
|
-
|
235
|
+
monkey_patch_p2p_access_check()
|
236
|
+
|
222
237
|
if self.server_args.dist_init_addr:
|
223
238
|
dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
|
224
239
|
else:
|
@@ -226,7 +241,7 @@ class ModelRunner:
|
|
226
241
|
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
227
242
|
|
228
243
|
if not self.is_draft_worker:
|
229
|
-
# Only
|
244
|
+
# Only initialize the distributed environment on the target model worker.
|
230
245
|
init_distributed_environment(
|
231
246
|
backend=backend,
|
232
247
|
world_size=self.tp_size,
|
@@ -235,11 +250,18 @@ class ModelRunner:
|
|
235
250
|
distributed_init_method=dist_init_method,
|
236
251
|
)
|
237
252
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
253
|
+
initialize_dp_attention(
|
254
|
+
enable_dp_attention=self.server_args.enable_dp_attention,
|
255
|
+
tp_rank=self.tp_rank,
|
256
|
+
tp_size=self.tp_size,
|
257
|
+
dp_size=self.server_args.dp_size,
|
258
|
+
)
|
238
259
|
|
239
260
|
min_per_gpu_memory = get_available_gpu_memory(
|
240
261
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
241
262
|
)
|
242
263
|
self.tp_group = get_tp_group()
|
264
|
+
self.attention_tp_group = get_attention_tp_group()
|
243
265
|
|
244
266
|
# Check memory for tensor parallelism
|
245
267
|
if self.tp_size > 1:
|
@@ -257,7 +279,8 @@ class ModelRunner:
|
|
257
279
|
)
|
258
280
|
|
259
281
|
# This can reduce thread conflicts and speed up weight loading.
|
260
|
-
|
282
|
+
if self.device != "cpu":
|
283
|
+
torch.set_num_threads(1)
|
261
284
|
if self.device == "cuda":
|
262
285
|
if torch.cuda.get_device_capability()[0] < 8:
|
263
286
|
logger.info(
|
@@ -277,12 +300,15 @@ class ModelRunner:
|
|
277
300
|
monkey_patch_vllm_gguf_config()
|
278
301
|
|
279
302
|
# Load the model
|
303
|
+
# Remove monkey_patch when linear.py quant remove dependencies with vllm
|
304
|
+
monkey_patch_vllm_parallel_state()
|
280
305
|
with self.memory_saver_adapter.region():
|
281
306
|
self.model = get_model(
|
282
307
|
model_config=self.model_config,
|
283
308
|
load_config=self.load_config,
|
284
309
|
device_config=DeviceConfig(self.device),
|
285
310
|
)
|
311
|
+
monkey_patch_vllm_parallel_state(reverse=True)
|
286
312
|
|
287
313
|
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
288
314
|
if self.server_args.quantization_param_path is not None:
|
@@ -521,7 +547,7 @@ class ModelRunner:
|
|
521
547
|
)
|
522
548
|
else:
|
523
549
|
cell_size = (
|
524
|
-
self.model_config.get_num_kv_heads(
|
550
|
+
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
525
551
|
* self.model_config.head_dim
|
526
552
|
* self.model_config.num_hidden_layers
|
527
553
|
* 2
|
@@ -595,7 +621,6 @@ class ModelRunner:
|
|
595
621
|
size=max_num_reqs + 1,
|
596
622
|
max_context_len=self.model_config.context_len + 4,
|
597
623
|
device=self.device,
|
598
|
-
use_records=False,
|
599
624
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
600
625
|
)
|
601
626
|
if (
|
@@ -615,7 +640,7 @@ class ModelRunner:
|
|
615
640
|
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
616
641
|
self.max_total_num_tokens,
|
617
642
|
dtype=self.kv_cache_dtype,
|
618
|
-
head_num=self.model_config.get_num_kv_heads(
|
643
|
+
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
619
644
|
head_dim=self.model_config.head_dim,
|
620
645
|
layer_num=self.model_config.num_hidden_layers,
|
621
646
|
device=self.device,
|
@@ -626,7 +651,7 @@ class ModelRunner:
|
|
626
651
|
self.token_to_kv_pool = MHATokenToKVPool(
|
627
652
|
self.max_total_num_tokens,
|
628
653
|
dtype=self.kv_cache_dtype,
|
629
|
-
head_num=self.model_config.get_num_kv_heads(
|
654
|
+
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
630
655
|
head_dim=self.model_config.head_dim,
|
631
656
|
layer_num=self.model_config.num_hidden_layers,
|
632
657
|
device=self.device,
|
@@ -21,14 +21,14 @@ from huggingface_hub import HfApi, hf_hub_download
|
|
21
21
|
from torch import nn
|
22
22
|
from transformers import AutoModelForCausalLM, PretrainedConfig
|
23
23
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
24
|
-
from vllm.distributed import (
|
25
|
-
get_tensor_model_parallel_rank,
|
26
|
-
get_tensor_model_parallel_world_size,
|
27
|
-
)
|
28
24
|
|
29
25
|
from sglang.srt.configs.device_config import DeviceConfig
|
30
26
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
31
27
|
from sglang.srt.configs.model_config import ModelConfig
|
28
|
+
from sglang.srt.distributed import (
|
29
|
+
get_tensor_model_parallel_rank,
|
30
|
+
get_tensor_model_parallel_world_size,
|
31
|
+
)
|
32
32
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
33
33
|
from sglang.srt.model_loader.utils import (
|
34
34
|
get_model_architecture,
|
@@ -374,6 +374,78 @@ class DefaultModelLoader(BaseModelLoader):
|
|
374
374
|
return model.eval()
|
375
375
|
|
376
376
|
|
377
|
+
class LayeredModelLoader(DefaultModelLoader):
|
378
|
+
"""Model loader that loads weights layer by layer so that one can quantize a
|
379
|
+
layer before loading another to make the peak memory envelope smaller."""
|
380
|
+
|
381
|
+
def __init__(self, load_config: LoadConfig):
|
382
|
+
# Back to the default load format
|
383
|
+
load_config.load_format = LoadFormat.AUTO
|
384
|
+
super().__init__(load_config)
|
385
|
+
|
386
|
+
def load_model(
|
387
|
+
self,
|
388
|
+
*,
|
389
|
+
model_config: ModelConfig,
|
390
|
+
device_config: DeviceConfig,
|
391
|
+
) -> nn.Module:
|
392
|
+
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
393
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
394
|
+
|
395
|
+
torchao_config = global_server_args_dict.get("torchao_config")
|
396
|
+
target_device = torch.device(device_config.device)
|
397
|
+
|
398
|
+
with set_default_torch_dtype(model_config.dtype):
|
399
|
+
# Create model on meta device
|
400
|
+
with torch.device("meta"):
|
401
|
+
model = _initialize_model(
|
402
|
+
model_config,
|
403
|
+
self.load_config,
|
404
|
+
)
|
405
|
+
|
406
|
+
# Check model's layered load support
|
407
|
+
if not hasattr(model, "load_weights_to_module"):
|
408
|
+
raise ValueError(
|
409
|
+
"LayeredModelLoader requires the model to have a "
|
410
|
+
"`load_weights_to_module` method. "
|
411
|
+
f"{model_config.model_path} does not support it."
|
412
|
+
)
|
413
|
+
|
414
|
+
# Get all weights from disk
|
415
|
+
weights = self._get_all_weights(model_config, model)
|
416
|
+
|
417
|
+
# Helper function to recursively fill the weights of a module
|
418
|
+
def fill_module(module, fqn: List[str], weights):
|
419
|
+
"""
|
420
|
+
fqn: list of strings representing the fully qualified name of `module`.
|
421
|
+
"""
|
422
|
+
# Layer by layer
|
423
|
+
for name, submod in module.named_children():
|
424
|
+
fill_module(submod, fqn + [name], weights)
|
425
|
+
|
426
|
+
# First materialize on target device
|
427
|
+
module.to_empty(device=target_device, recurse=False)
|
428
|
+
fqn_path = ".".join(fqn)
|
429
|
+
# Fill weights
|
430
|
+
model.load_weights_to_module(
|
431
|
+
fqn_path,
|
432
|
+
weights,
|
433
|
+
)
|
434
|
+
# Quantize weights if applicable
|
435
|
+
if torchao_config and "proj" in fqn_path:
|
436
|
+
# Note: `None` here is needed to indicate no filter, see
|
437
|
+
# `apply_torchao_config_to_model` for details.
|
438
|
+
apply_torchao_config_to_model(module, torchao_config, None)
|
439
|
+
|
440
|
+
# Start calling on root module
|
441
|
+
fill_module(model, [], weights)
|
442
|
+
|
443
|
+
if torchao_config:
|
444
|
+
model.torchao_applied = True
|
445
|
+
|
446
|
+
return model.eval()
|
447
|
+
|
448
|
+
|
377
449
|
class DummyModelLoader(BaseModelLoader):
|
378
450
|
"""Model loader that will set model weights to random values."""
|
379
451
|
|
@@ -496,7 +568,8 @@ class ShardedStateLoader(BaseModelLoader):
|
|
496
568
|
device_config: DeviceConfig,
|
497
569
|
) -> nn.Module:
|
498
570
|
from safetensors.torch import safe_open
|
499
|
-
|
571
|
+
|
572
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
500
573
|
|
501
574
|
local_model_path = self._prepare_weights(
|
502
575
|
model_config.model_path, model_config.revision
|
@@ -556,7 +629,8 @@ class ShardedStateLoader(BaseModelLoader):
|
|
556
629
|
max_size: Optional[int] = None,
|
557
630
|
) -> None:
|
558
631
|
from safetensors.torch import save_file
|
559
|
-
|
632
|
+
|
633
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
560
634
|
|
561
635
|
if pattern is None:
|
562
636
|
pattern = ShardedStateLoader.DEFAULT_PATTERN
|
@@ -1147,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|
1147
1221
|
if load_config.load_format == LoadFormat.GGUF:
|
1148
1222
|
return GGUFModelLoader(load_config)
|
1149
1223
|
|
1224
|
+
if load_config.load_format == LoadFormat.LAYERED:
|
1225
|
+
return LayeredModelLoader(load_config)
|
1226
|
+
|
1150
1227
|
return DefaultModelLoader(load_config)
|