sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -269,10 +269,9 @@ class TokenizerManager:
|
|
269
269
|
self.asyncio_tasks = set()
|
270
270
|
|
271
271
|
# Health check
|
272
|
-
self.
|
272
|
+
self.server_status = ServerStatus.Starting
|
273
273
|
self.gracefully_exit = False
|
274
274
|
self.last_receive_tstamp = 0
|
275
|
-
self.server_status = ServerStatus.Starting
|
276
275
|
|
277
276
|
# Dumping
|
278
277
|
self.dump_requests_folder = "" # By default do not dump
|
@@ -291,8 +290,8 @@ class TokenizerManager:
|
|
291
290
|
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
292
291
|
None
|
293
292
|
)
|
294
|
-
self.
|
295
|
-
self.
|
293
|
+
self.is_pause = False
|
294
|
+
self.is_pause_cond = asyncio.Condition()
|
296
295
|
|
297
296
|
# LoRA
|
298
297
|
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
@@ -476,16 +475,20 @@ class TokenizerManager:
|
|
476
475
|
self.auto_create_handle_loop()
|
477
476
|
obj.normalize_batch_and_arguments()
|
478
477
|
|
479
|
-
async with self._is_updating_cond:
|
480
|
-
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
|
481
|
-
|
482
478
|
if self.log_requests:
|
483
479
|
max_length, skip_names, _ = self.log_request_metadata
|
484
480
|
logger.info(
|
485
481
|
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
486
482
|
)
|
487
483
|
|
484
|
+
async with self.is_pause_cond:
|
485
|
+
await self.is_pause_cond.wait_for(lambda: not self.is_pause)
|
486
|
+
|
488
487
|
async with self.model_update_lock.reader_lock:
|
488
|
+
if self.server_args.enable_lora and obj.lora_path:
|
489
|
+
# Look up the LoRA ID from the registry and start tracking ongoing LoRA requests.
|
490
|
+
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
|
491
|
+
|
489
492
|
if obj.is_single:
|
490
493
|
tokenized_obj = await self._tokenize_one_request(obj)
|
491
494
|
state = self._send_one_request(obj, tokenized_obj, created_time)
|
@@ -553,11 +556,6 @@ class TokenizerManager:
|
|
553
556
|
else:
|
554
557
|
mm_inputs = None
|
555
558
|
|
556
|
-
if self.server_args.enable_lora and obj.lora_path:
|
557
|
-
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
|
558
|
-
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
|
559
|
-
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
|
560
|
-
|
561
559
|
self._validate_one_request(obj, input_ids)
|
562
560
|
return self._create_tokenized_object(
|
563
561
|
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
@@ -701,7 +699,7 @@ class TokenizerManager:
|
|
701
699
|
# Process all requests
|
702
700
|
tokenized_objs = []
|
703
701
|
for i, req in enumerate(requests):
|
704
|
-
self.
|
702
|
+
self._validate_one_request(obj[i], input_ids_list[i])
|
705
703
|
tokenized_objs.append(
|
706
704
|
self._create_tokenized_object(
|
707
705
|
req, req.text, input_ids_list[i], None, None
|
@@ -775,10 +773,6 @@ class TokenizerManager:
|
|
775
773
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
776
774
|
logger.info(msg)
|
777
775
|
|
778
|
-
# Mark ongoing LoRA request as finished.
|
779
|
-
if self.server_args.enable_lora and obj.lora_path:
|
780
|
-
await self.lora_registry.release(obj.lora_id)
|
781
|
-
|
782
776
|
# Check if this was an abort/error created by scheduler
|
783
777
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
784
778
|
finish_reason = out["meta_info"]["finish_reason"]
|
@@ -797,6 +791,11 @@ class TokenizerManager:
|
|
797
791
|
# Delete the key to prevent resending abort request to the scheduler and
|
798
792
|
# to ensure aborted request state is cleaned up.
|
799
793
|
del self.rid_to_state[state.obj.rid]
|
794
|
+
|
795
|
+
# Mark ongoing LoRA request as finished.
|
796
|
+
if self.server_args.enable_lora and state.obj.lora_path:
|
797
|
+
await self.lora_registry.release(state.obj.lora_id)
|
798
|
+
|
800
799
|
raise fastapi.HTTPException(
|
801
800
|
status_code=finish_reason["status_code"],
|
802
801
|
detail=finish_reason["message"],
|
@@ -982,14 +981,14 @@ class TokenizerManager:
|
|
982
981
|
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
983
982
|
|
984
983
|
async def pause_generation(self):
|
985
|
-
async with self.
|
986
|
-
self.
|
984
|
+
async with self.is_pause_cond:
|
985
|
+
self.is_pause = True
|
987
986
|
self.abort_request(abort_all=True)
|
988
987
|
|
989
988
|
async def continue_generation(self):
|
990
|
-
async with self.
|
991
|
-
self.
|
992
|
-
self.
|
989
|
+
async with self.is_pause_cond:
|
990
|
+
self.is_pause = False
|
991
|
+
self.is_pause_cond.notify_all()
|
993
992
|
|
994
993
|
async def update_weights_from_disk(
|
995
994
|
self,
|
@@ -1474,7 +1473,7 @@ class TokenizerManager:
|
|
1474
1473
|
while True:
|
1475
1474
|
remain_num_req = len(self.rid_to_state)
|
1476
1475
|
|
1477
|
-
if self.
|
1476
|
+
if self.server_status == ServerStatus.UnHealthy:
|
1478
1477
|
# if health check failed, we should exit immediately
|
1479
1478
|
logger.error(
|
1480
1479
|
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
@@ -1530,6 +1529,7 @@ class TokenizerManager:
|
|
1530
1529
|
"id": rid,
|
1531
1530
|
"finish_reason": recv_obj.finished_reasons[i],
|
1532
1531
|
"prompt_tokens": recv_obj.prompt_tokens[i],
|
1532
|
+
"weight_version": self.server_args.weight_version,
|
1533
1533
|
}
|
1534
1534
|
|
1535
1535
|
if getattr(state.obj, "return_logprob", False):
|
@@ -1600,6 +1600,10 @@ class TokenizerManager:
|
|
1600
1600
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
1601
1601
|
del self.rid_to_state[rid]
|
1602
1602
|
|
1603
|
+
# Mark ongoing LoRA request as finished.
|
1604
|
+
if self.server_args.enable_lora and state.obj.lora_path:
|
1605
|
+
asyncio.create_task(self.lora_registry.release(state.obj.lora_id))
|
1606
|
+
|
1603
1607
|
state.out_list.append(out_dict)
|
1604
1608
|
state.event.set()
|
1605
1609
|
|
@@ -1889,6 +1893,13 @@ class TokenizerManager:
|
|
1889
1893
|
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
1890
1894
|
)
|
1891
1895
|
|
1896
|
+
batch_request = GenerateReqInput(
|
1897
|
+
token_ids_logprob=label_token_ids,
|
1898
|
+
return_logprob=True,
|
1899
|
+
stream=False,
|
1900
|
+
sampling_params={"max_new_tokens": 0},
|
1901
|
+
)
|
1902
|
+
|
1892
1903
|
# Handle string or tokenized query/items
|
1893
1904
|
if isinstance(query, str) and (
|
1894
1905
|
isinstance(items, str)
|
@@ -1900,13 +1911,9 @@ class TokenizerManager:
|
|
1900
1911
|
prompts = [f"{item}{query}" for item in items_list]
|
1901
1912
|
else:
|
1902
1913
|
prompts = [f"{query}{item}" for item in items_list]
|
1903
|
-
|
1904
|
-
|
1905
|
-
|
1906
|
-
token_ids_logprob=label_token_ids,
|
1907
|
-
stream=False,
|
1908
|
-
sampling_params={"max_new_tokens": 1},
|
1909
|
-
)
|
1914
|
+
|
1915
|
+
batch_request.text = prompts
|
1916
|
+
|
1910
1917
|
elif (
|
1911
1918
|
isinstance(query, list)
|
1912
1919
|
and isinstance(items, list)
|
@@ -1918,13 +1925,8 @@ class TokenizerManager:
|
|
1918
1925
|
input_ids_list = [item + query for item in items]
|
1919
1926
|
else:
|
1920
1927
|
input_ids_list = [query + item for item in items]
|
1921
|
-
|
1922
|
-
|
1923
|
-
return_logprob=True,
|
1924
|
-
token_ids_logprob=label_token_ids,
|
1925
|
-
stream=False,
|
1926
|
-
sampling_params={"max_new_tokens": 1},
|
1927
|
-
)
|
1928
|
+
|
1929
|
+
batch_request.input_ids = input_ids_list
|
1928
1930
|
else:
|
1929
1931
|
raise ValueError(
|
1930
1932
|
"Invalid combination of query/items types for score_request."
|
@@ -1936,9 +1938,20 @@ class TokenizerManager:
|
|
1936
1938
|
for result in results:
|
1937
1939
|
# Get logprobs for each token
|
1938
1940
|
logprobs = {}
|
1939
|
-
|
1940
|
-
|
1941
|
-
|
1941
|
+
|
1942
|
+
# For scoring requests, we read from output_token_ids_logprobs since we want
|
1943
|
+
# the logprobs for specific tokens mentioned in the label_token_ids at
|
1944
|
+
# the next position after the last token in the prompt
|
1945
|
+
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
1946
|
+
|
1947
|
+
# Throw an error here if output_logprobs is None
|
1948
|
+
if output_logprobs is None:
|
1949
|
+
raise RuntimeError(
|
1950
|
+
f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
|
1951
|
+
"This usually indicates a problem with the scoring request or the backend output."
|
1952
|
+
)
|
1953
|
+
|
1954
|
+
for logprob, token_id, _ in output_logprobs[0]:
|
1942
1955
|
if token_id in label_token_ids:
|
1943
1956
|
logprobs[token_id] = logprob
|
1944
1957
|
|
@@ -1965,10 +1978,6 @@ class ServerStatus(Enum):
|
|
1965
1978
|
Up = "Up"
|
1966
1979
|
Starting = "Starting"
|
1967
1980
|
UnHealthy = "UnHealthy"
|
1968
|
-
Crashed = "Crashed"
|
1969
|
-
|
1970
|
-
def is_healthy(self) -> bool:
|
1971
|
-
return self == ServerStatus.Up
|
1972
1981
|
|
1973
1982
|
|
1974
1983
|
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
@@ -20,7 +20,6 @@ Page-aligned memory pool.
|
|
20
20
|
"""
|
21
21
|
|
22
22
|
import abc
|
23
|
-
import weakref
|
24
23
|
from typing import TYPE_CHECKING
|
25
24
|
|
26
25
|
import torch
|
@@ -43,12 +42,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
43
42
|
dtype: torch.dtype,
|
44
43
|
device: str,
|
45
44
|
kvcache: KVCache,
|
45
|
+
need_sort: bool,
|
46
46
|
):
|
47
47
|
self.size = size
|
48
48
|
self.page_size = page_size
|
49
49
|
self.dtype = dtype
|
50
50
|
self.device = device
|
51
51
|
self._kvcache = kvcache
|
52
|
+
self.need_sort = need_sort
|
52
53
|
|
53
54
|
self.free_pages = None
|
54
55
|
self.release_pages = None
|
@@ -117,8 +118,15 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
117
118
|
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
118
119
|
"""An allocator managing the indices to kv cache data."""
|
119
120
|
|
120
|
-
def __init__(
|
121
|
-
|
121
|
+
def __init__(
|
122
|
+
self,
|
123
|
+
size: int,
|
124
|
+
dtype: torch.dtype,
|
125
|
+
device: str,
|
126
|
+
kvcache: KVCache,
|
127
|
+
need_sort: bool,
|
128
|
+
):
|
129
|
+
super().__init__(size, 1, dtype, device, kvcache, need_sort)
|
122
130
|
self.clear()
|
123
131
|
|
124
132
|
def clear(self):
|
@@ -135,8 +143,9 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
135
143
|
return len(self.free_pages) + len(self.release_pages)
|
136
144
|
|
137
145
|
def alloc(self, need_size: int):
|
138
|
-
if need_size > len(self.free_pages):
|
146
|
+
if self.need_sort and need_size > len(self.free_pages):
|
139
147
|
self.merge_and_sort_free()
|
148
|
+
|
140
149
|
if need_size > len(self.free_pages):
|
141
150
|
return None
|
142
151
|
|
@@ -149,7 +158,10 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
149
158
|
return
|
150
159
|
|
151
160
|
if self.is_not_in_free_group:
|
152
|
-
|
161
|
+
if self.need_sort:
|
162
|
+
self.release_pages = torch.cat((self.release_pages, free_index))
|
163
|
+
else:
|
164
|
+
self.free_pages = torch.cat((self.free_pages, free_index))
|
153
165
|
else:
|
154
166
|
self.free_group.append(free_index)
|
155
167
|
|
@@ -170,8 +182,9 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
170
182
|
dtype: torch.dtype,
|
171
183
|
device: str,
|
172
184
|
kvcache: SWAKVPool,
|
185
|
+
need_sort: bool,
|
173
186
|
):
|
174
|
-
super().__init__(size, 1, dtype, device, kvcache)
|
187
|
+
super().__init__(size, 1, dtype, device, kvcache, need_sort)
|
175
188
|
assert isinstance(kvcache, SWAKVPool)
|
176
189
|
self._size_full = size
|
177
190
|
self._size_swa = size_swa
|
@@ -180,12 +193,14 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
180
193
|
dtype,
|
181
194
|
device,
|
182
195
|
kvcache.full_kv_pool,
|
196
|
+
need_sort,
|
183
197
|
)
|
184
198
|
self.swa_attn_allocator = TokenToKVPoolAllocator(
|
185
199
|
size_swa,
|
186
200
|
dtype,
|
187
201
|
device,
|
188
202
|
kvcache.swa_kv_pool,
|
203
|
+
need_sort,
|
189
204
|
)
|
190
205
|
self.full_to_swa_index_mapping = torch.empty(
|
191
206
|
size + size_swa + 1,
|
@@ -418,9 +433,14 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
418
433
|
dtype: torch.dtype,
|
419
434
|
device: str,
|
420
435
|
kvcache: KVCache,
|
436
|
+
need_sort: bool,
|
437
|
+
max_num_extend_tokens: int,
|
421
438
|
):
|
422
|
-
super().__init__(size, page_size, dtype, device, kvcache)
|
439
|
+
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
423
440
|
self.num_pages = size // page_size
|
441
|
+
self.max_num_extend_tokens_next_power_of_2 = next_power_of_2(
|
442
|
+
max_num_extend_tokens
|
443
|
+
)
|
424
444
|
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
425
445
|
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
426
446
|
self.clear()
|
@@ -433,7 +453,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
433
453
|
), "The allocation size should be page-aligned"
|
434
454
|
|
435
455
|
num_pages = need_size // self.page_size
|
436
|
-
if num_pages > len(self.free_pages):
|
456
|
+
if self.need_sort and num_pages > len(self.free_pages):
|
437
457
|
self.merge_and_sort_free()
|
438
458
|
if num_pages > len(self.free_pages):
|
439
459
|
return None
|
@@ -460,18 +480,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
460
480
|
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
461
481
|
)
|
462
482
|
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
)
|
468
|
-
.sum()
|
469
|
-
.item()
|
470
|
-
)
|
471
|
-
if estimated_num_new_pages > len(self.free_pages):
|
483
|
+
bs = len(prefix_lens)
|
484
|
+
if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
|
485
|
+
self.free_pages
|
486
|
+
):
|
472
487
|
self.merge_and_sort_free()
|
473
488
|
|
474
|
-
bs = len(prefix_lens)
|
475
489
|
out_indices = torch.empty(
|
476
490
|
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
477
491
|
)
|
@@ -484,7 +498,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
484
498
|
self.ret_values,
|
485
499
|
next_power_of_2(bs),
|
486
500
|
self.page_size,
|
487
|
-
|
501
|
+
self.max_num_extend_tokens_next_power_of_2,
|
488
502
|
)
|
489
503
|
|
490
504
|
if self.debug_mode:
|
@@ -508,18 +522,10 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
508
522
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
509
523
|
)
|
510
524
|
|
511
|
-
|
512
|
-
|
513
|
-
(seq_lens + self.page_size - 1) // self.page_size
|
514
|
-
- (seq_lens - 1 + self.page_size - 1) // self.page_size
|
515
|
-
)
|
516
|
-
.sum()
|
517
|
-
.item()
|
518
|
-
)
|
519
|
-
if estimated_num_new_pages > len(self.free_pages):
|
525
|
+
bs = len(seq_lens)
|
526
|
+
if self.need_sort and bs > len(self.free_pages):
|
520
527
|
self.merge_and_sort_free()
|
521
528
|
|
522
|
-
bs = len(seq_lens)
|
523
529
|
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
524
530
|
alloc_decode_kernel[(bs,)](
|
525
531
|
seq_lens,
|
@@ -547,7 +553,10 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
547
553
|
|
548
554
|
if self.is_not_in_free_group:
|
549
555
|
free_page_indices = torch.unique(free_index // self.page_size)
|
550
|
-
|
556
|
+
if self.need_sort:
|
557
|
+
self.release_pages = torch.cat((free_page_indices, self.release_pages))
|
558
|
+
else:
|
559
|
+
self.free_pages = torch.cat((free_page_indices, self.free_pages))
|
551
560
|
else:
|
552
561
|
self.free_group.append(free_index)
|
553
562
|
|
@@ -568,187 +577,3 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
568
577
|
|
569
578
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
570
579
|
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
571
|
-
|
572
|
-
|
573
|
-
def alloc_extend_kernel_ascend(
|
574
|
-
prefix_lens,
|
575
|
-
seq_lens,
|
576
|
-
last_loc,
|
577
|
-
free_pages,
|
578
|
-
out_indices,
|
579
|
-
page_size,
|
580
|
-
device,
|
581
|
-
):
|
582
|
-
extend_lens = seq_lens - prefix_lens
|
583
|
-
end_pos = torch.cumsum(extend_lens, 0)
|
584
|
-
start_pos = end_pos - extend_lens
|
585
|
-
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
586
|
-
prefix_lens + page_size - 1
|
587
|
-
) // page_size
|
588
|
-
num_full_new_pages = (seq_lens) // page_size - (
|
589
|
-
prefix_lens + page_size - 1
|
590
|
-
) // page_size
|
591
|
-
need_page = num_new_pages - num_full_new_pages
|
592
|
-
end_new_pages = torch.cumsum(num_new_pages, 0)
|
593
|
-
start_new_pages = end_new_pages - num_new_pages
|
594
|
-
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
|
595
|
-
for i in range(len(prefix_lens)):
|
596
|
-
num1 = (
|
597
|
-
min(
|
598
|
-
seq_lens[i],
|
599
|
-
(prefix_lens[i] + page_size - 1) // page_size * page_size,
|
600
|
-
)
|
601
|
-
- prefix_lens[i]
|
602
|
-
)
|
603
|
-
if num1:
|
604
|
-
out_indices[start_pos[i] : start_pos[i] + num1] = (
|
605
|
-
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
|
606
|
-
)
|
607
|
-
|
608
|
-
num2 = (
|
609
|
-
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
|
610
|
-
) * page_size
|
611
|
-
if num2:
|
612
|
-
pages = (
|
613
|
-
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
|
614
|
-
* page_size
|
615
|
-
)
|
616
|
-
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
|
617
|
-
pages.view(-1, 1) + pos_in_page.view(1, -1)
|
618
|
-
).view(-1)
|
619
|
-
|
620
|
-
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
|
621
|
-
if num3:
|
622
|
-
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
623
|
-
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
624
|
-
).view(-1)
|
625
|
-
return num_new_pages
|
626
|
-
|
627
|
-
|
628
|
-
def alloc_decode_kernel_ascend(
|
629
|
-
seq_lens,
|
630
|
-
last_loc,
|
631
|
-
free_pages,
|
632
|
-
out_indices,
|
633
|
-
page_size,
|
634
|
-
):
|
635
|
-
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
636
|
-
seq_lens - 1 + page_size - 1
|
637
|
-
) // page_size
|
638
|
-
end_new_pages = torch.cumsum(num_new_pages, 0)
|
639
|
-
start_new_pages = end_new_pages - num_new_pages
|
640
|
-
for i in range(len(seq_lens)):
|
641
|
-
if num_new_pages[i]:
|
642
|
-
out_indices[i] = free_pages[start_new_pages[i]] * page_size
|
643
|
-
else:
|
644
|
-
out_indices[i] = last_loc[i] + 1
|
645
|
-
return num_new_pages
|
646
|
-
|
647
|
-
|
648
|
-
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
649
|
-
|
650
|
-
def __init__(
|
651
|
-
self,
|
652
|
-
size: int,
|
653
|
-
page_size: int,
|
654
|
-
dtype: torch.dtype,
|
655
|
-
device: str,
|
656
|
-
kvcache: KVCache,
|
657
|
-
):
|
658
|
-
super().__init__(size, page_size, dtype, device, kvcache)
|
659
|
-
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
|
660
|
-
|
661
|
-
def alloc_extend(
|
662
|
-
self,
|
663
|
-
prefix_lens: torch.Tensor,
|
664
|
-
seq_lens: torch.Tensor,
|
665
|
-
last_loc: torch.Tensor,
|
666
|
-
extend_num_tokens: int,
|
667
|
-
):
|
668
|
-
if self.debug_mode:
|
669
|
-
assert torch.all(
|
670
|
-
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
671
|
-
)
|
672
|
-
|
673
|
-
estimated_num_new_pages = (
|
674
|
-
(
|
675
|
-
(seq_lens + self.page_size - 1) // self.page_size
|
676
|
-
- (prefix_lens + self.page_size - 1) // self.page_size
|
677
|
-
)
|
678
|
-
.sum()
|
679
|
-
.item()
|
680
|
-
)
|
681
|
-
if estimated_num_new_pages > len(self.free_pages):
|
682
|
-
self.merge_and_sort_free()
|
683
|
-
|
684
|
-
bs = len(prefix_lens)
|
685
|
-
out_indices = torch.empty(
|
686
|
-
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
687
|
-
)
|
688
|
-
|
689
|
-
self.ret_values = alloc_extend_kernel_ascend(
|
690
|
-
prefix_lens,
|
691
|
-
seq_lens,
|
692
|
-
last_loc,
|
693
|
-
self.free_pages,
|
694
|
-
out_indices,
|
695
|
-
self.page_size,
|
696
|
-
self.device,
|
697
|
-
)
|
698
|
-
|
699
|
-
if self.debug_mode:
|
700
|
-
assert len(torch.unique(out_indices)) == len(out_indices)
|
701
|
-
|
702
|
-
num_new_pages = self.ret_values.sum()
|
703
|
-
if num_new_pages > len(self.free_pages):
|
704
|
-
return None
|
705
|
-
|
706
|
-
self.free_pages = self.free_pages[num_new_pages:]
|
707
|
-
return out_indices
|
708
|
-
|
709
|
-
def alloc_decode(
|
710
|
-
self,
|
711
|
-
seq_lens: torch.Tensor,
|
712
|
-
last_loc: torch.Tensor,
|
713
|
-
):
|
714
|
-
if self.debug_mode:
|
715
|
-
assert torch.all(
|
716
|
-
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
717
|
-
)
|
718
|
-
|
719
|
-
estimated_num_new_pages = (
|
720
|
-
(
|
721
|
-
(seq_lens + self.page_size - 1) // self.page_size
|
722
|
-
- (seq_lens - 1 + self.page_size - 1) // self.page_size
|
723
|
-
)
|
724
|
-
.sum()
|
725
|
-
.item()
|
726
|
-
)
|
727
|
-
if estimated_num_new_pages > len(self.free_pages):
|
728
|
-
self.merge_and_sort_free()
|
729
|
-
|
730
|
-
bs = len(seq_lens)
|
731
|
-
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
732
|
-
|
733
|
-
self.ret_values = alloc_decode_kernel_ascend(
|
734
|
-
seq_lens,
|
735
|
-
last_loc,
|
736
|
-
self.free_pages,
|
737
|
-
out_indices,
|
738
|
-
self.page_size,
|
739
|
-
)
|
740
|
-
|
741
|
-
if self.debug_mode:
|
742
|
-
assert len(torch.unique(out_indices)) == len(out_indices)
|
743
|
-
|
744
|
-
num_new_pages = self.ret_values.sum()
|
745
|
-
if num_new_pages > len(self.free_pages):
|
746
|
-
return None
|
747
|
-
|
748
|
-
self.free_pages = self.free_pages[num_new_pages:]
|
749
|
-
return out_indices
|
750
|
-
|
751
|
-
def clear(self):
|
752
|
-
super().clear()
|
753
|
-
self.free_pages = self.free_pages.to(torch.int32)
|
754
|
-
self.release_pages = self.release_pages.to(torch.int32)
|