sglang 0.1.21__py3-none-any.whl → 0.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/vertexai.py +5 -4
  4. sglang/bench.py +627 -0
  5. sglang/bench_latency.py +22 -19
  6. sglang/bench_serving.py +976 -0
  7. sglang/check_env.py +171 -0
  8. sglang/global_config.py +3 -2
  9. sglang/lang/backend/__init__.py +0 -0
  10. sglang/lang/backend/anthropic.py +77 -0
  11. sglang/lang/backend/base_backend.py +80 -0
  12. sglang/lang/backend/litellm.py +90 -0
  13. sglang/lang/backend/openai.py +438 -0
  14. sglang/lang/backend/runtime_endpoint.py +283 -0
  15. sglang/lang/backend/vertexai.py +149 -0
  16. sglang/lang/interpreter.py +1 -0
  17. sglang/lang/tracer.py +1 -1
  18. sglang/launch_server.py +1 -1
  19. sglang/launch_server_llavavid.py +1 -4
  20. sglang/srt/conversation.py +1 -1
  21. sglang/srt/hf_transformers_utils.py +13 -1
  22. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  23. sglang/srt/layers/extend_attention.py +0 -39
  24. sglang/srt/layers/linear.py +869 -0
  25. sglang/srt/layers/logits_processor.py +4 -5
  26. sglang/srt/layers/quantization/__init__.py +49 -0
  27. sglang/srt/layers/quantization/fp8.py +662 -0
  28. sglang/srt/layers/radix_attention.py +39 -24
  29. sglang/srt/layers/token_attention.py +1 -51
  30. sglang/srt/managers/controller/cuda_graph_runner.py +72 -28
  31. sglang/srt/managers/controller/infer_batch.py +90 -63
  32. sglang/srt/managers/controller/manager_multi.py +107 -100
  33. sglang/srt/managers/controller/manager_single.py +76 -96
  34. sglang/srt/managers/controller/model_runner.py +41 -26
  35. sglang/srt/managers/controller/schedule_heuristic.py +8 -3
  36. sglang/srt/managers/controller/tp_worker.py +136 -149
  37. sglang/srt/managers/detokenizer_manager.py +49 -5
  38. sglang/srt/managers/io_struct.py +36 -17
  39. sglang/srt/managers/tokenizer_manager.py +228 -125
  40. sglang/srt/memory_pool.py +32 -11
  41. sglang/srt/model_loader/model_loader.py +277 -0
  42. sglang/srt/model_loader/utils.py +260 -0
  43. sglang/srt/models/chatglm.py +1 -0
  44. sglang/srt/models/dbrx.py +1 -0
  45. sglang/srt/models/deepseek.py +430 -0
  46. sglang/srt/models/gpt_bigcode.py +282 -0
  47. sglang/srt/models/grok.py +1 -0
  48. sglang/srt/models/internlm2.py +317 -0
  49. sglang/srt/models/llama2.py +81 -23
  50. sglang/srt/models/llama_classification.py +1 -0
  51. sglang/srt/models/llava.py +1 -0
  52. sglang/srt/models/llavavid.py +1 -0
  53. sglang/srt/models/minicpm.py +1 -0
  54. sglang/srt/models/mixtral.py +1 -0
  55. sglang/srt/models/mixtral_quant.py +1 -0
  56. sglang/srt/models/qwen.py +1 -0
  57. sglang/srt/models/qwen2.py +6 -0
  58. sglang/srt/models/qwen2_moe.py +7 -4
  59. sglang/srt/models/stablelm.py +1 -0
  60. sglang/srt/openai_api/adapter.py +432 -0
  61. sglang/srt/openai_api/api_adapter.py +432 -0
  62. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  63. sglang/srt/openai_api/openai_protocol.py +207 -0
  64. sglang/srt/openai_api/protocol.py +208 -0
  65. sglang/srt/openai_protocol.py +17 -0
  66. sglang/srt/sampling_params.py +2 -0
  67. sglang/srt/server.py +132 -84
  68. sglang/srt/server_args.py +35 -21
  69. sglang/srt/utils.py +65 -117
  70. sglang/test/test_conversation.py +1 -1
  71. sglang/test/test_openai_protocol.py +1 -1
  72. sglang/test/test_programs.py +1 -1
  73. sglang/test/test_utils.py +2 -2
  74. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/METADATA +162 -168
  75. sglang-0.1.24.dist-info/RECORD +105 -0
  76. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/WHEEL +1 -1
  77. sglang-0.1.21.dist-info/RECORD +0 -82
  78. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/LICENSE +0 -0
  79. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,8 @@ from torch import nn
7
7
  from sglang.global_config import global_config
8
8
  from sglang.srt.layers.extend_attention import extend_attention_fwd
9
9
  from sglang.srt.layers.token_attention import token_attention_fwd
10
- from sglang.srt.managers.controller.infer_batch import global_server_args_dict
11
10
  from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
11
+ from sglang.srt.server import global_server_args_dict
12
12
 
13
13
 
14
14
  class RadixAttention(nn.Module):
@@ -85,32 +85,47 @@ class RadixAttention(nn.Module):
85
85
  return o
86
86
 
87
87
  def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
88
- o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
89
- q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
90
- k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
91
- v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
92
- causal=True,
93
- sm_scale=self.scaling,
94
- logits_soft_cap=self.logit_cap,
95
- )
88
+ if not input_metadata.use_ragged:
89
+ self.store_kv_cache(k, v, input_metadata)
96
90
 
97
- if input_metadata.extend_no_prefix:
98
- o = o1
99
- else:
100
- o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
91
+ o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
101
92
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
102
- input_metadata.token_to_kv_pool.kv_data[self.layer_id],
103
- causal=False,
93
+ input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
94
+ causal=True,
104
95
  sm_scale=self.scaling,
105
96
  logits_soft_cap=self.logit_cap,
106
97
  )
98
+ else:
99
+ o1, s1 = (
100
+ input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
101
+ q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
102
+ k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
103
+ v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
104
+ causal=True,
105
+ sm_scale=self.scaling,
106
+ logits_soft_cap=self.logit_cap,
107
+ )
108
+ )
107
109
 
108
- o, _ = merge_state(o1, s1, o2, s2)
110
+ if input_metadata.extend_no_prefix:
111
+ o = o1
112
+ else:
113
+ o2, s2 = (
114
+ input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
115
+ q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
116
+ input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
117
+ causal=False,
118
+ sm_scale=self.scaling,
119
+ logits_soft_cap=self.logit_cap,
120
+ )
121
+ )
109
122
 
110
- self.store_kv_cache(k, v, input_metadata)
123
+ o, _ = merge_state(o1, s1, o2, s2)
124
+
125
+ self.store_kv_cache(k, v, input_metadata)
111
126
 
112
- if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
113
- torch.cuda.synchronize()
127
+ if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
128
+ torch.cuda.synchronize()
114
129
 
115
130
  return o.view(-1, self.tp_q_head_num * self.head_dim)
116
131
 
@@ -119,7 +134,7 @@ class RadixAttention(nn.Module):
119
134
 
120
135
  o = input_metadata.flashinfer_decode_wrapper.forward(
121
136
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
122
- input_metadata.token_to_kv_pool.kv_data[self.layer_id],
137
+ input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
123
138
  sm_scale=self.scaling,
124
139
  logits_soft_cap=self.logit_cap,
125
140
  )
@@ -136,7 +151,7 @@ class RadixAttention(nn.Module):
136
151
  return self.decode_forward(q, k, v, input_metadata)
137
152
 
138
153
  def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
139
- key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
140
- key_buffer[input_metadata.out_cache_loc] = cache_k
141
- value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
142
- value_buffer[input_metadata.out_cache_loc] = cache_v
154
+ k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
155
+ v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
156
+ k_cache[input_metadata.out_cache_loc] = cache_k
157
+ v_cache[input_metadata.out_cache_loc] = cache_v
@@ -5,8 +5,7 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
- from sglang.srt.managers.controller.model_runner import global_server_args_dict
9
- from sglang.srt.utils import wrap_kernel_launcher
8
+ from sglang.srt.server import global_server_args_dict
10
9
 
11
10
  if global_server_args_dict.get("attention_reduce_in_fp32", False):
12
11
  REDUCE_TRITON_TYPE = tl.float32
@@ -162,10 +161,6 @@ def _fwd_kernel_stage2(
162
161
  tl.store(out_ptrs, acc)
163
162
 
164
163
 
165
- cached_kernel_stage1 = None
166
- cached_kernel_stage2 = None
167
-
168
-
169
164
  def _token_att_m_fwd(
170
165
  q,
171
166
  k_buffer,
@@ -194,28 +189,6 @@ def _token_att_m_fwd(
194
189
  else:
195
190
  num_warps = 2
196
191
 
197
- global cached_kernel_stage1
198
- if cached_kernel_stage1:
199
- cached_kernel_stage1(
200
- grid,
201
- num_warps,
202
- q,
203
- k_buffer,
204
- sm_scale,
205
- Req_to_tokens,
206
- B_req_idx,
207
- B_Start_Loc,
208
- B_Seqlen,
209
- att_out,
210
- Req_to_tokens.stride(0),
211
- q.stride(0),
212
- q.stride(1),
213
- k_buffer.stride(0),
214
- k_buffer.stride(1),
215
- att_out.stride(0),
216
- )
217
- return
218
-
219
192
  _fwd_kernel_stage1[grid](
220
193
  q,
221
194
  k_buffer,
@@ -238,7 +211,6 @@ def _token_att_m_fwd(
238
211
  num_warps=num_warps,
239
212
  num_stages=1,
240
213
  )
241
- cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1)
242
214
 
243
215
 
244
216
  def _token_softmax_reducev_fwd(
@@ -257,27 +229,6 @@ def _token_softmax_reducev_fwd(
257
229
 
258
230
  num_warps = 1
259
231
 
260
- global cached_kernel_stage2
261
- if cached_kernel_stage2:
262
- cached_kernel_stage2(
263
- grid,
264
- num_warps,
265
- logics,
266
- v_buffer,
267
- o,
268
- req_to_tokens,
269
- b_req_idx,
270
- b_start_loc,
271
- b_seq_len,
272
- logics.stride(0),
273
- v_buffer.stride(0),
274
- v_buffer.stride(1),
275
- o.stride(0),
276
- o.stride(1),
277
- req_to_tokens.stride(0),
278
- )
279
- return
280
-
281
232
  _fwd_kernel_stage2[grid](
282
233
  logics,
283
234
  v_buffer,
@@ -298,7 +249,6 @@ def _token_softmax_reducev_fwd(
298
249
  num_warps=num_warps,
299
250
  num_stages=3,
300
251
  )
301
- cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2)
302
252
 
303
253
 
304
254
  def token_attention_fwd(
@@ -1,11 +1,14 @@
1
1
  """Run the model with cuda graph."""
2
2
 
3
3
  import bisect
4
+ from contextlib import contextmanager
4
5
 
5
6
  import torch
7
+ from flashinfer import BatchDecodeWithPagedKVCacheWrapper
8
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
6
9
  from vllm.distributed.parallel_state import graph_capture
10
+ from vllm.model_executor.custom_op import CustomOp
7
11
 
8
- from sglang.global_config import global_config
9
12
  from sglang.srt.layers.logits_processor import LogitProcessorOutput
10
13
  from sglang.srt.managers.controller.infer_batch import (
11
14
  Batch,
@@ -13,10 +16,44 @@ from sglang.srt.managers.controller.infer_batch import (
13
16
  InputMetadata,
14
17
  init_flashinfer_args,
15
18
  )
19
+ from sglang.srt.utils import monkey_patch_vllm_all_gather
20
+
21
+
22
+ def _to_torch(model: torch.nn.Module, reverse: bool = False):
23
+ for sub in model._modules.values():
24
+ if isinstance(sub, CustomOp):
25
+ if reverse:
26
+ sub._forward_method = sub.forward_cuda
27
+ else:
28
+ sub._forward_method = sub.forward_native
29
+ if isinstance(sub, torch.nn.Module):
30
+ _to_torch(sub, reverse)
31
+
32
+
33
+ @contextmanager
34
+ def patch_model(
35
+ model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator"
36
+ ):
37
+ backup_ca_comm = None
38
+
39
+ try:
40
+ if use_compile:
41
+ _to_torch(model)
42
+ monkey_patch_vllm_all_gather()
43
+ backup_ca_comm = tp_group.ca_comm
44
+ tp_group.ca_comm = None
45
+ yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
46
+ else:
47
+ yield model.forward
48
+ finally:
49
+ if use_compile:
50
+ _to_torch(model, reverse=True)
51
+ monkey_patch_vllm_all_gather(reverse=True)
52
+ tp_group.ca_comm = backup_ca_comm
16
53
 
17
54
 
18
55
  class CudaGraphRunner:
19
- def __init__(self, model_runner, max_batch_size_to_capture):
56
+ def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
20
57
  self.model_runner = model_runner
21
58
  self.graphs = {}
22
59
  self.input_buffers = {}
@@ -54,6 +91,8 @@ class CudaGraphRunner:
54
91
  (self.max_bs,), dtype=torch.int32, device="cuda"
55
92
  )
56
93
 
94
+ self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
95
+
57
96
  def can_run(self, batch_size):
58
97
  return batch_size < self.max_bs
59
98
 
@@ -62,21 +101,23 @@ class CudaGraphRunner:
62
101
  with graph_capture() as graph_capture_context:
63
102
  self.stream = graph_capture_context.stream
64
103
  for bs in batch_size_list:
65
- (
66
- graph,
67
- input_buffers,
68
- output_buffers,
69
- flashinfer_handler,
70
- ) = self.capture_one_batch_size(bs)
71
- self.graphs[bs] = graph
72
- self.input_buffers[bs] = input_buffers
73
- self.output_buffers[bs] = output_buffers
74
- self.flashinfer_handlers[bs] = flashinfer_handler
75
-
76
- def capture_one_batch_size(self, bs):
77
- from flashinfer import BatchDecodeWithPagedKVCacheWrapper
78
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
79
-
104
+ with patch_model(
105
+ self.model_runner.model,
106
+ bs in self.compile_bs,
107
+ self.model_runner.tp_group,
108
+ ) as forward:
109
+ (
110
+ graph,
111
+ input_buffers,
112
+ output_buffers,
113
+ flashinfer_handler,
114
+ ) = self.capture_one_batch_size(bs, forward)
115
+ self.graphs[bs] = graph
116
+ self.input_buffers[bs] = input_buffers
117
+ self.output_buffers[bs] = output_buffers
118
+ self.flashinfer_handlers[bs] = flashinfer_handler
119
+
120
+ def capture_one_batch_size(self, bs, forward):
80
121
  graph = torch.cuda.CUDAGraph()
81
122
  stream = self.stream
82
123
 
@@ -129,9 +170,8 @@ class CudaGraphRunner:
129
170
  skip_flashinfer_init=True,
130
171
  )
131
172
  input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
132
- return self.model_runner.model.forward(
133
- input_ids, input_metadata.positions, input_metadata
134
- )
173
+
174
+ return forward(input_ids, input_metadata.positions, input_metadata)
135
175
 
136
176
  for _ in range(2):
137
177
  run_once()
@@ -152,8 +192,8 @@ class CudaGraphRunner:
152
192
  index = bisect.bisect_left(self.batch_size_list, raw_bs)
153
193
  bs = self.batch_size_list[index]
154
194
  if bs != raw_bs:
155
- self.seq_lens.zero_()
156
- self.position_ids_offsets.fill_(1)
195
+ self.seq_lens.fill_(1)
196
+ self.position_ids_offsets.zero_()
157
197
  self.out_cache_loc.zero_()
158
198
 
159
199
  # Common inputs
@@ -183,14 +223,18 @@ class CudaGraphRunner:
183
223
  else:
184
224
  output = LogitProcessorOutput(
185
225
  next_token_logits=output.next_token_logits[:raw_bs],
186
- next_token_logprobs=output.next_token_logprobs[:raw_bs]
187
- if output.next_token_logprobs is not None
188
- else None,
226
+ next_token_logprobs=(
227
+ output.next_token_logprobs[:raw_bs]
228
+ if output.next_token_logprobs is not None
229
+ else None
230
+ ),
189
231
  normalized_prompt_logprobs=None,
190
232
  prefill_token_logprobs=None,
191
233
  prefill_top_logprobs=None,
192
- decode_top_logprobs=output.decode_top_logprobs[:raw_bs]
193
- if output.decode_top_logprobs is not None
194
- else None,
234
+ decode_top_logprobs=(
235
+ output.decode_top_logprobs[:raw_bs]
236
+ if output.decode_top_logprobs is not None
237
+ else None
238
+ ),
195
239
  )
196
240
  return output
@@ -7,7 +7,9 @@ from typing import List, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
10
+ from flashinfer.sampling import top_k_top_p_sampling_from_probs
10
11
 
12
+ from sglang.global_config import global_config
11
13
  from sglang.srt.constrained import RegexGuide
12
14
  from sglang.srt.constrained.jump_forward import JumpForwardMap
13
15
  from sglang.srt.managers.controller.radix_cache import RadixCache
@@ -15,9 +17,6 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
15
17
 
16
18
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
17
19
 
18
- # Store some global server args
19
- global_server_args_dict = {}
20
-
21
20
 
22
21
  class ForwardMode(IntEnum):
23
22
  # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
@@ -84,6 +83,15 @@ class Req:
84
83
  self.input_ids = None # input_ids = origin_input_ids + output_ids
85
84
 
86
85
  # For incremental decoding
86
+ # ----- | --------- read_ids -------|
87
+ # ----- | surr_ids |
88
+ # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
89
+ # ----- ^ ----------- ^ ----------- ^
90
+ # ----- 1 ----------- 2 ----------- 3
91
+ # 1: surr_offset
92
+ # 2: read_offset
93
+ # 3: last token
94
+ self.vid = 0 # version id to sync decode status with in detokenizer_manager
87
95
  self.decoded_text = ""
88
96
  self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
89
97
  self.read_offset = None
@@ -134,7 +142,7 @@ class Req:
134
142
  return self.finished_reason is not None
135
143
 
136
144
  # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
137
- def init_detokenize_incrementally(self):
145
+ def init_incremental_detokenize(self):
138
146
  first_iter = self.surr_offset is None or self.read_offset is None
139
147
 
140
148
  if first_iter:
@@ -144,13 +152,11 @@ class Req:
144
152
  )
145
153
 
146
154
  all_ids = self.origin_input_ids_unpadded + self.output_ids
147
- surr_ids = all_ids[self.surr_offset : self.read_offset]
148
- read_ids = all_ids[self.surr_offset :]
149
-
150
- return surr_ids, read_ids, len(all_ids)
155
+ return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
151
156
 
152
- def detokenize_incrementally(self, inplace: bool = True):
153
- surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
157
+ def get_next_inc_detokenization(self):
158
+ read_ids, read_offset = self.init_incremental_detokenize()
159
+ surr_ids = read_ids[:read_offset]
154
160
 
155
161
  surr_text = self.tokenizer.decode(
156
162
  surr_ids,
@@ -164,13 +170,7 @@ class Req:
164
170
  )
165
171
 
166
172
  if len(new_text) > len(surr_text) and not new_text.endswith("�"):
167
- new_text = new_text[len(surr_text) :]
168
- if inplace:
169
- self.decoded_text += new_text
170
- self.surr_offset = self.read_offset
171
- self.read_offset = num_all_tokens
172
-
173
- return True, new_text
173
+ return True, new_text[len(surr_text) :]
174
174
 
175
175
  return False, ""
176
176
 
@@ -272,6 +272,7 @@ class Batch:
272
272
  prefix_lens: torch.Tensor = None
273
273
  position_ids_offsets: torch.Tensor = None
274
274
  out_cache_loc: torch.Tensor = None
275
+ extend_num_tokens: int = None
275
276
 
276
277
  # For processing logprobs
277
278
  return_logprob: bool = False
@@ -282,10 +283,6 @@ class Batch:
282
283
  image_sizes: List[List[int]] = None
283
284
  image_offsets: List[int] = None
284
285
 
285
- # Other arguments for control
286
- output_ids: torch.Tensor = None
287
- extend_num_tokens: int = None
288
-
289
286
  # Batched sampling params
290
287
  temperatures: torch.Tensor = None
291
288
  top_ps: torch.Tensor = None
@@ -327,6 +324,13 @@ class Batch:
327
324
  seq_lens = []
328
325
 
329
326
  req_pool_indices = self.req_to_token_pool.alloc(bs)
327
+
328
+ if req_pool_indices is None:
329
+ raise RuntimeError(
330
+ "Out of memory. "
331
+ "Please set a smaller number for `--max-running-requests`."
332
+ )
333
+
330
334
  req_pool_indices_cpu = req_pool_indices.cpu().numpy()
331
335
  for i in range(bs):
332
336
  flatten_input_ids.extend(input_ids[i])
@@ -398,10 +402,10 @@ class Batch:
398
402
  ).view(-1, 1)
399
403
  self.top_ps = torch.tensor(
400
404
  [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
401
- ).view(-1, 1)
405
+ )
402
406
  self.top_ks = torch.tensor(
403
407
  [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
404
- ).view(-1, 1)
408
+ )
405
409
  self.frequency_penalties = torch.tensor(
406
410
  [r.sampling_params.frequency_penalty for r in reqs],
407
411
  dtype=torch.float,
@@ -428,7 +432,8 @@ class Batch:
428
432
 
429
433
  def retract_decode(self):
430
434
  sorted_indices = [i for i in range(len(self.reqs))]
431
- # TODO(lsyin): improve the priority of retraction
435
+
436
+ # TODO(lsyin): improve retraction policy for radix cache
432
437
  sorted_indices.sort(
433
438
  key=lambda i: (
434
439
  len(self.reqs[i].output_ids),
@@ -440,7 +445,17 @@ class Batch:
440
445
  retracted_reqs = []
441
446
  seq_lens_cpu = self.seq_lens.cpu().numpy()
442
447
  req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
443
- while self.token_to_kv_pool.available_size() < len(self.reqs):
448
+ while (
449
+ self.token_to_kv_pool.available_size()
450
+ < len(sorted_indices) * global_config.retract_decode_steps
451
+ ):
452
+ if len(sorted_indices) == 1:
453
+ # Corner case: only one request left
454
+ assert (
455
+ self.token_to_kv_pool.available_size() > 0
456
+ ), "No space left for only one request"
457
+ break
458
+
444
459
  idx = sorted_indices.pop()
445
460
  req = self.reqs[idx]
446
461
  retracted_reqs.append(req)
@@ -465,7 +480,16 @@ class Batch:
465
480
 
466
481
  self.filter_batch(sorted_indices)
467
482
 
468
- return retracted_reqs
483
+ # Reqs in batch are filtered
484
+ total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
485
+ total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
486
+
487
+ new_estimate_ratio = (
488
+ total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
489
+ ) / total_max_new_tokens
490
+ new_estimate_ratio = min(1.0, new_estimate_ratio)
491
+
492
+ return retracted_reqs, new_estimate_ratio
469
493
 
470
494
  def check_for_jump_forward(self, model_runner):
471
495
  jump_forward_reqs = []
@@ -499,7 +523,7 @@ class Batch:
499
523
  cur_output_ids = req.output_ids
500
524
 
501
525
  req.output_ids.extend(suffix_ids)
502
- decode_res, new_text = req.detokenize_incrementally(inplace=False)
526
+ decode_res, new_text = req.get_next_inc_detokenization()
503
527
  if not decode_res:
504
528
  req.output_ids = cur_output_ids
505
529
  continue
@@ -518,6 +542,9 @@ class Batch:
518
542
  req.output_ids = cur_output_ids
519
543
  continue
520
544
 
545
+ # The decode status has diverged from detokenizer_manager
546
+ req.vid += 1
547
+
521
548
  # insert the old request into tree_cache
522
549
  if req_pool_indices_cpu is None:
523
550
  req_pool_indices_cpu = self.req_pool_indices.tolist()
@@ -659,20 +686,20 @@ class Batch:
659
686
 
660
687
  # TODO(lmzheng): apply penalty
661
688
  probs = torch.softmax(logits, dim=-1)
662
- probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
663
- try:
664
- sampled_index = torch.multinomial(probs_sort, num_samples=1)
665
- except RuntimeError as e:
666
- warnings.warn(f"Ignore errors in sampling: {e}")
667
- sampled_index = torch.ones(
668
- probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
669
- )
670
- batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
671
- -1
689
+
690
+ max_top_k_round, batch_size = 32, probs.shape[0]
691
+ uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
692
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
693
+ probs, uniform_samples, self.top_ks, self.top_ps
672
694
  )
673
- batch_next_token_probs = torch.gather(
674
- probs_sort, dim=1, index=sampled_index
675
- ).view(-1)
695
+
696
+ if torch.any(~success):
697
+ warnings.warn("Sampling failed, fallback to top_k=1 strategy")
698
+ probs = probs.masked_fill(torch.isnan(probs), 0.0)
699
+ argmax_ids = torch.argmax(probs, dim=-1)
700
+ batch_next_token_ids = torch.where(
701
+ success, batch_next_token_ids, argmax_ids
702
+ )
676
703
 
677
704
  if has_regex:
678
705
  batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
@@ -682,18 +709,7 @@ class Batch:
682
709
  req.regex_fsm_state, batch_next_token_ids_cpu[i]
683
710
  )
684
711
 
685
- return batch_next_token_ids, batch_next_token_probs
686
-
687
-
688
- def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
689
- probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
690
- probs_sum = torch.cumsum(probs_sort, dim=-1)
691
- probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
692
- probs_sort[
693
- torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
694
- ] = 0.0
695
- probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
696
- return probs_sort, probs_idx
712
+ return batch_next_token_ids
697
713
 
698
714
 
699
715
  @dataclass
@@ -731,6 +747,7 @@ class InputMetadata:
731
747
  flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
732
748
  flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
733
749
  flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
750
+ use_ragged: bool = False
734
751
 
735
752
  @classmethod
736
753
  def create(
@@ -746,7 +763,10 @@ class InputMetadata:
746
763
  return_logprob=False,
747
764
  skip_flashinfer_init=False,
748
765
  ):
766
+ use_ragged = False
749
767
  if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
768
+ if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
769
+ use_ragged = True
750
770
  init_flashinfer_args(
751
771
  forward_mode,
752
772
  model_runner,
@@ -754,6 +774,7 @@ class InputMetadata:
754
774
  seq_lens,
755
775
  prefix_lens,
756
776
  model_runner.flashinfer_decode_wrapper,
777
+ use_ragged,
757
778
  )
758
779
 
759
780
  batch_size = len(req_pool_indices)
@@ -808,6 +829,7 @@ class InputMetadata:
808
829
  flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
809
830
  flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
810
831
  flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
832
+ use_ragged=use_ragged,
811
833
  )
812
834
 
813
835
  if model_runner.server_args.disable_flashinfer:
@@ -828,16 +850,19 @@ def init_flashinfer_args(
828
850
  seq_lens,
829
851
  prefix_lens,
830
852
  flashinfer_decode_wrapper,
853
+ use_ragged=False,
831
854
  ):
855
+ """Init auxiliary variables for FlashInfer attention backend."""
832
856
  num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
833
857
  num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
834
858
  head_dim = model_runner.model_config.head_dim
835
859
  batch_size = len(req_pool_indices)
860
+ total_num_tokens = int(torch.sum(seq_lens))
836
861
 
837
- if forward_mode == ForwardMode.DECODE:
838
- paged_kernel_lens = seq_lens
839
- else:
862
+ if use_ragged:
840
863
  paged_kernel_lens = prefix_lens
864
+ else:
865
+ paged_kernel_lens = seq_lens
841
866
 
842
867
  kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
843
868
  kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
@@ -870,14 +895,15 @@ def init_flashinfer_args(
870
895
  qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
871
896
  qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
872
897
 
873
- model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
874
- model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
875
- qo_indptr,
876
- qo_indptr,
877
- num_qo_heads,
878
- num_kv_heads,
879
- head_dim,
880
- )
898
+ if use_ragged:
899
+ model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
900
+ model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
901
+ qo_indptr,
902
+ qo_indptr,
903
+ num_qo_heads,
904
+ num_kv_heads,
905
+ head_dim,
906
+ )
881
907
 
882
908
  # cached part
883
909
  model_runner.flashinfer_prefill_wrapper_paged.end_forward()
@@ -894,6 +920,7 @@ def init_flashinfer_args(
894
920
 
895
921
 
896
922
  def init_triton_args(forward_mode, seq_lens, prefix_lens):
923
+ """Init auxiliary variables for triton attention backend."""
897
924
  batch_size = len(seq_lens)
898
925
  max_seq_len = int(torch.max(seq_lens))
899
926
  start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")