sglang 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/api.py +26 -0
  3. sglang/backend/runtime_endpoint.py +18 -14
  4. sglang/bench_latency.py +40 -18
  5. sglang/global_config.py +21 -16
  6. sglang/lang/chat_template.py +41 -6
  7. sglang/lang/interpreter.py +5 -1
  8. sglang/lang/ir.py +61 -25
  9. sglang/srt/constrained/__init__.py +3 -2
  10. sglang/srt/hf_transformers_utils.py +7 -3
  11. sglang/srt/layers/extend_attention.py +2 -1
  12. sglang/srt/layers/fused_moe.py +181 -167
  13. sglang/srt/layers/logits_processor.py +55 -19
  14. sglang/srt/layers/radix_attention.py +33 -59
  15. sglang/srt/layers/token_attention.py +4 -8
  16. sglang/srt/managers/controller/cuda_graph_runner.py +172 -0
  17. sglang/srt/managers/controller/infer_batch.py +244 -36
  18. sglang/srt/managers/controller/manager_single.py +1 -1
  19. sglang/srt/managers/controller/model_runner.py +69 -284
  20. sglang/srt/managers/controller/tp_worker.py +39 -20
  21. sglang/srt/managers/detokenizer_manager.py +4 -2
  22. sglang/srt/managers/io_struct.py +1 -1
  23. sglang/srt/managers/tokenizer_manager.py +14 -13
  24. sglang/srt/memory_pool.py +33 -6
  25. sglang/srt/model_config.py +6 -0
  26. sglang/srt/models/gemma2.py +436 -0
  27. sglang/srt/models/llama2.py +3 -3
  28. sglang/srt/models/llama_classification.py +10 -7
  29. sglang/srt/models/minicpm.py +373 -0
  30. sglang/srt/models/qwen2_moe.py +454 -0
  31. sglang/srt/openai_api_adapter.py +2 -2
  32. sglang/srt/openai_protocol.py +1 -1
  33. sglang/srt/server.py +18 -8
  34. sglang/srt/server_args.py +24 -20
  35. sglang/srt/utils.py +68 -35
  36. {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/METADATA +19 -13
  37. {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/RECORD +40 -36
  38. {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/WHEEL +1 -1
  39. {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/LICENSE +0 -0
  40. {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/top_level.txt +0 -0
@@ -1,60 +1,42 @@
1
1
  """Radix attention."""
2
2
 
3
- import numpy as np
4
3
  import torch
4
+ from flashinfer.cascade import merge_state
5
5
  from torch import nn
6
6
 
7
7
  from sglang.global_config import global_config
8
- from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
9
8
  from sglang.srt.layers.extend_attention import extend_attention_fwd
10
9
  from sglang.srt.layers.token_attention import token_attention_fwd
10
+ from sglang.srt.managers.controller.infer_batch import global_server_args_dict
11
11
  from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
12
12
 
13
13
 
14
14
  class RadixAttention(nn.Module):
15
15
  def __init__(
16
- self, num_heads: int, head_dim: int, scaling: float, num_kv_heads: int,
17
- layer_id: int, logit_cap: int = -1
16
+ self,
17
+ num_heads: int,
18
+ head_dim: int,
19
+ scaling: float,
20
+ num_kv_heads: int,
21
+ layer_id: int,
22
+ logit_cap: int = -1,
18
23
  ):
19
24
  super().__init__()
20
25
  self.tp_q_head_num = num_heads
21
26
  self.tp_k_head_num = num_kv_heads
22
27
  self.tp_v_head_num = num_kv_heads
23
28
  self.head_dim = head_dim
29
+ self.scaling = scaling
24
30
  self.layer_id = layer_id
25
31
 
26
- assert np.allclose(scaling, 1.0 / (head_dim**0.5))
27
-
28
- from sglang.srt.managers.controller.model_runner import global_server_args_dict
29
-
30
32
  if not global_server_args_dict.get("disable_flashinfer", False):
31
- self.prefill_forward = self.prefill_forward_flashinfer
32
- self.extend_forward = self.prefill_forward_flashinfer
33
+ self.extend_forward = self.extend_forward_flashinfer
33
34
  self.decode_forward = self.decode_forward_flashinfer
34
- # flashinfer now accepts float logit_cap argument
35
- self.logit_cap = logit_cap if logit_cap > 0 else 0
36
35
  else:
37
- self.prefill_forward = self.prefill_forward_triton
38
36
  self.extend_forward = self.extend_forward_triton
39
37
  self.decode_forward = self.decode_forward_triton
40
- self.logit_cap = logit_cap
41
-
42
- def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
43
- o = torch.empty_like(q)
44
-
45
- context_attention_fwd(
46
- q.view(-1, self.tp_q_head_num, self.head_dim),
47
- k,
48
- v,
49
- o.view(-1, self.tp_q_head_num, self.head_dim),
50
- input_metadata.start_loc,
51
- input_metadata.seq_lens,
52
- input_metadata.max_seq_len,
53
- self.logit_cap,
54
- )
55
- self.store_kv_cache(k, v, input_metadata)
56
38
 
57
- return o
39
+ self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
58
40
 
59
41
  def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
60
42
  o = torch.empty_like(q)
@@ -68,14 +50,15 @@ class RadixAttention(nn.Module):
68
50
  input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
69
51
  input_metadata.req_to_token_pool.req_to_token,
70
52
  input_metadata.req_pool_indices,
71
- input_metadata.start_loc,
53
+ input_metadata.triton_start_loc,
72
54
  input_metadata.seq_lens,
73
- input_metadata.prefix_lens,
55
+ input_metadata.triton_prefix_lens,
74
56
  input_metadata.extend_start_loc,
75
57
  input_metadata.extend_seq_lens,
76
- input_metadata.max_seq_len,
77
- input_metadata.max_extend_len,
78
- self.logit_cap,
58
+ input_metadata.triton_max_seq_len,
59
+ input_metadata.triton_max_extend_len,
60
+ sm_scale=self.scaling,
61
+ logit_cap=self.logit_cap,
79
62
  )
80
63
 
81
64
  return o
@@ -91,39 +74,41 @@ class RadixAttention(nn.Module):
91
74
  o.view(-1, self.tp_q_head_num, self.head_dim),
92
75
  input_metadata.req_to_token_pool.req_to_token,
93
76
  input_metadata.req_pool_indices,
94
- input_metadata.start_loc,
77
+ input_metadata.triton_start_loc,
95
78
  input_metadata.seq_lens,
96
- input_metadata.max_seq_len,
97
- input_metadata.other_kv_index,
79
+ input_metadata.triton_max_seq_len,
98
80
  input_metadata.total_num_tokens,
99
- self.logit_cap,
81
+ sm_scale=self.scaling,
82
+ logit_cap=self.logit_cap,
100
83
  )
101
84
 
102
85
  return o
103
86
 
104
- def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
105
- self.store_kv_cache(k, v, input_metadata)
106
-
87
+ def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
107
88
  o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
108
89
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
109
90
  k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
110
91
  v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
92
+ causal=True,
93
+ sm_scale=self.scaling,
111
94
  logits_soft_cap=self.logit_cap,
112
95
  )
113
96
 
114
- if input_metadata.no_prefix:
97
+ if input_metadata.extend_no_prefix:
115
98
  o = o1
116
99
  else:
117
100
  o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
118
101
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
119
102
  input_metadata.token_to_kv_pool.kv_data[self.layer_id],
120
103
  causal=False,
104
+ sm_scale=self.scaling,
121
105
  logits_soft_cap=self.logit_cap,
122
106
  )
123
107
 
124
- from flashinfer.cascade import merge_state
125
108
  o, _ = merge_state(o1, s1, o2, s2)
126
109
 
110
+ self.store_kv_cache(k, v, input_metadata)
111
+
127
112
  if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
128
113
  torch.cuda.synchronize()
129
114
 
@@ -135,6 +120,7 @@ class RadixAttention(nn.Module):
135
120
  o = input_metadata.flashinfer_decode_wrapper.forward(
136
121
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
137
122
  input_metadata.token_to_kv_pool.kv_data[self.layer_id],
123
+ sm_scale=self.scaling,
138
124
  logits_soft_cap=self.logit_cap,
139
125
  )
140
126
 
@@ -144,25 +130,13 @@ class RadixAttention(nn.Module):
144
130
  k = k.view(-1, self.tp_k_head_num, self.head_dim)
145
131
  v = v.view(-1, self.tp_v_head_num, self.head_dim)
146
132
 
147
- if input_metadata.forward_mode == ForwardMode.PREFILL:
148
- return self.prefill_forward(q, k, v, input_metadata)
149
- elif input_metadata.forward_mode == ForwardMode.EXTEND:
133
+ if input_metadata.forward_mode == ForwardMode.EXTEND:
150
134
  return self.extend_forward(q, k, v, input_metadata)
151
135
  elif input_metadata.forward_mode == ForwardMode.DECODE:
152
136
  return self.decode_forward(q, k, v, input_metadata)
153
137
 
154
138
  def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
155
139
  key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
140
+ key_buffer[input_metadata.out_cache_loc] = cache_k
156
141
  value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
157
- if input_metadata.out_cache_loc is not None:
158
- key_buffer[input_metadata.out_cache_loc] = cache_k
159
- value_buffer[input_metadata.out_cache_loc] = cache_v
160
- elif input_metadata.out_cache_cont_start is not None:
161
- key_buffer[
162
- input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
163
- ] = cache_k
164
- value_buffer[
165
- input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
166
- ] = cache_v
167
- else:
168
- raise RuntimeError()
142
+ value_buffer[input_metadata.out_cache_loc] = cache_v
@@ -107,7 +107,6 @@ def _fwd_kernel_stage2(
107
107
  stride_obs,
108
108
  stride_oh,
109
109
  stride_req_to_token_b,
110
- other_kv_index, # To fix a NAN issue
111
110
  kv_group_num: tl.constexpr,
112
111
  BLOCK_DMODEL: tl.constexpr,
113
112
  BLOCK_N: tl.constexpr,
@@ -138,7 +137,7 @@ def _fwd_kernel_stage2(
138
137
  + cur_batch_req_idx * stride_req_to_token_b
139
138
  + (start_n + offs_n),
140
139
  mask=(start_n + offs_n) < cur_batch_seq_len,
141
- other=other_kv_index,
140
+ other=0,
142
141
  )
143
142
 
144
143
  qk = tl.load(
@@ -176,6 +175,7 @@ def _token_att_m_fwd(
176
175
  B_Start_Loc,
177
176
  B_Seqlen,
178
177
  max_len_in_batch,
178
+ sm_scale,
179
179
  logit_cap,
180
180
  ):
181
181
  BLOCK = 32
@@ -183,7 +183,6 @@ def _token_att_m_fwd(
183
183
  Lq, Lk = q.shape[-1], k_buffer.shape[-1]
184
184
  assert Lq == Lk
185
185
  assert Lk in {16, 32, 64, 128, 256}
186
- sm_scale = 1.0 / (Lk**0.5)
187
186
 
188
187
  batch, head_num = B_req_idx.shape[0], q.shape[1]
189
188
 
@@ -250,7 +249,6 @@ def _token_softmax_reducev_fwd(
250
249
  b_req_idx,
251
250
  b_start_loc,
252
251
  b_seq_len,
253
- other_kv_index,
254
252
  ):
255
253
  BLOCK = 64
256
254
  batch, head = b_seq_len.shape[0], logics.shape[0]
@@ -277,7 +275,6 @@ def _token_softmax_reducev_fwd(
277
275
  o.stride(0),
278
276
  o.stride(1),
279
277
  req_to_tokens.stride(0),
280
- other_kv_index,
281
278
  )
282
279
  return
283
280
 
@@ -295,7 +292,6 @@ def _token_softmax_reducev_fwd(
295
292
  o.stride(0),
296
293
  o.stride(1),
297
294
  req_to_tokens.stride(0),
298
- other_kv_index,
299
295
  kv_group_num=kv_group_num,
300
296
  BLOCK_DMODEL=v_buffer.shape[-1],
301
297
  BLOCK_N=BLOCK,
@@ -315,8 +311,8 @@ def token_attention_fwd(
315
311
  b_start_loc,
316
312
  b_seq_len,
317
313
  max_len_in_batch,
318
- other_kv_index,
319
314
  total_num_tokens,
315
+ sm_scale,
320
316
  logit_cap=-1,
321
317
  att_m=None,
322
318
  ):
@@ -334,6 +330,7 @@ def token_attention_fwd(
334
330
  b_start_loc,
335
331
  b_seq_len,
336
332
  max_len_in_batch,
333
+ sm_scale,
337
334
  logit_cap,
338
335
  )
339
336
  _token_softmax_reducev_fwd(
@@ -344,5 +341,4 @@ def token_attention_fwd(
344
341
  b_req_idx,
345
342
  b_start_loc,
346
343
  b_seq_len,
347
- other_kv_index,
348
344
  )
@@ -0,0 +1,172 @@
1
+ """Run the model with cuda graph."""
2
+
3
+ import bisect
4
+
5
+ import torch
6
+ from vllm.distributed.parallel_state import graph_capture
7
+
8
+ from sglang.global_config import global_config
9
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput
10
+ from sglang.srt.managers.controller.infer_batch import (
11
+ Batch, ForwardMode, InputMetadata, init_flashinfer_args
12
+ )
13
+
14
+
15
+ class CudaGraphRunner:
16
+ def __init__(self, model_runner, max_batch_size_to_capture):
17
+ self.model_runner = model_runner
18
+ self.graphs = {}
19
+ self.input_buffers = {}
20
+ self.output_buffers = {}
21
+ self.flashinfer_handlers = {}
22
+ self.graph_memory_pool = None
23
+
24
+ # Common inputs
25
+ self.max_bs = max_batch_size_to_capture
26
+ self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
27
+ self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
28
+ self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
29
+ self.position_ids_offsets = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
30
+ self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
31
+
32
+ # FlashInfer inputs
33
+ self.flashinfer_workspace_buffer = self.model_runner.flashinfer_workspace_buffers[0]
34
+ self.flashinfer_kv_indptr = torch.zeros(
35
+ (self.max_bs + 1,), dtype=torch.int32, device="cuda"
36
+ )
37
+ self.flashinfer_kv_indices = torch.zeros(
38
+ (self.max_bs * model_runner.model_config.context_len,), dtype=torch.int32, device="cuda"
39
+ )
40
+ self.flashinfer_kv_last_page_len = torch.ones(
41
+ (self.max_bs,), dtype=torch.int32, device="cuda"
42
+ )
43
+
44
+ def can_run(self, batch_size):
45
+ return batch_size < self.max_bs
46
+
47
+ def capture(self, batch_size_list):
48
+ self.batch_size_list = batch_size_list
49
+ with graph_capture() as graph_capture_context:
50
+ self.stream = graph_capture_context.stream
51
+ for bs in batch_size_list:
52
+ graph, input_buffers, output_buffers, flashinfer_handler = self.capture_one_batch_size(bs)
53
+ self.graphs[bs] = graph
54
+ self.input_buffers[bs] = input_buffers
55
+ self.output_buffers[bs] = output_buffers
56
+ self.flashinfer_handlers[bs] = flashinfer_handler
57
+
58
+ def capture_one_batch_size(self, bs):
59
+ from flashinfer import BatchDecodeWithPagedKVCacheWrapper
60
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
61
+
62
+ graph = torch.cuda.CUDAGraph()
63
+ stream = self.stream
64
+
65
+ # Common inputs
66
+ input_ids = self.input_ids[:bs]
67
+ req_pool_indices = self.req_pool_indices[:bs]
68
+ seq_lens = self.seq_lens[:bs]
69
+ position_ids_offsets = self.position_ids_offsets[:bs]
70
+ out_cache_loc = self.out_cache_loc[:bs]
71
+
72
+ # FlashInfer inputs
73
+ if not _grouped_size_compiled_for_decode_kernels(
74
+ self.model_runner.model_config.num_attention_heads // self.model_runner.tp_size,
75
+ self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
76
+ ):
77
+ use_tensor_cores = True
78
+ else:
79
+ use_tensor_cores = False
80
+ flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
81
+ self.flashinfer_workspace_buffer, "NHD",
82
+ use_cuda_graph=True,
83
+ use_tensor_cores=use_tensor_cores,
84
+ paged_kv_indptr_buffer=self.flashinfer_kv_indptr[:bs+1],
85
+ paged_kv_indices_buffer=self.flashinfer_kv_indices,
86
+ paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
87
+ )
88
+ init_flashinfer_args(
89
+ ForwardMode.DECODE,
90
+ self.model_runner,
91
+ req_pool_indices,
92
+ seq_lens,
93
+ None,
94
+ flashinfer_decode_wrapper,
95
+ )
96
+
97
+ # Run and capture
98
+ def run_once():
99
+ input_metadata = InputMetadata.create(
100
+ self.model_runner,
101
+ forward_mode=ForwardMode.DECODE,
102
+ req_pool_indices=req_pool_indices,
103
+ seq_lens=seq_lens,
104
+ prefix_lens=None,
105
+ position_ids_offsets=position_ids_offsets,
106
+ out_cache_loc=out_cache_loc,
107
+ return_logprob=False,
108
+ top_logprobs_nums=0,
109
+ skip_flashinfer_init=True,
110
+ )
111
+ input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
112
+ return self.model_runner.model.forward(
113
+ input_ids, input_metadata.positions, input_metadata
114
+ )
115
+
116
+ for _ in range(2):
117
+ run_once()
118
+
119
+ torch.cuda.synchronize()
120
+ with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
121
+ out = run_once()
122
+ torch.cuda.synchronize()
123
+ self.graph_memory_pool = graph.pool()
124
+ return graph, None, out, flashinfer_decode_wrapper
125
+
126
+ def replay(self, batch: Batch):
127
+ assert batch.out_cache_loc is not None
128
+ assert not batch.return_logprob
129
+ raw_bs = len(batch.reqs)
130
+
131
+ # Pad
132
+ index = bisect.bisect_left(self.batch_size_list, raw_bs)
133
+ bs = self.batch_size_list[index]
134
+ if bs != raw_bs:
135
+ self.seq_lens.zero_()
136
+ self.position_ids_offsets.fill_(1)
137
+ self.out_cache_loc.zero_()
138
+
139
+ # Common inputs
140
+ self.input_ids[:raw_bs] = batch.input_ids
141
+ self.req_pool_indices[:raw_bs] = batch.req_pool_indices
142
+ self.seq_lens[:raw_bs] = batch.seq_lens
143
+ self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
144
+ self.out_cache_loc[:raw_bs] = batch.out_cache_loc
145
+
146
+ # FlashInfer inputs
147
+ init_flashinfer_args(
148
+ ForwardMode.DECODE,
149
+ self.model_runner,
150
+ self.req_pool_indices[:bs],
151
+ self.seq_lens[:bs],
152
+ None,
153
+ self.flashinfer_handlers[bs],
154
+ )
155
+
156
+ # Replay
157
+ self.graphs[bs].replay()
158
+ output = self.output_buffers[bs]
159
+
160
+ # Unpad
161
+ if bs == raw_bs:
162
+ return output
163
+ else:
164
+ output = LogitProcessorOutput(
165
+ next_token_logits=output.next_token_logits[:raw_bs],
166
+ next_token_logprobs=output.next_token_logprobs[:raw_bs] if output.next_token_logprobs is not None else None,
167
+ normalized_prompt_logprobs=None,
168
+ prefill_token_logprobs=None,
169
+ prefill_top_logprobs=None,
170
+ decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None,
171
+ )
172
+ return output