sglang 0.3.1.post1__py3-none-any.whl → 0.3.1.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +11 -2
- sglang/bench_server_latency.py +187 -0
- sglang/bench_serving.py +1 -1
- sglang/srt/layers/activation.py +8 -4
- sglang/srt/layers/attention_backend.py +3 -1
- sglang/srt/layers/layernorm.py +10 -7
- sglang/srt/layers/linear.py +1133 -0
- sglang/srt/layers/quantization/__init__.py +76 -0
- sglang/srt/layers/quantization/base_config.py +122 -0
- sglang/srt/layers/sampler.py +9 -2
- sglang/srt/managers/io_struct.py +3 -0
- sglang/srt/managers/policy_scheduler.py +49 -93
- sglang/srt/managers/schedule_batch.py +1 -1
- sglang/srt/managers/tp_worker.py +11 -6
- sglang/srt/model_executor/cuda_graph_runner.py +15 -14
- sglang/srt/model_executor/model_runner.py +13 -5
- sglang/srt/models/baichuan.py +1 -1
- sglang/srt/models/chatglm.py +6 -6
- sglang/srt/models/commandr.py +7 -7
- sglang/srt/models/dbrx.py +7 -7
- sglang/srt/models/deepseek.py +7 -7
- sglang/srt/models/deepseek_v2.py +9 -9
- sglang/srt/models/exaone.py +6 -6
- sglang/srt/models/gemma.py +6 -6
- sglang/srt/models/gemma2.py +6 -6
- sglang/srt/models/gpt_bigcode.py +6 -6
- sglang/srt/models/grok.py +6 -6
- sglang/srt/models/internlm2.py +6 -6
- sglang/srt/models/llama.py +7 -9
- sglang/srt/models/llama_classification.py +3 -4
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +6 -6
- sglang/srt/models/minicpm3.py +3 -3
- sglang/srt/models/mixtral.py +6 -6
- sglang/srt/models/mixtral_quant.py +6 -6
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen.py +6 -6
- sglang/srt/models/qwen2.py +6 -6
- sglang/srt/models/qwen2_moe.py +7 -7
- sglang/srt/models/stablelm.py +6 -6
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +2 -5
- sglang/srt/models/yivl.py +1 -1
- sglang/srt/server_args.py +17 -21
- sglang/srt/utils.py +21 -1
- sglang/test/few_shot_gsm8k.py +8 -2
- sglang/test/test_utils.py +5 -2
- sglang/version.py +1 -1
- {sglang-0.3.1.post1.dist-info → sglang-0.3.1.post3.dist-info}/METADATA +5 -5
- {sglang-0.3.1.post1.dist-info → sglang-0.3.1.post3.dist-info}/RECORD +54 -50
- {sglang-0.3.1.post1.dist-info → sglang-0.3.1.post3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post1.dist-info → sglang-0.3.1.post3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post1.dist-info → sglang-0.3.1.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,76 @@
|
|
1
|
+
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
2
|
+
|
3
|
+
from typing import Dict, Type
|
4
|
+
|
5
|
+
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
6
|
+
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
7
|
+
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
8
|
+
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
9
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
10
|
+
CompressedTensorsConfig,
|
11
|
+
)
|
12
|
+
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
13
|
+
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
14
|
+
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
15
|
+
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
16
|
+
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
17
|
+
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
18
|
+
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
19
|
+
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
|
20
|
+
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
21
|
+
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
22
|
+
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
23
|
+
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
24
|
+
|
25
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
26
|
+
|
27
|
+
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
28
|
+
"aqlm": AQLMConfig,
|
29
|
+
"awq": AWQConfig,
|
30
|
+
"deepspeedfp": DeepSpeedFPConfig,
|
31
|
+
"tpu_int8": Int8TpuConfig,
|
32
|
+
"fp8": Fp8Config,
|
33
|
+
"fbgemm_fp8": FBGEMMFp8Config,
|
34
|
+
# The order of gptq methods is important for config.py iteration over
|
35
|
+
# override_quantization_method(..)
|
36
|
+
"marlin": MarlinConfig,
|
37
|
+
"gguf": GGUFConfig,
|
38
|
+
"gptq_marlin_24": GPTQMarlin24Config,
|
39
|
+
"gptq_marlin": GPTQMarlinConfig,
|
40
|
+
"awq_marlin": AWQMarlinConfig,
|
41
|
+
"gptq": GPTQConfig,
|
42
|
+
"squeezellm": SqueezeLLMConfig,
|
43
|
+
"compressed-tensors": CompressedTensorsConfig,
|
44
|
+
"bitsandbytes": BitsAndBytesConfig,
|
45
|
+
"qqq": QQQConfig,
|
46
|
+
"experts_int8": ExpertsInt8Config,
|
47
|
+
}
|
48
|
+
|
49
|
+
|
50
|
+
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
51
|
+
if quantization not in QUANTIZATION_METHODS:
|
52
|
+
raise ValueError(f"Invalid quantization method: {quantization}")
|
53
|
+
return QUANTIZATION_METHODS[quantization]
|
54
|
+
|
55
|
+
|
56
|
+
__all__ = [
|
57
|
+
"QuantizationConfig",
|
58
|
+
"get_quantization_config",
|
59
|
+
"QUANTIZATION_METHODS",
|
60
|
+
]
|
61
|
+
|
62
|
+
"""
|
63
|
+
def fp8_get_quant_method(
|
64
|
+
self, layer: torch.nn.Module, prefix: str
|
65
|
+
) -> Optional["QuantizeMethodBase"]:
|
66
|
+
if isinstance(layer, LinearBase):
|
67
|
+
if is_layer_skipped(prefix, self.ignored_layers):
|
68
|
+
return UnquantizedLinearMethod()
|
69
|
+
return Fp8LinearMethod(self)
|
70
|
+
elif isinstance(layer, FusedMoE):
|
71
|
+
return Fp8MoEMethod(self)
|
72
|
+
return None
|
73
|
+
|
74
|
+
|
75
|
+
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
76
|
+
"""
|
@@ -0,0 +1,122 @@
|
|
1
|
+
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import Any, Dict, List, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
|
9
|
+
|
10
|
+
class QuantizeMethodBase(ABC):
|
11
|
+
"""Base class for different quantized methods."""
|
12
|
+
|
13
|
+
@abstractmethod
|
14
|
+
def create_weights(
|
15
|
+
self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs
|
16
|
+
):
|
17
|
+
"""Create weights for a layer.
|
18
|
+
|
19
|
+
The weights will be set as attributes of the layer."""
|
20
|
+
raise NotImplementedError
|
21
|
+
|
22
|
+
@abstractmethod
|
23
|
+
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
|
24
|
+
"""Apply the weights in layer to the input tensor.
|
25
|
+
|
26
|
+
Expects create_weights to have been called before on the layer."""
|
27
|
+
raise NotImplementedError
|
28
|
+
|
29
|
+
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
30
|
+
"""Process the weight after loading.
|
31
|
+
|
32
|
+
This can be used for example, to transpose weights for computation.
|
33
|
+
"""
|
34
|
+
return
|
35
|
+
|
36
|
+
|
37
|
+
class QuantizationConfig(ABC):
|
38
|
+
"""Base class for quantization configs."""
|
39
|
+
|
40
|
+
@abstractmethod
|
41
|
+
def get_name(self) -> str:
|
42
|
+
"""Name of the quantization method."""
|
43
|
+
raise NotImplementedError
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
47
|
+
"""List of supported activation dtypes."""
|
48
|
+
raise NotImplementedError
|
49
|
+
|
50
|
+
@classmethod
|
51
|
+
@abstractmethod
|
52
|
+
def get_min_capability(cls) -> int:
|
53
|
+
"""Minimum GPU capability to support the quantization method.
|
54
|
+
|
55
|
+
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
56
|
+
This requirement is due to the custom CUDA kernels used by the
|
57
|
+
quantization method.
|
58
|
+
"""
|
59
|
+
raise NotImplementedError
|
60
|
+
|
61
|
+
@staticmethod
|
62
|
+
@abstractmethod
|
63
|
+
def get_config_filenames() -> List[str]:
|
64
|
+
"""List of filenames to search for in the model directory."""
|
65
|
+
raise NotImplementedError
|
66
|
+
|
67
|
+
@classmethod
|
68
|
+
@abstractmethod
|
69
|
+
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
70
|
+
"""Create a config class from the model's quantization config."""
|
71
|
+
raise NotImplementedError
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
75
|
+
"""
|
76
|
+
Detects if this quantization method can support a given checkpoint
|
77
|
+
format by overriding the user specified quantization method --
|
78
|
+
this method should only be overwritten by subclasses in exceptional
|
79
|
+
circumstances
|
80
|
+
"""
|
81
|
+
return None
|
82
|
+
|
83
|
+
@staticmethod
|
84
|
+
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
85
|
+
"""Get a value from the model's quantization config."""
|
86
|
+
for key in keys:
|
87
|
+
if key in config:
|
88
|
+
return config[key]
|
89
|
+
raise ValueError(
|
90
|
+
f"Cannot find any of {keys} in the model's " "quantization config."
|
91
|
+
)
|
92
|
+
|
93
|
+
@staticmethod
|
94
|
+
def get_from_keys_or(config: Dict[str, Any], keys: List[str], default: Any) -> Any:
|
95
|
+
"""Get a optional value from the model's quantization config."""
|
96
|
+
try:
|
97
|
+
return QuantizationConfig.get_from_keys(config, keys)
|
98
|
+
except ValueError:
|
99
|
+
return default
|
100
|
+
|
101
|
+
@abstractmethod
|
102
|
+
def get_quant_method(
|
103
|
+
self, layer: torch.nn.Module, prefix: str
|
104
|
+
) -> Optional[QuantizeMethodBase]:
|
105
|
+
"""Get the quantize method to use for the quantized layer.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
layer: The layer for the quant method.
|
109
|
+
prefix: The full name of the layer in the state dict
|
110
|
+
Returns:
|
111
|
+
The quantize method. None if the given layer doesn't support quant
|
112
|
+
method.
|
113
|
+
"""
|
114
|
+
raise NotImplementedError
|
115
|
+
|
116
|
+
@abstractmethod
|
117
|
+
def get_scaled_act_names(self) -> List[str]:
|
118
|
+
"""Returns the activation function names that should be post-scaled.
|
119
|
+
|
120
|
+
For now, this is only used by AWQ.
|
121
|
+
"""
|
122
|
+
raise NotImplementedError
|
sglang/srt/layers/sampler.py
CHANGED
@@ -31,8 +31,11 @@ class Sampler(nn.Module):
|
|
31
31
|
logits = logits.next_token_logits
|
32
32
|
|
33
33
|
# Post process logits
|
34
|
+
logits = logits.contiguous()
|
34
35
|
logits.div_(sampling_info.temperatures)
|
35
|
-
probs =
|
36
|
+
probs = torch.softmax(logits, dim=-1)
|
37
|
+
logits = None
|
38
|
+
del logits
|
36
39
|
|
37
40
|
if torch.any(torch.isnan(probs)):
|
38
41
|
logger.warning("Detected errors during sampling! NaN in the probability.")
|
@@ -53,7 +56,11 @@ class Sampler(nn.Module):
|
|
53
56
|
)
|
54
57
|
else:
|
55
58
|
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
56
|
-
probs,
|
59
|
+
probs,
|
60
|
+
uniform_samples,
|
61
|
+
sampling_info.top_ks,
|
62
|
+
sampling_info.top_ps,
|
63
|
+
filter_apply_order="joint",
|
57
64
|
)
|
58
65
|
|
59
66
|
if not torch.all(success):
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -133,6 +133,9 @@ class GenerateReqInput:
|
|
133
133
|
self.image_data = [None] * num
|
134
134
|
elif not isinstance(self.image_data, list):
|
135
135
|
self.image_data = [self.image_data] * num
|
136
|
+
elif isinstance(self.image_data, list):
|
137
|
+
# multi-image with n > 1
|
138
|
+
self.image_data = self.image_data * num
|
136
139
|
|
137
140
|
if self.sampling_params is None:
|
138
141
|
self.sampling_params = [{}] * num
|
@@ -119,19 +119,32 @@ class PrefillAdder:
|
|
119
119
|
self.running_batch = running_batch
|
120
120
|
self.new_token_ratio = new_token_ratio
|
121
121
|
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
122
|
-
self.rem_total_tokens_ = self.rem_total_tokens
|
123
|
-
self.total_tokens = rem_total_tokens
|
124
122
|
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
125
123
|
self.rem_chunk_tokens = rem_chunk_tokens
|
126
124
|
if self.rem_chunk_tokens is not None:
|
127
125
|
self.rem_chunk_tokens -= mixed_with_decode_tokens
|
128
126
|
|
127
|
+
self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens
|
128
|
+
|
129
129
|
self.req_states = None
|
130
130
|
self.can_run_list = []
|
131
131
|
self.new_inflight_req = None
|
132
132
|
self.log_hit_tokens = 0
|
133
133
|
self.log_input_tokens = 0
|
134
134
|
|
135
|
+
if running_batch is not None:
|
136
|
+
# Pre-remove the tokens which will be occupied by the running requests
|
137
|
+
self.rem_total_tokens -= sum(
|
138
|
+
[
|
139
|
+
min(
|
140
|
+
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
141
|
+
CLIP_MAX_NEW_TOKENS,
|
142
|
+
)
|
143
|
+
* self.new_token_ratio
|
144
|
+
for r in running_batch.reqs
|
145
|
+
]
|
146
|
+
)
|
147
|
+
|
135
148
|
def no_remaining_tokens(self):
|
136
149
|
return (
|
137
150
|
self.rem_total_tokens <= 0
|
@@ -141,31 +154,14 @@ class PrefillAdder:
|
|
141
154
|
if self.rem_chunk_tokens is not None
|
142
155
|
else False
|
143
156
|
)
|
144
|
-
|
145
|
-
|
146
|
-
def remove_running_tokens(self, running_batch: ScheduleBatch):
|
147
|
-
self.rem_total_tokens -= sum(
|
148
|
-
[
|
149
|
-
min(
|
150
|
-
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
151
|
-
CLIP_MAX_NEW_TOKENS,
|
152
|
-
)
|
153
|
-
* self.new_token_ratio
|
154
|
-
for r in running_batch.reqs
|
155
|
-
]
|
156
|
-
)
|
157
|
-
self.rem_total_tokens_ -= sum(
|
158
|
-
[
|
159
|
-
r.sampling_params.max_new_tokens - len(r.output_ids)
|
160
|
-
for r in running_batch.reqs
|
161
|
-
]
|
157
|
+
or self.cur_rem_tokens <= 0
|
162
158
|
)
|
163
159
|
|
164
160
|
def _prefill_one_req(
|
165
161
|
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
166
162
|
):
|
167
163
|
self.rem_total_tokens -= extend_input_len + max_new_tokens
|
168
|
-
self.
|
164
|
+
self.cur_rem_tokens -= extend_input_len
|
169
165
|
self.rem_input_tokens -= extend_input_len
|
170
166
|
if self.rem_chunk_tokens is not None:
|
171
167
|
self.rem_chunk_tokens -= extend_input_len
|
@@ -173,29 +169,7 @@ class PrefillAdder:
|
|
173
169
|
self.log_hit_tokens += prefix_len
|
174
170
|
self.log_input_tokens += extend_input_len
|
175
171
|
|
176
|
-
def add_inflight_req_ignore_eos(self, req: Req):
|
177
|
-
truncated = req.extend_input_len > self.rem_chunk_tokens
|
178
|
-
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
179
|
-
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
180
|
-
self.can_run_list.append(req)
|
181
|
-
|
182
|
-
self._prefill_one_req(
|
183
|
-
0,
|
184
|
-
req.extend_input_len,
|
185
|
-
(
|
186
|
-
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
187
|
-
if not truncated
|
188
|
-
else 0
|
189
|
-
),
|
190
|
-
)
|
191
|
-
|
192
|
-
# Return if chunked prefill not finished
|
193
|
-
return req if truncated else None
|
194
|
-
|
195
172
|
def add_inflight_req(self, req: Req):
|
196
|
-
if req.sampling_params.ignore_eos:
|
197
|
-
return self.add_inflight_req_ignore_eos(req)
|
198
|
-
|
199
173
|
truncated = req.extend_input_len > self.rem_chunk_tokens
|
200
174
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
201
175
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
@@ -225,7 +199,7 @@ class PrefillAdder:
|
|
225
199
|
self.rem_total_tokens += delta
|
226
200
|
|
227
201
|
def add_one_req_ignore_eos(self, req: Req):
|
228
|
-
def
|
202
|
+
def add_req_state(r, insert_sort=False):
|
229
203
|
new_token_ratio = (
|
230
204
|
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
|
231
205
|
)
|
@@ -235,56 +209,38 @@ class PrefillAdder:
|
|
235
209
|
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
|
236
210
|
|
237
211
|
if tokens_left > 0:
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
can_run = False
|
244
|
-
if (
|
245
|
-
req.extend_input_len + req.sampling_params.max_new_tokens
|
246
|
-
<= self.rem_total_tokens
|
247
|
-
):
|
248
|
-
can_run = True
|
249
|
-
|
250
|
-
if not can_run:
|
251
|
-
if self.req_states is None:
|
252
|
-
self.req_states = []
|
253
|
-
if self.running_batch is not None:
|
254
|
-
for r in self.running_batch.reqs:
|
255
|
-
state = get_req_state(r)
|
256
|
-
if state is not None:
|
257
|
-
self.req_states.append(state)
|
258
|
-
for r in self.can_run_list:
|
259
|
-
state = get_req_state(r)
|
260
|
-
if state is not None:
|
261
|
-
self.req_states.append(state)
|
262
|
-
state = get_req_state(req)
|
263
|
-
if state is not None:
|
264
|
-
self.req_states.append(state)
|
265
|
-
|
266
|
-
self.req_states.sort(key=lambda x: x[0])
|
267
|
-
else:
|
268
|
-
state = get_req_state(req)
|
269
|
-
if state is not None:
|
270
|
-
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
271
|
-
if tokens_left >= state[0]:
|
272
|
-
self.req_states.insert(i, state)
|
212
|
+
if not insert_sort:
|
213
|
+
self.req_states.append((tokens_left, tokens_occupied))
|
214
|
+
else:
|
215
|
+
for i in range(len(self.req_states)):
|
216
|
+
if tokens_left <= self.req_states[i][0]:
|
273
217
|
break
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
)
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
218
|
+
self.req_states.insert(i, (tokens_left, tokens_occupied))
|
219
|
+
|
220
|
+
if self.req_states is None:
|
221
|
+
self.req_states = []
|
222
|
+
add_req_state(req)
|
223
|
+
if self.running_batch is not None:
|
224
|
+
for r in self.running_batch.reqs:
|
225
|
+
add_req_state(r)
|
226
|
+
for r in self.can_run_list:
|
227
|
+
add_req_state(r)
|
228
|
+
self.req_states.sort(key=lambda x: x[0])
|
229
|
+
else:
|
230
|
+
add_req_state(req, insert_sort=True)
|
231
|
+
|
232
|
+
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
|
233
|
+
tokens_freed = 0
|
234
|
+
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
235
|
+
decode_steps = (
|
236
|
+
self.req_states[i + 1][0]
|
237
|
+
if i + 1 < len(self.req_states)
|
238
|
+
else tokens_left
|
239
|
+
)
|
240
|
+
bs = len(self.req_states) - i
|
241
|
+
if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
|
242
|
+
return False
|
243
|
+
tokens_freed += tokens_occupied
|
288
244
|
|
289
245
|
if req.extend_input_len <= self.rem_chunk_tokens:
|
290
246
|
self.can_run_list.append(req)
|
@@ -40,7 +40,7 @@ global_server_args_dict = {
|
|
40
40
|
"attention_backend": ServerArgs.attention_backend,
|
41
41
|
"sampling_backend": ServerArgs.sampling_backend,
|
42
42
|
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
43
|
-
"
|
43
|
+
"disable_mla": ServerArgs.disable_mla,
|
44
44
|
"torchao_config": ServerArgs.torchao_config,
|
45
45
|
}
|
46
46
|
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -445,9 +445,6 @@ class ModelTpServer:
|
|
445
445
|
num_mixed_running,
|
446
446
|
)
|
447
447
|
|
448
|
-
if self.running_batch is not None:
|
449
|
-
adder.remove_running_tokens(self.running_batch)
|
450
|
-
|
451
448
|
has_inflight = self.current_inflight_req is not None
|
452
449
|
if self.current_inflight_req is not None:
|
453
450
|
self.current_inflight_req.init_next_round_input(
|
@@ -465,9 +462,6 @@ class ModelTpServer:
|
|
465
462
|
)
|
466
463
|
|
467
464
|
for req in self.waiting_queue:
|
468
|
-
if adder.no_remaining_tokens():
|
469
|
-
break
|
470
|
-
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
471
465
|
if (
|
472
466
|
self.lora_paths is not None
|
473
467
|
and len(
|
@@ -478,6 +472,10 @@ class ModelTpServer:
|
|
478
472
|
> self.max_loras_per_batch
|
479
473
|
):
|
480
474
|
break
|
475
|
+
|
476
|
+
if adder.no_remaining_tokens():
|
477
|
+
break
|
478
|
+
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
481
479
|
res = adder.add_one_req(req)
|
482
480
|
if (
|
483
481
|
not res
|
@@ -507,6 +505,11 @@ class ModelTpServer:
|
|
507
505
|
else:
|
508
506
|
tree_cache_hit_rate = 0.0
|
509
507
|
|
508
|
+
num_used = self.max_total_num_tokens - (
|
509
|
+
self.token_to_kv_pool.available_size()
|
510
|
+
+ self.tree_cache.evictable_size()
|
511
|
+
)
|
512
|
+
|
510
513
|
if num_mixed_running > 0:
|
511
514
|
logger.info(
|
512
515
|
f"Prefill batch"
|
@@ -515,6 +518,7 @@ class ModelTpServer:
|
|
515
518
|
f"#new-token: {adder.log_input_tokens}, "
|
516
519
|
f"#cached-token: {adder.log_hit_tokens}, "
|
517
520
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
521
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
518
522
|
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
519
523
|
)
|
520
524
|
else:
|
@@ -524,6 +528,7 @@ class ModelTpServer:
|
|
524
528
|
f"#new-token: {adder.log_input_tokens}, "
|
525
529
|
f"#cached-token: {adder.log_hit_tokens}, "
|
526
530
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
531
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
527
532
|
f"#running-req: {running_bs}, "
|
528
533
|
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
529
534
|
)
|
@@ -108,6 +108,10 @@ class CudaGraphRunner:
|
|
108
108
|
self.capture_bs = list(range(1, 32)) + [64, 128]
|
109
109
|
else:
|
110
110
|
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
111
|
+
|
112
|
+
self.capture_bs = [
|
113
|
+
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
|
114
|
+
]
|
111
115
|
self.compile_bs = (
|
112
116
|
[
|
113
117
|
bs
|
@@ -118,21 +122,8 @@ class CudaGraphRunner:
|
|
118
122
|
else []
|
119
123
|
)
|
120
124
|
|
121
|
-
# Common inputs
|
122
|
-
self.max_bs = max(self.capture_bs)
|
123
|
-
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
124
|
-
self.req_pool_indices = torch.zeros(
|
125
|
-
(self.max_bs,), dtype=torch.int32, device="cuda"
|
126
|
-
)
|
127
|
-
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
128
|
-
self.position_ids_offsets = torch.ones(
|
129
|
-
(self.max_bs,), dtype=torch.int32, device="cuda"
|
130
|
-
)
|
131
|
-
self.out_cache_loc = torch.zeros(
|
132
|
-
(self.max_bs,), dtype=torch.int32, device="cuda"
|
133
|
-
)
|
134
|
-
|
135
125
|
# Attention backend
|
126
|
+
self.max_bs = max(self.capture_bs)
|
136
127
|
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
137
128
|
self.seq_len_fill_value = (
|
138
129
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
@@ -141,6 +132,16 @@ class CudaGraphRunner:
|
|
141
132
|
if self.use_torch_compile:
|
142
133
|
set_torch_compile_config()
|
143
134
|
|
135
|
+
# Common inputs
|
136
|
+
with torch.device("cuda"):
|
137
|
+
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32)
|
138
|
+
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
139
|
+
self.seq_lens = torch.full(
|
140
|
+
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
141
|
+
)
|
142
|
+
self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32)
|
143
|
+
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
144
|
+
|
144
145
|
# Capture
|
145
146
|
try:
|
146
147
|
self.capture()
|
@@ -86,12 +86,20 @@ class ModelRunner:
|
|
86
86
|
self.is_multimodal_model = is_multimodal_model(
|
87
87
|
self.model_config.hf_config.architectures
|
88
88
|
)
|
89
|
+
|
90
|
+
if (
|
91
|
+
self.model_config.attention_arch == AttentionArch.MLA
|
92
|
+
and not self.server_args.disable_mla
|
93
|
+
):
|
94
|
+
logger.info("MLA optimization is tunred on. Use triton backend.")
|
95
|
+
self.server_args.attention_backend = "triton"
|
96
|
+
|
89
97
|
global_server_args_dict.update(
|
90
98
|
{
|
91
99
|
"attention_backend": server_args.attention_backend,
|
92
100
|
"sampling_backend": server_args.sampling_backend,
|
93
101
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
94
|
-
"
|
102
|
+
"disable_mla": server_args.disable_mla,
|
95
103
|
"torchao_config": server_args.torchao_config,
|
96
104
|
}
|
97
105
|
)
|
@@ -329,7 +337,7 @@ class ModelRunner:
|
|
329
337
|
)
|
330
338
|
if (
|
331
339
|
self.model_config.attention_arch == AttentionArch.MLA
|
332
|
-
and self.server_args.
|
340
|
+
and not self.server_args.disable_mla
|
333
341
|
):
|
334
342
|
cell_size = (
|
335
343
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
@@ -392,12 +400,12 @@ class ModelRunner:
|
|
392
400
|
)
|
393
401
|
|
394
402
|
self.req_to_token_pool = ReqToTokenPool(
|
395
|
-
max_num_reqs,
|
396
|
-
self.model_config.context_len +
|
403
|
+
max_num_reqs + 1,
|
404
|
+
self.model_config.context_len + 4,
|
397
405
|
)
|
398
406
|
if (
|
399
407
|
self.model_config.attention_arch == AttentionArch.MLA
|
400
|
-
and self.server_args.
|
408
|
+
and not self.server_args.disable_mla
|
401
409
|
):
|
402
410
|
self.token_to_kv_pool = MLATokenToKVPool(
|
403
411
|
self.max_total_num_tokens,
|
sglang/srt/models/baichuan.py
CHANGED
@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (
|
|
34
34
|
QKVParallelLinear,
|
35
35
|
RowParallelLinear,
|
36
36
|
)
|
37
|
-
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
38
37
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
39
38
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
40
39
|
ParallelLMHead,
|
@@ -45,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
45
44
|
from sglang.srt.layers.activation import SiluAndMul
|
46
45
|
from sglang.srt.layers.layernorm import RMSNorm
|
47
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
50
|
|
sglang/srt/models/chatglm.py
CHANGED
@@ -24,12 +24,6 @@ from torch import nn
|
|
24
24
|
from torch.nn import LayerNorm
|
25
25
|
from vllm.config import CacheConfig
|
26
26
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
27
|
-
from vllm.model_executor.layers.linear import (
|
28
|
-
MergedColumnParallelLinear,
|
29
|
-
QKVParallelLinear,
|
30
|
-
RowParallelLinear,
|
31
|
-
)
|
32
|
-
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
33
27
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
34
28
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
35
29
|
ParallelLMHead,
|
@@ -40,7 +34,13 @@ from vllm.transformers_utils.configs import ChatGLMConfig
|
|
40
34
|
|
41
35
|
from sglang.srt.layers.activation import SiluAndMul
|
42
36
|
from sglang.srt.layers.layernorm import RMSNorm
|
37
|
+
from sglang.srt.layers.linear import (
|
38
|
+
MergedColumnParallelLinear,
|
39
|
+
QKVParallelLinear,
|
40
|
+
RowParallelLinear,
|
41
|
+
)
|
43
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
44
44
|
from sglang.srt.layers.radix_attention import RadixAttention
|
45
45
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
46
46
|
|
sglang/srt/models/commandr.py
CHANGED
@@ -50,21 +50,21 @@ from vllm.distributed import (
|
|
50
50
|
get_tensor_model_parallel_rank,
|
51
51
|
get_tensor_model_parallel_world_size,
|
52
52
|
)
|
53
|
-
from vllm.model_executor.layers.linear import (
|
54
|
-
MergedColumnParallelLinear,
|
55
|
-
QKVParallelLinear,
|
56
|
-
RowParallelLinear,
|
57
|
-
)
|
58
|
-
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
59
53
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
60
54
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
61
55
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
62
|
-
from vllm.model_executor.utils import set_weight_attrs
|
63
56
|
|
64
57
|
from sglang.srt.layers.activation import SiluAndMul
|
58
|
+
from sglang.srt.layers.linear import (
|
59
|
+
MergedColumnParallelLinear,
|
60
|
+
QKVParallelLinear,
|
61
|
+
RowParallelLinear,
|
62
|
+
)
|
65
63
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
64
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
66
65
|
from sglang.srt.layers.radix_attention import RadixAttention
|
67
66
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
67
|
+
from sglang.srt.utils import set_weight_attrs
|
68
68
|
|
69
69
|
|
70
70
|
@torch.compile
|