sglang 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/runtime_endpoint.py +14 -4
  4. sglang/backend/vertexai.py +5 -4
  5. sglang/bench.py +627 -0
  6. sglang/bench_latency.py +22 -20
  7. sglang/bench_serving.py +758 -0
  8. sglang/check_env.py +171 -0
  9. sglang/global_config.py +3 -1
  10. sglang/lang/backend/__init__.py +0 -0
  11. sglang/lang/backend/anthropic.py +77 -0
  12. sglang/lang/backend/base_backend.py +80 -0
  13. sglang/lang/backend/litellm.py +90 -0
  14. sglang/lang/backend/openai.py +438 -0
  15. sglang/lang/backend/runtime_endpoint.py +283 -0
  16. sglang/lang/backend/vertexai.py +149 -0
  17. sglang/lang/chat_template.py +2 -2
  18. sglang/lang/ir.py +3 -3
  19. sglang/lang/tracer.py +1 -1
  20. sglang/launch_server.py +1 -1
  21. sglang/launch_server_llavavid.py +1 -4
  22. sglang/srt/conversation.py +1 -1
  23. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  24. sglang/srt/layers/extend_attention.py +0 -39
  25. sglang/srt/layers/linear.py +869 -0
  26. sglang/srt/layers/quantization/__init__.py +49 -0
  27. sglang/srt/layers/quantization/fp8.py +662 -0
  28. sglang/srt/layers/radix_attention.py +31 -5
  29. sglang/srt/layers/token_attention.py +1 -51
  30. sglang/srt/managers/controller/cuda_graph_runner.py +44 -18
  31. sglang/srt/managers/controller/infer_batch.py +76 -72
  32. sglang/srt/managers/controller/manager_multi.py +109 -98
  33. sglang/srt/managers/controller/manager_single.py +105 -50
  34. sglang/srt/managers/controller/model_runner.py +42 -18
  35. sglang/srt/managers/controller/radix_cache.py +4 -3
  36. sglang/srt/managers/controller/schedule_heuristic.py +4 -0
  37. sglang/srt/managers/controller/tp_worker.py +143 -156
  38. sglang/srt/managers/detokenizer_manager.py +49 -5
  39. sglang/srt/managers/io_struct.py +36 -17
  40. sglang/srt/managers/tokenizer_manager.py +228 -125
  41. sglang/srt/memory_pool.py +46 -58
  42. sglang/srt/model_loader/model_loader.py +277 -0
  43. sglang/srt/model_loader/utils.py +260 -0
  44. sglang/srt/models/chatglm.py +1 -0
  45. sglang/srt/models/dbrx.py +1 -0
  46. sglang/srt/models/grok.py +1 -0
  47. sglang/srt/models/internlm2.py +317 -0
  48. sglang/srt/models/llama2.py +65 -16
  49. sglang/srt/models/llama_classification.py +1 -0
  50. sglang/srt/models/llava.py +1 -0
  51. sglang/srt/models/llavavid.py +1 -0
  52. sglang/srt/models/minicpm.py +2 -8
  53. sglang/srt/models/mixtral.py +1 -0
  54. sglang/srt/models/mixtral_quant.py +1 -0
  55. sglang/srt/models/qwen.py +1 -0
  56. sglang/srt/models/qwen2.py +6 -0
  57. sglang/srt/models/qwen2_moe.py +130 -108
  58. sglang/srt/models/stablelm.py +1 -0
  59. sglang/srt/openai_api/adapter.py +432 -0
  60. sglang/srt/openai_api/api_adapter.py +432 -0
  61. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  62. sglang/srt/openai_api/openai_protocol.py +207 -0
  63. sglang/srt/openai_api/protocol.py +208 -0
  64. sglang/srt/openai_protocol.py +17 -0
  65. sglang/srt/sampling_params.py +2 -0
  66. sglang/srt/server.py +114 -90
  67. sglang/srt/server_args.py +27 -17
  68. sglang/srt/utils.py +17 -118
  69. sglang/test/test_conversation.py +1 -1
  70. sglang/test/test_openai_protocol.py +1 -1
  71. sglang/test/test_programs.py +1 -1
  72. sglang/test/test_utils.py +2 -2
  73. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -159
  74. sglang-0.1.22.dist-info/RECORD +103 -0
  75. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
  76. sglang-0.1.20.dist-info/RECORD +0 -82
  77. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
  78. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,8 @@ from torch import nn
7
7
  from sglang.global_config import global_config
8
8
  from sglang.srt.layers.extend_attention import extend_attention_fwd
9
9
  from sglang.srt.layers.token_attention import token_attention_fwd
10
- from sglang.srt.managers.controller.infer_batch import global_server_args_dict
11
10
  from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
11
+ from sglang.srt.server import global_server_args_dict
12
12
 
13
13
 
14
14
  class RadixAttention(nn.Module):
@@ -136,7 +136,33 @@ class RadixAttention(nn.Module):
136
136
  return self.decode_forward(q, k, v, input_metadata)
137
137
 
138
138
  def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
139
- key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
140
- key_buffer[input_metadata.out_cache_loc] = cache_k
141
- value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
142
- value_buffer[input_metadata.out_cache_loc] = cache_v
139
+ kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id]
140
+ _store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc)
141
+
142
+
143
+ try:
144
+
145
+ @torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
146
+ def _store_kv_cache(
147
+ k: torch.Tensor,
148
+ v: torch.Tensor,
149
+ kv_cache: torch.Tensor,
150
+ cache_loc: torch.Tensor,
151
+ ) -> None:
152
+ kv_cache[cache_loc, 0] = k
153
+ kv_cache[cache_loc, 1] = v
154
+
155
+ @_store_kv_cache.register_fake
156
+ def _(k, v, kv_cache, cache_loc):
157
+ pass
158
+
159
+ except:
160
+
161
+ def _store_kv_cache(
162
+ k: torch.Tensor,
163
+ v: torch.Tensor,
164
+ kv_cache: torch.Tensor,
165
+ cache_loc: torch.Tensor,
166
+ ) -> None:
167
+ kv_cache[cache_loc, 0] = k
168
+ kv_cache[cache_loc, 1] = v
@@ -5,8 +5,7 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
- from sglang.srt.managers.controller.model_runner import global_server_args_dict
9
- from sglang.srt.utils import wrap_kernel_launcher
8
+ from sglang.srt.server import global_server_args_dict
10
9
 
11
10
  if global_server_args_dict.get("attention_reduce_in_fp32", False):
12
11
  REDUCE_TRITON_TYPE = tl.float32
@@ -162,10 +161,6 @@ def _fwd_kernel_stage2(
162
161
  tl.store(out_ptrs, acc)
163
162
 
164
163
 
165
- cached_kernel_stage1 = None
166
- cached_kernel_stage2 = None
167
-
168
-
169
164
  def _token_att_m_fwd(
170
165
  q,
171
166
  k_buffer,
@@ -194,28 +189,6 @@ def _token_att_m_fwd(
194
189
  else:
195
190
  num_warps = 2
196
191
 
197
- global cached_kernel_stage1
198
- if cached_kernel_stage1:
199
- cached_kernel_stage1(
200
- grid,
201
- num_warps,
202
- q,
203
- k_buffer,
204
- sm_scale,
205
- Req_to_tokens,
206
- B_req_idx,
207
- B_Start_Loc,
208
- B_Seqlen,
209
- att_out,
210
- Req_to_tokens.stride(0),
211
- q.stride(0),
212
- q.stride(1),
213
- k_buffer.stride(0),
214
- k_buffer.stride(1),
215
- att_out.stride(0),
216
- )
217
- return
218
-
219
192
  _fwd_kernel_stage1[grid](
220
193
  q,
221
194
  k_buffer,
@@ -238,7 +211,6 @@ def _token_att_m_fwd(
238
211
  num_warps=num_warps,
239
212
  num_stages=1,
240
213
  )
241
- cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1)
242
214
 
243
215
 
244
216
  def _token_softmax_reducev_fwd(
@@ -257,27 +229,6 @@ def _token_softmax_reducev_fwd(
257
229
 
258
230
  num_warps = 1
259
231
 
260
- global cached_kernel_stage2
261
- if cached_kernel_stage2:
262
- cached_kernel_stage2(
263
- grid,
264
- num_warps,
265
- logics,
266
- v_buffer,
267
- o,
268
- req_to_tokens,
269
- b_req_idx,
270
- b_start_loc,
271
- b_seq_len,
272
- logics.stride(0),
273
- v_buffer.stride(0),
274
- v_buffer.stride(1),
275
- o.stride(0),
276
- o.stride(1),
277
- req_to_tokens.stride(0),
278
- )
279
- return
280
-
281
232
  _fwd_kernel_stage2[grid](
282
233
  logics,
283
234
  v_buffer,
@@ -298,7 +249,6 @@ def _token_softmax_reducev_fwd(
298
249
  num_warps=num_warps,
299
250
  num_stages=3,
300
251
  )
301
- cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2)
302
252
 
303
253
 
304
254
  def token_attention_fwd(
@@ -3,12 +3,16 @@
3
3
  import bisect
4
4
 
5
5
  import torch
6
+ from flashinfer import BatchDecodeWithPagedKVCacheWrapper
7
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
6
8
  from vllm.distributed.parallel_state import graph_capture
7
9
 
8
- from sglang.global_config import global_config
9
10
  from sglang.srt.layers.logits_processor import LogitProcessorOutput
10
11
  from sglang.srt.managers.controller.infer_batch import (
11
- Batch, ForwardMode, InputMetadata, init_flashinfer_args
12
+ Batch,
13
+ ForwardMode,
14
+ InputMetadata,
15
+ init_flashinfer_args,
12
16
  )
13
17
 
14
18
 
@@ -24,18 +28,28 @@ class CudaGraphRunner:
24
28
  # Common inputs
25
29
  self.max_bs = max_batch_size_to_capture
26
30
  self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
27
- self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
31
+ self.req_pool_indices = torch.zeros(
32
+ (self.max_bs,), dtype=torch.int32, device="cuda"
33
+ )
28
34
  self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
29
- self.position_ids_offsets = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
30
- self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
35
+ self.position_ids_offsets = torch.zeros(
36
+ (self.max_bs,), dtype=torch.int32, device="cuda"
37
+ )
38
+ self.out_cache_loc = torch.zeros(
39
+ (self.max_bs,), dtype=torch.int32, device="cuda"
40
+ )
31
41
 
32
42
  # FlashInfer inputs
33
- self.flashinfer_workspace_buffer = self.model_runner.flashinfer_workspace_buffers[0]
43
+ self.flashinfer_workspace_buffer = (
44
+ self.model_runner.flashinfer_workspace_buffers[0]
45
+ )
34
46
  self.flashinfer_kv_indptr = torch.zeros(
35
47
  (self.max_bs + 1,), dtype=torch.int32, device="cuda"
36
48
  )
37
49
  self.flashinfer_kv_indices = torch.zeros(
38
- (self.max_bs * model_runner.model_config.context_len,), dtype=torch.int32, device="cuda"
50
+ (self.max_bs * model_runner.model_config.context_len,),
51
+ dtype=torch.int32,
52
+ device="cuda",
39
53
  )
40
54
  self.flashinfer_kv_last_page_len = torch.ones(
41
55
  (self.max_bs,), dtype=torch.int32, device="cuda"
@@ -49,16 +63,18 @@ class CudaGraphRunner:
49
63
  with graph_capture() as graph_capture_context:
50
64
  self.stream = graph_capture_context.stream
51
65
  for bs in batch_size_list:
52
- graph, input_buffers, output_buffers, flashinfer_handler = self.capture_one_batch_size(bs)
66
+ (
67
+ graph,
68
+ input_buffers,
69
+ output_buffers,
70
+ flashinfer_handler,
71
+ ) = self.capture_one_batch_size(bs)
53
72
  self.graphs[bs] = graph
54
73
  self.input_buffers[bs] = input_buffers
55
74
  self.output_buffers[bs] = output_buffers
56
75
  self.flashinfer_handlers[bs] = flashinfer_handler
57
76
 
58
77
  def capture_one_batch_size(self, bs):
59
- from flashinfer import BatchDecodeWithPagedKVCacheWrapper
60
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
61
-
62
78
  graph = torch.cuda.CUDAGraph()
63
79
  stream = self.stream
64
80
 
@@ -71,17 +87,19 @@ class CudaGraphRunner:
71
87
 
72
88
  # FlashInfer inputs
73
89
  if not _grouped_size_compiled_for_decode_kernels(
74
- self.model_runner.model_config.num_attention_heads // self.model_runner.tp_size,
90
+ self.model_runner.model_config.num_attention_heads
91
+ // self.model_runner.tp_size,
75
92
  self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
76
93
  ):
77
94
  use_tensor_cores = True
78
95
  else:
79
96
  use_tensor_cores = False
80
97
  flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
81
- self.flashinfer_workspace_buffer, "NHD",
98
+ self.flashinfer_workspace_buffer,
99
+ "NHD",
82
100
  use_cuda_graph=True,
83
101
  use_tensor_cores=use_tensor_cores,
84
- paged_kv_indptr_buffer=self.flashinfer_kv_indptr[:bs+1],
102
+ paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
85
103
  paged_kv_indices_buffer=self.flashinfer_kv_indices,
86
104
  paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
87
105
  )
@@ -132,8 +150,8 @@ class CudaGraphRunner:
132
150
  index = bisect.bisect_left(self.batch_size_list, raw_bs)
133
151
  bs = self.batch_size_list[index]
134
152
  if bs != raw_bs:
135
- self.seq_lens.zero_()
136
- self.position_ids_offsets.fill_(1)
153
+ self.seq_lens.fill_(1)
154
+ self.position_ids_offsets.zero_()
137
155
  self.out_cache_loc.zero_()
138
156
 
139
157
  # Common inputs
@@ -163,10 +181,18 @@ class CudaGraphRunner:
163
181
  else:
164
182
  output = LogitProcessorOutput(
165
183
  next_token_logits=output.next_token_logits[:raw_bs],
166
- next_token_logprobs=output.next_token_logprobs[:raw_bs] if output.next_token_logprobs is not None else None,
184
+ next_token_logprobs=(
185
+ output.next_token_logprobs[:raw_bs]
186
+ if output.next_token_logprobs is not None
187
+ else None
188
+ ),
167
189
  normalized_prompt_logprobs=None,
168
190
  prefill_token_logprobs=None,
169
191
  prefill_top_logprobs=None,
170
- decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None,
192
+ decode_top_logprobs=(
193
+ output.decode_top_logprobs[:raw_bs]
194
+ if output.decode_top_logprobs is not None
195
+ else None
196
+ ),
171
197
  )
172
198
  return output
@@ -7,6 +7,7 @@ from typing import List, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
10
+ from flashinfer.sampling import top_k_top_p_sampling_from_probs
10
11
 
11
12
  from sglang.srt.constrained import RegexGuide
12
13
  from sglang.srt.constrained.jump_forward import JumpForwardMap
@@ -15,9 +16,6 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
15
16
 
16
17
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
17
18
 
18
- # Store some global server args
19
- global_server_args_dict = {}
20
-
21
19
 
22
20
  class ForwardMode(IntEnum):
23
21
  # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
@@ -84,6 +82,15 @@ class Req:
84
82
  self.input_ids = None # input_ids = origin_input_ids + output_ids
85
83
 
86
84
  # For incremental decoding
85
+ # ----- | --------- read_ids -------|
86
+ # ----- | surr_ids |
87
+ # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
88
+ # ----- ^ ----------- ^ ----------- ^
89
+ # ----- 1 ----------- 2 ----------- 3
90
+ # 1: surr_offset
91
+ # 2: read_offset
92
+ # 3: last token
93
+ self.vid = 0 # version id to sync decode status with in detokenizer_manager
87
94
  self.decoded_text = ""
88
95
  self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
89
96
  self.read_offset = None
@@ -134,7 +141,7 @@ class Req:
134
141
  return self.finished_reason is not None
135
142
 
136
143
  # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
137
- def init_detokenize_incrementally(self):
144
+ def init_incremental_detokenize(self):
138
145
  first_iter = self.surr_offset is None or self.read_offset is None
139
146
 
140
147
  if first_iter:
@@ -144,13 +151,11 @@ class Req:
144
151
  )
145
152
 
146
153
  all_ids = self.origin_input_ids_unpadded + self.output_ids
147
- surr_ids = all_ids[self.surr_offset : self.read_offset]
148
- read_ids = all_ids[self.surr_offset :]
149
-
150
- return surr_ids, read_ids, len(all_ids)
154
+ return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
151
155
 
152
- def detokenize_incrementally(self, inplace: bool = True):
153
- surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
156
+ def get_next_inc_detokenization(self):
157
+ read_ids, read_offset = self.init_incremental_detokenize()
158
+ surr_ids = read_ids[:read_offset]
154
159
 
155
160
  surr_text = self.tokenizer.decode(
156
161
  surr_ids,
@@ -164,19 +169,10 @@ class Req:
164
169
  )
165
170
 
166
171
  if len(new_text) > len(surr_text) and not new_text.endswith("�"):
167
- new_text = new_text[len(surr_text) :]
168
- if inplace:
169
- self.decoded_text += new_text
170
- self.surr_offset = self.read_offset
171
- self.read_offset = num_all_tokens
172
-
173
- return True, new_text
172
+ return True, new_text[len(surr_text) :]
174
173
 
175
174
  return False, ""
176
175
 
177
- def max_new_tokens(self):
178
- return self.sampling_params.max_new_tokens
179
-
180
176
  def check_finished(self):
181
177
  if self.finished():
182
178
  return
@@ -275,6 +271,7 @@ class Batch:
275
271
  prefix_lens: torch.Tensor = None
276
272
  position_ids_offsets: torch.Tensor = None
277
273
  out_cache_loc: torch.Tensor = None
274
+ extend_num_tokens: int = None
278
275
 
279
276
  # For processing logprobs
280
277
  return_logprob: bool = False
@@ -285,10 +282,6 @@ class Batch:
285
282
  image_sizes: List[List[int]] = None
286
283
  image_offsets: List[int] = None
287
284
 
288
- # Other arguments for control
289
- output_ids: torch.Tensor = None
290
- extend_num_tokens: int = None
291
-
292
285
  # Batched sampling params
293
286
  temperatures: torch.Tensor = None
294
287
  top_ps: torch.Tensor = None
@@ -330,6 +323,13 @@ class Batch:
330
323
  seq_lens = []
331
324
 
332
325
  req_pool_indices = self.req_to_token_pool.alloc(bs)
326
+
327
+ if req_pool_indices is None:
328
+ raise RuntimeError(
329
+ "Out of memory. "
330
+ "Please set a smaller number for `--max-running-requests`."
331
+ )
332
+
333
333
  req_pool_indices_cpu = req_pool_indices.cpu().numpy()
334
334
  for i in range(bs):
335
335
  flatten_input_ids.extend(input_ids[i])
@@ -352,7 +352,7 @@ class Batch:
352
352
  extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
353
353
  out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
354
354
  if out_cache_loc is None:
355
- self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
355
+ self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
356
356
  out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
357
357
 
358
358
  if out_cache_loc is None:
@@ -401,10 +401,10 @@ class Batch:
401
401
  ).view(-1, 1)
402
402
  self.top_ps = torch.tensor(
403
403
  [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
404
- ).view(-1, 1)
404
+ )
405
405
  self.top_ks = torch.tensor(
406
406
  [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
407
- ).view(-1, 1)
407
+ )
408
408
  self.frequency_penalties = torch.tensor(
409
409
  [r.sampling_params.frequency_penalty for r in reqs],
410
410
  dtype=torch.float,
@@ -422,7 +422,7 @@ class Batch:
422
422
  if self.token_to_kv_pool.available_size() >= bs:
423
423
  return True
424
424
 
425
- self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
425
+ self.tree_cache.evict(bs, self.token_to_kv_pool.free)
426
426
 
427
427
  if self.token_to_kv_pool.available_size() >= bs:
428
428
  return True
@@ -453,7 +453,7 @@ class Batch:
453
453
  token_indices = self.req_to_token_pool.req_to_token[
454
454
  req_pool_indices_cpu[idx]
455
455
  ][last_uncached_pos : seq_lens_cpu[idx]]
456
- self.token_to_kv_pool.dec_refs(token_indices)
456
+ self.token_to_kv_pool.free(token_indices)
457
457
 
458
458
  # release the last node
459
459
  self.tree_cache.dec_lock_ref(req.last_node)
@@ -502,7 +502,7 @@ class Batch:
502
502
  cur_output_ids = req.output_ids
503
503
 
504
504
  req.output_ids.extend(suffix_ids)
505
- decode_res, new_text = req.detokenize_incrementally(inplace=False)
505
+ decode_res, new_text = req.get_next_inc_detokenization()
506
506
  if not decode_res:
507
507
  req.output_ids = cur_output_ids
508
508
  continue
@@ -521,6 +521,9 @@ class Batch:
521
521
  req.output_ids = cur_output_ids
522
522
  continue
523
523
 
524
+ # The decode status has diverged from detokenizer_manager
525
+ req.vid += 1
526
+
524
527
  # insert the old request into tree_cache
525
528
  if req_pool_indices_cpu is None:
526
529
  req_pool_indices_cpu = self.req_pool_indices.tolist()
@@ -596,8 +599,7 @@ class Batch:
596
599
  "logit_bias",
597
600
  ]:
598
601
  self_val = getattr(self, item, None)
599
- # logit_bias can be None
600
- if self_val is not None:
602
+ if self_val is not None: # logit_bias can be None
601
603
  setattr(self, item, self_val[new_indices])
602
604
 
603
605
  def merge(self, other: "Batch"):
@@ -663,18 +665,21 @@ class Batch:
663
665
 
664
666
  # TODO(lmzheng): apply penalty
665
667
  probs = torch.softmax(logits, dim=-1)
666
- probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
667
- try:
668
- sampled_index = torch.multinomial(probs_sort, num_samples=1)
669
- except RuntimeError as e:
670
- warnings.warn(f"Ignore errors in sampling: {e}")
671
- sampled_index = torch.ones(probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device)
672
- batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
673
- -1
668
+
669
+ max_top_k_round, batch_size = 32, probs.shape[0]
670
+ uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
671
+ batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
672
+ probs, uniform_samples, self.top_ks, self.top_ps
674
673
  )
675
- batch_next_token_probs = torch.gather(
676
- probs_sort, dim=1, index=sampled_index
677
- ).view(-1)
674
+
675
+ # FIXME: this is a temporary fix for the illegal token ids
676
+ illegal_mask = torch.logical_or(
677
+ batch_next_token_ids < 0, batch_next_token_ids >= probs.shape[-1]
678
+ )
679
+ if torch.any(illegal_mask):
680
+ warnings.warn("Illegal sampled token ids")
681
+ probs = probs.masked_fill(torch.isnan(probs), 0.0)
682
+ batch_next_token_ids = torch.argmax(probs, dim=-1)
678
683
 
679
684
  if has_regex:
680
685
  batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
@@ -684,18 +689,7 @@ class Batch:
684
689
  req.regex_fsm_state, batch_next_token_ids_cpu[i]
685
690
  )
686
691
 
687
- return batch_next_token_ids, batch_next_token_probs
688
-
689
-
690
- def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
691
- probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
692
- probs_sum = torch.cumsum(probs_sort, dim=-1)
693
- probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
694
- probs_sort[
695
- torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
696
- ] = 0.0
697
- probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
698
- return probs_sort, probs_idx
692
+ return batch_next_token_ids
699
693
 
700
694
 
701
695
  @dataclass
@@ -749,8 +743,14 @@ class InputMetadata:
749
743
  skip_flashinfer_init=False,
750
744
  ):
751
745
  if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
752
- init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
753
- model_runner.flashinfer_decode_wrapper)
746
+ init_flashinfer_args(
747
+ forward_mode,
748
+ model_runner,
749
+ req_pool_indices,
750
+ seq_lens,
751
+ prefix_lens,
752
+ model_runner.flashinfer_decode_wrapper,
753
+ )
754
754
 
755
755
  batch_size = len(req_pool_indices)
756
756
 
@@ -807,16 +807,25 @@ class InputMetadata:
807
807
  )
808
808
 
809
809
  if model_runner.server_args.disable_flashinfer:
810
- (ret.triton_max_seq_len,
811
- ret.triton_max_extend_len,
812
- ret.triton_start_loc,
813
- ret.triton_prefix_lens) = init_triton_args(forward_mode, seq_lens, prefix_lens)
810
+ (
811
+ ret.triton_max_seq_len,
812
+ ret.triton_max_extend_len,
813
+ ret.triton_start_loc,
814
+ ret.triton_prefix_lens,
815
+ ) = init_triton_args(forward_mode, seq_lens, prefix_lens)
814
816
 
815
817
  return ret
816
818
 
817
819
 
818
- def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
819
- flashinfer_decode_wrapper):
820
+ def init_flashinfer_args(
821
+ forward_mode,
822
+ model_runner,
823
+ req_pool_indices,
824
+ seq_lens,
825
+ prefix_lens,
826
+ flashinfer_decode_wrapper,
827
+ ):
828
+ """Init auxiliary variables for FlashInfer attention backend."""
820
829
  num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
821
830
  num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
822
831
  head_dim = model_runner.model_config.head_dim
@@ -827,9 +836,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
827
836
  else:
828
837
  paged_kernel_lens = prefix_lens
829
838
 
830
- kv_indptr = torch.zeros(
831
- (batch_size + 1,), dtype=torch.int32, device="cuda"
832
- )
839
+ kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
833
840
  kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
834
841
  req_pool_indices_cpu = req_pool_indices.cpu().numpy()
835
842
  paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
@@ -842,9 +849,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
842
849
  ],
843
850
  dim=0,
844
851
  ).contiguous()
845
- kv_last_page_len = torch.ones(
846
- (batch_size,), dtype=torch.int32, device="cuda"
847
- )
852
+ kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
848
853
 
849
854
  if forward_mode == ForwardMode.DECODE:
850
855
  flashinfer_decode_wrapper.end_forward()
@@ -859,9 +864,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
859
864
  )
860
865
  else:
861
866
  # extend part
862
- qo_indptr = torch.zeros(
863
- (batch_size + 1,), dtype=torch.int32, device="cuda"
864
- )
867
+ qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
865
868
  qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
866
869
 
867
870
  model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
@@ -888,6 +891,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
888
891
 
889
892
 
890
893
  def init_triton_args(forward_mode, seq_lens, prefix_lens):
894
+ """Init auxiliary variables for triton attention backend."""
891
895
  batch_size = len(seq_lens)
892
896
  max_seq_len = int(torch.max(seq_lens))
893
897
  start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")