sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/conversation.py +0 -112
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  7. sglang/srt/disaggregation/launch_lb.py +5 -20
  8. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  9. sglang/srt/disaggregation/prefill.py +1 -0
  10. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  11. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  12. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  13. sglang/srt/distributed/parallel_state.py +11 -0
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +35 -15
  16. sglang/srt/eplb/expert_distribution.py +4 -2
  17. sglang/srt/hf_transformers_utils.py +25 -10
  18. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  19. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  20. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  21. sglang/srt/layers/attention/utils.py +6 -1
  22. sglang/srt/layers/attention/vision.py +27 -10
  23. sglang/srt/layers/communicator.py +14 -4
  24. sglang/srt/layers/linear.py +7 -1
  25. sglang/srt/layers/logits_processor.py +9 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +29 -68
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
  29. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  30. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  31. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  32. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  33. sglang/srt/layers/moe/utils.py +43 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  35. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  36. sglang/srt/layers/quantization/fp8.py +57 -1
  37. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  38. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  39. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  40. sglang/srt/lora/lora_registry.py +7 -0
  41. sglang/srt/managers/cache_controller.py +43 -39
  42. sglang/srt/managers/data_parallel_controller.py +52 -2
  43. sglang/srt/managers/io_struct.py +6 -1
  44. sglang/srt/managers/schedule_batch.py +3 -2
  45. sglang/srt/managers/schedule_policy.py +3 -1
  46. sglang/srt/managers/scheduler.py +145 -6
  47. sglang/srt/managers/template_manager.py +25 -22
  48. sglang/srt/managers/tokenizer_manager.py +114 -62
  49. sglang/srt/managers/utils.py +45 -1
  50. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  51. sglang/srt/mem_cache/hicache_storage.py +13 -12
  52. sglang/srt/mem_cache/hiradix_cache.py +21 -4
  53. sglang/srt/mem_cache/memory_pool.py +15 -118
  54. sglang/srt/mem_cache/memory_pool_host.py +350 -33
  55. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  56. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
  57. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  58. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
  59. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
  60. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
  61. sglang/srt/model_executor/cuda_graph_runner.py +42 -4
  62. sglang/srt/model_executor/forward_batch_info.py +13 -3
  63. sglang/srt/model_executor/model_runner.py +13 -1
  64. sglang/srt/model_loader/weight_utils.py +2 -0
  65. sglang/srt/models/deepseek_v2.py +28 -23
  66. sglang/srt/models/glm4_moe.py +85 -22
  67. sglang/srt/models/grok.py +3 -3
  68. sglang/srt/models/llama4.py +13 -2
  69. sglang/srt/models/mixtral.py +3 -3
  70. sglang/srt/models/mllama4.py +428 -19
  71. sglang/srt/models/qwen2_moe.py +1 -4
  72. sglang/srt/models/qwen3_moe.py +7 -8
  73. sglang/srt/models/step3_vl.py +1 -4
  74. sglang/srt/multimodal/processors/base_processor.py +4 -3
  75. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  76. sglang/srt/operations_strategy.py +1 -1
  77. sglang/srt/server_args.py +115 -21
  78. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  79. sglang/srt/two_batch_overlap.py +6 -4
  80. sglang/srt/utils.py +4 -24
  81. sglang/srt/weight_sync/utils.py +1 -1
  82. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  83. sglang/test/runners.py +2 -2
  84. sglang/test/test_utils.py +3 -3
  85. sglang/version.py +1 -1
  86. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  87. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
  88. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  89. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  90. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  91. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  92. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,216 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import unittest
5
+ from typing import List, Optional
6
+ from unittest.mock import MagicMock
7
+
8
+ import torch
9
+
10
+ from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
11
+ from sglang.srt.mem_cache.nixl.nixl_utils import NixlFileManager, NixlRegistration
12
+
13
+
14
+ class TestNixlUnified(unittest.TestCase):
15
+ """Unified test suite for all NIXL components."""
16
+
17
+ def setUp(self):
18
+ """Set up test environment."""
19
+ # Create test directories
20
+ self.test_dir = "/tmp/test_nixl_unified"
21
+ os.makedirs(self.test_dir, exist_ok=True)
22
+
23
+ # Mock NIXL agent for registration tests
24
+ self.mock_agent = MagicMock()
25
+ self.mock_agent.get_reg_descs.return_value = "mock_reg_descs"
26
+ self.mock_agent.register_memory.return_value = "mock_registered_memory"
27
+
28
+ # Create instances
29
+ self.file_manager = NixlFileManager(self.test_dir)
30
+ self.registration = NixlRegistration(self.mock_agent)
31
+ try:
32
+ self.hicache = HiCacheNixl(file_path=self.test_dir, plugin="POSIX")
33
+ except ImportError:
34
+ self.skipTest("NIXL not available, skipping NIXL storage tests")
35
+
36
+ def tearDown(self):
37
+ """Clean up test directories."""
38
+ if os.path.exists(self.test_dir):
39
+ import shutil
40
+
41
+ shutil.rmtree(self.test_dir)
42
+
43
+ def delete_test_file(self, file_path: str) -> bool:
44
+ """Helper method to delete a test file.
45
+
46
+ Args:
47
+ file_path: Path to the file to delete
48
+
49
+ Returns:
50
+ bool: True if file was deleted or didn't exist, False on error
51
+ """
52
+ try:
53
+ if os.path.exists(file_path):
54
+ os.remove(file_path)
55
+ return True
56
+ except Exception as e:
57
+ return False
58
+
59
+ def verify_tensors_equal(self, expected: torch.Tensor, actual: torch.Tensor):
60
+ """Helper to verify tensor equality."""
61
+ self.assertIsNotNone(actual, "Retrieved tensor is None")
62
+ self.assertTrue(
63
+ torch.allclose(expected, actual, atol=1e-6),
64
+ f"Tensors not equal:\nExpected: {expected}\nActual: {actual}",
65
+ )
66
+
67
+ def verify_tensor_lists_equal(
68
+ self, expected: List[torch.Tensor], actual: List[torch.Tensor]
69
+ ):
70
+ """Helper to verify lists of tensors are equal."""
71
+ self.assertEqual(len(expected), len(actual), "Lists have different lengths")
72
+ for exp, act in zip(expected, actual):
73
+ self.verify_tensors_equal(exp, act)
74
+
75
+ # ============================================================================
76
+ # HiCache Integration Tests
77
+ # ============================================================================
78
+
79
+ def test_single_set_get(self):
80
+ """Test single tensor set/get operations."""
81
+ key = "test_key"
82
+ value = torch.randn(10, 10, device="cpu")
83
+ dst_tensor = torch.zeros_like(value, device="cpu")
84
+
85
+ # Test set
86
+ self.assertTrue(self.hicache.set(key, value))
87
+ self.assertTrue(self.hicache.exists(key))
88
+
89
+ # Test get
90
+ retrieved = self.hicache.get(key, dst_tensor)
91
+ self.verify_tensors_equal(value, retrieved)
92
+
93
+ def test_batch_set_get(self):
94
+ """Test batch tensor set/get operations."""
95
+ keys = ["key1", "key2", "key3"]
96
+ values = [
97
+ torch.randn(5, 5, device="cpu"),
98
+ torch.randn(3, 3, device="cpu"),
99
+ torch.randn(7, 7, device="cpu"),
100
+ ]
101
+ dst_tensors = [torch.zeros_like(v, device="cpu") for v in values]
102
+
103
+ # Test batch set
104
+ self.assertTrue(self.hicache.batch_set(keys, values))
105
+ self.assertTrue(all(self.hicache.exists(key) for key in keys))
106
+
107
+ # Test batch get
108
+ retrieved = self.hicache.batch_get(keys, dst_tensors)
109
+ self.verify_tensor_lists_equal(values, retrieved)
110
+
111
+ def test_mixed_operations(self):
112
+ """Test mixing single and batch operations."""
113
+ # Test interleaved set/get operations
114
+ key1, key2 = "key1", "key2"
115
+ value1 = torch.randn(4, 4, device="cpu")
116
+ value2 = torch.randn(6, 6, device="cpu")
117
+ dst1 = torch.zeros_like(value1)
118
+ dst2 = torch.zeros_like(value2)
119
+
120
+ # Single set/get
121
+ self.assertTrue(self.hicache.set(key1, value1))
122
+ retrieved1 = self.hicache.get(key1, dst1)
123
+ self.verify_tensors_equal(value1, retrieved1)
124
+
125
+ # Batch set/get
126
+ self.assertTrue(self.hicache.batch_set([key2], [value2]))
127
+ retrieved2 = self.hicache.batch_get([key2], [dst2])
128
+ self.verify_tensors_equal(value2, retrieved2[0])
129
+
130
+ def test_data_integrity(self):
131
+ """Test data integrity across operations."""
132
+ # Test with various tensor types and sizes
133
+ test_cases = [
134
+ ("float32", torch.randn(10, 10, dtype=torch.float32)),
135
+ ("float64", torch.randn(5, 5, dtype=torch.float64)),
136
+ ("int32", torch.randint(-100, 100, (8, 8), dtype=torch.int32)),
137
+ ("int64", torch.randint(-100, 100, (6, 6), dtype=torch.int64)),
138
+ ("bool", torch.randint(0, 2, (4, 4)).bool()),
139
+ ]
140
+
141
+ for name, tensor in test_cases:
142
+ with self.subTest(tensor_type=name):
143
+ key = f"test_{name}"
144
+ dst_tensor = torch.zeros_like(tensor)
145
+
146
+ # Set and immediately get
147
+ self.assertTrue(self.hicache.set(key, tensor))
148
+ retrieved1 = self.hicache.get(key, dst_tensor)
149
+ self.verify_tensors_equal(tensor, retrieved1)
150
+
151
+ # Get again to verify persistence
152
+ dst_tensor.zero_()
153
+ retrieved2 = self.hicache.get(key, dst_tensor)
154
+ self.verify_tensors_equal(tensor, retrieved2)
155
+
156
+ def test_basic_file_operations(self):
157
+ """Test basic file operations."""
158
+ test_file = os.path.join(self.test_dir, "test_file.bin")
159
+ self.file_manager.create_file(test_file)
160
+ self.assertTrue(os.path.exists(test_file))
161
+ self.assertEqual(os.path.getsize(test_file), 0) # Empty file
162
+
163
+ # Test file deletion
164
+ self.assertTrue(self.delete_test_file(test_file))
165
+ self.assertFalse(os.path.exists(test_file))
166
+
167
+ def test_create_nixl_tuples(self):
168
+ """Test creation of NIXL tuples."""
169
+ test_file = os.path.join(self.test_dir, "test_file.bin")
170
+ self.file_manager.create_file(test_file)
171
+
172
+ # Test tuple creation
173
+ tuples = self.file_manager.files_to_nixl_tuples([test_file], False)
174
+ self.assertIsNotNone(tuples)
175
+ self.assertTrue(len(tuples) > 0)
176
+
177
+ def test_error_handling(self):
178
+ """Test error handling in file operations."""
179
+ # Test non-existent file
180
+ self.assertTrue(
181
+ self.delete_test_file("nonexistent_file.bin")
182
+ ) # Returns True if file doesn't exist
183
+
184
+ # Test invalid file path
185
+ self.assertFalse(self.file_manager.create_file("")) # Empty path should fail
186
+
187
+ def test_register_buffers(self):
188
+ """Test registration of memory buffers."""
189
+ # Create test tensor
190
+ tensor = torch.randn(10, 10)
191
+
192
+ # Test buffer registration
193
+ self.assertIsNotNone(self.registration.register_buffers(tensor))
194
+
195
+ # Test batch registration
196
+ tensors = [torch.randn(5, 5) for _ in range(3)]
197
+ self.assertIsNotNone(self.registration.register_buffers(tensors))
198
+
199
+ def test_register_files_with_tuples(self):
200
+ """Test registration of files using NIXL tuples."""
201
+ files = [os.path.join(self.test_dir, f"test_file_{i}.bin") for i in range(3)]
202
+ for file in files:
203
+ self.file_manager.create_file(file)
204
+
205
+ # Create tuples and register
206
+ tuples = self.file_manager.files_to_nixl_tuples(files, False)
207
+ self.registration.register_files(tuples)
208
+
209
+ # Verify tuples
210
+ self.assertEqual(len(tuples), len(files))
211
+ for t, f in zip(tuples, files):
212
+ self.assertEqual(t[3], f) # Check file path
213
+
214
+
215
+ if __name__ == "__main__":
216
+ unittest.main()
@@ -16,6 +16,7 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import bisect
19
+ import gc
19
20
  import inspect
20
21
  import logging
21
22
  import os
@@ -28,6 +29,9 @@ from torch.profiler import ProfilerActivity, profile
28
29
 
29
30
  from sglang.srt.custom_op import CustomOp
30
31
  from sglang.srt.distributed import get_tensor_model_parallel_rank
32
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
33
+ set_graph_pool_id,
34
+ )
31
35
  from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
32
36
  from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
33
37
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -75,6 +79,24 @@ def model_capture_mode():
75
79
  is_capture_mode = False
76
80
 
77
81
 
82
+ @contextmanager
83
+ def freeze_gc(enable_cudagraph_gc: bool):
84
+ """
85
+ Optimize garbage collection during CUDA graph capture.
86
+ Clean up, then freeze all remaining objects from being included
87
+ in future collections if GC is disabled during capture.
88
+ """
89
+ gc.collect()
90
+ should_freeze = not enable_cudagraph_gc
91
+ if should_freeze:
92
+ gc.freeze()
93
+ try:
94
+ yield
95
+ finally:
96
+ if should_freeze:
97
+ gc.unfreeze()
98
+
99
+
78
100
  def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
79
101
  for sub in model._modules.values():
80
102
  if isinstance(sub, CustomOp):
@@ -353,6 +375,11 @@ class CudaGraphRunner:
353
375
  dtype=torch.bool,
354
376
  device="cuda",
355
377
  )
378
+ self.next_token_logits_buffer = torch.zeros(
379
+ (self.max_num_token, self.model_runner.model_config.vocab_size),
380
+ dtype=torch.float,
381
+ device="cuda",
382
+ )
356
383
 
357
384
  # Capture
358
385
  try:
@@ -423,7 +450,12 @@ class CudaGraphRunner:
423
450
  record_shapes=True,
424
451
  )
425
452
 
426
- with graph_capture() as graph_capture_context:
453
+ # Trigger CUDA graph capture for specific shapes.
454
+ # Capture the large shapes first so that the smaller shapes
455
+ # can reuse the memory pool allocated for the large shapes.
456
+ with freeze_gc(
457
+ self.model_runner.server_args.enable_cudagraph_gc
458
+ ), graph_capture() as graph_capture_context:
427
459
  with profile_context as prof:
428
460
  self.stream = graph_capture_context.stream
429
461
  avail_mem = get_available_gpu_memory(
@@ -493,6 +525,7 @@ class CudaGraphRunner:
493
525
  else:
494
526
  encoder_lens = None
495
527
  mrope_positions = self.mrope_positions[:, :bs]
528
+ next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
496
529
  self.num_token_non_padded[...] = num_tokens
497
530
 
498
531
  # pipeline parallelism
@@ -555,6 +588,7 @@ class CudaGraphRunner:
555
588
  input_ids=input_ids,
556
589
  req_pool_indices=req_pool_indices,
557
590
  seq_lens=seq_lens,
591
+ next_token_logits_buffer=next_token_logits_buffer,
558
592
  req_to_token_pool=self.model_runner.req_to_token_pool,
559
593
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
560
594
  attn_backend=self.model_runner.attn_backend,
@@ -619,11 +653,15 @@ class CudaGraphRunner:
619
653
 
620
654
  run_once()
621
655
 
622
- global global_graph_memory_pool
623
- with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
656
+ if get_global_graph_memory_pool() is None:
657
+ set_global_graph_memory_pool(torch.cuda.graph_pool_handle())
658
+ # Set graph pool id globally to be able to use symmetric memory
659
+ set_graph_pool_id(get_global_graph_memory_pool())
660
+ with torch.cuda.graph(
661
+ graph, pool=get_global_graph_memory_pool(), stream=stream
662
+ ):
624
663
  out = run_once()
625
664
 
626
- global_graph_memory_pool = graph.pool()
627
665
  return graph, out
628
666
 
629
667
  def recapture_if_needed(self, forward_batch: ForwardBatch):
@@ -38,6 +38,7 @@ import torch
38
38
  import triton
39
39
  import triton.language as tl
40
40
 
41
+ from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
41
42
  from sglang.srt.layers.dp_attention import (
42
43
  DPPaddingMode,
43
44
  get_attention_dp_rank,
@@ -188,6 +189,7 @@ class ForwardBatch:
188
189
  token_ids_logprobs: Optional[List[List[int]]] = None
189
190
 
190
191
  # For logits and logprobs post processing
192
+ next_token_logits_buffer: torch.Tensor = None
191
193
  temp_scaled_logprobs: bool = False
192
194
  temperature: torch.Tensor = None
193
195
  top_p_normalized_logprobs: bool = False
@@ -644,12 +646,17 @@ class ForwardBatch:
644
646
  device=model_runner.device,
645
647
  )
646
648
 
647
- bs = self.batch_size
648
649
  if len(global_num_tokens) > 1:
649
650
  num_tokens = global_num_tokens[get_attention_dp_rank()]
650
651
  else:
651
652
  num_tokens = global_num_tokens[0]
652
653
 
654
+ if self.forward_mode.is_decode():
655
+ setattr(self, "raw_bs", self.batch_size)
656
+ self.batch_size = num_tokens
657
+
658
+ bs = self.batch_size
659
+
653
660
  # padding
654
661
  self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
655
662
  self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
@@ -657,6 +664,9 @@ class ForwardBatch:
657
664
  seq_len_fill_value = (
658
665
  model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
659
666
  )
667
+ self.seq_lens_sum = self.seq_lens_sum + seq_len_fill_value * (
668
+ bs - self.seq_lens.shape[0]
669
+ )
660
670
  self.seq_lens = self._pad_tensor_to_size(
661
671
  self.seq_lens, bs, value=seq_len_fill_value
662
672
  )
@@ -700,7 +710,7 @@ class ForwardBatch:
700
710
 
701
711
  def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
702
712
 
703
- bs = self.batch_size
713
+ bs = getattr(self, "raw_bs", self.batch_size)
704
714
 
705
715
  if self.spec_info is not None:
706
716
  if self.forward_mode.is_decode(): # draft
@@ -839,7 +849,7 @@ class ForwardBatch:
839
849
 
840
850
 
841
851
  def enable_num_token_non_padded(server_args):
842
- return server_args.enable_ep_moe or server_args.enable_deepep_moe
852
+ return get_moe_expert_parallel_world_size() > 1
843
853
 
844
854
 
845
855
  class PPProxyTensors:
@@ -60,6 +60,7 @@ from sglang.srt.layers.dp_attention import (
60
60
  initialize_dp_attention,
61
61
  )
62
62
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
63
+ from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
63
64
  from sglang.srt.layers.quantization import (
64
65
  deep_gemm_wrapper,
65
66
  monkey_patch_isinstance_for_vllm_base_layer,
@@ -217,6 +218,10 @@ class ModelRunner:
217
218
  "use_mla_backend": self.use_mla_backend,
218
219
  "speculative_algorithm": self.spec_algorithm,
219
220
  }
221
+ | {
222
+ "moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
223
+ "deepep_mode": DeepEPMode(server_args.deepep_mode),
224
+ }
220
225
  )
221
226
 
222
227
  # CPU offload
@@ -436,6 +441,7 @@ class ModelRunner:
436
441
  "triton",
437
442
  "flashmla",
438
443
  "cutlass_mla",
444
+ "trtllm_mla",
439
445
  "ascend",
440
446
  ]:
441
447
  logger.info(
@@ -671,7 +677,7 @@ class ModelRunner:
671
677
  self.sliding_window_size = self.model.get_attention_sliding_window_size()
672
678
  elif self.model_config.attention_chunk_size is not None:
673
679
  self.sliding_window_size = self.model_config.attention_chunk_size
674
- print(
680
+ logger.info(
675
681
  f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
676
682
  )
677
683
 
@@ -1437,6 +1443,12 @@ class ModelRunner:
1437
1443
  )
1438
1444
 
1439
1445
  return CutlassMLABackend(self)
1446
+ elif self.server_args.attention_backend == "trtllm_mla":
1447
+ if not self.use_mla_backend:
1448
+ raise ValueError("trtllm_mla backend can only be used with MLA models.")
1449
+ from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
1450
+
1451
+ return TRTLLMMLABackend(self)
1440
1452
  elif self.server_args.attention_backend == "intel_amx":
1441
1453
  from sglang.srt.layers.attention.intel_amx_backend import (
1442
1454
  IntelAMXAttnBackend,
@@ -229,6 +229,8 @@ def get_quant_config(
229
229
  f"Unsupported quantization config"
230
230
  f" found for {model_config.quantization} in {f}."
231
231
  )
232
+ elif model_config.quantization == "w8a8_int8":
233
+ config["packed_modules_mapping"] = packed_modules_mapping
232
234
 
233
235
  return quant_cls.from_config(config)
234
236
 
@@ -29,10 +29,14 @@ from tqdm import tqdm
29
29
  from transformers import PretrainedConfig
30
30
 
31
31
  from sglang.srt.distributed import (
32
+ get_moe_expert_parallel_world_size,
32
33
  get_tensor_model_parallel_world_size,
33
34
  parallel_state,
34
35
  tensor_model_parallel_all_reduce,
35
36
  )
37
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
38
+ use_symmetric_memory,
39
+ )
36
40
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
37
41
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
38
42
  from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
@@ -59,9 +63,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
59
63
  from sglang.srt.layers.moe.ep_moe.layer import (
60
64
  DeepEPMoE,
61
65
  get_moe_impl_class,
62
- use_flashinfer_trtllm_moe,
66
+ should_use_flashinfer_trtllm_moe,
63
67
  )
64
- from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
65
68
  from sglang.srt.layers.moe.topk import TopK
66
69
  from sglang.srt.layers.quantization import deep_gemm_wrapper
67
70
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -96,7 +99,6 @@ from sglang.srt.two_batch_overlap import (
96
99
  )
97
100
  from sglang.srt.utils import (
98
101
  BumpAllocator,
99
- DeepEPMode,
100
102
  LazyValue,
101
103
  add_prefix,
102
104
  bind_or_assign,
@@ -252,8 +254,7 @@ class MoEGate(nn.Module):
252
254
  # NOTE: For some unknown reason, router_gemm seems degrade accept length.
253
255
  if (
254
256
  _is_cuda
255
- and not self.is_nextn
256
- and hidden_states.shape[0] < 4
257
+ and hidden_states.shape[0] <= 16
257
258
  and hidden_states.shape[1] == 7168
258
259
  and self.weight.shape[0] == 256
259
260
  and _device_sm >= 90
@@ -317,7 +318,7 @@ class DeepseekV2MoE(nn.Module):
317
318
  correction_bias=self.gate.e_score_correction_bias,
318
319
  routed_scaling_factor=self.routed_scaling_factor,
319
320
  )
320
- if not use_flashinfer_trtllm_moe
321
+ if not should_use_flashinfer_trtllm_moe()
321
322
  else None
322
323
  )
323
324
 
@@ -334,15 +335,14 @@ class DeepseekV2MoE(nn.Module):
334
335
  routed_scaling_factor=self.routed_scaling_factor,
335
336
  prefix=add_prefix("experts", prefix),
336
337
  **(
337
- dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
338
- if global_server_args_dict["enable_deepep_moe"]
338
+ dict(deepep_mode=global_server_args_dict["deepep_mode"])
339
+ if global_server_args_dict["moe_a2a_backend"].is_deepep()
339
340
  else {}
340
341
  ),
341
342
  # Additional args for FusedMoE
342
343
  **(
343
344
  dict(
344
345
  enable_flashinfer_cutlass_moe=True,
345
- enable_ep_moe=global_server_args_dict["enable_ep_moe"],
346
346
  )
347
347
  if global_server_args_dict["enable_flashinfer_cutlass_moe"]
348
348
  else {}
@@ -352,11 +352,10 @@ class DeepseekV2MoE(nn.Module):
352
352
  renormalize=config.norm_topk_prob,
353
353
  use_grouped_topk=True,
354
354
  num_expert_group=config.n_group,
355
- num_fused_shared_experts=self.num_fused_shared_experts,
356
355
  topk_group=config.topk_group,
357
356
  correction_bias=self.gate.e_score_correction_bias,
358
357
  )
359
- if use_flashinfer_trtllm_moe
358
+ if should_use_flashinfer_trtllm_moe()
360
359
  else {}
361
360
  ),
362
361
  )
@@ -376,7 +375,7 @@ class DeepseekV2MoE(nn.Module):
376
375
  prefix=add_prefix("shared_experts", prefix),
377
376
  **(
378
377
  dict(tp_rank=0, tp_size=1)
379
- if global_server_args_dict["enable_deepep_moe"]
378
+ if global_server_args_dict["moe_a2a_backend"].is_deepep()
380
379
  else {}
381
380
  ),
382
381
  )
@@ -406,9 +405,9 @@ class DeepseekV2MoE(nn.Module):
406
405
 
407
406
  self.top_k = config.num_experts_per_tok
408
407
 
409
- if global_server_args_dict["enable_deepep_moe"]:
408
+ if global_server_args_dict["moe_a2a_backend"].is_deepep():
410
409
  # TODO: we will support tp < ep in the future
411
- self.ep_size = get_tensor_model_parallel_world_size()
410
+ self.ep_size = get_moe_expert_parallel_world_size()
412
411
  self.num_experts = (
413
412
  config.n_routed_experts
414
413
  + global_server_args_dict["ep_num_redundant_experts"]
@@ -430,12 +429,12 @@ class DeepseekV2MoE(nn.Module):
430
429
  num_local_experts=config.n_routed_experts // self.tp_size,
431
430
  hidden_size=config.hidden_size,
432
431
  params_dtype=config.torch_dtype,
433
- deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
432
+ deepep_mode=global_server_args_dict["deepep_mode"],
434
433
  async_finish=True,
435
434
  return_recv_hook=True,
436
435
  )
437
436
 
438
- self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
437
+ self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
439
438
 
440
439
  def get_moe_weights(self):
441
440
  return [
@@ -485,7 +484,11 @@ class DeepseekV2MoE(nn.Module):
485
484
  if not _is_cuda:
486
485
  final_hidden_states *= self.routed_scaling_factor
487
486
  current_stream.wait_stream(self.alt_stream)
488
- final_hidden_states += shared_output
487
+ with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
488
+ final_hidden_states_out = torch.empty_like(final_hidden_states)
489
+ torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
490
+ final_hidden_states = final_hidden_states_out
491
+ sm.tag(final_hidden_states)
489
492
  if self.tp_size > 1 and not can_fuse_mlp_allreduce:
490
493
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
491
494
  return final_hidden_states
@@ -511,7 +514,11 @@ class DeepseekV2MoE(nn.Module):
511
514
  # fused in biased_grouped_topk so we can skip here
512
515
  final_hidden_states *= self.routed_scaling_factor
513
516
  if shared_output is not None:
514
- final_hidden_states = final_hidden_states + shared_output
517
+ with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
518
+ final_hidden_states_out = torch.empty_like(final_hidden_states)
519
+ torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
520
+ final_hidden_states = final_hidden_states_out
521
+ sm.tag(final_hidden_states)
515
522
  if self.tp_size > 1 and not can_fuse_mlp_allreduce:
516
523
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
517
524
  return final_hidden_states
@@ -1259,6 +1266,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1259
1266
  self.current_attention_backend == "fa3"
1260
1267
  or self.current_attention_backend == "flashinfer"
1261
1268
  or self.current_attention_backend == "cutlass_mla"
1269
+ or self.current_attention_backend == "trtllm_mla"
1262
1270
  ):
1263
1271
  attn_output = self.attn_mqa(
1264
1272
  q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
@@ -2105,11 +2113,8 @@ class DeepseekV2ForCausalLM(nn.Module):
2105
2113
  or self.config.n_shared_experts != 1
2106
2114
  ):
2107
2115
  disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
2108
- elif (
2109
- global_server_args_dict["enable_deepep_moe"]
2110
- or global_server_args_dict["enable_ep_moe"]
2111
- ):
2112
- disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
2116
+ elif get_moe_expert_parallel_world_size() > 1:
2117
+ disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
2113
2118
 
2114
2119
  if disable_reason is not None:
2115
2120
  global_server_args_dict["disable_shared_experts_fusion"] = True