sglang 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. sglang/__init__.py +55 -2
  2. sglang/api.py +3 -5
  3. sglang/backend/anthropic.py +33 -13
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +1 -0
  8. sglang/lang/chat_template.py +74 -0
  9. sglang/lang/interpreter.py +40 -16
  10. sglang/lang/ir.py +1 -1
  11. sglang/lang/tracer.py +6 -4
  12. sglang/launch_server.py +2 -1
  13. sglang/srt/constrained/fsm_cache.py +15 -3
  14. sglang/srt/constrained/jump_forward.py +1 -0
  15. sglang/srt/conversation.py +2 -2
  16. sglang/srt/hf_transformers_utils.py +2 -1
  17. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  18. sglang/srt/layers/extend_attention.py +1 -0
  19. sglang/srt/layers/logits_processor.py +114 -54
  20. sglang/srt/layers/radix_attention.py +2 -1
  21. sglang/srt/layers/token_attention.py +1 -0
  22. sglang/srt/managers/detokenizer_manager.py +5 -1
  23. sglang/srt/managers/io_struct.py +12 -0
  24. sglang/srt/managers/router/infer_batch.py +70 -33
  25. sglang/srt/managers/router/manager.py +7 -2
  26. sglang/srt/managers/router/model_rpc.py +116 -73
  27. sglang/srt/managers/router/model_runner.py +121 -155
  28. sglang/srt/managers/router/radix_cache.py +46 -38
  29. sglang/srt/managers/tokenizer_manager.py +56 -11
  30. sglang/srt/memory_pool.py +5 -14
  31. sglang/srt/model_config.py +7 -0
  32. sglang/srt/models/commandr.py +376 -0
  33. sglang/srt/models/dbrx.py +413 -0
  34. sglang/srt/models/dbrx_config.py +281 -0
  35. sglang/srt/models/gemma.py +22 -20
  36. sglang/srt/models/llama2.py +23 -21
  37. sglang/srt/models/llava.py +12 -10
  38. sglang/srt/models/mixtral.py +27 -25
  39. sglang/srt/models/qwen.py +23 -21
  40. sglang/srt/models/qwen2.py +23 -21
  41. sglang/srt/models/stablelm.py +292 -0
  42. sglang/srt/models/yivl.py +6 -5
  43. sglang/srt/openai_api_adapter.py +356 -0
  44. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  45. sglang/srt/sampling_params.py +2 -0
  46. sglang/srt/server.py +68 -439
  47. sglang/srt/server_args.py +76 -49
  48. sglang/srt/utils.py +88 -32
  49. sglang/srt/weight_utils.py +402 -0
  50. sglang/test/test_programs.py +8 -7
  51. sglang/test/test_utils.py +196 -8
  52. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/METADATA +13 -15
  53. sglang-0.1.15.dist-info/RECORD +69 -0
  54. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/WHEEL +1 -1
  55. sglang-0.1.13.dist-info/RECORD +0 -63
  56. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/LICENSE +0 -0
  57. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/top_level.txt +0 -0
@@ -6,11 +6,15 @@ import warnings
6
6
  from concurrent.futures import ThreadPoolExecutor
7
7
  from typing import List
8
8
 
9
- import numpy as np
10
9
  import rpyc
11
10
  import torch
12
11
  from rpyc.utils.classic import obtain
13
12
  from rpyc.utils.server import ThreadedServer
13
+ try:
14
+ from vllm.logger import _default_handler as vllm_default_logger
15
+ except ImportError:
16
+ from vllm.logger import logger as vllm_default_logger
17
+
14
18
  from sglang.srt.constrained.fsm_cache import FSMCache
15
19
  from sglang.srt.constrained.jump_forward import JumpForwardCache
16
20
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
@@ -31,13 +35,15 @@ from sglang.srt.utils import (
31
35
  is_multimodal_model,
32
36
  set_random_seed,
33
37
  )
34
- from vllm.logger import _default_handler as vllm_default_handler
38
+
35
39
 
36
40
  logger = logging.getLogger("model_rpc")
41
+ vllm_default_logger.setLevel(logging.WARN)
42
+ logging.getLogger("vllm.utils").setLevel(logging.WARN)
37
43
 
38
44
 
39
- class ModelRpcServer(rpyc.Service):
40
- def exposed_init_model(
45
+ class ModelRpcServer:
46
+ def __init__(
41
47
  self,
42
48
  tp_rank: int,
43
49
  server_args: ServerArgs,
@@ -50,9 +56,6 @@ class ModelRpcServer(rpyc.Service):
50
56
  self.tp_size = server_args.tp_size
51
57
  self.schedule_heuristic = server_args.schedule_heuristic
52
58
  self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
53
- vllm_default_handler.setLevel(
54
- level=getattr(logging, server_args.log_level.upper())
55
- )
56
59
 
57
60
  # Init model and tokenizer
58
61
  self.model_config = ModelConfig(
@@ -61,7 +64,7 @@ class ModelRpcServer(rpyc.Service):
61
64
  context_length=server_args.context_length,
62
65
  )
63
66
 
64
- # for model end global settings
67
+ # For model end global settings
65
68
  server_args_dict = {
66
69
  "enable_flashinfer": server_args.enable_flashinfer,
67
70
  "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
@@ -90,7 +93,6 @@ class ModelRpcServer(rpyc.Service):
90
93
  tokenizer_mode=server_args.tokenizer_mode,
91
94
  trust_remote_code=server_args.trust_remote_code,
92
95
  )
93
- self.eos_token_id = self.tokenizer.eos_token_id
94
96
  self.max_total_num_token = self.model_runner.max_total_num_token
95
97
  self.max_num_running_seq = self.max_total_num_token // 2
96
98
  self.max_prefill_num_token = max(
@@ -111,10 +113,11 @@ class ModelRpcServer(rpyc.Service):
111
113
  f"max_prefill_num_token={self.max_prefill_num_token}, "
112
114
  f"context_len={self.model_config.context_len}, "
113
115
  )
114
- logger.info(server_args.get_optional_modes_logging())
116
+ if self.tp_rank == 0:
117
+ logger.info(f"server_args: {server_args.print_mode_args()}")
115
118
 
116
119
  # Init cache
117
- self.tree_cache = RadixCache(server_args.disable_radix_cache)
120
+ self.tree_cache = RadixCache(disable=server_args.disable_radix_cache)
118
121
  self.tree_cache_metrics = {"total": 0, "hit": 0}
119
122
  self.scheduler = Scheduler(
120
123
  self.schedule_heuristic,
@@ -161,7 +164,7 @@ class ModelRpcServer(rpyc.Service):
161
164
  logger.info("Cache flushed successfully!")
162
165
  else:
163
166
  warnings.warn(
164
- "Cache not flushed because there are pending requests. "
167
+ f"Cache not flushed because there are pending requests. "
165
168
  f"#queue-req: {len(self.forward_queue)}, "
166
169
  f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
167
170
  )
@@ -262,6 +265,7 @@ class ModelRpcServer(rpyc.Service):
262
265
  req.sampling_params = recv_req.sampling_params
263
266
  req.return_logprob = recv_req.return_logprob
264
267
  req.logprob_start_len = recv_req.logprob_start_len
268
+ req.top_logprobs_num = recv_req.top_logprobs_num
265
269
  req.stream = recv_req.stream
266
270
  req.tokenizer = self.tokenizer
267
271
 
@@ -348,6 +352,7 @@ class ModelRpcServer(rpyc.Service):
348
352
  # Undo the insertion
349
353
  delta = self.tree_cache.dec_ref_counter(req.last_node)
350
354
  available_size += delta
355
+ break
351
356
  else:
352
357
  # Add this request to the running batch
353
358
  self.token_to_kv_pool.add_refs(req.prefix_indices)
@@ -356,7 +361,8 @@ class ModelRpcServer(rpyc.Service):
356
361
  req.extend_input_len + req.max_new_tokens()
357
362
  )
358
363
  new_batch_input_tokens += req.extend_input_len
359
-
364
+ else:
365
+ break
360
366
  if len(can_run_list) == 0:
361
367
  return None
362
368
 
@@ -380,12 +386,12 @@ class ModelRpcServer(rpyc.Service):
380
386
  f"#running_req: {running_req}. "
381
387
  f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
382
388
  )
383
- logger.debug(
384
- f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
385
- f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
386
- f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
387
- f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
388
- )
389
+ #logger.debug(
390
+ # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
391
+ # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
392
+ # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
393
+ # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
394
+ #)
389
395
 
390
396
  new_batch = Batch.init_new(
391
397
  can_run_list,
@@ -402,53 +408,63 @@ class ModelRpcServer(rpyc.Service):
402
408
  self.model_config.vocab_size, self.int_token_logit_bias
403
409
  )
404
410
 
405
- logprobs = None
406
411
  if batch.extend_num_tokens != 0:
407
412
  # Forward
408
413
  logits, (
409
- prefill_logprobs,
410
- normalized_logprobs,
414
+ prefill_token_logprobs,
415
+ normalized_prompt_logprobs,
416
+ prefill_top_logprobs,
417
+ decode_top_logprobs,
411
418
  last_logprobs,
412
- ) = self.model_runner.forward(
413
- batch, ForwardMode.EXTEND, batch.return_logprob
414
- )
415
- if prefill_logprobs is not None:
416
- logprobs = prefill_logprobs.cpu().tolist()
417
- normalized_logprobs = normalized_logprobs.cpu().tolist()
419
+ ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
420
+ if prefill_token_logprobs is not None:
421
+ prefill_token_logprobs = prefill_token_logprobs.tolist()
422
+ normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
418
423
 
419
424
  next_token_ids, _ = batch.sample(logits)
420
- next_token_ids = next_token_ids.cpu().tolist()
425
+
426
+ # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
427
+ if last_logprobs is not None:
428
+ last_token_logprobs = (
429
+ last_logprobs[torch.arange(len(batch.reqs)), next_token_ids].tolist()
430
+ )
431
+
432
+ next_token_ids = next_token_ids.tolist()
421
433
  else:
422
434
  next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
423
- logits = logprobs = normalized_logprobs = last_logprobs = None
424
-
425
- # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
426
- reqs = batch.reqs
427
- if last_logprobs is not None:
428
- last_logprobs = (
429
- last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
430
- )
431
435
 
432
436
  # Check finish condition
433
437
  pt = 0
434
- for i, req in enumerate(reqs):
438
+ for i, req in enumerate(batch.reqs):
435
439
  req.completion_tokens_wo_jump_forward += 1
436
440
  req.output_ids = [next_token_ids[i]]
437
441
  req.check_finished()
438
442
 
439
- if logprobs is not None:
440
- req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
441
- req.normalized_logprob = normalized_logprobs[i]
443
+ if req.return_logprob:
444
+ req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
442
445
 
443
- # If logprob_start_len > 0, then first logprob_start_len prompt tokens
444
- # will be ignored.
445
- prompt_token_len = len(req.logprob)
446
- token_ids = req.input_ids[-prompt_token_len:] + [next_token_ids[i]]
447
- token_logprobs = req.logprob + [last_logprobs[i]]
448
- req.token_logprob = list(zip(token_ids, token_logprobs))
446
+ # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
447
+ req.prefill_token_logprobs = list(
448
+ zip(
449
+ prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
450
+ req.input_ids[-req.extend_input_len + 1 :],
451
+ )
452
+ )
453
+ if req.logprob_start_len == 0:
454
+ req.prefill_token_logprobs = [
455
+ (None, req.input_ids[0])
456
+ ] + req.prefill_token_logprobs
457
+ req.decode_token_logprobs = [
458
+ (last_token_logprobs[i], next_token_ids[i])
459
+ ]
460
+
461
+ if req.top_logprobs_num > 0:
462
+ req.prefill_top_logprobs = prefill_top_logprobs[i]
449
463
  if req.logprob_start_len == 0:
450
- req.token_logprob = [(req.input_ids[0], None)] + req.token_logprob
451
- pt += req.extend_input_len
464
+ req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
465
+ req.decode_top_logprobs = [decode_top_logprobs[i]]
466
+
467
+ pt += req.extend_input_len
452
468
 
453
469
  self.handle_finished_requests(batch)
454
470
 
@@ -497,29 +513,33 @@ class ModelRpcServer(rpyc.Service):
497
513
  batch.prepare_for_decode()
498
514
 
499
515
  # Forward
500
- logits, (_, _, last_logprobs) = self.model_runner.forward(
501
- batch,
502
- ForwardMode.DECODE,
503
- batch.return_logprob,
504
- )
516
+ logits, (
517
+ _,
518
+ _,
519
+ _,
520
+ decode_top_logprobs,
521
+ last_logprobs,
522
+ ) = self.model_runner.forward(batch, ForwardMode.DECODE)
505
523
  next_token_ids, _ = batch.sample(logits)
506
- next_token_ids = next_token_ids.cpu().tolist()
524
+ next_token_ids = next_token_ids.tolist()
507
525
 
508
526
  # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
509
- reqs = batch.reqs
510
527
  if last_logprobs is not None:
511
- last_logprobs = last_logprobs[
512
- torch.arange(len(reqs)), next_token_ids
528
+ new_token_logprobs = last_logprobs[
529
+ torch.arange(len(batch.reqs)), next_token_ids
513
530
  ].tolist()
514
531
 
515
532
  # Check finish condition
516
- for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
533
+ for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
517
534
  req.completion_tokens_wo_jump_forward += 1
518
- req.output_ids.append(next_tok_id)
535
+ req.output_ids.append(next_token_id)
519
536
  req.check_finished()
520
537
 
521
- if last_logprobs is not None:
522
- req.token_logprob.append((next_tok_id, last_logprobs[i]))
538
+ if req.return_logprob:
539
+ req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
540
+
541
+ if req.top_logprobs_num > 0:
542
+ req.decode_top_logprobs.append(decode_top_logprobs[i])
523
543
 
524
544
  self.handle_finished_requests(batch)
525
545
 
@@ -529,6 +549,7 @@ class ModelRpcServer(rpyc.Service):
529
549
  output_and_jump_forward_strs = []
530
550
  output_hit_stop_str = []
531
551
  output_skip_special_tokens = []
552
+ output_spaces_between_special_tokens = []
532
553
  output_meta_info = []
533
554
  output_finished = []
534
555
  finished_indices = []
@@ -555,6 +576,9 @@ class ModelRpcServer(rpyc.Service):
555
576
  output_skip_special_tokens.append(
556
577
  req.sampling_params.skip_special_tokens
557
578
  )
579
+ output_spaces_between_special_tokens.append(
580
+ req.sampling_params.spaces_between_special_tokens
581
+ )
558
582
 
559
583
  meta_info = {
560
584
  "prompt_tokens": req.prompt_tokens,
@@ -562,11 +586,22 @@ class ModelRpcServer(rpyc.Service):
562
586
  + len(req.output_ids)
563
587
  - req.prompt_tokens,
564
588
  "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
589
+ "finish_reason": str(req.finish_reason), # FIXME: convert to the correct string
565
590
  }
566
591
  if req.return_logprob:
567
- meta_info["prompt_logprob"] = req.logprob
568
- meta_info["token_logprob"] = req.token_logprob
569
- meta_info["normalized_prompt_logprob"] = req.normalized_logprob
592
+ (
593
+ meta_info["prefill_token_logprobs"],
594
+ meta_info["decode_token_logprobs"],
595
+ meta_info["prefill_top_logprobs"],
596
+ meta_info["decode_top_logprobs"],
597
+ meta_info["normalized_prompt_logprob"],
598
+ ) = (
599
+ req.prefill_token_logprobs,
600
+ req.decode_token_logprobs,
601
+ req.prefill_top_logprobs,
602
+ req.decode_top_logprobs,
603
+ req.normalized_prompt_logprob,
604
+ )
570
605
  output_meta_info.append(meta_info)
571
606
  output_finished.append(req.finished)
572
607
 
@@ -579,6 +614,7 @@ class ModelRpcServer(rpyc.Service):
579
614
  output_and_jump_forward_strs,
580
615
  output_hit_stop_str,
581
616
  output_skip_special_tokens,
617
+ output_spaces_between_special_tokens,
582
618
  output_meta_info,
583
619
  output_finished,
584
620
  )
@@ -587,7 +623,7 @@ class ModelRpcServer(rpyc.Service):
587
623
  # Remove finished reqs
588
624
  if finished_indices:
589
625
  # Update radix cache
590
- req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
626
+ req_pool_indices_cpu = batch.req_pool_indices.tolist()
591
627
  for i in finished_indices:
592
628
  req = batch.reqs[i]
593
629
  req_pool_idx = req_pool_indices_cpu[i]
@@ -598,7 +634,7 @@ class ModelRpcServer(rpyc.Service):
598
634
  token_ids[:seq_len], indices.clone()
599
635
  )
600
636
 
601
- self.token_to_kv_pool.free(indices[:prefix_len])
637
+ self.token_to_kv_pool.dec_refs(indices[:prefix_len])
602
638
  self.req_to_token_pool.free(req_pool_idx)
603
639
  self.tree_cache.dec_ref_counter(req.last_node)
604
640
 
@@ -609,14 +645,19 @@ class ModelRpcServer(rpyc.Service):
609
645
  batch.reqs = []
610
646
 
611
647
 
648
+ class ModelRpcService(rpyc.Service):
649
+ exposed_ModelRpcServer = ModelRpcServer
650
+
651
+
612
652
  class ModelRpcClient:
613
653
  def __init__(self, server_args: ServerArgs, port_args: PortArgs):
614
654
  tp_size = server_args.tp_size
615
655
 
616
656
  if tp_size == 1:
617
657
  # Init model
618
- self.model_server = ModelRpcServer()
619
- self.model_server.exposed_init_model(0, server_args, port_args)
658
+ self.model_server = ModelRpcService().exposed_ModelRpcServer(
659
+ 0, server_args, port_args
660
+ )
620
661
 
621
662
  # Wrap functions
622
663
  def async_wrap(f):
@@ -630,14 +671,16 @@ class ModelRpcClient:
630
671
  with ThreadPoolExecutor(tp_size) as executor:
631
672
  # Launch model processes
632
673
  rets = executor.map(start_model_process, port_args.model_rpc_ports)
633
- self.model_servers = [x[0] for x in rets]
674
+ self.remote_services = [x[0] for x in rets]
634
675
  self.procs = [x[1] for x in rets]
635
676
 
636
677
  # Init model
637
678
  def init_model(i):
638
- return self.model_servers[i].init_model(i, server_args, port_args)
679
+ return self.remote_services[i].ModelRpcServer(
680
+ i, server_args, port_args
681
+ )
639
682
 
640
- rets = [obtain(x) for x in executor.map(init_model, range(tp_size))]
683
+ self.model_servers = executor.map(init_model, range(tp_size))
641
684
 
642
685
  # Wrap functions
643
686
  def async_wrap(func_name):
@@ -655,7 +698,7 @@ class ModelRpcClient:
655
698
 
656
699
  def _init_service(port):
657
700
  t = ThreadedServer(
658
- ModelRpcServer(),
701
+ ModelRpcService(),
659
702
  port=port,
660
703
  protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
661
704
  )