sglang 0.1.21__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/vertexai.py +5 -4
  4. sglang/bench.py +627 -0
  5. sglang/bench_latency.py +22 -19
  6. sglang/bench_serving.py +758 -0
  7. sglang/check_env.py +171 -0
  8. sglang/lang/backend/__init__.py +0 -0
  9. sglang/lang/backend/anthropic.py +77 -0
  10. sglang/lang/backend/base_backend.py +80 -0
  11. sglang/lang/backend/litellm.py +90 -0
  12. sglang/lang/backend/openai.py +438 -0
  13. sglang/lang/backend/runtime_endpoint.py +283 -0
  14. sglang/lang/backend/vertexai.py +149 -0
  15. sglang/lang/tracer.py +1 -1
  16. sglang/launch_server.py +1 -1
  17. sglang/launch_server_llavavid.py +1 -4
  18. sglang/srt/conversation.py +1 -1
  19. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  20. sglang/srt/layers/extend_attention.py +0 -39
  21. sglang/srt/layers/linear.py +869 -0
  22. sglang/srt/layers/quantization/__init__.py +49 -0
  23. sglang/srt/layers/quantization/fp8.py +662 -0
  24. sglang/srt/layers/radix_attention.py +31 -5
  25. sglang/srt/layers/token_attention.py +1 -51
  26. sglang/srt/managers/controller/cuda_graph_runner.py +14 -12
  27. sglang/srt/managers/controller/infer_batch.py +47 -49
  28. sglang/srt/managers/controller/manager_multi.py +107 -100
  29. sglang/srt/managers/controller/manager_single.py +76 -96
  30. sglang/srt/managers/controller/model_runner.py +35 -23
  31. sglang/srt/managers/controller/tp_worker.py +127 -138
  32. sglang/srt/managers/detokenizer_manager.py +49 -5
  33. sglang/srt/managers/io_struct.py +36 -17
  34. sglang/srt/managers/tokenizer_manager.py +228 -125
  35. sglang/srt/memory_pool.py +19 -6
  36. sglang/srt/model_loader/model_loader.py +277 -0
  37. sglang/srt/model_loader/utils.py +260 -0
  38. sglang/srt/models/chatglm.py +1 -0
  39. sglang/srt/models/dbrx.py +1 -0
  40. sglang/srt/models/grok.py +1 -0
  41. sglang/srt/models/internlm2.py +317 -0
  42. sglang/srt/models/llama2.py +65 -16
  43. sglang/srt/models/llama_classification.py +1 -0
  44. sglang/srt/models/llava.py +1 -0
  45. sglang/srt/models/llavavid.py +1 -0
  46. sglang/srt/models/minicpm.py +1 -0
  47. sglang/srt/models/mixtral.py +1 -0
  48. sglang/srt/models/mixtral_quant.py +1 -0
  49. sglang/srt/models/qwen.py +1 -0
  50. sglang/srt/models/qwen2.py +6 -0
  51. sglang/srt/models/qwen2_moe.py +7 -4
  52. sglang/srt/models/stablelm.py +1 -0
  53. sglang/srt/openai_api/adapter.py +432 -0
  54. sglang/srt/openai_api/api_adapter.py +432 -0
  55. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  56. sglang/srt/openai_api/openai_protocol.py +207 -0
  57. sglang/srt/openai_api/protocol.py +208 -0
  58. sglang/srt/openai_protocol.py +17 -0
  59. sglang/srt/sampling_params.py +2 -0
  60. sglang/srt/server.py +113 -84
  61. sglang/srt/server_args.py +23 -15
  62. sglang/srt/utils.py +16 -117
  63. sglang/test/test_conversation.py +1 -1
  64. sglang/test/test_openai_protocol.py +1 -1
  65. sglang/test/test_programs.py +1 -1
  66. sglang/test/test_utils.py +2 -2
  67. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -167
  68. sglang-0.1.22.dist-info/RECORD +103 -0
  69. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
  70. sglang-0.1.21.dist-info/RECORD +0 -82
  71. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
  72. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,438 @@
1
+ import dataclasses
2
+ import logging
3
+ import time
4
+ import warnings
5
+ from typing import Callable, List, Optional, Union
6
+
7
+ import numpy as np
8
+
9
+ from sglang.lang.backend.base_backend import BaseBackend
10
+ from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
11
+ from sglang.lang.interpreter import StreamExecutor
12
+ from sglang.lang.ir import SglSamplingParams
13
+
14
+ try:
15
+ import openai
16
+ import tiktoken
17
+ except ImportError as e:
18
+ openai = tiktoken = e
19
+
20
+
21
+ logger = logging.getLogger("openai")
22
+
23
+
24
+ def create_logit_bias_int(tokenizer):
25
+ """Get logit bias for integer numbers."""
26
+ int_token_ids = []
27
+
28
+ tokens = tokenizer._mergeable_ranks
29
+ for token, token_id in tokens.items():
30
+ s = tokenizer.decode([token_id])
31
+ if all([c.isdigit() for c in s]) or s in [" "]:
32
+ int_token_ids.append(token_id)
33
+ if len(int_token_ids) >= 300: # OpenAI API limit
34
+ break
35
+ special_tokens = tokenizer._special_tokens
36
+ mask = {t: 100 for t in int_token_ids[:299]}
37
+ mask[special_tokens["<|endoftext|>"]] = 100
38
+ return mask
39
+
40
+
41
+ INSTRUCT_MODEL_NAMES = [
42
+ "gpt-3.5-turbo-instruct",
43
+ ]
44
+
45
+
46
+ @dataclasses.dataclass
47
+ class TokenUsage:
48
+ prompt_tokens: int
49
+ completion_tokens: int
50
+
51
+ def reset(self):
52
+ self.prompt_tokens = self.completion_tokens = 0
53
+
54
+
55
+ class OpenAI(BaseBackend):
56
+ def __init__(
57
+ self,
58
+ model_name: str,
59
+ is_chat_model: Optional[bool] = None,
60
+ chat_template: Optional[ChatTemplate] = None,
61
+ is_azure: bool = False,
62
+ *args,
63
+ **kwargs,
64
+ ):
65
+ super().__init__()
66
+
67
+ if isinstance(openai, Exception):
68
+ raise openai
69
+
70
+ if is_azure:
71
+ self.client = openai.AzureOpenAI(*args, **kwargs)
72
+ else:
73
+ self.client = openai.OpenAI(*args, **kwargs)
74
+
75
+ self.model_name = model_name
76
+ try:
77
+ self.tokenizer = tiktoken.encoding_for_model(model_name)
78
+ except KeyError:
79
+ self.tokenizer = tiktoken.get_encoding("cl100k_base")
80
+ self.logit_bias_int = create_logit_bias_int(self.tokenizer)
81
+
82
+ self.chat_template = chat_template or get_chat_template_by_model_path(
83
+ model_name
84
+ )
85
+
86
+ if is_chat_model is not None:
87
+ self.is_chat_model = is_chat_model
88
+ else:
89
+ if model_name in INSTRUCT_MODEL_NAMES:
90
+ self.is_chat_model = False
91
+ else:
92
+ self.is_chat_model = True
93
+
94
+ self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
95
+
96
+ # Usage
97
+ self.token_usage = TokenUsage(0, 0)
98
+
99
+ # API speculative execution
100
+ # TODO(ying): This does not support multi-threading (run_batch)
101
+ self.spec_kwargs = {}
102
+ self.spec_format = []
103
+ self.spec_max_num_tries = 3
104
+
105
+ def get_chat_template(self):
106
+ return self.chat_template
107
+
108
+ def _prepare_spec_execution(
109
+ self,
110
+ sampling_params: SglSamplingParams,
111
+ num_api_spec_tokens: int,
112
+ spec_var_name: str,
113
+ ):
114
+ if "max_tokens" not in self.spec_kwargs:
115
+ self.spec_kwargs["max_tokens"] = num_api_spec_tokens
116
+ else:
117
+ assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
118
+
119
+ params = sampling_params.to_openai_kwargs()
120
+ for key, value in params.items():
121
+ if key in ["stop"]:
122
+ continue
123
+ if key in ["max_tokens"]:
124
+ warnings.warn(
125
+ "The parameter max_tokens will be overwritten by speculated number of tokens."
126
+ )
127
+ continue
128
+ if key not in self.spec_kwargs:
129
+ self.spec_kwargs[key] = value
130
+ else:
131
+ assert (
132
+ value == self.spec_kwargs[key]
133
+ ), "sampling parameters should be consistent if turn on api speculative execution."
134
+ self.spec_format.append(
135
+ {"text": "", "stop": params["stop"], "name": spec_var_name}
136
+ )
137
+ return "", {}
138
+
139
+ def generate(
140
+ self,
141
+ s: StreamExecutor,
142
+ sampling_params: SglSamplingParams,
143
+ spec_var_name: str = None,
144
+ ):
145
+ if sampling_params.dtype is None:
146
+ if self.is_chat_model:
147
+ if s.num_api_spec_tokens is None:
148
+ if not s.text_.endswith(self.chat_prefix):
149
+ raise RuntimeError(
150
+ "This use case is not supported if api speculative execution is off. "
151
+ "For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
152
+ "Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
153
+ )
154
+ prompt = s.messages_
155
+ else:
156
+ return self._prepare_spec_execution(
157
+ sampling_params, s.num_api_spec_tokens, spec_var_name
158
+ )
159
+ else:
160
+ prompt = s.text_
161
+
162
+ kwargs = sampling_params.to_openai_kwargs()
163
+ comp = openai_completion(
164
+ client=self.client,
165
+ token_usage=self.token_usage,
166
+ is_chat=self.is_chat_model,
167
+ model=self.model_name,
168
+ prompt=prompt,
169
+ **kwargs,
170
+ )
171
+ elif sampling_params.dtype in [str, "str", "string"]:
172
+ assert (
173
+ not self.is_chat_model
174
+ ), "constrained type not supported on chat model"
175
+ kwargs = sampling_params.to_openai_kwargs()
176
+ kwargs.pop("stop")
177
+ comp = openai_completion(
178
+ client=self.client,
179
+ token_usage=self.token_usage,
180
+ is_chat=self.is_chat_model,
181
+ model=self.model_name,
182
+ prompt=s.text_ + '"',
183
+ stop='"',
184
+ **kwargs,
185
+ )
186
+ comp = '"' + comp + '"'
187
+ elif sampling_params.dtype in [int, "int"]:
188
+ assert (
189
+ not self.is_chat_model
190
+ ), "constrained type not supported on chat model"
191
+ kwargs = sampling_params.to_openai_kwargs()
192
+ kwargs.pop("stop")
193
+ comp = openai_completion(
194
+ client=self.client,
195
+ token_usage=self.token_usage,
196
+ is_chat=self.is_chat_model,
197
+ model=self.model_name,
198
+ prompt=s.text_,
199
+ logit_bias=self.logit_bias_int,
200
+ stop=[" "],
201
+ **kwargs,
202
+ )
203
+ else:
204
+ raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
205
+
206
+ return comp, {}
207
+
208
+ def spec_fill(self, value: str):
209
+ assert self.is_chat_model
210
+ self.spec_format.append({"text": value, "stop": None, "name": None})
211
+
212
+ def spec_pattern_match(self, comp):
213
+ for i, term in enumerate(self.spec_format):
214
+ text = term["text"]
215
+ if text != "":
216
+ if comp.startswith(text):
217
+ comp = comp[len(text) :]
218
+ else:
219
+ return False
220
+ else:
221
+ pos = comp.find(term["stop"])
222
+ if pos != -1:
223
+ term["text"] = comp[:pos]
224
+ comp = comp[pos:]
225
+ else:
226
+ if i == len(self.spec_format) - 1:
227
+ term["text"] = comp
228
+ else:
229
+ return False
230
+ return True
231
+
232
+ def role_end_generate(
233
+ self,
234
+ s: StreamExecutor,
235
+ ):
236
+ if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
237
+ return
238
+
239
+ comp = ""
240
+ if not all(x["name"] is None for x in self.spec_format):
241
+ # TODO(ying): throw errors or warnings
242
+ for i in range(self.spec_max_num_tries):
243
+ comp = openai_completion(
244
+ client=self.client,
245
+ token_usage=self.token_usage,
246
+ is_chat=self.is_chat_model,
247
+ model=self.model_name,
248
+ prompt=s.messages_,
249
+ **self.spec_kwargs,
250
+ )
251
+ if self.spec_pattern_match(comp):
252
+ break
253
+
254
+ for term in self.spec_format:
255
+ s.text_ += term["text"]
256
+ name = term["name"]
257
+ if name is not None:
258
+ s.variables[name] = term["text"]
259
+ s.meta_info[name] = {}
260
+ s.variable_event[name].set()
261
+
262
+ self.spec_kwargs = {}
263
+ self.spec_format = []
264
+
265
+ def generate_stream(
266
+ self,
267
+ s: StreamExecutor,
268
+ sampling_params: SglSamplingParams,
269
+ ):
270
+ if sampling_params.dtype is None:
271
+ if self.is_chat_model:
272
+ if not s.text_.endswith(self.chat_prefix):
273
+ raise RuntimeError(
274
+ "This use case is not supported. "
275
+ "For OpenAI chat models, sgl.gen must be right after sgl.assistant"
276
+ )
277
+ prompt = s.messages_
278
+ else:
279
+ prompt = s.text_
280
+
281
+ kwargs = sampling_params.to_openai_kwargs()
282
+ generator = openai_completion_stream(
283
+ client=self.client,
284
+ token_usage=self.token_usage,
285
+ is_chat=self.is_chat_model,
286
+ model=self.model_name,
287
+ prompt=prompt,
288
+ **kwargs,
289
+ )
290
+ return generator
291
+ else:
292
+ raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
293
+
294
+ def select(
295
+ self,
296
+ s: StreamExecutor,
297
+ choices: List[str],
298
+ temperature: float,
299
+ ):
300
+ if self.is_chat_model:
301
+ raise NotImplementedError(
302
+ "select/choices is not supported for chat models. "
303
+ "Please try to use a non-chat model such as gpt-3.5-turbo-instruct"
304
+ )
305
+
306
+ n_choices = len(choices)
307
+ token_ids = [self.tokenizer.encode(x) for x in choices]
308
+ scores = [0] * n_choices
309
+ valid = [len(x) > 0 for x in token_ids]
310
+ prompt_tokens = self.tokenizer.encode(s.text_)
311
+
312
+ max_len = max([len(x) for x in token_ids])
313
+ for step in range(max_len):
314
+ # Build logit bias
315
+ logit_bias = {}
316
+ for i in range(n_choices):
317
+ if valid[i]:
318
+ logit_bias[token_ids[i][step]] = 100
319
+
320
+ # Call API
321
+ ret = self.client.completions.create(
322
+ model=self.model_name,
323
+ prompt=prompt_tokens,
324
+ logit_bias=logit_bias,
325
+ max_tokens=1,
326
+ temperature=temperature,
327
+ )
328
+ ret_str = ret.choices[0].text
329
+ ret_token = self.tokenizer.encode(ret_str)[0]
330
+ self.token_usage.prompt_tokens += ret.usage.prompt_tokens
331
+ self.token_usage.completion_tokens = ret.usage.completion_tokens
332
+
333
+ # TODO:
334
+ # 1. return logits as the scores
335
+ # 2. compute logits of the full choice
336
+ # 3. consider chunk-based decoding
337
+
338
+ # Update valid
339
+ hit = False
340
+ for i in range(n_choices):
341
+ if valid[i]:
342
+ if step == len(token_ids[i]) - 1:
343
+ valid[i] = False
344
+
345
+ if ret_token == token_ids[i][step]:
346
+ scores[i] += 1
347
+ hit = True
348
+ else:
349
+ valid[i] = False
350
+ assert hit
351
+
352
+ if np.sum(valid) <= 1:
353
+ break
354
+
355
+ prompt_tokens.append(ret_token)
356
+
357
+ decision = choices[np.argmax(scores)]
358
+ return decision, scores, None, None
359
+
360
+
361
+ def openai_completion(
362
+ client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
363
+ ):
364
+ for attempt in range(retries):
365
+ try:
366
+ if is_chat:
367
+ if "stop" in kwargs and kwargs["stop"] is None:
368
+ kwargs.pop("stop")
369
+ ret = client.chat.completions.create(messages=prompt, **kwargs)
370
+ comp = ret.choices[0].message.content
371
+ else:
372
+ ret = client.completions.create(prompt=prompt, **kwargs)
373
+ if isinstance(prompt, (list, tuple)):
374
+ comp = [c.text for c in ret.choices]
375
+ else:
376
+ comp = ret.choices[0].text
377
+
378
+ token_usage.prompt_tokens += ret.usage.prompt_tokens
379
+ token_usage.completion_tokens += ret.usage.completion_tokens
380
+ break
381
+ except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
382
+ logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
383
+ time.sleep(5)
384
+ if attempt == retries - 1:
385
+ raise e
386
+ except Exception as e:
387
+ logger.error(f"RuntimeError {e}.")
388
+ raise e
389
+
390
+ return comp
391
+
392
+
393
+ def openai_completion_stream(
394
+ client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
395
+ ):
396
+ for attempt in range(retries):
397
+ try:
398
+ if is_chat:
399
+ if "stop" in kwargs and kwargs["stop"] is None:
400
+ kwargs.pop("stop")
401
+ generator = client.chat.completions.create(
402
+ messages=prompt,
403
+ stream=True,
404
+ stream_options={"include_usage": True},
405
+ **kwargs,
406
+ )
407
+ for ret in generator:
408
+ if len(ret.choices) == 0:
409
+ continue
410
+ try:
411
+ content = ret.choices[0].delta.content
412
+ except IndexError:
413
+ content = None
414
+ yield content or "", {}
415
+ else:
416
+ generator = client.completions.create(
417
+ prompt=prompt,
418
+ stream=True,
419
+ stream_options={"include_usage": True},
420
+ **kwargs,
421
+ )
422
+ for ret in generator:
423
+ if len(ret.choices) == 0:
424
+ continue
425
+ content = ret.choices[0].text
426
+ yield content or "", {}
427
+
428
+ token_usage.prompt_tokens += ret.usage.prompt_tokens
429
+ token_usage.completion_tokens += ret.usage.completion_tokens
430
+ break
431
+ except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
432
+ logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
433
+ time.sleep(5)
434
+ if attempt == retries - 1:
435
+ raise e
436
+ except Exception as e:
437
+ logger.error(f"RuntimeError {e}.")
438
+ raise e