sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_latency.py +28 -10
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/layers/attention/__init__.py +27 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  7. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  8. sglang/srt/layers/attention/triton_backend.py +6 -4
  9. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  10. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  11. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  12. sglang/srt/layers/sampler.py +6 -2
  13. sglang/srt/managers/detokenizer_manager.py +31 -10
  14. sglang/srt/managers/io_struct.py +4 -0
  15. sglang/srt/managers/schedule_batch.py +120 -43
  16. sglang/srt/managers/schedule_policy.py +2 -1
  17. sglang/srt/managers/scheduler.py +202 -140
  18. sglang/srt/managers/tokenizer_manager.py +5 -1
  19. sglang/srt/managers/tp_worker.py +111 -1
  20. sglang/srt/mem_cache/chunk_cache.py +8 -4
  21. sglang/srt/mem_cache/memory_pool.py +77 -4
  22. sglang/srt/mem_cache/radix_cache.py +15 -7
  23. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  24. sglang/srt/model_executor/forward_batch_info.py +16 -21
  25. sglang/srt/model_executor/model_runner.py +60 -1
  26. sglang/srt/models/baichuan.py +2 -3
  27. sglang/srt/models/chatglm.py +5 -6
  28. sglang/srt/models/commandr.py +1 -2
  29. sglang/srt/models/dbrx.py +1 -2
  30. sglang/srt/models/deepseek.py +4 -5
  31. sglang/srt/models/deepseek_v2.py +5 -6
  32. sglang/srt/models/exaone.py +1 -2
  33. sglang/srt/models/gemma.py +2 -2
  34. sglang/srt/models/gemma2.py +5 -5
  35. sglang/srt/models/gpt_bigcode.py +5 -5
  36. sglang/srt/models/grok.py +1 -2
  37. sglang/srt/models/internlm2.py +1 -2
  38. sglang/srt/models/llama.py +1 -2
  39. sglang/srt/models/llama_classification.py +1 -2
  40. sglang/srt/models/llama_reward.py +2 -3
  41. sglang/srt/models/llava.py +4 -8
  42. sglang/srt/models/llavavid.py +1 -2
  43. sglang/srt/models/minicpm.py +1 -2
  44. sglang/srt/models/minicpm3.py +5 -6
  45. sglang/srt/models/mixtral.py +1 -2
  46. sglang/srt/models/mixtral_quant.py +1 -2
  47. sglang/srt/models/olmo.py +352 -0
  48. sglang/srt/models/olmoe.py +1 -2
  49. sglang/srt/models/qwen.py +1 -2
  50. sglang/srt/models/qwen2.py +1 -2
  51. sglang/srt/models/qwen2_moe.py +4 -5
  52. sglang/srt/models/stablelm.py +1 -2
  53. sglang/srt/models/torch_native_llama.py +1 -2
  54. sglang/srt/models/xverse.py +1 -2
  55. sglang/srt/models/xverse_moe.py +4 -5
  56. sglang/srt/models/yivl.py +1 -2
  57. sglang/srt/openai_api/adapter.py +92 -49
  58. sglang/srt/openai_api/protocol.py +10 -2
  59. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  60. sglang/srt/sampling/sampling_batch_info.py +92 -58
  61. sglang/srt/sampling/sampling_params.py +2 -0
  62. sglang/srt/server.py +116 -17
  63. sglang/srt/server_args.py +121 -45
  64. sglang/srt/utils.py +11 -3
  65. sglang/test/few_shot_gsm8k.py +4 -1
  66. sglang/test/few_shot_gsm8k_engine.py +144 -0
  67. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  68. sglang/version.py +1 -1
  69. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
  70. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
  71. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  72. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  73. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  74. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,11 @@ limitations under the License.
17
17
 
18
18
  import json
19
19
  import logging
20
+ import threading
21
+ import time
22
+ from queue import Queue
23
+
24
+ import torch
20
25
 
21
26
  from sglang.srt.configs.model_config import ModelConfig
22
27
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
@@ -75,6 +80,7 @@ class TpModelWorker:
75
80
  tokenizer_mode=server_args.tokenizer_mode,
76
81
  trust_remote_code=server_args.trust_remote_code,
77
82
  )
83
+ self.device = self.model_runner.device
78
84
 
79
85
  # Profile number of tokens
80
86
  self.max_total_num_tokens = self.model_runner.max_total_num_tokens
@@ -100,6 +106,9 @@ class TpModelWorker:
100
106
  )[0]
101
107
  set_random_seed(self.random_seed)
102
108
 
109
+ if server_args.enable_overlap_schedule:
110
+ self.init_overlap_status()
111
+
103
112
  def get_token_and_memory_info(self):
104
113
  return (
105
114
  self.max_total_num_tokens,
@@ -109,6 +118,81 @@ class TpModelWorker:
109
118
  self.random_seed,
110
119
  )
111
120
 
121
+ def init_overlap_status(self):
122
+ self.future_logits_output_dict = dict()
123
+ self.future_logits_output_ct = 0
124
+ self.future_token_ids_ct = 0
125
+ self.future_token_ids_map = torch.empty(
126
+ (self.max_running_requests * 5,), dtype=torch.int32, device=self.device
127
+ )
128
+ self.future_token_ids_limit = self.max_running_requests * 3
129
+ self.future_token_ids_output = dict()
130
+
131
+ self.future_event_map = dict()
132
+ self.forward_queue = Queue()
133
+ self.forward_stream = torch.cuda.Stream()
134
+ self.forward_thread = threading.Thread(
135
+ target=self.forward_thread_func,
136
+ )
137
+ self.forward_thread.start()
138
+
139
+ def forward_thread_func(self):
140
+ with torch.cuda.stream(self.forward_stream):
141
+ self.forward_thread_func_()
142
+
143
+ @torch.inference_mode()
144
+ def forward_thread_func_(self):
145
+ while True:
146
+ tic1 = time.time()
147
+ model_worker_batch, future_logits_output, future_next_token_ids = (
148
+ self.forward_queue.get()
149
+ )
150
+
151
+ # Resolve future tokens in the input
152
+ tic2 = time.time()
153
+ resolved_input_ids = model_worker_batch.input_ids
154
+ future_mask = resolved_input_ids < 0
155
+ resolved_input_ids[future_mask] = self.future_token_ids_map[
156
+ -resolved_input_ids[future_mask]
157
+ ]
158
+
159
+ # Run forward
160
+ logits_output, next_token_ids = self.forward_batch_generation(
161
+ model_worker_batch
162
+ )
163
+
164
+ # Set future values
165
+ if model_worker_batch.return_logprob:
166
+ self.future_logits_output_dict[future_logits_output] = logits_output
167
+
168
+ # logger.info(f"set output {future_next_token_ids=}, {next_token_ids=}")
169
+ self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
170
+ torch.int32
171
+ )
172
+ # logger.info("Set event")
173
+ self.future_token_ids_output[model_worker_batch.bid] = (
174
+ next_token_ids.tolist()
175
+ )
176
+ self.future_event_map[model_worker_batch.bid].set()
177
+
178
+ if False:
179
+ tic3 = time.time()
180
+ self.acc_time_with_waiting += tic3 - tic1
181
+ self.acc_time_without_waiting += tic3 - tic2
182
+ if self.forward_queue.qsize() == 0:
183
+ logger.info(
184
+ f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
185
+ )
186
+
187
+ def resolve_future_token_ids(self, bid: int):
188
+ self.future_event_map[bid].wait()
189
+ ret = self.future_token_ids_output[bid]
190
+ del self.future_event_map[bid]
191
+ return ret
192
+
193
+ def resolve_future_logits_output(self, future_obj):
194
+ return self.future_logits_output_dict.pop(future_obj)
195
+
112
196
  def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
113
197
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
114
198
  logits_output = self.model_runner.forward(forward_batch)
@@ -118,9 +202,35 @@ class TpModelWorker:
118
202
  def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
119
203
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
120
204
  logits_output = self.model_runner.forward(forward_batch)
121
- embeddings = logits_output.embeddings.tolist()
205
+ embeddings = logits_output.embeddings
122
206
  return embeddings
123
207
 
208
+ def forward_batch_generation_non_blocking(
209
+ self, model_worker_batch: ModelWorkerBatch
210
+ ):
211
+ # Allocate output future objects
212
+ future_logits_output = self.future_logits_output_ct
213
+ self.future_logits_output_ct += 1
214
+
215
+ bs = len(model_worker_batch.seq_lens)
216
+ with torch.cuda.stream(self.forward_stream):
217
+ future_next_token_ids = -torch.arange(
218
+ self.future_token_ids_ct + 1,
219
+ self.future_token_ids_ct + 1 + bs,
220
+ dtype=torch.int32,
221
+ device=self.device,
222
+ )
223
+ self.future_token_ids_ct = (
224
+ self.future_token_ids_ct + bs
225
+ ) % self.future_token_ids_limit
226
+ ret = future_logits_output, future_next_token_ids
227
+
228
+ self.future_event_map[model_worker_batch.bid] = threading.Event()
229
+ self.forward_queue.put(
230
+ (model_worker_batch.copy(), future_logits_output, future_next_token_ids)
231
+ )
232
+ return ret
233
+
124
234
  def update_weights(self, recv_req: UpdateWeightReqInput):
125
235
  success, message = self.model_runner.update_weights(
126
236
  recv_req.model_path, recv_req.load_format
@@ -40,10 +40,12 @@ class ChunkCache(BasePrefixCache):
40
40
 
41
41
  def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
42
42
  if token_ids is None:
43
- token_ids = (req.origin_input_ids + req.output_ids)[:-1]
43
+ token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
44
+ else:
45
+ token_id_len = len(token_ids)
44
46
 
45
47
  kv_indices = self.req_to_token_pool.req_to_token[
46
- req.req_pool_idx, : len(token_ids)
48
+ req.req_pool_idx, :token_id_len
47
49
  ]
48
50
  self.req_to_token_pool.free(req.req_pool_idx)
49
51
  self.token_to_kv_pool.free(kv_indices)
@@ -53,10 +55,12 @@ class ChunkCache(BasePrefixCache):
53
55
 
54
56
  def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
55
57
  if token_ids is None:
56
- token_ids = req.fill_ids
58
+ token_id_len = len(req.fill_ids)
59
+ else:
60
+ token_id_len = len(token_ids)
57
61
 
58
62
  kv_indices = self.req_to_token_pool.req_to_token[
59
- req.req_pool_idx, : len(token_ids)
63
+ req.req_pool_idx, :token_id_len
60
64
  ]
61
65
 
62
66
  if req.rid not in self.entries:
@@ -18,7 +18,6 @@ limitations under the License.
18
18
  import logging
19
19
  from typing import List, Tuple, Union
20
20
 
21
- import numpy as np
22
21
  import torch
23
22
 
24
23
  logger = logging.getLogger(__name__)
@@ -77,6 +76,8 @@ class BaseTokenToKVPool:
77
76
  self.store_dtype = dtype
78
77
 
79
78
  self.free_slots = None
79
+ self.is_not_in_free_group = True
80
+ self.free_group = []
80
81
  self.clear()
81
82
 
82
83
  def available_size(self):
@@ -89,14 +90,28 @@ class BaseTokenToKVPool:
89
90
  select_index = self.free_slots[:need_size]
90
91
  self.free_slots = self.free_slots[need_size:]
91
92
 
92
- return torch.tensor(select_index, dtype=torch.int32, device=self.device)
93
+ return select_index.to(self.device, non_blocking=True)
93
94
 
94
95
  def free(self, free_index: torch.Tensor):
95
- self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))
96
+ if self.is_not_in_free_group:
97
+ self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
98
+ else:
99
+ self.free_group.append(free_index)
100
+
101
+ def free_group_begin(self):
102
+ self.is_not_in_free_group = False
103
+ self.free_group = []
104
+
105
+ def free_group_end(self):
106
+ self.is_not_in_free_group = True
107
+ if self.free_group:
108
+ self.free(torch.concat(self.free_group))
96
109
 
97
110
  def clear(self):
98
111
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
99
- self.free_slots = np.arange(1, self.size + 1)
112
+ self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32)
113
+ self.is_in_free_group = False
114
+ self.free_group = []
100
115
 
101
116
  def get_key_buffer(self, layer_id: int) -> torch.Tensor:
102
117
  raise NotImplementedError()
@@ -231,3 +246,61 @@ class MLATokenToKVPool(BaseTokenToKVPool):
231
246
  self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
232
247
  else:
233
248
  self.kv_buffer[layer_id][loc] = cache_k
249
+
250
+
251
+ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
252
+
253
+ def __init__(
254
+ self,
255
+ size: int,
256
+ dtype: torch.dtype,
257
+ head_num: int,
258
+ head_dim: int,
259
+ layer_num: int,
260
+ device: str,
261
+ heavy_channel_num: int,
262
+ ):
263
+ super().__init__(size, dtype, device)
264
+
265
+ # [size, head_num, head_dim] for each layer
266
+ self.k_buffer = [
267
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
268
+ for _ in range(layer_num)
269
+ ]
270
+ self.v_buffer = [
271
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
272
+ for _ in range(layer_num)
273
+ ]
274
+
275
+ # [size, head_num, heavy_channel_num] for each layer
276
+ self.label_buffer = [
277
+ torch.empty(
278
+ (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
279
+ )
280
+ for _ in range(layer_num)
281
+ ]
282
+
283
+ def get_key_buffer(self, layer_id: int):
284
+ return self.k_buffer[layer_id]
285
+
286
+ def get_value_buffer(self, layer_id: int):
287
+ return self.v_buffer[layer_id]
288
+
289
+ def get_label_buffer(self, layer_id: int):
290
+ return self.label_buffer[layer_id]
291
+
292
+ def get_kv_buffer(self, layer_id: int):
293
+ return self.k_buffer[layer_id], self.v_buffer[layer_id]
294
+
295
+ def set_kv_buffer(
296
+ self,
297
+ layer_id: int,
298
+ loc: torch.Tensor,
299
+ cache_k: torch.Tensor,
300
+ cache_v: torch.Tensor,
301
+ cache_label: torch.Tensor,
302
+ ):
303
+ # NOTE(Andy): ignore the dtype check
304
+ self.k_buffer[layer_id][loc] = cache_k
305
+ self.v_buffer[layer_id][loc] = cache_v
306
+ self.label_buffer[layer_id][loc] = cache_label
@@ -99,17 +99,25 @@ class RadixCache(BasePrefixCache):
99
99
 
100
100
  def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
101
101
  """Cache request when it finishes."""
102
+ if self.disable:
103
+ if token_ids is None:
104
+ token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1
105
+ else:
106
+ token_ids_len = len(token_ids)
107
+
108
+ kv_indices = self.req_to_token_pool.req_to_token[
109
+ req.req_pool_idx, :token_ids_len
110
+ ]
111
+ self.token_to_kv_pool.free(kv_indices)
112
+ self.req_to_token_pool.free(req.req_pool_idx)
113
+ return
114
+
102
115
  if token_ids is None:
103
116
  token_ids = (req.origin_input_ids + req.output_ids)[:-1]
104
117
  kv_indices = self.req_to_token_pool.req_to_token[
105
118
  req.req_pool_idx, : len(token_ids)
106
119
  ]
107
120
 
108
- if self.disable:
109
- self.token_to_kv_pool.free(kv_indices)
110
- self.req_to_token_pool.free(req.req_pool_idx)
111
- return
112
-
113
121
  # Radix Cache takes one ref in memory pool
114
122
  new_prefix_len = self.insert(token_ids, kv_indices.clone())
115
123
  self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
@@ -229,7 +237,7 @@ class RadixCache(BasePrefixCache):
229
237
  def _split_node(self, key, child: TreeNode, split_len: int):
230
238
  # new_node -> child
231
239
  new_node = TreeNode()
232
- new_node.children = {key[split_len:][0]: child}
240
+ new_node.children = {key[split_len]: child}
233
241
  new_node.parent = child.parent
234
242
  new_node.lock_ref = child.lock_ref
235
243
  new_node.key = child.key[:split_len]
@@ -237,7 +245,7 @@ class RadixCache(BasePrefixCache):
237
245
  child.parent = new_node
238
246
  child.key = child.key[split_len:]
239
247
  child.value = child.value[split_len:]
240
- new_node.parent.children[key[:split_len][0]] = new_node
248
+ new_node.parent.children[key[0]] = new_node
241
249
  return new_node
242
250
 
243
251
  def _insert_helper(self, node: TreeNode, key: List, value):
@@ -245,10 +245,10 @@ class CudaGraphRunner:
245
245
  self.out_cache_loc.zero_()
246
246
 
247
247
  # Common inputs
248
- self.input_ids[:raw_bs] = forward_batch.input_ids
249
- self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices
250
- self.seq_lens[:raw_bs] = forward_batch.seq_lens
251
- self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc
248
+ self.input_ids[:raw_bs].copy_(forward_batch.input_ids)
249
+ self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
250
+ self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
251
+ self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
252
252
 
253
253
  # Attention backend
254
254
  self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
@@ -118,7 +118,7 @@ class ForwardBatch:
118
118
  batch: ModelWorkerBatch,
119
119
  model_runner: ModelRunner,
120
120
  ):
121
- device = "cuda"
121
+ device = model_runner.device
122
122
 
123
123
  ret = cls(
124
124
  forward_mode=batch.forward_mode,
@@ -134,27 +134,23 @@ class ForwardBatch:
134
134
  )
135
135
 
136
136
  # Init position information
137
- if ret.forward_mode.is_decode():
138
- ret.positions = (ret.seq_lens - 1).to(torch.int64)
139
- else:
140
- ret.positions = torch.tensor(
141
- np.concatenate(
142
- [
143
- np.arange(prefix_len, prefix_len + extend_len)
144
- for prefix_len, extend_len in zip(
145
- batch.extend_prefix_lens, batch.extend_seq_lens
146
- )
147
- ],
148
- axis=0,
149
- ),
150
- device=device,
151
- ).to(torch.int64)
152
-
137
+ if not ret.forward_mode.is_decode():
138
+ ret.positions = torch.concat(
139
+ [
140
+ torch.arange(prefix_len, prefix_len + extend_len, device=device)
141
+ for prefix_len, extend_len in zip(
142
+ batch.extend_prefix_lens, batch.extend_seq_lens
143
+ )
144
+ ],
145
+ axis=0,
146
+ )
153
147
  ret.image_inputs = batch.image_inputs
154
- ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
148
+ ret.extend_seq_lens = torch.tensor(
149
+ batch.extend_seq_lens, dtype=torch.int32
150
+ ).to(device, non_blocking=True)
155
151
  ret.extend_prefix_lens = torch.tensor(
156
- batch.extend_prefix_lens, device=device
157
- )
152
+ batch.extend_prefix_lens, dtype=torch.int32
153
+ ).to(device, non_blocking=True)
158
154
  ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
159
155
  ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
160
156
  ret.extend_seq_lens_cpu = batch.extend_seq_lens
@@ -164,7 +160,6 @@ class ForwardBatch:
164
160
  ret.req_to_token_pool = model_runner.req_to_token_pool
165
161
  ret.token_to_kv_pool = model_runner.token_to_kv_pool
166
162
  ret.attn_backend = model_runner.attn_backend
167
- model_runner.attn_backend.init_forward_metadata(ret)
168
163
 
169
164
  # Init lora information
170
165
  if model_runner.server_args.lora_paths is not None:
@@ -18,6 +18,7 @@ limitations under the License.
18
18
  import gc
19
19
  import importlib
20
20
  import importlib.resources
21
+ import json
21
22
  import logging
22
23
  import pkgutil
23
24
  from functools import lru_cache
@@ -39,6 +40,7 @@ from vllm.model_executor.models import ModelRegistry
39
40
 
40
41
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
41
42
  from sglang.srt.constrained import disable_cache
43
+ from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
42
44
  from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
43
45
  from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
44
46
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -46,6 +48,7 @@ from sglang.srt.layers.sampler import Sampler
46
48
  from sglang.srt.lora.lora_manager import LoRAManager
47
49
  from sglang.srt.managers.schedule_batch import global_server_args_dict
48
50
  from sglang.srt.mem_cache.memory_pool import (
51
+ DoubleSparseTokenToKVPool,
49
52
  MHATokenToKVPool,
50
53
  MLATokenToKVPool,
51
54
  ReqToTokenPool,
@@ -99,6 +102,20 @@ class ModelRunner:
99
102
  logger.info("MLA optimization is turned on. Use triton backend.")
100
103
  self.server_args.attention_backend = "triton"
101
104
 
105
+ if self.server_args.enable_double_sparsity:
106
+ logger.info(
107
+ "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
108
+ )
109
+ self.server_args.attention_backend = "triton"
110
+ self.server_args.disable_cuda_graph = True
111
+ if self.server_args.ds_heavy_channel_type is None:
112
+ raise ValueError(
113
+ "Please specify the heavy channel type for double sparsity optimization."
114
+ )
115
+ self.init_double_sparsity_channel_config(
116
+ self.server_args.ds_heavy_channel_type
117
+ )
118
+
102
119
  if self.is_multimodal_model:
103
120
  logger.info(
104
121
  "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
@@ -119,6 +136,8 @@ class ModelRunner:
119
136
  "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
120
137
  "disable_mla": server_args.disable_mla,
121
138
  "torchao_config": server_args.torchao_config,
139
+ "disable_penalizer": server_args.disable_penalizer,
140
+ "disable_nan_detection": server_args.disable_nan_detection,
122
141
  }
123
142
  )
124
143
 
@@ -138,6 +157,7 @@ class ModelRunner:
138
157
  self.init_attention_backend()
139
158
  self.init_cuda_graphs()
140
159
  else:
160
+ self.cuda_graph_runner = None
141
161
  self.init_attention_backend()
142
162
 
143
163
  def init_torch_distributed(self):
@@ -146,6 +166,11 @@ class ModelRunner:
146
166
  if self.device == "cuda":
147
167
  torch.cuda.set_device(self.gpu_id)
148
168
  backend = "nccl"
169
+ # ToDO(liangan1):Just use gloo to bypass the initilization fail
170
+ # Need to use xccl for xpu backend in the future
171
+ elif self.device == "xpu":
172
+ torch.xpu.set_device(self.gpu_id)
173
+ backend = "gloo"
149
174
 
150
175
  if not self.server_args.enable_p2p_check:
151
176
  monkey_patch_vllm_p2p_access_check(self.gpu_id)
@@ -432,6 +457,16 @@ class ModelRunner:
432
457
  layer_num=self.model_config.num_hidden_layers,
433
458
  device=self.device,
434
459
  )
460
+ elif self.server_args.enable_double_sparsity:
461
+ self.token_to_kv_pool = DoubleSparseTokenToKVPool(
462
+ self.max_total_num_tokens,
463
+ dtype=self.kv_cache_dtype,
464
+ head_num=self.model_config.get_num_kv_heads(self.tp_size),
465
+ head_dim=self.model_config.head_dim,
466
+ layer_num=self.model_config.num_hidden_layers,
467
+ device=self.device,
468
+ heavy_channel_num=self.server_args.ds_heavy_channel_num,
469
+ )
435
470
  else:
436
471
  self.token_to_kv_pool = MHATokenToKVPool(
437
472
  self.max_total_num_tokens,
@@ -468,12 +503,33 @@ class ModelRunner:
468
503
  "Cross attention is not supported in the triton attention backend. "
469
504
  "Please use `--attention-backend flashinfer`."
470
505
  )
471
- self.attn_backend = TritonAttnBackend(self)
506
+ if self.server_args.enable_double_sparsity:
507
+ self.attn_backend = DoubleSparseAttnBackend(self)
508
+ else:
509
+ self.attn_backend = TritonAttnBackend(self)
472
510
  else:
473
511
  raise ValueError(
474
512
  f"Invalid attention backend: {self.server_args.attention_backend}"
475
513
  )
476
514
 
515
+ def init_double_sparsity_channel_config(self, selected_channel):
516
+
517
+ selected_channel = "." + selected_channel + "_proj"
518
+ self.sorted_channels = []
519
+ # load channel config
520
+ with open(self.server_args.ds_channel_config_path, "r") as f:
521
+ channel_config = json.load(f)
522
+
523
+ for i in range(self.model_config.num_hidden_layers):
524
+ key = "model.layers." + str(i) + ".self_attn" + selected_channel
525
+ self.sorted_channels.append(
526
+ torch.tensor(channel_config[key])[
527
+ :, : self.server_args.ds_heavy_channel_num
528
+ ]
529
+ .contiguous()
530
+ .cuda()
531
+ )
532
+
477
533
  def init_cuda_graphs(self):
478
534
  """Capture cuda graphs."""
479
535
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
@@ -496,11 +552,14 @@ class ModelRunner:
496
552
  ):
497
553
  return self.cuda_graph_runner.replay(forward_batch)
498
554
 
555
+ forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
556
+ self.attn_backend.init_forward_metadata(forward_batch)
499
557
  return self.model.forward(
500
558
  forward_batch.input_ids, forward_batch.positions, forward_batch
501
559
  )
502
560
 
503
561
  def forward_extend(self, forward_batch: ForwardBatch):
562
+ self.attn_backend.init_forward_metadata(forward_batch)
504
563
  if self.is_generation:
505
564
  return self.model.forward(
506
565
  forward_batch.input_ids, forward_batch.positions, forward_batch
@@ -24,7 +24,6 @@ from typing import Iterable, Optional, Tuple
24
24
  import torch
25
25
  from torch import nn
26
26
  from transformers import PretrainedConfig
27
- from vllm.config import CacheConfig
28
27
  from vllm.distributed import (
29
28
  get_tensor_model_parallel_rank,
30
29
  get_tensor_model_parallel_world_size,
@@ -330,7 +329,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
330
329
  self,
331
330
  config: PretrainedConfig,
332
331
  position_embedding: str,
333
- cache_config: Optional[CacheConfig] = None,
332
+ cache_config=None,
334
333
  quant_config: Optional[QuantizationConfig] = None,
335
334
  ):
336
335
  super().__init__()
@@ -404,7 +403,7 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
404
403
  def __init__(
405
404
  self,
406
405
  config,
407
- cache_config: Optional[CacheConfig] = None,
406
+ cache_config=None,
408
407
  quant_config: Optional[QuantizationConfig] = None,
409
408
  ):
410
409
  if config.hidden_size == 4096: # baichuan2 7b
@@ -22,7 +22,6 @@ from typing import Iterable, Optional, Tuple
22
22
  import torch
23
23
  from torch import nn
24
24
  from torch.nn import LayerNorm
25
- from vllm.config import CacheConfig
26
25
  from vllm.distributed import get_tensor_model_parallel_world_size
27
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
28
27
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -52,7 +51,7 @@ class GLMAttention(nn.Module):
52
51
  self,
53
52
  config,
54
53
  layer_id: int = 0,
55
- cache_config: Optional[CacheConfig] = None,
54
+ cache_config=None,
56
55
  quant_config: Optional[QuantizationConfig] = None,
57
56
  ):
58
57
  super().__init__()
@@ -188,7 +187,7 @@ class GLMBlock(nn.Module):
188
187
  self,
189
188
  config,
190
189
  layer_id: int,
191
- cache_config: Optional[CacheConfig] = None,
190
+ cache_config=None,
192
191
  quant_config: Optional[QuantizationConfig] = None,
193
192
  ):
194
193
  super().__init__()
@@ -260,7 +259,7 @@ class GLMTransformer(nn.Module):
260
259
  def __init__(
261
260
  self,
262
261
  config,
263
- cache_config: Optional[CacheConfig] = None,
262
+ cache_config=None,
264
263
  quant_config: Optional[QuantizationConfig] = None,
265
264
  ):
266
265
  super().__init__()
@@ -308,7 +307,7 @@ class ChatGLMModel(nn.Module):
308
307
  def __init__(
309
308
  self,
310
309
  config,
311
- cache_config: Optional[CacheConfig] = None,
310
+ cache_config=None,
312
311
  quant_config: Optional[QuantizationConfig] = None,
313
312
  ):
314
313
  super().__init__()
@@ -359,7 +358,7 @@ class ChatGLMForCausalLM(nn.Module):
359
358
  def __init__(
360
359
  self,
361
360
  config: ChatGLMConfig,
362
- cache_config: Optional[CacheConfig] = None,
361
+ cache_config=None,
363
362
  quant_config: Optional[QuantizationConfig] = None,
364
363
  lora_config: Optional[LoraConfig] = None,
365
364
  ):
@@ -45,7 +45,6 @@ import torch.utils.checkpoint
45
45
  from torch import nn
46
46
  from torch.nn.parameter import Parameter
47
47
  from transformers import PretrainedConfig
48
- from vllm.config import CacheConfig
49
48
  from vllm.distributed import (
50
49
  get_tensor_model_parallel_rank,
51
50
  get_tensor_model_parallel_world_size,
@@ -320,7 +319,7 @@ class CohereForCausalLM(nn.Module):
320
319
  self,
321
320
  config: PretrainedConfig,
322
321
  quant_config: Optional[QuantizationConfig] = None,
323
- cache_config: Optional[CacheConfig] = None,
322
+ cache_config=None,
324
323
  ) -> None:
325
324
  super().__init__()
326
325
  self.config = config
sglang/srt/models/dbrx.py CHANGED
@@ -20,7 +20,6 @@ from typing import Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
22
  import torch.nn as nn
23
- from vllm.config import CacheConfig
24
23
  from vllm.distributed import (
25
24
  get_tensor_model_parallel_rank,
26
25
  get_tensor_model_parallel_world_size,
@@ -368,7 +367,7 @@ class DbrxForCausalLM(nn.Module):
368
367
  self,
369
368
  config: DbrxConfig,
370
369
  quant_config: Optional[QuantizationConfig] = None,
371
- cache_config: Optional[CacheConfig] = None,
370
+ cache_config=None,
372
371
  ):
373
372
  super().__init__()
374
373
  self.config = config