sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +35 -0
  5. sglang/srt/conversation.py +9 -117
  6. sglang/srt/disaggregation/base/conn.py +5 -2
  7. sglang/srt/disaggregation/decode.py +6 -1
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
  9. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  10. sglang/srt/disaggregation/prefill.py +3 -0
  11. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  12. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  13. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  14. sglang/srt/distributed/parallel_state.py +22 -9
  15. sglang/srt/entrypoints/context.py +244 -0
  16. sglang/srt/entrypoints/engine.py +8 -5
  17. sglang/srt/entrypoints/harmony_utils.py +370 -0
  18. sglang/srt/entrypoints/http_server.py +106 -15
  19. sglang/srt/entrypoints/openai/protocol.py +227 -1
  20. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  21. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  22. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  23. sglang/srt/entrypoints/tool.py +87 -0
  24. sglang/srt/eplb/expert_distribution.py +4 -2
  25. sglang/srt/eplb/expert_location.py +5 -1
  26. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  27. sglang/srt/hf_transformers_utils.py +55 -13
  28. sglang/srt/jinja_template_utils.py +8 -1
  29. sglang/srt/layers/attention/aiter_backend.py +5 -8
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  31. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  32. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  33. sglang/srt/layers/attention/triton_backend.py +85 -14
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  35. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  36. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  38. sglang/srt/layers/attention/vision.py +40 -15
  39. sglang/srt/layers/communicator.py +35 -8
  40. sglang/srt/layers/dp_attention.py +12 -0
  41. sglang/srt/layers/linear.py +9 -8
  42. sglang/srt/layers/logits_processor.py +9 -1
  43. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  44. sglang/srt/layers/moe/ep_moe/layer.py +87 -107
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
  48. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
  49. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  50. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  51. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  52. sglang/srt/layers/moe/topk.py +12 -3
  53. sglang/srt/layers/moe/utils.py +59 -0
  54. sglang/srt/layers/quantization/__init__.py +22 -0
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  56. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  57. sglang/srt/layers/quantization/fp4.py +557 -0
  58. sglang/srt/layers/quantization/fp8.py +8 -7
  59. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  60. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  61. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  62. sglang/srt/layers/quantization/mxfp4.py +651 -0
  63. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  64. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  65. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  66. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  67. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  68. sglang/srt/layers/quantization/quark/utils.py +107 -0
  69. sglang/srt/layers/quantization/unquant.py +60 -6
  70. sglang/srt/layers/quantization/w4afp8.py +1 -1
  71. sglang/srt/layers/rotary_embedding.py +225 -1
  72. sglang/srt/layers/utils.py +9 -0
  73. sglang/srt/layers/vocab_parallel_embedding.py +15 -4
  74. sglang/srt/lora/lora_manager.py +70 -14
  75. sglang/srt/lora/lora_registry.py +10 -2
  76. sglang/srt/lora/mem_pool.py +43 -5
  77. sglang/srt/managers/cache_controller.py +61 -32
  78. sglang/srt/managers/data_parallel_controller.py +52 -2
  79. sglang/srt/managers/detokenizer_manager.py +1 -1
  80. sglang/srt/managers/io_struct.py +21 -4
  81. sglang/srt/managers/mm_utils.py +5 -11
  82. sglang/srt/managers/schedule_batch.py +30 -8
  83. sglang/srt/managers/schedule_policy.py +3 -1
  84. sglang/srt/managers/scheduler.py +170 -18
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  86. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  87. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  88. sglang/srt/managers/template_manager.py +59 -22
  89. sglang/srt/managers/tokenizer_manager.py +137 -67
  90. sglang/srt/managers/tp_worker.py +3 -0
  91. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  92. sglang/srt/managers/utils.py +45 -1
  93. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  94. sglang/srt/mem_cache/hicache_storage.py +13 -21
  95. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  96. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  97. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  98. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  99. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  100. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  101. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  102. sglang/srt/model_executor/forward_batch_info.py +48 -17
  103. sglang/srt/model_executor/model_runner.py +24 -2
  104. sglang/srt/model_loader/weight_utils.py +10 -0
  105. sglang/srt/models/bailing_moe.py +425 -0
  106. sglang/srt/models/deepseek_v2.py +95 -50
  107. sglang/srt/models/ernie4.py +426 -0
  108. sglang/srt/models/ernie4_eagle.py +203 -0
  109. sglang/srt/models/gemma3n_mm.py +39 -0
  110. sglang/srt/models/glm4_moe.py +102 -27
  111. sglang/srt/models/gpt_oss.py +1134 -0
  112. sglang/srt/models/grok.py +3 -3
  113. sglang/srt/models/llama4.py +13 -2
  114. sglang/srt/models/mixtral.py +3 -3
  115. sglang/srt/models/mllama4.py +428 -19
  116. sglang/srt/models/qwen2.py +6 -0
  117. sglang/srt/models/qwen2_moe.py +7 -4
  118. sglang/srt/models/qwen3_moe.py +39 -14
  119. sglang/srt/models/step3_vl.py +10 -1
  120. sglang/srt/models/transformers.py +2 -5
  121. sglang/srt/multimodal/processors/base_processor.py +4 -3
  122. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  123. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  124. sglang/srt/operations_strategy.py +1 -1
  125. sglang/srt/reasoning_parser.py +18 -39
  126. sglang/srt/server_args.py +218 -23
  127. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  128. sglang/srt/two_batch_overlap.py +163 -9
  129. sglang/srt/utils.py +41 -26
  130. sglang/srt/weight_sync/utils.py +1 -1
  131. sglang/test/runners.py +4 -4
  132. sglang/test/test_utils.py +4 -4
  133. sglang/version.py +1 -1
  134. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
  135. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
  136. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  137. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  138. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  139. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  140. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  141. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  142. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  143. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -4,13 +4,14 @@ import dataclasses
4
4
  import functools
5
5
  import math
6
6
  from functools import lru_cache, partial
7
- from typing import Any, Optional, Tuple, Union
7
+ from typing import Any, Callable, Optional, Tuple, Union
8
8
 
9
9
  import torch
10
10
  import torch.nn as nn
11
11
  import torch.nn.functional as F
12
12
  from einops import rearrange
13
13
 
14
+ from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
14
15
  from sglang.srt.utils import is_cuda, print_info_once
15
16
 
16
17
  _is_cuda = is_cuda()
@@ -308,6 +309,7 @@ class VisionFlash3Attention(nn.Module):
308
309
  cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
309
310
  seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
310
311
  max_seqlen = seq_lens.max().item()
312
+
311
313
  output = flash_attn_varlen_func(
312
314
  q,
313
315
  k,
@@ -358,22 +360,26 @@ class VisionAttention(nn.Module):
358
360
  qkv_bias: bool = True,
359
361
  qk_normalization: bool = False,
360
362
  layer_norm_eps: float = 1e-06,
363
+ customized_position_embedding_applier: Callable[
364
+ [torch.Tensor, torch.Tensor, Any, Any], Tuple[torch.Tensor, torch.Tensor]
365
+ ] = None,
361
366
  **kwargs,
362
367
  ):
363
368
  super().__init__()
364
- world_size = parallel_state.get_tensor_model_parallel_world_size()
365
- self.tp_size = world_size
366
- self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
369
+ attn_tp_rank = get_attention_tp_rank()
370
+ attn_tp_size = get_attention_tp_size()
371
+ self.tp_size = attn_tp_size
372
+ self.tp_rank = attn_tp_rank
367
373
  self.dropout = dropout
368
374
  self.head_size = embed_dim // num_heads
369
375
  self.hidden_size_per_attention_head = dist_utils.divide(
370
376
  projection_size, num_heads
371
377
  )
372
378
  self.num_attention_heads_per_partition = dist_utils.divide(
373
- num_dummy_heads + num_heads, world_size
379
+ num_dummy_heads + num_heads, self.tp_size
374
380
  )
375
381
  self.num_attention_kv_heads_per_partition = dist_utils.divide(
376
- num_dummy_heads + num_heads, world_size
382
+ num_dummy_heads + num_heads, self.tp_size
377
383
  )
378
384
 
379
385
  self.q_size = self.num_attention_heads_per_partition * self.head_size
@@ -392,6 +398,7 @@ class VisionAttention(nn.Module):
392
398
  self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
393
399
  )
394
400
 
401
+ # priority: server_args > passed qkv_backend > sdpa
395
402
  if global_server_args_dict["mm_attention_backend"] is None:
396
403
  if qkv_backend is None:
397
404
  qkv_backend = "sdpa"
@@ -401,6 +408,9 @@ class VisionAttention(nn.Module):
401
408
 
402
409
  print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
403
410
 
411
+ self.customized_position_embedding_applier = (
412
+ customized_position_embedding_applier
413
+ )
404
414
  self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
405
415
  head_dim=self.head_size,
406
416
  num_heads=self.num_attention_heads_per_partition,
@@ -419,6 +429,8 @@ class VisionAttention(nn.Module):
419
429
  total_num_kv_heads=num_dummy_heads + num_heads,
420
430
  bias=qkv_bias,
421
431
  quant_config=quant_config,
432
+ tp_rank=self.tp_rank,
433
+ tp_size=self.tp_size,
422
434
  prefix=add_prefix("qkv_proj", prefix),
423
435
  )
424
436
  else:
@@ -427,6 +439,8 @@ class VisionAttention(nn.Module):
427
439
  output_size=3 * self.dummy_dim,
428
440
  bias=qkv_bias,
429
441
  quant_config=quant_config,
442
+ tp_rank=self.tp_rank,
443
+ tp_size=self.tp_size,
430
444
  prefix=add_prefix("qkv_proj", prefix),
431
445
  )
432
446
  self.proj = RowParallelLinear(
@@ -434,6 +448,8 @@ class VisionAttention(nn.Module):
434
448
  output_size=embed_dim,
435
449
  bias=proj_bias,
436
450
  quant_config=quant_config,
451
+ tp_rank=self.tp_rank,
452
+ tp_size=self.tp_size,
437
453
  prefix=add_prefix("proj", prefix),
438
454
  )
439
455
 
@@ -473,13 +489,13 @@ class VisionAttention(nn.Module):
473
489
  if x.dim() == 2:
474
490
  x = x.unsqueeze(0)
475
491
  assert x.dim() == 3, x.shape
476
- bsz, s, _ = x.shape
492
+ x_shape = x.shape
493
+ bsz, s, _ = x_shape
477
494
  head = self.num_attention_heads_per_partition
478
495
  kv_head = self.num_attention_kv_heads_per_partition
479
496
  if self.use_qkv_parallel:
480
497
  # [b, s, embed_dim] --> [b, s, embed_dim]
481
498
  qkv, _ = self.qkv_proj(x)
482
-
483
499
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
484
500
 
485
501
  # [b, s, embed_dim] --> [b * s, head, head_size]
@@ -508,16 +524,25 @@ class VisionAttention(nn.Module):
508
524
  ]
509
525
 
510
526
  if position_embeddings is not None:
511
- cos, sin = position_embeddings
512
527
  original_shape = q.shape
513
- # [total_tokens, head, head_size]
514
- q = q.view(-1, head, self.head_size)
515
- k = k.view(-1, head, self.head_size)
516
528
 
517
- q, k = apply_rotary_pos_emb(q, k, cos, sin)
529
+ if self.customized_position_embedding_applier is not None:
530
+ q, k = self.customized_position_embedding_applier(
531
+ q, k, position_embeddings, x_shape
532
+ )
533
+ q = q.view(original_shape)
534
+ k = k.view(original_shape)
535
+ else:
536
+ cos, sin = position_embeddings
537
+
538
+ # [total_tokens, head, head_size]
539
+ q = q.view(-1, head, self.head_size)
540
+ k = k.view(-1, head, self.head_size)
541
+
542
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
518
543
 
519
- q = q.view(original_shape)
520
- k = k.view(original_shape)
544
+ q = q.view(original_shape)
545
+ k = k.view(original_shape)
521
546
 
522
547
  if q.dim() == 4:
523
548
  # [b, s, head, head_size] --> [b * s, head, head_size]
@@ -27,6 +27,7 @@ from sglang.srt.layers.dp_attention import (
27
27
  attn_tp_all_gather_into_tensor,
28
28
  attn_tp_reduce_scatter_tensor,
29
29
  dp_gather_partial,
30
+ dp_reduce_scatter_tensor,
30
31
  dp_scatter,
31
32
  get_attention_dp_size,
32
33
  get_attention_tp_rank,
@@ -108,7 +109,7 @@ class LayerScatterModes:
108
109
  if context.is_layer_sparse:
109
110
  return (
110
111
  ScatterMode.SCATTERED
111
- if global_server_args_dict["enable_deepep_moe"]
112
+ if not global_server_args_dict["moe_a2a_backend"].is_standard()
112
113
  else ScatterMode.FULL
113
114
  )
114
115
  else:
@@ -149,10 +150,13 @@ class LayerCommunicator:
149
150
  layer_scatter_modes: LayerScatterModes,
150
151
  input_layernorm: torch.nn.Module,
151
152
  post_attention_layernorm: torch.nn.Module,
153
+ # Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
154
+ allow_reduce_scatter: bool = False,
152
155
  ):
153
156
  self.layer_scatter_modes = layer_scatter_modes
154
157
  self.input_layernorm = input_layernorm
155
158
  self.post_attention_layernorm = post_attention_layernorm
159
+ self.allow_reduce_scatter = allow_reduce_scatter
156
160
 
157
161
  self._context = CommunicateContext.init_new()
158
162
  self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
@@ -239,6 +243,15 @@ class LayerCommunicator:
239
243
  residual=residual,
240
244
  forward_batch=forward_batch,
241
245
  context=self._context,
246
+ allow_reduce_scatter=self.allow_reduce_scatter,
247
+ )
248
+
249
+ def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
250
+ return (
251
+ self.allow_reduce_scatter
252
+ and self._communicate_summable_tensor_pair_fn
253
+ is CommunicateSummableTensorPairFn._scatter_hidden_states
254
+ and forward_batch.dp_padding_mode.is_max_len()
242
255
  )
243
256
 
244
257
 
@@ -404,14 +417,24 @@ class CommunicateWithAllReduceAndLayerNormFn:
404
417
  if context.attn_dp_size != 1:
405
418
  if context.attn_tp_rank == 0:
406
419
  hidden_states += residual
420
+
421
+ # Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
422
+ use_layer_norm_before_gather = context.attn_tp_size == 1
423
+ if use_layer_norm_before_gather:
424
+ residual.copy_(hidden_states)
425
+ if hidden_states.shape[0] != 0:
426
+ hidden_states = layernorm(hidden_states)
427
+
407
428
  hidden_states, local_hidden_states = (
408
429
  forward_batch.gathered_buffer,
409
430
  hidden_states,
410
431
  )
411
432
  dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
412
- dp_scatter(residual, hidden_states, forward_batch)
413
- if hidden_states.shape[0] != 0:
414
- hidden_states = layernorm(hidden_states)
433
+
434
+ if not use_layer_norm_before_gather:
435
+ dp_scatter(residual, hidden_states, forward_batch)
436
+ if hidden_states.shape[0] != 0:
437
+ hidden_states = layernorm(hidden_states)
415
438
  else:
416
439
  # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
417
440
  # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
@@ -514,6 +537,7 @@ class CommunicateSummableTensorPairFn:
514
537
  residual: torch.Tensor,
515
538
  forward_batch: ForwardBatch,
516
539
  context: CommunicateContext,
540
+ **kwargs,
517
541
  ):
518
542
  return hidden_states, residual
519
543
 
@@ -523,15 +547,17 @@ class CommunicateSummableTensorPairFn:
523
547
  residual: torch.Tensor,
524
548
  forward_batch: ForwardBatch,
525
549
  context: CommunicateContext,
550
+ allow_reduce_scatter: bool = False,
526
551
  ):
527
- # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
528
- # important: forward batch.gathered_buffer is used both after scatter and after gather.
529
- # be careful about this!
530
552
  hidden_states, global_hidden_states = (
531
553
  forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
532
554
  hidden_states,
533
555
  )
534
- dp_scatter(hidden_states, global_hidden_states, forward_batch)
556
+ if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
557
+ # When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
558
+ dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
559
+ else:
560
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
535
561
  return hidden_states, residual
536
562
 
537
563
  @staticmethod
@@ -540,6 +566,7 @@ class CommunicateSummableTensorPairFn:
540
566
  residual: torch.Tensor,
541
567
  forward_batch: ForwardBatch,
542
568
  context: CommunicateContext,
569
+ **kwargs,
543
570
  ):
544
571
  hidden_states += residual
545
572
  residual = None
@@ -12,6 +12,7 @@ import triton.language as tl
12
12
 
13
13
  from sglang.srt.distributed import (
14
14
  GroupCoordinator,
15
+ get_tensor_model_parallel_rank,
15
16
  get_tensor_model_parallel_world_size,
16
17
  get_tp_group,
17
18
  tensor_model_parallel_all_reduce,
@@ -355,6 +356,17 @@ def dp_scatter(
355
356
  )
356
357
 
357
358
 
359
+ def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
360
+ if get_tensor_model_parallel_world_size() == get_attention_dp_size():
361
+ get_tp_group().reduce_scatter_tensor(output, input)
362
+ else:
363
+ scattered_local_tokens = input.tensor_split(
364
+ get_tensor_model_parallel_world_size()
365
+ )[get_tensor_model_parallel_rank()]
366
+ get_tp_group().reduce_scatter_tensor(scattered_local_tokens, input)
367
+ get_attention_tp_group().all_gather_into_tensor(output, scattered_local_tokens)
368
+
369
+
358
370
  def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
359
371
  return get_attention_tp_group().reduce_scatter_tensor(output, input)
360
372
 
@@ -13,10 +13,14 @@ from sglang.srt.distributed import (
13
13
  divide,
14
14
  get_tensor_model_parallel_rank,
15
15
  get_tensor_model_parallel_world_size,
16
+ parallel_state,
16
17
  split_tensor_along_last_dim,
17
18
  tensor_model_parallel_all_gather,
18
19
  tensor_model_parallel_all_reduce,
19
20
  )
21
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
22
+ use_symmetric_memory,
23
+ )
20
24
  from sglang.srt.layers.parameter import (
21
25
  BasevLLMParameter,
22
26
  BlockQuantScaleParameter,
@@ -1187,11 +1191,6 @@ class RowParallelLinear(LinearBase):
1187
1191
  else self.weight_loader
1188
1192
  ),
1189
1193
  )
1190
- if not reduce_results and (bias and not skip_bias_add):
1191
- raise ValueError(
1192
- "When not reduce the results, adding bias to the "
1193
- "results can lead to incorrect results"
1194
- )
1195
1194
 
1196
1195
  if bias:
1197
1196
  self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
@@ -1278,7 +1277,7 @@ class RowParallelLinear(LinearBase):
1278
1277
  # It does not support additional parameters.
1279
1278
  param.load_row_parallel_weight(loaded_weight)
1280
1279
 
1281
- def forward(self, input_, can_fuse_mlp_allreduce=False):
1280
+ def forward(self, input_, skip_all_reduce=False):
1282
1281
  if self.input_is_parallel:
1283
1282
  input_parallel = input_
1284
1283
  else:
@@ -1292,8 +1291,10 @@ class RowParallelLinear(LinearBase):
1292
1291
  # Only fuse bias add into GEMM for rank 0 (this ensures that
1293
1292
  # bias will not get added more than once in TP>1 case)
1294
1293
  bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1295
- output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
1296
- if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
1294
+ with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
1295
+ output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
1296
+ sm.tag(output_parallel)
1297
+ if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
1297
1298
  output = tensor_model_parallel_all_reduce(output_parallel)
1298
1299
  else:
1299
1300
  output = output_parallel
@@ -83,6 +83,7 @@ class LogitsProcessorOutput:
83
83
  class LogitsMetadata:
84
84
  forward_mode: ForwardMode
85
85
  capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
86
+ next_token_logits_buffer: Optional[torch.Tensor] = None
86
87
 
87
88
  extend_return_logprob: bool = False
88
89
  extend_return_top_logprob: bool = False
@@ -148,6 +149,7 @@ class LogitsMetadata:
148
149
  return cls(
149
150
  forward_mode=forward_batch.forward_mode,
150
151
  capture_hidden_mode=forward_batch.capture_hidden_mode,
152
+ next_token_logits_buffer=forward_batch.next_token_logits_buffer,
151
153
  extend_return_logprob=extend_return_logprob,
152
154
  extend_return_top_logprob=extend_return_top_logprob,
153
155
  extend_token_ids_logprob=extend_token_ids_logprob,
@@ -508,7 +510,13 @@ class LogitsProcessor(nn.Module):
508
510
  )
509
511
  dp_scatter(logits, global_logits, logits_metadata)
510
512
 
511
- logits = logits[:, : self.config.vocab_size].float()
513
+ if logits_metadata.next_token_logits_buffer is not None:
514
+ logits_buffer = logits_metadata.next_token_logits_buffer
515
+ assert logits_buffer.dtype == torch.float
516
+ logits_buffer.copy_(logits[:, : self.config.vocab_size])
517
+ logits = logits_buffer
518
+ else:
519
+ logits = logits[:, : self.config.vocab_size].float()
512
520
 
513
521
  if self.final_logit_softcapping:
514
522
  fused_softcap(logits, self.final_logit_softcapping)
@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
9
9
  import torch
10
10
 
11
11
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
12
+ from sglang.srt.layers.utils import is_sm100_supported
12
13
  from sglang.srt.utils import is_cuda
13
14
 
14
15
  _is_cuda = is_cuda()
@@ -123,6 +124,7 @@ def cutlass_fused_experts_fp8(
123
124
 
124
125
  if is_cuda:
125
126
  from sglang.srt.layers.quantization.fp8_kernel import (
127
+ per_token_group_quant_fp8_hopper_moe_mn_major,
126
128
  sglang_per_token_group_quant_fp8,
127
129
  )
128
130
 
@@ -133,9 +135,7 @@ def cutlass_fused_experts_fp8(
133
135
  n = w2_q.size(1)
134
136
 
135
137
  topk = topk_ids.size(1)
136
-
137
- a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
138
- device = a_q.device
138
+ device = a.device
139
139
 
140
140
  a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
141
141
  c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
@@ -152,8 +152,16 @@ def cutlass_fused_experts_fp8(
152
152
  k,
153
153
  )
154
154
 
155
- rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
156
- rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
155
+ if is_sm100_supported():
156
+ a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
157
+ rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
158
+ rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
159
+ else:
160
+ rep_a = shuffle_rows(a, a_map, (m * topk, k))
161
+ rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major(
162
+ rep_a, expert_offsets, problem_sizes1, 128
163
+ )
164
+ w1_scale = w1_scale.contiguous()
157
165
 
158
166
  c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
159
167
  c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
@@ -185,7 +193,13 @@ def cutlass_fused_experts_fp8(
185
193
  intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
186
194
  silu_and_mul(c1, intermediate)
187
195
 
188
- intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
196
+ if is_sm100_supported():
197
+ intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
198
+ else:
199
+ intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major(
200
+ intermediate, expert_offsets, problem_sizes2, 128
201
+ )
202
+ w2_scale = w2_scale.contiguous()
189
203
 
190
204
  fp8_blockwise_scaled_grouped_mm(
191
205
  c2,