sglang 0.1.21__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/vertexai.py +5 -4
  4. sglang/bench.py +627 -0
  5. sglang/bench_latency.py +22 -19
  6. sglang/bench_serving.py +758 -0
  7. sglang/check_env.py +171 -0
  8. sglang/lang/backend/__init__.py +0 -0
  9. sglang/lang/backend/anthropic.py +77 -0
  10. sglang/lang/backend/base_backend.py +80 -0
  11. sglang/lang/backend/litellm.py +90 -0
  12. sglang/lang/backend/openai.py +438 -0
  13. sglang/lang/backend/runtime_endpoint.py +283 -0
  14. sglang/lang/backend/vertexai.py +149 -0
  15. sglang/lang/tracer.py +1 -1
  16. sglang/launch_server.py +1 -1
  17. sglang/launch_server_llavavid.py +1 -4
  18. sglang/srt/conversation.py +1 -1
  19. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  20. sglang/srt/layers/extend_attention.py +0 -39
  21. sglang/srt/layers/linear.py +869 -0
  22. sglang/srt/layers/quantization/__init__.py +49 -0
  23. sglang/srt/layers/quantization/fp8.py +662 -0
  24. sglang/srt/layers/radix_attention.py +31 -5
  25. sglang/srt/layers/token_attention.py +1 -51
  26. sglang/srt/managers/controller/cuda_graph_runner.py +14 -12
  27. sglang/srt/managers/controller/infer_batch.py +47 -49
  28. sglang/srt/managers/controller/manager_multi.py +107 -100
  29. sglang/srt/managers/controller/manager_single.py +76 -96
  30. sglang/srt/managers/controller/model_runner.py +35 -23
  31. sglang/srt/managers/controller/tp_worker.py +127 -138
  32. sglang/srt/managers/detokenizer_manager.py +49 -5
  33. sglang/srt/managers/io_struct.py +36 -17
  34. sglang/srt/managers/tokenizer_manager.py +228 -125
  35. sglang/srt/memory_pool.py +19 -6
  36. sglang/srt/model_loader/model_loader.py +277 -0
  37. sglang/srt/model_loader/utils.py +260 -0
  38. sglang/srt/models/chatglm.py +1 -0
  39. sglang/srt/models/dbrx.py +1 -0
  40. sglang/srt/models/grok.py +1 -0
  41. sglang/srt/models/internlm2.py +317 -0
  42. sglang/srt/models/llama2.py +65 -16
  43. sglang/srt/models/llama_classification.py +1 -0
  44. sglang/srt/models/llava.py +1 -0
  45. sglang/srt/models/llavavid.py +1 -0
  46. sglang/srt/models/minicpm.py +1 -0
  47. sglang/srt/models/mixtral.py +1 -0
  48. sglang/srt/models/mixtral_quant.py +1 -0
  49. sglang/srt/models/qwen.py +1 -0
  50. sglang/srt/models/qwen2.py +6 -0
  51. sglang/srt/models/qwen2_moe.py +7 -4
  52. sglang/srt/models/stablelm.py +1 -0
  53. sglang/srt/openai_api/adapter.py +432 -0
  54. sglang/srt/openai_api/api_adapter.py +432 -0
  55. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  56. sglang/srt/openai_api/openai_protocol.py +207 -0
  57. sglang/srt/openai_api/protocol.py +208 -0
  58. sglang/srt/openai_protocol.py +17 -0
  59. sglang/srt/sampling_params.py +2 -0
  60. sglang/srt/server.py +113 -84
  61. sglang/srt/server_args.py +23 -15
  62. sglang/srt/utils.py +16 -117
  63. sglang/test/test_conversation.py +1 -1
  64. sglang/test/test_openai_protocol.py +1 -1
  65. sglang/test/test_programs.py +1 -1
  66. sglang/test/test_utils.py +2 -2
  67. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -167
  68. sglang-0.1.22.dist-info/RECORD +103 -0
  69. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
  70. sglang-0.1.21.dist-info/RECORD +0 -82
  71. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
  72. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,8 @@ from torch import nn
7
7
  from sglang.global_config import global_config
8
8
  from sglang.srt.layers.extend_attention import extend_attention_fwd
9
9
  from sglang.srt.layers.token_attention import token_attention_fwd
10
- from sglang.srt.managers.controller.infer_batch import global_server_args_dict
11
10
  from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
11
+ from sglang.srt.server import global_server_args_dict
12
12
 
13
13
 
14
14
  class RadixAttention(nn.Module):
@@ -136,7 +136,33 @@ class RadixAttention(nn.Module):
136
136
  return self.decode_forward(q, k, v, input_metadata)
137
137
 
138
138
  def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
139
- key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
140
- key_buffer[input_metadata.out_cache_loc] = cache_k
141
- value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
142
- value_buffer[input_metadata.out_cache_loc] = cache_v
139
+ kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id]
140
+ _store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc)
141
+
142
+
143
+ try:
144
+
145
+ @torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
146
+ def _store_kv_cache(
147
+ k: torch.Tensor,
148
+ v: torch.Tensor,
149
+ kv_cache: torch.Tensor,
150
+ cache_loc: torch.Tensor,
151
+ ) -> None:
152
+ kv_cache[cache_loc, 0] = k
153
+ kv_cache[cache_loc, 1] = v
154
+
155
+ @_store_kv_cache.register_fake
156
+ def _(k, v, kv_cache, cache_loc):
157
+ pass
158
+
159
+ except:
160
+
161
+ def _store_kv_cache(
162
+ k: torch.Tensor,
163
+ v: torch.Tensor,
164
+ kv_cache: torch.Tensor,
165
+ cache_loc: torch.Tensor,
166
+ ) -> None:
167
+ kv_cache[cache_loc, 0] = k
168
+ kv_cache[cache_loc, 1] = v
@@ -5,8 +5,7 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
- from sglang.srt.managers.controller.model_runner import global_server_args_dict
9
- from sglang.srt.utils import wrap_kernel_launcher
8
+ from sglang.srt.server import global_server_args_dict
10
9
 
11
10
  if global_server_args_dict.get("attention_reduce_in_fp32", False):
12
11
  REDUCE_TRITON_TYPE = tl.float32
@@ -162,10 +161,6 @@ def _fwd_kernel_stage2(
162
161
  tl.store(out_ptrs, acc)
163
162
 
164
163
 
165
- cached_kernel_stage1 = None
166
- cached_kernel_stage2 = None
167
-
168
-
169
164
  def _token_att_m_fwd(
170
165
  q,
171
166
  k_buffer,
@@ -194,28 +189,6 @@ def _token_att_m_fwd(
194
189
  else:
195
190
  num_warps = 2
196
191
 
197
- global cached_kernel_stage1
198
- if cached_kernel_stage1:
199
- cached_kernel_stage1(
200
- grid,
201
- num_warps,
202
- q,
203
- k_buffer,
204
- sm_scale,
205
- Req_to_tokens,
206
- B_req_idx,
207
- B_Start_Loc,
208
- B_Seqlen,
209
- att_out,
210
- Req_to_tokens.stride(0),
211
- q.stride(0),
212
- q.stride(1),
213
- k_buffer.stride(0),
214
- k_buffer.stride(1),
215
- att_out.stride(0),
216
- )
217
- return
218
-
219
192
  _fwd_kernel_stage1[grid](
220
193
  q,
221
194
  k_buffer,
@@ -238,7 +211,6 @@ def _token_att_m_fwd(
238
211
  num_warps=num_warps,
239
212
  num_stages=1,
240
213
  )
241
- cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1)
242
214
 
243
215
 
244
216
  def _token_softmax_reducev_fwd(
@@ -257,27 +229,6 @@ def _token_softmax_reducev_fwd(
257
229
 
258
230
  num_warps = 1
259
231
 
260
- global cached_kernel_stage2
261
- if cached_kernel_stage2:
262
- cached_kernel_stage2(
263
- grid,
264
- num_warps,
265
- logics,
266
- v_buffer,
267
- o,
268
- req_to_tokens,
269
- b_req_idx,
270
- b_start_loc,
271
- b_seq_len,
272
- logics.stride(0),
273
- v_buffer.stride(0),
274
- v_buffer.stride(1),
275
- o.stride(0),
276
- o.stride(1),
277
- req_to_tokens.stride(0),
278
- )
279
- return
280
-
281
232
  _fwd_kernel_stage2[grid](
282
233
  logics,
283
234
  v_buffer,
@@ -298,7 +249,6 @@ def _token_softmax_reducev_fwd(
298
249
  num_warps=num_warps,
299
250
  num_stages=3,
300
251
  )
301
- cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2)
302
252
 
303
253
 
304
254
  def token_attention_fwd(
@@ -3,9 +3,10 @@
3
3
  import bisect
4
4
 
5
5
  import torch
6
+ from flashinfer import BatchDecodeWithPagedKVCacheWrapper
7
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
6
8
  from vllm.distributed.parallel_state import graph_capture
7
9
 
8
- from sglang.global_config import global_config
9
10
  from sglang.srt.layers.logits_processor import LogitProcessorOutput
10
11
  from sglang.srt.managers.controller.infer_batch import (
11
12
  Batch,
@@ -74,9 +75,6 @@ class CudaGraphRunner:
74
75
  self.flashinfer_handlers[bs] = flashinfer_handler
75
76
 
76
77
  def capture_one_batch_size(self, bs):
77
- from flashinfer import BatchDecodeWithPagedKVCacheWrapper
78
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
79
-
80
78
  graph = torch.cuda.CUDAGraph()
81
79
  stream = self.stream
82
80
 
@@ -152,8 +150,8 @@ class CudaGraphRunner:
152
150
  index = bisect.bisect_left(self.batch_size_list, raw_bs)
153
151
  bs = self.batch_size_list[index]
154
152
  if bs != raw_bs:
155
- self.seq_lens.zero_()
156
- self.position_ids_offsets.fill_(1)
153
+ self.seq_lens.fill_(1)
154
+ self.position_ids_offsets.zero_()
157
155
  self.out_cache_loc.zero_()
158
156
 
159
157
  # Common inputs
@@ -183,14 +181,18 @@ class CudaGraphRunner:
183
181
  else:
184
182
  output = LogitProcessorOutput(
185
183
  next_token_logits=output.next_token_logits[:raw_bs],
186
- next_token_logprobs=output.next_token_logprobs[:raw_bs]
187
- if output.next_token_logprobs is not None
188
- else None,
184
+ next_token_logprobs=(
185
+ output.next_token_logprobs[:raw_bs]
186
+ if output.next_token_logprobs is not None
187
+ else None
188
+ ),
189
189
  normalized_prompt_logprobs=None,
190
190
  prefill_token_logprobs=None,
191
191
  prefill_top_logprobs=None,
192
- decode_top_logprobs=output.decode_top_logprobs[:raw_bs]
193
- if output.decode_top_logprobs is not None
194
- else None,
192
+ decode_top_logprobs=(
193
+ output.decode_top_logprobs[:raw_bs]
194
+ if output.decode_top_logprobs is not None
195
+ else None
196
+ ),
195
197
  )
196
198
  return output
@@ -7,6 +7,7 @@ from typing import List, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
10
+ from flashinfer.sampling import top_k_top_p_sampling_from_probs
10
11
 
11
12
  from sglang.srt.constrained import RegexGuide
12
13
  from sglang.srt.constrained.jump_forward import JumpForwardMap
@@ -15,9 +16,6 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
15
16
 
16
17
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
17
18
 
18
- # Store some global server args
19
- global_server_args_dict = {}
20
-
21
19
 
22
20
  class ForwardMode(IntEnum):
23
21
  # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
@@ -84,6 +82,15 @@ class Req:
84
82
  self.input_ids = None # input_ids = origin_input_ids + output_ids
85
83
 
86
84
  # For incremental decoding
85
+ # ----- | --------- read_ids -------|
86
+ # ----- | surr_ids |
87
+ # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
88
+ # ----- ^ ----------- ^ ----------- ^
89
+ # ----- 1 ----------- 2 ----------- 3
90
+ # 1: surr_offset
91
+ # 2: read_offset
92
+ # 3: last token
93
+ self.vid = 0 # version id to sync decode status with in detokenizer_manager
87
94
  self.decoded_text = ""
88
95
  self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
89
96
  self.read_offset = None
@@ -134,7 +141,7 @@ class Req:
134
141
  return self.finished_reason is not None
135
142
 
136
143
  # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
137
- def init_detokenize_incrementally(self):
144
+ def init_incremental_detokenize(self):
138
145
  first_iter = self.surr_offset is None or self.read_offset is None
139
146
 
140
147
  if first_iter:
@@ -144,13 +151,11 @@ class Req:
144
151
  )
145
152
 
146
153
  all_ids = self.origin_input_ids_unpadded + self.output_ids
147
- surr_ids = all_ids[self.surr_offset : self.read_offset]
148
- read_ids = all_ids[self.surr_offset :]
149
-
150
- return surr_ids, read_ids, len(all_ids)
154
+ return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
151
155
 
152
- def detokenize_incrementally(self, inplace: bool = True):
153
- surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
156
+ def get_next_inc_detokenization(self):
157
+ read_ids, read_offset = self.init_incremental_detokenize()
158
+ surr_ids = read_ids[:read_offset]
154
159
 
155
160
  surr_text = self.tokenizer.decode(
156
161
  surr_ids,
@@ -164,13 +169,7 @@ class Req:
164
169
  )
165
170
 
166
171
  if len(new_text) > len(surr_text) and not new_text.endswith("�"):
167
- new_text = new_text[len(surr_text) :]
168
- if inplace:
169
- self.decoded_text += new_text
170
- self.surr_offset = self.read_offset
171
- self.read_offset = num_all_tokens
172
-
173
- return True, new_text
172
+ return True, new_text[len(surr_text) :]
174
173
 
175
174
  return False, ""
176
175
 
@@ -272,6 +271,7 @@ class Batch:
272
271
  prefix_lens: torch.Tensor = None
273
272
  position_ids_offsets: torch.Tensor = None
274
273
  out_cache_loc: torch.Tensor = None
274
+ extend_num_tokens: int = None
275
275
 
276
276
  # For processing logprobs
277
277
  return_logprob: bool = False
@@ -282,10 +282,6 @@ class Batch:
282
282
  image_sizes: List[List[int]] = None
283
283
  image_offsets: List[int] = None
284
284
 
285
- # Other arguments for control
286
- output_ids: torch.Tensor = None
287
- extend_num_tokens: int = None
288
-
289
285
  # Batched sampling params
290
286
  temperatures: torch.Tensor = None
291
287
  top_ps: torch.Tensor = None
@@ -327,6 +323,13 @@ class Batch:
327
323
  seq_lens = []
328
324
 
329
325
  req_pool_indices = self.req_to_token_pool.alloc(bs)
326
+
327
+ if req_pool_indices is None:
328
+ raise RuntimeError(
329
+ "Out of memory. "
330
+ "Please set a smaller number for `--max-running-requests`."
331
+ )
332
+
330
333
  req_pool_indices_cpu = req_pool_indices.cpu().numpy()
331
334
  for i in range(bs):
332
335
  flatten_input_ids.extend(input_ids[i])
@@ -398,10 +401,10 @@ class Batch:
398
401
  ).view(-1, 1)
399
402
  self.top_ps = torch.tensor(
400
403
  [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
401
- ).view(-1, 1)
404
+ )
402
405
  self.top_ks = torch.tensor(
403
406
  [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
404
- ).view(-1, 1)
407
+ )
405
408
  self.frequency_penalties = torch.tensor(
406
409
  [r.sampling_params.frequency_penalty for r in reqs],
407
410
  dtype=torch.float,
@@ -499,7 +502,7 @@ class Batch:
499
502
  cur_output_ids = req.output_ids
500
503
 
501
504
  req.output_ids.extend(suffix_ids)
502
- decode_res, new_text = req.detokenize_incrementally(inplace=False)
505
+ decode_res, new_text = req.get_next_inc_detokenization()
503
506
  if not decode_res:
504
507
  req.output_ids = cur_output_ids
505
508
  continue
@@ -518,6 +521,9 @@ class Batch:
518
521
  req.output_ids = cur_output_ids
519
522
  continue
520
523
 
524
+ # The decode status has diverged from detokenizer_manager
525
+ req.vid += 1
526
+
521
527
  # insert the old request into tree_cache
522
528
  if req_pool_indices_cpu is None:
523
529
  req_pool_indices_cpu = self.req_pool_indices.tolist()
@@ -659,20 +665,21 @@ class Batch:
659
665
 
660
666
  # TODO(lmzheng): apply penalty
661
667
  probs = torch.softmax(logits, dim=-1)
662
- probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
663
- try:
664
- sampled_index = torch.multinomial(probs_sort, num_samples=1)
665
- except RuntimeError as e:
666
- warnings.warn(f"Ignore errors in sampling: {e}")
667
- sampled_index = torch.ones(
668
- probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
669
- )
670
- batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
671
- -1
668
+
669
+ max_top_k_round, batch_size = 32, probs.shape[0]
670
+ uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
671
+ batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
672
+ probs, uniform_samples, self.top_ks, self.top_ps
673
+ )
674
+
675
+ # FIXME: this is a temporary fix for the illegal token ids
676
+ illegal_mask = torch.logical_or(
677
+ batch_next_token_ids < 0, batch_next_token_ids >= probs.shape[-1]
672
678
  )
673
- batch_next_token_probs = torch.gather(
674
- probs_sort, dim=1, index=sampled_index
675
- ).view(-1)
679
+ if torch.any(illegal_mask):
680
+ warnings.warn("Illegal sampled token ids")
681
+ probs = probs.masked_fill(torch.isnan(probs), 0.0)
682
+ batch_next_token_ids = torch.argmax(probs, dim=-1)
676
683
 
677
684
  if has_regex:
678
685
  batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
@@ -682,18 +689,7 @@ class Batch:
682
689
  req.regex_fsm_state, batch_next_token_ids_cpu[i]
683
690
  )
684
691
 
685
- return batch_next_token_ids, batch_next_token_probs
686
-
687
-
688
- def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
689
- probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
690
- probs_sum = torch.cumsum(probs_sort, dim=-1)
691
- probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
692
- probs_sort[
693
- torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
694
- ] = 0.0
695
- probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
696
- return probs_sort, probs_idx
692
+ return batch_next_token_ids
697
693
 
698
694
 
699
695
  @dataclass
@@ -829,6 +825,7 @@ def init_flashinfer_args(
829
825
  prefix_lens,
830
826
  flashinfer_decode_wrapper,
831
827
  ):
828
+ """Init auxiliary variables for FlashInfer attention backend."""
832
829
  num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
833
830
  num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
834
831
  head_dim = model_runner.model_config.head_dim
@@ -894,6 +891,7 @@ def init_flashinfer_args(
894
891
 
895
892
 
896
893
  def init_triton_args(forward_mode, seq_lens, prefix_lens):
894
+ """Init auxiliary variables for triton attention backend."""
897
895
  batch_size = len(seq_lens)
898
896
  max_seq_len = int(torch.max(seq_lens))
899
897
  start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")