sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
1
1
  """Utilities for Huggingface Transformers."""
2
2
 
3
+ import functools
3
4
  import json
4
5
  import os
5
6
  import warnings
6
- from typing import List, Optional, Tuple, Union
7
+ from typing import AbstractSet, Collection, Literal, Optional, Union
7
8
 
8
9
  from huggingface_hub import snapshot_download
9
- from sglang.srt.utils import is_multimodal_model
10
10
  from transformers import (
11
11
  AutoConfig,
12
12
  AutoProcessor,
@@ -15,6 +15,8 @@ from transformers import (
15
15
  PreTrainedTokenizerFast,
16
16
  )
17
17
 
18
+ from sglang.srt.utils import is_multimodal_model
19
+
18
20
 
19
21
  def download_from_hf(model_path: str):
20
22
  if os.path.exists(model_path):
@@ -29,10 +31,17 @@ def get_config_json(model_path: str):
29
31
  return config
30
32
 
31
33
 
32
- def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None):
34
+ def get_config(
35
+ model: str,
36
+ trust_remote_code: bool,
37
+ revision: Optional[str] = None,
38
+ model_overide_args: Optional[dict] = None,
39
+ ):
33
40
  config = AutoConfig.from_pretrained(
34
41
  model, trust_remote_code=trust_remote_code, revision=revision
35
42
  )
43
+ if model_overide_args:
44
+ config.update(model_overide_args)
36
45
  return config
37
46
 
38
47
 
@@ -76,6 +85,12 @@ def get_tokenizer(
76
85
  tokenizer_revision: Optional[str] = None,
77
86
  **kwargs,
78
87
  ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
88
+ if tokenizer_name.endswith(".json"):
89
+ return TiktokenTokenizer(tokenizer_name)
90
+
91
+ if tokenizer_name.endswith(".model"):
92
+ return SentencePieceTokenizer(tokenizer_name)
93
+
79
94
  """Gets a tokenizer for the given model name via Huggingface."""
80
95
  if is_multimodal_model(tokenizer_name):
81
96
  processor = get_processor(
@@ -162,3 +177,129 @@ def get_processor(
162
177
  **kwargs,
163
178
  )
164
179
  return processor
180
+
181
+
182
+ class TiktokenTokenizer:
183
+ def __init__(self, tokenizer_path):
184
+ import tiktoken
185
+ from jinja2 import Template
186
+
187
+ PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
188
+
189
+ # Read JSON
190
+ name = "tmp-json"
191
+ with open(tokenizer_path, "rb") as fin:
192
+ tok_dict = json.load(fin)
193
+
194
+ mergeable_ranks = {
195
+ bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
196
+ }
197
+ special_tokens = {
198
+ bytes(item["bytes"]).decode(): item["token"]
199
+ for item in tok_dict["special_tokens"]
200
+ }
201
+ assert tok_dict["word_split"] == "V1"
202
+
203
+ kwargs = {
204
+ "name": name,
205
+ "pat_str": tok_dict.get("pat_str", PAT_STR_B),
206
+ "mergeable_ranks": mergeable_ranks,
207
+ "special_tokens": special_tokens,
208
+ }
209
+ if "default_allowed_special" in tok_dict:
210
+ default_allowed_special = set(
211
+ [
212
+ bytes(bytes_list).decode()
213
+ for bytes_list in tok_dict["default_allowed_special"]
214
+ ]
215
+ )
216
+ else:
217
+ default_allowed_special = None
218
+ if "vocab_size" in tok_dict:
219
+ kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
220
+
221
+ tokenizer = tiktoken.Encoding(**kwargs)
222
+ tokenizer._default_allowed_special = default_allowed_special or set()
223
+ tokenizer._default_allowed_special |= {"<|separator|>"}
224
+
225
+ def encode_patched(
226
+ self,
227
+ text: str,
228
+ *,
229
+ allowed_special: Union[
230
+ Literal["all"], AbstractSet[str]
231
+ ] = set(), # noqa: B006
232
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all",
233
+ ) -> list[int]:
234
+ if isinstance(allowed_special, set):
235
+ allowed_special |= self._default_allowed_special
236
+ return tiktoken.Encoding.encode(
237
+ self,
238
+ text,
239
+ allowed_special=allowed_special,
240
+ disallowed_special=disallowed_special,
241
+ )
242
+
243
+ tokenizer.encode = functools.partial(encode_patched, tokenizer)
244
+
245
+ # Convert to HF interface
246
+ self.tokenizer = tokenizer
247
+ self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
248
+ self.vocab_size = tokenizer.n_vocab
249
+ self.chat_template = Template(
250
+ "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
251
+ )
252
+
253
+ def encode(self, x, add_special_tokens=False):
254
+ return self.tokenizer.encode(x)
255
+
256
+ def decode(self, x):
257
+ return self.tokenizer.decode(x)
258
+
259
+ def batch_decode(
260
+ self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
261
+ ):
262
+ if isinstance(batch[0], int):
263
+ batch = [[x] for x in batch]
264
+ return self.tokenizer.decode_batch(batch)
265
+
266
+ def apply_chat_template(self, messages, tokenize, add_generation_prompt):
267
+ ret = self.chat_template.render(
268
+ messages=messages, add_generation_prompt=add_generation_prompt
269
+ )
270
+ return self.encode(ret) if tokenize else ret
271
+
272
+
273
+ class SentencePieceTokenizer:
274
+ def __init__(self, tokenizer_path):
275
+ import sentencepiece as spm
276
+ from jinja2 import Template
277
+
278
+ tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
279
+
280
+ # Convert to HF interface
281
+ self.tokenizer = tokenizer
282
+ self.eos_token_id = tokenizer.eos_id()
283
+ self.vocab_size = tokenizer.vocab_size()
284
+ self.chat_template = Template(
285
+ "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
286
+ )
287
+
288
+ def encode(self, x, add_special_tokens=False):
289
+ return self.tokenizer.encode(x)
290
+
291
+ def decode(self, x):
292
+ return self.tokenizer.decode(x)
293
+
294
+ def batch_decode(
295
+ self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
296
+ ):
297
+ if isinstance(batch[0], int):
298
+ batch = [[x] for x in batch]
299
+ return self.tokenizer.decode(batch)
300
+
301
+ def apply_chat_template(self, messages, tokenize, add_generation_prompt):
302
+ ret = self.chat_template.render(
303
+ messages=messages, add_generation_prompt=add_generation_prompt
304
+ )
305
+ return self.encode(ret) if tokenize else ret
@@ -3,6 +3,7 @@
3
3
  import torch
4
4
  import triton
5
5
  import triton.language as tl
6
+
6
7
  from sglang.srt.utils import wrap_kernel_launcher
7
8
 
8
9
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
@@ -1,12 +1,19 @@
1
1
  import torch
2
2
  import triton
3
3
  import triton.language as tl
4
+
4
5
  from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
5
6
  from sglang.srt.utils import wrap_kernel_launcher
6
7
 
7
8
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
8
9
 
9
10
 
11
+ @triton.jit
12
+ def tanh(x):
13
+ # Tanh is just a scaled sigmoid
14
+ return 2 * tl.sigmoid(2 * x) - 1
15
+
16
+
10
17
  @triton.jit
11
18
  def _fwd_kernel(
12
19
  Q_Extend,
@@ -38,6 +45,7 @@ def _fwd_kernel(
38
45
  BLOCK_DMODEL: tl.constexpr,
39
46
  BLOCK_M: tl.constexpr,
40
47
  BLOCK_N: tl.constexpr,
48
+ logit_cap: tl.constexpr,
41
49
  ):
42
50
  cur_seq = tl.program_id(0)
43
51
  cur_head = tl.program_id(1)
@@ -89,6 +97,10 @@ def _fwd_kernel(
89
97
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
90
98
  qk += tl.dot(q, k)
91
99
  qk *= sm_scale
100
+
101
+ if logit_cap > 0:
102
+ qk = logit_cap * tanh(qk / logit_cap)
103
+
92
104
  qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
93
105
 
94
106
  n_e_max = tl.maximum(tl.max(qk, 1), e_max)
@@ -125,6 +137,10 @@ def _fwd_kernel(
125
137
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
126
138
  qk += tl.dot(q, k)
127
139
  qk *= sm_scale
140
+
141
+ if logit_cap > 0:
142
+ qk = logit_cap * tanh(qk / logit_cap)
143
+
128
144
  mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
129
145
  start_n + offs_n[None, :]
130
146
  )
@@ -175,6 +191,8 @@ def extend_attention_fwd(
175
191
  b_seq_len_extend,
176
192
  max_len_in_batch,
177
193
  max_len_extend,
194
+ sm_scale=None,
195
+ logit_cap=-1,
178
196
  ):
179
197
  """
180
198
  q_extend, k_extend, v_extend, o_extend: contiguous tensors
@@ -196,7 +214,7 @@ def extend_attention_fwd(
196
214
  else:
197
215
  BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
198
216
 
199
- sm_scale = 1.0 / (Lq**0.5)
217
+ sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
200
218
  batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
201
219
  kv_group_num = q_extend.shape[1] // k_extend.shape[1]
202
220
 
@@ -270,6 +288,7 @@ def extend_attention_fwd(
270
288
  BLOCK_N=BLOCK_N,
271
289
  num_warps=num_warps,
272
290
  num_stages=num_stages,
291
+ logit_cap=logit_cap,
273
292
  )
274
293
  cached_kernel = wrap_kernel_launcher(_fwd_kernel)
275
294