sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +1 -0
  3. sglang/bench_serving.py +9 -1
  4. sglang/check_env.py +140 -48
  5. sglang/lang/backend/runtime_endpoint.py +1 -0
  6. sglang/lang/chat_template.py +32 -0
  7. sglang/llama3_eval.py +316 -0
  8. sglang/srt/aio_rwlock.py +100 -0
  9. sglang/srt/configs/model_config.py +8 -1
  10. sglang/srt/constrained/xgrammar_backend.py +4 -1
  11. sglang/srt/layers/attention/flashinfer_backend.py +51 -5
  12. sglang/srt/layers/attention/triton_backend.py +16 -25
  13. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  14. sglang/srt/layers/linear.py +20 -2
  15. sglang/srt/layers/logits_processor.py +133 -95
  16. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
  17. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  18. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  19. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
  20. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
  21. sglang/srt/layers/moe/topk.py +191 -0
  22. sglang/srt/layers/quantization/__init__.py +5 -50
  23. sglang/srt/layers/quantization/fp8.py +221 -36
  24. sglang/srt/layers/quantization/fp8_kernel.py +278 -0
  25. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  26. sglang/srt/layers/radix_attention.py +8 -1
  27. sglang/srt/layers/sampler.py +27 -5
  28. sglang/srt/layers/torchao_utils.py +31 -0
  29. sglang/srt/managers/detokenizer_manager.py +37 -17
  30. sglang/srt/managers/io_struct.py +39 -10
  31. sglang/srt/managers/schedule_batch.py +54 -34
  32. sglang/srt/managers/schedule_policy.py +64 -5
  33. sglang/srt/managers/scheduler.py +171 -136
  34. sglang/srt/managers/tokenizer_manager.py +184 -133
  35. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  36. sglang/srt/mem_cache/chunk_cache.py +2 -2
  37. sglang/srt/mem_cache/memory_pool.py +15 -8
  38. sglang/srt/mem_cache/radix_cache.py +12 -2
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -11
  40. sglang/srt/model_executor/model_runner.py +28 -14
  41. sglang/srt/model_parallel.py +66 -5
  42. sglang/srt/models/dbrx.py +1 -1
  43. sglang/srt/models/deepseek.py +1 -1
  44. sglang/srt/models/deepseek_v2.py +67 -18
  45. sglang/srt/models/gemma2.py +34 -0
  46. sglang/srt/models/gemma2_reward.py +0 -1
  47. sglang/srt/models/granite.py +517 -0
  48. sglang/srt/models/grok.py +73 -9
  49. sglang/srt/models/llama.py +22 -0
  50. sglang/srt/models/llama_classification.py +11 -23
  51. sglang/srt/models/llama_reward.py +0 -2
  52. sglang/srt/models/llava.py +37 -14
  53. sglang/srt/models/mixtral.py +2 -2
  54. sglang/srt/models/olmoe.py +1 -1
  55. sglang/srt/models/qwen2.py +20 -0
  56. sglang/srt/models/qwen2_moe.py +1 -1
  57. sglang/srt/models/xverse_moe.py +1 -1
  58. sglang/srt/openai_api/adapter.py +8 -0
  59. sglang/srt/openai_api/protocol.py +9 -4
  60. sglang/srt/server.py +2 -1
  61. sglang/srt/server_args.py +19 -9
  62. sglang/srt/utils.py +40 -54
  63. sglang/test/test_block_fp8.py +341 -0
  64. sglang/test/test_utils.py +3 -2
  65. sglang/utils.py +10 -3
  66. sglang/version.py +1 -1
  67. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
  68. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
  69. sglang/srt/layers/fused_moe_patch.py +0 -133
  70. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  71. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  72. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
  73. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import signal
22
22
  import sys
23
23
  import time
24
24
  import uuid
25
- from typing import Dict, List, Optional, Tuple, Union
25
+ from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
26
26
 
27
27
  import fastapi
28
28
  import uvloop
@@ -30,6 +30,7 @@ import zmq
30
30
  import zmq.asyncio
31
31
  from fastapi import BackgroundTasks
32
32
 
33
+ from sglang.srt.aio_rwlock import RWLock
33
34
  from sglang.srt.configs.model_config import ModelConfig
34
35
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
35
36
  from sglang.srt.managers.image_processor import (
@@ -62,7 +63,11 @@ from sglang.srt.managers.io_struct import (
62
63
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
63
64
  from sglang.srt.sampling.sampling_params import SamplingParams
64
65
  from sglang.srt.server_args import PortArgs, ServerArgs
65
- from sglang.srt.utils import get_zmq_socket, kill_process_tree
66
+ from sglang.srt.utils import (
67
+ dataclass_to_string_truncated,
68
+ get_zmq_socket,
69
+ kill_process_tree,
70
+ )
66
71
 
67
72
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
68
73
 
@@ -76,11 +81,15 @@ class ReqState:
76
81
  out_list: List
77
82
  finished: bool
78
83
  event: asyncio.Event
84
+ obj: Any
79
85
 
80
86
  # For metrics
81
87
  created_time: float
82
88
  first_token_time: Optional[float] = None
83
89
 
90
+ # For streaming output
91
+ last_output_offset: int = 0
92
+
84
93
 
85
94
  class TokenizerManager:
86
95
  """TokenizerManager is a process that tokenizes the text."""
@@ -119,6 +128,7 @@ class TokenizerManager:
119
128
 
120
129
  self.is_generation = self.model_config.is_generation
121
130
  self.context_len = self.model_config.context_len
131
+ self.image_token_id = self.model_config.image_token_id
122
132
 
123
133
  # Create image processor placeholder
124
134
  self.image_processor = get_dummy_image_processor()
@@ -151,9 +161,12 @@ class TokenizerManager:
151
161
  self.to_create_loop = True
152
162
  self.rid_to_state: Dict[str, ReqState] = {}
153
163
 
154
- # For update model weights
155
- self.model_update_lock = asyncio.Lock()
156
- self.model_update_result = None
164
+ # The event to notify the weight sync is finished.
165
+ self.model_update_lock = RWLock()
166
+ self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
167
+ None
168
+ )
169
+ self.asyncio_tasks = set()
157
170
 
158
171
  # For session info
159
172
  self.session_futures = {} # session_id -> asyncio event
@@ -180,9 +193,6 @@ class TokenizerManager:
180
193
  if self.to_create_loop:
181
194
  self.create_handle_loop()
182
195
 
183
- while self.model_update_lock.locked():
184
- await asyncio.sleep(0.001)
185
-
186
196
  if isinstance(obj, EmbeddingReqInput) and self.is_generation:
187
197
  raise ValueError(
188
198
  "This model does not appear to be an embedding model by default. "
@@ -190,17 +200,24 @@ class TokenizerManager:
190
200
  )
191
201
 
192
202
  obj.normalize_batch_and_arguments()
193
- is_single = obj.is_single
194
- if is_single:
195
- tokenized_obj = await self._tokenize_one_request(obj)
196
- self.send_to_scheduler.send_pyobj(tokenized_obj)
197
- async for response in self._wait_one_response(obj, request, created_time):
198
- yield response
199
- else:
200
- async for response in self._handle_batch_request(
201
- obj, request, created_time
202
- ):
203
- yield response
203
+
204
+ if self.server_args.log_requests:
205
+ logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
206
+
207
+ async with self.model_update_lock.reader_lock:
208
+ is_single = obj.is_single
209
+ if is_single:
210
+ tokenized_obj = await self._tokenize_one_request(obj)
211
+ self.send_to_scheduler.send_pyobj(tokenized_obj)
212
+ async for response in self._wait_one_response(
213
+ obj, request, created_time
214
+ ):
215
+ yield response
216
+ else:
217
+ async for response in self._handle_batch_request(
218
+ obj, request, created_time
219
+ ):
220
+ yield response
204
221
 
205
222
  async def _tokenize_one_request(
206
223
  self,
@@ -214,7 +231,7 @@ class TokenizerManager:
214
231
  if not self.server_args.disable_radix_cache:
215
232
  raise ValueError(
216
233
  "input_embeds is provided while disable_radix_cache is False. "
217
- "Please add `--disable-radix-cach` when you launch the server "
234
+ "Please add `--disable-radix-cache` when you launch the server "
218
235
  "if you want to use input_embeds as inputs."
219
236
  )
220
237
  input_embeds = obj.input_embeds
@@ -283,7 +300,7 @@ class TokenizerManager:
283
300
  ):
284
301
  """Wait for the response of one request."""
285
302
  event = asyncio.Event()
286
- state = ReqState([], False, event, created_time=created_time)
303
+ state = ReqState([], False, event, obj, created_time=created_time)
287
304
  self.rid_to_state[obj.rid] = state
288
305
 
289
306
  while True:
@@ -295,27 +312,25 @@ class TokenizerManager:
295
312
  raise ValueError(f"Abort request {obj.rid}")
296
313
  continue
297
314
 
298
- if isinstance(obj, GenerateReqInput):
299
- out = self.convert_logprob_style(
300
- state.out_list[-1],
301
- obj.return_logprob,
302
- obj.top_logprobs_num,
303
- obj.return_text_in_logprobs,
304
- )
305
- else: # isinstance(obj, (EmbeddingReqInput,))
306
- out = state.out_list[-1]
315
+ out = state.out_list[-1]
307
316
 
308
317
  state.out_list = []
309
318
  if state.finished:
310
319
  if self.server_args.log_requests:
311
- # Log requests
312
- logger.info(f"in={obj}, out={out}")
320
+ msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
321
+ logger.info(msg)
313
322
  del self.rid_to_state[obj.rid]
314
323
  yield out
315
324
  break
316
325
 
317
326
  state.event.clear()
318
- yield out
327
+
328
+ if obj.stream:
329
+ yield out
330
+ else:
331
+ if request is not None and await request.is_disconnected():
332
+ self.abort_request(obj.rid)
333
+ raise ValueError(f"Abort request {obj.rid}")
319
334
 
320
335
  async def _handle_batch_request(
321
336
  self,
@@ -424,55 +439,52 @@ class TokenizerManager:
424
439
  self,
425
440
  obj: UpdateWeightFromDiskReqInput,
426
441
  request: Optional[fastapi.Request] = None,
427
- ):
442
+ ) -> Tuple[bool, str]:
428
443
  if self.to_create_loop:
429
444
  self.create_handle_loop()
430
445
 
431
446
  # default the load format to the server_args
432
447
  if obj.load_format is None:
433
448
  obj.load_format = self.server_args.load_format
449
+ logger.info("Start update_weights. Load format=%s", obj.load_format)
434
450
 
435
- if not self.model_update_lock.locked():
451
+ if True:
452
+ # Hold the lock if it is not async. This means that weight sync
453
+ # cannot run while requests are in progress.
454
+ async with self.model_update_lock.writer_lock:
455
+ return await self._wait_for_model_update_from_disk(obj)
436
456
 
437
- async with self.model_update_lock:
438
- # wait for the previous generation requests to finish
439
- for i in range(3):
440
- while len(self.rid_to_state) > 0:
441
- await asyncio.sleep(0.001)
442
- # FIXME: We add some sleep here to avoid some race conditions.
443
- # We can use a read-write lock as a better fix.
444
- await asyncio.sleep(0.01)
445
- self.send_to_scheduler.send_pyobj(obj)
446
- self.model_update_result = asyncio.Future()
447
-
448
- if self.server_args.dp_size == 1:
449
- result = await self.model_update_result
450
- if result.success:
451
- self.server_args.model_path = obj.model_path
452
- self.server_args.load_format = obj.load_format
453
- self.model_path = obj.model_path
454
- return result.success, result.message
455
- else: # self.server_args.dp_size > 1
456
- self.model_update_tmp = []
457
- result = await self.model_update_result
458
-
459
- all_success = all([r.success for r in result])
460
- if all_success is True:
461
- self.server_args.model_path = obj.model_path
462
- self.server_args.load_format = obj.load_format
463
- self.model_path = obj.model_path
464
- all_message = [r.message for r in result]
465
- all_message = " | ".join(all_message)
466
- return all_success, all_message
467
-
468
- else:
469
- return False, "Another update is in progress. Please try again later."
457
+ async def _wait_for_model_update_from_disk(
458
+ self, obj: UpdateWeightFromDiskReqInput
459
+ ) -> Tuple[bool, str, int]:
460
+ self.send_to_scheduler.send_pyobj(obj)
461
+ self.model_update_result = asyncio.Future()
462
+ if self.server_args.dp_size == 1:
463
+ result = await self.model_update_result
464
+ if result.success:
465
+ self.served_model_name = obj.model_path
466
+ self.server_args.model_path = obj.model_path
467
+ self.server_args.load_format = obj.load_format
468
+ self.model_path = obj.model_path
469
+ return result.success, result.message
470
+ else: # self.server_args.dp_size > 1
471
+ self.model_update_tmp = []
472
+ result = await self.model_update_result
473
+
474
+ all_success = all([r.success for r in result])
475
+ if all_success is True:
476
+ self.server_args.model_path = obj.model_path
477
+ self.server_args.load_format = obj.load_format
478
+ self.model_path = obj.model_path
479
+ all_message = [r.message for r in result]
480
+ all_message = " | ".join(all_message)
481
+ return all_success, all_message
470
482
 
471
483
  async def init_weights_update_group(
472
484
  self,
473
485
  obj: InitWeightsUpdateGroupReqInput,
474
486
  request: Optional[fastapi.Request] = None,
475
- ) -> bool:
487
+ ) -> Tuple[bool, str]:
476
488
  if self.to_create_loop:
477
489
  self.create_handle_loop()
478
490
  self.send_to_scheduler.send_pyobj(obj)
@@ -488,25 +500,22 @@ class TokenizerManager:
488
500
  self,
489
501
  obj: UpdateWeightsFromDistributedReqInput,
490
502
  request: Optional[fastapi.Request] = None,
491
- ):
503
+ ) -> Tuple[bool, str]:
492
504
  if self.to_create_loop:
493
505
  self.create_handle_loop()
494
506
 
495
- if not self.model_update_lock.locked():
496
- async with self.model_update_lock:
497
- self.send_to_scheduler.send_pyobj(obj)
498
- self.parameter_update_result = asyncio.Future()
499
- assert (
500
- self.server_args.dp_size == 1
501
- ), "dp_size must be for update weights from distributed"
502
- result = await self.parameter_update_result
503
- return result.success, result.message
504
- else:
505
- logger.error("Another parameter update is in progress in tokenizer manager")
506
- return (
507
- False,
508
- "Another parameter update is in progress. Please try again later.",
509
- )
507
+ # This means that weight sync
508
+ # cannot run while requests are in progress.
509
+ async with self.model_update_lock.writer_lock:
510
+ self.send_to_scheduler.send_pyobj(obj)
511
+ self.parameter_update_result: Awaitable[
512
+ UpdateWeightsFromDistributedReqOutput
513
+ ] = asyncio.Future()
514
+ assert (
515
+ self.server_args.dp_size == 1
516
+ ), "dp_size must be for update weights from distributed"
517
+ result = await self.parameter_update_result
518
+ return result.success, result.message
510
519
 
511
520
  async def get_weights_by_name(
512
521
  self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
@@ -565,15 +574,15 @@ class TokenizerManager:
565
574
 
566
575
  self.to_create_loop = False
567
576
  loop = asyncio.get_event_loop()
568
- loop.create_task(self.handle_loop())
577
+ self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
569
578
 
570
579
  signal_handler = SignalHandler(self)
571
580
  loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
572
- loop.create_task(self.sigterm_watchdog())
581
+ self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
573
582
 
574
583
  async def sigterm_watchdog(self):
575
584
  while not self.gracefully_exit:
576
- await asyncio.sleep(60)
585
+ await asyncio.sleep(5)
577
586
 
578
587
  # drain requests
579
588
  while True:
@@ -609,29 +618,55 @@ class TokenizerManager:
609
618
  if state is None:
610
619
  continue
611
620
 
612
- recv_obj.meta_info[i]["id"] = rid
621
+ meta_info = {
622
+ "id": rid,
623
+ "finish_reason": recv_obj.finished_reasons[i],
624
+ "prompt_tokens": recv_obj.prompt_tokens[i],
625
+ }
626
+
627
+ if getattr(state.obj, "return_logprob", False):
628
+ self.convert_logprob_style(
629
+ meta_info,
630
+ state.obj.top_logprobs_num,
631
+ state.obj.return_text_in_logprobs,
632
+ recv_obj,
633
+ i,
634
+ )
635
+
636
+ if not isinstance(recv_obj, BatchEmbeddingOut):
637
+ meta_info.update(
638
+ {
639
+ "completion_tokens": recv_obj.completion_tokens[i],
640
+ "cached_tokens": recv_obj.cached_tokens[i],
641
+ }
642
+ )
643
+
613
644
  if isinstance(recv_obj, BatchStrOut):
614
645
  out_dict = {
615
646
  "text": recv_obj.output_strs[i],
616
- "meta_info": recv_obj.meta_info[i],
647
+ "meta_info": meta_info,
617
648
  }
618
649
  elif isinstance(recv_obj, BatchTokenIDOut):
619
650
  out_dict = {
620
651
  "token_ids": recv_obj.output_ids[i],
621
- "meta_info": recv_obj.meta_info[i],
652
+ "meta_info": meta_info,
622
653
  }
623
654
  else:
624
655
  assert isinstance(recv_obj, BatchEmbeddingOut)
625
656
  out_dict = {
626
657
  "embedding": recv_obj.embeddings[i],
627
- "meta_info": recv_obj.meta_info[i],
658
+ "meta_info": meta_info,
628
659
  }
629
660
  state.out_list.append(out_dict)
630
- state.finished = recv_obj.finished_reason[i] is not None
661
+ state.finished = recv_obj.finished_reasons[i] is not None
631
662
  state.event.set()
632
663
 
633
664
  if self.enable_metrics:
634
- completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
665
+ completion_tokens = (
666
+ recv_obj.completion_tokens[i]
667
+ if recv_obj.completion_tokens
668
+ else 0
669
+ )
635
670
 
636
671
  if state.first_token_time is None:
637
672
  state.first_token_time = time.time()
@@ -647,7 +682,7 @@ class TokenizerManager:
647
682
 
648
683
  if state.finished:
649
684
  self.metrics_collector.inc_prompt_tokens(
650
- recv_obj.meta_info[i]["prompt_tokens"]
685
+ recv_obj.prompt_tokens[i]
651
686
  )
652
687
  self.metrics_collector.inc_generation_tokens(
653
688
  completion_tokens
@@ -696,57 +731,73 @@ class TokenizerManager:
696
731
 
697
732
  def convert_logprob_style(
698
733
  self,
699
- ret: dict,
700
- return_logprob: bool,
734
+ meta_info: dict,
701
735
  top_logprobs_num: int,
702
736
  return_text_in_logprobs: bool,
737
+ recv_obj: BatchStrOut,
738
+ recv_obj_index: int,
703
739
  ):
704
- if return_logprob:
705
- ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
706
- ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
740
+ meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
741
+ recv_obj.input_token_logprobs_val[recv_obj_index],
742
+ recv_obj.input_token_logprobs_idx[recv_obj_index],
743
+ return_text_in_logprobs,
744
+ )
745
+ meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
746
+ recv_obj.output_token_logprobs_val[recv_obj_index],
747
+ recv_obj.output_token_logprobs_idx[recv_obj_index],
748
+ return_text_in_logprobs,
749
+ )
750
+ meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
751
+ recv_obj_index
752
+ ]
753
+
754
+ if top_logprobs_num > 0:
755
+ meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
756
+ recv_obj.input_top_logprobs_val[recv_obj_index],
757
+ recv_obj.input_top_logprobs_idx[recv_obj_index],
758
+ return_text_in_logprobs,
707
759
  )
708
- ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
709
- ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
760
+ meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
761
+ recv_obj.output_top_logprobs_val[recv_obj_index],
762
+ recv_obj.output_top_logprobs_idx[recv_obj_index],
763
+ return_text_in_logprobs,
710
764
  )
711
765
 
712
- if top_logprobs_num > 0:
713
- ret["meta_info"]["input_top_logprobs"] = (
714
- self.detokenize_top_logprobs_tokens(
715
- ret["meta_info"]["input_top_logprobs"],
716
- return_text_in_logprobs,
717
- )
718
- )
719
- ret["meta_info"]["output_top_logprobs"] = (
720
- self.detokenize_top_logprobs_tokens(
721
- ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
722
- )
723
- )
724
- return ret
725
-
726
766
  def detokenize_logprob_tokens(
727
- self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
767
+ self,
768
+ token_logprobs_val: List[float],
769
+ token_logprobs_idx: List[int],
770
+ decode_to_text: bool,
728
771
  ):
729
- # TODO(lianmin): This should run on DetokenizerManager
730
772
  if not decode_to_text:
731
- return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
732
-
733
- assert self.tokenizer is not None
734
- token_ids = [tid for _, tid in token_logprobs]
735
- token_texts = self.tokenizer.batch_decode(token_ids)
736
- return [
737
- (logprob, token_id, token_text)
738
- for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
739
- ]
773
+ return [
774
+ (logprob, token_id, None)
775
+ for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
776
+ ]
777
+ else:
778
+ assert self.tokenizer is not None
779
+ token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
780
+ return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
740
781
 
741
- def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
782
+ def detokenize_top_logprobs_tokens(
783
+ self,
784
+ token_logprobs_val: List[float],
785
+ token_logprobs_idx: List[int],
786
+ decode_to_text: bool,
787
+ ):
742
788
  # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
743
789
  # We should batch all top-k tokens in all positions.
744
- for i, token_top_logprobs in enumerate(top_logprobs):
745
- if token_top_logprobs:
746
- top_logprobs[i] = self.detokenize_logprob_tokens(
747
- token_top_logprobs, decode_to_text
790
+ ret = []
791
+ for i in range(len(token_logprobs_val)):
792
+ if token_logprobs_val[i]:
793
+ ret.append(
794
+ self.detokenize_logprob_tokens(
795
+ token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
796
+ )
748
797
  )
749
- return top_logprobs
798
+ else:
799
+ ret.append(None)
800
+ return ret
750
801
 
751
802
 
752
803
  class SignalHandler:
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Callable
2
+ from typing import Callable, List, Tuple
3
3
 
4
4
 
5
5
  class BasePrefixCache(ABC):
@@ -10,7 +10,7 @@ class BasePrefixCache(ABC):
10
10
  pass
11
11
 
12
12
  @abstractmethod
13
- def match_prefix(self, **kwargs):
13
+ def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
14
14
  pass
15
15
 
16
16
  @abstractmethod
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  """Cache for chunked prefill, used when RadixCache is disabled."""
4
4
 
5
- from typing import TYPE_CHECKING, Callable, List, Optional
5
+ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
6
6
 
7
7
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
8
8
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
@@ -30,7 +30,7 @@ class ChunkCache(BasePrefixCache):
30
30
  def reset(self):
31
31
  self.entries = {}
32
32
 
33
- def match_prefix(self, rid: int, key: List[int]):
33
+ def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
34
34
  if rid not in self.entries:
35
35
  return [], None
36
36
 
@@ -184,26 +184,35 @@ class MHATokenToKVPool(BaseTokenToKVPool):
184
184
  device: str,
185
185
  ):
186
186
  super().__init__(size, dtype, device)
187
+ self.head_num = head_num
188
+ self.head_dim = head_dim
189
+ self.layer_num = layer_num
190
+ self._create_buffers()
187
191
 
192
+ def _create_buffers(self):
188
193
  # [size, head_num, head_dim] for each layer
189
194
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
190
195
  self.k_buffer = [
191
196
  torch.empty(
192
- (size + 1, head_num, head_dim),
197
+ (self.size + 1, self.head_num, self.head_dim),
193
198
  dtype=self.store_dtype,
194
- device=device,
199
+ device=self.device,
195
200
  )
196
- for _ in range(layer_num)
201
+ for _ in range(self.layer_num)
197
202
  ]
198
203
  self.v_buffer = [
199
204
  torch.empty(
200
- (size + 1, head_num, head_dim),
205
+ (self.size + 1, self.head_num, self.head_dim),
201
206
  dtype=self.store_dtype,
202
- device=device,
207
+ device=self.device,
203
208
  )
204
- for _ in range(layer_num)
209
+ for _ in range(self.layer_num)
205
210
  ]
206
211
 
212
+ def _clear_buffers(self):
213
+ del self.k_buffer
214
+ del self.v_buffer
215
+
207
216
  def get_key_buffer(self, layer_id: int):
208
217
  if self.store_dtype != self.dtype:
209
218
  return self.k_buffer[layer_id].view(self.dtype)
@@ -245,7 +254,6 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
245
254
 
246
255
 
247
256
  class MLATokenToKVPool(BaseTokenToKVPool):
248
-
249
257
  def __init__(
250
258
  self,
251
259
  size: int,
@@ -298,7 +306,6 @@ class MLATokenToKVPool(BaseTokenToKVPool):
298
306
 
299
307
 
300
308
  class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
301
-
302
309
  def __init__(
303
310
  self,
304
311
  size: int,
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
22
22
  import heapq
23
23
  import time
24
24
  from collections import defaultdict
25
- from typing import TYPE_CHECKING, Callable, List, Optional
25
+ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
26
26
 
27
27
  import torch
28
28
 
@@ -76,7 +76,17 @@ class RadixCache(BasePrefixCache):
76
76
  self.root_node.lock_ref = 1
77
77
  self.evictable_size_ = 0
78
78
 
79
- def match_prefix(self, key: List, **kwargs):
79
+ def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
80
+ """Find the matching prefix from the radix tree.
81
+ Args:
82
+ key: A list of token IDs to find a matching prefix.
83
+ Returns:
84
+ A tuple of a tensor of matching prefix token IDs and
85
+ the last node that contains the prefix values. Note that
86
+ this API can modify the internal state of the Radix tree.
87
+ The last node create a new child if the prefix is shorter
88
+ than the last node's value.
89
+ """
80
90
  if self.disable:
81
91
  return [], self.root_node
82
92