sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -17,7 +17,7 @@ from __future__ import annotations
|
|
|
17
17
|
|
|
18
18
|
from dataclasses import dataclass
|
|
19
19
|
|
|
20
|
-
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
|
|
20
|
+
from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, Mamba2CacheParams
|
|
21
21
|
from sglang.srt.layers.attention.nsa import index_buf_accessor
|
|
22
22
|
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
|
23
23
|
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
@@ -33,7 +33,7 @@ KVCache actually holds the physical kv cache.
|
|
|
33
33
|
|
|
34
34
|
import abc
|
|
35
35
|
import logging
|
|
36
|
-
from contextlib import nullcontext
|
|
36
|
+
from contextlib import contextmanager, nullcontext
|
|
37
37
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
|
38
38
|
|
|
39
39
|
import numpy as np
|
|
@@ -59,7 +59,9 @@ if _is_npu:
|
|
|
59
59
|
import torch_npu
|
|
60
60
|
|
|
61
61
|
|
|
62
|
-
def get_tensor_size_bytes(t: torch.Tensor):
|
|
62
|
+
def get_tensor_size_bytes(t: Union[torch.Tensor, List[torch.Tensor]]):
|
|
63
|
+
if isinstance(t, list):
|
|
64
|
+
return sum(get_tensor_size_bytes(x) for x in t)
|
|
63
65
|
return np.prod(t.shape) * t.dtype.itemsize
|
|
64
66
|
|
|
65
67
|
|
|
@@ -116,10 +118,15 @@ class ReqToTokenPool:
|
|
|
116
118
|
class MambaPool:
|
|
117
119
|
@dataclass(frozen=True, kw_only=True)
|
|
118
120
|
class State:
|
|
119
|
-
conv: torch.Tensor
|
|
121
|
+
conv: Union[torch.Tensor, List[torch.Tensor]]
|
|
120
122
|
temporal: torch.Tensor
|
|
121
123
|
|
|
122
124
|
def at_layer_idx(self, layer: int):
|
|
125
|
+
if isinstance(self.conv, list):
|
|
126
|
+
return type(self)(
|
|
127
|
+
conv=[v[layer] for v in self.conv],
|
|
128
|
+
temporal=self.temporal[layer],
|
|
129
|
+
)
|
|
123
130
|
return type(self)(**{k: v[layer] for k, v in vars(self).items()})
|
|
124
131
|
|
|
125
132
|
def mem_usage_bytes(self):
|
|
@@ -127,14 +134,14 @@ class MambaPool:
|
|
|
127
134
|
|
|
128
135
|
@dataclass(frozen=True, kw_only=True)
|
|
129
136
|
class SpeculativeState(State):
|
|
130
|
-
intermediate_ssm: torch.Tensor
|
|
137
|
+
intermediate_ssm: Union[torch.Tensor, List[torch.Tensor]]
|
|
131
138
|
intermediate_conv_window: torch.Tensor
|
|
132
139
|
|
|
133
140
|
def __init__(
|
|
134
141
|
self,
|
|
135
142
|
*,
|
|
136
143
|
size: int,
|
|
137
|
-
cache_params: "Mamba2CacheParams",
|
|
144
|
+
cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
|
|
138
145
|
device: str,
|
|
139
146
|
speculative_num_draft_tokens: Optional[int] = None,
|
|
140
147
|
):
|
|
@@ -157,18 +164,29 @@ class MambaPool:
|
|
|
157
164
|
else:
|
|
158
165
|
self.custom_mem_pool = None
|
|
159
166
|
|
|
167
|
+
self.is_kda_cache = isinstance(cache_params, KimiLinearCacheParams)
|
|
160
168
|
with (
|
|
161
169
|
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
|
162
170
|
if self.enable_custom_mem_pool
|
|
163
171
|
else nullcontext()
|
|
164
172
|
):
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
173
|
+
if self.is_kda_cache:
|
|
174
|
+
conv_state = [
|
|
175
|
+
torch.zeros(
|
|
176
|
+
size=(num_mamba_layers, size + 1) + conv_shape,
|
|
177
|
+
dtype=conv_dtype,
|
|
178
|
+
device=device,
|
|
179
|
+
)
|
|
180
|
+
for conv_shape in conv_state_shape
|
|
181
|
+
]
|
|
182
|
+
else:
|
|
183
|
+
# assume conv_state = (dim, state_len)
|
|
184
|
+
assert conv_state_shape[0] > conv_state_shape[1]
|
|
185
|
+
conv_state = torch.zeros(
|
|
186
|
+
size=(num_mamba_layers, size + 1) + conv_state_shape,
|
|
187
|
+
dtype=conv_dtype,
|
|
188
|
+
device=device,
|
|
189
|
+
)
|
|
172
190
|
temporal_state = torch.zeros(
|
|
173
191
|
size=(num_mamba_layers, size + 1) + temporal_state_shape,
|
|
174
192
|
dtype=ssm_dtype,
|
|
@@ -191,17 +209,34 @@ class MambaPool:
|
|
|
191
209
|
)
|
|
192
210
|
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
|
|
193
211
|
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
212
|
+
|
|
213
|
+
if self.is_kda_cache:
|
|
214
|
+
intermediate_conv_window_cache = [
|
|
215
|
+
torch.zeros(
|
|
216
|
+
size=(
|
|
217
|
+
num_mamba_layers,
|
|
218
|
+
size + 1,
|
|
219
|
+
speculative_num_draft_tokens,
|
|
220
|
+
conv_shape[0],
|
|
221
|
+
conv_shape[1],
|
|
222
|
+
),
|
|
223
|
+
dtype=conv_dtype,
|
|
224
|
+
device="cuda",
|
|
225
|
+
)
|
|
226
|
+
for conv_shape in conv_state_shape
|
|
227
|
+
]
|
|
228
|
+
else:
|
|
229
|
+
intermediate_conv_window_cache = torch.zeros(
|
|
230
|
+
size=(
|
|
231
|
+
num_mamba_layers,
|
|
232
|
+
size + 1,
|
|
233
|
+
speculative_num_draft_tokens,
|
|
234
|
+
conv_state_shape[0],
|
|
235
|
+
conv_state_shape[1],
|
|
236
|
+
),
|
|
237
|
+
dtype=conv_dtype,
|
|
238
|
+
device="cuda",
|
|
239
|
+
)
|
|
205
240
|
self.mamba_cache = self.SpeculativeState(
|
|
206
241
|
conv=conv_state,
|
|
207
242
|
temporal=temporal_state,
|
|
@@ -255,15 +290,25 @@ class MambaPool:
|
|
|
255
290
|
if free_index.numel() == 0:
|
|
256
291
|
return
|
|
257
292
|
self.free_slots = torch.cat((self.free_slots, free_index))
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
293
|
+
if self.is_kda_cache:
|
|
294
|
+
for i in range(len(self.mamba_cache.conv)):
|
|
295
|
+
self.mamba_cache.conv[i][:, free_index] = 0
|
|
296
|
+
else:
|
|
297
|
+
self.mamba_cache.conv[:, free_index] = 0
|
|
298
|
+
self.mamba_cache.temporal[:, free_index] = 0
|
|
261
299
|
|
|
262
300
|
def clear(self):
|
|
263
301
|
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
|
|
264
302
|
|
|
265
303
|
def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor):
|
|
266
|
-
|
|
304
|
+
if self.is_kda_cache:
|
|
305
|
+
for i in range(len(self.mamba_cache.conv)):
|
|
306
|
+
self.mamba_cache.conv[i][:, dst_index] = self.mamba_cache.conv[i][
|
|
307
|
+
:, src_index
|
|
308
|
+
]
|
|
309
|
+
else:
|
|
310
|
+
self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
|
|
311
|
+
|
|
267
312
|
self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[
|
|
268
313
|
:, src_index
|
|
269
314
|
]
|
|
@@ -304,7 +349,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
|
|
304
349
|
max_context_len: int,
|
|
305
350
|
device: str,
|
|
306
351
|
enable_memory_saver: bool,
|
|
307
|
-
cache_params: "Mamba2CacheParams",
|
|
352
|
+
cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
|
|
308
353
|
speculative_num_draft_tokens: int = None,
|
|
309
354
|
):
|
|
310
355
|
super().__init__(
|
|
@@ -323,7 +368,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
|
|
323
368
|
def _init_mamba_pool(
|
|
324
369
|
self,
|
|
325
370
|
size: int,
|
|
326
|
-
cache_params: "Mamba2CacheParams",
|
|
371
|
+
cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
|
|
327
372
|
device: str,
|
|
328
373
|
speculative_num_draft_tokens: int = None,
|
|
329
374
|
):
|
|
@@ -509,6 +554,7 @@ class MHATokenToKVPool(KVCache):
|
|
|
509
554
|
enable_memory_saver: bool,
|
|
510
555
|
start_layer: Optional[int] = None,
|
|
511
556
|
end_layer: Optional[int] = None,
|
|
557
|
+
enable_alt_stream: bool = True,
|
|
512
558
|
enable_kv_cache_copy: bool = False,
|
|
513
559
|
):
|
|
514
560
|
super().__init__(
|
|
@@ -527,7 +573,9 @@ class MHATokenToKVPool(KVCache):
|
|
|
527
573
|
self._create_buffers()
|
|
528
574
|
|
|
529
575
|
self.device_module = torch.get_device_module(self.device)
|
|
530
|
-
self.alt_stream =
|
|
576
|
+
self.alt_stream = (
|
|
577
|
+
self.device_module.Stream() if _is_cuda and enable_alt_stream else None
|
|
578
|
+
)
|
|
531
579
|
|
|
532
580
|
if enable_kv_cache_copy:
|
|
533
581
|
self._init_kv_copy_and_warmup()
|
|
@@ -809,6 +857,10 @@ class HybridLinearKVPool(KVCache):
|
|
|
809
857
|
enable_kvcache_transpose: bool,
|
|
810
858
|
device: str,
|
|
811
859
|
mamba_pool: MambaPool,
|
|
860
|
+
# TODO: refactor mla related args
|
|
861
|
+
use_mla: bool = False,
|
|
862
|
+
kv_lora_rank: int = None,
|
|
863
|
+
qk_rope_head_dim: int = None,
|
|
812
864
|
):
|
|
813
865
|
self.size = size
|
|
814
866
|
self.dtype = dtype
|
|
@@ -822,25 +874,42 @@ class HybridLinearKVPool(KVCache):
|
|
|
822
874
|
self.mamba_pool = mamba_pool
|
|
823
875
|
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
|
824
876
|
assert not enable_kvcache_transpose
|
|
825
|
-
|
|
826
|
-
|
|
877
|
+
self.use_mla = use_mla
|
|
878
|
+
if not use_mla:
|
|
879
|
+
if _is_npu:
|
|
880
|
+
TokenToKVPoolClass = AscendTokenToKVPool
|
|
881
|
+
else:
|
|
882
|
+
TokenToKVPoolClass = MHATokenToKVPool
|
|
883
|
+
self.full_kv_pool = TokenToKVPoolClass(
|
|
884
|
+
size=size,
|
|
885
|
+
page_size=self.page_size,
|
|
886
|
+
dtype=dtype,
|
|
887
|
+
head_num=head_num,
|
|
888
|
+
head_dim=head_dim,
|
|
889
|
+
layer_num=self.full_layer_nums,
|
|
890
|
+
device=device,
|
|
891
|
+
enable_memory_saver=False,
|
|
892
|
+
)
|
|
827
893
|
else:
|
|
828
|
-
TokenToKVPoolClass =
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
894
|
+
TokenToKVPoolClass = MLATokenToKVPool
|
|
895
|
+
self.full_kv_pool = TokenToKVPoolClass(
|
|
896
|
+
size=size,
|
|
897
|
+
page_size=self.page_size,
|
|
898
|
+
dtype=dtype,
|
|
899
|
+
layer_num=self.full_layer_nums,
|
|
900
|
+
device=device,
|
|
901
|
+
kv_lora_rank=kv_lora_rank,
|
|
902
|
+
qk_rope_head_dim=qk_rope_head_dim,
|
|
903
|
+
enable_memory_saver=False,
|
|
904
|
+
)
|
|
839
905
|
self.full_attention_layer_id_mapping = {
|
|
840
906
|
id: i for i, id in enumerate(full_attention_layer_ids)
|
|
841
907
|
}
|
|
842
|
-
|
|
843
|
-
|
|
908
|
+
if use_mla:
|
|
909
|
+
self.mem_usage = self.get_kv_size_bytes() / GB
|
|
910
|
+
else:
|
|
911
|
+
k_size, v_size = self.get_kv_size_bytes()
|
|
912
|
+
self.mem_usage = (k_size + v_size) / GB
|
|
844
913
|
|
|
845
914
|
def get_kv_size_bytes(self):
|
|
846
915
|
return self.full_kv_pool.get_kv_size_bytes()
|
|
@@ -876,6 +945,21 @@ class HybridLinearKVPool(KVCache):
|
|
|
876
945
|
layer_id = self._transfer_full_attention_id(layer_id)
|
|
877
946
|
return self.full_kv_pool.get_kv_buffer(layer_id)
|
|
878
947
|
|
|
948
|
+
@contextmanager
|
|
949
|
+
def _transfer_id_context(self, layer: RadixAttention):
|
|
950
|
+
|
|
951
|
+
@contextmanager
|
|
952
|
+
def _patch_layer_id(layer):
|
|
953
|
+
original_layer_id = layer.layer_id
|
|
954
|
+
layer.layer_id = self._transfer_full_attention_id(layer.layer_id)
|
|
955
|
+
try:
|
|
956
|
+
yield
|
|
957
|
+
finally:
|
|
958
|
+
layer.layer_id = original_layer_id
|
|
959
|
+
|
|
960
|
+
with _patch_layer_id(layer):
|
|
961
|
+
yield
|
|
962
|
+
|
|
879
963
|
def set_kv_buffer(
|
|
880
964
|
self,
|
|
881
965
|
layer: RadixAttention,
|
|
@@ -886,19 +970,49 @@ class HybridLinearKVPool(KVCache):
|
|
|
886
970
|
v_scale: float = 1.0,
|
|
887
971
|
):
|
|
888
972
|
layer_id = self._transfer_full_attention_id(layer.layer_id)
|
|
889
|
-
self.
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
973
|
+
if not self.use_mla:
|
|
974
|
+
self.full_kv_pool.set_kv_buffer(
|
|
975
|
+
None,
|
|
976
|
+
loc,
|
|
977
|
+
cache_k,
|
|
978
|
+
cache_v,
|
|
979
|
+
k_scale,
|
|
980
|
+
v_scale,
|
|
981
|
+
layer_id_override=layer_id,
|
|
982
|
+
)
|
|
983
|
+
else:
|
|
984
|
+
with self._transfer_id_context(layer):
|
|
985
|
+
self.full_kv_pool.set_kv_buffer(
|
|
986
|
+
layer,
|
|
987
|
+
loc,
|
|
988
|
+
cache_k,
|
|
989
|
+
cache_v,
|
|
990
|
+
)
|
|
898
991
|
|
|
899
992
|
def get_v_head_dim(self):
|
|
900
993
|
return self.full_kv_pool.get_value_buffer(0).shape[-1]
|
|
901
994
|
|
|
995
|
+
def set_mla_kv_buffer(
|
|
996
|
+
self,
|
|
997
|
+
layer: RadixAttention,
|
|
998
|
+
loc: torch.Tensor,
|
|
999
|
+
cache_k_nope: torch.Tensor,
|
|
1000
|
+
cache_k_rope: torch.Tensor,
|
|
1001
|
+
):
|
|
1002
|
+
assert self.use_mla, "set_mla_kv_buffer called when use_mla is False"
|
|
1003
|
+
with self._transfer_id_context(layer):
|
|
1004
|
+
self.full_kv_pool.set_mla_kv_buffer(layer, loc, cache_k_nope, cache_k_rope)
|
|
1005
|
+
|
|
1006
|
+
def get_mla_kv_buffer(
|
|
1007
|
+
self,
|
|
1008
|
+
layer: RadixAttention,
|
|
1009
|
+
loc: torch.Tensor,
|
|
1010
|
+
dst_dtype: Optional[torch.dtype] = None,
|
|
1011
|
+
):
|
|
1012
|
+
assert self.use_mla, "get_mla_kv_buffer called when use_mla is False"
|
|
1013
|
+
with self._transfer_id_context(layer):
|
|
1014
|
+
return self.full_kv_pool.get_mla_kv_buffer(layer, loc, dst_dtype)
|
|
1015
|
+
|
|
902
1016
|
|
|
903
1017
|
class SWAKVPool(KVCache):
|
|
904
1018
|
"""KV cache with separate pools for full and SWA attention layers."""
|
|
@@ -1137,10 +1251,10 @@ class AscendTokenToKVPool(MHATokenToKVPool):
|
|
|
1137
1251
|
torch_npu._npu_reshape_and_cache(
|
|
1138
1252
|
key=cache_k,
|
|
1139
1253
|
value=cache_v,
|
|
1140
|
-
key_cache=self.k_buffer[layer_id].view(
|
|
1254
|
+
key_cache=self.k_buffer[layer_id - self.start_layer].view(
|
|
1141
1255
|
-1, self.page_size, self.head_num, self.head_dim
|
|
1142
1256
|
),
|
|
1143
|
-
value_cache=self.v_buffer[layer_id].view(
|
|
1257
|
+
value_cache=self.v_buffer[layer_id - self.start_layer].view(
|
|
1144
1258
|
-1, self.page_size, self.head_num, self.head_dim
|
|
1145
1259
|
),
|
|
1146
1260
|
slot_indices=loc,
|
|
@@ -1213,6 +1327,65 @@ def set_mla_kv_buffer_triton(
|
|
|
1213
1327
|
)
|
|
1214
1328
|
|
|
1215
1329
|
|
|
1330
|
+
@triton.jit
|
|
1331
|
+
def get_mla_kv_buffer_kernel(
|
|
1332
|
+
kv_buffer_ptr,
|
|
1333
|
+
cache_k_nope_ptr,
|
|
1334
|
+
cache_k_rope_ptr,
|
|
1335
|
+
loc_ptr,
|
|
1336
|
+
buffer_stride: tl.constexpr,
|
|
1337
|
+
nope_stride: tl.constexpr,
|
|
1338
|
+
rope_stride: tl.constexpr,
|
|
1339
|
+
nope_dim: tl.constexpr,
|
|
1340
|
+
rope_dim: tl.constexpr,
|
|
1341
|
+
):
|
|
1342
|
+
pid_loc = tl.program_id(0)
|
|
1343
|
+
loc = tl.load(loc_ptr + pid_loc)
|
|
1344
|
+
loc_src_ptr = kv_buffer_ptr + loc * buffer_stride
|
|
1345
|
+
|
|
1346
|
+
nope_offs = tl.arange(0, nope_dim)
|
|
1347
|
+
nope_src_ptr = loc_src_ptr + nope_offs
|
|
1348
|
+
nope_src = tl.load(nope_src_ptr)
|
|
1349
|
+
|
|
1350
|
+
tl.store(
|
|
1351
|
+
cache_k_nope_ptr + pid_loc * nope_stride + nope_offs,
|
|
1352
|
+
nope_src,
|
|
1353
|
+
)
|
|
1354
|
+
|
|
1355
|
+
rope_offs = tl.arange(0, rope_dim)
|
|
1356
|
+
rope_src_ptr = loc_src_ptr + nope_dim + rope_offs
|
|
1357
|
+
rope_src = tl.load(rope_src_ptr)
|
|
1358
|
+
tl.store(
|
|
1359
|
+
cache_k_rope_ptr + pid_loc * rope_stride + rope_offs,
|
|
1360
|
+
rope_src,
|
|
1361
|
+
)
|
|
1362
|
+
|
|
1363
|
+
|
|
1364
|
+
def get_mla_kv_buffer_triton(
|
|
1365
|
+
kv_buffer: torch.Tensor,
|
|
1366
|
+
loc: torch.Tensor,
|
|
1367
|
+
cache_k_nope: torch.Tensor,
|
|
1368
|
+
cache_k_rope: torch.Tensor,
|
|
1369
|
+
):
|
|
1370
|
+
# The source data type will be implicitly converted to the target data type.
|
|
1371
|
+
nope_dim = cache_k_nope.shape[-1] # 512
|
|
1372
|
+
rope_dim = cache_k_rope.shape[-1] # 64
|
|
1373
|
+
n_loc = loc.numel()
|
|
1374
|
+
grid = (n_loc,)
|
|
1375
|
+
|
|
1376
|
+
get_mla_kv_buffer_kernel[grid](
|
|
1377
|
+
kv_buffer,
|
|
1378
|
+
cache_k_nope,
|
|
1379
|
+
cache_k_rope,
|
|
1380
|
+
loc,
|
|
1381
|
+
kv_buffer.stride(0),
|
|
1382
|
+
cache_k_nope.stride(0),
|
|
1383
|
+
cache_k_rope.stride(0),
|
|
1384
|
+
nope_dim,
|
|
1385
|
+
rope_dim,
|
|
1386
|
+
)
|
|
1387
|
+
|
|
1388
|
+
|
|
1216
1389
|
class MLATokenToKVPool(KVCache):
|
|
1217
1390
|
def __init__(
|
|
1218
1391
|
self,
|
|
@@ -1363,6 +1536,29 @@ class MLATokenToKVPool(KVCache):
|
|
|
1363
1536
|
cache_k_rope,
|
|
1364
1537
|
)
|
|
1365
1538
|
|
|
1539
|
+
def get_mla_kv_buffer(
|
|
1540
|
+
self,
|
|
1541
|
+
layer: RadixAttention,
|
|
1542
|
+
loc: torch.Tensor,
|
|
1543
|
+
dst_dtype: Optional[torch.dtype] = None,
|
|
1544
|
+
):
|
|
1545
|
+
# get k nope and k rope from the kv buffer, and optionally cast them to dst_dtype.
|
|
1546
|
+
layer_id = layer.layer_id
|
|
1547
|
+
kv_buffer = self.get_key_buffer(layer_id)
|
|
1548
|
+
dst_dtype = dst_dtype or self.dtype
|
|
1549
|
+
cache_k_nope = torch.empty(
|
|
1550
|
+
(loc.shape[0], 1, self.kv_lora_rank),
|
|
1551
|
+
dtype=dst_dtype,
|
|
1552
|
+
device=kv_buffer.device,
|
|
1553
|
+
)
|
|
1554
|
+
cache_k_rope = torch.empty(
|
|
1555
|
+
(loc.shape[0], 1, self.qk_rope_head_dim),
|
|
1556
|
+
dtype=dst_dtype,
|
|
1557
|
+
device=kv_buffer.device,
|
|
1558
|
+
)
|
|
1559
|
+
get_mla_kv_buffer_triton(kv_buffer, loc, cache_k_nope, cache_k_rope)
|
|
1560
|
+
return cache_k_nope, cache_k_rope
|
|
1561
|
+
|
|
1366
1562
|
def get_cpu_copy(self, indices):
|
|
1367
1563
|
torch.cuda.synchronize()
|
|
1368
1564
|
kv_cache_cpu = []
|
|
@@ -238,12 +238,16 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
|
238
238
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
|
239
239
|
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
|
|
240
240
|
self.layout_dim = self.token_stride_size * self.layer_num
|
|
241
|
-
|
|
241
|
+
buffer = torch.empty(
|
|
242
242
|
dims,
|
|
243
243
|
dtype=self.dtype,
|
|
244
244
|
device=self.device,
|
|
245
|
-
pin_memory=self.pin_memory,
|
|
246
245
|
)
|
|
246
|
+
if self.pin_memory:
|
|
247
|
+
torch.cuda.cudart().cudaHostRegister(
|
|
248
|
+
buffer.data_ptr(), buffer.numel() * buffer.element_size(), 0
|
|
249
|
+
)
|
|
250
|
+
return buffer
|
|
247
251
|
|
|
248
252
|
@property
|
|
249
253
|
def k_buffer(self):
|
|
@@ -551,13 +555,16 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
|
551
555
|
self.kv_lora_rank + self.qk_rope_head_dim
|
|
552
556
|
) * self.dtype.itemsize
|
|
553
557
|
self.layout_dim = self.token_stride_size * self.layer_num
|
|
554
|
-
|
|
555
|
-
return torch.empty(
|
|
558
|
+
buffer = torch.empty(
|
|
556
559
|
dims,
|
|
557
560
|
dtype=self.dtype,
|
|
558
561
|
device=self.device,
|
|
559
|
-
pin_memory=self.pin_memory,
|
|
560
562
|
)
|
|
563
|
+
if self.pin_memory:
|
|
564
|
+
torch.cuda.cudart().cudaHostRegister(
|
|
565
|
+
buffer.data_ptr(), buffer.numel() * buffer.element_size(), 0
|
|
566
|
+
)
|
|
567
|
+
return buffer
|
|
561
568
|
|
|
562
569
|
def load_to_device_per_layer(
|
|
563
570
|
self, device_pool, host_indices, device_indices, layer_id, io_backend
|
|
@@ -533,6 +533,10 @@ class RadixCache(BasePrefixCache):
|
|
|
533
533
|
self.protected_size_ -= len(node.key)
|
|
534
534
|
delta += len(node.key)
|
|
535
535
|
node.lock_ref -= 1
|
|
536
|
+
if node.parent is None:
|
|
537
|
+
assert (
|
|
538
|
+
node is self.root_node
|
|
539
|
+
), f"This request holds the node from another tree"
|
|
536
540
|
node = node.parent
|
|
537
541
|
return delta
|
|
538
542
|
|
|
@@ -3,8 +3,9 @@ import atexit
|
|
|
3
3
|
import json
|
|
4
4
|
import logging
|
|
5
5
|
import threading
|
|
6
|
+
from collections import OrderedDict
|
|
6
7
|
from pathlib import Path
|
|
7
|
-
from typing import Dict, List, Optional,
|
|
8
|
+
from typing import Dict, List, Optional, Tuple
|
|
8
9
|
|
|
9
10
|
import orjson
|
|
10
11
|
import requests
|
|
@@ -136,7 +137,7 @@ class GlobalMetadataState:
|
|
|
136
137
|
num_pages = data["num_pages"]
|
|
137
138
|
rank_meta = RankMetadata(num_pages)
|
|
138
139
|
rank_meta.free_pages = data["free_pages"]
|
|
139
|
-
rank_meta.key_to_index =
|
|
140
|
+
rank_meta.key_to_index = OrderedDict(data["key_to_index"])
|
|
140
141
|
self.ranks[rank_id] = rank_meta
|
|
141
142
|
logging.info(
|
|
142
143
|
f"Successfully loaded metadata for {len(self.ranks)} ranks."
|
|
@@ -104,7 +104,7 @@ class MooncakeStoreConfig:
|
|
|
104
104
|
device_name=os.getenv("MOONCAKE_DEVICE", ""),
|
|
105
105
|
master_server_address=os.getenv("MOONCAKE_MASTER"),
|
|
106
106
|
master_metrics_port=int(
|
|
107
|
-
os.getenv("MOONCAKE_MASTER_METRICS_PORT",
|
|
107
|
+
os.getenv("MOONCAKE_MASTER_METRICS_PORT", DEFAULT_MASTER_METRICS_PORT)
|
|
108
108
|
),
|
|
109
109
|
check_server=bool(os.getenv("MOONCAKE_CHECK_SERVER", DEFAULT_CHECK_SERVER)),
|
|
110
110
|
)
|
sglang/srt/metrics/collector.py
CHANGED
|
@@ -811,6 +811,34 @@ class TokenizerMetricsCollector:
|
|
|
811
811
|
buckets=bucket_e2e_request_latency,
|
|
812
812
|
)
|
|
813
813
|
|
|
814
|
+
# Retraction count histogram
|
|
815
|
+
self.num_retractions = Histogram(
|
|
816
|
+
name="sglang:num_retractions",
|
|
817
|
+
documentation="Histogram of retraction counts per request.",
|
|
818
|
+
labelnames=labels.keys(),
|
|
819
|
+
buckets=[
|
|
820
|
+
0,
|
|
821
|
+
1,
|
|
822
|
+
2,
|
|
823
|
+
3,
|
|
824
|
+
4,
|
|
825
|
+
5,
|
|
826
|
+
6,
|
|
827
|
+
7,
|
|
828
|
+
8,
|
|
829
|
+
9,
|
|
830
|
+
10,
|
|
831
|
+
15,
|
|
832
|
+
20,
|
|
833
|
+
25,
|
|
834
|
+
30,
|
|
835
|
+
40,
|
|
836
|
+
50,
|
|
837
|
+
75,
|
|
838
|
+
100,
|
|
839
|
+
],
|
|
840
|
+
)
|
|
841
|
+
|
|
814
842
|
def observe_one_finished_request(
|
|
815
843
|
self,
|
|
816
844
|
labels: Dict[str, str],
|
|
@@ -819,6 +847,7 @@ class TokenizerMetricsCollector:
|
|
|
819
847
|
cached_tokens: int,
|
|
820
848
|
e2e_latency: float,
|
|
821
849
|
has_grammar: bool,
|
|
850
|
+
retraction_count: int,
|
|
822
851
|
):
|
|
823
852
|
self.prompt_tokens_total.labels(**labels).inc(prompt_tokens)
|
|
824
853
|
self.generation_tokens_total.labels(**labels).inc(generation_tokens)
|
|
@@ -833,6 +862,7 @@ class TokenizerMetricsCollector:
|
|
|
833
862
|
self.generation_tokens_histogram.labels(**labels).observe(
|
|
834
863
|
float(generation_tokens)
|
|
835
864
|
)
|
|
865
|
+
self.num_retractions.labels(**labels).observe(retraction_count)
|
|
836
866
|
|
|
837
867
|
def observe_time_to_first_token(self, labels: Dict[str, str], value: float):
|
|
838
868
|
self.histogram_time_to_first_token.labels(**labels).observe(value)
|
|
@@ -840,13 +870,13 @@ class TokenizerMetricsCollector:
|
|
|
840
870
|
def check_time_to_first_token_straggler(self, value: float) -> bool:
|
|
841
871
|
his = self.histogram_time_to_first_token.labels(**self.labels)
|
|
842
872
|
total_observations = sum(bucket._value for bucket in his._buckets)
|
|
843
|
-
if total_observations <
|
|
873
|
+
if total_observations < 100:
|
|
844
874
|
return False
|
|
845
|
-
|
|
875
|
+
p99_threshold = total_observations * 0.99
|
|
846
876
|
cumulative_count = 0
|
|
847
877
|
for i, bucket in enumerate(his._buckets):
|
|
848
878
|
cumulative_count += bucket._value
|
|
849
|
-
if cumulative_count >
|
|
879
|
+
if cumulative_count > p99_threshold:
|
|
850
880
|
return value >= his._upper_bounds[i]
|
|
851
881
|
return False
|
|
852
882
|
|
|
@@ -969,3 +999,16 @@ class StorageMetricsCollector:
|
|
|
969
999
|
self._log_histogram(self.histogram_prefetch_bandwidth, v)
|
|
970
1000
|
for v in storage_metrics.backup_bandwidth:
|
|
971
1001
|
self._log_histogram(self.histogram_backup_bandwidth, v)
|
|
1002
|
+
|
|
1003
|
+
|
|
1004
|
+
class ExpertDispatchCollector:
|
|
1005
|
+
def __init__(self, ep_size: int) -> None:
|
|
1006
|
+
from prometheus_client import Histogram
|
|
1007
|
+
|
|
1008
|
+
ep_size_buckets = [i for i in range(ep_size)]
|
|
1009
|
+
self.eplb_gpu_physical_count = Histogram(
|
|
1010
|
+
name="sglang:eplb_gpu_physical_count",
|
|
1011
|
+
documentation="The selected count of physical experts on each layer and GPU rank.",
|
|
1012
|
+
labelnames={"layer"},
|
|
1013
|
+
buckets=ep_size_buckets,
|
|
1014
|
+
)
|
|
@@ -21,12 +21,14 @@ import inspect
|
|
|
21
21
|
import logging
|
|
22
22
|
import os
|
|
23
23
|
from contextlib import contextmanager
|
|
24
|
+
from functools import partial
|
|
24
25
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
|
25
26
|
|
|
26
27
|
import torch
|
|
27
28
|
import tqdm
|
|
28
29
|
from torch.profiler import ProfilerActivity, profile
|
|
29
30
|
|
|
31
|
+
from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH
|
|
30
32
|
from sglang.srt.custom_op import CustomOp
|
|
31
33
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
|
32
34
|
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
|
@@ -64,6 +66,7 @@ from sglang.srt.utils import (
|
|
|
64
66
|
require_mlp_tp_gather,
|
|
65
67
|
)
|
|
66
68
|
from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
|
|
69
|
+
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
67
70
|
|
|
68
71
|
try:
|
|
69
72
|
from kt_kernel import AMXMoEWrapper
|
|
@@ -320,11 +323,11 @@ class CudaGraphRunner:
|
|
|
320
323
|
self.pp_proxy_tensors = {
|
|
321
324
|
"hidden_states": torch.zeros(
|
|
322
325
|
(self.max_bs, self.model_runner.model_config.hidden_size),
|
|
323
|
-
dtype=
|
|
326
|
+
dtype=self.model_runner.model_config.dtype,
|
|
324
327
|
),
|
|
325
328
|
"residual": torch.zeros(
|
|
326
329
|
(self.max_bs, self.model_runner.model_config.hidden_size),
|
|
327
|
-
dtype=
|
|
330
|
+
dtype=self.model_runner.model_config.dtype,
|
|
328
331
|
),
|
|
329
332
|
}
|
|
330
333
|
|
|
@@ -518,7 +521,16 @@ class CudaGraphRunner:
|
|
|
518
521
|
logger.info(log_message)
|
|
519
522
|
|
|
520
523
|
def _capture_graph(self, graph, pool, stream, run_once_fn):
|
|
521
|
-
|
|
524
|
+
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
|
525
|
+
enable=self.model_runner.server_args.enable_memory_saver
|
|
526
|
+
and get_bool_env_var("SGLANG_MEMORY_SAVER_CUDA_GRAPH")
|
|
527
|
+
)
|
|
528
|
+
graph_fn = (
|
|
529
|
+
partial(memory_saver_adapter.cuda_graph, tag=GPU_MEMORY_TYPE_CUDA_GRAPH)
|
|
530
|
+
if memory_saver_adapter.enabled
|
|
531
|
+
else self.device_module.graph
|
|
532
|
+
)
|
|
533
|
+
with graph_fn(cuda_graph=graph, pool=pool, stream=stream):
|
|
522
534
|
out = run_once_fn()
|
|
523
535
|
return out
|
|
524
536
|
|