sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -17,12 +17,14 @@ from __future__ import annotations
|
|
17
17
|
|
18
18
|
import bisect
|
19
19
|
import inspect
|
20
|
+
import logging
|
20
21
|
import os
|
21
22
|
from contextlib import contextmanager
|
22
23
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
23
24
|
|
24
25
|
import torch
|
25
26
|
import tqdm
|
27
|
+
from torch.profiler import ProfilerActivity, profile
|
26
28
|
|
27
29
|
from sglang.srt.custom_op import CustomOp
|
28
30
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
@@ -40,11 +42,18 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
40
42
|
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
41
43
|
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
42
44
|
from sglang.srt.utils import (
|
45
|
+
empty_context,
|
43
46
|
get_available_gpu_memory,
|
44
47
|
get_device_memory_capacity,
|
45
48
|
rank0_log,
|
49
|
+
require_attn_tp_gather,
|
50
|
+
require_gathered_buffer,
|
51
|
+
require_mlp_sync,
|
52
|
+
require_mlp_tp_gather,
|
46
53
|
)
|
47
54
|
|
55
|
+
logger = logging.getLogger(__name__)
|
56
|
+
|
48
57
|
if TYPE_CHECKING:
|
49
58
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
50
59
|
|
@@ -147,10 +156,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
147
156
|
)
|
148
157
|
|
149
158
|
gpu_mem = get_device_memory_capacity()
|
150
|
-
if gpu_mem is not None
|
151
|
-
|
152
|
-
|
153
|
-
|
159
|
+
if gpu_mem is not None:
|
160
|
+
if gpu_mem > 90 * 1024: # H200, H20
|
161
|
+
capture_bs += list(range(160, 257, 8))
|
162
|
+
if gpu_mem > 160 * 1000: # B200, MI300
|
163
|
+
capture_bs += list(range(256, 513, 16))
|
154
164
|
|
155
165
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
156
166
|
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
@@ -201,12 +211,17 @@ class CudaGraphRunner:
|
|
201
211
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
202
212
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
203
213
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
204
|
-
self.
|
205
|
-
self.
|
214
|
+
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
215
|
+
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
216
|
+
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
217
|
+
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
206
218
|
self.enable_two_batch_overlap = (
|
207
219
|
model_runner.server_args.enable_two_batch_overlap
|
208
220
|
)
|
209
221
|
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
222
|
+
self.enable_profile_cuda_graph = (
|
223
|
+
model_runner.server_args.enable_profile_cuda_graph
|
224
|
+
)
|
210
225
|
self.tp_size = model_runner.server_args.tp_size
|
211
226
|
self.dp_size = model_runner.server_args.dp_size
|
212
227
|
self.pp_size = model_runner.server_args.pp_size
|
@@ -226,16 +241,20 @@ class CudaGraphRunner:
|
|
226
241
|
self.model_runner.server_args.speculative_num_draft_tokens
|
227
242
|
)
|
228
243
|
|
244
|
+
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
|
245
|
+
if model_runner.server_args.enable_return_hidden_states:
|
246
|
+
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
247
|
+
|
229
248
|
# Attention backend
|
230
249
|
self.max_bs = max(self.capture_bs)
|
231
250
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
232
|
-
|
233
|
-
self.
|
234
|
-
|
235
|
-
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
|
251
|
+
self.model_runner.attn_backend.init_cuda_graph_state(
|
252
|
+
self.max_bs, self.max_num_token
|
253
|
+
)
|
236
254
|
self.seq_len_fill_value = (
|
237
255
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
238
256
|
)
|
257
|
+
|
239
258
|
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
240
259
|
self.encoder_len_fill_value = 0
|
241
260
|
self.seq_lens_cpu = torch.full(
|
@@ -286,18 +305,30 @@ class CudaGraphRunner:
|
|
286
305
|
else:
|
287
306
|
self.encoder_lens = None
|
288
307
|
|
289
|
-
if self.
|
290
|
-
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
308
|
+
if self.require_gathered_buffer:
|
291
309
|
self.gathered_buffer = torch.zeros(
|
292
310
|
(
|
293
|
-
self.
|
311
|
+
self.max_num_token,
|
294
312
|
self.model_runner.model_config.hidden_size,
|
295
313
|
),
|
296
314
|
dtype=self.model_runner.dtype,
|
297
315
|
)
|
298
|
-
self.
|
299
|
-
|
300
|
-
|
316
|
+
if self.require_mlp_tp_gather:
|
317
|
+
self.global_num_tokens_gpu = torch.zeros(
|
318
|
+
(self.dp_size,), dtype=torch.int32
|
319
|
+
)
|
320
|
+
else:
|
321
|
+
assert self.require_attn_tp_gather
|
322
|
+
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
323
|
+
|
324
|
+
self.custom_mask = torch.ones(
|
325
|
+
(
|
326
|
+
(self.seq_lens.sum().item() + self.max_num_token)
|
327
|
+
* self.num_tokens_per_bs
|
328
|
+
),
|
329
|
+
dtype=torch.bool,
|
330
|
+
device="cuda",
|
331
|
+
)
|
301
332
|
|
302
333
|
# Capture
|
303
334
|
try:
|
@@ -309,20 +340,23 @@ class CudaGraphRunner:
|
|
309
340
|
)
|
310
341
|
|
311
342
|
def can_run(self, forward_batch: ForwardBatch):
|
312
|
-
if self.
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
if self.disable_padding
|
318
|
-
else total_global_tokens <= self.max_bs
|
343
|
+
if self.require_mlp_tp_gather:
|
344
|
+
cuda_graph_bs = (
|
345
|
+
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
346
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
347
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
319
348
|
)
|
320
349
|
else:
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
350
|
+
cuda_graph_bs = forward_batch.batch_size
|
351
|
+
|
352
|
+
is_bs_supported = (
|
353
|
+
cuda_graph_bs in self.graphs
|
354
|
+
if self.disable_padding
|
355
|
+
else cuda_graph_bs <= self.max_bs
|
356
|
+
)
|
357
|
+
|
358
|
+
if self.require_mlp_sync:
|
359
|
+
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
326
360
|
|
327
361
|
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
328
362
|
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
@@ -333,50 +367,91 @@ class CudaGraphRunner:
|
|
333
367
|
else True
|
334
368
|
)
|
335
369
|
|
370
|
+
requested_capture_hidden_mode = max(
|
371
|
+
forward_batch.capture_hidden_mode,
|
372
|
+
(
|
373
|
+
forward_batch.spec_info.capture_hidden_mode
|
374
|
+
if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
|
375
|
+
is not None
|
376
|
+
else CaptureHiddenMode.NULL
|
377
|
+
),
|
378
|
+
)
|
379
|
+
capture_hidden_mode_matches = (
|
380
|
+
requested_capture_hidden_mode == CaptureHiddenMode.NULL
|
381
|
+
or requested_capture_hidden_mode == self.capture_hidden_mode
|
382
|
+
)
|
336
383
|
is_tbo_supported = (
|
337
384
|
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
338
385
|
)
|
339
386
|
|
340
|
-
return
|
387
|
+
return (
|
388
|
+
is_bs_supported
|
389
|
+
and is_encoder_lens_supported
|
390
|
+
and is_tbo_supported
|
391
|
+
and capture_hidden_mode_matches
|
392
|
+
)
|
341
393
|
|
342
|
-
def capture(self):
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
394
|
+
def capture(self) -> None:
|
395
|
+
profile_context = empty_context()
|
396
|
+
if self.enable_profile_cuda_graph:
|
397
|
+
profile_context = profile(
|
398
|
+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
399
|
+
record_shapes=True,
|
347
400
|
)
|
348
|
-
# Reverse the order to enable better memory sharing across cuda graphs.
|
349
|
-
capture_range = (
|
350
|
-
tqdm.tqdm(list(reversed(self.capture_bs)))
|
351
|
-
if get_tensor_model_parallel_rank() == 0
|
352
|
-
else reversed(self.capture_bs)
|
353
|
-
)
|
354
|
-
for bs in capture_range:
|
355
|
-
if get_tensor_model_parallel_rank() == 0:
|
356
|
-
avail_mem = get_available_gpu_memory(
|
357
|
-
self.model_runner.device,
|
358
|
-
self.model_runner.gpu_id,
|
359
|
-
empty_cache=False,
|
360
|
-
)
|
361
|
-
capture_range.set_description(
|
362
|
-
f"Capturing batches ({avail_mem=:.2f} GB)"
|
363
|
-
)
|
364
401
|
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
self.
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
402
|
+
with graph_capture() as graph_capture_context:
|
403
|
+
with profile_context as prof:
|
404
|
+
self.stream = graph_capture_context.stream
|
405
|
+
avail_mem = get_available_gpu_memory(
|
406
|
+
self.model_runner.device,
|
407
|
+
self.model_runner.gpu_id,
|
408
|
+
empty_cache=False,
|
409
|
+
)
|
410
|
+
# Reverse the order to enable better memory sharing across cuda graphs.
|
411
|
+
capture_range = (
|
412
|
+
tqdm.tqdm(list(reversed(self.capture_bs)))
|
413
|
+
if get_tensor_model_parallel_rank() == 0
|
414
|
+
else reversed(self.capture_bs)
|
415
|
+
)
|
416
|
+
for i, bs in enumerate(capture_range):
|
417
|
+
if get_tensor_model_parallel_rank() == 0:
|
418
|
+
avail_mem = get_available_gpu_memory(
|
419
|
+
self.model_runner.device,
|
420
|
+
self.model_runner.gpu_id,
|
421
|
+
empty_cache=False,
|
422
|
+
)
|
423
|
+
capture_range.set_description(
|
424
|
+
f"Capturing batches ({avail_mem=:.2f} GB)"
|
425
|
+
)
|
426
|
+
|
427
|
+
with patch_model(
|
428
|
+
self.model_runner.model,
|
429
|
+
bs in self.compile_bs,
|
430
|
+
num_tokens=bs * self.num_tokens_per_bs,
|
431
|
+
tp_group=self.model_runner.tp_group,
|
432
|
+
) as forward:
|
433
|
+
(
|
434
|
+
graph,
|
435
|
+
output_buffers,
|
436
|
+
) = self.capture_one_batch_size(bs, forward)
|
437
|
+
self.graphs[bs] = graph
|
438
|
+
self.output_buffers[bs] = output_buffers
|
439
|
+
|
440
|
+
# Save gemlite cache after each capture
|
441
|
+
save_gemlite_cache()
|
442
|
+
|
443
|
+
if self.enable_profile_cuda_graph:
|
444
|
+
log_message = (
|
445
|
+
"Sorted by CUDA Time:\n"
|
446
|
+
+ prof.key_averages(group_by_input_shape=True).table(
|
447
|
+
sort_by="cuda_time_total", row_limit=10
|
448
|
+
)
|
449
|
+
+ "\n\nSorted by CPU Time:\n"
|
450
|
+
+ prof.key_averages(group_by_input_shape=True).table(
|
451
|
+
sort_by="cpu_time_total", row_limit=10
|
452
|
+
)
|
453
|
+
)
|
454
|
+
logger.info(log_message)
|
380
455
|
|
381
456
|
def capture_one_batch_size(self, bs: int, forward: Callable):
|
382
457
|
graph = torch.cuda.CUDAGraph()
|
@@ -402,11 +477,11 @@ class CudaGraphRunner:
|
|
402
477
|
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
|
403
478
|
)
|
404
479
|
|
405
|
-
if self.
|
480
|
+
if self.require_mlp_tp_gather:
|
406
481
|
self.global_num_tokens_gpu.copy_(
|
407
482
|
torch.tensor(
|
408
483
|
[
|
409
|
-
num_tokens // self.dp_size + (i <
|
484
|
+
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
410
485
|
for i in range(self.dp_size)
|
411
486
|
],
|
412
487
|
dtype=torch.int32,
|
@@ -415,6 +490,16 @@ class CudaGraphRunner:
|
|
415
490
|
)
|
416
491
|
global_num_tokens = self.global_num_tokens_gpu
|
417
492
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
493
|
+
elif self.require_attn_tp_gather:
|
494
|
+
self.global_num_tokens_gpu.copy_(
|
495
|
+
torch.tensor(
|
496
|
+
[num_tokens],
|
497
|
+
dtype=torch.int32,
|
498
|
+
device=input_ids.device,
|
499
|
+
)
|
500
|
+
)
|
501
|
+
global_num_tokens = self.global_num_tokens_gpu
|
502
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
418
503
|
else:
|
419
504
|
global_num_tokens = None
|
420
505
|
gathered_buffer = None
|
@@ -443,7 +528,7 @@ class CudaGraphRunner:
|
|
443
528
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
444
529
|
attn_backend=self.model_runner.attn_backend,
|
445
530
|
out_cache_loc=out_cache_loc,
|
446
|
-
seq_lens_sum=seq_lens.sum(),
|
531
|
+
seq_lens_sum=seq_lens.sum().item(),
|
447
532
|
encoder_lens=encoder_lens,
|
448
533
|
return_logprob=False,
|
449
534
|
positions=positions,
|
@@ -509,21 +594,34 @@ class CudaGraphRunner:
|
|
509
594
|
return graph, out
|
510
595
|
|
511
596
|
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
512
|
-
|
513
|
-
|
597
|
+
|
598
|
+
# If the required capture_hidden_mode changes, we need to recapture the graph
|
599
|
+
|
600
|
+
# These are the different factors that can influence the capture_hidden_mode
|
601
|
+
capture_hidden_mode_required_by_forward_batch = (
|
602
|
+
forward_batch.capture_hidden_mode
|
603
|
+
)
|
604
|
+
capture_hidden_mode_required_by_spec_info = getattr(
|
514
605
|
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
515
606
|
)
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
607
|
+
capture_hidden_mode_required_for_returning_hidden_states = (
|
608
|
+
CaptureHiddenMode.FULL
|
609
|
+
if self.model_runner.server_args.enable_return_hidden_states
|
610
|
+
else CaptureHiddenMode.NULL
|
611
|
+
)
|
612
|
+
|
613
|
+
# Determine the highest capture_hidden_mode required
|
614
|
+
# (If we have FULL, we can emulate LAST or NULL)
|
615
|
+
# (If we have LAST, we can emulate NULL)
|
616
|
+
required_capture_hidden_mode = max(
|
617
|
+
capture_hidden_mode_required_by_forward_batch,
|
618
|
+
capture_hidden_mode_required_by_spec_info,
|
619
|
+
capture_hidden_mode_required_for_returning_hidden_states,
|
620
|
+
)
|
621
|
+
|
622
|
+
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
|
623
|
+
if self.capture_hidden_mode != required_capture_hidden_mode:
|
624
|
+
self.capture_hidden_mode = required_capture_hidden_mode
|
527
625
|
self.capture()
|
528
626
|
|
529
627
|
def replay_prepare(
|
@@ -537,15 +635,18 @@ class CudaGraphRunner:
|
|
537
635
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
538
636
|
|
539
637
|
# Pad
|
540
|
-
if self.
|
541
|
-
|
542
|
-
|
638
|
+
if self.require_mlp_tp_gather:
|
639
|
+
total_batch_size = (
|
640
|
+
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
|
641
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
642
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
543
643
|
)
|
644
|
+
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
544
645
|
else:
|
545
646
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
546
647
|
bs = self.capture_bs[index]
|
547
648
|
if bs != raw_bs:
|
548
|
-
self.seq_lens.fill_(
|
649
|
+
self.seq_lens.fill_(self.seq_len_fill_value)
|
549
650
|
self.out_cache_loc.zero_()
|
550
651
|
|
551
652
|
# Common inputs
|
@@ -557,7 +658,7 @@ class CudaGraphRunner:
|
|
557
658
|
|
558
659
|
if forward_batch.seq_lens_cpu is not None:
|
559
660
|
if bs != raw_bs:
|
560
|
-
self.seq_lens_cpu.fill_(
|
661
|
+
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
561
662
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
562
663
|
|
563
664
|
if pp_proxy_tensors:
|
@@ -569,27 +670,28 @@ class CudaGraphRunner:
|
|
569
670
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
570
671
|
if forward_batch.mrope_positions is not None:
|
571
672
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
572
|
-
if self.
|
673
|
+
if self.require_gathered_buffer:
|
573
674
|
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
574
675
|
if enable_num_token_non_padded(self.model_runner.server_args):
|
575
676
|
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
|
576
677
|
if self.enable_two_batch_overlap:
|
577
678
|
self.tbo_plugin.replay_prepare(
|
578
|
-
forward_mode=
|
679
|
+
forward_mode=self.capture_forward_mode,
|
579
680
|
bs=bs,
|
580
681
|
num_token_non_padded=len(forward_batch.input_ids),
|
581
682
|
)
|
582
|
-
|
683
|
+
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
|
684
|
+
forward_batch.spec_info.custom_mask = self.custom_mask
|
583
685
|
# Attention backend
|
584
686
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
585
687
|
bs,
|
586
|
-
self.req_pool_indices,
|
587
|
-
self.seq_lens,
|
588
|
-
forward_batch.seq_lens_sum + (bs - raw_bs),
|
589
|
-
self.encoder_lens,
|
590
|
-
|
688
|
+
self.req_pool_indices[:bs],
|
689
|
+
self.seq_lens[:bs],
|
690
|
+
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
|
691
|
+
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
|
692
|
+
self.capture_forward_mode,
|
591
693
|
forward_batch.spec_info,
|
592
|
-
seq_lens_cpu=self.seq_lens_cpu,
|
694
|
+
seq_lens_cpu=self.seq_lens_cpu[:bs],
|
593
695
|
)
|
594
696
|
|
595
697
|
# Store fields
|
@@ -637,11 +739,7 @@ class CudaGraphRunner:
|
|
637
739
|
else:
|
638
740
|
spec_info = EagleVerifyInput(
|
639
741
|
draft_token=None,
|
640
|
-
custom_mask=
|
641
|
-
(num_tokens * self.model_runner.model_config.context_len),
|
642
|
-
dtype=torch.bool,
|
643
|
-
device="cuda",
|
644
|
-
),
|
742
|
+
custom_mask=self.custom_mask,
|
645
743
|
positions=None,
|
646
744
|
retrive_index=None,
|
647
745
|
retrive_next_token=None,
|
@@ -31,6 +31,7 @@ from __future__ import annotations
|
|
31
31
|
|
32
32
|
from dataclasses import dataclass
|
33
33
|
from enum import IntEnum, auto
|
34
|
+
from functools import total_ordering
|
34
35
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
35
36
|
|
36
37
|
import torch
|
@@ -117,13 +118,14 @@ class ForwardMode(IntEnum):
|
|
117
118
|
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
|
118
119
|
|
119
120
|
|
121
|
+
@total_ordering
|
120
122
|
class CaptureHiddenMode(IntEnum):
|
121
123
|
# Do not capture anything.
|
122
|
-
NULL =
|
123
|
-
# Capture hidden states of all tokens.
|
124
|
-
FULL = auto()
|
124
|
+
NULL = 0
|
125
125
|
# Capture a hidden state of the last token.
|
126
|
-
LAST =
|
126
|
+
LAST = 1
|
127
|
+
# Capture hidden states of all tokens.
|
128
|
+
FULL = 2
|
127
129
|
|
128
130
|
def need_capture(self):
|
129
131
|
return self != CaptureHiddenMode.NULL
|
@@ -134,6 +136,9 @@ class CaptureHiddenMode(IntEnum):
|
|
134
136
|
def is_last(self):
|
135
137
|
return self == CaptureHiddenMode.LAST
|
136
138
|
|
139
|
+
def __lt__(self, other):
|
140
|
+
return self.value < other.value
|
141
|
+
|
137
142
|
|
138
143
|
@dataclass
|
139
144
|
class ForwardBatch:
|
@@ -219,6 +224,9 @@ class ForwardBatch:
|
|
219
224
|
# For input embeddings
|
220
225
|
input_embeds: Optional[torch.tensor] = None
|
221
226
|
|
227
|
+
# For cross-encoder model
|
228
|
+
token_type_ids: Optional[torch.Tensor] = None
|
229
|
+
|
222
230
|
# Sampling info
|
223
231
|
sampling_info: SamplingBatchInfo = None
|
224
232
|
|
@@ -295,6 +303,7 @@ class ForwardBatch:
|
|
295
303
|
spec_info=batch.spec_info,
|
296
304
|
capture_hidden_mode=batch.capture_hidden_mode,
|
297
305
|
input_embeds=batch.input_embeds,
|
306
|
+
token_type_ids=batch.token_type_ids,
|
298
307
|
tbo_split_seq_index=batch.tbo_split_seq_index,
|
299
308
|
)
|
300
309
|
device = model_runner.device
|
@@ -311,17 +320,30 @@ class ForwardBatch:
|
|
311
320
|
|
312
321
|
# For DP attention
|
313
322
|
if batch.global_num_tokens is not None:
|
314
|
-
|
323
|
+
|
324
|
+
spec_num_draft_tokens = (
|
325
|
+
batch.spec_num_draft_tokens
|
326
|
+
if batch.spec_num_draft_tokens is not None
|
327
|
+
else 1
|
328
|
+
)
|
329
|
+
global_num_tokens = [
|
330
|
+
x * spec_num_draft_tokens for x in batch.global_num_tokens
|
331
|
+
]
|
332
|
+
global_num_tokens_for_logprob = [
|
333
|
+
x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
|
334
|
+
]
|
335
|
+
|
336
|
+
ret.global_num_tokens_cpu = global_num_tokens
|
315
337
|
ret.global_num_tokens_gpu = torch.tensor(
|
316
|
-
|
338
|
+
global_num_tokens, dtype=torch.int64
|
317
339
|
).to(device, non_blocking=True)
|
318
340
|
|
319
|
-
ret.global_num_tokens_for_logprob_cpu =
|
341
|
+
ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
|
320
342
|
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
|
321
|
-
|
343
|
+
global_num_tokens_for_logprob, dtype=torch.int64
|
322
344
|
).to(device, non_blocking=True)
|
323
345
|
|
324
|
-
sum_len = sum(
|
346
|
+
sum_len = sum(global_num_tokens)
|
325
347
|
ret.gathered_buffer = torch.zeros(
|
326
348
|
(sum_len, model_runner.model_config.hidden_size),
|
327
349
|
dtype=model_runner.dtype,
|
@@ -351,8 +373,8 @@ class ForwardBatch:
|
|
351
373
|
ret.extend_prefix_lens = torch.tensor(
|
352
374
|
batch.extend_prefix_lens, dtype=torch.int32
|
353
375
|
).to(device, non_blocking=True)
|
376
|
+
ret.extend_num_tokens = batch.extend_num_tokens
|
354
377
|
if support_triton(model_runner.server_args.attention_backend):
|
355
|
-
ret.extend_num_tokens = batch.extend_num_tokens
|
356
378
|
positions, ret.extend_start_loc = compute_position_triton(
|
357
379
|
ret.extend_prefix_lens,
|
358
380
|
ret.extend_seq_lens,
|