sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +0 -112
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +1 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +11 -0
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +35 -15
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/hf_transformers_utils.py +25 -10
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/attention/vision.py +27 -10
- sglang/srt/layers/communicator.py +14 -4
- sglang/srt/layers/linear.py +7 -1
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/ep_moe/layer.py +29 -68
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/utils.py +43 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp8.py +57 -1
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/layers/vocab_parallel_embedding.py +7 -1
- sglang/srt/lora/lora_registry.py +7 -0
- sglang/srt/managers/cache_controller.py +43 -39
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +3 -2
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +145 -6
- sglang/srt/managers/template_manager.py +25 -22
- sglang/srt/managers/tokenizer_manager.py +114 -62
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -12
- sglang/srt/mem_cache/hiradix_cache.py +21 -4
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +350 -33
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/model_executor/cuda_graph_runner.py +42 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -3
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/deepseek_v2.py +28 -23
- sglang/srt/models/glm4_moe.py +85 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2_moe.py +1 -4
- sglang/srt/models/qwen3_moe.py +7 -8
- sglang/srt/models/step3_vl.py +1 -4
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/server_args.py +115 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +6 -4
- sglang/srt/utils.py +4 -24
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,216 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
import os
|
4
|
+
import unittest
|
5
|
+
from typing import List, Optional
|
6
|
+
from unittest.mock import MagicMock
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
|
11
|
+
from sglang.srt.mem_cache.nixl.nixl_utils import NixlFileManager, NixlRegistration
|
12
|
+
|
13
|
+
|
14
|
+
class TestNixlUnified(unittest.TestCase):
|
15
|
+
"""Unified test suite for all NIXL components."""
|
16
|
+
|
17
|
+
def setUp(self):
|
18
|
+
"""Set up test environment."""
|
19
|
+
# Create test directories
|
20
|
+
self.test_dir = "/tmp/test_nixl_unified"
|
21
|
+
os.makedirs(self.test_dir, exist_ok=True)
|
22
|
+
|
23
|
+
# Mock NIXL agent for registration tests
|
24
|
+
self.mock_agent = MagicMock()
|
25
|
+
self.mock_agent.get_reg_descs.return_value = "mock_reg_descs"
|
26
|
+
self.mock_agent.register_memory.return_value = "mock_registered_memory"
|
27
|
+
|
28
|
+
# Create instances
|
29
|
+
self.file_manager = NixlFileManager(self.test_dir)
|
30
|
+
self.registration = NixlRegistration(self.mock_agent)
|
31
|
+
try:
|
32
|
+
self.hicache = HiCacheNixl(file_path=self.test_dir, plugin="POSIX")
|
33
|
+
except ImportError:
|
34
|
+
self.skipTest("NIXL not available, skipping NIXL storage tests")
|
35
|
+
|
36
|
+
def tearDown(self):
|
37
|
+
"""Clean up test directories."""
|
38
|
+
if os.path.exists(self.test_dir):
|
39
|
+
import shutil
|
40
|
+
|
41
|
+
shutil.rmtree(self.test_dir)
|
42
|
+
|
43
|
+
def delete_test_file(self, file_path: str) -> bool:
|
44
|
+
"""Helper method to delete a test file.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
file_path: Path to the file to delete
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
bool: True if file was deleted or didn't exist, False on error
|
51
|
+
"""
|
52
|
+
try:
|
53
|
+
if os.path.exists(file_path):
|
54
|
+
os.remove(file_path)
|
55
|
+
return True
|
56
|
+
except Exception as e:
|
57
|
+
return False
|
58
|
+
|
59
|
+
def verify_tensors_equal(self, expected: torch.Tensor, actual: torch.Tensor):
|
60
|
+
"""Helper to verify tensor equality."""
|
61
|
+
self.assertIsNotNone(actual, "Retrieved tensor is None")
|
62
|
+
self.assertTrue(
|
63
|
+
torch.allclose(expected, actual, atol=1e-6),
|
64
|
+
f"Tensors not equal:\nExpected: {expected}\nActual: {actual}",
|
65
|
+
)
|
66
|
+
|
67
|
+
def verify_tensor_lists_equal(
|
68
|
+
self, expected: List[torch.Tensor], actual: List[torch.Tensor]
|
69
|
+
):
|
70
|
+
"""Helper to verify lists of tensors are equal."""
|
71
|
+
self.assertEqual(len(expected), len(actual), "Lists have different lengths")
|
72
|
+
for exp, act in zip(expected, actual):
|
73
|
+
self.verify_tensors_equal(exp, act)
|
74
|
+
|
75
|
+
# ============================================================================
|
76
|
+
# HiCache Integration Tests
|
77
|
+
# ============================================================================
|
78
|
+
|
79
|
+
def test_single_set_get(self):
|
80
|
+
"""Test single tensor set/get operations."""
|
81
|
+
key = "test_key"
|
82
|
+
value = torch.randn(10, 10, device="cpu")
|
83
|
+
dst_tensor = torch.zeros_like(value, device="cpu")
|
84
|
+
|
85
|
+
# Test set
|
86
|
+
self.assertTrue(self.hicache.set(key, value))
|
87
|
+
self.assertTrue(self.hicache.exists(key))
|
88
|
+
|
89
|
+
# Test get
|
90
|
+
retrieved = self.hicache.get(key, dst_tensor)
|
91
|
+
self.verify_tensors_equal(value, retrieved)
|
92
|
+
|
93
|
+
def test_batch_set_get(self):
|
94
|
+
"""Test batch tensor set/get operations."""
|
95
|
+
keys = ["key1", "key2", "key3"]
|
96
|
+
values = [
|
97
|
+
torch.randn(5, 5, device="cpu"),
|
98
|
+
torch.randn(3, 3, device="cpu"),
|
99
|
+
torch.randn(7, 7, device="cpu"),
|
100
|
+
]
|
101
|
+
dst_tensors = [torch.zeros_like(v, device="cpu") for v in values]
|
102
|
+
|
103
|
+
# Test batch set
|
104
|
+
self.assertTrue(self.hicache.batch_set(keys, values))
|
105
|
+
self.assertTrue(all(self.hicache.exists(key) for key in keys))
|
106
|
+
|
107
|
+
# Test batch get
|
108
|
+
retrieved = self.hicache.batch_get(keys, dst_tensors)
|
109
|
+
self.verify_tensor_lists_equal(values, retrieved)
|
110
|
+
|
111
|
+
def test_mixed_operations(self):
|
112
|
+
"""Test mixing single and batch operations."""
|
113
|
+
# Test interleaved set/get operations
|
114
|
+
key1, key2 = "key1", "key2"
|
115
|
+
value1 = torch.randn(4, 4, device="cpu")
|
116
|
+
value2 = torch.randn(6, 6, device="cpu")
|
117
|
+
dst1 = torch.zeros_like(value1)
|
118
|
+
dst2 = torch.zeros_like(value2)
|
119
|
+
|
120
|
+
# Single set/get
|
121
|
+
self.assertTrue(self.hicache.set(key1, value1))
|
122
|
+
retrieved1 = self.hicache.get(key1, dst1)
|
123
|
+
self.verify_tensors_equal(value1, retrieved1)
|
124
|
+
|
125
|
+
# Batch set/get
|
126
|
+
self.assertTrue(self.hicache.batch_set([key2], [value2]))
|
127
|
+
retrieved2 = self.hicache.batch_get([key2], [dst2])
|
128
|
+
self.verify_tensors_equal(value2, retrieved2[0])
|
129
|
+
|
130
|
+
def test_data_integrity(self):
|
131
|
+
"""Test data integrity across operations."""
|
132
|
+
# Test with various tensor types and sizes
|
133
|
+
test_cases = [
|
134
|
+
("float32", torch.randn(10, 10, dtype=torch.float32)),
|
135
|
+
("float64", torch.randn(5, 5, dtype=torch.float64)),
|
136
|
+
("int32", torch.randint(-100, 100, (8, 8), dtype=torch.int32)),
|
137
|
+
("int64", torch.randint(-100, 100, (6, 6), dtype=torch.int64)),
|
138
|
+
("bool", torch.randint(0, 2, (4, 4)).bool()),
|
139
|
+
]
|
140
|
+
|
141
|
+
for name, tensor in test_cases:
|
142
|
+
with self.subTest(tensor_type=name):
|
143
|
+
key = f"test_{name}"
|
144
|
+
dst_tensor = torch.zeros_like(tensor)
|
145
|
+
|
146
|
+
# Set and immediately get
|
147
|
+
self.assertTrue(self.hicache.set(key, tensor))
|
148
|
+
retrieved1 = self.hicache.get(key, dst_tensor)
|
149
|
+
self.verify_tensors_equal(tensor, retrieved1)
|
150
|
+
|
151
|
+
# Get again to verify persistence
|
152
|
+
dst_tensor.zero_()
|
153
|
+
retrieved2 = self.hicache.get(key, dst_tensor)
|
154
|
+
self.verify_tensors_equal(tensor, retrieved2)
|
155
|
+
|
156
|
+
def test_basic_file_operations(self):
|
157
|
+
"""Test basic file operations."""
|
158
|
+
test_file = os.path.join(self.test_dir, "test_file.bin")
|
159
|
+
self.file_manager.create_file(test_file)
|
160
|
+
self.assertTrue(os.path.exists(test_file))
|
161
|
+
self.assertEqual(os.path.getsize(test_file), 0) # Empty file
|
162
|
+
|
163
|
+
# Test file deletion
|
164
|
+
self.assertTrue(self.delete_test_file(test_file))
|
165
|
+
self.assertFalse(os.path.exists(test_file))
|
166
|
+
|
167
|
+
def test_create_nixl_tuples(self):
|
168
|
+
"""Test creation of NIXL tuples."""
|
169
|
+
test_file = os.path.join(self.test_dir, "test_file.bin")
|
170
|
+
self.file_manager.create_file(test_file)
|
171
|
+
|
172
|
+
# Test tuple creation
|
173
|
+
tuples = self.file_manager.files_to_nixl_tuples([test_file], False)
|
174
|
+
self.assertIsNotNone(tuples)
|
175
|
+
self.assertTrue(len(tuples) > 0)
|
176
|
+
|
177
|
+
def test_error_handling(self):
|
178
|
+
"""Test error handling in file operations."""
|
179
|
+
# Test non-existent file
|
180
|
+
self.assertTrue(
|
181
|
+
self.delete_test_file("nonexistent_file.bin")
|
182
|
+
) # Returns True if file doesn't exist
|
183
|
+
|
184
|
+
# Test invalid file path
|
185
|
+
self.assertFalse(self.file_manager.create_file("")) # Empty path should fail
|
186
|
+
|
187
|
+
def test_register_buffers(self):
|
188
|
+
"""Test registration of memory buffers."""
|
189
|
+
# Create test tensor
|
190
|
+
tensor = torch.randn(10, 10)
|
191
|
+
|
192
|
+
# Test buffer registration
|
193
|
+
self.assertIsNotNone(self.registration.register_buffers(tensor))
|
194
|
+
|
195
|
+
# Test batch registration
|
196
|
+
tensors = [torch.randn(5, 5) for _ in range(3)]
|
197
|
+
self.assertIsNotNone(self.registration.register_buffers(tensors))
|
198
|
+
|
199
|
+
def test_register_files_with_tuples(self):
|
200
|
+
"""Test registration of files using NIXL tuples."""
|
201
|
+
files = [os.path.join(self.test_dir, f"test_file_{i}.bin") for i in range(3)]
|
202
|
+
for file in files:
|
203
|
+
self.file_manager.create_file(file)
|
204
|
+
|
205
|
+
# Create tuples and register
|
206
|
+
tuples = self.file_manager.files_to_nixl_tuples(files, False)
|
207
|
+
self.registration.register_files(tuples)
|
208
|
+
|
209
|
+
# Verify tuples
|
210
|
+
self.assertEqual(len(tuples), len(files))
|
211
|
+
for t, f in zip(tuples, files):
|
212
|
+
self.assertEqual(t[3], f) # Check file path
|
213
|
+
|
214
|
+
|
215
|
+
if __name__ == "__main__":
|
216
|
+
unittest.main()
|
@@ -16,6 +16,7 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import bisect
|
19
|
+
import gc
|
19
20
|
import inspect
|
20
21
|
import logging
|
21
22
|
import os
|
@@ -28,6 +29,9 @@ from torch.profiler import ProfilerActivity, profile
|
|
28
29
|
|
29
30
|
from sglang.srt.custom_op import CustomOp
|
30
31
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
32
|
+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
33
|
+
set_graph_pool_id,
|
34
|
+
)
|
31
35
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
32
36
|
from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
|
33
37
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
@@ -75,6 +79,24 @@ def model_capture_mode():
|
|
75
79
|
is_capture_mode = False
|
76
80
|
|
77
81
|
|
82
|
+
@contextmanager
|
83
|
+
def freeze_gc(enable_cudagraph_gc: bool):
|
84
|
+
"""
|
85
|
+
Optimize garbage collection during CUDA graph capture.
|
86
|
+
Clean up, then freeze all remaining objects from being included
|
87
|
+
in future collections if GC is disabled during capture.
|
88
|
+
"""
|
89
|
+
gc.collect()
|
90
|
+
should_freeze = not enable_cudagraph_gc
|
91
|
+
if should_freeze:
|
92
|
+
gc.freeze()
|
93
|
+
try:
|
94
|
+
yield
|
95
|
+
finally:
|
96
|
+
if should_freeze:
|
97
|
+
gc.unfreeze()
|
98
|
+
|
99
|
+
|
78
100
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
79
101
|
for sub in model._modules.values():
|
80
102
|
if isinstance(sub, CustomOp):
|
@@ -353,6 +375,11 @@ class CudaGraphRunner:
|
|
353
375
|
dtype=torch.bool,
|
354
376
|
device="cuda",
|
355
377
|
)
|
378
|
+
self.next_token_logits_buffer = torch.zeros(
|
379
|
+
(self.max_num_token, self.model_runner.model_config.vocab_size),
|
380
|
+
dtype=torch.float,
|
381
|
+
device="cuda",
|
382
|
+
)
|
356
383
|
|
357
384
|
# Capture
|
358
385
|
try:
|
@@ -423,7 +450,12 @@ class CudaGraphRunner:
|
|
423
450
|
record_shapes=True,
|
424
451
|
)
|
425
452
|
|
426
|
-
|
453
|
+
# Trigger CUDA graph capture for specific shapes.
|
454
|
+
# Capture the large shapes first so that the smaller shapes
|
455
|
+
# can reuse the memory pool allocated for the large shapes.
|
456
|
+
with freeze_gc(
|
457
|
+
self.model_runner.server_args.enable_cudagraph_gc
|
458
|
+
), graph_capture() as graph_capture_context:
|
427
459
|
with profile_context as prof:
|
428
460
|
self.stream = graph_capture_context.stream
|
429
461
|
avail_mem = get_available_gpu_memory(
|
@@ -493,6 +525,7 @@ class CudaGraphRunner:
|
|
493
525
|
else:
|
494
526
|
encoder_lens = None
|
495
527
|
mrope_positions = self.mrope_positions[:, :bs]
|
528
|
+
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
|
496
529
|
self.num_token_non_padded[...] = num_tokens
|
497
530
|
|
498
531
|
# pipeline parallelism
|
@@ -555,6 +588,7 @@ class CudaGraphRunner:
|
|
555
588
|
input_ids=input_ids,
|
556
589
|
req_pool_indices=req_pool_indices,
|
557
590
|
seq_lens=seq_lens,
|
591
|
+
next_token_logits_buffer=next_token_logits_buffer,
|
558
592
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
559
593
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
560
594
|
attn_backend=self.model_runner.attn_backend,
|
@@ -619,11 +653,15 @@ class CudaGraphRunner:
|
|
619
653
|
|
620
654
|
run_once()
|
621
655
|
|
622
|
-
|
623
|
-
|
656
|
+
if get_global_graph_memory_pool() is None:
|
657
|
+
set_global_graph_memory_pool(torch.cuda.graph_pool_handle())
|
658
|
+
# Set graph pool id globally to be able to use symmetric memory
|
659
|
+
set_graph_pool_id(get_global_graph_memory_pool())
|
660
|
+
with torch.cuda.graph(
|
661
|
+
graph, pool=get_global_graph_memory_pool(), stream=stream
|
662
|
+
):
|
624
663
|
out = run_once()
|
625
664
|
|
626
|
-
global_graph_memory_pool = graph.pool()
|
627
665
|
return graph, out
|
628
666
|
|
629
667
|
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
@@ -38,6 +38,7 @@ import torch
|
|
38
38
|
import triton
|
39
39
|
import triton.language as tl
|
40
40
|
|
41
|
+
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
41
42
|
from sglang.srt.layers.dp_attention import (
|
42
43
|
DPPaddingMode,
|
43
44
|
get_attention_dp_rank,
|
@@ -188,6 +189,7 @@ class ForwardBatch:
|
|
188
189
|
token_ids_logprobs: Optional[List[List[int]]] = None
|
189
190
|
|
190
191
|
# For logits and logprobs post processing
|
192
|
+
next_token_logits_buffer: torch.Tensor = None
|
191
193
|
temp_scaled_logprobs: bool = False
|
192
194
|
temperature: torch.Tensor = None
|
193
195
|
top_p_normalized_logprobs: bool = False
|
@@ -644,12 +646,17 @@ class ForwardBatch:
|
|
644
646
|
device=model_runner.device,
|
645
647
|
)
|
646
648
|
|
647
|
-
bs = self.batch_size
|
648
649
|
if len(global_num_tokens) > 1:
|
649
650
|
num_tokens = global_num_tokens[get_attention_dp_rank()]
|
650
651
|
else:
|
651
652
|
num_tokens = global_num_tokens[0]
|
652
653
|
|
654
|
+
if self.forward_mode.is_decode():
|
655
|
+
setattr(self, "raw_bs", self.batch_size)
|
656
|
+
self.batch_size = num_tokens
|
657
|
+
|
658
|
+
bs = self.batch_size
|
659
|
+
|
653
660
|
# padding
|
654
661
|
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
|
655
662
|
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
|
@@ -657,6 +664,9 @@ class ForwardBatch:
|
|
657
664
|
seq_len_fill_value = (
|
658
665
|
model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
659
666
|
)
|
667
|
+
self.seq_lens_sum = self.seq_lens_sum + seq_len_fill_value * (
|
668
|
+
bs - self.seq_lens.shape[0]
|
669
|
+
)
|
660
670
|
self.seq_lens = self._pad_tensor_to_size(
|
661
671
|
self.seq_lens, bs, value=seq_len_fill_value
|
662
672
|
)
|
@@ -700,7 +710,7 @@ class ForwardBatch:
|
|
700
710
|
|
701
711
|
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
|
702
712
|
|
703
|
-
bs = self.batch_size
|
713
|
+
bs = getattr(self, "raw_bs", self.batch_size)
|
704
714
|
|
705
715
|
if self.spec_info is not None:
|
706
716
|
if self.forward_mode.is_decode(): # draft
|
@@ -839,7 +849,7 @@ class ForwardBatch:
|
|
839
849
|
|
840
850
|
|
841
851
|
def enable_num_token_non_padded(server_args):
|
842
|
-
return
|
852
|
+
return get_moe_expert_parallel_world_size() > 1
|
843
853
|
|
844
854
|
|
845
855
|
class PPProxyTensors:
|
@@ -60,6 +60,7 @@ from sglang.srt.layers.dp_attention import (
|
|
60
60
|
initialize_dp_attention,
|
61
61
|
)
|
62
62
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
63
|
+
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
63
64
|
from sglang.srt.layers.quantization import (
|
64
65
|
deep_gemm_wrapper,
|
65
66
|
monkey_patch_isinstance_for_vllm_base_layer,
|
@@ -217,6 +218,10 @@ class ModelRunner:
|
|
217
218
|
"use_mla_backend": self.use_mla_backend,
|
218
219
|
"speculative_algorithm": self.spec_algorithm,
|
219
220
|
}
|
221
|
+
| {
|
222
|
+
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
|
223
|
+
"deepep_mode": DeepEPMode(server_args.deepep_mode),
|
224
|
+
}
|
220
225
|
)
|
221
226
|
|
222
227
|
# CPU offload
|
@@ -436,6 +441,7 @@ class ModelRunner:
|
|
436
441
|
"triton",
|
437
442
|
"flashmla",
|
438
443
|
"cutlass_mla",
|
444
|
+
"trtllm_mla",
|
439
445
|
"ascend",
|
440
446
|
]:
|
441
447
|
logger.info(
|
@@ -671,7 +677,7 @@ class ModelRunner:
|
|
671
677
|
self.sliding_window_size = self.model.get_attention_sliding_window_size()
|
672
678
|
elif self.model_config.attention_chunk_size is not None:
|
673
679
|
self.sliding_window_size = self.model_config.attention_chunk_size
|
674
|
-
|
680
|
+
logger.info(
|
675
681
|
f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
|
676
682
|
)
|
677
683
|
|
@@ -1437,6 +1443,12 @@ class ModelRunner:
|
|
1437
1443
|
)
|
1438
1444
|
|
1439
1445
|
return CutlassMLABackend(self)
|
1446
|
+
elif self.server_args.attention_backend == "trtllm_mla":
|
1447
|
+
if not self.use_mla_backend:
|
1448
|
+
raise ValueError("trtllm_mla backend can only be used with MLA models.")
|
1449
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
1450
|
+
|
1451
|
+
return TRTLLMMLABackend(self)
|
1440
1452
|
elif self.server_args.attention_backend == "intel_amx":
|
1441
1453
|
from sglang.srt.layers.attention.intel_amx_backend import (
|
1442
1454
|
IntelAMXAttnBackend,
|
@@ -229,6 +229,8 @@ def get_quant_config(
|
|
229
229
|
f"Unsupported quantization config"
|
230
230
|
f" found for {model_config.quantization} in {f}."
|
231
231
|
)
|
232
|
+
elif model_config.quantization == "w8a8_int8":
|
233
|
+
config["packed_modules_mapping"] = packed_modules_mapping
|
232
234
|
|
233
235
|
return quant_cls.from_config(config)
|
234
236
|
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -29,10 +29,14 @@ from tqdm import tqdm
|
|
29
29
|
from transformers import PretrainedConfig
|
30
30
|
|
31
31
|
from sglang.srt.distributed import (
|
32
|
+
get_moe_expert_parallel_world_size,
|
32
33
|
get_tensor_model_parallel_world_size,
|
33
34
|
parallel_state,
|
34
35
|
tensor_model_parallel_all_reduce,
|
35
36
|
)
|
37
|
+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
38
|
+
use_symmetric_memory,
|
39
|
+
)
|
36
40
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
37
41
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
38
42
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
@@ -59,9 +63,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
59
63
|
from sglang.srt.layers.moe.ep_moe.layer import (
|
60
64
|
DeepEPMoE,
|
61
65
|
get_moe_impl_class,
|
62
|
-
|
66
|
+
should_use_flashinfer_trtllm_moe,
|
63
67
|
)
|
64
|
-
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
65
68
|
from sglang.srt.layers.moe.topk import TopK
|
66
69
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
67
70
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -96,7 +99,6 @@ from sglang.srt.two_batch_overlap import (
|
|
96
99
|
)
|
97
100
|
from sglang.srt.utils import (
|
98
101
|
BumpAllocator,
|
99
|
-
DeepEPMode,
|
100
102
|
LazyValue,
|
101
103
|
add_prefix,
|
102
104
|
bind_or_assign,
|
@@ -252,8 +254,7 @@ class MoEGate(nn.Module):
|
|
252
254
|
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
|
253
255
|
if (
|
254
256
|
_is_cuda
|
255
|
-
and
|
256
|
-
and hidden_states.shape[0] < 4
|
257
|
+
and hidden_states.shape[0] <= 16
|
257
258
|
and hidden_states.shape[1] == 7168
|
258
259
|
and self.weight.shape[0] == 256
|
259
260
|
and _device_sm >= 90
|
@@ -317,7 +318,7 @@ class DeepseekV2MoE(nn.Module):
|
|
317
318
|
correction_bias=self.gate.e_score_correction_bias,
|
318
319
|
routed_scaling_factor=self.routed_scaling_factor,
|
319
320
|
)
|
320
|
-
if not
|
321
|
+
if not should_use_flashinfer_trtllm_moe()
|
321
322
|
else None
|
322
323
|
)
|
323
324
|
|
@@ -334,15 +335,14 @@ class DeepseekV2MoE(nn.Module):
|
|
334
335
|
routed_scaling_factor=self.routed_scaling_factor,
|
335
336
|
prefix=add_prefix("experts", prefix),
|
336
337
|
**(
|
337
|
-
dict(deepep_mode=
|
338
|
-
if global_server_args_dict["
|
338
|
+
dict(deepep_mode=global_server_args_dict["deepep_mode"])
|
339
|
+
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
339
340
|
else {}
|
340
341
|
),
|
341
342
|
# Additional args for FusedMoE
|
342
343
|
**(
|
343
344
|
dict(
|
344
345
|
enable_flashinfer_cutlass_moe=True,
|
345
|
-
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
346
346
|
)
|
347
347
|
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
348
348
|
else {}
|
@@ -352,11 +352,10 @@ class DeepseekV2MoE(nn.Module):
|
|
352
352
|
renormalize=config.norm_topk_prob,
|
353
353
|
use_grouped_topk=True,
|
354
354
|
num_expert_group=config.n_group,
|
355
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
356
355
|
topk_group=config.topk_group,
|
357
356
|
correction_bias=self.gate.e_score_correction_bias,
|
358
357
|
)
|
359
|
-
if
|
358
|
+
if should_use_flashinfer_trtllm_moe()
|
360
359
|
else {}
|
361
360
|
),
|
362
361
|
)
|
@@ -376,7 +375,7 @@ class DeepseekV2MoE(nn.Module):
|
|
376
375
|
prefix=add_prefix("shared_experts", prefix),
|
377
376
|
**(
|
378
377
|
dict(tp_rank=0, tp_size=1)
|
379
|
-
if global_server_args_dict["
|
378
|
+
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
380
379
|
else {}
|
381
380
|
),
|
382
381
|
)
|
@@ -406,9 +405,9 @@ class DeepseekV2MoE(nn.Module):
|
|
406
405
|
|
407
406
|
self.top_k = config.num_experts_per_tok
|
408
407
|
|
409
|
-
if global_server_args_dict["
|
408
|
+
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
410
409
|
# TODO: we will support tp < ep in the future
|
411
|
-
self.ep_size =
|
410
|
+
self.ep_size = get_moe_expert_parallel_world_size()
|
412
411
|
self.num_experts = (
|
413
412
|
config.n_routed_experts
|
414
413
|
+ global_server_args_dict["ep_num_redundant_experts"]
|
@@ -430,12 +429,12 @@ class DeepseekV2MoE(nn.Module):
|
|
430
429
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
431
430
|
hidden_size=config.hidden_size,
|
432
431
|
params_dtype=config.torch_dtype,
|
433
|
-
deepep_mode=
|
432
|
+
deepep_mode=global_server_args_dict["deepep_mode"],
|
434
433
|
async_finish=True,
|
435
434
|
return_recv_hook=True,
|
436
435
|
)
|
437
436
|
|
438
|
-
self._enable_deepep_moe = global_server_args_dict["
|
437
|
+
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
|
439
438
|
|
440
439
|
def get_moe_weights(self):
|
441
440
|
return [
|
@@ -485,7 +484,11 @@ class DeepseekV2MoE(nn.Module):
|
|
485
484
|
if not _is_cuda:
|
486
485
|
final_hidden_states *= self.routed_scaling_factor
|
487
486
|
current_stream.wait_stream(self.alt_stream)
|
488
|
-
|
487
|
+
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
488
|
+
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
489
|
+
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
490
|
+
final_hidden_states = final_hidden_states_out
|
491
|
+
sm.tag(final_hidden_states)
|
489
492
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
490
493
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
491
494
|
return final_hidden_states
|
@@ -511,7 +514,11 @@ class DeepseekV2MoE(nn.Module):
|
|
511
514
|
# fused in biased_grouped_topk so we can skip here
|
512
515
|
final_hidden_states *= self.routed_scaling_factor
|
513
516
|
if shared_output is not None:
|
514
|
-
|
517
|
+
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
518
|
+
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
519
|
+
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
520
|
+
final_hidden_states = final_hidden_states_out
|
521
|
+
sm.tag(final_hidden_states)
|
515
522
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
516
523
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
517
524
|
return final_hidden_states
|
@@ -1259,6 +1266,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1259
1266
|
self.current_attention_backend == "fa3"
|
1260
1267
|
or self.current_attention_backend == "flashinfer"
|
1261
1268
|
or self.current_attention_backend == "cutlass_mla"
|
1269
|
+
or self.current_attention_backend == "trtllm_mla"
|
1262
1270
|
):
|
1263
1271
|
attn_output = self.attn_mqa(
|
1264
1272
|
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
@@ -2105,11 +2113,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2105
2113
|
or self.config.n_shared_experts != 1
|
2106
2114
|
):
|
2107
2115
|
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
2108
|
-
elif (
|
2109
|
-
|
2110
|
-
or global_server_args_dict["enable_ep_moe"]
|
2111
|
-
):
|
2112
|
-
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
2116
|
+
elif get_moe_expert_parallel_world_size() > 1:
|
2117
|
+
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
|
2113
2118
|
|
2114
2119
|
if disable_reason is not None:
|
2115
2120
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|