sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -117
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +3 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +22 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +8 -5
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +106 -15
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +55 -13
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +40 -15
- sglang/srt/layers/communicator.py +35 -8
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +9 -8
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +87 -107
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +59 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +8 -7
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -4
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +10 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +61 -32
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +21 -4
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +30 -8
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +170 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +59 -22
- sglang/srt/managers/tokenizer_manager.py +137 -67
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +48 -17
- sglang/srt/model_executor/model_runner.py +24 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +95 -50
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +102 -27
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/qwen3_moe.py +39 -14
- sglang/srt/models/step3_vl.py +10 -1
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +218 -23
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +163 -9
- sglang/srt/utils.py +41 -26
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +4 -4
- sglang/test/test_utils.py +4 -4
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -4,13 +4,14 @@ import dataclasses
|
|
4
4
|
import functools
|
5
5
|
import math
|
6
6
|
from functools import lru_cache, partial
|
7
|
-
from typing import Any, Optional, Tuple, Union
|
7
|
+
from typing import Any, Callable, Optional, Tuple, Union
|
8
8
|
|
9
9
|
import torch
|
10
10
|
import torch.nn as nn
|
11
11
|
import torch.nn.functional as F
|
12
12
|
from einops import rearrange
|
13
13
|
|
14
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
14
15
|
from sglang.srt.utils import is_cuda, print_info_once
|
15
16
|
|
16
17
|
_is_cuda = is_cuda()
|
@@ -308,6 +309,7 @@ class VisionFlash3Attention(nn.Module):
|
|
308
309
|
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
|
309
310
|
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
310
311
|
max_seqlen = seq_lens.max().item()
|
312
|
+
|
311
313
|
output = flash_attn_varlen_func(
|
312
314
|
q,
|
313
315
|
k,
|
@@ -358,22 +360,26 @@ class VisionAttention(nn.Module):
|
|
358
360
|
qkv_bias: bool = True,
|
359
361
|
qk_normalization: bool = False,
|
360
362
|
layer_norm_eps: float = 1e-06,
|
363
|
+
customized_position_embedding_applier: Callable[
|
364
|
+
[torch.Tensor, torch.Tensor, Any, Any], Tuple[torch.Tensor, torch.Tensor]
|
365
|
+
] = None,
|
361
366
|
**kwargs,
|
362
367
|
):
|
363
368
|
super().__init__()
|
364
|
-
|
365
|
-
|
366
|
-
self.
|
369
|
+
attn_tp_rank = get_attention_tp_rank()
|
370
|
+
attn_tp_size = get_attention_tp_size()
|
371
|
+
self.tp_size = attn_tp_size
|
372
|
+
self.tp_rank = attn_tp_rank
|
367
373
|
self.dropout = dropout
|
368
374
|
self.head_size = embed_dim // num_heads
|
369
375
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
370
376
|
projection_size, num_heads
|
371
377
|
)
|
372
378
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
373
|
-
num_dummy_heads + num_heads,
|
379
|
+
num_dummy_heads + num_heads, self.tp_size
|
374
380
|
)
|
375
381
|
self.num_attention_kv_heads_per_partition = dist_utils.divide(
|
376
|
-
num_dummy_heads + num_heads,
|
382
|
+
num_dummy_heads + num_heads, self.tp_size
|
377
383
|
)
|
378
384
|
|
379
385
|
self.q_size = self.num_attention_heads_per_partition * self.head_size
|
@@ -392,6 +398,7 @@ class VisionAttention(nn.Module):
|
|
392
398
|
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
|
393
399
|
)
|
394
400
|
|
401
|
+
# priority: server_args > passed qkv_backend > sdpa
|
395
402
|
if global_server_args_dict["mm_attention_backend"] is None:
|
396
403
|
if qkv_backend is None:
|
397
404
|
qkv_backend = "sdpa"
|
@@ -401,6 +408,9 @@ class VisionAttention(nn.Module):
|
|
401
408
|
|
402
409
|
print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
403
410
|
|
411
|
+
self.customized_position_embedding_applier = (
|
412
|
+
customized_position_embedding_applier
|
413
|
+
)
|
404
414
|
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
|
405
415
|
head_dim=self.head_size,
|
406
416
|
num_heads=self.num_attention_heads_per_partition,
|
@@ -419,6 +429,8 @@ class VisionAttention(nn.Module):
|
|
419
429
|
total_num_kv_heads=num_dummy_heads + num_heads,
|
420
430
|
bias=qkv_bias,
|
421
431
|
quant_config=quant_config,
|
432
|
+
tp_rank=self.tp_rank,
|
433
|
+
tp_size=self.tp_size,
|
422
434
|
prefix=add_prefix("qkv_proj", prefix),
|
423
435
|
)
|
424
436
|
else:
|
@@ -427,6 +439,8 @@ class VisionAttention(nn.Module):
|
|
427
439
|
output_size=3 * self.dummy_dim,
|
428
440
|
bias=qkv_bias,
|
429
441
|
quant_config=quant_config,
|
442
|
+
tp_rank=self.tp_rank,
|
443
|
+
tp_size=self.tp_size,
|
430
444
|
prefix=add_prefix("qkv_proj", prefix),
|
431
445
|
)
|
432
446
|
self.proj = RowParallelLinear(
|
@@ -434,6 +448,8 @@ class VisionAttention(nn.Module):
|
|
434
448
|
output_size=embed_dim,
|
435
449
|
bias=proj_bias,
|
436
450
|
quant_config=quant_config,
|
451
|
+
tp_rank=self.tp_rank,
|
452
|
+
tp_size=self.tp_size,
|
437
453
|
prefix=add_prefix("proj", prefix),
|
438
454
|
)
|
439
455
|
|
@@ -473,13 +489,13 @@ class VisionAttention(nn.Module):
|
|
473
489
|
if x.dim() == 2:
|
474
490
|
x = x.unsqueeze(0)
|
475
491
|
assert x.dim() == 3, x.shape
|
476
|
-
|
492
|
+
x_shape = x.shape
|
493
|
+
bsz, s, _ = x_shape
|
477
494
|
head = self.num_attention_heads_per_partition
|
478
495
|
kv_head = self.num_attention_kv_heads_per_partition
|
479
496
|
if self.use_qkv_parallel:
|
480
497
|
# [b, s, embed_dim] --> [b, s, embed_dim]
|
481
498
|
qkv, _ = self.qkv_proj(x)
|
482
|
-
|
483
499
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
484
500
|
|
485
501
|
# [b, s, embed_dim] --> [b * s, head, head_size]
|
@@ -508,16 +524,25 @@ class VisionAttention(nn.Module):
|
|
508
524
|
]
|
509
525
|
|
510
526
|
if position_embeddings is not None:
|
511
|
-
cos, sin = position_embeddings
|
512
527
|
original_shape = q.shape
|
513
|
-
# [total_tokens, head, head_size]
|
514
|
-
q = q.view(-1, head, self.head_size)
|
515
|
-
k = k.view(-1, head, self.head_size)
|
516
528
|
|
517
|
-
|
529
|
+
if self.customized_position_embedding_applier is not None:
|
530
|
+
q, k = self.customized_position_embedding_applier(
|
531
|
+
q, k, position_embeddings, x_shape
|
532
|
+
)
|
533
|
+
q = q.view(original_shape)
|
534
|
+
k = k.view(original_shape)
|
535
|
+
else:
|
536
|
+
cos, sin = position_embeddings
|
537
|
+
|
538
|
+
# [total_tokens, head, head_size]
|
539
|
+
q = q.view(-1, head, self.head_size)
|
540
|
+
k = k.view(-1, head, self.head_size)
|
541
|
+
|
542
|
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
518
543
|
|
519
|
-
|
520
|
-
|
544
|
+
q = q.view(original_shape)
|
545
|
+
k = k.view(original_shape)
|
521
546
|
|
522
547
|
if q.dim() == 4:
|
523
548
|
# [b, s, head, head_size] --> [b * s, head, head_size]
|
@@ -27,6 +27,7 @@ from sglang.srt.layers.dp_attention import (
|
|
27
27
|
attn_tp_all_gather_into_tensor,
|
28
28
|
attn_tp_reduce_scatter_tensor,
|
29
29
|
dp_gather_partial,
|
30
|
+
dp_reduce_scatter_tensor,
|
30
31
|
dp_scatter,
|
31
32
|
get_attention_dp_size,
|
32
33
|
get_attention_tp_rank,
|
@@ -108,7 +109,7 @@ class LayerScatterModes:
|
|
108
109
|
if context.is_layer_sparse:
|
109
110
|
return (
|
110
111
|
ScatterMode.SCATTERED
|
111
|
-
if global_server_args_dict["
|
112
|
+
if not global_server_args_dict["moe_a2a_backend"].is_standard()
|
112
113
|
else ScatterMode.FULL
|
113
114
|
)
|
114
115
|
else:
|
@@ -149,10 +150,13 @@ class LayerCommunicator:
|
|
149
150
|
layer_scatter_modes: LayerScatterModes,
|
150
151
|
input_layernorm: torch.nn.Module,
|
151
152
|
post_attention_layernorm: torch.nn.Module,
|
153
|
+
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
|
154
|
+
allow_reduce_scatter: bool = False,
|
152
155
|
):
|
153
156
|
self.layer_scatter_modes = layer_scatter_modes
|
154
157
|
self.input_layernorm = input_layernorm
|
155
158
|
self.post_attention_layernorm = post_attention_layernorm
|
159
|
+
self.allow_reduce_scatter = allow_reduce_scatter
|
156
160
|
|
157
161
|
self._context = CommunicateContext.init_new()
|
158
162
|
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
|
@@ -239,6 +243,15 @@ class LayerCommunicator:
|
|
239
243
|
residual=residual,
|
240
244
|
forward_batch=forward_batch,
|
241
245
|
context=self._context,
|
246
|
+
allow_reduce_scatter=self.allow_reduce_scatter,
|
247
|
+
)
|
248
|
+
|
249
|
+
def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
|
250
|
+
return (
|
251
|
+
self.allow_reduce_scatter
|
252
|
+
and self._communicate_summable_tensor_pair_fn
|
253
|
+
is CommunicateSummableTensorPairFn._scatter_hidden_states
|
254
|
+
and forward_batch.dp_padding_mode.is_max_len()
|
242
255
|
)
|
243
256
|
|
244
257
|
|
@@ -404,14 +417,24 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
404
417
|
if context.attn_dp_size != 1:
|
405
418
|
if context.attn_tp_rank == 0:
|
406
419
|
hidden_states += residual
|
420
|
+
|
421
|
+
# Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
|
422
|
+
use_layer_norm_before_gather = context.attn_tp_size == 1
|
423
|
+
if use_layer_norm_before_gather:
|
424
|
+
residual.copy_(hidden_states)
|
425
|
+
if hidden_states.shape[0] != 0:
|
426
|
+
hidden_states = layernorm(hidden_states)
|
427
|
+
|
407
428
|
hidden_states, local_hidden_states = (
|
408
429
|
forward_batch.gathered_buffer,
|
409
430
|
hidden_states,
|
410
431
|
)
|
411
432
|
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
412
|
-
|
413
|
-
if
|
414
|
-
hidden_states
|
433
|
+
|
434
|
+
if not use_layer_norm_before_gather:
|
435
|
+
dp_scatter(residual, hidden_states, forward_batch)
|
436
|
+
if hidden_states.shape[0] != 0:
|
437
|
+
hidden_states = layernorm(hidden_states)
|
415
438
|
else:
|
416
439
|
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
417
440
|
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
@@ -514,6 +537,7 @@ class CommunicateSummableTensorPairFn:
|
|
514
537
|
residual: torch.Tensor,
|
515
538
|
forward_batch: ForwardBatch,
|
516
539
|
context: CommunicateContext,
|
540
|
+
**kwargs,
|
517
541
|
):
|
518
542
|
return hidden_states, residual
|
519
543
|
|
@@ -523,15 +547,17 @@ class CommunicateSummableTensorPairFn:
|
|
523
547
|
residual: torch.Tensor,
|
524
548
|
forward_batch: ForwardBatch,
|
525
549
|
context: CommunicateContext,
|
550
|
+
allow_reduce_scatter: bool = False,
|
526
551
|
):
|
527
|
-
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
528
|
-
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
529
|
-
# be careful about this!
|
530
552
|
hidden_states, global_hidden_states = (
|
531
553
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
532
554
|
hidden_states,
|
533
555
|
)
|
534
|
-
|
556
|
+
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
|
557
|
+
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
|
558
|
+
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
|
559
|
+
else:
|
560
|
+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
535
561
|
return hidden_states, residual
|
536
562
|
|
537
563
|
@staticmethod
|
@@ -540,6 +566,7 @@ class CommunicateSummableTensorPairFn:
|
|
540
566
|
residual: torch.Tensor,
|
541
567
|
forward_batch: ForwardBatch,
|
542
568
|
context: CommunicateContext,
|
569
|
+
**kwargs,
|
543
570
|
):
|
544
571
|
hidden_states += residual
|
545
572
|
residual = None
|
@@ -12,6 +12,7 @@ import triton.language as tl
|
|
12
12
|
|
13
13
|
from sglang.srt.distributed import (
|
14
14
|
GroupCoordinator,
|
15
|
+
get_tensor_model_parallel_rank,
|
15
16
|
get_tensor_model_parallel_world_size,
|
16
17
|
get_tp_group,
|
17
18
|
tensor_model_parallel_all_reduce,
|
@@ -355,6 +356,17 @@ def dp_scatter(
|
|
355
356
|
)
|
356
357
|
|
357
358
|
|
359
|
+
def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
360
|
+
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
|
361
|
+
get_tp_group().reduce_scatter_tensor(output, input)
|
362
|
+
else:
|
363
|
+
scattered_local_tokens = input.tensor_split(
|
364
|
+
get_tensor_model_parallel_world_size()
|
365
|
+
)[get_tensor_model_parallel_rank()]
|
366
|
+
get_tp_group().reduce_scatter_tensor(scattered_local_tokens, input)
|
367
|
+
get_attention_tp_group().all_gather_into_tensor(output, scattered_local_tokens)
|
368
|
+
|
369
|
+
|
358
370
|
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
359
371
|
return get_attention_tp_group().reduce_scatter_tensor(output, input)
|
360
372
|
|
sglang/srt/layers/linear.py
CHANGED
@@ -13,10 +13,14 @@ from sglang.srt.distributed import (
|
|
13
13
|
divide,
|
14
14
|
get_tensor_model_parallel_rank,
|
15
15
|
get_tensor_model_parallel_world_size,
|
16
|
+
parallel_state,
|
16
17
|
split_tensor_along_last_dim,
|
17
18
|
tensor_model_parallel_all_gather,
|
18
19
|
tensor_model_parallel_all_reduce,
|
19
20
|
)
|
21
|
+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
22
|
+
use_symmetric_memory,
|
23
|
+
)
|
20
24
|
from sglang.srt.layers.parameter import (
|
21
25
|
BasevLLMParameter,
|
22
26
|
BlockQuantScaleParameter,
|
@@ -1187,11 +1191,6 @@ class RowParallelLinear(LinearBase):
|
|
1187
1191
|
else self.weight_loader
|
1188
1192
|
),
|
1189
1193
|
)
|
1190
|
-
if not reduce_results and (bias and not skip_bias_add):
|
1191
|
-
raise ValueError(
|
1192
|
-
"When not reduce the results, adding bias to the "
|
1193
|
-
"results can lead to incorrect results"
|
1194
|
-
)
|
1195
1194
|
|
1196
1195
|
if bias:
|
1197
1196
|
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
|
@@ -1278,7 +1277,7 @@ class RowParallelLinear(LinearBase):
|
|
1278
1277
|
# It does not support additional parameters.
|
1279
1278
|
param.load_row_parallel_weight(loaded_weight)
|
1280
1279
|
|
1281
|
-
def forward(self, input_,
|
1280
|
+
def forward(self, input_, skip_all_reduce=False):
|
1282
1281
|
if self.input_is_parallel:
|
1283
1282
|
input_parallel = input_
|
1284
1283
|
else:
|
@@ -1292,8 +1291,10 @@ class RowParallelLinear(LinearBase):
|
|
1292
1291
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
1293
1292
|
# bias will not get added more than once in TP>1 case)
|
1294
1293
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
1295
|
-
|
1296
|
-
|
1294
|
+
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
1295
|
+
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
1296
|
+
sm.tag(output_parallel)
|
1297
|
+
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
|
1297
1298
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
1298
1299
|
else:
|
1299
1300
|
output = output_parallel
|
@@ -83,6 +83,7 @@ class LogitsProcessorOutput:
|
|
83
83
|
class LogitsMetadata:
|
84
84
|
forward_mode: ForwardMode
|
85
85
|
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
86
|
+
next_token_logits_buffer: Optional[torch.Tensor] = None
|
86
87
|
|
87
88
|
extend_return_logprob: bool = False
|
88
89
|
extend_return_top_logprob: bool = False
|
@@ -148,6 +149,7 @@ class LogitsMetadata:
|
|
148
149
|
return cls(
|
149
150
|
forward_mode=forward_batch.forward_mode,
|
150
151
|
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
152
|
+
next_token_logits_buffer=forward_batch.next_token_logits_buffer,
|
151
153
|
extend_return_logprob=extend_return_logprob,
|
152
154
|
extend_return_top_logprob=extend_return_top_logprob,
|
153
155
|
extend_token_ids_logprob=extend_token_ids_logprob,
|
@@ -508,7 +510,13 @@ class LogitsProcessor(nn.Module):
|
|
508
510
|
)
|
509
511
|
dp_scatter(logits, global_logits, logits_metadata)
|
510
512
|
|
511
|
-
|
513
|
+
if logits_metadata.next_token_logits_buffer is not None:
|
514
|
+
logits_buffer = logits_metadata.next_token_logits_buffer
|
515
|
+
assert logits_buffer.dtype == torch.float
|
516
|
+
logits_buffer.copy_(logits[:, : self.config.vocab_size])
|
517
|
+
logits = logits_buffer
|
518
|
+
else:
|
519
|
+
logits = logits[:, : self.config.vocab_size].float()
|
512
520
|
|
513
521
|
if self.final_logit_softcapping:
|
514
522
|
fused_softcap(logits, self.final_logit_softcapping)
|
@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
|
12
|
+
from sglang.srt.layers.utils import is_sm100_supported
|
12
13
|
from sglang.srt.utils import is_cuda
|
13
14
|
|
14
15
|
_is_cuda = is_cuda()
|
@@ -123,6 +124,7 @@ def cutlass_fused_experts_fp8(
|
|
123
124
|
|
124
125
|
if is_cuda:
|
125
126
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
127
|
+
per_token_group_quant_fp8_hopper_moe_mn_major,
|
126
128
|
sglang_per_token_group_quant_fp8,
|
127
129
|
)
|
128
130
|
|
@@ -133,9 +135,7 @@ def cutlass_fused_experts_fp8(
|
|
133
135
|
n = w2_q.size(1)
|
134
136
|
|
135
137
|
topk = topk_ids.size(1)
|
136
|
-
|
137
|
-
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
138
|
-
device = a_q.device
|
138
|
+
device = a.device
|
139
139
|
|
140
140
|
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
141
141
|
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
@@ -152,8 +152,16 @@ def cutlass_fused_experts_fp8(
|
|
152
152
|
k,
|
153
153
|
)
|
154
154
|
|
155
|
-
|
156
|
-
|
155
|
+
if is_sm100_supported():
|
156
|
+
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
157
|
+
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
158
|
+
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
159
|
+
else:
|
160
|
+
rep_a = shuffle_rows(a, a_map, (m * topk, k))
|
161
|
+
rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major(
|
162
|
+
rep_a, expert_offsets, problem_sizes1, 128
|
163
|
+
)
|
164
|
+
w1_scale = w1_scale.contiguous()
|
157
165
|
|
158
166
|
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
159
167
|
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
@@ -185,7 +193,13 @@ def cutlass_fused_experts_fp8(
|
|
185
193
|
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
186
194
|
silu_and_mul(c1, intermediate)
|
187
195
|
|
188
|
-
|
196
|
+
if is_sm100_supported():
|
197
|
+
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
198
|
+
else:
|
199
|
+
intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major(
|
200
|
+
intermediate, expert_offsets, problem_sizes2, 128
|
201
|
+
)
|
202
|
+
w2_scale = w2_scale.contiguous()
|
189
203
|
|
190
204
|
fp8_blockwise_scaled_grouped_mm(
|
191
205
|
c2,
|