sglang 0.5.0rc1__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/disaggregation/decode.py +0 -1
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/http_server.py +64 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -0
- sglang/srt/entrypoints/openai/serving_chat.py +1 -0
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/layers/attention/flashinfer_backend.py +3 -0
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +10 -3
- sglang/srt/layers/communicator.py +7 -7
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +5 -32
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +52 -30
- sglang/srt/layers/quantization/mxfp4.py +16 -2
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/managers/cache_controller.py +4 -1
- sglang/srt/managers/io_struct.py +14 -0
- sglang/srt/managers/schedule_batch.py +18 -39
- sglang/srt/managers/scheduler.py +3 -4
- sglang/srt/managers/tokenizer_manager.py +28 -18
- sglang/srt/mem_cache/allocator.py +8 -157
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +8 -21
- sglang/srt/model_executor/forward_batch_info.py +8 -10
- sglang/srt/model_executor/model_runner.py +57 -53
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +5 -3
- sglang/srt/models/glm4_moe.py +2 -2
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/gpt_oss.py +7 -2
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -5
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +33 -7
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/two_batch_overlap.py +4 -8
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +5 -5
- {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +75 -63
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py
CHANGED
@@ -267,7 +267,6 @@ def extend(reqs, model_runner):
|
|
267
267
|
model_config=model_runner.model_config,
|
268
268
|
enable_overlap=False,
|
269
269
|
spec_algorithm=SpeculativeAlgorithm.NONE,
|
270
|
-
enable_custom_logit_processor=False,
|
271
270
|
)
|
272
271
|
batch.prepare_for_extend()
|
273
272
|
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
@@ -642,6 +642,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
|
642
642
|
or "InternLM2ForRewardModel" in model_architectures
|
643
643
|
or "Qwen2ForRewardModel" in model_architectures
|
644
644
|
or "Qwen2ForSequenceClassification" in model_architectures
|
645
|
+
or "Qwen3ForSequenceClassification" in model_architectures
|
645
646
|
or "CLIPModel" in model_architectures
|
646
647
|
or "BertModel" in model_architectures
|
647
648
|
or "Contriever" in model_architectures
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -647,7 +647,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
647
647
|
if server_args.attention_backend == "flashinfer":
|
648
648
|
assert_pkg_version(
|
649
649
|
"flashinfer_python",
|
650
|
-
"0.2.11.
|
650
|
+
"0.2.11.post3",
|
651
651
|
"Please uninstall the old version and "
|
652
652
|
"reinstall the latest version by following the instructions "
|
653
653
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -655,7 +655,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
655
655
|
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
656
656
|
assert_pkg_version(
|
657
657
|
"sgl-kernel",
|
658
|
-
"0.3.
|
658
|
+
"0.3.5",
|
659
659
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
660
660
|
)
|
661
661
|
|
@@ -88,6 +88,7 @@ from sglang.srt.managers.io_struct import (
|
|
88
88
|
UpdateWeightFromDiskReqInput,
|
89
89
|
UpdateWeightsFromDistributedReqInput,
|
90
90
|
UpdateWeightsFromTensorReqInput,
|
91
|
+
UpdateWeightVersionReqInput,
|
91
92
|
VertexGenerateReqInput,
|
92
93
|
)
|
93
94
|
from sglang.srt.managers.template_manager import TemplateManager
|
@@ -342,10 +343,19 @@ async def get_model_info():
|
|
342
343
|
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
|
343
344
|
"is_generation": _global_state.tokenizer_manager.is_generation,
|
344
345
|
"preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params,
|
346
|
+
"weight_version": _global_state.tokenizer_manager.server_args.weight_version,
|
345
347
|
}
|
346
348
|
return result
|
347
349
|
|
348
350
|
|
351
|
+
@app.get("/get_weight_version")
|
352
|
+
async def get_weight_version():
|
353
|
+
"""Get the current weight version."""
|
354
|
+
return {
|
355
|
+
"weight_version": _global_state.tokenizer_manager.server_args.weight_version
|
356
|
+
}
|
357
|
+
|
358
|
+
|
349
359
|
@app.get("/get_server_info")
|
350
360
|
async def get_server_info():
|
351
361
|
# Returns interna states per DP.
|
@@ -537,6 +547,12 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
|
|
537
547
|
success, message, num_paused_requests = (
|
538
548
|
await _global_state.tokenizer_manager.update_weights_from_disk(obj, request)
|
539
549
|
)
|
550
|
+
|
551
|
+
# Update weight version if provided and weights update was successful
|
552
|
+
if success and obj.weight_version is not None:
|
553
|
+
_update_weight_version_if_provided(obj.weight_version)
|
554
|
+
message += f" Weight version updated to {obj.weight_version}."
|
555
|
+
|
540
556
|
content = {
|
541
557
|
"success": success,
|
542
558
|
"message": message,
|
@@ -583,6 +599,12 @@ async def update_weights_from_tensor(
|
|
583
599
|
success, message = await _global_state.tokenizer_manager.update_weights_from_tensor(
|
584
600
|
obj, request
|
585
601
|
)
|
602
|
+
|
603
|
+
# Update weight version if provided and weights update was successful
|
604
|
+
if success and obj.weight_version is not None:
|
605
|
+
_update_weight_version_if_provided(obj.weight_version)
|
606
|
+
message += f" Weight version updated to {obj.weight_version}."
|
607
|
+
|
586
608
|
content = {"success": success, "message": message}
|
587
609
|
return ORJSONResponse(
|
588
610
|
content, status_code=200 if success else HTTPStatus.BAD_REQUEST
|
@@ -599,6 +621,12 @@ async def update_weights_from_distributed(
|
|
599
621
|
obj, request
|
600
622
|
)
|
601
623
|
)
|
624
|
+
|
625
|
+
# Update weight version if provided and weights update was successful
|
626
|
+
if success and obj.weight_version is not None:
|
627
|
+
_update_weight_version_if_provided(obj.weight_version)
|
628
|
+
message += f" Weight version updated to {obj.weight_version}."
|
629
|
+
|
602
630
|
content = {"success": success, "message": message}
|
603
631
|
if success:
|
604
632
|
return ORJSONResponse(content, status_code=200)
|
@@ -606,6 +634,36 @@ async def update_weights_from_distributed(
|
|
606
634
|
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
607
635
|
|
608
636
|
|
637
|
+
@app.post("/update_weight_version")
|
638
|
+
async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
|
639
|
+
"""Update the weight version. This operation requires no active requests."""
|
640
|
+
if obj.abort_all_requests:
|
641
|
+
_global_state.tokenizer_manager.abort_request(abort_all=True)
|
642
|
+
|
643
|
+
# Use a simple approach without the complex lock mechanism for now
|
644
|
+
# since weight_version update is a simple operation that doesn't affect model weights
|
645
|
+
try:
|
646
|
+
# Update the weight version in server args (the single source of truth)
|
647
|
+
_global_state.tokenizer_manager.server_args.weight_version = obj.new_version
|
648
|
+
|
649
|
+
return ORJSONResponse(
|
650
|
+
{
|
651
|
+
"success": True,
|
652
|
+
"message": f"Weight version updated to {obj.new_version}",
|
653
|
+
"new_version": obj.new_version,
|
654
|
+
},
|
655
|
+
status_code=HTTPStatus.OK,
|
656
|
+
)
|
657
|
+
except Exception as e:
|
658
|
+
return ORJSONResponse(
|
659
|
+
{
|
660
|
+
"success": False,
|
661
|
+
"message": f"Failed to update weight version: {str(e)}",
|
662
|
+
},
|
663
|
+
status_code=HTTPStatus.BAD_REQUEST,
|
664
|
+
)
|
665
|
+
|
666
|
+
|
609
667
|
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
610
668
|
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
611
669
|
"""Get model parameter by name."""
|
@@ -966,6 +1024,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
|
|
966
1024
|
return ORJSONResponse({"predictions": ret})
|
967
1025
|
|
968
1026
|
|
1027
|
+
def _update_weight_version_if_provided(weight_version: Optional[str]) -> None:
|
1028
|
+
"""Update weight version if provided."""
|
1029
|
+
if weight_version is not None:
|
1030
|
+
_global_state.tokenizer_manager.server_args.weight_version = weight_version
|
1031
|
+
|
1032
|
+
|
969
1033
|
def _create_error_response(e):
|
970
1034
|
return ORJSONResponse(
|
971
1035
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
@@ -240,6 +240,7 @@ class CompletionResponse(BaseModel):
|
|
240
240
|
model: str
|
241
241
|
choices: List[CompletionResponseChoice]
|
242
242
|
usage: UsageInfo
|
243
|
+
metadata: Optional[Dict[str, Any]] = None
|
243
244
|
|
244
245
|
|
245
246
|
class CompletionResponseStreamChoice(BaseModel):
|
@@ -517,6 +518,7 @@ class ChatCompletionResponse(BaseModel):
|
|
517
518
|
model: str
|
518
519
|
choices: List[ChatCompletionResponseChoice]
|
519
520
|
usage: UsageInfo
|
521
|
+
metadata: Optional[Dict[str, Any]] = None
|
520
522
|
|
521
523
|
|
522
524
|
class DeltaMessage(BaseModel):
|
@@ -373,6 +373,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|
373
373
|
created=created,
|
374
374
|
choices=choices,
|
375
375
|
usage=usage,
|
376
|
+
metadata={"weight_version": ret[0]["meta_info"]["weight_version"]},
|
376
377
|
)
|
377
378
|
|
378
379
|
def _get_echo_text(self, request: CompletionRequest, index: int) -> str:
|
@@ -122,6 +122,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
122
122
|
# Allocate buffers
|
123
123
|
global global_workspace_buffer
|
124
124
|
if global_workspace_buffer is None:
|
125
|
+
# different from flashinfer zero_init_global_workspace_buffer
|
125
126
|
global_workspace_buffer = torch.empty(
|
126
127
|
global_config.flashinfer_workspace_size,
|
127
128
|
dtype=torch.uint8,
|
@@ -870,6 +871,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
870
871
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
871
872
|
):
|
872
873
|
if use_ragged:
|
874
|
+
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
875
|
+
# and forward_batch.extend_seq_lens_cpu
|
873
876
|
paged_kernel_lens = prefix_lens
|
874
877
|
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
875
878
|
else:
|
@@ -81,6 +81,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
81
81
|
# Allocate buffers
|
82
82
|
global global_workspace_buffer
|
83
83
|
if global_workspace_buffer is None:
|
84
|
+
# different from flashinfer zero_init_global_workspace_buffer
|
84
85
|
global_workspace_buffer = torch.empty(
|
85
86
|
global_config.flashinfer_workspace_size,
|
86
87
|
dtype=torch.uint8,
|
@@ -57,16 +57,36 @@ class TritonAttnBackend(AttentionBackend):
|
|
57
57
|
self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
|
58
58
|
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
|
59
59
|
|
60
|
+
# Parse args
|
60
61
|
self.skip_prefill = skip_prefill
|
61
|
-
|
62
62
|
max_bs = model_runner.req_to_token_pool.size
|
63
|
+
self.sliding_window_size = model_runner.sliding_window_size
|
64
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
65
|
+
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
66
|
+
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
67
|
+
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
68
|
+
self.num_head = (
|
69
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
70
|
+
)
|
71
|
+
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
72
|
+
get_attention_tp_size()
|
73
|
+
)
|
74
|
+
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
75
|
+
self.max_context_len = model_runner.model_config.context_len
|
76
|
+
self.device = model_runner.device
|
77
|
+
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
78
|
+
self.static_kv_splits = get_bool_env_var(
|
79
|
+
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
80
|
+
)
|
81
|
+
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
63
82
|
|
83
|
+
# Check arguments
|
64
84
|
assert not (
|
65
85
|
model_runner.sliding_window_size is not None
|
66
86
|
and model_runner.model_config.is_encoder_decoder
|
67
87
|
), "Sliding window and cross attention are not supported together"
|
68
|
-
self.sliding_window_size = model_runner.sliding_window_size
|
69
88
|
|
89
|
+
# Initialize buffers
|
70
90
|
# TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
|
71
91
|
if kv_indptr_buf is None:
|
72
92
|
self.kv_indptr = torch.zeros(
|
@@ -87,9 +107,6 @@ class TritonAttnBackend(AttentionBackend):
|
|
87
107
|
# When provided a buffer, create a clone for the second buffer
|
88
108
|
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
|
89
109
|
|
90
|
-
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
91
|
-
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
92
|
-
|
93
110
|
if not self.skip_prefill:
|
94
111
|
self.qo_indptr = torch.zeros(
|
95
112
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
@@ -99,29 +116,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
99
116
|
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
|
100
117
|
)
|
101
118
|
|
102
|
-
|
103
|
-
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
104
|
-
|
105
|
-
self.num_head = (
|
106
|
-
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
107
|
-
)
|
108
|
-
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
109
|
-
get_attention_tp_size()
|
110
|
-
)
|
111
|
-
|
112
|
-
self.static_kv_splits = get_bool_env_var(
|
113
|
-
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
114
|
-
)
|
115
|
-
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
116
|
-
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
117
|
-
|
119
|
+
# Initialize forward metadata
|
118
120
|
self.forward_metadata: ForwardMetadata = None
|
119
121
|
|
120
|
-
self.max_context_len = model_runner.model_config.context_len
|
121
|
-
|
122
|
-
self.device = model_runner.device
|
123
|
-
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
124
|
-
|
125
122
|
def get_num_kv_splits(
|
126
123
|
self,
|
127
124
|
num_kv_splits: torch.Tensor,
|
@@ -333,7 +330,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
333
330
|
mask_indptr = None
|
334
331
|
attn_logits = None
|
335
332
|
attn_lse = None
|
336
|
-
max_extend_len =
|
333
|
+
max_extend_len = max(forward_batch.extend_seq_lens_cpu)
|
337
334
|
num_kv_splits = None
|
338
335
|
|
339
336
|
self.forward_metadata = ForwardMetadata(
|
@@ -23,10 +23,12 @@ if TYPE_CHECKING:
|
|
23
23
|
from sglang.srt.speculative.spec_info import SpecInfo
|
24
24
|
|
25
25
|
# Constants
|
26
|
-
DEFAULT_WORKSPACE_SIZE_MB =
|
26
|
+
DEFAULT_WORKSPACE_SIZE_MB = (
|
27
|
+
512 # Memory workspace size in MB, todo(Yingyi): read from config
|
28
|
+
)
|
27
29
|
|
28
30
|
# Reuse this workspace buffer across all TRTLLM MHA wrappers
|
29
|
-
|
31
|
+
global_zero_init_workspace_buffer = None
|
30
32
|
|
31
33
|
|
32
34
|
@dataclass
|
@@ -73,14 +75,14 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
73
75
|
# Workspace allocation
|
74
76
|
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
75
77
|
# Allocate buffers
|
76
|
-
global
|
77
|
-
if
|
78
|
-
|
78
|
+
global global_zero_init_workspace_buffer
|
79
|
+
if global_zero_init_workspace_buffer is None:
|
80
|
+
global_zero_init_workspace_buffer = torch.zeros(
|
79
81
|
self.workspace_size,
|
80
82
|
dtype=torch.uint8,
|
81
83
|
device=model_runner.device,
|
82
84
|
)
|
83
|
-
self.workspace_buffer =
|
85
|
+
self.workspace_buffer = global_zero_init_workspace_buffer
|
84
86
|
|
85
87
|
# CUDA graph state
|
86
88
|
self.decode_cuda_graph_metadata = {}
|
@@ -39,6 +39,8 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
|
39
39
|
# compute the LCM with other padding constraints.
|
40
40
|
TRTLLM_BLOCK_CONSTRAINT = 128
|
41
41
|
|
42
|
+
global_zero_init_workspace_buffer = None
|
43
|
+
|
42
44
|
|
43
45
|
@dataclass
|
44
46
|
class TRTLLMMLADecodeMetadata:
|
@@ -83,9 +85,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
83
85
|
|
84
86
|
# Workspace allocation
|
85
87
|
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
86
|
-
|
87
|
-
|
88
|
-
|
88
|
+
global global_zero_init_workspace_buffer
|
89
|
+
if global_zero_init_workspace_buffer is None:
|
90
|
+
global_zero_init_workspace_buffer = torch.zeros(
|
91
|
+
self.workspace_size,
|
92
|
+
dtype=torch.uint8,
|
93
|
+
device=model_runner.device,
|
94
|
+
)
|
95
|
+
self.workspace_buffer = global_zero_init_workspace_buffer
|
89
96
|
|
90
97
|
# CUDA graph state
|
91
98
|
self.decode_cuda_graph_metadata = {}
|
@@ -32,6 +32,8 @@ from sglang.srt.layers.dp_attention import (
|
|
32
32
|
get_attention_dp_size,
|
33
33
|
get_attention_tp_rank,
|
34
34
|
get_attention_tp_size,
|
35
|
+
get_global_dp_buffer,
|
36
|
+
get_local_dp_buffer,
|
35
37
|
)
|
36
38
|
from sglang.srt.layers.utils import is_sm100_supported
|
37
39
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -319,7 +321,7 @@ class CommunicateSimpleFn:
|
|
319
321
|
context: CommunicateContext,
|
320
322
|
) -> torch.Tensor:
|
321
323
|
hidden_states, local_hidden_states = (
|
322
|
-
|
324
|
+
get_local_dp_buffer(),
|
323
325
|
hidden_states,
|
324
326
|
)
|
325
327
|
attn_tp_all_gather_into_tensor(
|
@@ -408,9 +410,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
408
410
|
):
|
409
411
|
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
|
410
412
|
residual, local_residual = (
|
411
|
-
|
412
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
|
413
|
-
),
|
413
|
+
get_local_dp_buffer(),
|
414
414
|
residual,
|
415
415
|
)
|
416
416
|
attn_tp_all_gather_into_tensor(residual, local_residual)
|
@@ -424,7 +424,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
424
424
|
residual = hidden_states
|
425
425
|
hidden_states = layernorm(hidden_states)
|
426
426
|
hidden_states, local_hidden_states = (
|
427
|
-
|
427
|
+
get_global_dp_buffer(),
|
428
428
|
hidden_states,
|
429
429
|
)
|
430
430
|
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
@@ -548,7 +548,7 @@ class CommunicateSummableTensorPairFn:
|
|
548
548
|
allow_reduce_scatter: bool = False,
|
549
549
|
):
|
550
550
|
hidden_states, global_hidden_states = (
|
551
|
-
|
551
|
+
get_local_dp_buffer(),
|
552
552
|
hidden_states,
|
553
553
|
)
|
554
554
|
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
|
@@ -569,7 +569,7 @@ class CommunicateSummableTensorPairFn:
|
|
569
569
|
hidden_states += residual
|
570
570
|
residual = None
|
571
571
|
hidden_states, local_hidden_states = (
|
572
|
-
|
572
|
+
get_local_dp_buffer(),
|
573
573
|
hidden_states,
|
574
574
|
)
|
575
575
|
attn_tp_all_gather_into_tensor(
|
@@ -4,7 +4,7 @@ import functools
|
|
4
4
|
import logging
|
5
5
|
from contextlib import contextmanager
|
6
6
|
from enum import IntEnum, auto
|
7
|
-
from typing import TYPE_CHECKING, List, Tuple
|
7
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
8
8
|
|
9
9
|
import torch
|
10
10
|
import triton
|
@@ -18,21 +18,26 @@ from sglang.srt.distributed import (
|
|
18
18
|
tensor_model_parallel_all_reduce,
|
19
19
|
)
|
20
20
|
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from sglang.srt.configs.model_config import ModelConfig
|
23
|
+
from sglang.srt.server_args import ServerArgs
|
24
|
+
|
21
25
|
logger = logging.getLogger(__name__)
|
22
26
|
|
23
27
|
if TYPE_CHECKING:
|
24
28
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
25
29
|
|
26
|
-
_ATTN_TP_GROUP = None
|
27
|
-
_ATTN_TP_RANK = None
|
28
|
-
_ATTN_TP_SIZE = None
|
29
|
-
_ATTN_DP_RANK = None
|
30
|
-
_ATTN_DP_SIZE = None
|
31
|
-
_LOCAL_ATTN_DP_SIZE = None
|
32
|
-
_LOCAL_ATTN_DP_RANK = None
|
30
|
+
_ATTN_TP_GROUP: Optional[GroupCoordinator] = None
|
31
|
+
_ATTN_TP_RANK: Optional[int] = None
|
32
|
+
_ATTN_TP_SIZE: Optional[int] = None
|
33
|
+
_ATTN_DP_RANK: Optional[int] = None
|
34
|
+
_ATTN_DP_SIZE: Optional[int] = None
|
35
|
+
_LOCAL_ATTN_DP_SIZE: Optional[int] = None
|
36
|
+
_LOCAL_ATTN_DP_RANK: Optional[int] = None
|
37
|
+
_ENABLE_DP_ATTENTION_FLAG: bool = False
|
33
38
|
|
34
39
|
|
35
|
-
class
|
40
|
+
class DpPaddingMode(IntEnum):
|
36
41
|
|
37
42
|
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
|
38
43
|
MAX_LEN = auto()
|
@@ -40,13 +45,13 @@ class DPPaddingMode(IntEnum):
|
|
40
45
|
SUM_LEN = auto()
|
41
46
|
|
42
47
|
def is_max_len(self):
|
43
|
-
return self ==
|
48
|
+
return self == DpPaddingMode.MAX_LEN
|
44
49
|
|
45
50
|
def is_sum_len(self):
|
46
|
-
return self ==
|
51
|
+
return self == DpPaddingMode.SUM_LEN
|
47
52
|
|
48
53
|
@classmethod
|
49
|
-
def get_dp_padding_mode(cls, global_num_tokens: List[int]) ->
|
54
|
+
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
|
50
55
|
# we choose the mode that minimizes the communication cost
|
51
56
|
max_len = max(global_num_tokens)
|
52
57
|
sum_len = sum(global_num_tokens)
|
@@ -56,10 +61,76 @@ class DPPaddingMode(IntEnum):
|
|
56
61
|
return cls.SUM_LEN
|
57
62
|
|
58
63
|
@classmethod
|
59
|
-
def get_default_mode_in_cuda_graph(cls) ->
|
64
|
+
def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
|
60
65
|
return cls.MAX_LEN
|
61
66
|
|
62
67
|
|
68
|
+
class _DpGatheredBufferWrapper:
|
69
|
+
|
70
|
+
_hidden_size: int
|
71
|
+
_dtype: torch.dtype
|
72
|
+
_device: torch.device
|
73
|
+
_global_dp_buffer_len: int
|
74
|
+
_local_dp_buffer_len: int
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
|
78
|
+
cls._hidden_size = hidden_size
|
79
|
+
cls._dtype = dtype
|
80
|
+
cls._device = device
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
|
84
|
+
cls._global_dp_buffer_len = global_dp_buffer_len
|
85
|
+
cls._local_dp_buffer_len = local_dp_buffer_len
|
86
|
+
|
87
|
+
@classmethod
|
88
|
+
def get_global_dp_buffer(cls) -> torch.Tensor:
|
89
|
+
return torch.empty(
|
90
|
+
(cls._global_dp_buffer_len, cls._hidden_size),
|
91
|
+
dtype=cls._dtype,
|
92
|
+
device=cls._device,
|
93
|
+
)
|
94
|
+
|
95
|
+
@classmethod
|
96
|
+
def get_local_dp_buffer(cls) -> torch.Tensor:
|
97
|
+
return torch.empty(
|
98
|
+
(cls._local_dp_buffer_len, cls._hidden_size),
|
99
|
+
dtype=cls._dtype,
|
100
|
+
device=cls._device,
|
101
|
+
)
|
102
|
+
|
103
|
+
@classmethod
|
104
|
+
def get_global_dp_buffer_len(cls) -> int:
|
105
|
+
return cls._global_dp_buffer_len
|
106
|
+
|
107
|
+
@classmethod
|
108
|
+
def get_local_dp_buffer_len(cls) -> int:
|
109
|
+
return cls._local_dp_buffer_len
|
110
|
+
|
111
|
+
|
112
|
+
def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
|
113
|
+
_DpGatheredBufferWrapper.set_dp_buffer_len(
|
114
|
+
global_dp_buffer_len, local_dp_buffer_len
|
115
|
+
)
|
116
|
+
|
117
|
+
|
118
|
+
def get_global_dp_buffer() -> torch.Tensor:
|
119
|
+
return _DpGatheredBufferWrapper.get_global_dp_buffer()
|
120
|
+
|
121
|
+
|
122
|
+
def get_local_dp_buffer() -> torch.Tensor:
|
123
|
+
return _DpGatheredBufferWrapper.get_local_dp_buffer()
|
124
|
+
|
125
|
+
|
126
|
+
def get_global_dp_buffer_len() -> int:
|
127
|
+
return _DpGatheredBufferWrapper.get_global_dp_buffer_len()
|
128
|
+
|
129
|
+
|
130
|
+
def get_local_dp_buffer_len() -> int:
|
131
|
+
return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
|
132
|
+
|
133
|
+
|
63
134
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
64
135
|
if not enable_dp_attention:
|
65
136
|
return tp_rank, tp_size, 0
|
@@ -89,18 +160,24 @@ def compute_dp_attention_local_info(
|
|
89
160
|
|
90
161
|
|
91
162
|
def initialize_dp_attention(
|
92
|
-
|
93
|
-
|
94
|
-
tp_size: int,
|
95
|
-
dp_size: int,
|
96
|
-
moe_dense_tp_size: int,
|
97
|
-
pp_size: int,
|
163
|
+
server_args: ServerArgs,
|
164
|
+
model_config: ModelConfig,
|
98
165
|
):
|
99
166
|
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
|
100
|
-
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
|
167
|
+
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK, _ENABLE_DP_ATTENTION_FLAG
|
101
168
|
|
102
169
|
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
103
170
|
|
171
|
+
enable_dp_attention = server_args.enable_dp_attention
|
172
|
+
tp_size = server_args.tp_size
|
173
|
+
dp_size = server_args.dp_size
|
174
|
+
moe_dense_tp_size = server_args.moe_dense_tp_size
|
175
|
+
pp_size = server_args.pp_size
|
176
|
+
|
177
|
+
tp_rank = get_tensor_model_parallel_rank()
|
178
|
+
|
179
|
+
_ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
|
180
|
+
|
104
181
|
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
|
105
182
|
enable_dp_attention, tp_rank, tp_size, dp_size
|
106
183
|
)
|
@@ -135,38 +212,48 @@ def initialize_dp_attention(
|
|
135
212
|
group_name="attention_tp",
|
136
213
|
)
|
137
214
|
|
215
|
+
_DpGatheredBufferWrapper.set_metadata(
|
216
|
+
hidden_size=model_config.hidden_size,
|
217
|
+
dtype=model_config.dtype,
|
218
|
+
device=torch.device("cuda"),
|
219
|
+
)
|
138
220
|
|
139
|
-
|
221
|
+
|
222
|
+
def is_dp_attention_enabled() -> bool:
|
223
|
+
return _ENABLE_DP_ATTENTION_FLAG
|
224
|
+
|
225
|
+
|
226
|
+
def get_attention_tp_group() -> GroupCoordinator:
|
140
227
|
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
|
141
228
|
return _ATTN_TP_GROUP
|
142
229
|
|
143
230
|
|
144
|
-
def get_attention_tp_rank():
|
231
|
+
def get_attention_tp_rank() -> int:
|
145
232
|
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
|
146
233
|
return _ATTN_TP_RANK
|
147
234
|
|
148
235
|
|
149
|
-
def get_attention_tp_size():
|
236
|
+
def get_attention_tp_size() -> int:
|
150
237
|
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
|
151
238
|
return _ATTN_TP_SIZE
|
152
239
|
|
153
240
|
|
154
|
-
def get_attention_dp_rank():
|
241
|
+
def get_attention_dp_rank() -> int:
|
155
242
|
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
|
156
243
|
return _ATTN_DP_RANK
|
157
244
|
|
158
245
|
|
159
|
-
def get_attention_dp_size():
|
246
|
+
def get_attention_dp_size() -> int:
|
160
247
|
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
161
248
|
return _ATTN_DP_SIZE
|
162
249
|
|
163
250
|
|
164
|
-
def get_local_attention_dp_rank():
|
251
|
+
def get_local_attention_dp_rank() -> int:
|
165
252
|
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
166
253
|
return _LOCAL_ATTN_DP_RANK
|
167
254
|
|
168
255
|
|
169
|
-
def get_local_attention_dp_size():
|
256
|
+
def get_local_attention_dp_size() -> int:
|
170
257
|
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
171
258
|
return _LOCAL_ATTN_DP_SIZE
|
172
259
|
|
@@ -292,6 +379,10 @@ def _dp_gather_via_all_gather(
|
|
292
379
|
forward_batch: ForwardBatch,
|
293
380
|
is_partial: bool,
|
294
381
|
):
|
382
|
+
if get_attention_tp_size() == 1:
|
383
|
+
get_tp_group().all_gather_into_tensor(global_tokens, local_tokens)
|
384
|
+
return
|
385
|
+
|
295
386
|
if not is_partial:
|
296
387
|
if get_attention_tp_rank() != 0:
|
297
388
|
local_tokens.fill_(0)
|