sglang 0.3.1.post1__py3-none-any.whl → 0.3.1.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/bench_latency.py +11 -2
  2. sglang/bench_server_latency.py +187 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/srt/layers/activation.py +8 -4
  5. sglang/srt/layers/attention_backend.py +3 -1
  6. sglang/srt/layers/layernorm.py +10 -7
  7. sglang/srt/layers/linear.py +1133 -0
  8. sglang/srt/layers/quantization/__init__.py +76 -0
  9. sglang/srt/layers/quantization/base_config.py +122 -0
  10. sglang/srt/layers/sampler.py +9 -2
  11. sglang/srt/managers/io_struct.py +3 -0
  12. sglang/srt/managers/policy_scheduler.py +49 -93
  13. sglang/srt/managers/schedule_batch.py +1 -1
  14. sglang/srt/managers/tp_worker.py +11 -6
  15. sglang/srt/model_executor/cuda_graph_runner.py +15 -14
  16. sglang/srt/model_executor/model_runner.py +13 -5
  17. sglang/srt/models/baichuan.py +1 -1
  18. sglang/srt/models/chatglm.py +6 -6
  19. sglang/srt/models/commandr.py +7 -7
  20. sglang/srt/models/dbrx.py +7 -7
  21. sglang/srt/models/deepseek.py +7 -7
  22. sglang/srt/models/deepseek_v2.py +9 -9
  23. sglang/srt/models/exaone.py +6 -6
  24. sglang/srt/models/gemma.py +6 -6
  25. sglang/srt/models/gemma2.py +6 -6
  26. sglang/srt/models/gpt_bigcode.py +6 -6
  27. sglang/srt/models/grok.py +6 -6
  28. sglang/srt/models/internlm2.py +6 -6
  29. sglang/srt/models/llama.py +7 -9
  30. sglang/srt/models/llama_classification.py +3 -4
  31. sglang/srt/models/llava.py +1 -1
  32. sglang/srt/models/llavavid.py +1 -1
  33. sglang/srt/models/minicpm.py +6 -6
  34. sglang/srt/models/minicpm3.py +3 -3
  35. sglang/srt/models/mixtral.py +6 -6
  36. sglang/srt/models/mixtral_quant.py +6 -6
  37. sglang/srt/models/olmoe.py +1 -1
  38. sglang/srt/models/qwen.py +6 -6
  39. sglang/srt/models/qwen2.py +6 -6
  40. sglang/srt/models/qwen2_moe.py +7 -7
  41. sglang/srt/models/stablelm.py +6 -6
  42. sglang/srt/models/xverse.py +2 -4
  43. sglang/srt/models/xverse_moe.py +2 -5
  44. sglang/srt/models/yivl.py +1 -1
  45. sglang/srt/server_args.py +17 -21
  46. sglang/srt/utils.py +21 -1
  47. sglang/test/few_shot_gsm8k.py +8 -2
  48. sglang/test/test_utils.py +5 -2
  49. sglang/version.py +1 -1
  50. {sglang-0.3.1.post1.dist-info → sglang-0.3.1.post3.dist-info}/METADATA +5 -5
  51. {sglang-0.3.1.post1.dist-info → sglang-0.3.1.post3.dist-info}/RECORD +54 -50
  52. {sglang-0.3.1.post1.dist-info → sglang-0.3.1.post3.dist-info}/LICENSE +0 -0
  53. {sglang-0.3.1.post1.dist-info → sglang-0.3.1.post3.dist-info}/WHEEL +0 -0
  54. {sglang-0.3.1.post1.dist-info → sglang-0.3.1.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,76 @@
1
+ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
2
+
3
+ from typing import Dict, Type
4
+
5
+ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
6
+ from vllm.model_executor.layers.quantization.awq import AWQConfig
7
+ from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
8
+ from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
9
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
10
+ CompressedTensorsConfig,
11
+ )
12
+ from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
13
+ from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
14
+ from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
15
+ from vllm.model_executor.layers.quantization.fp8 import Fp8Config
16
+ from vllm.model_executor.layers.quantization.gguf import GGUFConfig
17
+ from vllm.model_executor.layers.quantization.gptq import GPTQConfig
18
+ from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
19
+ from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
20
+ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
21
+ from vllm.model_executor.layers.quantization.qqq import QQQConfig
22
+ from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
23
+ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
24
+
25
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
+
27
+ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
28
+ "aqlm": AQLMConfig,
29
+ "awq": AWQConfig,
30
+ "deepspeedfp": DeepSpeedFPConfig,
31
+ "tpu_int8": Int8TpuConfig,
32
+ "fp8": Fp8Config,
33
+ "fbgemm_fp8": FBGEMMFp8Config,
34
+ # The order of gptq methods is important for config.py iteration over
35
+ # override_quantization_method(..)
36
+ "marlin": MarlinConfig,
37
+ "gguf": GGUFConfig,
38
+ "gptq_marlin_24": GPTQMarlin24Config,
39
+ "gptq_marlin": GPTQMarlinConfig,
40
+ "awq_marlin": AWQMarlinConfig,
41
+ "gptq": GPTQConfig,
42
+ "squeezellm": SqueezeLLMConfig,
43
+ "compressed-tensors": CompressedTensorsConfig,
44
+ "bitsandbytes": BitsAndBytesConfig,
45
+ "qqq": QQQConfig,
46
+ "experts_int8": ExpertsInt8Config,
47
+ }
48
+
49
+
50
+ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
51
+ if quantization not in QUANTIZATION_METHODS:
52
+ raise ValueError(f"Invalid quantization method: {quantization}")
53
+ return QUANTIZATION_METHODS[quantization]
54
+
55
+
56
+ __all__ = [
57
+ "QuantizationConfig",
58
+ "get_quantization_config",
59
+ "QUANTIZATION_METHODS",
60
+ ]
61
+
62
+ """
63
+ def fp8_get_quant_method(
64
+ self, layer: torch.nn.Module, prefix: str
65
+ ) -> Optional["QuantizeMethodBase"]:
66
+ if isinstance(layer, LinearBase):
67
+ if is_layer_skipped(prefix, self.ignored_layers):
68
+ return UnquantizedLinearMethod()
69
+ return Fp8LinearMethod(self)
70
+ elif isinstance(layer, FusedMoE):
71
+ return Fp8MoEMethod(self)
72
+ return None
73
+
74
+
75
+ setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
76
+ """
@@ -0,0 +1,122 @@
1
+ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ class QuantizeMethodBase(ABC):
11
+ """Base class for different quantized methods."""
12
+
13
+ @abstractmethod
14
+ def create_weights(
15
+ self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs
16
+ ):
17
+ """Create weights for a layer.
18
+
19
+ The weights will be set as attributes of the layer."""
20
+ raise NotImplementedError
21
+
22
+ @abstractmethod
23
+ def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
24
+ """Apply the weights in layer to the input tensor.
25
+
26
+ Expects create_weights to have been called before on the layer."""
27
+ raise NotImplementedError
28
+
29
+ def process_weights_after_loading(self, layer: nn.Module) -> None:
30
+ """Process the weight after loading.
31
+
32
+ This can be used for example, to transpose weights for computation.
33
+ """
34
+ return
35
+
36
+
37
+ class QuantizationConfig(ABC):
38
+ """Base class for quantization configs."""
39
+
40
+ @abstractmethod
41
+ def get_name(self) -> str:
42
+ """Name of the quantization method."""
43
+ raise NotImplementedError
44
+
45
+ @abstractmethod
46
+ def get_supported_act_dtypes(self) -> List[torch.dtype]:
47
+ """List of supported activation dtypes."""
48
+ raise NotImplementedError
49
+
50
+ @classmethod
51
+ @abstractmethod
52
+ def get_min_capability(cls) -> int:
53
+ """Minimum GPU capability to support the quantization method.
54
+
55
+ E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
56
+ This requirement is due to the custom CUDA kernels used by the
57
+ quantization method.
58
+ """
59
+ raise NotImplementedError
60
+
61
+ @staticmethod
62
+ @abstractmethod
63
+ def get_config_filenames() -> List[str]:
64
+ """List of filenames to search for in the model directory."""
65
+ raise NotImplementedError
66
+
67
+ @classmethod
68
+ @abstractmethod
69
+ def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
70
+ """Create a config class from the model's quantization config."""
71
+ raise NotImplementedError
72
+
73
+ @classmethod
74
+ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
75
+ """
76
+ Detects if this quantization method can support a given checkpoint
77
+ format by overriding the user specified quantization method --
78
+ this method should only be overwritten by subclasses in exceptional
79
+ circumstances
80
+ """
81
+ return None
82
+
83
+ @staticmethod
84
+ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
85
+ """Get a value from the model's quantization config."""
86
+ for key in keys:
87
+ if key in config:
88
+ return config[key]
89
+ raise ValueError(
90
+ f"Cannot find any of {keys} in the model's " "quantization config."
91
+ )
92
+
93
+ @staticmethod
94
+ def get_from_keys_or(config: Dict[str, Any], keys: List[str], default: Any) -> Any:
95
+ """Get a optional value from the model's quantization config."""
96
+ try:
97
+ return QuantizationConfig.get_from_keys(config, keys)
98
+ except ValueError:
99
+ return default
100
+
101
+ @abstractmethod
102
+ def get_quant_method(
103
+ self, layer: torch.nn.Module, prefix: str
104
+ ) -> Optional[QuantizeMethodBase]:
105
+ """Get the quantize method to use for the quantized layer.
106
+
107
+ Args:
108
+ layer: The layer for the quant method.
109
+ prefix: The full name of the layer in the state dict
110
+ Returns:
111
+ The quantize method. None if the given layer doesn't support quant
112
+ method.
113
+ """
114
+ raise NotImplementedError
115
+
116
+ @abstractmethod
117
+ def get_scaled_act_names(self) -> List[str]:
118
+ """Returns the activation function names that should be post-scaled.
119
+
120
+ For now, this is only used by AWQ.
121
+ """
122
+ raise NotImplementedError
@@ -31,8 +31,11 @@ class Sampler(nn.Module):
31
31
  logits = logits.next_token_logits
32
32
 
33
33
  # Post process logits
34
+ logits = logits.contiguous()
34
35
  logits.div_(sampling_info.temperatures)
35
- probs = logits[:] = torch.softmax(logits, dim=-1)
36
+ probs = torch.softmax(logits, dim=-1)
37
+ logits = None
38
+ del logits
36
39
 
37
40
  if torch.any(torch.isnan(probs)):
38
41
  logger.warning("Detected errors during sampling! NaN in the probability.")
@@ -53,7 +56,11 @@ class Sampler(nn.Module):
53
56
  )
54
57
  else:
55
58
  batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
56
- probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
59
+ probs,
60
+ uniform_samples,
61
+ sampling_info.top_ks,
62
+ sampling_info.top_ps,
63
+ filter_apply_order="joint",
57
64
  )
58
65
 
59
66
  if not torch.all(success):
@@ -133,6 +133,9 @@ class GenerateReqInput:
133
133
  self.image_data = [None] * num
134
134
  elif not isinstance(self.image_data, list):
135
135
  self.image_data = [self.image_data] * num
136
+ elif isinstance(self.image_data, list):
137
+ # multi-image with n > 1
138
+ self.image_data = self.image_data * num
136
139
 
137
140
  if self.sampling_params is None:
138
141
  self.sampling_params = [{}] * num
@@ -119,19 +119,32 @@ class PrefillAdder:
119
119
  self.running_batch = running_batch
120
120
  self.new_token_ratio = new_token_ratio
121
121
  self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
122
- self.rem_total_tokens_ = self.rem_total_tokens
123
- self.total_tokens = rem_total_tokens
124
122
  self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
125
123
  self.rem_chunk_tokens = rem_chunk_tokens
126
124
  if self.rem_chunk_tokens is not None:
127
125
  self.rem_chunk_tokens -= mixed_with_decode_tokens
128
126
 
127
+ self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens
128
+
129
129
  self.req_states = None
130
130
  self.can_run_list = []
131
131
  self.new_inflight_req = None
132
132
  self.log_hit_tokens = 0
133
133
  self.log_input_tokens = 0
134
134
 
135
+ if running_batch is not None:
136
+ # Pre-remove the tokens which will be occupied by the running requests
137
+ self.rem_total_tokens -= sum(
138
+ [
139
+ min(
140
+ (r.sampling_params.max_new_tokens - len(r.output_ids)),
141
+ CLIP_MAX_NEW_TOKENS,
142
+ )
143
+ * self.new_token_ratio
144
+ for r in running_batch.reqs
145
+ ]
146
+ )
147
+
135
148
  def no_remaining_tokens(self):
136
149
  return (
137
150
  self.rem_total_tokens <= 0
@@ -141,31 +154,14 @@ class PrefillAdder:
141
154
  if self.rem_chunk_tokens is not None
142
155
  else False
143
156
  )
144
- )
145
-
146
- def remove_running_tokens(self, running_batch: ScheduleBatch):
147
- self.rem_total_tokens -= sum(
148
- [
149
- min(
150
- (r.sampling_params.max_new_tokens - len(r.output_ids)),
151
- CLIP_MAX_NEW_TOKENS,
152
- )
153
- * self.new_token_ratio
154
- for r in running_batch.reqs
155
- ]
156
- )
157
- self.rem_total_tokens_ -= sum(
158
- [
159
- r.sampling_params.max_new_tokens - len(r.output_ids)
160
- for r in running_batch.reqs
161
- ]
157
+ or self.cur_rem_tokens <= 0
162
158
  )
163
159
 
164
160
  def _prefill_one_req(
165
161
  self, prefix_len: int, extend_input_len: int, max_new_tokens: int
166
162
  ):
167
163
  self.rem_total_tokens -= extend_input_len + max_new_tokens
168
- self.rem_total_tokens_ -= extend_input_len + max_new_tokens
164
+ self.cur_rem_tokens -= extend_input_len
169
165
  self.rem_input_tokens -= extend_input_len
170
166
  if self.rem_chunk_tokens is not None:
171
167
  self.rem_chunk_tokens -= extend_input_len
@@ -173,29 +169,7 @@ class PrefillAdder:
173
169
  self.log_hit_tokens += prefix_len
174
170
  self.log_input_tokens += extend_input_len
175
171
 
176
- def add_inflight_req_ignore_eos(self, req: Req):
177
- truncated = req.extend_input_len > self.rem_chunk_tokens
178
- req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
179
- req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
180
- self.can_run_list.append(req)
181
-
182
- self._prefill_one_req(
183
- 0,
184
- req.extend_input_len,
185
- (
186
- min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
187
- if not truncated
188
- else 0
189
- ),
190
- )
191
-
192
- # Return if chunked prefill not finished
193
- return req if truncated else None
194
-
195
172
  def add_inflight_req(self, req: Req):
196
- if req.sampling_params.ignore_eos:
197
- return self.add_inflight_req_ignore_eos(req)
198
-
199
173
  truncated = req.extend_input_len > self.rem_chunk_tokens
200
174
  req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
201
175
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
@@ -225,7 +199,7 @@ class PrefillAdder:
225
199
  self.rem_total_tokens += delta
226
200
 
227
201
  def add_one_req_ignore_eos(self, req: Req):
228
- def get_req_state(r):
202
+ def add_req_state(r, insert_sort=False):
229
203
  new_token_ratio = (
230
204
  1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
231
205
  )
@@ -235,56 +209,38 @@ class PrefillAdder:
235
209
  tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
236
210
 
237
211
  if tokens_left > 0:
238
- return (tokens_left, tokens_occupied)
239
-
240
- return None
241
-
242
- # Quick Check
243
- can_run = False
244
- if (
245
- req.extend_input_len + req.sampling_params.max_new_tokens
246
- <= self.rem_total_tokens
247
- ):
248
- can_run = True
249
-
250
- if not can_run:
251
- if self.req_states is None:
252
- self.req_states = []
253
- if self.running_batch is not None:
254
- for r in self.running_batch.reqs:
255
- state = get_req_state(r)
256
- if state is not None:
257
- self.req_states.append(state)
258
- for r in self.can_run_list:
259
- state = get_req_state(r)
260
- if state is not None:
261
- self.req_states.append(state)
262
- state = get_req_state(req)
263
- if state is not None:
264
- self.req_states.append(state)
265
-
266
- self.req_states.sort(key=lambda x: x[0])
267
- else:
268
- state = get_req_state(req)
269
- if state is not None:
270
- for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
271
- if tokens_left >= state[0]:
272
- self.req_states.insert(i, state)
212
+ if not insert_sort:
213
+ self.req_states.append((tokens_left, tokens_occupied))
214
+ else:
215
+ for i in range(len(self.req_states)):
216
+ if tokens_left <= self.req_states[i][0]:
273
217
  break
274
- else:
275
- self.req_states.append(state)
276
-
277
- tokens_freed = 0
278
- for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
279
- decode_steps = (
280
- self.req_states[i + 1][0]
281
- if i + 1 < len(self.req_states)
282
- else tokens_left
283
- )
284
- bs = len(self.req_states) - i
285
- if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
286
- return False
287
- tokens_freed += tokens_occupied
218
+ self.req_states.insert(i, (tokens_left, tokens_occupied))
219
+
220
+ if self.req_states is None:
221
+ self.req_states = []
222
+ add_req_state(req)
223
+ if self.running_batch is not None:
224
+ for r in self.running_batch.reqs:
225
+ add_req_state(r)
226
+ for r in self.can_run_list:
227
+ add_req_state(r)
228
+ self.req_states.sort(key=lambda x: x[0])
229
+ else:
230
+ add_req_state(req, insert_sort=True)
231
+
232
+ cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
233
+ tokens_freed = 0
234
+ for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
235
+ decode_steps = (
236
+ self.req_states[i + 1][0]
237
+ if i + 1 < len(self.req_states)
238
+ else tokens_left
239
+ )
240
+ bs = len(self.req_states) - i
241
+ if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
242
+ return False
243
+ tokens_freed += tokens_occupied
288
244
 
289
245
  if req.extend_input_len <= self.rem_chunk_tokens:
290
246
  self.can_run_list.append(req)
@@ -40,7 +40,7 @@ global_server_args_dict = {
40
40
  "attention_backend": ServerArgs.attention_backend,
41
41
  "sampling_backend": ServerArgs.sampling_backend,
42
42
  "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
43
- "enable_mla": ServerArgs.enable_mla,
43
+ "disable_mla": ServerArgs.disable_mla,
44
44
  "torchao_config": ServerArgs.torchao_config,
45
45
  }
46
46
 
@@ -445,9 +445,6 @@ class ModelTpServer:
445
445
  num_mixed_running,
446
446
  )
447
447
 
448
- if self.running_batch is not None:
449
- adder.remove_running_tokens(self.running_batch)
450
-
451
448
  has_inflight = self.current_inflight_req is not None
452
449
  if self.current_inflight_req is not None:
453
450
  self.current_inflight_req.init_next_round_input(
@@ -465,9 +462,6 @@ class ModelTpServer:
465
462
  )
466
463
 
467
464
  for req in self.waiting_queue:
468
- if adder.no_remaining_tokens():
469
- break
470
- req.init_next_round_input(None if prefix_computed else self.tree_cache)
471
465
  if (
472
466
  self.lora_paths is not None
473
467
  and len(
@@ -478,6 +472,10 @@ class ModelTpServer:
478
472
  > self.max_loras_per_batch
479
473
  ):
480
474
  break
475
+
476
+ if adder.no_remaining_tokens():
477
+ break
478
+ req.init_next_round_input(None if prefix_computed else self.tree_cache)
481
479
  res = adder.add_one_req(req)
482
480
  if (
483
481
  not res
@@ -507,6 +505,11 @@ class ModelTpServer:
507
505
  else:
508
506
  tree_cache_hit_rate = 0.0
509
507
 
508
+ num_used = self.max_total_num_tokens - (
509
+ self.token_to_kv_pool.available_size()
510
+ + self.tree_cache.evictable_size()
511
+ )
512
+
510
513
  if num_mixed_running > 0:
511
514
  logger.info(
512
515
  f"Prefill batch"
@@ -515,6 +518,7 @@ class ModelTpServer:
515
518
  f"#new-token: {adder.log_input_tokens}, "
516
519
  f"#cached-token: {adder.log_hit_tokens}, "
517
520
  f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
521
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
518
522
  f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
519
523
  )
520
524
  else:
@@ -524,6 +528,7 @@ class ModelTpServer:
524
528
  f"#new-token: {adder.log_input_tokens}, "
525
529
  f"#cached-token: {adder.log_hit_tokens}, "
526
530
  f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
531
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
527
532
  f"#running-req: {running_bs}, "
528
533
  f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
529
534
  )
@@ -108,6 +108,10 @@ class CudaGraphRunner:
108
108
  self.capture_bs = list(range(1, 32)) + [64, 128]
109
109
  else:
110
110
  self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
111
+
112
+ self.capture_bs = [
113
+ bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
114
+ ]
111
115
  self.compile_bs = (
112
116
  [
113
117
  bs
@@ -118,21 +122,8 @@ class CudaGraphRunner:
118
122
  else []
119
123
  )
120
124
 
121
- # Common inputs
122
- self.max_bs = max(self.capture_bs)
123
- self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
124
- self.req_pool_indices = torch.zeros(
125
- (self.max_bs,), dtype=torch.int32, device="cuda"
126
- )
127
- self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
128
- self.position_ids_offsets = torch.ones(
129
- (self.max_bs,), dtype=torch.int32, device="cuda"
130
- )
131
- self.out_cache_loc = torch.zeros(
132
- (self.max_bs,), dtype=torch.int32, device="cuda"
133
- )
134
-
135
125
  # Attention backend
126
+ self.max_bs = max(self.capture_bs)
136
127
  self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
137
128
  self.seq_len_fill_value = (
138
129
  self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
@@ -141,6 +132,16 @@ class CudaGraphRunner:
141
132
  if self.use_torch_compile:
142
133
  set_torch_compile_config()
143
134
 
135
+ # Common inputs
136
+ with torch.device("cuda"):
137
+ self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32)
138
+ self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
139
+ self.seq_lens = torch.full(
140
+ (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
141
+ )
142
+ self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32)
143
+ self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
144
+
144
145
  # Capture
145
146
  try:
146
147
  self.capture()
@@ -86,12 +86,20 @@ class ModelRunner:
86
86
  self.is_multimodal_model = is_multimodal_model(
87
87
  self.model_config.hf_config.architectures
88
88
  )
89
+
90
+ if (
91
+ self.model_config.attention_arch == AttentionArch.MLA
92
+ and not self.server_args.disable_mla
93
+ ):
94
+ logger.info("MLA optimization is tunred on. Use triton backend.")
95
+ self.server_args.attention_backend = "triton"
96
+
89
97
  global_server_args_dict.update(
90
98
  {
91
99
  "attention_backend": server_args.attention_backend,
92
100
  "sampling_backend": server_args.sampling_backend,
93
101
  "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
94
- "enable_mla": server_args.enable_mla,
102
+ "disable_mla": server_args.disable_mla,
95
103
  "torchao_config": server_args.torchao_config,
96
104
  }
97
105
  )
@@ -329,7 +337,7 @@ class ModelRunner:
329
337
  )
330
338
  if (
331
339
  self.model_config.attention_arch == AttentionArch.MLA
332
- and self.server_args.enable_mla
340
+ and not self.server_args.disable_mla
333
341
  ):
334
342
  cell_size = (
335
343
  (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
@@ -392,12 +400,12 @@ class ModelRunner:
392
400
  )
393
401
 
394
402
  self.req_to_token_pool = ReqToTokenPool(
395
- max_num_reqs,
396
- self.model_config.context_len + 8,
403
+ max_num_reqs + 1,
404
+ self.model_config.context_len + 4,
397
405
  )
398
406
  if (
399
407
  self.model_config.attention_arch == AttentionArch.MLA
400
- and self.server_args.enable_mla
408
+ and not self.server_args.disable_mla
401
409
  ):
402
410
  self.token_to_kv_pool = MLATokenToKVPool(
403
411
  self.max_total_num_tokens,
@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (
34
34
  QKVParallelLinear,
35
35
  RowParallelLinear,
36
36
  )
37
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
38
37
  from vllm.model_executor.layers.rotary_embedding import get_rope
39
38
  from vllm.model_executor.layers.vocab_parallel_embedding import (
40
39
  ParallelLMHead,
@@ -45,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
44
  from sglang.srt.layers.activation import SiluAndMul
46
45
  from sglang.srt.layers.layernorm import RMSNorm
47
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
50
 
@@ -24,12 +24,6 @@ from torch import nn
24
24
  from torch.nn import LayerNorm
25
25
  from vllm.config import CacheConfig
26
26
  from vllm.distributed import get_tensor_model_parallel_world_size
27
- from vllm.model_executor.layers.linear import (
28
- MergedColumnParallelLinear,
29
- QKVParallelLinear,
30
- RowParallelLinear,
31
- )
32
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
33
27
  from vllm.model_executor.layers.rotary_embedding import get_rope
34
28
  from vllm.model_executor.layers.vocab_parallel_embedding import (
35
29
  ParallelLMHead,
@@ -40,7 +34,13 @@ from vllm.transformers_utils.configs import ChatGLMConfig
40
34
 
41
35
  from sglang.srt.layers.activation import SiluAndMul
42
36
  from sglang.srt.layers.layernorm import RMSNorm
37
+ from sglang.srt.layers.linear import (
38
+ MergedColumnParallelLinear,
39
+ QKVParallelLinear,
40
+ RowParallelLinear,
41
+ )
43
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
44
44
  from sglang.srt.layers.radix_attention import RadixAttention
45
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
46
 
@@ -50,21 +50,21 @@ from vllm.distributed import (
50
50
  get_tensor_model_parallel_rank,
51
51
  get_tensor_model_parallel_world_size,
52
52
  )
53
- from vllm.model_executor.layers.linear import (
54
- MergedColumnParallelLinear,
55
- QKVParallelLinear,
56
- RowParallelLinear,
57
- )
58
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
59
53
  from vllm.model_executor.layers.rotary_embedding import get_rope
60
54
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
61
55
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
62
- from vllm.model_executor.utils import set_weight_attrs
63
56
 
64
57
  from sglang.srt.layers.activation import SiluAndMul
58
+ from sglang.srt.layers.linear import (
59
+ MergedColumnParallelLinear,
60
+ QKVParallelLinear,
61
+ RowParallelLinear,
62
+ )
65
63
  from sglang.srt.layers.logits_processor import LogitsProcessor
64
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
66
65
  from sglang.srt.layers.radix_attention import RadixAttention
67
66
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
67
+ from sglang.srt.utils import set_weight_attrs
68
68
 
69
69
 
70
70
  @torch.compile