sglang 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +3 -2
  3. sglang/global_config.py +1 -1
  4. sglang/lang/backend/runtime_endpoint.py +60 -49
  5. sglang/lang/interpreter.py +4 -2
  6. sglang/lang/ir.py +13 -4
  7. sglang/srt/constrained/jump_forward.py +13 -2
  8. sglang/srt/layers/activation.py +0 -1
  9. sglang/srt/layers/extend_attention.py +3 -1
  10. sglang/srt/layers/fused_moe/__init__.py +1 -0
  11. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  12. sglang/srt/layers/fused_moe/layer.py +587 -0
  13. sglang/srt/layers/logits_processor.py +4 -4
  14. sglang/srt/layers/radix_attention.py +38 -14
  15. sglang/srt/managers/schedule_batch.py +9 -14
  16. sglang/srt/managers/tokenizer_manager.py +1 -1
  17. sglang/srt/managers/tp_worker.py +1 -7
  18. sglang/srt/model_executor/cuda_graph_runner.py +48 -17
  19. sglang/srt/model_executor/forward_batch_info.py +132 -58
  20. sglang/srt/model_executor/model_runner.py +61 -28
  21. sglang/srt/models/chatglm.py +2 -2
  22. sglang/srt/models/commandr.py +1 -1
  23. sglang/srt/models/deepseek.py +2 -2
  24. sglang/srt/models/deepseek_v2.py +7 -6
  25. sglang/srt/models/gemma.py +1 -1
  26. sglang/srt/models/gemma2.py +11 -5
  27. sglang/srt/models/grok.py +50 -396
  28. sglang/srt/models/minicpm.py +2 -2
  29. sglang/srt/models/mixtral.py +56 -254
  30. sglang/srt/models/mixtral_quant.py +1 -4
  31. sglang/srt/models/qwen.py +2 -2
  32. sglang/srt/models/qwen2.py +2 -2
  33. sglang/srt/models/qwen2_moe.py +2 -2
  34. sglang/srt/models/stablelm.py +1 -1
  35. sglang/srt/openai_api/adapter.py +32 -21
  36. sglang/srt/sampling_params.py +0 -4
  37. sglang/srt/server.py +23 -15
  38. sglang/srt/server_args.py +7 -1
  39. sglang/srt/utils.py +1 -2
  40. sglang/test/runners.py +18 -10
  41. sglang/test/test_programs.py +32 -5
  42. sglang/test/test_utils.py +5 -1
  43. sglang/version.py +1 -1
  44. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/METADATA +12 -4
  45. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/RECORD +48 -48
  46. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  47. sglang/srt/model_loader/model_loader.py +0 -292
  48. sglang/srt/model_loader/utils.py +0 -275
  49. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  50. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
sglang/api.py CHANGED
@@ -62,6 +62,7 @@ def gen(
62
62
  name: Optional[str] = None,
63
63
  max_tokens: Optional[int] = None,
64
64
  stop: Optional[Union[str, List[str]]] = None,
65
+ stop_token_ids: Optional[List[int]] = None,
65
66
  temperature: Optional[float] = None,
66
67
  top_p: Optional[float] = None,
67
68
  top_k: Optional[int] = None,
@@ -72,7 +73,7 @@ def gen(
72
73
  logprob_start_len: Optional[int] = None,
73
74
  top_logprobs_num: Optional[int] = None,
74
75
  return_text_in_logprobs: Optional[bool] = None,
75
- dtype: Optional[type] = None,
76
+ dtype: Optional[Union[type, str]] = None,
76
77
  choices: Optional[List[str]] = None,
77
78
  choices_method: Optional[ChoicesSamplingMethod] = None,
78
79
  regex: Optional[str] = None,
@@ -98,6 +99,7 @@ def gen(
98
99
  name,
99
100
  max_tokens,
100
101
  stop,
102
+ stop_token_ids,
101
103
  temperature,
102
104
  top_p,
103
105
  top_k,
@@ -117,6 +119,7 @@ def gen_int(
117
119
  name: Optional[str] = None,
118
120
  max_tokens: Optional[int] = None,
119
121
  stop: Optional[Union[str, List[str]]] = None,
122
+ stop_token_ids: Optional[List[int]] = None,
120
123
  temperature: Optional[float] = None,
121
124
  top_p: Optional[float] = None,
122
125
  top_k: Optional[int] = None,
@@ -132,6 +135,7 @@ def gen_int(
132
135
  name,
133
136
  max_tokens,
134
137
  stop,
138
+ stop_token_ids,
135
139
  temperature,
136
140
  top_p,
137
141
  top_k,
@@ -151,6 +155,7 @@ def gen_string(
151
155
  name: Optional[str] = None,
152
156
  max_tokens: Optional[int] = None,
153
157
  stop: Optional[Union[str, List[str]]] = None,
158
+ stop_token_ids: Optional[List[int]] = None,
154
159
  temperature: Optional[float] = None,
155
160
  top_p: Optional[float] = None,
156
161
  top_k: Optional[int] = None,
@@ -166,6 +171,7 @@ def gen_string(
166
171
  name,
167
172
  max_tokens,
168
173
  stop,
174
+ stop_token_ids,
169
175
  temperature,
170
176
  top_p,
171
177
  top_k,
sglang/bench_latency.py CHANGED
@@ -64,7 +64,7 @@ class BenchArgs:
64
64
  run_name: str = "before"
65
65
  batch_size: Tuple[int] = (1,)
66
66
  input_len: Tuple[int] = (1024,)
67
- output_len: Tuple[int] = (4,)
67
+ output_len: Tuple[int] = (16,)
68
68
  result_filename: str = ""
69
69
  correctness_test: bool = False
70
70
  # This is only used for correctness test
@@ -195,7 +195,7 @@ def extend(reqs, model_runner):
195
195
  token_to_kv_pool=model_runner.token_to_kv_pool,
196
196
  tree_cache=None,
197
197
  )
198
- batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
198
+ batch.prepare_for_extend(model_runner.model_config.vocab_size)
199
199
  output = model_runner.forward(batch, ForwardMode.EXTEND)
200
200
  next_token_ids = batch.sample(output.next_token_logits)
201
201
  return next_token_ids, output.next_token_logits, batch
@@ -221,6 +221,7 @@ def correctness_test(
221
221
 
222
222
  # Prepare inputs
223
223
  input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
224
+ rank_print(f"{input_ids=}")
224
225
 
225
226
  if bench_args.cut_len > 0:
226
227
  # Prefill
sglang/global_config.py CHANGED
@@ -27,7 +27,7 @@ class GlobalConfig:
27
27
  # Runtime constants: others
28
28
  self.num_continue_decode_steps = 10
29
29
  self.retract_decode_steps = 20
30
- self.flashinfer_workspace_size = 192 * 1024 * 1024
30
+ self.flashinfer_workspace_size = 384 * 1024 * 1024
31
31
 
32
32
  # Output tokenization configs
33
33
  self.skip_special_tokens_in_output = True
@@ -1,21 +1,23 @@
1
1
  import json
2
+ import warnings
2
3
  from typing import List, Optional
3
4
 
4
5
  from sglang.global_config import global_config
5
6
  from sglang.lang.backend.base_backend import BaseBackend
6
7
  from sglang.lang.chat_template import get_chat_template_by_model_path
7
- from sglang.lang.choices import (
8
- ChoicesDecision,
9
- ChoicesSamplingMethod,
10
- token_length_normalized,
11
- )
8
+ from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
12
9
  from sglang.lang.interpreter import StreamExecutor
13
- from sglang.lang.ir import SglSamplingParams
10
+ from sglang.lang.ir import (
11
+ REGEX_BOOL,
12
+ REGEX_FLOAT,
13
+ REGEX_INT,
14
+ REGEX_STR,
15
+ SglSamplingParams,
16
+ )
14
17
  from sglang.utils import http_request
15
18
 
16
19
 
17
20
  class RuntimeEndpoint(BaseBackend):
18
-
19
21
  def __init__(
20
22
  self,
21
23
  base_url: str,
@@ -95,32 +97,52 @@ class RuntimeEndpoint(BaseBackend):
95
97
  )
96
98
  self._assert_success(res)
97
99
 
100
+ def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
101
+ if sampling_params.dtype is None:
102
+ return
103
+
104
+ if sampling_params.stop == ():
105
+ sampling_params.stop = []
106
+
107
+ dtype_regex = None
108
+ if sampling_params.dtype in ["int", int]:
109
+
110
+ dtype_regex = REGEX_INT
111
+ sampling_params.stop.extend([" ", "\n"])
112
+ elif sampling_params.dtype in ["float", float]:
113
+
114
+ dtype_regex = REGEX_FLOAT
115
+ sampling_params.stop.extend([" ", "\n"])
116
+ elif sampling_params.dtype in ["str", str]:
117
+
118
+ dtype_regex = REGEX_STR
119
+ elif sampling_params.dtype in ["bool", bool]:
120
+
121
+ dtype_regex = REGEX_BOOL
122
+ else:
123
+ raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
124
+
125
+ if dtype_regex is not None and sampling_params.regex is not None:
126
+ warnings.warn(
127
+ f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
128
+ )
129
+
130
+ sampling_params.regex = dtype_regex
131
+
98
132
  def generate(
99
133
  self,
100
134
  s: StreamExecutor,
101
135
  sampling_params: SglSamplingParams,
102
136
  ):
103
- if sampling_params.dtype is None:
104
- data = {
105
- "text": s.text_,
106
- "sampling_params": {
107
- "skip_special_tokens": global_config.skip_special_tokens_in_output,
108
- "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
109
- **sampling_params.to_srt_kwargs(),
110
- },
111
- }
112
- elif sampling_params.dtype in [int, "int"]:
113
- data = {
114
- "text": s.text_,
115
- "sampling_params": {
116
- "skip_special_tokens": global_config.skip_special_tokens_in_output,
117
- "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
118
- "dtype": "int",
119
- **sampling_params.to_srt_kwargs(),
120
- },
121
- }
122
- else:
123
- raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
137
+ self._handle_dtype_to_regex(sampling_params)
138
+ data = {
139
+ "text": s.text_,
140
+ "sampling_params": {
141
+ "skip_special_tokens": global_config.skip_special_tokens_in_output,
142
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
143
+ **sampling_params.to_srt_kwargs(),
144
+ },
145
+ }
124
146
 
125
147
  for item in [
126
148
  "return_logprob",
@@ -151,27 +173,16 @@ class RuntimeEndpoint(BaseBackend):
151
173
  s: StreamExecutor,
152
174
  sampling_params: SglSamplingParams,
153
175
  ):
154
- if sampling_params.dtype is None:
155
- data = {
156
- "text": s.text_,
157
- "sampling_params": {
158
- "skip_special_tokens": global_config.skip_special_tokens_in_output,
159
- "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
160
- **sampling_params.to_srt_kwargs(),
161
- },
162
- }
163
- elif sampling_params.dtype in [int, "int"]:
164
- data = {
165
- "text": s.text_,
166
- "sampling_params": {
167
- "skip_special_tokens": global_config.skip_special_tokens_in_output,
168
- "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
169
- "dtype": "int",
170
- **sampling_params.to_srt_kwargs(),
171
- },
172
- }
173
- else:
174
- raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
176
+ self._handle_dtype_to_regex(sampling_params)
177
+
178
+ data = {
179
+ "text": s.text_,
180
+ "sampling_params": {
181
+ "skip_special_tokens": global_config.skip_special_tokens_in_output,
182
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
183
+ **sampling_params.to_srt_kwargs(),
184
+ },
185
+ }
175
186
 
176
187
  for item in [
177
188
  "return_logprob",
@@ -20,7 +20,6 @@ from sglang.lang.ir import (
20
20
  SglConstantText,
21
21
  SglExpr,
22
22
  SglExprList,
23
- SglFunction,
24
23
  SglGen,
25
24
  SglImage,
26
25
  SglRoleBegin,
@@ -181,8 +180,10 @@ class StreamExecutor:
181
180
  num_api_spec_tokens=None,
182
181
  use_thread=True,
183
182
  ):
183
+ from sglang.lang.backend.base_backend import BaseBackend
184
+
184
185
  self.sid = uuid.uuid4().hex
185
- self.backend = backend
186
+ self.backend: BaseBackend = backend
186
187
  self.arguments: Dict[str, Any] = arguments
187
188
  self.default_sampling_para = default_sampling_para
188
189
  self.stream = stream
@@ -658,6 +659,7 @@ class StreamExecutor:
658
659
  for item in [
659
660
  "max_new_tokens",
660
661
  "stop",
662
+ "stop_token_ids",
661
663
  "temperature",
662
664
  "top_p",
663
665
  "top_k",
sglang/lang/ir.py CHANGED
@@ -8,16 +8,17 @@ from typing import List, Optional, Union
8
8
  from sglang.global_config import global_config
9
9
  from sglang.lang.choices import ChoicesSamplingMethod
10
10
 
11
- REGEX_INT = r"[-+]?[0-9]+"
12
- REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
11
+ REGEX_INT = r"[-+]?[0-9]+[ \n]*"
12
+ REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*"
13
13
  REGEX_BOOL = r"(True|False)"
14
- REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
14
+ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
15
15
 
16
16
 
17
17
  @dataclasses.dataclass
18
18
  class SglSamplingParams:
19
19
  max_new_tokens: int = 128
20
20
  stop: Union[str, List[str]] = ()
21
+ stop_token_ids: Optional[List[int]] = ()
21
22
  temperature: float = 1.0
22
23
  top_p: float = 1.0
23
24
  top_k: int = -1 # -1 means disable
@@ -37,6 +38,7 @@ class SglSamplingParams:
37
38
  return SglSamplingParams(
38
39
  self.max_new_tokens,
39
40
  self.stop,
41
+ self.stop_token_ids,
40
42
  self.temperature,
41
43
  self.top_p,
42
44
  self.top_k,
@@ -108,6 +110,7 @@ class SglSamplingParams:
108
110
  return {
109
111
  "max_new_tokens": self.max_new_tokens,
110
112
  "stop": self.stop,
113
+ "stop_token_ids": self.stop_token_ids,
111
114
  "temperature": self.temperature,
112
115
  "top_p": self.top_p,
113
116
  "top_k": self.top_k,
@@ -141,7 +144,8 @@ class SglFunction:
141
144
  self,
142
145
  *args,
143
146
  max_new_tokens: int = 128,
144
- stop: Union[str, List[str]] = (),
147
+ stop: Union[str, List[str]] = [],
148
+ stop_token_ids: Optional[List[int]] = [],
145
149
  temperature: float = 1.0,
146
150
  top_p: float = 1.0,
147
151
  top_k: int = -1,
@@ -161,6 +165,7 @@ class SglFunction:
161
165
  default_sampling_para = SglSamplingParams(
162
166
  max_new_tokens=max_new_tokens,
163
167
  stop=stop,
168
+ stop_token_ids=stop_token_ids,
164
169
  temperature=temperature,
165
170
  top_p=top_p,
166
171
  top_k=top_k,
@@ -181,6 +186,7 @@ class SglFunction:
181
186
  *,
182
187
  max_new_tokens: int = 128,
183
188
  stop: Union[str, List[str]] = (),
189
+ stop_token_ids: Optional[List[int]] = [],
184
190
  temperature: float = 1.0,
185
191
  top_p: float = 1.0,
186
192
  top_k: int = -1,
@@ -218,6 +224,7 @@ class SglFunction:
218
224
  default_sampling_para = SglSamplingParams(
219
225
  max_new_tokens=max_new_tokens,
220
226
  stop=stop,
227
+ stop_token_ids=stop_token_ids,
221
228
  temperature=temperature,
222
229
  top_p=top_p,
223
230
  top_k=top_k,
@@ -397,6 +404,7 @@ class SglGen(SglExpr):
397
404
  name: Optional[str] = None,
398
405
  max_new_tokens: Optional[int] = None,
399
406
  stop: Optional[Union[str, List[str]]] = None,
407
+ stop_token_ids: Optional[List[int]] = None,
400
408
  temperature: Optional[float] = None,
401
409
  top_p: Optional[float] = None,
402
410
  top_k: Optional[int] = None,
@@ -416,6 +424,7 @@ class SglGen(SglExpr):
416
424
  self.sampling_params = SglSamplingParams(
417
425
  max_new_tokens=max_new_tokens,
418
426
  stop=stop,
427
+ stop_token_ids=stop_token_ids,
419
428
  temperature=temperature,
420
429
  top_p=top_p,
421
430
  top_k=top_k,
@@ -62,16 +62,22 @@ class JumpForwardMap:
62
62
  id_to_symbol.setdefault(id_, []).append(symbol)
63
63
 
64
64
  transitions = fsm_info.transitions
65
+
65
66
  outgoings_ct = defaultdict(int)
66
- state_to_jump_forward = {}
67
+ # NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
68
+ for s in fsm_info.finals:
69
+ outgoings_ct[s] = 1
67
70
 
71
+ state_to_jump_forward = {}
68
72
  for (state, id_), next_state in transitions.items():
69
73
  if id_ == fsm_info.alphabet_anything_value:
74
+ # Arbitrarily symbol cannot be recognized as jump forward
70
75
  continue
76
+
71
77
  symbols = id_to_symbol[id_]
72
78
  for c in symbols:
73
79
  if len(c) > 1:
74
- # Skip byte level transitions
80
+ # Skip byte level transitions like c = "5E"
75
81
  continue
76
82
 
77
83
  outgoings_ct[state] += 1
@@ -87,6 +93,9 @@ class JumpForwardMap:
87
93
 
88
94
  # Process the byte level jump forward
89
95
  outgoings_ct = defaultdict(int)
96
+ for s in fsm_info.finals:
97
+ outgoings_ct[s] = 1
98
+
90
99
  for (state, id_), next_state in transitions.items():
91
100
  if id_ == fsm_info.alphabet_anything_value:
92
101
  continue
@@ -177,3 +186,5 @@ if __name__ == "__main__":
177
186
  test_main(r"霍格沃茨特快列车|霍比特人比尔博")
178
187
  # 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
179
188
  # 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
189
+
190
+ test_main(r"[-+]?[0-9]+[ ]*")
@@ -14,7 +14,6 @@ limitations under the License.
14
14
  """Fused operators for activation layers."""
15
15
 
16
16
  import torch
17
- import torch.nn as nn
18
17
  import torch.nn.functional as F
19
18
  from flashinfer.activation import silu_and_mul
20
19
  from vllm.model_executor.custom_op import CustomOp
@@ -275,7 +275,9 @@ def extend_attention_fwd(
275
275
  BLOCK_DPE = 0
276
276
  BLOCK_DV = Lv
277
277
 
278
- if CUDA_CAPABILITY[0] >= 8:
278
+ if CUDA_CAPABILITY[0] >= 9:
279
+ BLOCK_M, BLOCK_N = (128, 64)
280
+ elif CUDA_CAPABILITY[0] >= 8:
279
281
  BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
280
282
  else:
281
283
  BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
@@ -0,0 +1 @@
1
+ from sglang.srt.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase