sglang 0.4.1__py3-none-any.whl → 0.4.1.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. sglang/bench_offline_throughput.py +1 -0
  2. sglang/bench_serving.py +11 -3
  3. sglang/lang/backend/openai.py +10 -0
  4. sglang/srt/configs/model_config.py +11 -2
  5. sglang/srt/constrained/xgrammar_backend.py +6 -0
  6. sglang/srt/layers/attention/__init__.py +0 -1
  7. sglang/srt/layers/attention/flashinfer_backend.py +54 -41
  8. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
  9. sglang/srt/layers/logits_processor.py +30 -2
  10. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -30
  11. sglang/srt/layers/moe/topk.py +14 -0
  12. sglang/srt/layers/quantization/fp8.py +42 -2
  13. sglang/srt/layers/quantization/fp8_kernel.py +91 -18
  14. sglang/srt/layers/quantization/fp8_utils.py +8 -2
  15. sglang/srt/managers/io_struct.py +29 -8
  16. sglang/srt/managers/schedule_batch.py +22 -15
  17. sglang/srt/managers/schedule_policy.py +1 -1
  18. sglang/srt/managers/scheduler.py +71 -34
  19. sglang/srt/managers/session_controller.py +102 -27
  20. sglang/srt/managers/tokenizer_manager.py +95 -55
  21. sglang/srt/managers/tp_worker.py +7 -0
  22. sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
  23. sglang/srt/model_executor/forward_batch_info.py +42 -3
  24. sglang/srt/model_executor/model_runner.py +4 -6
  25. sglang/srt/model_loader/loader.py +22 -11
  26. sglang/srt/models/gemma2.py +19 -0
  27. sglang/srt/models/llama.py +13 -2
  28. sglang/srt/models/llama_eagle.py +132 -0
  29. sglang/srt/openai_api/adapter.py +79 -2
  30. sglang/srt/openai_api/protocol.py +50 -0
  31. sglang/srt/sampling/sampling_params.py +9 -2
  32. sglang/srt/server.py +45 -39
  33. sglang/srt/server_args.py +17 -30
  34. sglang/srt/speculative/spec_info.py +19 -0
  35. sglang/srt/utils.py +62 -0
  36. sglang/version.py +1 -1
  37. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/METADATA +5 -5
  38. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/RECORD +41 -39
  39. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/LICENSE +0 -0
  40. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/WHEEL +0 -0
  41. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import warnings
22
22
  from collections import deque
23
23
  from concurrent import futures
24
24
  from types import SimpleNamespace
25
- from typing import Callable, Dict, List, Optional, Tuple
25
+ from typing import Dict, List, Optional, Tuple
26
26
 
27
27
  import psutil
28
28
  import setproctitle
@@ -52,6 +52,8 @@ from sglang.srt.managers.io_struct import (
52
52
  UpdateWeightFromDiskReqOutput,
53
53
  UpdateWeightsFromDistributedReqInput,
54
54
  UpdateWeightsFromDistributedReqOutput,
55
+ UpdateWeightsFromTensorReqInput,
56
+ UpdateWeightsFromTensorReqOutput,
55
57
  )
56
58
  from sglang.srt.managers.schedule_batch import (
57
59
  FINISH_ABORT,
@@ -88,7 +90,7 @@ from sglang.utils import get_exception_traceback
88
90
 
89
91
  logger = logging.getLogger(__name__)
90
92
 
91
- # Test retract decode
93
+ # Test retract decode for debugging purposes
92
94
  test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
93
95
 
94
96
 
@@ -127,12 +129,12 @@ class Scheduler:
127
129
  )
128
130
 
129
131
  if server_args.skip_tokenizer_init:
130
- # Directly send to the tokenizer/api
132
+ # Directly send to the TokenizerManager
131
133
  self.send_to_detokenizer = get_zmq_socket(
132
134
  context, zmq.PUSH, port_args.tokenizer_ipc_name
133
135
  )
134
136
  else:
135
- # Send to the detokenizer
137
+ # Send to the DetokenizerManager
136
138
  self.send_to_detokenizer = get_zmq_socket(
137
139
  context, zmq.PUSH, port_args.detokenizer_ipc_name
138
140
  )
@@ -383,7 +385,8 @@ class Scheduler:
383
385
  self.process_input_requests(recv_reqs)
384
386
 
385
387
  batch = self.get_next_batch_to_run()
386
- if self.server_args.enable_dp_attention:
388
+
389
+ if self.server_args.enable_dp_attention: # TODO: simplify this
387
390
  batch = self.prepare_dp_attn_batch(batch)
388
391
 
389
392
  self.cur_batch = batch
@@ -392,7 +395,7 @@ class Scheduler:
392
395
  result = self.run_batch(batch)
393
396
  self.process_batch_result(batch, result)
394
397
  else:
395
- # Self-check and re-init some states when the server is idle
398
+ # When the server is idle, so self-check and re-init some states
396
399
  self.check_memory()
397
400
  self.new_token_ratio = self.init_new_token_ratio
398
401
 
@@ -409,12 +412,13 @@ class Scheduler:
409
412
 
410
413
  batch = self.get_next_batch_to_run()
411
414
  self.cur_batch = batch
415
+
412
416
  if batch:
413
417
  result = self.run_batch(batch)
414
418
  result_queue.append((batch.copy(), result))
415
419
 
416
420
  if self.last_batch is None:
417
- # A dummy first batch to start the pipeline for overlap scheduler.
421
+ # Create a dummy first batch to start the pipeline for overlap scheduler.
418
422
  # It is now used for triggering the sampling_info_done event.
419
423
  tmp_batch = ScheduleBatch(
420
424
  reqs=None,
@@ -424,19 +428,21 @@ class Scheduler:
424
428
  self.process_batch_result(tmp_batch, None)
425
429
 
426
430
  if self.last_batch:
431
+ # Process the results of the last batch
427
432
  tmp_batch, tmp_result = result_queue.popleft()
428
433
  tmp_batch.next_batch_sampling_info = (
429
434
  self.tp_worker.cur_sampling_info if batch else None
430
435
  )
431
436
  self.process_batch_result(tmp_batch, tmp_result)
432
437
  elif batch is None:
433
- # Self-check and re-init some states when the server is idle
438
+ # When the server is idle, so self-check and re-init some states
434
439
  self.check_memory()
435
440
  self.new_token_ratio = self.init_new_token_ratio
436
441
 
437
442
  self.last_batch = batch
438
443
 
439
- def recv_requests(self):
444
+ def recv_requests(self) -> List[Req]:
445
+ """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
440
446
  if self.tp_rank == 0 or self.server_args.enable_dp_attention:
441
447
  recv_reqs = []
442
448
 
@@ -468,9 +474,6 @@ class Scheduler:
468
474
  self.send_to_tokenizer.send_pyobj(
469
475
  UpdateWeightFromDiskReqOutput(success, message)
470
476
  )
471
- elif isinstance(recv_req, GetWeightsByNameReqInput):
472
- parameter = self.get_weights_by_name(recv_req)
473
- self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
474
477
  elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
475
478
  success, message = self.init_weights_update_group(recv_req)
476
479
  self.send_to_tokenizer.send_pyobj(
@@ -481,6 +484,11 @@ class Scheduler:
481
484
  self.send_to_tokenizer.send_pyobj(
482
485
  UpdateWeightsFromDistributedReqOutput(success, message)
483
486
  )
487
+ elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
488
+ success, message = self.update_weights_from_tensor(recv_req)
489
+ self.send_to_tokenizer.send_pyobj(
490
+ UpdateWeightsFromTensorReqOutput(success, message)
491
+ )
484
492
  elif isinstance(recv_req, GetWeightsByNameReqInput):
485
493
  parameter = self.get_weights_by_name(recv_req)
486
494
  self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
@@ -490,8 +498,10 @@ class Scheduler:
490
498
  else:
491
499
  self.stop_profile()
492
500
  elif isinstance(recv_req, OpenSessionReqInput):
493
- session_id = self.open_session(recv_req)
494
- self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
501
+ session_id, success = self.open_session(recv_req)
502
+ self.send_to_tokenizer.send_pyobj(
503
+ OpenSessionReqOutput(session_id=session_id, success=success)
504
+ )
495
505
  elif isinstance(recv_req, CloseSessionReqInput):
496
506
  self.close_session(recv_req)
497
507
  else:
@@ -502,7 +512,11 @@ class Scheduler:
502
512
  recv_req: TokenizedGenerateReqInput,
503
513
  ):
504
514
  # Create a new request
505
- if recv_req.session_id is None or recv_req.session_id not in self.sessions:
515
+ if (
516
+ recv_req.session_params is None
517
+ or recv_req.session_params.id is None
518
+ or recv_req.session_params.id not in self.sessions
519
+ ):
506
520
 
507
521
  if recv_req.input_embeds is not None:
508
522
  # Generate fake input_ids based on the length of input_embeds
@@ -520,18 +534,22 @@ class Scheduler:
520
534
  stream=recv_req.stream,
521
535
  lora_path=recv_req.lora_path,
522
536
  input_embeds=recv_req.input_embeds,
537
+ eos_token_ids=self.model_config.hf_eos_token_id,
523
538
  )
524
539
  req.tokenizer = self.tokenizer
525
540
 
526
- if recv_req.session_id is not None:
541
+ if (
542
+ recv_req.session_params is not None
543
+ and recv_req.session_params.id is not None
544
+ ):
527
545
  req.finished_reason = FINISH_ABORT(
528
- f"Invalid request: session id {recv_req.session_id} does not exist"
546
+ f"Invalid request: session id {recv_req.session_params.id} does not exist"
529
547
  )
530
548
  self.waiting_queue.append(req)
531
549
  return
532
550
  else:
533
- # Create a new request from a previsou session
534
- session = self.sessions[recv_req.session_id]
551
+ # Create a new request from a previous session
552
+ session = self.sessions[recv_req.session_params.id]
535
553
  req = session.create_req(recv_req, self.tokenizer)
536
554
  if isinstance(req.finished_reason, FINISH_ABORT):
537
555
  self.waiting_queue.append(req)
@@ -565,7 +583,7 @@ class Scheduler:
565
583
 
566
584
  if req.logprob_start_len == -1:
567
585
  # By default, only return the logprobs for output tokens
568
- req.logprob_start_len = len(recv_req.input_ids) - 1
586
+ req.logprob_start_len = len(req.origin_input_ids) - 1
569
587
 
570
588
  # Truncate prompts that are too long
571
589
  if len(req.origin_input_ids) > self.max_req_input_len:
@@ -589,12 +607,15 @@ class Scheduler:
589
607
  if (
590
608
  req.sampling_params.json_schema is not None
591
609
  or req.sampling_params.regex is not None
610
+ or req.sampling_params.ebnf is not None
592
611
  ):
593
612
  assert self.grammar_backend is not None
594
613
  if req.sampling_params.json_schema is not None:
595
614
  key = ("json", req.sampling_params.json_schema)
596
615
  elif req.sampling_params.regex is not None:
597
616
  key = ("regex", req.sampling_params.regex)
617
+ elif req.sampling_params.ebnf is not None:
618
+ key = ("ebnf", req.sampling_params.ebnf)
598
619
 
599
620
  req.grammar = self.grammar_backend.get_cached_value(key)
600
621
  if not req.grammar:
@@ -629,16 +650,13 @@ class Scheduler:
629
650
  self.waiting_queue.append(req)
630
651
 
631
652
  def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
632
- if isinstance(self.tree_cache, RadixCache):
633
- self.tree_cache_metrics["total"] += (
634
- adder.log_input_tokens + adder.log_hit_tokens
635
- ) / 10**9
636
- self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
637
- tree_cache_hit_rate = (
638
- self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
639
- )
640
- else:
641
- tree_cache_hit_rate = 0.0
653
+ self.tree_cache_metrics["total"] += (
654
+ adder.log_input_tokens + adder.log_hit_tokens
655
+ ) / 10**9
656
+ self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
657
+ tree_cache_hit_rate = (
658
+ self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
659
+ )
642
660
 
643
661
  num_used = self.max_total_num_tokens - (
644
662
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
@@ -807,6 +825,8 @@ class Scheduler:
807
825
  if res == AddReqResult.NO_TOKEN:
808
826
  self.batch_is_full = True
809
827
  break
828
+ if self.server_args.prefill_only_one_req:
829
+ break
810
830
 
811
831
  # Update waiting queue
812
832
  can_run_list = adder.can_run_list
@@ -1460,6 +1480,17 @@ class Scheduler:
1460
1480
  logger.error(message)
1461
1481
  return success, message
1462
1482
 
1483
+ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
1484
+ """Update the online model parameter from tensors."""
1485
+ success, message = self.tp_worker.update_weights_from_tensor(recv_req)
1486
+ # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
1487
+ if success:
1488
+ flash_cache_success = self.flush_cache()
1489
+ assert flash_cache_success, "Cache flush failed after updating weights"
1490
+ else:
1491
+ logger.error(message)
1492
+ return success, message
1493
+
1463
1494
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
1464
1495
  parameter = self.tp_worker.get_weights_by_name(recv_req)
1465
1496
  return parameter
@@ -1478,16 +1509,20 @@ class Scheduler:
1478
1509
  )
1479
1510
  logger.info("Profiler is done")
1480
1511
 
1481
- def open_session(self, recv_req: OpenSessionReqInput) -> str:
1512
+ def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
1482
1513
  # handle error
1483
1514
  session_id = recv_req.session_id
1484
1515
  if session_id in self.sessions:
1485
1516
  logger.warning(f"session id {session_id} already exist, cannot open.")
1517
+ return session_id, False
1518
+ elif session_id is None:
1519
+ logger.warning(f"session id is None, cannot open.")
1520
+ return session_id, False
1486
1521
  else:
1487
1522
  self.sessions[session_id] = Session(
1488
1523
  recv_req.capacity_of_str_len, session_id
1489
1524
  )
1490
- return session_id
1525
+ return session_id, True
1491
1526
 
1492
1527
  def close_session(self, recv_req: CloseSessionReqInput):
1493
1528
  # handle error
@@ -1512,18 +1547,20 @@ def run_scheduler_process(
1512
1547
  if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
1513
1548
  dp_rank = int(os.environ["SGLANG_DP_RANK"])
1514
1549
 
1550
+ # Configue the logger
1515
1551
  if dp_rank is None:
1516
1552
  configure_logger(server_args, prefix=f" TP{tp_rank}")
1517
1553
  else:
1518
1554
  configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
1555
+ suppress_other_loggers()
1519
1556
 
1520
- # set cpu affinity to this gpu process
1557
+ # Set cpu affinity to this gpu process
1521
1558
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
1522
1559
  set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1523
1560
 
1524
- suppress_other_loggers()
1525
1561
  parent_process = psutil.Process().parent()
1526
1562
 
1563
+ # Create a scheduler and run the event loop
1527
1564
  try:
1528
1565
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1529
1566
  pipe_writer.send(
@@ -10,41 +10,116 @@
10
10
  # limitations under the License.
11
11
  # ==============================================================================
12
12
 
13
+ import logging
13
14
  import uuid
15
+ from typing import Dict, Optional
14
16
 
15
17
  from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
16
- from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
18
+ from sglang.srt.managers.schedule_batch import Req
19
+
20
+
21
+ class SessionReqNode:
22
+ def __init__(self, req, parent=None, childs=None):
23
+ self.req = req
24
+ self.parent = parent
25
+ if parent is not None:
26
+ parent.childs.append(self)
27
+ self.childs = [] if not childs else childs
28
+
29
+ def clear_childs(self, req_dict):
30
+ for req_node in self.childs:
31
+ req_node.clear(req_dict)
32
+ self.childs = []
33
+
34
+ def clear(self, req_dict):
35
+ for req_node in self.childs:
36
+ req_node.clear(req_dict)
37
+
38
+ if self.req.finished_reason == None:
39
+ self.req.to_abort = True
40
+ del req_dict[self.req.rid]
41
+
42
+ def abort(self):
43
+ if self.req.finished_reason == None:
44
+ self.req.to_abort = True
45
+
46
+ def __str__(self):
47
+ return self._str_helper(self.req.rid)
48
+
49
+ def _str_helper(self, prefix=""):
50
+ if len(self.childs) == 0:
51
+ return prefix + "\n"
52
+ else:
53
+ origin_prefix = prefix
54
+ prefix += " -- " + self.childs[0].req.rid
55
+ ret = self.childs[0]._str_helper(prefix)
56
+ for child in self.childs[1:]:
57
+ prefix = " " * len(origin_prefix) + " \- " + child.req.rid
58
+ ret += child._str_helper(prefix)
59
+ return ret
17
60
 
18
61
 
19
62
  class Session:
20
- def __init__(self, capacity_of_str_len: int, session_id: str = None):
63
+ def __init__(self, capacity_of_str_len: int, session_id: Optional[str] = None):
21
64
  self.session_id = session_id if session_id is not None else uuid.uuid4().hex
22
65
  self.capacity_of_str_len = capacity_of_str_len
23
- self.reqs: List[Req] = []
66
+ self.req_nodes: Dict[str, SessionReqNode] = {}
24
67
 
25
68
  def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
26
- if req.session_rid is not None:
27
- while len(self.reqs) > 0:
28
- if self.reqs[-1].rid == req.session_rid:
29
- break
30
- self.reqs = self.reqs[:-1]
69
+ assert req.session_params is not None
70
+ session_params = req.session_params
71
+
72
+ last_req_node = None
73
+ last_req = None
74
+ abort = False
75
+ if session_params.replace:
76
+ if session_params.rid is None:
77
+ for _, req_node in self.req_nodes.items():
78
+ req_node.clear(self.req_nodes)
79
+ else:
80
+ if session_params.rid not in self.req_nodes:
81
+ abort = True
82
+ else:
83
+ last_req_node = self.req_nodes[session_params.rid]
84
+ last_req_node.abort()
85
+ last_req = last_req_node.req
86
+ last_req_node.clear_childs(self.req_nodes)
31
87
  else:
32
- self.reqs = []
33
- if len(self.reqs) > 0:
88
+ if session_params.rid is not None:
89
+ if session_params.rid not in self.req_nodes:
90
+ abort = True
91
+ else:
92
+ last_req_node = self.req_nodes[session_params.rid]
93
+ last_req = last_req_node.req
94
+ if not last_req.finished():
95
+ logging.warning(
96
+ "The request in a session is appending to a request that hasn't finished."
97
+ )
98
+ abort = True
99
+
100
+ if last_req is not None:
101
+ # trim bos token if it is an append
102
+ if req.input_ids[0] == tokenizer.bos_token_id:
103
+ req.input_ids = req.input_ids[1:]
104
+
34
105
  input_ids = (
35
- self.reqs[-1].origin_input_ids
36
- + self.reqs[-1].output_ids[
37
- : self.reqs[-1].sampling_params.max_new_tokens
38
- ]
39
- + req.input_ids
106
+ last_req.origin_input_ids
107
+ + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
40
108
  )
109
+ if session_params.offset and session_params.offset != 0:
110
+ input_ids = input_ids[: session_params.offset] + req.input_ids
111
+ else:
112
+ input_ids += req.input_ids
41
113
  input_ids_unpadded = (
42
- self.reqs[-1].origin_input_ids_unpadded
43
- + self.reqs[-1].output_ids[
44
- : self.reqs[-1].sampling_params.max_new_tokens
45
- ]
46
- + req.input_ids
114
+ last_req.origin_input_ids_unpadded
115
+ + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
47
116
  )
117
+ if session_params.offset and session_params.offset != 0:
118
+ input_ids_unpadded = (
119
+ input_ids_unpadded[: session_params.offset] + req.input_ids
120
+ )
121
+ else:
122
+ input_ids_unpadded += req.input_ids
48
123
  else:
49
124
  input_ids = req.input_ids
50
125
  input_ids_unpadded = req.input_ids
@@ -57,13 +132,13 @@ class Session:
57
132
  lora_path=req.lora_path,
58
133
  session_id=self.session_id,
59
134
  )
60
- if len(self.reqs) > 0:
61
- new_req.image_inputs = self.reqs[-1].image_inputs
135
+ if last_req is not None:
136
+ new_req.image_inputs = last_req.image_inputs
62
137
  new_req.tokenizer = tokenizer
63
- if req.session_rid is not None and len(self.reqs) == 0:
64
- new_req.finished_reason = FINISH_ABORT(
65
- f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
66
- )
138
+ if abort:
139
+ new_req.to_abort = True
67
140
  else:
68
- self.reqs.append(new_req)
141
+ new_req_node = SessionReqNode(new_req, last_req_node)
142
+ self.req_nodes[req.rid] = new_req_node
143
+
69
144
  return new_req