sglang 0.4.10.post1__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/conversation.py +0 -112
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
- sglang/srt/disaggregation/prefill.py +1 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +11 -0
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +35 -15
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/hf_transformers_utils.py +25 -10
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +27 -10
- sglang/srt/layers/communicator.py +14 -4
- sglang/srt/layers/linear.py +7 -1
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/ep_moe/layer.py +11 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +26 -23
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/utils.py +43 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp8.py +5 -1
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +7 -1
- sglang/srt/lora/lora_registry.py +7 -0
- sglang/srt/managers/cache_controller.py +8 -4
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +3 -2
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +144 -6
- sglang/srt/managers/template_manager.py +25 -22
- sglang/srt/managers/tokenizer_manager.py +114 -62
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +17 -3
- sglang/srt/model_executor/forward_batch_info.py +13 -3
- sglang/srt/model_executor/model_runner.py +5 -0
- sglang/srt/models/deepseek_v2.py +23 -17
- sglang/srt/models/glm4_moe.py +82 -19
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2_moe.py +1 -4
- sglang/srt/models/qwen3_moe.py +7 -8
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/server_args.py +80 -20
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +6 -4
- sglang/srt/utils.py +3 -24
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +80 -74
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -146,34 +146,3 @@ def triton_kernel_fused_experts(
|
|
146
146
|
)
|
147
147
|
|
148
148
|
return intermediate_cache3
|
149
|
-
|
150
|
-
|
151
|
-
def triton_kernel_moe_forward_fake(
|
152
|
-
hidden_states: torch.Tensor,
|
153
|
-
w1: torch.Tensor,
|
154
|
-
w2: torch.Tensor,
|
155
|
-
gating_output: torch.Tensor,
|
156
|
-
topk: int,
|
157
|
-
renormalize: bool,
|
158
|
-
inplace: bool = False,
|
159
|
-
activation: str = "silu",
|
160
|
-
apply_router_weight_on_input: bool = False,
|
161
|
-
use_fp8_w8a8: bool = False,
|
162
|
-
per_channel_quant: bool = False,
|
163
|
-
global_num_experts: int = -1,
|
164
|
-
expert_map: Optional[torch.Tensor] = None,
|
165
|
-
w1_scale: Optional[torch.Tensor] = None,
|
166
|
-
w2_scale: Optional[torch.Tensor] = None,
|
167
|
-
a1_scale: Optional[torch.Tensor] = None,
|
168
|
-
a2_scale: Optional[torch.Tensor] = None,
|
169
|
-
block_shape: Optional[list[int]] = None,
|
170
|
-
) -> torch.Tensor:
|
171
|
-
return torch.empty_like(hidden_states)
|
172
|
-
|
173
|
-
|
174
|
-
direct_register_custom_op(
|
175
|
-
op_name="forward_cuda_triton",
|
176
|
-
op_func=triton_kernel_moe_forward,
|
177
|
-
mutates_args=[],
|
178
|
-
fake_impl=triton_kernel_moe_forward_fake,
|
179
|
-
)
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
2
|
+
BaseDispatcher,
|
3
|
+
BaseDispatcherConfig,
|
4
|
+
DispatchOutput,
|
5
|
+
DispatchOutputFormat,
|
6
|
+
)
|
7
|
+
from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
8
|
+
DeepEPConfig,
|
9
|
+
DeepEPDispatcher,
|
10
|
+
DeepEPLLOutput,
|
11
|
+
DeepEPNormalOutput,
|
12
|
+
)
|
13
|
+
|
14
|
+
__all__ = [
|
15
|
+
"BaseDispatcher",
|
16
|
+
"BaseDispatcherConfig",
|
17
|
+
"DispatchOutput",
|
18
|
+
"DispatchOutputFormat",
|
19
|
+
"DeepEPConfig",
|
20
|
+
"DeepEPDispatcher",
|
21
|
+
"DeepEPNormalOutput",
|
22
|
+
"DeepEPLLOutput",
|
23
|
+
]
|
@@ -2,11 +2,22 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from enum import Enum, auto
|
5
|
-
from typing import
|
5
|
+
from typing import Protocol, runtime_checkable
|
6
6
|
|
7
7
|
import torch
|
8
8
|
|
9
9
|
|
10
|
+
class MoEA2ABackend(Enum):
|
11
|
+
none = "none"
|
12
|
+
deepep = "deepep"
|
13
|
+
|
14
|
+
def is_none(self):
|
15
|
+
return self == MoEA2ABackend.none
|
16
|
+
|
17
|
+
def is_deepep(self):
|
18
|
+
return self == MoEA2ABackend.deepep
|
19
|
+
|
20
|
+
|
10
21
|
class DispatchOutputFormat(Enum):
|
11
22
|
standard = auto()
|
12
23
|
deepep_normal = auto()
|
@@ -1,5 +1,3 @@
|
|
1
|
-
# TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
|
2
|
-
|
3
1
|
from __future__ import annotations
|
4
2
|
|
5
3
|
import logging
|
@@ -22,15 +20,10 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
|
22
20
|
DispatchOutput,
|
23
21
|
DispatchOutputFormat,
|
24
22
|
)
|
23
|
+
from sglang.srt.layers.moe.utils import DeepEPMode
|
25
24
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
26
25
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
27
|
-
from sglang.srt.utils import
|
28
|
-
DeepEPMode,
|
29
|
-
get_bool_env_var,
|
30
|
-
get_int_env_var,
|
31
|
-
is_hip,
|
32
|
-
load_json_config,
|
33
|
-
)
|
26
|
+
from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
|
34
27
|
|
35
28
|
try:
|
36
29
|
from deep_ep import Buffer, Config
|
@@ -150,9 +143,9 @@ class DeepEPBuffer:
|
|
150
143
|
num_rdma_bytes,
|
151
144
|
)
|
152
145
|
|
153
|
-
if deepep_mode == DeepEPMode.
|
146
|
+
if deepep_mode == DeepEPMode.NORMAL:
|
154
147
|
num_qps_per_rank = DeepEPConfig.get_instance().num_sms // 2
|
155
|
-
elif deepep_mode in [DeepEPMode.
|
148
|
+
elif deepep_mode in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]:
|
156
149
|
num_qps_per_rank = num_experts // group.size()
|
157
150
|
else:
|
158
151
|
raise NotImplementedError
|
@@ -161,7 +154,7 @@ class DeepEPBuffer:
|
|
161
154
|
device="cuda"
|
162
155
|
).multi_processor_count
|
163
156
|
if (
|
164
|
-
(deepep_mode != DeepEPMode.
|
157
|
+
(deepep_mode != DeepEPMode.LOW_LATENCY)
|
165
158
|
and not global_server_args_dict["enable_two_batch_overlap"]
|
166
159
|
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
167
160
|
):
|
@@ -611,7 +604,7 @@ class DeepEPDispatcher(BaseDispatcher):
|
|
611
604
|
num_local_experts: int = None,
|
612
605
|
hidden_size: int = None,
|
613
606
|
params_dtype: torch.dtype = None,
|
614
|
-
deepep_mode: DeepEPMode = DeepEPMode.
|
607
|
+
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
|
615
608
|
async_finish: bool = False,
|
616
609
|
return_recv_hook: bool = False,
|
617
610
|
):
|
@@ -697,9 +690,9 @@ class DeepEPDispatcher(BaseDispatcher):
|
|
697
690
|
resolved_deepep_mode = self.deepep_mode.resolve(
|
698
691
|
forward_batch.is_extend_in_batch
|
699
692
|
)
|
700
|
-
if resolved_deepep_mode == DeepEPMode.
|
693
|
+
if resolved_deepep_mode == DeepEPMode.NORMAL:
|
701
694
|
return self._normal_dispatcher
|
702
|
-
elif resolved_deepep_mode == DeepEPMode.
|
695
|
+
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
|
703
696
|
return self._low_latency_dispatcher
|
704
697
|
else:
|
705
698
|
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
@@ -0,0 +1,43 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
|
4
|
+
class MoeA2ABackend(Enum):
|
5
|
+
|
6
|
+
STANDARD = ("standard", "none")
|
7
|
+
DEEPEP = "deepep"
|
8
|
+
|
9
|
+
@classmethod
|
10
|
+
def _missing_(cls, value):
|
11
|
+
if value is None:
|
12
|
+
return cls.STANDARD
|
13
|
+
for member in cls:
|
14
|
+
if value in member.value:
|
15
|
+
return member
|
16
|
+
raise ValueError(f"No {cls.__name__} member for value {value}")
|
17
|
+
|
18
|
+
def is_deepep(self):
|
19
|
+
return self == MoeA2ABackend.DEEPEP
|
20
|
+
|
21
|
+
def is_standard(self):
|
22
|
+
return self == MoeA2ABackend.STANDARD
|
23
|
+
|
24
|
+
|
25
|
+
class DeepEPMode(Enum):
|
26
|
+
NORMAL = "normal"
|
27
|
+
LOW_LATENCY = "low_latency"
|
28
|
+
AUTO = "auto"
|
29
|
+
|
30
|
+
def enable_normal(self):
|
31
|
+
return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
|
32
|
+
|
33
|
+
def enable_low_latency(self):
|
34
|
+
return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
|
35
|
+
|
36
|
+
def resolve(self, is_extend_in_batch: bool):
|
37
|
+
if self != DeepEPMode.AUTO:
|
38
|
+
return self
|
39
|
+
|
40
|
+
if is_extend_in_batch:
|
41
|
+
return DeepEPMode.NORMAL
|
42
|
+
else:
|
43
|
+
return DeepEPMode.LOW_LATENCY
|
@@ -23,6 +23,7 @@ from sglang.srt.layers.quantization.utils import (
|
|
23
23
|
from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
|
24
24
|
|
25
25
|
if TYPE_CHECKING:
|
26
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
26
27
|
from sglang.srt.layers.moe.topk import TopKOutput
|
27
28
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
28
29
|
CompressedTensorsConfig,
|
@@ -189,7 +190,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
189
190
|
layer.w13_input_scale = None
|
190
191
|
layer.w2_input_scale = None
|
191
192
|
|
192
|
-
def process_weights_after_loading(self, layer:
|
193
|
+
def process_weights_after_loading(self, layer: FusedMoE) -> None:
|
193
194
|
# Fp8 moe kernels require a single activation scale.
|
194
195
|
# We take the max of all the scales in case they differ.
|
195
196
|
if self.static_input_scales:
|
@@ -246,7 +247,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
246
247
|
assert layer.w13_weight_scale is not None
|
247
248
|
shard_size = layer.intermediate_size_per_partition
|
248
249
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
249
|
-
for expert_id in range(layer.
|
250
|
+
for expert_id in range(layer.num_local_experts):
|
250
251
|
start = 0
|
251
252
|
for shard_id in range(2):
|
252
253
|
dq_weight = per_tensor_dequantize(
|
@@ -148,7 +148,7 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
|
148
148
|
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
149
149
|
"N": n,
|
150
150
|
"K": k,
|
151
|
-
"NUM_GROUPS":
|
151
|
+
"NUM_GROUPS": num_groups,
|
152
152
|
"BLOCK_M": block_m,
|
153
153
|
"BLOCK_N": block_n,
|
154
154
|
"BLOCK_K": block_k,
|
@@ -1039,7 +1039,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1039
1039
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
1040
1040
|
|
1041
1041
|
topk_weights, topk_ids, _ = topk_output
|
1042
|
-
|
1042
|
+
output = cutlass_fused_experts_fp8(
|
1043
1043
|
x,
|
1044
1044
|
layer.w13_weight.transpose(1, 2),
|
1045
1045
|
layer.w2_weight.transpose(1, 2),
|
@@ -1062,6 +1062,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1062
1062
|
self.problem_sizes2,
|
1063
1063
|
use_fp8_blockscale=True,
|
1064
1064
|
)
|
1065
|
+
# TODO: Fuse into select_experts
|
1066
|
+
if routed_scaling_factor is not None:
|
1067
|
+
output *= routed_scaling_factor
|
1068
|
+
return output
|
1065
1069
|
# Expert fusion with FP8 quantization
|
1066
1070
|
return fused_experts(
|
1067
1071
|
x,
|
@@ -354,10 +354,6 @@ def sglang_per_token_group_quant_fp8(
|
|
354
354
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
355
355
|
assert x.is_contiguous(), "`x` is not contiguous"
|
356
356
|
|
357
|
-
if scale_ue8m0:
|
358
|
-
# TODO: handle this case by fixing the (token=4, dim=256, group_size=128) UT case
|
359
|
-
assert x.shape[-1] % (group_size * 4) == 0
|
360
|
-
|
361
357
|
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
362
358
|
x_s = create_per_token_group_quant_fp8_output_scale(
|
363
359
|
x_shape=x.shape,
|
@@ -11,8 +11,12 @@ from sglang.srt.distributed import (
|
|
11
11
|
divide,
|
12
12
|
get_tensor_model_parallel_rank,
|
13
13
|
get_tensor_model_parallel_world_size,
|
14
|
+
parallel_state,
|
14
15
|
tensor_model_parallel_all_reduce,
|
15
16
|
)
|
17
|
+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
18
|
+
use_symmetric_memory,
|
19
|
+
)
|
16
20
|
from sglang.srt.layers.amx_utils import PackWeightMethod
|
17
21
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
18
22
|
from sglang.srt.layers.parameter import BasevLLMParameter
|
@@ -464,7 +468,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
464
468
|
else:
|
465
469
|
masked_input = input_
|
466
470
|
# Get the embeddings.
|
467
|
-
|
471
|
+
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
472
|
+
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
473
|
+
sm.tag(output_parallel)
|
468
474
|
# Mask the output embedding.
|
469
475
|
if self.tp_size > 1:
|
470
476
|
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
sglang/srt/lora/lora_registry.py
CHANGED
@@ -186,3 +186,10 @@ class LoRARegistry:
|
|
186
186
|
self._registry[lora_ref.lora_name] = lora_ref
|
187
187
|
self._counters[lora_ref.lora_id] = ConcurrentCounter()
|
188
188
|
return lora_ref
|
189
|
+
|
190
|
+
@property
|
191
|
+
def num_registered_loras(self) -> int:
|
192
|
+
"""
|
193
|
+
Returns the total number of LoRA adapters currently registered.
|
194
|
+
"""
|
195
|
+
return len(self._registry)
|
@@ -236,6 +236,7 @@ class HiCacheController:
|
|
236
236
|
self.enable_storage = False
|
237
237
|
# todo: move backend initialization to storage backend module
|
238
238
|
if storage_backend is not None:
|
239
|
+
self.storage_backend_type = storage_backend
|
239
240
|
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
240
241
|
|
241
242
|
if storage_backend == "file":
|
@@ -573,6 +574,9 @@ class HiCacheController:
|
|
573
574
|
self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
|
574
575
|
operation.increment(len(operation.hash_value) * self.page_size)
|
575
576
|
|
577
|
+
def is_mooncake_backend(self):
|
578
|
+
return self.storage_backend_type == "mooncake"
|
579
|
+
|
576
580
|
def prefetch_io_aux_func(self):
|
577
581
|
"""
|
578
582
|
Auxiliary function conducting IO operations for prefetching.
|
@@ -580,7 +584,7 @@ class HiCacheController:
|
|
580
584
|
while not self.stop_event.is_set():
|
581
585
|
try:
|
582
586
|
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
583
|
-
if
|
587
|
+
if self.is_mooncake_backend():
|
584
588
|
self.mooncake_page_transfer(operation)
|
585
589
|
else:
|
586
590
|
self.generic_page_transfer(operation)
|
@@ -615,14 +619,14 @@ class HiCacheController:
|
|
615
619
|
)
|
616
620
|
|
617
621
|
# todo, more unified interface
|
618
|
-
if not
|
622
|
+
if not self.is_mooncake_backend():
|
619
623
|
if not self.storage_backend.exists(last_hash):
|
620
624
|
break
|
621
625
|
hash_value.append(last_hash)
|
622
626
|
storage_hit_count += self.page_size
|
623
627
|
remaining_tokens -= self.page_size
|
624
628
|
|
625
|
-
if
|
629
|
+
if self.is_mooncake_backend():
|
626
630
|
# deferring to batch exists for mooncake store
|
627
631
|
exist_result = self.storage_backend.exists(hash_value)
|
628
632
|
storage_hit_count = (
|
@@ -744,7 +748,7 @@ class HiCacheController:
|
|
744
748
|
remaining_tokens -= self.page_size
|
745
749
|
operation.hash_value = hash_value
|
746
750
|
|
747
|
-
if
|
751
|
+
if self.is_mooncake_backend():
|
748
752
|
self.mooncake_page_backup(operation)
|
749
753
|
else:
|
750
754
|
self.generic_page_backup(operation)
|
@@ -16,9 +16,13 @@
|
|
16
16
|
import logging
|
17
17
|
import multiprocessing as mp
|
18
18
|
import signal
|
19
|
+
import struct
|
20
|
+
import sys
|
19
21
|
import threading
|
20
22
|
import time
|
21
23
|
from enum import Enum, auto
|
24
|
+
from multiprocessing import shared_memory
|
25
|
+
from typing import Dict, List
|
22
26
|
|
23
27
|
import psutil
|
24
28
|
import setproctitle
|
@@ -32,6 +36,7 @@ from sglang.srt.managers.io_struct import (
|
|
32
36
|
)
|
33
37
|
from sglang.srt.managers.schedule_batch import Req
|
34
38
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
39
|
+
from sglang.srt.managers.utils import DPBalanceMeta
|
35
40
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
36
41
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
37
42
|
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
|
@@ -45,6 +50,7 @@ class LoadBalanceMethod(Enum):
|
|
45
50
|
|
46
51
|
ROUND_ROBIN = auto()
|
47
52
|
SHORTEST_QUEUE = auto()
|
53
|
+
MINIMUM_TOKENS = auto()
|
48
54
|
|
49
55
|
@classmethod
|
50
56
|
def from_str(cls, method: str):
|
@@ -58,7 +64,16 @@ class LoadBalanceMethod(Enum):
|
|
58
64
|
class DataParallelController:
|
59
65
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
60
66
|
|
61
|
-
def __init__(
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
server_args: ServerArgs,
|
70
|
+
port_args: PortArgs,
|
71
|
+
dp_balance_meta: DPBalanceMeta,
|
72
|
+
) -> None:
|
73
|
+
# for dp balance
|
74
|
+
self.global_balance_id = 0
|
75
|
+
self.balance_meta = dp_balance_meta
|
76
|
+
|
62
77
|
# Parse args
|
63
78
|
self.max_total_num_tokens = None
|
64
79
|
self.server_args = server_args
|
@@ -79,6 +94,7 @@ class DataParallelController:
|
|
79
94
|
dispatch_lookup = {
|
80
95
|
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
81
96
|
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
97
|
+
LoadBalanceMethod.MINIMUM_TOKENS: self.minimum_tokens_scheduler,
|
82
98
|
}
|
83
99
|
self.dispatching = dispatch_lookup[self.load_balance_method]
|
84
100
|
|
@@ -234,6 +250,7 @@ class DataParallelController:
|
|
234
250
|
pp_rank,
|
235
251
|
dp_rank,
|
236
252
|
writer,
|
253
|
+
self.balance_meta,
|
237
254
|
),
|
238
255
|
)
|
239
256
|
with memory_saver_adapter.configure_subprocess():
|
@@ -269,6 +286,33 @@ class DataParallelController:
|
|
269
286
|
def shortest_queue_scheduler(self, input_requests):
|
270
287
|
raise NotImplementedError()
|
271
288
|
|
289
|
+
def minimum_tokens_scheduler(self, req):
|
290
|
+
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
|
291
|
+
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
|
292
|
+
def get_next_global_balance_id() -> int:
|
293
|
+
INT32_MAX = 2147483647
|
294
|
+
current_id = self.global_balance_id
|
295
|
+
self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
|
296
|
+
return current_id
|
297
|
+
|
298
|
+
req.dp_balance_id = get_next_global_balance_id()
|
299
|
+
with self.balance_meta.mutex:
|
300
|
+
# 1. local_tokens represents the tokens currently inferring on the worker,
|
301
|
+
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
|
302
|
+
onfly_info = self.balance_meta.get_shared_onfly()
|
303
|
+
local_tokens = self.balance_meta.get_shared_local_tokens()
|
304
|
+
total_tokens = [
|
305
|
+
local_token + sum(onfly_dict.values())
|
306
|
+
for local_token, onfly_dict in zip(local_tokens, onfly_info)
|
307
|
+
]
|
308
|
+
target_worker = total_tokens.index(min(total_tokens))
|
309
|
+
onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids)
|
310
|
+
# 2. write the new onfly info to the shm
|
311
|
+
self.balance_meta.set_shared_onfly_info(onfly_info)
|
312
|
+
|
313
|
+
# logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
|
314
|
+
self.workers[target_worker].send_pyobj(req)
|
315
|
+
|
272
316
|
def event_loop(self):
|
273
317
|
while True:
|
274
318
|
while True:
|
@@ -302,9 +346,12 @@ def run_data_parallel_controller_process(
|
|
302
346
|
setproctitle.setproctitle("sglang::data_parallel_controller")
|
303
347
|
configure_logger(server_args)
|
304
348
|
parent_process = psutil.Process().parent()
|
349
|
+
balance_meta = DPBalanceMeta(server_args.dp_size)
|
305
350
|
|
306
351
|
try:
|
307
|
-
controller = DataParallelController(
|
352
|
+
controller = DataParallelController(
|
353
|
+
server_args, port_args, dp_balance_meta=balance_meta
|
354
|
+
)
|
308
355
|
pipe_writer.send(
|
309
356
|
{
|
310
357
|
"status": "ready",
|
@@ -323,3 +370,6 @@ def run_data_parallel_controller_process(
|
|
323
370
|
traceback = get_exception_traceback()
|
324
371
|
logger.error(f"DataParallelController hit an exception: {traceback}")
|
325
372
|
parent_process.send_signal(signal.SIGQUIT)
|
373
|
+
finally:
|
374
|
+
# we need to destruct mp.Manager() in balance_meta
|
375
|
+
balance_meta.destructor()
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -523,6 +523,9 @@ class TokenizedGenerateReqInput:
|
|
523
523
|
# For data parallel rank routing
|
524
524
|
data_parallel_rank: Optional[int] = None
|
525
525
|
|
526
|
+
# For dp balance
|
527
|
+
dp_balance_id: int = -1
|
528
|
+
|
526
529
|
|
527
530
|
@dataclass
|
528
531
|
class EmbeddingReqInput:
|
@@ -648,6 +651,8 @@ class TokenizedEmbeddingReqInput:
|
|
648
651
|
token_type_ids: List[int]
|
649
652
|
# Dummy sampling params for compatibility
|
650
653
|
sampling_params: SamplingParams
|
654
|
+
# For dp balance
|
655
|
+
dp_balance_id: int = -1
|
651
656
|
|
652
657
|
|
653
658
|
@dataclass
|
@@ -1097,7 +1102,7 @@ class UnloadLoRAAdapterReqInput:
|
|
1097
1102
|
class LoRAUpdateResult:
|
1098
1103
|
success: bool
|
1099
1104
|
error_message: Optional[str] = None
|
1100
|
-
loaded_adapters: Dict[str, LoRARef] =
|
1105
|
+
loaded_adapters: Optional[Dict[str, LoRARef]] = None
|
1101
1106
|
|
1102
1107
|
|
1103
1108
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
51
51
|
ScheduleBatchDisaggregationDecodeMixin,
|
52
52
|
)
|
53
53
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
54
|
+
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
54
55
|
from sglang.srt.mem_cache.allocator import (
|
55
56
|
BaseTokenToKVPoolAllocator,
|
56
57
|
SWATokenToKVPoolAllocator,
|
@@ -85,9 +86,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
85
86
|
"enable_dp_attention",
|
86
87
|
"enable_two_batch_overlap",
|
87
88
|
"enable_dp_lm_head",
|
88
|
-
"
|
89
|
+
"moe_a2a_backend",
|
89
90
|
"deepep_mode",
|
90
|
-
"enable_ep_moe",
|
91
91
|
"enable_flashinfer_cutlass_moe",
|
92
92
|
"enable_flashinfer_trtllm_moe",
|
93
93
|
"enable_flashinfer_allreduce_fusion",
|
@@ -108,6 +108,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
108
108
|
"weight_loader_disable_mmap",
|
109
109
|
"enable_triton_kernel_moe",
|
110
110
|
"enable_multimodal",
|
111
|
+
"enable_symm_mem",
|
111
112
|
]
|
112
113
|
|
113
114
|
# Put some global args for easy access
|
@@ -455,7 +455,9 @@ class PrefillAdder:
|
|
455
455
|
if not self.is_hybrid:
|
456
456
|
# Skip this logic for swa. The SWA has different memory management, and
|
457
457
|
# this mechanism is underestimating the memory usage.
|
458
|
-
cur_rem_tokens = self.cur_rem_tokens -
|
458
|
+
cur_rem_tokens = self.cur_rem_tokens - self.ceil_paged_tokens(
|
459
|
+
req.extend_input_len
|
460
|
+
)
|
459
461
|
tokens_freed = 0
|
460
462
|
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
461
463
|
# tokens_left gives a reservative calculation as the last token is not stored
|