sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,640 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""Run the model with cpu torch compile."""
|
15
|
+
|
16
|
+
# The implementation of CPUGraphRunner follows the CudaGraphRunner
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
import logging
|
21
|
+
from contextlib import contextmanager
|
22
|
+
from typing import TYPE_CHECKING, Callable, Optional, Union
|
23
|
+
|
24
|
+
import psutil
|
25
|
+
import torch
|
26
|
+
import tqdm
|
27
|
+
|
28
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
29
|
+
from sglang.srt.distributed.parallel_state import GroupCoordinator
|
30
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
31
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
32
|
+
CaptureHiddenMode,
|
33
|
+
ForwardBatch,
|
34
|
+
ForwardMode,
|
35
|
+
PPProxyTensors,
|
36
|
+
)
|
37
|
+
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
38
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
39
|
+
from sglang.srt.utils import (
|
40
|
+
log_info_on_rank0,
|
41
|
+
require_attn_tp_gather,
|
42
|
+
require_gathered_buffer,
|
43
|
+
require_mlp_sync,
|
44
|
+
require_mlp_tp_gather,
|
45
|
+
)
|
46
|
+
|
47
|
+
logger = logging.getLogger(__name__)
|
48
|
+
|
49
|
+
if TYPE_CHECKING:
|
50
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
51
|
+
|
52
|
+
|
53
|
+
@contextmanager
|
54
|
+
def patch_model(
|
55
|
+
model: torch.nn.Module,
|
56
|
+
enable_compile: bool,
|
57
|
+
num_tokens: int,
|
58
|
+
tp_group: GroupCoordinator,
|
59
|
+
):
|
60
|
+
"""Patch the model to make it compatible with torch.compile"""
|
61
|
+
backup_ca_comm = None
|
62
|
+
|
63
|
+
try:
|
64
|
+
if enable_compile:
|
65
|
+
backup_ca_comm = tp_group.ca_comm
|
66
|
+
# Use custom-allreduce here.
|
67
|
+
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
68
|
+
# even with ENABLE_INTRA_NODE_COMM=1.
|
69
|
+
# tp_group.ca_comm = None
|
70
|
+
yield torch.compile(
|
71
|
+
torch.no_grad()(model.forward),
|
72
|
+
dynamic=False,
|
73
|
+
)
|
74
|
+
else:
|
75
|
+
yield model.forward
|
76
|
+
finally:
|
77
|
+
if enable_compile:
|
78
|
+
tp_group.ca_comm = backup_ca_comm
|
79
|
+
|
80
|
+
|
81
|
+
def set_torch_compile_config():
|
82
|
+
import torch._dynamo.config
|
83
|
+
import torch._inductor.config
|
84
|
+
|
85
|
+
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
86
|
+
torch._inductor.config.freezing = True
|
87
|
+
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
88
|
+
if hasattr(torch._dynamo.config, "cache_size_limit"):
|
89
|
+
torch._dynamo.config.cache_size_limit = 1024
|
90
|
+
monkey_patch_torch_compile()
|
91
|
+
|
92
|
+
|
93
|
+
def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
94
|
+
server_args = model_runner.server_args
|
95
|
+
# cpu torch compile only speeds up decoding by
|
96
|
+
# reducing python overhead when bs is small
|
97
|
+
capture_bs = list(range(1, 17))
|
98
|
+
capture_bs = [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
99
|
+
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
100
|
+
capture_bs = list(sorted(set(capture_bs)))
|
101
|
+
assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
|
102
|
+
return capture_bs
|
103
|
+
|
104
|
+
|
105
|
+
def register_fake_ops():
|
106
|
+
"""
|
107
|
+
Registers fake/meta implementations for all custom sgl_kernel CPU operators
|
108
|
+
using torch.library.register_fake to support torch.compile
|
109
|
+
"""
|
110
|
+
|
111
|
+
none_return_ops = [
|
112
|
+
"shm_allreduce",
|
113
|
+
"bmm_cpu",
|
114
|
+
"fused_add_rmsnorm_cpu",
|
115
|
+
"decode_attention_cpu",
|
116
|
+
"extend_attention_cpu",
|
117
|
+
]
|
118
|
+
for op in none_return_ops:
|
119
|
+
|
120
|
+
@torch.library.register_fake(f"sgl_kernel::{op}")
|
121
|
+
def _(*args, **kwargs):
|
122
|
+
return
|
123
|
+
|
124
|
+
for op in [
|
125
|
+
"rmsnorm_cpu",
|
126
|
+
"l2norm_cpu",
|
127
|
+
"fused_experts_cpu",
|
128
|
+
"shared_expert_cpu",
|
129
|
+
]:
|
130
|
+
|
131
|
+
@torch.library.register_fake(f"sgl_kernel::{op}")
|
132
|
+
def _(input, *args, **kwargs):
|
133
|
+
return torch.empty_like(input)
|
134
|
+
|
135
|
+
@torch.library.register_fake("sgl_kernel::qkv_proj_with_rope")
|
136
|
+
def _(
|
137
|
+
hidden_states,
|
138
|
+
q_a_proj_weight,
|
139
|
+
q_b_proj_weight,
|
140
|
+
kv_a_proj_weight,
|
141
|
+
w_kc,
|
142
|
+
q_a_layernorm_weight,
|
143
|
+
kv_a_layernorm_weight,
|
144
|
+
positions,
|
145
|
+
cos_sin_cache,
|
146
|
+
eps,
|
147
|
+
use_int8_w8a8,
|
148
|
+
use_fp8_w8a16,
|
149
|
+
q_a_proj_scale,
|
150
|
+
q_b_proj_scale,
|
151
|
+
kv_a_proj_scale,
|
152
|
+
is_vnni,
|
153
|
+
block_size,
|
154
|
+
):
|
155
|
+
num_seqs = hidden_states.shape[0]
|
156
|
+
num_heads = w_kc.shape[0]
|
157
|
+
kv_lora_rank = w_kc.shape[1]
|
158
|
+
qk_rope_head_dim = kv_a_proj_weight.shape[0] - kv_lora_rank
|
159
|
+
q_input = torch.empty(
|
160
|
+
num_seqs,
|
161
|
+
num_heads,
|
162
|
+
kv_lora_rank + qk_rope_head_dim,
|
163
|
+
dtype=hidden_states.dtype,
|
164
|
+
device=hidden_states.device,
|
165
|
+
)
|
166
|
+
k_input = torch.empty(
|
167
|
+
num_seqs,
|
168
|
+
1,
|
169
|
+
kv_lora_rank + qk_rope_head_dim,
|
170
|
+
dtype=hidden_states.dtype,
|
171
|
+
device=hidden_states.device,
|
172
|
+
)
|
173
|
+
v_input = k_input.narrow(-1, 0, kv_lora_rank)
|
174
|
+
return q_input, k_input, v_input
|
175
|
+
|
176
|
+
@torch.library.register_fake("sgl_kernel::rotary_embedding_cpu")
|
177
|
+
def _(positions, query, key, head_size, cos_sin_cache, is_neox):
|
178
|
+
if query.ndim == 2:
|
179
|
+
return query, key
|
180
|
+
else:
|
181
|
+
return torch.empty_like(query), torch.empty_like(key)
|
182
|
+
|
183
|
+
@torch.library.register_fake("sgl_kernel::qkv_proj_with_rope_fused_weight")
|
184
|
+
def _(
|
185
|
+
hidden_states,
|
186
|
+
q_a_proj_weight,
|
187
|
+
q_b_proj_weight,
|
188
|
+
w_kc,
|
189
|
+
q_a_layernorm_weight,
|
190
|
+
kv_a_layernorm_weight,
|
191
|
+
positions,
|
192
|
+
cos_sin_cache,
|
193
|
+
eps,
|
194
|
+
use_int8_w8a8,
|
195
|
+
use_fp8_w8a16,
|
196
|
+
qkv_a_proj_scale,
|
197
|
+
q_b_proj_scale,
|
198
|
+
is_vnni,
|
199
|
+
block_size,
|
200
|
+
q_lora_rank,
|
201
|
+
kv_lora_rank,
|
202
|
+
qk_rope_head_dim,
|
203
|
+
):
|
204
|
+
num_seqs = hidden_states.shape[0]
|
205
|
+
num_heads = w_kc.shape[0]
|
206
|
+
kv_lora_rank = w_kc.shape[1]
|
207
|
+
weight_chunks = torch.split(
|
208
|
+
q_a_proj_weight, [q_lora_rank, kv_lora_rank + qk_rope_head_dim], dim=0
|
209
|
+
)
|
210
|
+
qk_rope_head_dim = weight_chunks[1].shape[0] - kv_lora_rank
|
211
|
+
q_input = torch.empty(
|
212
|
+
num_seqs,
|
213
|
+
num_heads,
|
214
|
+
kv_lora_rank + qk_rope_head_dim,
|
215
|
+
dtype=hidden_states.dtype,
|
216
|
+
device=hidden_states.device,
|
217
|
+
)
|
218
|
+
k_input = torch.empty(
|
219
|
+
num_seqs,
|
220
|
+
1,
|
221
|
+
kv_lora_rank + qk_rope_head_dim,
|
222
|
+
dtype=hidden_states.dtype,
|
223
|
+
device=hidden_states.device,
|
224
|
+
)
|
225
|
+
v_input = k_input.narrow(-1, 0, kv_lora_rank)
|
226
|
+
return q_input, k_input, v_input
|
227
|
+
|
228
|
+
@torch.library.register_fake("sgl_kernel::weight_packed_linear")
|
229
|
+
def _(x, weight, bias, is_vnni):
|
230
|
+
return x.new_empty(x.shape[0], weight.shape[0])
|
231
|
+
|
232
|
+
@torch.library.register_fake("sgl_kernel::per_token_quant_int8_cpu")
|
233
|
+
def _(input):
|
234
|
+
M = input.shape[0]
|
235
|
+
K = input.shape[1]
|
236
|
+
Aq = input.new_empty(M, K, dtype=torch.int8)
|
237
|
+
As = input.new_empty(M, dtype=torch.float32)
|
238
|
+
return Aq, As
|
239
|
+
|
240
|
+
@torch.library.register_fake("sgl_kernel::int8_scaled_mm_cpu")
|
241
|
+
def _(mat1, mat2, scales1, scales2, bias, out_dtype, is_vnni):
|
242
|
+
M = mat1.shape[0]
|
243
|
+
N = mat2.shape[0]
|
244
|
+
out = mat1.new_empty(M, N, dtype=out_dtype)
|
245
|
+
return out
|
246
|
+
|
247
|
+
@torch.library.register_fake("sgl_kernel::grouped_topk_cpu")
|
248
|
+
def _(
|
249
|
+
hidden_states,
|
250
|
+
gating_output,
|
251
|
+
topk,
|
252
|
+
renormalize,
|
253
|
+
num_expert_group,
|
254
|
+
topk_group,
|
255
|
+
num_fused_shared_experts,
|
256
|
+
routed_scaling_factor,
|
257
|
+
num_token_non_padded,
|
258
|
+
):
|
259
|
+
num_tokens = hidden_states.shape[0]
|
260
|
+
shape = (num_tokens, topk)
|
261
|
+
device = hidden_states.device
|
262
|
+
topk_weights = torch.empty(shape, device=device, dtype=torch.float32)
|
263
|
+
topk_ids = torch.empty(shape, device=device, dtype=torch.int)
|
264
|
+
return topk_weights, topk_ids
|
265
|
+
|
266
|
+
@torch.library.register_fake("sgl_kernel::biased_grouped_topk_cpu")
|
267
|
+
def _(
|
268
|
+
hidden_states,
|
269
|
+
gating_output,
|
270
|
+
correction_bias,
|
271
|
+
topk,
|
272
|
+
renormalize,
|
273
|
+
num_expert_group,
|
274
|
+
topk_group,
|
275
|
+
num_fused_shared_experts,
|
276
|
+
routed_scaling_factor,
|
277
|
+
num_token_non_padded,
|
278
|
+
):
|
279
|
+
num_tokens = hidden_states.shape[0]
|
280
|
+
shape = (num_tokens, topk)
|
281
|
+
device = hidden_states.device
|
282
|
+
topk_weights = torch.empty(shape, device=device, dtype=torch.float32)
|
283
|
+
topk_ids = torch.empty(shape, device=device, dtype=torch.int)
|
284
|
+
return topk_weights, topk_ids
|
285
|
+
|
286
|
+
@torch.library.register_fake("sgl_kernel::topk_sigmoid_cpu")
|
287
|
+
def _(hidden_states, gating_output, topk, renormalize):
|
288
|
+
num_tokens = hidden_states.shape[0]
|
289
|
+
shape = (num_tokens, topk)
|
290
|
+
return (
|
291
|
+
torch.empty(shape, device=hidden_states.device, dtype=torch.float),
|
292
|
+
torch.empty(shape, device=hidden_states.device, dtype=torch.int),
|
293
|
+
)
|
294
|
+
|
295
|
+
@torch.library.register_fake("sgl_kernel::topk_softmax_cpu")
|
296
|
+
def _(
|
297
|
+
hidden_states,
|
298
|
+
gating_output,
|
299
|
+
topk,
|
300
|
+
renormalize,
|
301
|
+
):
|
302
|
+
num_tokens = hidden_states.shape[0]
|
303
|
+
shape = (num_tokens, topk)
|
304
|
+
return (
|
305
|
+
torch.empty(shape, device=hidden_states.device, dtype=torch.float),
|
306
|
+
torch.empty(shape, device=hidden_states.device, dtype=torch.int),
|
307
|
+
)
|
308
|
+
|
309
|
+
@torch.library.register_fake("sgl_kernel::silu_and_mul_cpu")
|
310
|
+
def _(input):
|
311
|
+
return input.new_empty(input.shape[0], input.shape[1] // 2)
|
312
|
+
|
313
|
+
@torch.library.register_fake("sgl_kernel::int8_scaled_mm_with_quant")
|
314
|
+
def _(
|
315
|
+
mat1,
|
316
|
+
mat2,
|
317
|
+
scales2,
|
318
|
+
bias,
|
319
|
+
out_dtype,
|
320
|
+
is_vnni,
|
321
|
+
):
|
322
|
+
M = mat1.shape[0]
|
323
|
+
N = mat2.shape[0]
|
324
|
+
return mat1.new_empty(M, N, dtype=out_dtype)
|
325
|
+
|
326
|
+
@torch.library.register_fake("sgl_kernel::fp8_scaled_mm_cpu")
|
327
|
+
def _(
|
328
|
+
mat1,
|
329
|
+
mat2,
|
330
|
+
scales2,
|
331
|
+
block_size,
|
332
|
+
bias,
|
333
|
+
out_dtype,
|
334
|
+
is_vnni,
|
335
|
+
):
|
336
|
+
M = mat1.shape[0]
|
337
|
+
N = mat2.shape[0]
|
338
|
+
return mat1.new_empty(M, N, dtype=out_dtype)
|
339
|
+
|
340
|
+
|
341
|
+
# TODO Remove unnecessary settings for CPUGraphRunner.
|
342
|
+
# Re-abstract the graph runner and restructure CPUGraphRunner to reuse the same logic.
|
343
|
+
class CPUGraphRunner:
|
344
|
+
"""A CPUGraphRunner runs the forward pass of a model with cpu torch.compile."""
|
345
|
+
|
346
|
+
def __init__(self, model_runner: ModelRunner):
|
347
|
+
# Parse args
|
348
|
+
self.model_runner = model_runner
|
349
|
+
self.device = model_runner.device
|
350
|
+
self.graphs = {}
|
351
|
+
self.output_buffers = {}
|
352
|
+
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
353
|
+
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
354
|
+
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
355
|
+
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
356
|
+
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
357
|
+
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
358
|
+
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
359
|
+
self.enable_two_batch_overlap = (
|
360
|
+
model_runner.server_args.enable_two_batch_overlap
|
361
|
+
)
|
362
|
+
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
363
|
+
self.enable_profile_cuda_graph = (
|
364
|
+
model_runner.server_args.enable_profile_cuda_graph
|
365
|
+
)
|
366
|
+
self.tp_size = model_runner.server_args.tp_size
|
367
|
+
self.dp_size = model_runner.server_args.dp_size
|
368
|
+
self.pp_size = model_runner.server_args.pp_size
|
369
|
+
|
370
|
+
self.capture_forward_mode = ForwardMode.DECODE
|
371
|
+
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
372
|
+
self.num_tokens_per_bs = 1
|
373
|
+
|
374
|
+
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
|
375
|
+
if model_runner.server_args.enable_return_hidden_states:
|
376
|
+
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
377
|
+
|
378
|
+
assert (
|
379
|
+
not self.model_runner.server_args.enable_lora
|
380
|
+
), "CPUGraphRunner does not support LoRA yet."
|
381
|
+
assert (
|
382
|
+
not self.enable_two_batch_overlap
|
383
|
+
), "CPUGraphRunner does not support two batch overlap yet."
|
384
|
+
assert (
|
385
|
+
not self.require_mlp_tp_gather
|
386
|
+
), "CPUGraphRunner does not support MLP TP gather yet."
|
387
|
+
assert (
|
388
|
+
not self.require_mlp_sync
|
389
|
+
), "CPUGraphRunner does not support MLP sync yet."
|
390
|
+
assert (
|
391
|
+
not self.require_gathered_buffer
|
392
|
+
), "CPUGraphRunner does not support gathered buffer yet."
|
393
|
+
assert (
|
394
|
+
model_runner.spec_algorithm == SpeculativeAlgorithm.NONE
|
395
|
+
), "CPUGraphRunner does not support speculative inference yet."
|
396
|
+
# TODO add compile support for encoder-decoder models
|
397
|
+
assert (
|
398
|
+
not self.is_encoder_decoder
|
399
|
+
), "CPUGraphRunner does not support encoder-decoder models yet."
|
400
|
+
assert self.dp_size == 1, "CPUGraphRunner does not support DP yet."
|
401
|
+
assert self.pp_size == 1, "CPUGraphRunner does not support PP yet."
|
402
|
+
|
403
|
+
# Batch sizes to capture
|
404
|
+
self.capture_bs = get_batch_sizes_to_capture(model_runner)
|
405
|
+
log_info_on_rank0(logger, f"Capture cpu graph bs {self.capture_bs}")
|
406
|
+
# Attention backend
|
407
|
+
self.max_bs = max(self.capture_bs)
|
408
|
+
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
409
|
+
|
410
|
+
self.seq_len_fill_value = (
|
411
|
+
self.model_runner.attn_backend.get_graph_seq_len_fill_value()
|
412
|
+
)
|
413
|
+
|
414
|
+
if self.enable_torch_compile:
|
415
|
+
register_fake_ops()
|
416
|
+
set_torch_compile_config()
|
417
|
+
|
418
|
+
# Graph inputs
|
419
|
+
with torch.device(self.device):
|
420
|
+
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
421
|
+
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int64)
|
422
|
+
self.seq_lens = torch.full(
|
423
|
+
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int64
|
424
|
+
)
|
425
|
+
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
426
|
+
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
427
|
+
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
428
|
+
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int64)
|
429
|
+
self.custom_mask = torch.ones(
|
430
|
+
(
|
431
|
+
(self.seq_lens.sum().item() + self.max_num_token)
|
432
|
+
* self.num_tokens_per_bs
|
433
|
+
),
|
434
|
+
dtype=torch.bool,
|
435
|
+
device=self.device,
|
436
|
+
)
|
437
|
+
|
438
|
+
# Capture
|
439
|
+
try:
|
440
|
+
self.capture()
|
441
|
+
except RuntimeError as e:
|
442
|
+
raise Exception(
|
443
|
+
f"Capture CPU graph failed: {e}\n{CPU_GRAPH_CAPTURE_FAILED_MSG}"
|
444
|
+
)
|
445
|
+
|
446
|
+
def can_run(self, forward_batch: ForwardBatch):
|
447
|
+
is_bs_supported = forward_batch.batch_size in self.graphs
|
448
|
+
|
449
|
+
requested_capture_hidden_mode = max(
|
450
|
+
forward_batch.capture_hidden_mode,
|
451
|
+
(
|
452
|
+
forward_batch.spec_info.capture_hidden_mode
|
453
|
+
if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
|
454
|
+
is not None
|
455
|
+
else CaptureHiddenMode.NULL
|
456
|
+
),
|
457
|
+
)
|
458
|
+
capture_hidden_mode_matches = (
|
459
|
+
requested_capture_hidden_mode == CaptureHiddenMode.NULL
|
460
|
+
or requested_capture_hidden_mode == self.capture_hidden_mode
|
461
|
+
)
|
462
|
+
|
463
|
+
return is_bs_supported and capture_hidden_mode_matches
|
464
|
+
|
465
|
+
def capture(self) -> None:
|
466
|
+
capture_range = (
|
467
|
+
tqdm.tqdm(list(reversed(self.capture_bs)))
|
468
|
+
if get_tensor_model_parallel_rank() == 0
|
469
|
+
else reversed(self.capture_bs)
|
470
|
+
)
|
471
|
+
for bs in capture_range:
|
472
|
+
if get_tensor_model_parallel_rank() == 0:
|
473
|
+
avail_mem = psutil.virtual_memory().available / (1 << 30)
|
474
|
+
capture_range.set_description(
|
475
|
+
f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
|
476
|
+
)
|
477
|
+
|
478
|
+
with patch_model(
|
479
|
+
self.model_runner.model,
|
480
|
+
bs in self.capture_bs,
|
481
|
+
num_tokens=bs * self.num_tokens_per_bs,
|
482
|
+
tp_group=self.model_runner.tp_group,
|
483
|
+
) as forward:
|
484
|
+
(
|
485
|
+
graph,
|
486
|
+
output_buffers,
|
487
|
+
) = self.capture_one_batch_size(bs, forward)
|
488
|
+
self.graphs[bs] = graph
|
489
|
+
self.output_buffers[bs] = output_buffers
|
490
|
+
|
491
|
+
def capture_one_batch_size(self, bs: int, forward: Callable):
|
492
|
+
num_tokens = bs * self.num_tokens_per_bs
|
493
|
+
|
494
|
+
# Graph inputs
|
495
|
+
input_ids = self.input_ids[:num_tokens]
|
496
|
+
req_pool_indices = self.req_pool_indices[:bs]
|
497
|
+
seq_lens = self.seq_lens[:bs]
|
498
|
+
out_cache_loc = self.out_cache_loc[:num_tokens]
|
499
|
+
positions = self.positions[:num_tokens]
|
500
|
+
mrope_positions = self.mrope_positions[:, :bs]
|
501
|
+
self.num_token_non_padded[...] = num_tokens
|
502
|
+
|
503
|
+
spec_info = self.get_spec_info(num_tokens)
|
504
|
+
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
|
505
|
+
self.capture_hidden_mode = (
|
506
|
+
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
507
|
+
)
|
508
|
+
|
509
|
+
forward_batch = ForwardBatch(
|
510
|
+
forward_mode=self.capture_forward_mode,
|
511
|
+
batch_size=bs,
|
512
|
+
input_ids=input_ids,
|
513
|
+
req_pool_indices=req_pool_indices,
|
514
|
+
seq_lens=seq_lens,
|
515
|
+
req_to_token_pool=self.model_runner.req_to_token_pool,
|
516
|
+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
517
|
+
attn_backend=self.model_runner.attn_backend,
|
518
|
+
out_cache_loc=out_cache_loc,
|
519
|
+
seq_lens_sum=seq_lens.sum().item(),
|
520
|
+
return_logprob=False,
|
521
|
+
positions=positions,
|
522
|
+
mrope_positions=mrope_positions,
|
523
|
+
spec_algorithm=self.model_runner.spec_algorithm,
|
524
|
+
spec_info=spec_info,
|
525
|
+
capture_hidden_mode=self.capture_hidden_mode,
|
526
|
+
num_token_non_padded=self.num_token_non_padded,
|
527
|
+
global_forward_mode=self.capture_forward_mode,
|
528
|
+
)
|
529
|
+
|
530
|
+
# Attention backend
|
531
|
+
self.model_runner.attn_backend.init_forward_metadata(forward_batch)
|
532
|
+
# Do infernence to avoid setting attr at runtime, e.g.,
|
533
|
+
# self.attn_mha.kv_b_proj = self.kv_b_proj for full graph compile on CPU
|
534
|
+
self.model_runner.model.forward(
|
535
|
+
forward_batch.input_ids,
|
536
|
+
forward_batch.positions,
|
537
|
+
forward_batch,
|
538
|
+
)
|
539
|
+
|
540
|
+
# Run and capture
|
541
|
+
def run_once():
|
542
|
+
# Clean intermediate result cache for DP attention
|
543
|
+
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
544
|
+
logits_output_or_pp_proxy_tensors = forward(
|
545
|
+
input_ids,
|
546
|
+
forward_batch.positions,
|
547
|
+
forward_batch,
|
548
|
+
)
|
549
|
+
return logits_output_or_pp_proxy_tensors
|
550
|
+
|
551
|
+
with torch.no_grad():
|
552
|
+
for _ in range(2):
|
553
|
+
self.model_runner.tp_group.barrier()
|
554
|
+
out = run_once()
|
555
|
+
return forward, out
|
556
|
+
|
557
|
+
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
558
|
+
|
559
|
+
# If the required capture_hidden_mode changes, we need to recapture the graph
|
560
|
+
|
561
|
+
# These are the different factors that can influence the capture_hidden_mode
|
562
|
+
capture_hidden_mode_required_by_forward_batch = (
|
563
|
+
forward_batch.capture_hidden_mode
|
564
|
+
)
|
565
|
+
capture_hidden_mode_required_by_spec_info = getattr(
|
566
|
+
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
567
|
+
)
|
568
|
+
capture_hidden_mode_required_for_returning_hidden_states = (
|
569
|
+
CaptureHiddenMode.FULL
|
570
|
+
if self.model_runner.server_args.enable_return_hidden_states
|
571
|
+
else CaptureHiddenMode.NULL
|
572
|
+
)
|
573
|
+
|
574
|
+
# Determine the highest capture_hidden_mode required
|
575
|
+
# (If we have FULL, we can emulate LAST or NULL)
|
576
|
+
# (If we have LAST, we can emulate NULL)
|
577
|
+
required_capture_hidden_mode = max(
|
578
|
+
capture_hidden_mode_required_by_forward_batch,
|
579
|
+
capture_hidden_mode_required_by_spec_info,
|
580
|
+
capture_hidden_mode_required_for_returning_hidden_states,
|
581
|
+
)
|
582
|
+
|
583
|
+
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
|
584
|
+
if self.capture_hidden_mode != required_capture_hidden_mode:
|
585
|
+
self.capture_hidden_mode = required_capture_hidden_mode
|
586
|
+
self.capture()
|
587
|
+
|
588
|
+
# TODO add padding support for CPUGraphRunner
|
589
|
+
def replay(
|
590
|
+
self,
|
591
|
+
forward_batch: ForwardBatch,
|
592
|
+
skip_attn_backend_init: bool = False,
|
593
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
594
|
+
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
595
|
+
assert (
|
596
|
+
pp_proxy_tensors is None
|
597
|
+
), "PPProxyTensors is not supported in CPUGraphRunner yet."
|
598
|
+
self.recapture_if_needed(forward_batch)
|
599
|
+
self.model_runner.attn_backend.init_forward_metadata(forward_batch)
|
600
|
+
output = self.graphs[forward_batch.batch_size](
|
601
|
+
forward_batch.input_ids,
|
602
|
+
forward_batch.positions,
|
603
|
+
forward_batch,
|
604
|
+
)
|
605
|
+
return output
|
606
|
+
|
607
|
+
def get_spec_info(self, num_tokens: int):
|
608
|
+
spec_info = None
|
609
|
+
if self.model_runner.spec_algorithm.is_eagle():
|
610
|
+
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
611
|
+
|
612
|
+
if self.model_runner.is_draft_worker:
|
613
|
+
raise RuntimeError("This should not happen.")
|
614
|
+
else:
|
615
|
+
spec_info = EagleVerifyInput(
|
616
|
+
draft_token=None,
|
617
|
+
custom_mask=self.custom_mask,
|
618
|
+
positions=None,
|
619
|
+
retrive_index=None,
|
620
|
+
retrive_next_token=None,
|
621
|
+
retrive_next_sibling=None,
|
622
|
+
retrive_cum_len=None,
|
623
|
+
spec_steps=self.model_runner.server_args.speculative_num_steps,
|
624
|
+
topk=self.model_runner.server_args.speculative_eagle_topk,
|
625
|
+
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
626
|
+
capture_hidden_mode=CaptureHiddenMode.FULL,
|
627
|
+
seq_lens_sum=None,
|
628
|
+
seq_lens_cpu=None,
|
629
|
+
)
|
630
|
+
|
631
|
+
return spec_info
|
632
|
+
|
633
|
+
|
634
|
+
CPU_GRAPH_CAPTURE_FAILED_MSG = (
|
635
|
+
"Possible solutions:\n"
|
636
|
+
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
637
|
+
"2. set --torch-compile-max-bs to a smaller value (e.g., 8)\n"
|
638
|
+
"3. disable torch compile by not using --enable-torch-compile\n"
|
639
|
+
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
640
|
+
)
|
@@ -271,7 +271,10 @@ class CudaGraphRunner:
|
|
271
271
|
self.capture_forward_mode = ForwardMode.DECODE
|
272
272
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
273
273
|
self.num_tokens_per_bs = 1
|
274
|
-
if
|
274
|
+
if (
|
275
|
+
model_runner.spec_algorithm.is_eagle()
|
276
|
+
or model_runner.spec_algorithm.is_standalone()
|
277
|
+
):
|
275
278
|
if self.model_runner.is_draft_worker:
|
276
279
|
raise RuntimeError("This should not happen")
|
277
280
|
else:
|
@@ -317,7 +320,9 @@ class CudaGraphRunner:
|
|
317
320
|
(self.max_num_token,), dtype=self._cache_loc_dtype()
|
318
321
|
)
|
319
322
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
320
|
-
self.mrope_positions = torch.zeros(
|
323
|
+
self.mrope_positions = torch.zeros(
|
324
|
+
(3, self.max_num_token), dtype=torch.int64
|
325
|
+
)
|
321
326
|
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
|
322
327
|
self.tbo_plugin = TboCudaGraphRunnerPlugin()
|
323
328
|
|
@@ -532,7 +537,7 @@ class CudaGraphRunner:
|
|
532
537
|
encoder_lens = self.encoder_lens[:bs]
|
533
538
|
else:
|
534
539
|
encoder_lens = None
|
535
|
-
mrope_positions = self.mrope_positions[:, :
|
540
|
+
mrope_positions = self.mrope_positions[:, :num_tokens]
|
536
541
|
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
|
537
542
|
self.num_token_non_padded[...] = num_tokens
|
538
543
|
|
@@ -751,7 +756,7 @@ class CudaGraphRunner:
|
|
751
756
|
if self.is_encoder_decoder:
|
752
757
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
753
758
|
if forward_batch.mrope_positions is not None:
|
754
|
-
self.mrope_positions[:, :
|
759
|
+
self.mrope_positions[:, :raw_num_token].copy_(forward_batch.mrope_positions)
|
755
760
|
if self.require_gathered_buffer:
|
756
761
|
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
757
762
|
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
@@ -825,7 +830,10 @@ class CudaGraphRunner:
|
|
825
830
|
|
826
831
|
def get_spec_info(self, num_tokens: int):
|
827
832
|
spec_info = None
|
828
|
-
if
|
833
|
+
if (
|
834
|
+
self.model_runner.spec_algorithm.is_eagle()
|
835
|
+
or self.model_runner.spec_algorithm.is_standalone()
|
836
|
+
):
|
829
837
|
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
830
838
|
|
831
839
|
if self.model_runner.is_draft_worker:
|