sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +0 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +62 -6
- sglang/srt/disaggregation/mini_lb.py +5 -1
- sglang/srt/disaggregation/mooncake/conn.py +32 -62
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/prefill.py +40 -4
- sglang/srt/disaggregation/utils.py +15 -0
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +114 -71
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -57
- sglang/srt/layers/quantization/fp8_utils.py +187 -262
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +3 -2
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +1 -0
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +2 -4
- sglang/srt/managers/scheduler.py +12 -71
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +7 -2
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/model_runner.py +20 -27
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +289 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +29 -201
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +34 -32
- sglang/srt/speculative/eagle_worker.py +4 -7
- sglang/srt/utils.py +16 -1
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""Radix attention."""
|
15
15
|
|
16
|
+
from enum import Enum
|
16
17
|
from typing import Optional
|
17
18
|
|
18
19
|
from torch import nn
|
@@ -22,6 +23,18 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
22
23
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
23
24
|
|
24
25
|
|
26
|
+
class AttentionType(Enum):
|
27
|
+
"""
|
28
|
+
Attention type.
|
29
|
+
Use string to be compatible with `torch.compile`.
|
30
|
+
"""
|
31
|
+
|
32
|
+
# Decoder attention between previous layer Q/K/V
|
33
|
+
DECODER = "decoder"
|
34
|
+
# Encoder attention between previous layer Q/K/V
|
35
|
+
ENCODER_ONLY = "encoder_only"
|
36
|
+
|
37
|
+
|
25
38
|
class RadixAttention(nn.Module):
|
26
39
|
"""
|
27
40
|
The attention layer implementation.
|
@@ -39,6 +52,7 @@ class RadixAttention(nn.Module):
|
|
39
52
|
sliding_window_size: int = -1,
|
40
53
|
is_cross_attention: bool = False,
|
41
54
|
quant_config: Optional[QuantizationConfig] = None,
|
55
|
+
attn_type=AttentionType.DECODER,
|
42
56
|
prefix: str = "",
|
43
57
|
use_irope: bool = False,
|
44
58
|
):
|
@@ -64,6 +78,7 @@ class RadixAttention(nn.Module):
|
|
64
78
|
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
65
79
|
if self.quant_method is not None:
|
66
80
|
self.quant_method.create_weights(self)
|
81
|
+
self.attn_type = attn_type
|
67
82
|
|
68
83
|
def forward(
|
69
84
|
self,
|
@@ -11,10 +11,11 @@ from sglang.srt.custom_op import CustomOp
|
|
11
11
|
from sglang.srt.utils import is_cuda_available
|
12
12
|
|
13
13
|
_is_cuda_available = is_cuda_available()
|
14
|
+
|
14
15
|
if _is_cuda_available:
|
15
16
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
16
17
|
else:
|
17
|
-
from vllm import
|
18
|
+
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
|
18
19
|
|
19
20
|
|
20
21
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
@@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp):
|
|
159
160
|
)
|
160
161
|
else:
|
161
162
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
162
|
-
|
163
|
+
vllm_rotary_embedding(
|
163
164
|
positions,
|
164
165
|
query,
|
165
166
|
key,
|
sglang/srt/layers/sampler.py
CHANGED
@@ -93,28 +93,23 @@ class Sampler(nn.Module):
|
|
93
93
|
).clamp(min=torch.finfo(probs.dtype).min)
|
94
94
|
|
95
95
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
96
|
-
uniform_samples = torch.rand(
|
97
|
-
(max_top_k_round, batch_size), device=probs.device
|
98
|
-
)
|
99
96
|
if sampling_info.need_min_p_sampling:
|
100
97
|
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
101
98
|
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
102
99
|
batch_next_token_ids = min_p_sampling_from_probs(
|
103
|
-
probs,
|
100
|
+
probs, sampling_info.min_ps
|
104
101
|
)
|
105
102
|
else:
|
106
|
-
|
103
|
+
# Check Nan will throw exception, only check when crash_on_warnings is True
|
104
|
+
check_nan = self.use_nan_detection and crash_on_warnings()
|
105
|
+
batch_next_token_ids = top_k_top_p_sampling_from_probs(
|
107
106
|
probs,
|
108
|
-
uniform_samples,
|
109
107
|
sampling_info.top_ks,
|
110
108
|
sampling_info.top_ps,
|
111
109
|
filter_apply_order="joint",
|
110
|
+
check_nan=check_nan,
|
112
111
|
)
|
113
112
|
|
114
|
-
if self.use_nan_detection and not torch.all(success):
|
115
|
-
logger.warning("Detected errors during sampling!")
|
116
|
-
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
117
|
-
|
118
113
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
119
114
|
# A slower fallback implementation with torch native operations.
|
120
115
|
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
@@ -75,7 +75,7 @@ class BaseLoRABackend:
|
|
75
75
|
qkv_lora_a: torch.Tensor,
|
76
76
|
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
77
77
|
*args,
|
78
|
-
**kwargs
|
78
|
+
**kwargs,
|
79
79
|
) -> torch.Tensor:
|
80
80
|
"""Run the lora pass for QKV Layer.
|
81
81
|
|
@@ -98,7 +98,7 @@ class BaseLoRABackend:
|
|
98
98
|
gate_up_lora_a: torch.Tensor,
|
99
99
|
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
100
100
|
*args,
|
101
|
-
**kwargs
|
101
|
+
**kwargs,
|
102
102
|
) -> torch.Tensor:
|
103
103
|
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
|
104
104
|
|
@@ -115,3 +115,19 @@ class BaseLoRABackend:
|
|
115
115
|
|
116
116
|
def set_batch_info(self, batch_info: LoRABatchInfo):
|
117
117
|
self.batch_info = batch_info
|
118
|
+
|
119
|
+
|
120
|
+
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
121
|
+
"""
|
122
|
+
Get corresponding backend class from backend's name
|
123
|
+
"""
|
124
|
+
if name == "triton":
|
125
|
+
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
126
|
+
|
127
|
+
return TritonLoRABackend
|
128
|
+
elif name == "flashinfer":
|
129
|
+
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
|
130
|
+
|
131
|
+
return FlashInferLoRABackend
|
132
|
+
else:
|
133
|
+
raise ValueError(f"Invalid backend: {name}")
|
@@ -2,7 +2,7 @@ from typing import Tuple
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from sglang.srt.lora.backend import BaseLoRABackend
|
5
|
+
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
6
6
|
from sglang.srt.lora.utils import LoRABatchInfo
|
7
7
|
from sglang.srt.utils import is_flashinfer_available
|
8
8
|
|
sglang/srt/lora/layers.py
CHANGED
@@ -16,7 +16,7 @@ from sglang.srt.layers.linear import (
|
|
16
16
|
RowParallelLinear,
|
17
17
|
)
|
18
18
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
19
|
-
from sglang.srt.lora.backend import BaseLoRABackend
|
19
|
+
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
20
20
|
|
21
21
|
|
22
22
|
class BaseLayerWithLoRA(nn.Module):
|
sglang/srt/lora/lora.py
CHANGED
@@ -27,7 +27,7 @@ from torch import nn
|
|
27
27
|
|
28
28
|
from sglang.srt.configs.load_config import LoadConfig
|
29
29
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
30
|
-
from sglang.srt.lora.backend import BaseLoRABackend
|
30
|
+
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
31
31
|
from sglang.srt.lora.lora_config import LoRAConfig
|
32
32
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
33
33
|
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -22,7 +22,7 @@ import torch
|
|
22
22
|
|
23
23
|
from sglang.srt.configs.load_config import LoadConfig
|
24
24
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
25
|
-
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
|
25
|
+
from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
|
26
26
|
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
27
27
|
from sglang.srt.lora.lora import LoRAAdapter
|
28
28
|
from sglang.srt.lora.lora_config import LoRAConfig
|
sglang/srt/managers/io_struct.py
CHANGED
sglang/srt/managers/mm_utils.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
"""
|
2
|
-
|
2
|
+
Multi-modality utils
|
3
3
|
"""
|
4
4
|
|
5
|
+
import logging
|
5
6
|
from abc import abstractmethod
|
6
7
|
from typing import Callable, List, Optional, Tuple
|
7
8
|
|
@@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import (
|
|
12
13
|
MultimodalDataItem,
|
13
14
|
MultimodalInputs,
|
14
15
|
global_server_args_dict,
|
15
|
-
logger,
|
16
16
|
)
|
17
17
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
18
18
|
from sglang.srt.utils import print_warning_once
|
19
|
-
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
20
21
|
|
21
22
|
|
22
23
|
class MultiModalityDataPaddingPattern:
|
@@ -8,8 +8,6 @@ from typing import List, Optional
|
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import PIL
|
11
|
-
from decord import VideoReader, cpu
|
12
|
-
from PIL import Image
|
13
11
|
from transformers import BaseImageProcessorFast
|
14
12
|
|
15
13
|
from sglang.srt.managers.schedule_batch import Modality
|
@@ -102,6 +100,9 @@ class BaseMultimodalProcessor(ABC):
|
|
102
100
|
"""
|
103
101
|
estimate the total frame count from all visual input
|
104
102
|
"""
|
103
|
+
# Lazy import because decord is not available on some arm platforms.
|
104
|
+
from decord import VideoReader, cpu
|
105
|
+
|
105
106
|
# Before processing inputs
|
106
107
|
estimated_frames_list = []
|
107
108
|
for image in image_data:
|
@@ -67,7 +67,6 @@ global_server_args_dict = {
|
|
67
67
|
"attention_backend": ServerArgs.attention_backend,
|
68
68
|
"sampling_backend": ServerArgs.sampling_backend,
|
69
69
|
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
70
|
-
"disable_mla": ServerArgs.disable_mla,
|
71
70
|
"torchao_config": ServerArgs.torchao_config,
|
72
71
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
73
72
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
@@ -77,12 +76,11 @@ global_server_args_dict = {
|
|
77
76
|
"device": ServerArgs.device,
|
78
77
|
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
79
78
|
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
80
|
-
"enable_flashmla": ServerArgs.enable_flashmla,
|
81
79
|
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
82
80
|
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
81
|
+
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
|
83
82
|
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
84
83
|
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
|
85
|
-
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
|
86
84
|
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
|
87
85
|
}
|
88
86
|
|
@@ -1481,7 +1479,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1481
1479
|
global_server_args_dict["use_mla_backend"]
|
1482
1480
|
and global_server_args_dict["attention_backend"] == "flashinfer"
|
1483
1481
|
)
|
1484
|
-
or global_server_args_dict["
|
1482
|
+
or global_server_args_dict["attention_backend"] == "flashmla"
|
1485
1483
|
or global_server_args_dict["attention_backend"] == "fa3"
|
1486
1484
|
):
|
1487
1485
|
seq_lens_cpu = self.seq_lens.cpu()
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -391,6 +391,7 @@ class Scheduler(
|
|
391
391
|
self.torch_profiler = None
|
392
392
|
self.torch_profiler_output_dir: Optional[str] = None
|
393
393
|
self.profiler_activities: Optional[List[str]] = None
|
394
|
+
self.profiler_id: Optional[str] = None
|
394
395
|
self.profiler_target_forward_ct: Optional[int] = None
|
395
396
|
|
396
397
|
# Init metrics stats
|
@@ -484,7 +485,7 @@ class Scheduler(
|
|
484
485
|
self.tree_cache = HiRadixCache(
|
485
486
|
req_to_token_pool=self.req_to_token_pool,
|
486
487
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
487
|
-
tp_cache_group=self.
|
488
|
+
tp_cache_group=self.tp_cpu_group,
|
488
489
|
page_size=self.page_size,
|
489
490
|
hicache_ratio=server_args.hicache_ratio,
|
490
491
|
)
|
@@ -553,7 +554,7 @@ class Scheduler(
|
|
553
554
|
|
554
555
|
# The decode requests polling kv cache
|
555
556
|
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
556
|
-
gloo_group=self.
|
557
|
+
gloo_group=self.attn_tp_cpu_group,
|
557
558
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
558
559
|
metadata_buffers=metadata_buffers,
|
559
560
|
)
|
@@ -568,7 +569,7 @@ class Scheduler(
|
|
568
569
|
scheduler=self,
|
569
570
|
transfer_queue=self.disagg_decode_transfer_queue,
|
570
571
|
tree_cache=self.tree_cache,
|
571
|
-
gloo_group=self.
|
572
|
+
gloo_group=self.attn_tp_cpu_group,
|
572
573
|
tp_rank=self.tp_rank,
|
573
574
|
tp_size=self.tp_size,
|
574
575
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
@@ -597,7 +598,7 @@ class Scheduler(
|
|
597
598
|
tp_rank=self.tp_rank,
|
598
599
|
tp_size=self.tp_size,
|
599
600
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
600
|
-
gloo_group=self.
|
601
|
+
gloo_group=self.attn_tp_cpu_group,
|
601
602
|
transfer_backend=self.transfer_backend,
|
602
603
|
scheduler=self,
|
603
604
|
)
|
@@ -664,70 +665,6 @@ class Scheduler(
|
|
664
665
|
|
665
666
|
self.last_batch = batch
|
666
667
|
|
667
|
-
@torch.no_grad()
|
668
|
-
def event_loop_normal_disagg_prefill(self):
|
669
|
-
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
670
|
-
|
671
|
-
while True:
|
672
|
-
recv_reqs = self.recv_requests()
|
673
|
-
self.process_input_requests(recv_reqs)
|
674
|
-
self.waiting_queue.extend(
|
675
|
-
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
676
|
-
)
|
677
|
-
self.process_prefill_chunk()
|
678
|
-
batch = self.get_new_batch_prefill()
|
679
|
-
self.cur_batch = batch
|
680
|
-
|
681
|
-
if batch:
|
682
|
-
result = self.run_batch(batch)
|
683
|
-
self.process_batch_result_disagg_prefill(batch, result)
|
684
|
-
|
685
|
-
if len(self.disagg_prefill_inflight_queue) > 0:
|
686
|
-
self.process_disagg_prefill_inflight_queue()
|
687
|
-
|
688
|
-
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
689
|
-
self.check_memory()
|
690
|
-
self.new_token_ratio = self.init_new_token_ratio
|
691
|
-
|
692
|
-
self.last_batch = batch
|
693
|
-
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
694
|
-
# Otherwise, it hangs under high concurrency
|
695
|
-
self.running_batch.batch_is_full = False
|
696
|
-
|
697
|
-
@torch.no_grad()
|
698
|
-
def event_loop_normal_disagg_decode(self):
|
699
|
-
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
700
|
-
|
701
|
-
while True:
|
702
|
-
recv_reqs = self.recv_requests()
|
703
|
-
self.process_input_requests(recv_reqs)
|
704
|
-
# polling and allocating kv cache
|
705
|
-
self.process_decode_queue()
|
706
|
-
batch = self.get_next_disagg_decode_batch_to_run()
|
707
|
-
self.cur_batch = batch
|
708
|
-
|
709
|
-
if batch:
|
710
|
-
# Generate fake extend output.
|
711
|
-
if batch.forward_mode.is_extend():
|
712
|
-
# Note: Logprobs should be handled on the prefill engine.
|
713
|
-
self.stream_output(
|
714
|
-
batch.reqs, [False for _ in range(len(batch.reqs))]
|
715
|
-
)
|
716
|
-
else:
|
717
|
-
result = self.run_batch(batch)
|
718
|
-
self.process_batch_result(batch, result)
|
719
|
-
|
720
|
-
if batch is None and (
|
721
|
-
len(self.disagg_decode_transfer_queue.queue)
|
722
|
-
+ len(self.disagg_decode_prealloc_queue.queue)
|
723
|
-
== 0
|
724
|
-
):
|
725
|
-
# When the server is idle, do self-check and re-init some states
|
726
|
-
self.check_memory()
|
727
|
-
self.new_token_ratio = self.init_new_token_ratio
|
728
|
-
|
729
|
-
self.last_batch = batch
|
730
|
-
|
731
668
|
def recv_requests(self) -> List[Req]:
|
732
669
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
733
670
|
if self.attn_tp_rank == 0:
|
@@ -1869,6 +1806,7 @@ class Scheduler(
|
|
1869
1806
|
recv_req.activities,
|
1870
1807
|
recv_req.with_stack,
|
1871
1808
|
recv_req.record_shapes,
|
1809
|
+
recv_req.profile_id,
|
1872
1810
|
)
|
1873
1811
|
else:
|
1874
1812
|
return self.stop_profile()
|
@@ -1880,6 +1818,7 @@ class Scheduler(
|
|
1880
1818
|
activities: Optional[List[str]],
|
1881
1819
|
with_stack: Optional[bool],
|
1882
1820
|
record_shapes: Optional[bool],
|
1821
|
+
profile_id: Optional[str],
|
1883
1822
|
) -> None:
|
1884
1823
|
if self.profiler_activities:
|
1885
1824
|
return ProfileReqOutput(
|
@@ -1894,9 +1833,11 @@ class Scheduler(
|
|
1894
1833
|
|
1895
1834
|
self.torch_profiler_output_dir = output_dir
|
1896
1835
|
self.profiler_activities = activities
|
1836
|
+
self.profiler_id = profile_id
|
1897
1837
|
logger.info(
|
1898
|
-
"Profiling starts. Traces will be saved to: %s",
|
1838
|
+
"Profiling starts. Traces will be saved to: %s (with id %s)",
|
1899
1839
|
self.torch_profiler_output_dir,
|
1840
|
+
self.profiler_id,
|
1900
1841
|
)
|
1901
1842
|
|
1902
1843
|
activity_map = {
|
@@ -1938,14 +1879,14 @@ class Scheduler(
|
|
1938
1879
|
self.torch_profiler.export_chrome_trace(
|
1939
1880
|
os.path.join(
|
1940
1881
|
self.torch_profiler_output_dir,
|
1941
|
-
|
1882
|
+
self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz",
|
1942
1883
|
)
|
1943
1884
|
)
|
1944
1885
|
|
1945
1886
|
if "MEM" in self.profiler_activities:
|
1946
1887
|
memory_profile_path = os.path.join(
|
1947
1888
|
self.torch_profiler_output_dir,
|
1948
|
-
|
1889
|
+
self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle",
|
1949
1890
|
)
|
1950
1891
|
torch.cuda.memory._dump_snapshot(memory_profile_path)
|
1951
1892
|
torch.cuda.memory._record_memory_history(enabled=None)
|
@@ -92,7 +92,7 @@ class HiRadixCache(RadixCache):
|
|
92
92
|
self.ongoing_write_through[node.id] = node
|
93
93
|
self.inc_lock_ref(node)
|
94
94
|
else:
|
95
|
-
return
|
95
|
+
return 0
|
96
96
|
|
97
97
|
return len(host_indices)
|
98
98
|
|
@@ -153,6 +153,7 @@ class HiRadixCache(RadixCache):
|
|
153
153
|
if x.host_value is None:
|
154
154
|
if self.cache_controller.write_policy == "write_back":
|
155
155
|
num_evicted += self.write_backup(x)
|
156
|
+
pending_nodes.append(x)
|
156
157
|
elif self.cache_controller.write_policy == "write_through_selective":
|
157
158
|
num_evicted += self._evict_write_through_selective(x)
|
158
159
|
else:
|
@@ -177,6 +178,9 @@ class HiRadixCache(RadixCache):
|
|
177
178
|
while len(self.ongoing_write_through) > 0:
|
178
179
|
self.writing_check()
|
179
180
|
time.sleep(0.1)
|
181
|
+
for node in pending_nodes:
|
182
|
+
assert node.host_value is not None
|
183
|
+
self._evict_write_through(node)
|
180
184
|
|
181
185
|
def _evict_write_through(self, node: TreeNode):
|
182
186
|
# evict a node already written to host
|
@@ -286,8 +286,12 @@ class MHATokenToKVPool(KVCache):
|
|
286
286
|
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
|
287
287
|
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
|
288
288
|
kv_item_lens = [
|
289
|
-
self.get_key_buffer(i)[0].nbytes
|
290
|
-
|
289
|
+
self.get_key_buffer(i)[0].nbytes * self.page_size
|
290
|
+
for i in range(self.layer_num)
|
291
|
+
] + [
|
292
|
+
self.get_value_buffer(i)[0].nbytes * self.page_size
|
293
|
+
for i in range(self.layer_num)
|
294
|
+
]
|
291
295
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
292
296
|
|
293
297
|
# Todo: different memory layout
|
@@ -414,6 +418,7 @@ class MLATokenToKVPool(KVCache):
|
|
414
418
|
enable_memory_saver: bool,
|
415
419
|
):
|
416
420
|
self.size = size
|
421
|
+
self.page_size = page_size
|
417
422
|
self.dtype = dtype
|
418
423
|
self.device = device
|
419
424
|
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
@@ -37,11 +37,11 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
37
37
|
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
38
38
|
from sglang.srt.utils import get_available_gpu_memory, is_hip
|
39
39
|
|
40
|
-
_is_hip = is_hip()
|
41
|
-
|
42
40
|
if TYPE_CHECKING:
|
43
41
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
44
42
|
|
43
|
+
_is_hip = is_hip()
|
44
|
+
|
45
45
|
|
46
46
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
47
47
|
for sub in model._modules.values():
|
@@ -73,6 +73,7 @@ from sglang.srt.utils import (
|
|
73
73
|
MultiprocessingSerializer,
|
74
74
|
enable_show_time_cost,
|
75
75
|
get_available_gpu_memory,
|
76
|
+
get_bool_env_var,
|
76
77
|
init_custom_process_group,
|
77
78
|
is_cuda,
|
78
79
|
is_fa3_default_architecture,
|
@@ -127,10 +128,7 @@ class ModelRunner:
|
|
127
128
|
self.page_size = server_args.page_size
|
128
129
|
self.req_to_token_pool = req_to_token_pool
|
129
130
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
130
|
-
self.use_mla_backend =
|
131
|
-
self.model_config.attention_arch == AttentionArch.MLA
|
132
|
-
and not server_args.disable_mla
|
133
|
-
)
|
131
|
+
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
134
132
|
self.attention_chunk_size = model_config.attention_chunk_size
|
135
133
|
|
136
134
|
# Model-specific adjustment
|
@@ -139,18 +137,12 @@ class ModelRunner:
|
|
139
137
|
if server_args.show_time_cost:
|
140
138
|
enable_show_time_cost()
|
141
139
|
|
142
|
-
if server_args.disable_outlines_disk_cache:
|
143
|
-
from outlines.caching import disable_cache
|
144
|
-
|
145
|
-
disable_cache()
|
146
|
-
|
147
140
|
# Global vars
|
148
141
|
global_server_args_dict.update(
|
149
142
|
{
|
150
143
|
"attention_backend": server_args.attention_backend,
|
151
144
|
"sampling_backend": server_args.sampling_backend,
|
152
145
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
153
|
-
"disable_mla": server_args.disable_mla,
|
154
146
|
"torchao_config": server_args.torchao_config,
|
155
147
|
"enable_nan_detection": server_args.enable_nan_detection,
|
156
148
|
"enable_dp_attention": server_args.enable_dp_attention,
|
@@ -160,13 +152,12 @@ class ModelRunner:
|
|
160
152
|
"device": server_args.device,
|
161
153
|
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
162
154
|
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
163
|
-
"enable_flashmla": server_args.enable_flashmla,
|
164
155
|
"disable_radix_cache": server_args.disable_radix_cache,
|
165
156
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
157
|
+
"moe_dense_tp_size": server_args.moe_dense_tp_size,
|
166
158
|
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
167
159
|
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
168
160
|
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
169
|
-
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
|
170
161
|
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
|
171
162
|
"use_mla_backend": self.use_mla_backend,
|
172
163
|
}
|
@@ -229,15 +220,7 @@ class ModelRunner:
|
|
229
220
|
def model_specific_adjustment(self):
|
230
221
|
server_args = self.server_args
|
231
222
|
|
232
|
-
if server_args.
|
233
|
-
# TODO: remove this branch after enable_flashinfer_mla is deprecated
|
234
|
-
logger.info("MLA optimization is turned on. Use flashinfer backend.")
|
235
|
-
server_args.attention_backend = "flashinfer"
|
236
|
-
elif server_args.enable_flashmla:
|
237
|
-
# TODO: remove this branch after enable_flashmla is deprecated
|
238
|
-
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
239
|
-
server_args.attention_backend = "flashmla"
|
240
|
-
elif server_args.attention_backend is None:
|
223
|
+
if server_args.attention_backend is None:
|
241
224
|
# By default, use flashinfer for non-mla attention and triton for mla attention
|
242
225
|
if not self.use_mla_backend:
|
243
226
|
if (
|
@@ -263,7 +246,12 @@ class ModelRunner:
|
|
263
246
|
elif self.use_mla_backend:
|
264
247
|
# TODO: add MLA optimization on CPU
|
265
248
|
if server_args.device != "cpu":
|
266
|
-
if server_args.attention_backend in [
|
249
|
+
if server_args.attention_backend in [
|
250
|
+
"flashinfer",
|
251
|
+
"fa3",
|
252
|
+
"triton",
|
253
|
+
"flashmla",
|
254
|
+
]:
|
267
255
|
logger.info(
|
268
256
|
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
269
257
|
)
|
@@ -320,7 +308,6 @@ class ModelRunner:
|
|
320
308
|
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
|
321
309
|
|
322
310
|
if not self.use_mla_backend:
|
323
|
-
logger.info("Disable chunked prefix cache for non-MLA backend.")
|
324
311
|
server_args.disable_chunked_prefix_cache = True
|
325
312
|
elif self.page_size > 1:
|
326
313
|
logger.info("Disable chunked prefix cache when page size > 1.")
|
@@ -387,10 +374,16 @@ class ModelRunner:
|
|
387
374
|
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
388
375
|
if self.tp_size > 1:
|
389
376
|
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
377
|
+
if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
|
378
|
+
logger.warning(
|
379
|
+
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
|
380
|
+
f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
|
381
|
+
)
|
382
|
+
else:
|
383
|
+
raise ValueError(
|
384
|
+
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
|
385
|
+
f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
|
386
|
+
)
|
394
387
|
|
395
388
|
logger.info(
|
396
389
|
f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
|