sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/check_env.py +3 -3
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +15 -0
- sglang/srt/conversation.py +122 -1
- sglang/srt/disaggregation/decode.py +8 -2
- sglang/srt/disaggregation/fake/__init__.py +1 -0
- sglang/srt/disaggregation/fake/conn.py +88 -0
- sglang/srt/disaggregation/prefill.py +12 -3
- sglang/srt/disaggregation/utils.py +16 -2
- sglang/srt/entrypoints/engine.py +52 -21
- sglang/srt/entrypoints/http_server.py +27 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
- sglang/srt/layers/attention/flashinfer_backend.py +107 -82
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +1 -1
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
- sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +1 -1
- sglang/srt/layers/quantization/fp8.py +20 -22
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +84 -35
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +34 -15
- sglang/srt/managers/scheduler.py +273 -67
- sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
- sglang/srt/managers/tp_worker.py +52 -17
- sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +123 -58
- sglang/srt/models/deepseek_nextn.py +1 -257
- sglang/srt/models/deepseek_v2.py +78 -18
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +92 -30
- sglang/srt/models/llama4.py +2 -1
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +0 -12
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/openai_api/adapter.py +49 -8
- sglang/srt/openai_api/protocol.py +13 -1
- sglang/srt/reasoning_parser.py +25 -1
- sglang/srt/server_args.py +83 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +91 -9
- sglang/test/runners.py +4 -0
- sglang/test/send_one.py +84 -28
- sglang/test/test_utils.py +67 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import bisect
|
19
|
+
import inspect
|
19
20
|
import os
|
20
21
|
from contextlib import contextmanager
|
21
22
|
from typing import TYPE_CHECKING, Callable
|
@@ -33,12 +34,14 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
33
34
|
CaptureHiddenMode,
|
34
35
|
ForwardBatch,
|
35
36
|
ForwardMode,
|
37
|
+
PPProxyTensors,
|
36
38
|
)
|
37
39
|
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
38
40
|
from sglang.srt.utils import (
|
39
41
|
get_available_gpu_memory,
|
40
42
|
get_device_memory_capacity,
|
41
43
|
is_hip,
|
44
|
+
rank0_log,
|
42
45
|
)
|
43
46
|
|
44
47
|
if TYPE_CHECKING:
|
@@ -135,7 +138,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
135
138
|
|
136
139
|
gpu_mem = get_device_memory_capacity()
|
137
140
|
# Batch size of each rank will not become so large when DP is on
|
138
|
-
if gpu_mem is not None and gpu_mem >
|
141
|
+
if gpu_mem is not None and gpu_mem > 96 * 1024:
|
139
142
|
capture_bs += list(range(160, 257, 8))
|
140
143
|
|
141
144
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
@@ -188,10 +191,11 @@ class CudaGraphRunner:
|
|
188
191
|
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
189
192
|
self.tp_size = model_runner.server_args.tp_size
|
190
193
|
self.dp_size = model_runner.server_args.dp_size
|
194
|
+
self.pp_size = model_runner.server_args.pp_size
|
191
195
|
|
192
196
|
# Batch sizes to capture
|
193
197
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
194
|
-
|
198
|
+
rank0_log(f"Capture cuda graph bs {self.capture_bs}")
|
195
199
|
self.capture_forward_mode = ForwardMode.DECODE
|
196
200
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
197
201
|
self.num_tokens_per_bs = 1
|
@@ -220,6 +224,9 @@ class CudaGraphRunner:
|
|
220
224
|
if self.enable_torch_compile:
|
221
225
|
set_torch_compile_config()
|
222
226
|
|
227
|
+
if self.model_runner.server_args.lora_paths is not None:
|
228
|
+
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
|
229
|
+
|
223
230
|
# Graph inputs
|
224
231
|
with torch.device("cuda"):
|
225
232
|
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
@@ -231,6 +238,19 @@ class CudaGraphRunner:
|
|
231
238
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
232
239
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
233
240
|
|
241
|
+
# pipeline parallelism
|
242
|
+
if self.pp_size > 1:
|
243
|
+
self.pp_proxy_tensors = {
|
244
|
+
"hidden_states": torch.zeros(
|
245
|
+
(self.max_bs, self.model_runner.model_config.hidden_size),
|
246
|
+
dtype=torch.bfloat16,
|
247
|
+
),
|
248
|
+
"residual": torch.zeros(
|
249
|
+
(self.max_bs, self.model_runner.model_config.hidden_size),
|
250
|
+
dtype=torch.bfloat16,
|
251
|
+
),
|
252
|
+
}
|
253
|
+
|
234
254
|
# Speculative_inference
|
235
255
|
if (
|
236
256
|
model_runner.spec_algorithm.is_eagle3()
|
@@ -381,6 +401,12 @@ class CudaGraphRunner:
|
|
381
401
|
encoder_lens = None
|
382
402
|
mrope_positions = self.mrope_positions[:, :bs]
|
383
403
|
|
404
|
+
# pipeline parallelism
|
405
|
+
if self.pp_size > 1:
|
406
|
+
pp_proxy_tensors = PPProxyTensors(
|
407
|
+
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
|
408
|
+
)
|
409
|
+
|
384
410
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
385
411
|
self.global_num_tokens_gpu.copy_(
|
386
412
|
torch.tensor(
|
@@ -403,6 +429,13 @@ class CudaGraphRunner:
|
|
403
429
|
self.capture_hidden_mode = (
|
404
430
|
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
405
431
|
)
|
432
|
+
if self.model_runner.server_args.lora_paths is not None:
|
433
|
+
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
|
434
|
+
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
|
435
|
+
# values if lora is enabled.
|
436
|
+
lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs
|
437
|
+
else:
|
438
|
+
lora_paths = None
|
406
439
|
|
407
440
|
forward_batch = ForwardBatch(
|
408
441
|
forward_mode=self.capture_forward_mode,
|
@@ -424,8 +457,12 @@ class CudaGraphRunner:
|
|
424
457
|
spec_algorithm=self.model_runner.spec_algorithm,
|
425
458
|
spec_info=spec_info,
|
426
459
|
capture_hidden_mode=self.capture_hidden_mode,
|
460
|
+
lora_paths=lora_paths,
|
427
461
|
)
|
428
462
|
|
463
|
+
if lora_paths is not None:
|
464
|
+
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
465
|
+
|
429
466
|
# Attention backend
|
430
467
|
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
431
468
|
bs,
|
@@ -442,8 +479,20 @@ class CudaGraphRunner:
|
|
442
479
|
# Clean intermediate result cache for DP attention
|
443
480
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
444
481
|
|
445
|
-
|
446
|
-
|
482
|
+
kwargs = {}
|
483
|
+
if (
|
484
|
+
self.pp_size > 1
|
485
|
+
and "pp_proxy_tensors" in inspect.signature(forward).parameters
|
486
|
+
):
|
487
|
+
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
488
|
+
|
489
|
+
logits_output_or_pp_proxy_tensors = forward(
|
490
|
+
input_ids,
|
491
|
+
forward_batch.positions,
|
492
|
+
forward_batch,
|
493
|
+
**kwargs,
|
494
|
+
)
|
495
|
+
return logits_output_or_pp_proxy_tensors
|
447
496
|
|
448
497
|
for _ in range(2):
|
449
498
|
torch.cuda.synchronize()
|
@@ -476,7 +525,11 @@ class CudaGraphRunner:
|
|
476
525
|
self.capture_hidden_mode = hidden_mode_from_spec_info
|
477
526
|
self.capture()
|
478
527
|
|
479
|
-
def replay_prepare(
|
528
|
+
def replay_prepare(
|
529
|
+
self,
|
530
|
+
forward_batch: ForwardBatch,
|
531
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
532
|
+
):
|
480
533
|
self.recapture_if_needed(forward_batch)
|
481
534
|
|
482
535
|
raw_bs = forward_batch.batch_size
|
@@ -505,6 +558,11 @@ class CudaGraphRunner:
|
|
505
558
|
self.seq_lens_cpu.fill_(1)
|
506
559
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
507
560
|
|
561
|
+
if pp_proxy_tensors:
|
562
|
+
for key in self.pp_proxy_tensors.keys():
|
563
|
+
dim = pp_proxy_tensors[key].shape[0]
|
564
|
+
self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key])
|
565
|
+
|
508
566
|
if self.is_encoder_decoder:
|
509
567
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
510
568
|
if forward_batch.mrope_positions is not None:
|
@@ -533,10 +591,13 @@ class CudaGraphRunner:
|
|
533
591
|
self.bs = bs
|
534
592
|
|
535
593
|
def replay(
|
536
|
-
self,
|
537
|
-
|
594
|
+
self,
|
595
|
+
forward_batch: ForwardBatch,
|
596
|
+
skip_attn_backend_init: bool = False,
|
597
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
598
|
+
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
538
599
|
if not skip_attn_backend_init:
|
539
|
-
self.replay_prepare(forward_batch)
|
600
|
+
self.replay_prepare(forward_batch, pp_proxy_tensors)
|
540
601
|
else:
|
541
602
|
# In speculative decoding, these two fields are still needed.
|
542
603
|
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
@@ -544,17 +605,19 @@ class CudaGraphRunner:
|
|
544
605
|
|
545
606
|
# Replay
|
546
607
|
self.graphs[self.bs].replay()
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
608
|
+
output = self.output_buffers[self.bs]
|
609
|
+
if isinstance(output, LogitsProcessorOutput):
|
610
|
+
return LogitsProcessorOutput(
|
611
|
+
next_token_logits=output.next_token_logits[: self.raw_num_token],
|
612
|
+
hidden_states=(
|
613
|
+
output.hidden_states[: self.raw_num_token]
|
614
|
+
if output.hidden_states is not None
|
615
|
+
else None
|
616
|
+
),
|
617
|
+
)
|
618
|
+
else:
|
619
|
+
assert isinstance(output, PPProxyTensors)
|
620
|
+
return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
|
558
621
|
|
559
622
|
def get_spec_info(self, num_tokens: int):
|
560
623
|
spec_info = None
|
@@ -31,7 +31,7 @@ from __future__ import annotations
|
|
31
31
|
|
32
32
|
from dataclasses import dataclass
|
33
33
|
from enum import IntEnum, auto
|
34
|
-
from typing import TYPE_CHECKING, List, Optional, Union
|
34
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
35
35
|
|
36
36
|
import torch
|
37
37
|
import triton
|
@@ -585,6 +585,36 @@ class ForwardBatch:
|
|
585
585
|
self.prepare_chunked_kv_indices(device)
|
586
586
|
|
587
587
|
|
588
|
+
class PPProxyTensors:
|
589
|
+
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
|
590
|
+
tensors: Dict[str, torch.Tensor]
|
591
|
+
|
592
|
+
def __init__(self, tensors):
|
593
|
+
# manually define this function, so that
|
594
|
+
# Dynamo knows `IntermediateTensors()` comes from this file.
|
595
|
+
# Otherwise, dataclass will generate this function by evaluating
|
596
|
+
# a string, and we will lose the information about the source file.
|
597
|
+
self.tensors = tensors
|
598
|
+
|
599
|
+
def __getitem__(self, key: Union[str, slice]):
|
600
|
+
if isinstance(key, str):
|
601
|
+
return self.tensors[key]
|
602
|
+
elif isinstance(key, slice):
|
603
|
+
return self.__class__({k: v[key] for k, v in self.tensors.items()})
|
604
|
+
|
605
|
+
def __setitem__(self, key: str, value: torch.Tensor):
|
606
|
+
self.tensors[key] = value
|
607
|
+
|
608
|
+
def __len__(self):
|
609
|
+
return len(self.tensors)
|
610
|
+
|
611
|
+
def __eq__(self, other: object):
|
612
|
+
return isinstance(other, self.__class__) and self
|
613
|
+
|
614
|
+
def __repr__(self) -> str:
|
615
|
+
return f"PPProxyTensors(tensors={self.tensors})"
|
616
|
+
|
617
|
+
|
588
618
|
def compute_position_triton(
|
589
619
|
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
590
620
|
):
|
@@ -13,8 +13,10 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""ModelRunner runs the forward passes of the models."""
|
15
15
|
|
16
|
+
import collections
|
16
17
|
import datetime
|
17
18
|
import gc
|
19
|
+
import inspect
|
18
20
|
import json
|
19
21
|
import logging
|
20
22
|
import os
|
@@ -59,7 +61,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
59
61
|
)
|
60
62
|
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
61
63
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
62
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
64
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
63
65
|
from sglang.srt.model_loader import get_model
|
64
66
|
from sglang.srt.model_loader.loader import (
|
65
67
|
DefaultModelLoader,
|
@@ -110,6 +112,8 @@ class ModelRunner:
|
|
110
112
|
gpu_id: int,
|
111
113
|
tp_rank: int,
|
112
114
|
tp_size: int,
|
115
|
+
pp_rank: int,
|
116
|
+
pp_size: int,
|
113
117
|
nccl_port: int,
|
114
118
|
server_args: ServerArgs,
|
115
119
|
is_draft_worker: bool = False,
|
@@ -123,6 +127,8 @@ class ModelRunner:
|
|
123
127
|
self.gpu_id = gpu_id
|
124
128
|
self.tp_rank = tp_rank
|
125
129
|
self.tp_size = tp_size
|
130
|
+
self.pp_rank = pp_rank
|
131
|
+
self.pp_size = pp_size
|
126
132
|
self.dist_port = nccl_port
|
127
133
|
self.server_args = server_args
|
128
134
|
self.is_draft_worker = is_draft_worker
|
@@ -148,24 +154,24 @@ class ModelRunner:
|
|
148
154
|
global_server_args_dict.update(
|
149
155
|
{
|
150
156
|
"attention_backend": server_args.attention_backend,
|
151
|
-
"
|
152
|
-
"
|
153
|
-
"
|
157
|
+
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
158
|
+
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
159
|
+
"deepep_mode": server_args.deepep_mode,
|
160
|
+
"device": server_args.device,
|
161
|
+
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
|
162
|
+
"disable_radix_cache": server_args.disable_radix_cache,
|
154
163
|
"enable_nan_detection": server_args.enable_nan_detection,
|
155
164
|
"enable_dp_attention": server_args.enable_dp_attention,
|
156
165
|
"enable_ep_moe": server_args.enable_ep_moe,
|
157
166
|
"enable_deepep_moe": server_args.enable_deepep_moe,
|
158
|
-
"deepep_mode": server_args.deepep_mode,
|
159
|
-
"device": server_args.device,
|
160
|
-
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
161
|
-
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
162
|
-
"disable_radix_cache": server_args.disable_radix_cache,
|
163
167
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
164
168
|
"moe_dense_tp_size": server_args.moe_dense_tp_size,
|
165
|
-
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
166
|
-
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
167
169
|
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
168
|
-
"
|
170
|
+
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
171
|
+
"torchao_config": server_args.torchao_config,
|
172
|
+
"sampling_backend": server_args.sampling_backend,
|
173
|
+
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
174
|
+
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
169
175
|
"use_mla_backend": self.use_mla_backend,
|
170
176
|
}
|
171
177
|
)
|
@@ -183,6 +189,11 @@ class ModelRunner:
|
|
183
189
|
# If it is a draft model, tp_group can be different
|
184
190
|
self.initialize(min_per_gpu_memory)
|
185
191
|
|
192
|
+
# temporary cached values
|
193
|
+
self.support_pp = (
|
194
|
+
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
|
195
|
+
)
|
196
|
+
|
186
197
|
def initialize(self, min_per_gpu_memory: float):
|
187
198
|
server_args = self.server_args
|
188
199
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
@@ -193,6 +204,12 @@ class ModelRunner:
|
|
193
204
|
self.sampler = Sampler()
|
194
205
|
self.load_model()
|
195
206
|
|
207
|
+
self.start_layer = getattr(self.model, "start_layer", 0)
|
208
|
+
self.end_layer = getattr(
|
209
|
+
self.model, "end_layer", self.model_config.num_hidden_layers
|
210
|
+
)
|
211
|
+
self.num_effective_layers = self.end_layer - self.start_layer
|
212
|
+
|
196
213
|
# Apply torchao quantization
|
197
214
|
torchao_applied = getattr(self.model, "torchao_applied", False)
|
198
215
|
# In layered loading, torchao may have been applied
|
@@ -271,6 +288,7 @@ class ModelRunner:
|
|
271
288
|
"fa3",
|
272
289
|
"triton",
|
273
290
|
"flashmla",
|
291
|
+
"cutlass_mla",
|
274
292
|
]:
|
275
293
|
logger.info(
|
276
294
|
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
@@ -358,18 +376,22 @@ class ModelRunner:
|
|
358
376
|
# Only initialize the distributed environment on the target model worker.
|
359
377
|
init_distributed_environment(
|
360
378
|
backend=backend,
|
361
|
-
world_size=self.tp_size,
|
362
|
-
rank=self.tp_rank,
|
379
|
+
world_size=self.tp_size * self.pp_size,
|
380
|
+
rank=self.tp_size * self.pp_rank + self.tp_rank,
|
363
381
|
local_rank=self.gpu_id,
|
364
382
|
distributed_init_method=dist_init_method,
|
365
383
|
timeout=self.server_args.dist_timeout,
|
366
384
|
)
|
367
|
-
initialize_model_parallel(
|
385
|
+
initialize_model_parallel(
|
386
|
+
tensor_model_parallel_size=self.tp_size,
|
387
|
+
pipeline_model_parallel_size=self.pp_size,
|
388
|
+
)
|
368
389
|
initialize_dp_attention(
|
369
390
|
enable_dp_attention=self.server_args.enable_dp_attention,
|
370
391
|
tp_rank=self.tp_rank,
|
371
392
|
tp_size=self.tp_size,
|
372
393
|
dp_size=self.server_args.dp_size,
|
394
|
+
pp_size=self.server_args.pp_size,
|
373
395
|
)
|
374
396
|
|
375
397
|
min_per_gpu_memory = get_available_gpu_memory(
|
@@ -691,16 +713,23 @@ class ModelRunner:
|
|
691
713
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
692
714
|
)
|
693
715
|
if self.use_mla_backend:
|
716
|
+
num_layers = (
|
717
|
+
self.model_config.num_hidden_layers
|
718
|
+
if not self.is_draft_worker
|
719
|
+
else self.model_config.hf_config.num_nextn_predict_layers
|
720
|
+
)
|
721
|
+
# FIXME: pipeline parallelism is not compatible with mla backend
|
722
|
+
assert self.pp_size == 1
|
694
723
|
cell_size = (
|
695
724
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
696
|
-
*
|
725
|
+
* num_layers
|
697
726
|
* torch._utils._element_size(self.kv_cache_dtype)
|
698
727
|
)
|
699
728
|
else:
|
700
729
|
cell_size = (
|
701
730
|
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
702
731
|
* self.model_config.head_dim
|
703
|
-
* self.
|
732
|
+
* self.num_effective_layers
|
704
733
|
* 2
|
705
734
|
* torch._utils._element_size(self.kv_cache_dtype)
|
706
735
|
)
|
@@ -808,9 +837,15 @@ class ModelRunner:
|
|
808
837
|
dtype=self.kv_cache_dtype,
|
809
838
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
810
839
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
811
|
-
layer_num=
|
840
|
+
layer_num=(
|
841
|
+
self.model_config.num_hidden_layers
|
842
|
+
if not self.is_draft_worker
|
843
|
+
else self.model_config.hf_config.num_nextn_predict_layers
|
844
|
+
), # PP is not compatible with mla backend
|
812
845
|
device=self.device,
|
813
846
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
847
|
+
start_layer=self.start_layer,
|
848
|
+
end_layer=self.end_layer,
|
814
849
|
)
|
815
850
|
elif self.server_args.enable_double_sparsity:
|
816
851
|
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
@@ -819,10 +854,12 @@ class ModelRunner:
|
|
819
854
|
dtype=self.kv_cache_dtype,
|
820
855
|
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
821
856
|
head_dim=self.model_config.head_dim,
|
822
|
-
layer_num=self.
|
857
|
+
layer_num=self.num_effective_layers,
|
823
858
|
device=self.device,
|
824
859
|
heavy_channel_num=self.server_args.ds_heavy_channel_num,
|
825
860
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
861
|
+
start_layer=self.start_layer,
|
862
|
+
end_layer=self.end_layer,
|
826
863
|
)
|
827
864
|
else:
|
828
865
|
self.token_to_kv_pool = MHATokenToKVPool(
|
@@ -831,9 +868,11 @@ class ModelRunner:
|
|
831
868
|
dtype=self.kv_cache_dtype,
|
832
869
|
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
833
870
|
head_dim=self.model_config.head_dim,
|
834
|
-
layer_num=self.
|
871
|
+
layer_num=self.num_effective_layers,
|
835
872
|
device=self.device,
|
836
873
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
874
|
+
start_layer=self.start_layer,
|
875
|
+
end_layer=self.end_layer,
|
837
876
|
)
|
838
877
|
|
839
878
|
if self.token_to_kv_pool_allocator is None:
|
@@ -917,8 +956,10 @@ class ModelRunner:
|
|
917
956
|
|
918
957
|
self.attn_backend = FlashMLABackend(self)
|
919
958
|
elif self.server_args.attention_backend == "fa3":
|
920
|
-
assert
|
921
|
-
|
959
|
+
assert (
|
960
|
+
torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
|
961
|
+
) or torch.cuda.get_device_capability()[0] == 9, (
|
962
|
+
"FlashAttention v3 Backend requires SM>=80 and SM<=90. "
|
922
963
|
"Please use `--attention-backend flashinfer`."
|
923
964
|
)
|
924
965
|
from sglang.srt.layers.attention.flashattention_backend import (
|
@@ -926,6 +967,12 @@ class ModelRunner:
|
|
926
967
|
)
|
927
968
|
|
928
969
|
self.attn_backend = FlashAttentionBackend(self)
|
970
|
+
elif self.server_args.attention_backend == "cutlass_mla":
|
971
|
+
from sglang.srt.layers.attention.cutlass_mla_backend import (
|
972
|
+
CutlassMLABackend,
|
973
|
+
)
|
974
|
+
|
975
|
+
self.attn_backend = CutlassMLABackend(self)
|
929
976
|
else:
|
930
977
|
raise ValueError(
|
931
978
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
@@ -938,7 +985,7 @@ class ModelRunner:
|
|
938
985
|
with open(self.server_args.ds_channel_config_path, "r") as f:
|
939
986
|
channel_config = json.load(f)
|
940
987
|
|
941
|
-
for i in range(self.
|
988
|
+
for i in range(self.start_layer, self.end_layer):
|
942
989
|
key = "model.layers." + str(i) + ".self_attn" + selected_channel
|
943
990
|
self.sorted_channels.append(
|
944
991
|
torch.tensor(channel_config[key])[
|
@@ -968,7 +1015,7 @@ class ModelRunner:
|
|
968
1015
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
969
1016
|
logger.info(
|
970
1017
|
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
|
971
|
-
f"
|
1018
|
+
f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
972
1019
|
)
|
973
1020
|
|
974
1021
|
def apply_torch_tp(self):
|
@@ -978,64 +1025,82 @@ class ModelRunner:
|
|
978
1025
|
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
979
1026
|
tensor_parallel(self.model, device_mesh)
|
980
1027
|
|
981
|
-
def forward_decode(
|
1028
|
+
def forward_decode(
|
1029
|
+
self, forward_batch: ForwardBatch, pp_proxy_tensors=None
|
1030
|
+
) -> LogitsProcessorOutput:
|
982
1031
|
self.attn_backend.init_forward_metadata(forward_batch)
|
1032
|
+
# FIXME: add pp_proxy_tensors arg to all models
|
1033
|
+
kwargs = {}
|
1034
|
+
if self.support_pp:
|
1035
|
+
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
983
1036
|
return self.model.forward(
|
984
|
-
forward_batch.input_ids, forward_batch.positions, forward_batch
|
1037
|
+
forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
|
985
1038
|
)
|
986
1039
|
|
987
1040
|
def forward_extend(
|
988
|
-
self,
|
989
|
-
|
1041
|
+
self,
|
1042
|
+
forward_batch: ForwardBatch,
|
1043
|
+
skip_attn_backend_init: bool = False,
|
1044
|
+
pp_proxy_tensors=None,
|
1045
|
+
) -> LogitsProcessorOutput:
|
990
1046
|
if not skip_attn_backend_init:
|
991
1047
|
self.attn_backend.init_forward_metadata(forward_batch)
|
992
1048
|
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1005
|
-
|
1006
|
-
# Only embedding models have get_embedding parameter
|
1007
|
-
return self.model.forward(
|
1008
|
-
forward_batch.input_ids,
|
1009
|
-
forward_batch.positions,
|
1010
|
-
forward_batch,
|
1011
|
-
get_embedding=True,
|
1012
|
-
)
|
1049
|
+
kwargs = {}
|
1050
|
+
if self.support_pp:
|
1051
|
+
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
1052
|
+
if forward_batch.input_embeds is not None:
|
1053
|
+
kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
|
1054
|
+
if not self.is_generation:
|
1055
|
+
kwargs["get_embedding"] = True
|
1056
|
+
return self.model.forward(
|
1057
|
+
forward_batch.input_ids,
|
1058
|
+
forward_batch.positions,
|
1059
|
+
forward_batch,
|
1060
|
+
**kwargs,
|
1061
|
+
)
|
1013
1062
|
|
1014
|
-
def forward_idle(
|
1063
|
+
def forward_idle(
|
1064
|
+
self, forward_batch: ForwardBatch, pp_proxy_tensors=None
|
1065
|
+
) -> LogitsProcessorOutput:
|
1066
|
+
kwargs = {}
|
1067
|
+
if self.support_pp:
|
1068
|
+
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
1015
1069
|
return self.model.forward(
|
1016
|
-
forward_batch.input_ids,
|
1070
|
+
forward_batch.input_ids,
|
1071
|
+
forward_batch.positions,
|
1072
|
+
forward_batch,
|
1073
|
+
**kwargs,
|
1017
1074
|
)
|
1018
1075
|
|
1019
1076
|
def forward(
|
1020
|
-
self,
|
1021
|
-
|
1022
|
-
|
1077
|
+
self,
|
1078
|
+
forward_batch: ForwardBatch,
|
1079
|
+
skip_attn_backend_init: bool = False,
|
1080
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
1081
|
+
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
1082
|
+
can_run_cuda_graph = bool(
|
1023
1083
|
forward_batch.forward_mode.is_cuda_graph()
|
1024
1084
|
and self.cuda_graph_runner
|
1025
1085
|
and self.cuda_graph_runner.can_run(forward_batch)
|
1026
|
-
)
|
1086
|
+
)
|
1087
|
+
if can_run_cuda_graph:
|
1027
1088
|
return self.cuda_graph_runner.replay(
|
1028
|
-
forward_batch,
|
1089
|
+
forward_batch,
|
1090
|
+
skip_attn_backend_init=skip_attn_backend_init,
|
1091
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
1029
1092
|
)
|
1030
1093
|
|
1031
1094
|
if forward_batch.forward_mode.is_decode():
|
1032
|
-
return self.forward_decode(forward_batch)
|
1095
|
+
return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
1033
1096
|
elif forward_batch.forward_mode.is_extend():
|
1034
1097
|
return self.forward_extend(
|
1035
|
-
forward_batch,
|
1098
|
+
forward_batch,
|
1099
|
+
skip_attn_backend_init=skip_attn_backend_init,
|
1100
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
1036
1101
|
)
|
1037
1102
|
elif forward_batch.forward_mode.is_idle():
|
1038
|
-
return self.forward_idle(forward_batch)
|
1103
|
+
return self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
1039
1104
|
else:
|
1040
1105
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
1041
1106
|
|