sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -48,18 +48,22 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
48
48
|
PPProxyTensors,
|
49
49
|
enable_num_token_non_padded,
|
50
50
|
)
|
51
|
-
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
52
51
|
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
53
52
|
from sglang.srt.utils import (
|
54
53
|
empty_context,
|
55
54
|
get_available_gpu_memory,
|
55
|
+
get_bool_env_var,
|
56
56
|
get_device_memory_capacity,
|
57
|
+
is_hip,
|
57
58
|
log_info_on_rank0,
|
58
59
|
require_attn_tp_gather,
|
59
60
|
require_gathered_buffer,
|
60
61
|
require_mlp_sync,
|
61
62
|
require_mlp_tp_gather,
|
62
63
|
)
|
64
|
+
from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
|
65
|
+
|
66
|
+
_is_hip = is_hip()
|
63
67
|
|
64
68
|
logger = logging.getLogger(__name__)
|
65
69
|
|
@@ -100,6 +104,7 @@ def freeze_gc(enable_cudagraph_gc: bool):
|
|
100
104
|
finally:
|
101
105
|
if should_freeze:
|
102
106
|
gc.unfreeze()
|
107
|
+
gc.collect()
|
103
108
|
|
104
109
|
|
105
110
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
@@ -136,7 +141,7 @@ def patch_model(
|
|
136
141
|
mode=os.environ.get(
|
137
142
|
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
|
138
143
|
),
|
139
|
-
dynamic=
|
144
|
+
dynamic=_is_hip and get_bool_env_var("SGLANG_TORCH_DYNAMIC_SHAPE"),
|
140
145
|
)
|
141
146
|
else:
|
142
147
|
yield model.forward
|
@@ -166,29 +171,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
166
171
|
server_args = model_runner.server_args
|
167
172
|
capture_bs = server_args.cuda_graph_bs
|
168
173
|
|
169
|
-
if capture_bs is None:
|
170
|
-
if server_args.speculative_algorithm is None:
|
171
|
-
if server_args.disable_cuda_graph_padding:
|
172
|
-
capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
|
173
|
-
else:
|
174
|
-
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
|
175
|
-
else:
|
176
|
-
# Since speculative decoding requires more cuda graph memory, we
|
177
|
-
# capture less.
|
178
|
-
capture_bs = (
|
179
|
-
list(range(1, 9))
|
180
|
-
+ list(range(10, 33, 2))
|
181
|
-
+ list(range(40, 64, 8))
|
182
|
-
+ list(range(80, 161, 16))
|
183
|
-
)
|
184
|
-
|
185
|
-
gpu_mem = get_device_memory_capacity()
|
186
|
-
if gpu_mem is not None:
|
187
|
-
if gpu_mem > 90 * 1024: # H200, H20
|
188
|
-
capture_bs += list(range(160, 257, 8))
|
189
|
-
if gpu_mem > 160 * 1000: # B200, MI300
|
190
|
-
capture_bs += list(range(256, 513, 16))
|
191
|
-
|
192
174
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
193
175
|
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
194
176
|
# is very small. We add more values here to make sure we capture the maximum bs.
|
@@ -204,12 +186,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
204
186
|
|
205
187
|
capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
|
206
188
|
|
207
|
-
if server_args.cuda_graph_max_bs:
|
208
|
-
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
209
|
-
if max(capture_bs) < server_args.cuda_graph_max_bs:
|
210
|
-
capture_bs += list(
|
211
|
-
range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
|
212
|
-
)
|
213
189
|
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
214
190
|
capture_bs = list(sorted(set(capture_bs)))
|
215
191
|
assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
|
@@ -274,6 +250,7 @@ class CudaGraphRunner:
|
|
274
250
|
if (
|
275
251
|
model_runner.spec_algorithm.is_eagle()
|
276
252
|
or model_runner.spec_algorithm.is_standalone()
|
253
|
+
or model_runner.spec_algorithm.is_ngram()
|
277
254
|
):
|
278
255
|
if self.model_runner.is_draft_worker:
|
279
256
|
raise RuntimeError("This should not happen")
|
@@ -440,11 +417,21 @@ class CudaGraphRunner:
|
|
440
417
|
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
441
418
|
)
|
442
419
|
|
420
|
+
is_ngram_supported = (
|
421
|
+
(
|
422
|
+
forward_batch.batch_size * self.num_tokens_per_bs
|
423
|
+
== forward_batch.input_ids.numel()
|
424
|
+
)
|
425
|
+
if self.model_runner.spec_algorithm.is_ngram()
|
426
|
+
else True
|
427
|
+
)
|
428
|
+
|
443
429
|
return (
|
444
430
|
is_bs_supported
|
445
431
|
and is_encoder_lens_supported
|
446
432
|
and is_tbo_supported
|
447
433
|
and capture_hidden_mode_matches
|
434
|
+
and is_ngram_supported
|
448
435
|
)
|
449
436
|
|
450
437
|
def capture(self) -> None:
|
@@ -454,6 +441,7 @@ class CudaGraphRunner:
|
|
454
441
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
455
442
|
record_shapes=True,
|
456
443
|
)
|
444
|
+
torch.cuda.memory._record_memory_history()
|
457
445
|
|
458
446
|
# Trigger CUDA graph capture for specific shapes.
|
459
447
|
# Capture the large shapes first so that the smaller shapes
|
@@ -502,6 +490,8 @@ class CudaGraphRunner:
|
|
502
490
|
save_gemlite_cache()
|
503
491
|
|
504
492
|
if self.enable_profile_cuda_graph:
|
493
|
+
torch.cuda.memory._dump_snapshot(f"cuda_graph_runner_memory_usage.pickle")
|
494
|
+
torch.cuda.memory._record_memory_history(enabled=None)
|
505
495
|
log_message = (
|
506
496
|
"Sorted by CUDA Time:\n"
|
507
497
|
+ prof.key_averages(group_by_input_shape=True).table(
|
@@ -511,6 +501,7 @@ class CudaGraphRunner:
|
|
511
501
|
+ prof.key_averages(group_by_input_shape=True).table(
|
512
502
|
sort_by="cpu_time_total", row_limit=10
|
513
503
|
)
|
504
|
+
+ "\n\nMemory Usage is saved to cuda_graph_runner_memory_usage.pickle\n"
|
514
505
|
)
|
515
506
|
logger.info(log_message)
|
516
507
|
|
@@ -531,6 +522,7 @@ class CudaGraphRunner:
|
|
531
522
|
input_ids = self.input_ids[:num_tokens]
|
532
523
|
req_pool_indices = self.req_pool_indices[:bs]
|
533
524
|
seq_lens = self.seq_lens[:bs]
|
525
|
+
seq_lens_cpu = self.seq_lens_cpu[:bs]
|
534
526
|
out_cache_loc = self.out_cache_loc[:num_tokens]
|
535
527
|
positions = self.positions[:num_tokens]
|
536
528
|
if self.is_encoder_decoder:
|
@@ -601,6 +593,7 @@ class CudaGraphRunner:
|
|
601
593
|
input_ids=input_ids,
|
602
594
|
req_pool_indices=req_pool_indices,
|
603
595
|
seq_lens=seq_lens,
|
596
|
+
seq_lens_cpu=seq_lens_cpu,
|
604
597
|
next_token_logits_buffer=next_token_logits_buffer,
|
605
598
|
orig_seq_lens=seq_lens,
|
606
599
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
@@ -834,7 +827,7 @@ class CudaGraphRunner:
|
|
834
827
|
self.model_runner.spec_algorithm.is_eagle()
|
835
828
|
or self.model_runner.spec_algorithm.is_standalone()
|
836
829
|
):
|
837
|
-
from sglang.srt.speculative.
|
830
|
+
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
838
831
|
|
839
832
|
if self.model_runner.is_draft_worker:
|
840
833
|
raise RuntimeError("This should not happen.")
|
@@ -855,6 +848,20 @@ class CudaGraphRunner:
|
|
855
848
|
seq_lens_cpu=None,
|
856
849
|
)
|
857
850
|
|
851
|
+
elif self.model_runner.spec_algorithm.is_ngram():
|
852
|
+
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
853
|
+
|
854
|
+
spec_info = NgramVerifyInput(
|
855
|
+
draft_token=None,
|
856
|
+
tree_mask=self.custom_mask,
|
857
|
+
positions=None,
|
858
|
+
retrive_index=None,
|
859
|
+
retrive_next_token=None,
|
860
|
+
retrive_next_sibling=None,
|
861
|
+
draft_token_num=self.num_tokens_per_bs,
|
862
|
+
)
|
863
|
+
spec_info.capture_hidden_mode = CaptureHiddenMode.NULL
|
864
|
+
|
858
865
|
return spec_info
|
859
866
|
|
860
867
|
|
@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
|
|
45
45
|
get_attention_tp_size,
|
46
46
|
set_dp_buffer_len,
|
47
47
|
)
|
48
|
-
from sglang.srt.
|
49
|
-
from sglang.srt.utils import (
|
50
|
-
flatten_nested_list,
|
51
|
-
get_compiler_backend,
|
52
|
-
is_npu,
|
53
|
-
support_triton,
|
54
|
-
)
|
48
|
+
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
|
55
49
|
|
56
50
|
if TYPE_CHECKING:
|
57
51
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
@@ -60,8 +54,7 @@ if TYPE_CHECKING:
|
|
60
54
|
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
61
55
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
62
56
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
63
|
-
from sglang.srt.speculative.
|
64
|
-
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
57
|
+
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
|
65
58
|
|
66
59
|
_is_npu = is_npu()
|
67
60
|
|
@@ -293,13 +286,14 @@ class ForwardBatch:
|
|
293
286
|
global_forward_mode: Optional[ForwardMode] = None
|
294
287
|
|
295
288
|
# Speculative decoding
|
296
|
-
spec_info: Optional[
|
289
|
+
spec_info: Optional[SpecInput] = None
|
297
290
|
spec_algorithm: SpeculativeAlgorithm = None
|
298
291
|
capture_hidden_mode: CaptureHiddenMode = None
|
299
292
|
|
300
293
|
# For padding
|
301
294
|
padded_static_len: int = -1 # -1 if not padded
|
302
295
|
num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor
|
296
|
+
num_token_non_padded_cpu: int = None
|
303
297
|
|
304
298
|
# For Qwen2-VL
|
305
299
|
mrope_positions: torch.Tensor = None
|
@@ -361,36 +355,18 @@ class ForwardBatch:
|
|
361
355
|
ret.num_token_non_padded = torch.tensor(
|
362
356
|
len(batch.input_ids), dtype=torch.int32
|
363
357
|
).to(device, non_blocking=True)
|
358
|
+
ret.num_token_non_padded_cpu = len(batch.input_ids)
|
364
359
|
|
365
360
|
# For MLP sync
|
366
361
|
if batch.global_num_tokens is not None:
|
367
|
-
from sglang.srt.speculative.eagle_utils import (
|
368
|
-
EagleDraftInput,
|
369
|
-
EagleVerifyInput,
|
370
|
-
)
|
371
|
-
|
372
362
|
assert batch.global_num_tokens_for_logprob is not None
|
363
|
+
|
373
364
|
# process global_num_tokens and global_num_tokens_for_logprob
|
374
365
|
if batch.spec_info is not None:
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
]
|
380
|
-
global_num_tokens_for_logprob = [
|
381
|
-
x * batch.spec_info.num_tokens_for_logprob_per_batch
|
382
|
-
for x in batch.global_num_tokens_for_logprob
|
383
|
-
]
|
384
|
-
else:
|
385
|
-
assert isinstance(batch.spec_info, EagleVerifyInput)
|
386
|
-
global_num_tokens = [
|
387
|
-
x * batch.spec_info.draft_token_num
|
388
|
-
for x in batch.global_num_tokens
|
389
|
-
]
|
390
|
-
global_num_tokens_for_logprob = [
|
391
|
-
x * batch.spec_info.draft_token_num
|
392
|
-
for x in batch.global_num_tokens_for_logprob
|
393
|
-
]
|
366
|
+
spec_info: SpecInput = batch.spec_info
|
367
|
+
global_num_tokens, global_num_tokens_for_logprob = (
|
368
|
+
spec_info.get_spec_adjusted_global_num_tokens(batch)
|
369
|
+
)
|
394
370
|
else:
|
395
371
|
global_num_tokens = batch.global_num_tokens
|
396
372
|
global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
|
@@ -669,9 +645,6 @@ class ForwardBatch:
|
|
669
645
|
)
|
670
646
|
|
671
647
|
def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
|
672
|
-
|
673
|
-
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
674
|
-
|
675
648
|
assert self.global_num_tokens_cpu is not None
|
676
649
|
assert self.global_num_tokens_for_logprob_cpu is not None
|
677
650
|
|
@@ -768,7 +741,8 @@ class ForwardBatch:
|
|
768
741
|
if self.extend_seq_lens is not None:
|
769
742
|
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
|
770
743
|
|
771
|
-
if self.spec_info is not None and
|
744
|
+
if self.spec_info is not None and self.spec_info.is_draft_input():
|
745
|
+
# FIXME(lsyin): remove this isinstance logic
|
772
746
|
spec_info = self.spec_info
|
773
747
|
self.output_cache_loc_backup = self.out_cache_loc
|
774
748
|
self.hidden_states_backup = spec_info.hidden_states
|
@@ -928,6 +902,17 @@ class ForwardBatch:
|
|
928
902
|
return self.tbo_split_seq_index is not None
|
929
903
|
|
930
904
|
|
905
|
+
@dataclass
|
906
|
+
class ForwardBatchOutput:
|
907
|
+
# FIXME(lsyin): unify the forward batch output between different spec and parallelism
|
908
|
+
# need to be more organized
|
909
|
+
logits_output: Optional[torch.Tensor] = None
|
910
|
+
next_token_ids: Optional[torch.Tensor] = None
|
911
|
+
num_accepted_tokens: Optional[int] = None
|
912
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None
|
913
|
+
can_run_cuda_graph: bool = False
|
914
|
+
|
915
|
+
|
931
916
|
def enable_num_token_non_padded(server_args):
|
932
917
|
return get_moe_expert_parallel_world_size() > 1
|
933
918
|
|