sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ import ctypes
|
|
22
22
|
import dataclasses
|
23
23
|
import functools
|
24
24
|
import importlib
|
25
|
+
import inspect
|
25
26
|
import io
|
26
27
|
import ipaddress
|
27
28
|
import itertools
|
@@ -82,11 +83,9 @@ from packaging import version as pkg_version
|
|
82
83
|
from PIL import Image
|
83
84
|
from starlette.routing import Mount
|
84
85
|
from torch import nn
|
85
|
-
from torch.func import functional_call
|
86
86
|
from torch.library import Library
|
87
87
|
from torch.profiler import ProfilerActivity, profile, record_function
|
88
88
|
from torch.utils._contextlib import _DecoratorContextManager
|
89
|
-
from triton.runtime.cache import FileCacheManager
|
90
89
|
from typing_extensions import Literal
|
91
90
|
|
92
91
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
@@ -167,6 +166,7 @@ is_ampere_with_cuda_12_3 = lambda: _check(8)
|
|
167
166
|
is_hopper_with_cuda_12_3 = lambda: _check(9)
|
168
167
|
|
169
168
|
|
169
|
+
@lru_cache(maxsize=1)
|
170
170
|
def is_blackwell():
|
171
171
|
if not is_cuda():
|
172
172
|
return False
|
@@ -175,6 +175,8 @@ def is_blackwell():
|
|
175
175
|
|
176
176
|
@lru_cache(maxsize=1)
|
177
177
|
def is_sm100_supported(device=None) -> bool:
|
178
|
+
if not is_cuda_alike():
|
179
|
+
return False
|
178
180
|
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
179
181
|
torch.version.cuda >= "12.8"
|
180
182
|
)
|
@@ -182,6 +184,8 @@ def is_sm100_supported(device=None) -> bool:
|
|
182
184
|
|
183
185
|
@lru_cache(maxsize=1)
|
184
186
|
def is_sm90_supported(device=None) -> bool:
|
187
|
+
if not is_cuda_alike():
|
188
|
+
return False
|
185
189
|
return (torch.cuda.get_device_capability(device)[0] == 9) and (
|
186
190
|
torch.version.cuda >= "12.3"
|
187
191
|
)
|
@@ -191,6 +195,7 @@ _warned_bool_env_var_keys = set()
|
|
191
195
|
|
192
196
|
|
193
197
|
def get_bool_env_var(name: str, default: str = "false") -> bool:
|
198
|
+
# FIXME: move your environment variable to sglang.srt.environ
|
194
199
|
value = os.getenv(name, default)
|
195
200
|
value = value.lower()
|
196
201
|
|
@@ -208,6 +213,7 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
|
|
208
213
|
|
209
214
|
|
210
215
|
def get_int_env_var(name: str, default: int = 0) -> int:
|
216
|
+
# FIXME: move your environment variable to sglang.srt.environ
|
211
217
|
value = os.getenv(name)
|
212
218
|
if value is None or not value.strip():
|
213
219
|
return default
|
@@ -465,7 +471,7 @@ def is_pin_memory_available() -> bool:
|
|
465
471
|
|
466
472
|
class LayerFn(Protocol):
|
467
473
|
|
468
|
-
def __call__(self,
|
474
|
+
def __call__(self, idx: int, prefix: str) -> torch.nn.Module: ...
|
469
475
|
|
470
476
|
|
471
477
|
def make_layers(
|
@@ -476,7 +482,7 @@ def make_layers(
|
|
476
482
|
prefix: str = "",
|
477
483
|
return_tuple: bool = False,
|
478
484
|
offloader_kwargs: Dict[str, Any] = {},
|
479
|
-
) -> Tuple[
|
485
|
+
) -> Tuple[torch.nn.Module, int, int]:
|
480
486
|
"""Make a list of layers with the given layer function"""
|
481
487
|
# circula imports
|
482
488
|
from sglang.srt.distributed import get_pp_indices
|
@@ -512,6 +518,50 @@ def make_layers(
|
|
512
518
|
return modules, start_layer, end_layer
|
513
519
|
|
514
520
|
|
521
|
+
cmo_stream = None
|
522
|
+
|
523
|
+
|
524
|
+
def get_cmo_stream():
|
525
|
+
"""
|
526
|
+
Cache Management Operation(CMO).
|
527
|
+
Launch a new stream to prefetch the weight of matmul when running other
|
528
|
+
AIV or communication kernels, aiming to overlap the memory access time.
|
529
|
+
"""
|
530
|
+
global cmo_stream
|
531
|
+
if cmo_stream is None:
|
532
|
+
cmo_stream = torch.get_device_module().Stream()
|
533
|
+
return cmo_stream
|
534
|
+
|
535
|
+
|
536
|
+
def prepare_weight_cache(handle, cache):
|
537
|
+
import torch_npu
|
538
|
+
|
539
|
+
NPU_PREFETCH_MAX_SIZE_BYTES = (
|
540
|
+
1000000000 # 1GB, a large value to prefetch entire weight
|
541
|
+
)
|
542
|
+
stream = get_cmo_stream()
|
543
|
+
stream.wait_stream(torch.npu.current_stream())
|
544
|
+
with torch.npu.stream(stream):
|
545
|
+
if isinstance(cache, list):
|
546
|
+
for weight in cache:
|
547
|
+
torch_npu.npu_prefetch(
|
548
|
+
weight,
|
549
|
+
handle,
|
550
|
+
NPU_PREFETCH_MAX_SIZE_BYTES,
|
551
|
+
)
|
552
|
+
else:
|
553
|
+
torch_npu.npu_prefetch(
|
554
|
+
cache,
|
555
|
+
handle,
|
556
|
+
NPU_PREFETCH_MAX_SIZE_BYTES,
|
557
|
+
)
|
558
|
+
|
559
|
+
|
560
|
+
def wait_cmo_stream():
|
561
|
+
cur_stream = torch.get_device_module().current_stream()
|
562
|
+
cur_stream.wait_stream(get_cmo_stream())
|
563
|
+
|
564
|
+
|
515
565
|
def set_random_seed(seed: int) -> None:
|
516
566
|
"""Set the random seed for all libraries."""
|
517
567
|
random.seed(seed)
|
@@ -749,6 +799,25 @@ def load_image(
|
|
749
799
|
return image, image_size
|
750
800
|
|
751
801
|
|
802
|
+
def get_image_bytes(image_file: Union[str, bytes]):
|
803
|
+
if isinstance(image_file, bytes):
|
804
|
+
return image_file
|
805
|
+
elif image_file.startswith("http://") or image_file.startswith("https://"):
|
806
|
+
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
807
|
+
response = requests.get(image_file, timeout=timeout)
|
808
|
+
return response.content
|
809
|
+
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
|
810
|
+
with open(image_file, "rb") as f:
|
811
|
+
return f.read()
|
812
|
+
elif image_file.startswith("data:"):
|
813
|
+
image_file = image_file.split(",")[1]
|
814
|
+
return pybase64.b64decode(image_file)
|
815
|
+
elif isinstance(image_file, str):
|
816
|
+
return pybase64.b64decode(image_file)
|
817
|
+
else:
|
818
|
+
raise NotImplementedError(f"Invalid image: {image_file}")
|
819
|
+
|
820
|
+
|
752
821
|
def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
753
822
|
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
|
754
823
|
from decord import VideoReader, cpu, gpu
|
@@ -804,6 +873,33 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
|
804
873
|
os.unlink(tmp_file.name)
|
805
874
|
|
806
875
|
|
876
|
+
def encode_video(video_path, frame_count_limit=None):
|
877
|
+
# Lazy import because decord is not available on some arm platforms.
|
878
|
+
from decord import VideoReader, cpu
|
879
|
+
|
880
|
+
if not os.path.exists(video_path):
|
881
|
+
logger.error(f"Video {video_path} does not exist")
|
882
|
+
return []
|
883
|
+
|
884
|
+
if frame_count_limit == 0:
|
885
|
+
return []
|
886
|
+
|
887
|
+
def uniform_sample(l, n):
|
888
|
+
gap = len(l) / n
|
889
|
+
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
890
|
+
return [l[i] for i in idxs]
|
891
|
+
|
892
|
+
vr = VideoReader(video_path, ctx=cpu(0))
|
893
|
+
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
894
|
+
frame_indices = [i for i in range(0, len(vr), sample_fps)]
|
895
|
+
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
|
896
|
+
frame_indices = uniform_sample(frame_indices, frame_count_limit)
|
897
|
+
|
898
|
+
frames = vr.get_batch(frame_indices).asnumpy()
|
899
|
+
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
900
|
+
return frames
|
901
|
+
|
902
|
+
|
807
903
|
def suppress_other_loggers():
|
808
904
|
warnings.filterwarnings(
|
809
905
|
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
@@ -946,6 +1042,13 @@ def set_ulimit(target_soft_limit=65535):
|
|
946
1042
|
logger.warning(f"Fail to set RLIMIT_STACK: {e}")
|
947
1043
|
|
948
1044
|
|
1045
|
+
def rank0_log(msg: str):
|
1046
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
1047
|
+
|
1048
|
+
if get_tensor_model_parallel_rank() == 0:
|
1049
|
+
logger.info(msg)
|
1050
|
+
|
1051
|
+
|
949
1052
|
def add_api_key_middleware(app, api_key: str):
|
950
1053
|
@app.middleware("http")
|
951
1054
|
async def authentication(request, call_next):
|
@@ -1404,6 +1507,32 @@ def get_npu_memory_capacity():
|
|
1404
1507
|
raise ImportError("torch_npu is required when run on npu device.")
|
1405
1508
|
|
1406
1509
|
|
1510
|
+
def get_cpu_memory_capacity():
|
1511
|
+
# Per-rank memory capacity cannot be determined for customized core settings
|
1512
|
+
if os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", ""):
|
1513
|
+
return None
|
1514
|
+
n_numa_node: int = len(get_cpu_ids_by_node())
|
1515
|
+
if n_numa_node == 0:
|
1516
|
+
# Cannot determine NUMA config, fallback to total memory and avoid ZeroDivisionError.
|
1517
|
+
return float(psutil.virtual_memory().total // (1 << 20))
|
1518
|
+
try:
|
1519
|
+
numa_mem_list = list()
|
1520
|
+
file_prefix = "/sys/devices/system/node/"
|
1521
|
+
for numa_id in range(n_numa_node):
|
1522
|
+
file_meminfo = f"node{numa_id}/meminfo"
|
1523
|
+
with open(os.path.join(file_prefix, file_meminfo), "r") as f:
|
1524
|
+
# 1st line contains 'MemTotal'
|
1525
|
+
line = f.read().split("\n")[0]
|
1526
|
+
numa_mem_list.append(int(line.split()[3]))
|
1527
|
+
# Retrieved value in KB, need MB
|
1528
|
+
numa_mem = float(min(numa_mem_list) // 1024)
|
1529
|
+
return numa_mem
|
1530
|
+
except FileNotFoundError:
|
1531
|
+
numa_mem = psutil.virtual_memory().total / n_numa_node
|
1532
|
+
# Retrieved value in Byte, need MB
|
1533
|
+
return float(numa_mem // (1 << 20))
|
1534
|
+
|
1535
|
+
|
1407
1536
|
def get_device_memory_capacity(device: str = None):
|
1408
1537
|
if is_cuda():
|
1409
1538
|
gpu_mem = get_nvgpu_memory_capacity()
|
@@ -1413,6 +1542,8 @@ def get_device_memory_capacity(device: str = None):
|
|
1413
1542
|
gpu_mem = get_hpu_memory_capacity()
|
1414
1543
|
elif device == "npu":
|
1415
1544
|
gpu_mem = get_npu_memory_capacity()
|
1545
|
+
elif device == "cpu":
|
1546
|
+
gpu_mem = get_cpu_memory_capacity()
|
1416
1547
|
else:
|
1417
1548
|
# GPU memory is not known yet or no GPU is available.
|
1418
1549
|
gpu_mem = None
|
@@ -1951,50 +2082,6 @@ def set_uvicorn_logging_configs():
|
|
1951
2082
|
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
1952
2083
|
|
1953
2084
|
|
1954
|
-
def get_ip() -> str:
|
1955
|
-
# SGLANG_HOST_IP env can be ignore
|
1956
|
-
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
|
1957
|
-
if host_ip:
|
1958
|
-
return host_ip
|
1959
|
-
|
1960
|
-
# IP is not set, try to get it from the network interface
|
1961
|
-
|
1962
|
-
# try ipv4
|
1963
|
-
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
1964
|
-
try:
|
1965
|
-
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
1966
|
-
return s.getsockname()[0]
|
1967
|
-
except Exception:
|
1968
|
-
pass
|
1969
|
-
|
1970
|
-
# try ipv6
|
1971
|
-
try:
|
1972
|
-
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
1973
|
-
# Google's public DNS server, see
|
1974
|
-
# https://developers.google.com/speed/public-dns/docs/using#addresses
|
1975
|
-
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
1976
|
-
return s.getsockname()[0]
|
1977
|
-
except Exception:
|
1978
|
-
pass
|
1979
|
-
|
1980
|
-
# try using hostname
|
1981
|
-
hostname = socket.gethostname()
|
1982
|
-
try:
|
1983
|
-
ip_addr = socket.gethostbyname(hostname)
|
1984
|
-
warnings.warn("using local ip address: {}".format(ip_addr))
|
1985
|
-
return ip_addr
|
1986
|
-
except Exception:
|
1987
|
-
pass
|
1988
|
-
|
1989
|
-
warnings.warn(
|
1990
|
-
"Failed to get the IP address, using 0.0.0.0 by default."
|
1991
|
-
"The value can be set by the environment variable"
|
1992
|
-
" SGLANG_HOST_IP or HOST_IP.",
|
1993
|
-
stacklevel=2,
|
1994
|
-
)
|
1995
|
-
return "0.0.0.0"
|
1996
|
-
|
1997
|
-
|
1998
2085
|
def get_open_port() -> int:
|
1999
2086
|
port = os.getenv("SGLANG_PORT")
|
2000
2087
|
if port is not None:
|
@@ -2251,16 +2338,9 @@ def bind_or_assign(target, source):
|
|
2251
2338
|
return source
|
2252
2339
|
|
2253
2340
|
|
2254
|
-
def
|
2255
|
-
interface
|
2256
|
-
|
2257
|
-
get_local_ip_by_nic(interface)
|
2258
|
-
if interface is not None
|
2259
|
-
else get_local_ip_by_remote()
|
2260
|
-
)
|
2261
|
-
|
2262
|
-
|
2263
|
-
def get_local_ip_by_nic(interface: str) -> str:
|
2341
|
+
def get_local_ip_by_nic(interface: str = None) -> Optional[str]:
|
2342
|
+
if not (interface := interface or os.environ.get("SGLANG_LOCAL_IP_NIC", None)):
|
2343
|
+
return None
|
2264
2344
|
try:
|
2265
2345
|
import netifaces
|
2266
2346
|
except ImportError as e:
|
@@ -2281,15 +2361,13 @@ def get_local_ip_by_nic(interface: str) -> str:
|
|
2281
2361
|
if ip and not ip.startswith("fe80::") and ip != "::1":
|
2282
2362
|
return ip.split("%")[0]
|
2283
2363
|
except (ValueError, OSError) as e:
|
2284
|
-
|
2285
|
-
"Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
|
2364
|
+
logger.warning(
|
2365
|
+
f"{e} Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
|
2286
2366
|
)
|
2287
|
-
|
2288
|
-
# Fallback
|
2289
|
-
return get_local_ip_by_remote()
|
2367
|
+
return None
|
2290
2368
|
|
2291
2369
|
|
2292
|
-
def get_local_ip_by_remote() -> str:
|
2370
|
+
def get_local_ip_by_remote() -> Optional[str]:
|
2293
2371
|
# try ipv4
|
2294
2372
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
2295
2373
|
try:
|
@@ -2314,7 +2392,51 @@ def get_local_ip_by_remote() -> str:
|
|
2314
2392
|
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
2315
2393
|
return s.getsockname()[0]
|
2316
2394
|
except Exception:
|
2317
|
-
|
2395
|
+
logger.warning("Can not get local ip by remote")
|
2396
|
+
return None
|
2397
|
+
|
2398
|
+
|
2399
|
+
def get_local_ip_auto(fallback: str = None) -> str:
|
2400
|
+
"""
|
2401
|
+
Automatically detect the local IP address using multiple fallback strategies.
|
2402
|
+
|
2403
|
+
This function attempts to obtain the local IP address through several methods.
|
2404
|
+
If all methods fail, it returns the specified fallback value or raises an exception.
|
2405
|
+
|
2406
|
+
Args:
|
2407
|
+
fallback (str, optional): Fallback IP address to return if all detection
|
2408
|
+
methods fail. For server applications, explicitly set this to
|
2409
|
+
"0.0.0.0" (IPv4) or "::" (IPv6) to bind to all available interfaces.
|
2410
|
+
Defaults to None.
|
2411
|
+
|
2412
|
+
Returns:
|
2413
|
+
str: The detected local IP address, or the fallback value if detection fails.
|
2414
|
+
|
2415
|
+
Raises:
|
2416
|
+
ValueError: If IP detection fails and no fallback value is provided.
|
2417
|
+
|
2418
|
+
Note:
|
2419
|
+
The function tries detection methods in the following order:
|
2420
|
+
1. Direct IP detection via get_ip()
|
2421
|
+
2. Network interface enumeration via get_local_ip_by_nic()
|
2422
|
+
3. Remote connection method via get_local_ip_by_remote()
|
2423
|
+
"""
|
2424
|
+
# Try environment variable
|
2425
|
+
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
|
2426
|
+
if host_ip:
|
2427
|
+
return host_ip
|
2428
|
+
logger.debug("get_ip failed")
|
2429
|
+
# Fallback
|
2430
|
+
if ip := get_local_ip_by_nic():
|
2431
|
+
return ip
|
2432
|
+
logger.debug("get_local_ip_by_nic failed")
|
2433
|
+
# Fallback
|
2434
|
+
if ip := get_local_ip_by_remote():
|
2435
|
+
return ip
|
2436
|
+
logger.debug("get_local_ip_by_remote failed")
|
2437
|
+
if fallback:
|
2438
|
+
return fallback
|
2439
|
+
raise ValueError("Can not get local ip")
|
2318
2440
|
|
2319
2441
|
|
2320
2442
|
def is_page_size_one(server_args):
|
@@ -2366,7 +2488,7 @@ class BumpAllocator:
|
|
2366
2488
|
def log_info_on_rank0(logger, msg):
|
2367
2489
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
2368
2490
|
|
2369
|
-
if get_tensor_model_parallel_rank() == 0:
|
2491
|
+
if torch.distributed.is_initialized() and get_tensor_model_parallel_rank() == 0:
|
2370
2492
|
logger.info(msg)
|
2371
2493
|
|
2372
2494
|
|
@@ -2496,14 +2618,6 @@ def read_system_prompt_from_file(model_name: str) -> str:
|
|
2496
2618
|
return ""
|
2497
2619
|
|
2498
2620
|
|
2499
|
-
def bind_or_assign(target, source):
|
2500
|
-
if target is not None:
|
2501
|
-
target.copy_(source)
|
2502
|
-
return target
|
2503
|
-
else:
|
2504
|
-
return source
|
2505
|
-
|
2506
|
-
|
2507
2621
|
def prepack_weight_if_needed(weight):
|
2508
2622
|
if weight.device != torch.device("cpu"):
|
2509
2623
|
return weight
|
@@ -3042,6 +3156,44 @@ def check_cuda_result(raw_output):
|
|
3042
3156
|
return results
|
3043
3157
|
|
3044
3158
|
|
3159
|
+
def get_physical_device_id(pytorch_device_id: int) -> int:
|
3160
|
+
"""
|
3161
|
+
Convert PyTorch logical device ID to physical device ID.
|
3162
|
+
"""
|
3163
|
+
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
3164
|
+
assert (
|
3165
|
+
cuda_visible_devices is not None
|
3166
|
+
), "CUDA_VISIBLE_DEVICES should be set in a scheduler"
|
3167
|
+
device_list = cuda_visible_devices.split(",")
|
3168
|
+
assert (
|
3169
|
+
len(device_list) == 1
|
3170
|
+
), "CUDA_VISIBLE_DEVICES should be set to a single device in a scheduler"
|
3171
|
+
return int(device_list[0])
|
3172
|
+
|
3173
|
+
|
3174
|
+
def get_device_sm_nvidia_smi():
|
3175
|
+
try:
|
3176
|
+
# Run nvidia-smi command and capture output
|
3177
|
+
result = subprocess.run(
|
3178
|
+
["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
|
3179
|
+
capture_output=True,
|
3180
|
+
text=True,
|
3181
|
+
check=True,
|
3182
|
+
)
|
3183
|
+
|
3184
|
+
# Get the first line of output (assuming at least one GPU exists)
|
3185
|
+
compute_cap_str = result.stdout.strip().split("\n")[0]
|
3186
|
+
|
3187
|
+
# Convert string (e.g., "9.0") to tuple of integers (9, 0)
|
3188
|
+
major, minor = map(int, compute_cap_str.split("."))
|
3189
|
+
return (major, minor)
|
3190
|
+
|
3191
|
+
except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
|
3192
|
+
# Handle cases where nvidia-smi isn't available or output is unexpected
|
3193
|
+
print(f"Error getting compute capability: {e}")
|
3194
|
+
return (0, 0) # Default/fallback value
|
3195
|
+
|
3196
|
+
|
3045
3197
|
def numa_bind_to_node(node: int):
|
3046
3198
|
libnuma = ctypes.CDLL("libnuma.so")
|
3047
3199
|
if libnuma.numa_available() < 0:
|
@@ -3058,3 +3210,176 @@ def json_list_type(value):
|
|
3058
3210
|
raise argparse.ArgumentTypeError(
|
3059
3211
|
f"Invalid JSON list: {value}. Please provide a valid JSON list."
|
3060
3212
|
)
|
3213
|
+
|
3214
|
+
|
3215
|
+
@contextmanager
|
3216
|
+
def temp_set_cuda_visible_devices(gpu_id: int):
|
3217
|
+
original_cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
|
3218
|
+
if original_cuda_visible_devices:
|
3219
|
+
cuda_visible_devices = original_cuda_visible_devices.split(",")
|
3220
|
+
else:
|
3221
|
+
cuda_visible_devices = []
|
3222
|
+
|
3223
|
+
str_gpu_id = cuda_visible_devices[gpu_id] if cuda_visible_devices else str(gpu_id)
|
3224
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = str_gpu_id
|
3225
|
+
yield
|
3226
|
+
if original_cuda_visible_devices:
|
3227
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = original_cuda_visible_devices
|
3228
|
+
else:
|
3229
|
+
del os.environ["CUDA_VISIBLE_DEVICES"]
|
3230
|
+
|
3231
|
+
|
3232
|
+
def get_extend_input_len_swa_limit(
|
3233
|
+
sliding_window_size: int, chunked_prefill_size: int, page_size: int
|
3234
|
+
) -> int:
|
3235
|
+
# 1. a factor of 2x is because each prefill contains chunked_prefill_size tokens,
|
3236
|
+
# and between prefills, we run swa_radix_cache.cache_unfinished_req(),
|
3237
|
+
# so we unlock the previously locked nodes.
|
3238
|
+
# 2. max is to handle the case that chunked_prefill_size is larger than sliding_window_size.
|
3239
|
+
# in that case, each prefill contains chunked_prefill_size tokens,
|
3240
|
+
# and we can only free out-of-sliding-window kv indices after each prefill.
|
3241
|
+
# 3. page_size is because we want to have 1 token extra for generated tokens.
|
3242
|
+
return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
|
3243
|
+
|
3244
|
+
|
3245
|
+
def get_num_new_pages(
|
3246
|
+
seq_lens: torch.Tensor,
|
3247
|
+
page_size: int,
|
3248
|
+
prefix_lens: Optional[torch.Tensor] = None,
|
3249
|
+
decode: bool = False,
|
3250
|
+
) -> torch.Tensor:
|
3251
|
+
"""
|
3252
|
+
Get the number of new pages for the given prefix and sequence lengths.
|
3253
|
+
We use cpu tensors to avoid blocking kernel launch.
|
3254
|
+
"""
|
3255
|
+
cpu_device = torch.device("cpu")
|
3256
|
+
assert seq_lens.device == cpu_device
|
3257
|
+
|
3258
|
+
if prefix_lens is None or decode:
|
3259
|
+
# NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`.
|
3260
|
+
assert decode
|
3261
|
+
return (seq_lens % page_size == 1).int().sum().item()
|
3262
|
+
|
3263
|
+
assert prefix_lens.device == cpu_device
|
3264
|
+
num_pages_after = (seq_lens + page_size - 1) // page_size
|
3265
|
+
num_pages_before = (prefix_lens + page_size - 1) // page_size
|
3266
|
+
num_new_pages = num_pages_after - num_pages_before
|
3267
|
+
sum_num_new_pages = torch.sum(num_new_pages).to(torch.int64)
|
3268
|
+
return sum_num_new_pages.item()
|
3269
|
+
|
3270
|
+
|
3271
|
+
class CachedKernel:
|
3272
|
+
"""
|
3273
|
+
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
|
3274
|
+
|
3275
|
+
This wrapper caches compiled Triton kernels based on keys extracted by a
|
3276
|
+
user-provided key function to avoid redundant compilations.
|
3277
|
+
"""
|
3278
|
+
|
3279
|
+
def __init__(self, fn, key_fn=None):
|
3280
|
+
self.fn = fn
|
3281
|
+
assert isinstance(fn, triton.runtime.jit.JITFunction)
|
3282
|
+
|
3283
|
+
original_fn = fn.fn
|
3284
|
+
self.signature = inspect.signature(original_fn)
|
3285
|
+
self.param_names = tuple(self.signature.parameters.keys())
|
3286
|
+
self.num_args = len(self.param_names)
|
3287
|
+
|
3288
|
+
# Check that no parameters have default values
|
3289
|
+
for name, param in self.signature.parameters.items():
|
3290
|
+
assert (
|
3291
|
+
param.default is inspect.Parameter.empty
|
3292
|
+
), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
|
3293
|
+
|
3294
|
+
functools.update_wrapper(self, original_fn)
|
3295
|
+
self.kernel_cache = {}
|
3296
|
+
|
3297
|
+
# Store the key function
|
3298
|
+
self.key_fn = key_fn
|
3299
|
+
|
3300
|
+
def __getitem__(self, grid):
|
3301
|
+
"""
|
3302
|
+
Index with grid to get a launcher function.
|
3303
|
+
Returns a launcher that will handle caching based on the key function.
|
3304
|
+
"""
|
3305
|
+
assert (
|
3306
|
+
isinstance(grid, tuple) and len(grid) <= 3
|
3307
|
+
), "Grid must be a tuple with at most 3 dimensions."
|
3308
|
+
|
3309
|
+
# Normalize grid once
|
3310
|
+
if len(grid) < 3:
|
3311
|
+
grid = grid + (1,) * (3 - len(grid))
|
3312
|
+
|
3313
|
+
def launcher(*args, **kwargs):
|
3314
|
+
cache_key = self.key_fn(args, kwargs)
|
3315
|
+
|
3316
|
+
cached_kernel = self.kernel_cache.get(cache_key)
|
3317
|
+
|
3318
|
+
if cached_kernel is None:
|
3319
|
+
# First time: compile and cache the kernel
|
3320
|
+
cached_kernel = self.fn[grid](*args, **kwargs)
|
3321
|
+
self.kernel_cache[cache_key] = cached_kernel
|
3322
|
+
return cached_kernel
|
3323
|
+
else:
|
3324
|
+
# Use cached kernel
|
3325
|
+
all_args = self._build_args(args, kwargs)
|
3326
|
+
cached_kernel[grid](*all_args)
|
3327
|
+
return cached_kernel
|
3328
|
+
|
3329
|
+
return launcher
|
3330
|
+
|
3331
|
+
def _build_args(self, args, kwargs):
|
3332
|
+
"""
|
3333
|
+
Build the complete argument list for kernel invocation.
|
3334
|
+
"""
|
3335
|
+
complete_args = list(args)
|
3336
|
+
|
3337
|
+
for i in range(len(args), self.num_args):
|
3338
|
+
name = self.param_names[i]
|
3339
|
+
value = kwargs.get(name, inspect.Parameter.empty)
|
3340
|
+
if value is not inspect.Parameter.empty:
|
3341
|
+
complete_args.append(value)
|
3342
|
+
else:
|
3343
|
+
raise ValueError(f"Missing argument: {name}")
|
3344
|
+
|
3345
|
+
return complete_args
|
3346
|
+
|
3347
|
+
def _clear_cache(self):
|
3348
|
+
"""
|
3349
|
+
Clear the kernel cache for testing purposes.
|
3350
|
+
"""
|
3351
|
+
self.kernel_cache.clear()
|
3352
|
+
|
3353
|
+
|
3354
|
+
def cached_triton_kernel(key_fn=None):
|
3355
|
+
"""
|
3356
|
+
Decorator that enables key-based caching for Triton kernels using a key function.
|
3357
|
+
|
3358
|
+
It essentially bypasses Triton's built-in caching mechanism, allowing users to
|
3359
|
+
define their own caching strategy based on kernel parameters. This helps reduce
|
3360
|
+
the heavy overheads of Triton kernel launch when the kernel specialization dispatch
|
3361
|
+
is simple.
|
3362
|
+
|
3363
|
+
Usage:
|
3364
|
+
@cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
|
3365
|
+
@triton.jit
|
3366
|
+
def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
|
3367
|
+
...
|
3368
|
+
|
3369
|
+
# Invoke normally
|
3370
|
+
my_kernel[grid](x, y, BLOCK_SIZE=1024)
|
3371
|
+
|
3372
|
+
Args:
|
3373
|
+
key_fn: A function that takes (args, kwargs) and returns the cache key(s).
|
3374
|
+
The key can be a single value or a tuple of values.
|
3375
|
+
|
3376
|
+
Returns:
|
3377
|
+
A decorator that wraps the kernel with caching functionality.
|
3378
|
+
|
3379
|
+
Note: Kernels with default parameter values are not supported and will raise an assertion error.
|
3380
|
+
"""
|
3381
|
+
|
3382
|
+
def decorator(fn):
|
3383
|
+
return CachedKernel(fn, key_fn)
|
3384
|
+
|
3385
|
+
return decorator
|