sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/__init__.py +57 -2
  2. sglang/api.py +8 -5
  3. sglang/backend/anthropic.py +18 -4
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +83 -2
  9. sglang/lang/interpreter.py +92 -35
  10. sglang/lang/ir.py +12 -9
  11. sglang/lang/tracer.py +6 -4
  12. sglang/launch_server_llavavid.py +31 -0
  13. sglang/srt/constrained/fsm_cache.py +1 -0
  14. sglang/srt/constrained/jump_forward.py +1 -0
  15. sglang/srt/conversation.py +2 -2
  16. sglang/srt/flush_cache.py +16 -0
  17. sglang/srt/hf_transformers_utils.py +10 -2
  18. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  19. sglang/srt/layers/extend_attention.py +1 -0
  20. sglang/srt/layers/logits_processor.py +114 -54
  21. sglang/srt/layers/radix_attention.py +2 -1
  22. sglang/srt/layers/token_attention.py +1 -0
  23. sglang/srt/managers/detokenizer_manager.py +5 -1
  24. sglang/srt/managers/io_struct.py +27 -3
  25. sglang/srt/managers/router/infer_batch.py +97 -48
  26. sglang/srt/managers/router/manager.py +11 -8
  27. sglang/srt/managers/router/model_rpc.py +169 -90
  28. sglang/srt/managers/router/model_runner.py +110 -166
  29. sglang/srt/managers/router/radix_cache.py +89 -51
  30. sglang/srt/managers/router/scheduler.py +17 -28
  31. sglang/srt/managers/tokenizer_manager.py +110 -33
  32. sglang/srt/memory_pool.py +5 -14
  33. sglang/srt/model_config.py +11 -0
  34. sglang/srt/models/commandr.py +372 -0
  35. sglang/srt/models/dbrx.py +412 -0
  36. sglang/srt/models/dbrx_config.py +281 -0
  37. sglang/srt/models/gemma.py +24 -25
  38. sglang/srt/models/llama2.py +25 -26
  39. sglang/srt/models/llava.py +8 -10
  40. sglang/srt/models/llavavid.py +307 -0
  41. sglang/srt/models/mixtral.py +29 -33
  42. sglang/srt/models/qwen.py +34 -25
  43. sglang/srt/models/qwen2.py +25 -26
  44. sglang/srt/models/stablelm.py +26 -26
  45. sglang/srt/models/yivl.py +3 -5
  46. sglang/srt/openai_api_adapter.py +356 -0
  47. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  48. sglang/srt/sampling_params.py +2 -0
  49. sglang/srt/server.py +91 -456
  50. sglang/srt/server_args.py +79 -49
  51. sglang/srt/utils.py +212 -47
  52. sglang/srt/weight_utils.py +417 -0
  53. sglang/test/test_programs.py +8 -7
  54. sglang/test/test_utils.py +195 -7
  55. sglang/utils.py +77 -26
  56. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
  57. sglang-0.1.16.dist-info/RECORD +72 -0
  58. sglang-0.1.14.dist-info/RECORD +0 -64
  59. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
  60. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
  61. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
@@ -1,24 +1,36 @@
1
1
  from dataclasses import dataclass
2
- from enum import Enum, auto
2
+ from enum import IntEnum, auto
3
3
  from typing import List
4
4
 
5
5
  import numpy as np
6
6
  import torch
7
+
7
8
  from sglang.srt.managers.router.radix_cache import RadixCache
8
9
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
9
10
 
10
11
 
11
- class ForwardMode(Enum):
12
+ class ForwardMode(IntEnum):
12
13
  PREFILL = auto()
13
14
  EXTEND = auto()
14
15
  DECODE = auto()
15
16
 
16
17
 
17
- class FinishReason(Enum):
18
- LENGTH = auto()
18
+ class FinishReason(IntEnum):
19
19
  EOS_TOKEN = auto()
20
+ LENGTH = auto()
20
21
  STOP_STR = auto()
21
22
 
23
+ @staticmethod
24
+ def to_str(reason):
25
+ if reason == FinishReason.EOS_TOKEN:
26
+ return None
27
+ elif reason == FinishReason.LENGTH:
28
+ return "length"
29
+ elif reason == FinishReason.STOP_STR:
30
+ return "stop"
31
+ else:
32
+ return None
33
+
22
34
 
23
35
  class Req:
24
36
  def __init__(self, rid, input_text, input_ids):
@@ -30,6 +42,7 @@ class Req:
30
42
  # Since jump forward may retokenize the prompt with partial outputs,
31
43
  # we maintain the original prompt length to report the correct usage.
32
44
  self.prompt_tokens = len(input_ids)
45
+
33
46
  # The number of decoded tokens for token usage report. Note that
34
47
  # this does not include the jump forward tokens.
35
48
  self.completion_tokens_wo_jump_forward = 0
@@ -40,11 +53,11 @@ class Req:
40
53
  self.image_offset = 0
41
54
  self.pad_value = None
42
55
 
56
+ # Sampling parameters
43
57
  self.sampling_params = None
44
- self.return_logprob = False
45
- self.logprob_start_len = 0
46
58
  self.stream = False
47
59
 
60
+ # Check finish
48
61
  self.tokenizer = None
49
62
  self.finished = False
50
63
  self.finish_reason = None
@@ -54,11 +67,17 @@ class Req:
54
67
  self.prefix_indices = []
55
68
  self.last_node = None
56
69
 
57
- self.logprob = None
58
- self.token_logprob = None
59
- self.normalized_logprob = None
60
-
61
- # For constrained decoding
70
+ # Logprobs
71
+ self.return_logprob = False
72
+ self.logprob_start_len = 0
73
+ self.top_logprobs_num = 0
74
+ self.normalized_prompt_logprob = None
75
+ self.prefill_token_logprobs = None
76
+ self.decode_token_logprobs = None
77
+ self.prefill_top_logprobs = None
78
+ self.decode_top_logprobs = None
79
+
80
+ # Constrained decoding
62
81
  self.regex_fsm = None
63
82
  self.regex_fsm_state = 0
64
83
  self.jump_forward_map = None
@@ -77,6 +96,9 @@ class Req:
77
96
  )
78
97
  if first_token.startswith("▁"):
79
98
  old_output_str = " " + old_output_str
99
+ if self.input_text is None:
100
+ # TODO(lmzheng): This can be wrong. Check with Liangsheng.
101
+ self.input_text = self.tokenizer.decode(self.input_ids)
80
102
  new_input_string = (
81
103
  self.input_text
82
104
  + self.output_and_jump_forward_str
@@ -159,7 +181,10 @@ class Batch:
159
181
  out_cache_loc: torch.Tensor = None
160
182
  out_cache_cont_start: torch.Tensor = None
161
183
  out_cache_cont_end: torch.Tensor = None
184
+
185
+ # for processing logprobs
162
186
  return_logprob: bool = False
187
+ top_logprobs_nums: List[int] = None
163
188
 
164
189
  # for multimodal
165
190
  pixel_values: List[torch.Tensor] = None
@@ -229,12 +254,11 @@ class Batch:
229
254
  extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
230
255
  out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
231
256
  if out_cache_loc is None:
232
- if not self.tree_cache.disable:
233
- self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
234
- out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
257
+ self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
258
+ out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
235
259
 
236
260
  if out_cache_loc is None:
237
- print("Prefill out of memory. This should nerver happen.")
261
+ print("Prefill out of memory. This should never happen.")
238
262
  self.tree_cache.pretty_print()
239
263
  exit()
240
264
 
@@ -245,10 +269,14 @@ class Batch:
245
269
  ] = out_cache_loc[pt : pt + extend_lens[i]]
246
270
  pt += extend_lens[i]
247
271
 
248
- # Handle logit bias
249
- logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device)
272
+ # Handle logit bias but only allocate when needed
273
+ logit_bias = None
250
274
  for i in range(bs):
251
275
  if reqs[i].sampling_params.dtype == "int":
276
+ if logit_bias is None:
277
+ logit_bias = torch.zeros(
278
+ (bs, vocab_size), dtype=torch.float32, device=device
279
+ )
252
280
  logit_bias[i] = int_token_logit_bias
253
281
 
254
282
  # Set fields
@@ -266,6 +294,7 @@ class Batch:
266
294
  self.position_ids_offsets = position_ids_offsets
267
295
  self.extend_num_tokens = extend_num_tokens
268
296
  self.out_cache_loc = out_cache_loc
297
+ self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
269
298
 
270
299
  self.temperatures = torch.tensor(
271
300
  [r.sampling_params.temperature for r in reqs],
@@ -295,8 +324,8 @@ class Batch:
295
324
  if self.token_to_kv_pool.available_size() >= bs:
296
325
  return True
297
326
 
298
- if not self.tree_cache.disable:
299
- self.tree_cache.evict(bs, self.token_to_kv_pool.free)
327
+ self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
328
+
300
329
  if self.token_to_kv_pool.available_size() >= bs:
301
330
  return True
302
331
 
@@ -310,27 +339,27 @@ class Batch:
310
339
  )
311
340
 
312
341
  retracted_reqs = []
313
- seq_lens_np = self.seq_lens.cpu().numpy()
314
- req_pool_indices_np = self.req_pool_indices.cpu().numpy()
342
+ seq_lens_cpu = self.seq_lens.cpu().numpy()
343
+ req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
315
344
  while self.token_to_kv_pool.available_size() < len(self.reqs):
316
345
  idx = sorted_indices.pop()
317
346
  req = self.reqs[idx]
318
347
  retracted_reqs.append(req)
319
348
 
320
- self.tree_cache.dec_ref_counter(req.last_node)
349
+ # TODO: apply more fine-grained retraction
350
+ last_uncached_pos = len(req.prefix_indices)
351
+ token_indices = self.req_to_token_pool.req_to_token[
352
+ req_pool_indices_cpu[idx]
353
+ ][last_uncached_pos : seq_lens_cpu[idx]]
354
+ self.token_to_kv_pool.dec_refs(token_indices)
355
+
356
+ self.tree_cache.dec_lock_ref(req.last_node)
321
357
  req.prefix_indices = None
322
358
  req.last_node = None
323
359
  req.extend_input_len = 0
324
360
  req.output_ids = []
325
361
  req.regex_fsm_state = 0
326
362
 
327
- # TODO: apply more fine-grained retraction
328
-
329
- token_indices = self.req_to_token_pool.req_to_token[
330
- req_pool_indices_np[idx]
331
- ][: seq_lens_np[idx]]
332
- self.token_to_kv_pool.free(token_indices)
333
-
334
363
  self.filter_batch(sorted_indices)
335
364
 
336
365
  return retracted_reqs
@@ -349,20 +378,18 @@ class Batch:
349
378
  if len(jump_forward_str) <= 1:
350
379
  continue
351
380
 
352
- # insert the old request into tree_cache
353
- token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
354
381
  if req_pool_indices_cpu is None:
355
- req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
356
- req_pool_idx = req_pool_indices_cpu[i]
357
- indices = self.req_to_token_pool.req_to_token[
358
- req_pool_idx, : len(token_ids_in_memory)
359
- ]
360
- prefix_len = self.tree_cache.insert(
361
- token_ids_in_memory, indices.clone()
382
+ req_pool_indices_cpu = self.req_pool_indices.tolist()
383
+
384
+ # insert the old request into tree_cache
385
+ self.tree_cache.cache_req(
386
+ token_ids=tuple(req.input_ids + req.output_ids)[:-1],
387
+ last_uncached_pos=len(req.prefix_indices),
388
+ req_pool_idx=req_pool_indices_cpu[i],
362
389
  )
363
- self.token_to_kv_pool.free(indices[:prefix_len])
364
- self.req_to_token_pool.free(req_pool_idx)
365
- self.tree_cache.dec_ref_counter(req.last_node)
390
+
391
+ # unlock the last node
392
+ self.tree_cache.dec_lock_ref(req.last_node)
366
393
 
367
394
  # jump-forward
368
395
  req.jump_forward_and_retokenize(jump_forward_str, next_state)
@@ -391,7 +418,7 @@ class Batch:
391
418
  self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
392
419
 
393
420
  if self.out_cache_loc is None:
394
- print("Decode out of memory. This should nerver happen.")
421
+ print("Decode out of memory. This should never happen.")
395
422
  self.tree_cache.pretty_print()
396
423
  exit()
397
424
 
@@ -415,6 +442,7 @@ class Batch:
415
442
  self.prefix_lens = None
416
443
  self.position_ids_offsets = self.position_ids_offsets[new_indices]
417
444
  self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
445
+ self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
418
446
  self.return_logprob = any(req.return_logprob for req in self.reqs)
419
447
 
420
448
  for item in [
@@ -425,9 +453,12 @@ class Batch:
425
453
  "presence_penalties",
426
454
  "logit_bias",
427
455
  ]:
428
- setattr(self, item, getattr(self, item)[new_indices])
456
+ self_val = getattr(self, item, None)
457
+ # logit_bias can be None
458
+ if self_val is not None:
459
+ setattr(self, item, self_val[new_indices])
429
460
 
430
- def merge(self, other):
461
+ def merge(self, other: "Batch"):
431
462
  self.reqs.extend(other.reqs)
432
463
 
433
464
  self.req_pool_indices = torch.concat(
@@ -439,6 +470,7 @@ class Batch:
439
470
  [self.position_ids_offsets, other.position_ids_offsets]
440
471
  )
441
472
  self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
473
+ self.top_logprobs_nums.extend(other.top_logprobs_nums)
442
474
  self.return_logprob = any(req.return_logprob for req in self.reqs)
443
475
 
444
476
  for item in [
@@ -447,17 +479,34 @@ class Batch:
447
479
  "top_ks",
448
480
  "frequency_penalties",
449
481
  "presence_penalties",
450
- "logit_bias",
451
482
  ]:
452
- setattr(
453
- self, item, torch.concat([getattr(self, item), getattr(other, item)])
483
+ self_val = getattr(self, item, None)
484
+ other_val = getattr(other, item, None)
485
+ setattr(self, item, torch.concat([self_val, other_val]))
486
+
487
+ # logit_bias can be None
488
+ if self.logit_bias is not None or other.logit_bias is not None:
489
+ vocab_size = (
490
+ self.logit_bias.shape[1]
491
+ if self.logit_bias is not None
492
+ else other.logit_bias.shape[1]
454
493
  )
494
+ if self.logit_bias is None:
495
+ self.logit_bias = torch.zeros(
496
+ (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
497
+ )
498
+ if other.logit_bias is None:
499
+ other.logit_bias = torch.zeros(
500
+ (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
501
+ )
502
+ self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
455
503
 
456
504
  def sample(self, logits: torch.Tensor):
457
505
  # Post process logits
458
506
  logits = logits.contiguous()
459
507
  logits.div_(self.temperatures)
460
- logits.add_(self.logit_bias)
508
+ if self.logit_bias is not None:
509
+ logits.add_(self.logit_bias)
461
510
 
462
511
  has_regex = any(req.regex_fsm is not None for req in self.reqs)
463
512
  if has_regex:
@@ -4,7 +4,8 @@ import logging
4
4
  import uvloop
5
5
  import zmq
6
6
  import zmq.asyncio
7
- from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
7
+
8
+ from sglang.global_config import global_config
8
9
  from sglang.srt.managers.router.model_rpc import ModelRpcClient
9
10
  from sglang.srt.server_args import PortArgs, ServerArgs
10
11
  from sglang.srt.utils import get_exception_traceback
@@ -29,7 +30,7 @@ class RouterManager:
29
30
  self.recv_reqs = []
30
31
 
31
32
  # Init some configs
32
- self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
33
+ self.request_dependency_time = global_config.request_dependency_time
33
34
 
34
35
  async def loop_for_forward(self):
35
36
  while True:
@@ -41,12 +42,16 @@ class RouterManager:
41
42
  self.send_to_detokenizer.send_pyobj(obj)
42
43
 
43
44
  # async sleep for receiving the subsequent request and avoiding cache miss
45
+ slept = False
44
46
  if len(out_pyobjs) != 0:
45
47
  has_finished = any([obj.finished for obj in out_pyobjs])
46
48
  if has_finished:
47
- await asyncio.sleep(self.extend_dependency_time)
49
+ if self.request_dependency_time > 0:
50
+ slept = True
51
+ await asyncio.sleep(self.request_dependency_time)
48
52
 
49
- await asyncio.sleep(0.0006)
53
+ if not slept:
54
+ await asyncio.sleep(0.0006)
50
55
 
51
56
  async def loop_for_recv_requests(self):
52
57
  while True:
@@ -55,9 +60,7 @@ class RouterManager:
55
60
 
56
61
 
57
62
  def start_router_process(
58
- server_args: ServerArgs,
59
- port_args: PortArgs,
60
- pipe_writer,
63
+ server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
61
64
  ):
62
65
  logging.basicConfig(
63
66
  level=getattr(logging, server_args.log_level.upper()),
@@ -65,7 +68,7 @@ def start_router_process(
65
68
  )
66
69
 
67
70
  try:
68
- model_client = ModelRpcClient(server_args, port_args)
71
+ model_client = ModelRpcClient(server_args, port_args, model_overide_args)
69
72
  router = RouterManager(model_client, port_args)
70
73
  except Exception:
71
74
  pipe_writer.send(get_exception_traceback())