sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -269,10 +269,9 @@ class TokenizerManager:
269
269
  self.asyncio_tasks = set()
270
270
 
271
271
  # Health check
272
- self.health_check_failed = False
272
+ self.server_status = ServerStatus.Starting
273
273
  self.gracefully_exit = False
274
274
  self.last_receive_tstamp = 0
275
- self.server_status = ServerStatus.Starting
276
275
 
277
276
  # Dumping
278
277
  self.dump_requests_folder = "" # By default do not dump
@@ -291,8 +290,8 @@ class TokenizerManager:
291
290
  self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
292
291
  None
293
292
  )
294
- self._is_updating = False
295
- self._is_updating_cond = asyncio.Condition()
293
+ self.is_pause = False
294
+ self.is_pause_cond = asyncio.Condition()
296
295
 
297
296
  # LoRA
298
297
  # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
@@ -476,16 +475,20 @@ class TokenizerManager:
476
475
  self.auto_create_handle_loop()
477
476
  obj.normalize_batch_and_arguments()
478
477
 
479
- async with self._is_updating_cond:
480
- await self._is_updating_cond.wait_for(lambda: not self._is_updating)
481
-
482
478
  if self.log_requests:
483
479
  max_length, skip_names, _ = self.log_request_metadata
484
480
  logger.info(
485
481
  f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
486
482
  )
487
483
 
484
+ async with self.is_pause_cond:
485
+ await self.is_pause_cond.wait_for(lambda: not self.is_pause)
486
+
488
487
  async with self.model_update_lock.reader_lock:
488
+ if self.server_args.enable_lora and obj.lora_path:
489
+ # Look up the LoRA ID from the registry and start tracking ongoing LoRA requests.
490
+ obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
491
+
489
492
  if obj.is_single:
490
493
  tokenized_obj = await self._tokenize_one_request(obj)
491
494
  state = self._send_one_request(obj, tokenized_obj, created_time)
@@ -553,11 +556,6 @@ class TokenizerManager:
553
556
  else:
554
557
  mm_inputs = None
555
558
 
556
- if self.server_args.enable_lora and obj.lora_path:
557
- # Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
558
- # `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
559
- obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
560
-
561
559
  self._validate_one_request(obj, input_ids)
562
560
  return self._create_tokenized_object(
563
561
  obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
@@ -701,7 +699,7 @@ class TokenizerManager:
701
699
  # Process all requests
702
700
  tokenized_objs = []
703
701
  for i, req in enumerate(requests):
704
- self._validate_token_len(obj[i], input_ids_list[i])
702
+ self._validate_one_request(obj[i], input_ids_list[i])
705
703
  tokenized_objs.append(
706
704
  self._create_tokenized_object(
707
705
  req, req.text, input_ids_list[i], None, None
@@ -775,10 +773,6 @@ class TokenizerManager:
775
773
  msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
776
774
  logger.info(msg)
777
775
 
778
- # Mark ongoing LoRA request as finished.
779
- if self.server_args.enable_lora and obj.lora_path:
780
- await self.lora_registry.release(obj.lora_id)
781
-
782
776
  # Check if this was an abort/error created by scheduler
783
777
  if isinstance(out["meta_info"].get("finish_reason"), dict):
784
778
  finish_reason = out["meta_info"]["finish_reason"]
@@ -797,6 +791,11 @@ class TokenizerManager:
797
791
  # Delete the key to prevent resending abort request to the scheduler and
798
792
  # to ensure aborted request state is cleaned up.
799
793
  del self.rid_to_state[state.obj.rid]
794
+
795
+ # Mark ongoing LoRA request as finished.
796
+ if self.server_args.enable_lora and state.obj.lora_path:
797
+ await self.lora_registry.release(state.obj.lora_id)
798
+
800
799
  raise fastapi.HTTPException(
801
800
  status_code=finish_reason["status_code"],
802
801
  detail=finish_reason["message"],
@@ -982,14 +981,14 @@ class TokenizerManager:
982
981
  await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
983
982
 
984
983
  async def pause_generation(self):
985
- async with self._is_updating_cond:
986
- self._is_updating = True
984
+ async with self.is_pause_cond:
985
+ self.is_pause = True
987
986
  self.abort_request(abort_all=True)
988
987
 
989
988
  async def continue_generation(self):
990
- async with self._is_updating_cond:
991
- self._is_updating = False
992
- self._is_updating_cond.notify_all()
989
+ async with self.is_pause_cond:
990
+ self.is_pause = False
991
+ self.is_pause_cond.notify_all()
993
992
 
994
993
  async def update_weights_from_disk(
995
994
  self,
@@ -1474,7 +1473,7 @@ class TokenizerManager:
1474
1473
  while True:
1475
1474
  remain_num_req = len(self.rid_to_state)
1476
1475
 
1477
- if self.health_check_failed:
1476
+ if self.server_status == ServerStatus.UnHealthy:
1478
1477
  # if health check failed, we should exit immediately
1479
1478
  logger.error(
1480
1479
  "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
@@ -1530,6 +1529,7 @@ class TokenizerManager:
1530
1529
  "id": rid,
1531
1530
  "finish_reason": recv_obj.finished_reasons[i],
1532
1531
  "prompt_tokens": recv_obj.prompt_tokens[i],
1532
+ "weight_version": self.server_args.weight_version,
1533
1533
  }
1534
1534
 
1535
1535
  if getattr(state.obj, "return_logprob", False):
@@ -1600,6 +1600,10 @@ class TokenizerManager:
1600
1600
  meta_info["e2e_latency"] = state.finished_time - state.created_time
1601
1601
  del self.rid_to_state[rid]
1602
1602
 
1603
+ # Mark ongoing LoRA request as finished.
1604
+ if self.server_args.enable_lora and state.obj.lora_path:
1605
+ asyncio.create_task(self.lora_registry.release(state.obj.lora_id))
1606
+
1603
1607
  state.out_list.append(out_dict)
1604
1608
  state.event.set()
1605
1609
 
@@ -1889,6 +1893,13 @@ class TokenizerManager:
1889
1893
  f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
1890
1894
  )
1891
1895
 
1896
+ batch_request = GenerateReqInput(
1897
+ token_ids_logprob=label_token_ids,
1898
+ return_logprob=True,
1899
+ stream=False,
1900
+ sampling_params={"max_new_tokens": 0},
1901
+ )
1902
+
1892
1903
  # Handle string or tokenized query/items
1893
1904
  if isinstance(query, str) and (
1894
1905
  isinstance(items, str)
@@ -1900,13 +1911,9 @@ class TokenizerManager:
1900
1911
  prompts = [f"{item}{query}" for item in items_list]
1901
1912
  else:
1902
1913
  prompts = [f"{query}{item}" for item in items_list]
1903
- batch_request = GenerateReqInput(
1904
- text=prompts,
1905
- return_logprob=True,
1906
- token_ids_logprob=label_token_ids,
1907
- stream=False,
1908
- sampling_params={"max_new_tokens": 1},
1909
- )
1914
+
1915
+ batch_request.text = prompts
1916
+
1910
1917
  elif (
1911
1918
  isinstance(query, list)
1912
1919
  and isinstance(items, list)
@@ -1918,13 +1925,8 @@ class TokenizerManager:
1918
1925
  input_ids_list = [item + query for item in items]
1919
1926
  else:
1920
1927
  input_ids_list = [query + item for item in items]
1921
- batch_request = GenerateReqInput(
1922
- input_ids=input_ids_list,
1923
- return_logprob=True,
1924
- token_ids_logprob=label_token_ids,
1925
- stream=False,
1926
- sampling_params={"max_new_tokens": 1},
1927
- )
1928
+
1929
+ batch_request.input_ids = input_ids_list
1928
1930
  else:
1929
1931
  raise ValueError(
1930
1932
  "Invalid combination of query/items types for score_request."
@@ -1936,9 +1938,20 @@ class TokenizerManager:
1936
1938
  for result in results:
1937
1939
  # Get logprobs for each token
1938
1940
  logprobs = {}
1939
- for logprob, token_id, _ in result["meta_info"].get(
1940
- "output_token_ids_logprobs", []
1941
- )[0]:
1941
+
1942
+ # For scoring requests, we read from output_token_ids_logprobs since we want
1943
+ # the logprobs for specific tokens mentioned in the label_token_ids at
1944
+ # the next position after the last token in the prompt
1945
+ output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
1946
+
1947
+ # Throw an error here if output_logprobs is None
1948
+ if output_logprobs is None:
1949
+ raise RuntimeError(
1950
+ f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
1951
+ "This usually indicates a problem with the scoring request or the backend output."
1952
+ )
1953
+
1954
+ for logprob, token_id, _ in output_logprobs[0]:
1942
1955
  if token_id in label_token_ids:
1943
1956
  logprobs[token_id] = logprob
1944
1957
 
@@ -1965,10 +1978,6 @@ class ServerStatus(Enum):
1965
1978
  Up = "Up"
1966
1979
  Starting = "Starting"
1967
1980
  UnHealthy = "UnHealthy"
1968
- Crashed = "Crashed"
1969
-
1970
- def is_healthy(self) -> bool:
1971
- return self == ServerStatus.Up
1972
1981
 
1973
1982
 
1974
1983
  def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
@@ -20,7 +20,6 @@ Page-aligned memory pool.
20
20
  """
21
21
 
22
22
  import abc
23
- import weakref
24
23
  from typing import TYPE_CHECKING
25
24
 
26
25
  import torch
@@ -43,12 +42,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
43
42
  dtype: torch.dtype,
44
43
  device: str,
45
44
  kvcache: KVCache,
45
+ need_sort: bool,
46
46
  ):
47
47
  self.size = size
48
48
  self.page_size = page_size
49
49
  self.dtype = dtype
50
50
  self.device = device
51
51
  self._kvcache = kvcache
52
+ self.need_sort = need_sort
52
53
 
53
54
  self.free_pages = None
54
55
  self.release_pages = None
@@ -117,8 +118,15 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
117
118
  class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
118
119
  """An allocator managing the indices to kv cache data."""
119
120
 
120
- def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
121
- super().__init__(size, 1, dtype, device, kvcache)
121
+ def __init__(
122
+ self,
123
+ size: int,
124
+ dtype: torch.dtype,
125
+ device: str,
126
+ kvcache: KVCache,
127
+ need_sort: bool,
128
+ ):
129
+ super().__init__(size, 1, dtype, device, kvcache, need_sort)
122
130
  self.clear()
123
131
 
124
132
  def clear(self):
@@ -135,8 +143,9 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
135
143
  return len(self.free_pages) + len(self.release_pages)
136
144
 
137
145
  def alloc(self, need_size: int):
138
- if need_size > len(self.free_pages):
146
+ if self.need_sort and need_size > len(self.free_pages):
139
147
  self.merge_and_sort_free()
148
+
140
149
  if need_size > len(self.free_pages):
141
150
  return None
142
151
 
@@ -149,7 +158,10 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
149
158
  return
150
159
 
151
160
  if self.is_not_in_free_group:
152
- self.release_pages = torch.cat((self.release_pages, free_index))
161
+ if self.need_sort:
162
+ self.release_pages = torch.cat((self.release_pages, free_index))
163
+ else:
164
+ self.free_pages = torch.cat((self.free_pages, free_index))
153
165
  else:
154
166
  self.free_group.append(free_index)
155
167
 
@@ -170,8 +182,9 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
170
182
  dtype: torch.dtype,
171
183
  device: str,
172
184
  kvcache: SWAKVPool,
185
+ need_sort: bool,
173
186
  ):
174
- super().__init__(size, 1, dtype, device, kvcache)
187
+ super().__init__(size, 1, dtype, device, kvcache, need_sort)
175
188
  assert isinstance(kvcache, SWAKVPool)
176
189
  self._size_full = size
177
190
  self._size_swa = size_swa
@@ -180,12 +193,14 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
180
193
  dtype,
181
194
  device,
182
195
  kvcache.full_kv_pool,
196
+ need_sort,
183
197
  )
184
198
  self.swa_attn_allocator = TokenToKVPoolAllocator(
185
199
  size_swa,
186
200
  dtype,
187
201
  device,
188
202
  kvcache.swa_kv_pool,
203
+ need_sort,
189
204
  )
190
205
  self.full_to_swa_index_mapping = torch.empty(
191
206
  size + size_swa + 1,
@@ -418,9 +433,14 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
418
433
  dtype: torch.dtype,
419
434
  device: str,
420
435
  kvcache: KVCache,
436
+ need_sort: bool,
437
+ max_num_extend_tokens: int,
421
438
  ):
422
- super().__init__(size, page_size, dtype, device, kvcache)
439
+ super().__init__(size, page_size, dtype, device, kvcache, need_sort)
423
440
  self.num_pages = size // page_size
441
+ self.max_num_extend_tokens_next_power_of_2 = next_power_of_2(
442
+ max_num_extend_tokens
443
+ )
424
444
  self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
425
445
  self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
426
446
  self.clear()
@@ -433,7 +453,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
433
453
  ), "The allocation size should be page-aligned"
434
454
 
435
455
  num_pages = need_size // self.page_size
436
- if num_pages > len(self.free_pages):
456
+ if self.need_sort and num_pages > len(self.free_pages):
437
457
  self.merge_and_sort_free()
438
458
  if num_pages > len(self.free_pages):
439
459
  return None
@@ -460,18 +480,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
460
480
  (last_loc + 1) % self.page_size == prefix_lens % self.page_size
461
481
  )
462
482
 
463
- estimated_num_new_pages = (
464
- (
465
- (seq_lens + self.page_size - 1) // self.page_size
466
- - (prefix_lens + self.page_size - 1) // self.page_size
467
- )
468
- .sum()
469
- .item()
470
- )
471
- if estimated_num_new_pages > len(self.free_pages):
483
+ bs = len(prefix_lens)
484
+ if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
485
+ self.free_pages
486
+ ):
472
487
  self.merge_and_sort_free()
473
488
 
474
- bs = len(prefix_lens)
475
489
  out_indices = torch.empty(
476
490
  (extend_num_tokens,), dtype=torch.int64, device=self.device
477
491
  )
@@ -484,7 +498,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
484
498
  self.ret_values,
485
499
  next_power_of_2(bs),
486
500
  self.page_size,
487
- next_power_of_2(extend_num_tokens),
501
+ self.max_num_extend_tokens_next_power_of_2,
488
502
  )
489
503
 
490
504
  if self.debug_mode:
@@ -508,18 +522,10 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
508
522
  (last_loc + 2) % self.page_size == seq_lens % self.page_size
509
523
  )
510
524
 
511
- estimated_num_new_pages = (
512
- (
513
- (seq_lens + self.page_size - 1) // self.page_size
514
- - (seq_lens - 1 + self.page_size - 1) // self.page_size
515
- )
516
- .sum()
517
- .item()
518
- )
519
- if estimated_num_new_pages > len(self.free_pages):
525
+ bs = len(seq_lens)
526
+ if self.need_sort and bs > len(self.free_pages):
520
527
  self.merge_and_sort_free()
521
528
 
522
- bs = len(seq_lens)
523
529
  out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
524
530
  alloc_decode_kernel[(bs,)](
525
531
  seq_lens,
@@ -547,7 +553,10 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
547
553
 
548
554
  if self.is_not_in_free_group:
549
555
  free_page_indices = torch.unique(free_index // self.page_size)
550
- self.release_pages = torch.cat((free_page_indices, self.release_pages))
556
+ if self.need_sort:
557
+ self.release_pages = torch.cat((free_page_indices, self.release_pages))
558
+ else:
559
+ self.free_pages = torch.cat((free_page_indices, self.free_pages))
551
560
  else:
552
561
  self.free_group.append(free_index)
553
562
 
@@ -568,187 +577,3 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
568
577
 
569
578
  def load_cpu_copy(self, kv_cache_cpu, indices):
570
579
  return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
571
-
572
-
573
- def alloc_extend_kernel_ascend(
574
- prefix_lens,
575
- seq_lens,
576
- last_loc,
577
- free_pages,
578
- out_indices,
579
- page_size,
580
- device,
581
- ):
582
- extend_lens = seq_lens - prefix_lens
583
- end_pos = torch.cumsum(extend_lens, 0)
584
- start_pos = end_pos - extend_lens
585
- num_new_pages = (seq_lens + page_size - 1) // page_size - (
586
- prefix_lens + page_size - 1
587
- ) // page_size
588
- num_full_new_pages = (seq_lens) // page_size - (
589
- prefix_lens + page_size - 1
590
- ) // page_size
591
- need_page = num_new_pages - num_full_new_pages
592
- end_new_pages = torch.cumsum(num_new_pages, 0)
593
- start_new_pages = end_new_pages - num_new_pages
594
- pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
595
- for i in range(len(prefix_lens)):
596
- num1 = (
597
- min(
598
- seq_lens[i],
599
- (prefix_lens[i] + page_size - 1) // page_size * page_size,
600
- )
601
- - prefix_lens[i]
602
- )
603
- if num1:
604
- out_indices[start_pos[i] : start_pos[i] + num1] = (
605
- last_loc[i] + 1 + pos_in_page[:num1].view(-1)
606
- )
607
-
608
- num2 = (
609
- seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
610
- ) * page_size
611
- if num2:
612
- pages = (
613
- free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
614
- * page_size
615
- )
616
- out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
617
- pages.view(-1, 1) + pos_in_page.view(1, -1)
618
- ).view(-1)
619
-
620
- num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
621
- if num3:
622
- out_indices[end_pos[i] - num3 : end_pos[i]] = (
623
- free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
624
- ).view(-1)
625
- return num_new_pages
626
-
627
-
628
- def alloc_decode_kernel_ascend(
629
- seq_lens,
630
- last_loc,
631
- free_pages,
632
- out_indices,
633
- page_size,
634
- ):
635
- num_new_pages = (seq_lens + page_size - 1) // page_size - (
636
- seq_lens - 1 + page_size - 1
637
- ) // page_size
638
- end_new_pages = torch.cumsum(num_new_pages, 0)
639
- start_new_pages = end_new_pages - num_new_pages
640
- for i in range(len(seq_lens)):
641
- if num_new_pages[i]:
642
- out_indices[i] = free_pages[start_new_pages[i]] * page_size
643
- else:
644
- out_indices[i] = last_loc[i] + 1
645
- return num_new_pages
646
-
647
-
648
- class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
649
-
650
- def __init__(
651
- self,
652
- size: int,
653
- page_size: int,
654
- dtype: torch.dtype,
655
- device: str,
656
- kvcache: KVCache,
657
- ):
658
- super().__init__(size, page_size, dtype, device, kvcache)
659
- self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
660
-
661
- def alloc_extend(
662
- self,
663
- prefix_lens: torch.Tensor,
664
- seq_lens: torch.Tensor,
665
- last_loc: torch.Tensor,
666
- extend_num_tokens: int,
667
- ):
668
- if self.debug_mode:
669
- assert torch.all(
670
- (last_loc + 1) % self.page_size == prefix_lens % self.page_size
671
- )
672
-
673
- estimated_num_new_pages = (
674
- (
675
- (seq_lens + self.page_size - 1) // self.page_size
676
- - (prefix_lens + self.page_size - 1) // self.page_size
677
- )
678
- .sum()
679
- .item()
680
- )
681
- if estimated_num_new_pages > len(self.free_pages):
682
- self.merge_and_sort_free()
683
-
684
- bs = len(prefix_lens)
685
- out_indices = torch.empty(
686
- (extend_num_tokens,), dtype=torch.int32, device=self.device
687
- )
688
-
689
- self.ret_values = alloc_extend_kernel_ascend(
690
- prefix_lens,
691
- seq_lens,
692
- last_loc,
693
- self.free_pages,
694
- out_indices,
695
- self.page_size,
696
- self.device,
697
- )
698
-
699
- if self.debug_mode:
700
- assert len(torch.unique(out_indices)) == len(out_indices)
701
-
702
- num_new_pages = self.ret_values.sum()
703
- if num_new_pages > len(self.free_pages):
704
- return None
705
-
706
- self.free_pages = self.free_pages[num_new_pages:]
707
- return out_indices
708
-
709
- def alloc_decode(
710
- self,
711
- seq_lens: torch.Tensor,
712
- last_loc: torch.Tensor,
713
- ):
714
- if self.debug_mode:
715
- assert torch.all(
716
- (last_loc + 2) % self.page_size == seq_lens % self.page_size
717
- )
718
-
719
- estimated_num_new_pages = (
720
- (
721
- (seq_lens + self.page_size - 1) // self.page_size
722
- - (seq_lens - 1 + self.page_size - 1) // self.page_size
723
- )
724
- .sum()
725
- .item()
726
- )
727
- if estimated_num_new_pages > len(self.free_pages):
728
- self.merge_and_sort_free()
729
-
730
- bs = len(seq_lens)
731
- out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
732
-
733
- self.ret_values = alloc_decode_kernel_ascend(
734
- seq_lens,
735
- last_loc,
736
- self.free_pages,
737
- out_indices,
738
- self.page_size,
739
- )
740
-
741
- if self.debug_mode:
742
- assert len(torch.unique(out_indices)) == len(out_indices)
743
-
744
- num_new_pages = self.ret_values.sum()
745
- if num_new_pages > len(self.free_pages):
746
- return None
747
-
748
- self.free_pages = self.free_pages[num_new_pages:]
749
- return out_indices
750
-
751
- def clear(self):
752
- super().clear()
753
- self.free_pages = self.free_pages.to(torch.int32)
754
- self.release_pages = self.release_pages.to(torch.int32)