sglang 0.4.9__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +12 -1
  3. sglang/srt/conversation.py +35 -1
  4. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  5. sglang/srt/entrypoints/http_server_engine.py +1 -1
  6. sglang/srt/layers/communicator.py +3 -1
  7. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  8. sglang/srt/layers/layernorm.py +2 -2
  9. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  10. sglang/srt/layers/moe/ep_moe/kernels.py +58 -0
  11. sglang/srt/layers/moe/ep_moe/layer.py +140 -2
  12. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  13. sglang/srt/layers/moe/fused_moe_triton/layer.py +135 -58
  14. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  15. sglang/srt/layers/quantization/__init__.py +2 -0
  16. sglang/srt/layers/quantization/fp8.py +28 -7
  17. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  18. sglang/srt/layers/quantization/w4afp8.py +264 -0
  19. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  20. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  21. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  22. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  23. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  24. sglang/srt/managers/cache_controller.py +41 -195
  25. sglang/srt/managers/io_struct.py +8 -1
  26. sglang/srt/managers/mm_utils.py +4 -2
  27. sglang/srt/managers/schedule_batch.py +1 -1
  28. sglang/srt/managers/scheduler.py +17 -5
  29. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  30. sglang/srt/mem_cache/memory_pool.py +113 -63
  31. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  32. sglang/srt/mem_cache/radix_cache.py +8 -4
  33. sglang/srt/models/deepseek_v2.py +16 -2
  34. sglang/srt/models/mllama4.py +360 -79
  35. sglang/srt/multimodal/mm_utils.py +2 -2
  36. sglang/srt/multimodal/processors/mllama4.py +62 -60
  37. sglang/srt/server_args.py +15 -0
  38. sglang/srt/two_batch_overlap.py +3 -0
  39. sglang/srt/utils.py +37 -17
  40. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  41. sglang/utils.py +5 -5
  42. sglang/version.py +1 -1
  43. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +4 -3
  44. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +47 -43
  45. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  46. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  47. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -34,10 +34,11 @@ import torch
34
34
  import torch.distributed as dist
35
35
  import triton
36
36
  import triton.language as tl
37
+ from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
37
38
 
38
39
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
39
40
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
41
+ from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
41
42
 
42
43
  logger = logging.getLogger(__name__)
43
44
 
@@ -150,13 +151,16 @@ class KVCache(abc.ABC):
150
151
  ) -> None:
151
152
  raise NotImplementedError()
152
153
 
153
- def get_flat_data(self, indices):
154
- raise NotImplementedError()
155
-
156
- def transfer(self, indices, flat_data):
154
+ @abc.abstractmethod
155
+ def load_from_host_per_layer(
156
+ self, host_pool, host_indices, device_indices, layer_id, io_backend
157
+ ):
157
158
  raise NotImplementedError()
158
159
 
159
- def transfer_per_layer(self, indices, flat_data, layer_id):
160
+ @abc.abstractmethod
161
+ def backup_to_host_all_layer(
162
+ self, host_pool, host_indices, device_indices, io_backend
163
+ ):
160
164
  raise NotImplementedError()
161
165
 
162
166
  def register_layer_transfer_counter(self, layer_transfer_counter):
@@ -247,7 +251,7 @@ class MHATokenToKVPool(KVCache):
247
251
  )
248
252
  for _ in range(self.layer_num)
249
253
  ]
250
-
254
+ self.token_stride = self.head_num * self.head_dim
251
255
  self.data_ptrs = torch.tensor(
252
256
  [x.data_ptr() for x in self.k_buffer + self.v_buffer],
253
257
  dtype=torch.uint64,
@@ -281,24 +285,24 @@ class MHATokenToKVPool(KVCache):
281
285
  # layer_num x [seq_len, head_num, head_dim]
282
286
  # layer_num x [page_num, page_size, head_num, head_dim]
283
287
  kv_data_ptrs = [
284
- self.get_key_buffer(i).data_ptr()
288
+ self._get_key_buffer(i).data_ptr()
285
289
  for i in range(self.start_layer, self.start_layer + self.layer_num)
286
290
  ] + [
287
- self.get_value_buffer(i).data_ptr()
291
+ self._get_value_buffer(i).data_ptr()
288
292
  for i in range(self.start_layer, self.start_layer + self.layer_num)
289
293
  ]
290
294
  kv_data_lens = [
291
- self.get_key_buffer(i).nbytes
295
+ self._get_key_buffer(i).nbytes
292
296
  for i in range(self.start_layer, self.start_layer + self.layer_num)
293
297
  ] + [
294
- self.get_value_buffer(i).nbytes
298
+ self._get_value_buffer(i).nbytes
295
299
  for i in range(self.start_layer, self.start_layer + self.layer_num)
296
300
  ]
297
301
  kv_item_lens = [
298
- self.get_key_buffer(i)[0].nbytes * self.page_size
302
+ self._get_key_buffer(i)[0].nbytes * self.page_size
299
303
  for i in range(self.start_layer, self.start_layer + self.layer_num)
300
304
  ] + [
301
- self.get_value_buffer(i)[0].nbytes * self.page_size
305
+ self._get_value_buffer(i)[0].nbytes * self.page_size
302
306
  for i in range(self.start_layer, self.start_layer + self.layer_num)
303
307
  ]
304
308
  return kv_data_ptrs, kv_data_lens, kv_item_lens
@@ -341,49 +345,73 @@ class MHATokenToKVPool(KVCache):
341
345
  self.v_buffer[layer_id][chunk_indices] = v_chunk
342
346
  torch.cuda.synchronize()
343
347
 
344
- # Todo: different memory layout
345
- def get_flat_data(self, indices):
346
- # prepare a large chunk of contiguous data for efficient transfer
347
- flatten = torch.stack(
348
- [
349
- torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]),
350
- torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]),
351
- ]
348
+ def load_from_host_per_layer(
349
+ self,
350
+ host_pool,
351
+ host_indices,
352
+ device_indices,
353
+ layer_id,
354
+ io_backend,
355
+ ):
356
+ transfer_kv_per_layer(
357
+ src_k=host_pool.k_buffer[layer_id],
358
+ dst_k=self.k_buffer[layer_id],
359
+ src_v=host_pool.v_buffer[layer_id],
360
+ dst_v=self.v_buffer[layer_id],
361
+ src_indices=host_indices,
362
+ dst_indices=device_indices,
363
+ io_backend=io_backend,
364
+ page_size=self.page_size,
365
+ item_size=self.token_stride,
352
366
  )
353
- return flatten
354
-
355
- @debug_timing
356
- def transfer(self, indices, flat_data):
357
- # transfer prepared data from host to device
358
- flat_data = flat_data.to(device=self.device, non_blocking=False)
359
- k_data, v_data = flat_data[0], flat_data[1]
360
- for i in range(self.layer_num):
361
- self.k_buffer[i][indices] = k_data[i]
362
- self.v_buffer[i][indices] = v_data[i]
363
-
364
- def transfer_per_layer(self, indices, flat_data, layer_id):
365
- # transfer prepared data from host to device
366
- flat_data = flat_data.to(device=self.device, non_blocking=False)
367
- k_data, v_data = flat_data[0], flat_data[1]
368
- self.k_buffer[layer_id - self.start_layer][indices] = k_data
369
- self.v_buffer[layer_id - self.start_layer][indices] = v_data
370
367
 
371
- def get_key_buffer(self, layer_id: int):
372
- if self.layer_transfer_counter is not None:
373
- self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
368
+ def backup_to_host_all_layer(
369
+ self, host_pool, host_indices, device_indices, io_backend
370
+ ):
371
+ # todo: specialized all layer kernels for the layer-non-contiguous memory pool
372
+ for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
373
+ if layer_id - self.start_layer >= len(host_pool.k_buffer):
374
+ raise ValueError(
375
+ f"Layer ID {layer_id} exceeds the number of layers in host pool."
376
+ )
377
+ transfer_kv_per_layer(
378
+ src_k=self.k_buffer[layer_id],
379
+ dst_k=host_pool.k_buffer[layer_id],
380
+ src_v=self.v_buffer[layer_id],
381
+ dst_v=host_pool.v_buffer[layer_id],
382
+ src_indices=device_indices,
383
+ dst_indices=host_indices,
384
+ io_backend=io_backend,
385
+ page_size=self.page_size,
386
+ item_size=self.token_stride,
387
+ )
374
388
 
389
+ def _get_key_buffer(self, layer_id: int):
390
+ # for internal use of referencing
375
391
  if self.store_dtype != self.dtype:
376
392
  return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
377
393
  return self.k_buffer[layer_id - self.start_layer]
378
394
 
379
- def get_value_buffer(self, layer_id: int):
395
+ def get_key_buffer(self, layer_id: int):
396
+ # note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading
397
+ # it is supposed to be used only by attention backend not for information purpose
398
+ # same applies to get_value_buffer and get_kv_buffer
380
399
  if self.layer_transfer_counter is not None:
381
400
  self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
382
401
 
402
+ return self._get_key_buffer(layer_id)
403
+
404
+ def _get_value_buffer(self, layer_id: int):
405
+ # for internal use of referencing
383
406
  if self.store_dtype != self.dtype:
384
407
  return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
385
408
  return self.v_buffer[layer_id - self.start_layer]
386
409
 
410
+ def get_value_buffer(self, layer_id: int):
411
+ if self.layer_transfer_counter is not None:
412
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
413
+ return self._get_value_buffer(layer_id)
414
+
387
415
  def get_kv_buffer(self, layer_id: int):
388
416
  return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
389
417
 
@@ -761,6 +789,7 @@ class MLATokenToKVPool(KVCache):
761
789
  for _ in range(layer_num)
762
790
  ]
763
791
 
792
+ self.token_stride = kv_lora_rank + qk_rope_head_dim
764
793
  self.layer_transfer_counter = None
765
794
 
766
795
  kv_size = self.get_kv_size_bytes()
@@ -846,21 +875,37 @@ class MLATokenToKVPool(KVCache):
846
875
  self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
847
876
  )
848
877
 
849
- def get_flat_data(self, indices):
850
- # prepare a large chunk of contiguous data for efficient transfer
851
- return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
852
-
853
- @debug_timing
854
- def transfer(self, indices, flat_data):
855
- # transfer prepared data from host to device
856
- flat_data = flat_data.to(device=self.device, non_blocking=False)
857
- for i in range(self.layer_num):
858
- self.kv_buffer[i][indices] = flat_data[i]
878
+ def load_from_host_per_layer(
879
+ self, host_pool, host_indices, device_indices, layer_id, io_backend
880
+ ):
881
+ transfer_kv_per_layer_mla(
882
+ src=host_pool.kv_buffer[layer_id],
883
+ dst=self.kv_buffer[layer_id],
884
+ src_indices=host_indices,
885
+ dst_indices=device_indices,
886
+ io_backend=io_backend,
887
+ page_size=self.page_size,
888
+ item_size=self.token_stride,
889
+ )
859
890
 
860
- def transfer_per_layer(self, indices, flat_data, layer_id):
861
- # transfer prepared data from host to device
862
- flat_data = flat_data.to(device=self.device, non_blocking=False)
863
- self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
891
+ def backup_to_host_all_layer(
892
+ self, host_pool, host_indices, device_indices, io_backend
893
+ ):
894
+ # todo: specialized all layer kernels for the layer-non-contiguous memory pool
895
+ for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
896
+ if layer_id - self.start_layer >= len(host_pool.kv_buffer):
897
+ raise ValueError(
898
+ f"Layer ID {layer_id} exceeds the number of layers in host pool."
899
+ )
900
+ transfer_kv_per_layer_mla(
901
+ src=self.kv_buffer[layer_id],
902
+ dst=host_pool.kv_buffer[layer_id],
903
+ src_indices=device_indices,
904
+ dst_indices=host_indices,
905
+ io_backend=io_backend,
906
+ page_size=self.page_size,
907
+ item_size=self.token_stride,
908
+ )
864
909
 
865
910
  def get_cpu_copy(self, indices):
866
911
  torch.cuda.synchronize()
@@ -1046,14 +1091,19 @@ class DoubleSparseTokenToKVPool(KVCache):
1046
1091
  self.v_buffer[layer_id - self.start_layer][loc] = cache_v
1047
1092
  self.label_buffer[layer_id - self.start_layer][loc] = cache_label
1048
1093
 
1049
- def get_flat_data(self, indices):
1050
- pass
1051
-
1052
- def transfer(self, indices, flat_data):
1053
- pass
1094
+ def load_from_host_per_layer(
1095
+ self, host_pool, host_indices, device_indices, layer_id, io_backend
1096
+ ):
1097
+ raise NotImplementedError(
1098
+ "HiCache not supported for DoubleSparseTokenToKVPool."
1099
+ )
1054
1100
 
1055
- def transfer_per_layer(self, indices, flat_data, layer_id):
1056
- pass
1101
+ def backup_to_host_all_layer(
1102
+ self, host_pool, host_indices, device_indices, io_backend
1103
+ ):
1104
+ raise NotImplementedError(
1105
+ "HiCache not supported for DoubleSparseTokenToKVPool."
1106
+ )
1057
1107
 
1058
1108
 
1059
1109
  @triton.jit
@@ -8,7 +8,6 @@ import psutil
8
8
  import torch
9
9
 
10
10
  from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
11
- from sglang.srt.utils import debug_timing
12
11
 
13
12
  logger = logging.getLogger(__name__)
14
13
 
@@ -99,22 +98,6 @@ class HostKVCache(abc.ABC):
99
98
  def init_kv_buffer(self):
100
99
  raise NotImplementedError()
101
100
 
102
- @abc.abstractmethod
103
- def transfer(self, indices, flat_data):
104
- raise NotImplementedError()
105
-
106
- @abc.abstractmethod
107
- def get_flat_data(self, indices):
108
- raise NotImplementedError()
109
-
110
- @abc.abstractmethod
111
- def get_flat_data_by_layer(self, indices, layer_id):
112
- raise NotImplementedError()
113
-
114
- @abc.abstractmethod
115
- def assign_flat_data(self, indices, flat_data):
116
- raise NotImplementedError()
117
-
118
101
  @synchronized()
119
102
  def clear(self):
120
103
  # Initialize memory states and tracking structures.
@@ -243,58 +226,13 @@ class MHATokenToKVPoolHost(HostKVCache):
243
226
  pin_memory=self.pin_memory,
244
227
  )
245
228
 
246
- @debug_timing
247
- def transfer(self, indices, flat_data):
248
- # backup prepared data from device to host
249
- self.kv_buffer[:, :, indices] = flat_data.to(
250
- device=self.device, non_blocking=False
251
- )
229
+ @property
230
+ def k_buffer(self):
231
+ return self.kv_buffer[0]
252
232
 
253
- def get_flat_data(self, indices):
254
- return self.kv_buffer[:, :, indices]
255
-
256
- def get_flat_data_by_layer(self, indices, layer_id):
257
- return self.kv_buffer[:, layer_id - self.start_layer, indices]
258
-
259
- def assign_flat_data(self, indices, flat_data):
260
- self.kv_buffer[:, :, indices] = flat_data
261
-
262
- def write_page_all_layers(self, host_indices, device_indices, device_pool):
263
- device_indices_cpu = device_indices[:: self.page_size].cpu()
264
- for i in range(len(device_indices_cpu)):
265
- h_index = host_indices[i * self.page_size]
266
- d_index = device_indices_cpu[i]
267
- for j in range(self.layer_num):
268
- self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
269
- device_pool.k_buffer[j][d_index : d_index + self.page_size],
270
- non_blocking=True,
271
- )
272
- self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
273
- device_pool.v_buffer[j][d_index : d_index + self.page_size],
274
- non_blocking=True,
275
- )
276
-
277
- def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
278
- device_indices_cpu = device_indices[:: self.page_size].cpu()
279
- for i in range(len(device_indices_cpu)):
280
- h_index = host_indices[i * self.page_size]
281
- d_index = device_indices_cpu[i]
282
- device_pool.k_buffer[layer_id - self.start_layer][
283
- d_index : d_index + self.page_size
284
- ].copy_(
285
- self.kv_buffer[
286
- 0, layer_id - self.start_layer, h_index : h_index + self.page_size
287
- ],
288
- non_blocking=True,
289
- )
290
- device_pool.v_buffer[layer_id - self.start_layer][
291
- d_index : d_index + self.page_size
292
- ].copy_(
293
- self.kv_buffer[
294
- 1, layer_id - self.start_layer, h_index : h_index + self.page_size
295
- ],
296
- non_blocking=True,
297
- )
233
+ @property
234
+ def v_buffer(self):
235
+ return self.kv_buffer[1]
298
236
 
299
237
 
300
238
  class MLATokenToKVPoolHost(HostKVCache):
@@ -337,44 +275,3 @@ class MLATokenToKVPoolHost(HostKVCache):
337
275
  device=self.device,
338
276
  pin_memory=self.pin_memory,
339
277
  )
340
-
341
- @debug_timing
342
- def transfer(self, indices, flat_data):
343
- # backup prepared data from device to host
344
- self.kv_buffer[:, indices] = flat_data.to(
345
- device=self.device, non_blocking=False
346
- )
347
-
348
- def get_flat_data(self, indices):
349
- return self.kv_buffer[:, indices]
350
-
351
- def get_flat_data_by_layer(self, indices, layer_id):
352
- return self.kv_buffer[layer_id - self.start_layer, indices]
353
-
354
- def assign_flat_data(self, indices, flat_data):
355
- self.kv_buffer[:, indices] = flat_data
356
-
357
- def write_page_all_layers(self, host_indices, device_indices, device_pool):
358
- device_indices_cpu = device_indices[:: self.page_size].cpu()
359
- for i in range(len(device_indices_cpu)):
360
- h_index = host_indices[i * self.page_size]
361
- d_index = device_indices_cpu[i]
362
- for j in range(self.layer_num):
363
- self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
364
- device_pool.kv_buffer[j][d_index : d_index + self.page_size],
365
- non_blocking=True,
366
- )
367
-
368
- def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
369
- device_indices_cpu = device_indices[:: self.page_size].cpu()
370
- for i in range(len(device_indices_cpu)):
371
- h_index = host_indices[i * self.page_size]
372
- d_index = device_indices_cpu[i]
373
- device_pool.kv_buffer[layer_id - self.start_layer][
374
- d_index : d_index + self.page_size
375
- ].copy_(
376
- self.kv_buffer[
377
- layer_id - self.start_layer, h_index : h_index + self.page_size
378
- ],
379
- non_blocking=True,
380
- )
@@ -196,11 +196,13 @@ class RadixCache(BasePrefixCache):
196
196
 
197
197
  if self.page_size != 1:
198
198
  page_aligned_len = len(kv_indices) // self.page_size * self.page_size
199
- page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
199
+ page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
200
+ dtype=torch.int64, copy=True
201
+ )
200
202
  self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
201
203
  else:
202
204
  page_aligned_len = len(kv_indices)
203
- page_aligned_kv_indices = kv_indices.clone()
205
+ page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
204
206
 
205
207
  # Radix Cache takes one ref in memory pool
206
208
  new_prefix_len = self.insert(
@@ -226,10 +228,12 @@ class RadixCache(BasePrefixCache):
226
228
 
227
229
  if self.page_size != 1:
228
230
  page_aligned_len = len(kv_indices) // self.page_size * self.page_size
229
- page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
231
+ page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
232
+ dtype=torch.int64, copy=True
233
+ )
230
234
  else:
231
235
  page_aligned_len = len(kv_indices)
232
- page_aligned_kv_indices = kv_indices.clone()
236
+ page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
233
237
  page_aligned_token_ids = token_ids[:page_aligned_len]
234
238
 
235
239
  # Radix Cache takes one ref in memory pool
@@ -210,8 +210,10 @@ class MoEGate(nn.Module):
210
210
  self,
211
211
  config,
212
212
  prefix: str = "",
213
+ is_nextn: bool = False,
213
214
  ):
214
215
  super().__init__()
216
+ self.is_nextn = is_nextn
215
217
  self.weight = nn.Parameter(
216
218
  torch.empty((config.n_routed_experts, config.hidden_size))
217
219
  )
@@ -233,8 +235,10 @@ class MoEGate(nn.Module):
233
235
  True, # is_vnni
234
236
  )
235
237
 
238
+ # NOTE: For some unknown reason, router_gemm seems degrade accept length.
236
239
  if (
237
240
  _is_cuda
241
+ and not self.is_nextn
238
242
  and hidden_states.shape[0] < 4
239
243
  and hidden_states.shape[1] == 7168
240
244
  and self.weight.shape[0] == 256
@@ -258,6 +262,7 @@ class DeepseekV2MoE(nn.Module):
258
262
  quant_config: Optional[QuantizationConfig] = None,
259
263
  prefix: str = "",
260
264
  alt_stream: Optional[torch.cuda.Stream] = None,
265
+ is_nextn: bool = False,
261
266
  ):
262
267
  super().__init__()
263
268
  self.tp_size = get_tensor_model_parallel_world_size()
@@ -284,7 +289,9 @@ class DeepseekV2MoE(nn.Module):
284
289
  "Only silu is supported for now."
285
290
  )
286
291
 
287
- self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
292
+ self.gate = MoEGate(
293
+ config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
294
+ )
288
295
 
289
296
  self.experts = get_moe_impl_class()(
290
297
  num_experts=config.n_routed_experts
@@ -1776,6 +1783,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1776
1783
  prefix=add_prefix("mlp", prefix),
1777
1784
  layer_id=self.layer_id,
1778
1785
  alt_stream=alt_stream,
1786
+ is_nextn=is_nextn,
1779
1787
  )
1780
1788
  else:
1781
1789
  if enable_moe_dense_fully_dp():
@@ -1930,7 +1938,7 @@ class DeepseekV2Model(nn.Module):
1930
1938
  self.embed_tokens = VocabParallelEmbedding(
1931
1939
  config.vocab_size,
1932
1940
  config.hidden_size,
1933
- use_attn_tp_group=True,
1941
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
1934
1942
  )
1935
1943
  self.alt_stream = torch.cuda.Stream() if _is_cuda else None
1936
1944
  self.layers = nn.ModuleList(
@@ -2355,6 +2363,12 @@ class DeepseekV2ForCausalLM(nn.Module):
2355
2363
  ckpt_up_proj_name="up_proj",
2356
2364
  num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
2357
2365
  )
2366
+ if self.quant_config and self.quant_config.get_name() == "w4afp8":
2367
+ expert_params_mapping += (
2368
+ get_moe_impl_class().make_expert_input_scale_params_mapping(
2369
+ num_experts=self.config.n_routed_experts
2370
+ )
2371
+ )
2358
2372
 
2359
2373
  # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
2360
2374
  fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (