sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
sglang/srt/models/ernie4.py
CHANGED
@@ -31,13 +31,13 @@ from sglang.srt.layers.communicator import enable_moe_dense_fully_dp
|
|
31
31
|
from sglang.srt.layers.layernorm import RMSNorm
|
32
32
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
33
33
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
34
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
34
35
|
from sglang.srt.layers.moe.topk import TopK
|
35
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
37
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
37
38
|
ParallelLMHead,
|
38
39
|
VocabParallelEmbedding,
|
39
40
|
)
|
40
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
41
41
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
42
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
43
43
|
from sglang.srt.models.deepseek_v2 import DeepseekV2MLP as Ernie4MLP
|
@@ -361,7 +361,7 @@ class Ernie4_5_ForCausalLM(nn.Module):
|
|
361
361
|
|
362
362
|
class Ernie4_5_MoeForCausalLM(Ernie4_5_ForCausalLM):
|
363
363
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
364
|
-
expert_params_mapping =
|
364
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
365
365
|
ckpt_gate_proj_name="gate_proj",
|
366
366
|
ckpt_down_proj_name="down_proj",
|
367
367
|
ckpt_up_proj_name="up_proj",
|
sglang/srt/models/glm4_moe.py
CHANGED
@@ -24,6 +24,7 @@ from transformers import PretrainedConfig
|
|
24
24
|
|
25
25
|
from sglang.srt.distributed import (
|
26
26
|
get_moe_expert_parallel_world_size,
|
27
|
+
get_pp_group,
|
27
28
|
get_tensor_model_parallel_rank,
|
28
29
|
get_tensor_model_parallel_world_size,
|
29
30
|
parallel_state,
|
@@ -39,7 +40,7 @@ from sglang.srt.layers.communicator import (
|
|
39
40
|
from sglang.srt.layers.dp_attention import (
|
40
41
|
get_attention_tp_rank,
|
41
42
|
get_attention_tp_size,
|
42
|
-
|
43
|
+
is_dp_attention_enabled,
|
43
44
|
)
|
44
45
|
from sglang.srt.layers.layernorm import RMSNorm
|
45
46
|
from sglang.srt.layers.linear import (
|
@@ -50,9 +51,10 @@ from sglang.srt.layers.linear import (
|
|
50
51
|
RowParallelLinear,
|
51
52
|
)
|
52
53
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
54
|
+
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
|
53
55
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
56
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
54
57
|
from sglang.srt.layers.moe.topk import TopK
|
55
|
-
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
56
58
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
57
59
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
58
60
|
is_fp8_fnuz,
|
@@ -75,10 +77,7 @@ from sglang.srt.models.deepseek_v2 import (
|
|
75
77
|
DeepseekV2Model,
|
76
78
|
DeepseekV2MoE,
|
77
79
|
)
|
78
|
-
from sglang.srt.two_batch_overlap import
|
79
|
-
MaybeTboDeepEPDispatcher,
|
80
|
-
model_forward_maybe_tbo,
|
81
|
-
)
|
80
|
+
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
82
81
|
from sglang.srt.utils import (
|
83
82
|
BumpAllocator,
|
84
83
|
LazyValue,
|
@@ -413,19 +412,15 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
413
412
|
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
414
413
|
)
|
415
414
|
|
416
|
-
self.topk = (
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
426
|
-
)
|
427
|
-
if not should_use_flashinfer_trtllm_moe()
|
428
|
-
else None
|
415
|
+
self.topk = TopK(
|
416
|
+
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
417
|
+
renormalize=config.norm_topk_prob,
|
418
|
+
use_grouped_topk=True,
|
419
|
+
num_expert_group=config.n_group,
|
420
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
421
|
+
topk_group=config.topk_group,
|
422
|
+
correction_bias=self.gate.e_score_correction_bias,
|
423
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
429
424
|
)
|
430
425
|
|
431
426
|
self.experts = get_moe_impl_class()(
|
@@ -440,31 +435,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
440
435
|
quant_config=quant_config,
|
441
436
|
routed_scaling_factor=self.routed_scaling_factor,
|
442
437
|
prefix=add_prefix("experts", prefix),
|
443
|
-
**(
|
444
|
-
dict(deepep_mode=global_server_args_dict["deepep_mode"])
|
445
|
-
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
446
|
-
else {}
|
447
|
-
),
|
448
|
-
# Additional args for FusedMoE
|
449
|
-
**(
|
450
|
-
dict(
|
451
|
-
enable_flashinfer_cutlass_moe=True,
|
452
|
-
)
|
453
|
-
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
454
|
-
else {}
|
455
|
-
),
|
456
|
-
**(
|
457
|
-
dict(
|
458
|
-
renormalize=config.norm_topk_prob,
|
459
|
-
use_grouped_topk=True,
|
460
|
-
num_expert_group=config.n_group,
|
461
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
462
|
-
topk_group=config.topk_group,
|
463
|
-
correction_bias=self.gate.e_score_correction_bias,
|
464
|
-
)
|
465
|
-
if should_use_flashinfer_trtllm_moe()
|
466
|
-
else {}
|
467
|
-
),
|
468
438
|
)
|
469
439
|
|
470
440
|
self.shared_experts_is_int8 = False
|
@@ -495,7 +465,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
495
465
|
|
496
466
|
self.top_k = config.num_experts_per_tok
|
497
467
|
|
498
|
-
if
|
468
|
+
if get_moe_a2a_backend().is_deepep():
|
499
469
|
# TODO: we will support tp < ep in the future
|
500
470
|
self.ep_size = get_moe_expert_parallel_world_size()
|
501
471
|
self.num_experts = (
|
@@ -519,12 +489,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
519
489
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
520
490
|
hidden_size=config.hidden_size,
|
521
491
|
params_dtype=config.torch_dtype,
|
522
|
-
deepep_mode=
|
492
|
+
deepep_mode=get_deepep_mode(),
|
523
493
|
async_finish=True,
|
524
494
|
return_recv_hook=True,
|
525
495
|
)
|
526
496
|
|
527
|
-
self._enable_deepep_moe =
|
497
|
+
self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
|
528
498
|
|
529
499
|
def forward_normal_dual_stream(
|
530
500
|
self,
|
@@ -540,12 +510,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
540
510
|
with torch.cuda.stream(self.alt_stream):
|
541
511
|
# router_logits: (num_tokens, n_experts)
|
542
512
|
router_logits = self.gate(hidden_states)
|
543
|
-
|
544
|
-
|
545
|
-
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
546
|
-
else:
|
547
|
-
kwargs["router_logits"] = router_logits
|
548
|
-
final_hidden_states = self.experts(**kwargs)
|
513
|
+
topk_output = self.topk(hidden_states, router_logits)
|
514
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
549
515
|
if not _is_cuda:
|
550
516
|
final_hidden_states *= self.routed_scaling_factor
|
551
517
|
current_stream.wait_stream(self.alt_stream)
|
@@ -586,12 +552,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
586
552
|
shared_output = self._forward_shared_experts(hidden_states)
|
587
553
|
# router_logits: (num_tokens, n_experts)
|
588
554
|
router_logits = self.gate(hidden_states)
|
589
|
-
|
590
|
-
|
591
|
-
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
592
|
-
else:
|
593
|
-
kwargs["router_logits"] = router_logits
|
594
|
-
final_hidden_states = self.experts(**kwargs)
|
555
|
+
topk_output = self.topk(hidden_states, router_logits)
|
556
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
595
557
|
if not _is_cuda and not _use_aiter:
|
596
558
|
# fused in biased_grouped_topk so we can skip here
|
597
559
|
final_hidden_states *= self.routed_scaling_factor
|
@@ -634,7 +596,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
|
634
596
|
)
|
635
597
|
rms_norm_eps = config.rms_norm_eps
|
636
598
|
attention_bias = config.attention_bias
|
637
|
-
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
638
599
|
self.layer_id = layer_id
|
639
600
|
self.self_attn = Glm4MoeAttention(
|
640
601
|
hidden_size=self.hidden_size,
|
@@ -744,7 +705,7 @@ class Glm4MoeModel(DeepseekV2Model):
|
|
744
705
|
self.embed_tokens = VocabParallelEmbedding(
|
745
706
|
config.vocab_size,
|
746
707
|
config.hidden_size,
|
747
|
-
enable_tp=not
|
708
|
+
enable_tp=not is_dp_attention_enabled(),
|
748
709
|
)
|
749
710
|
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
750
711
|
self.layers = nn.ModuleList(
|
@@ -759,10 +720,11 @@ class Glm4MoeModel(DeepseekV2Model):
|
|
759
720
|
for layer_id in range(config.num_hidden_layers)
|
760
721
|
]
|
761
722
|
)
|
723
|
+
self.pp_group = get_pp_group()
|
724
|
+
self.start_layer = 0
|
725
|
+
self.end_layer = config.num_hidden_layers
|
762
726
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
763
727
|
|
764
|
-
self.dp_size = get_local_attention_dp_size()
|
765
|
-
|
766
728
|
|
767
729
|
class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
768
730
|
|
@@ -777,6 +739,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|
777
739
|
self.config = config
|
778
740
|
self.tp_size = get_tensor_model_parallel_world_size()
|
779
741
|
self.quant_config = quant_config
|
742
|
+
self.pp_group = get_pp_group()
|
780
743
|
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
|
781
744
|
self.model = Glm4MoeModel(
|
782
745
|
config, quant_config, prefix=add_prefix("model", prefix)
|
@@ -789,7 +752,6 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|
789
752
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
790
753
|
)
|
791
754
|
self.logits_processor = LogitsProcessor(config)
|
792
|
-
self.dp_size = get_local_attention_dp_size()
|
793
755
|
|
794
756
|
self._routed_experts_weights_of_layer = LazyValue(
|
795
757
|
lambda: {
|
@@ -953,7 +915,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|
953
915
|
|
954
916
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
955
917
|
# (param_name, weight_name, expert_id, shard_id)
|
956
|
-
expert_params_mapping =
|
918
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
957
919
|
ckpt_gate_proj_name="gate_proj",
|
958
920
|
ckpt_down_proj_name="down_proj",
|
959
921
|
ckpt_up_proj_name="up_proj",
|
@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
|
|
22
22
|
|
23
23
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
24
24
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
25
|
+
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
25
26
|
from sglang.srt.layers.layernorm import RMSNorm
|
26
27
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
27
28
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -56,7 +57,7 @@ class Glm4MoeModelNextN(nn.Module):
|
|
56
57
|
self.embed_tokens = VocabParallelEmbedding(
|
57
58
|
config.vocab_size,
|
58
59
|
config.hidden_size,
|
59
|
-
enable_tp=not
|
60
|
+
enable_tp=not is_dp_attention_enabled(),
|
60
61
|
prefix=add_prefix("embed_tokens", prefix),
|
61
62
|
)
|
62
63
|
|
sglang/srt/models/glm4v.py
CHANGED
@@ -9,6 +9,7 @@ from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisi
|
|
9
9
|
|
10
10
|
from sglang.srt.hf_transformers_utils import get_processor
|
11
11
|
from sglang.srt.layers.activation import SiluAndMul
|
12
|
+
from sglang.srt.layers.attention import vision_utils
|
12
13
|
from sglang.srt.layers.layernorm import RMSNorm
|
13
14
|
from sglang.srt.layers.linear import (
|
14
15
|
ColumnParallelLinear,
|
@@ -91,6 +92,7 @@ class Glm4vVisionBlock(Qwen2_5_VisionBlock):
|
|
91
92
|
norm_layer=norm_layer,
|
92
93
|
quant_config=quant_config,
|
93
94
|
prefix=prefix,
|
95
|
+
num_dummy_heads=config.num_dummy_heads,
|
94
96
|
)
|
95
97
|
self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
96
98
|
self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
@@ -469,7 +471,7 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
|
469
471
|
nn.Module.__init__(self)
|
470
472
|
|
471
473
|
self.config = config
|
472
|
-
|
474
|
+
vision_utils.update_vit_attn_dummy_heads_config(self.config)
|
473
475
|
self.model = Glm4Model(
|
474
476
|
config,
|
475
477
|
quant_config,
|
@@ -537,6 +539,51 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
|
537
539
|
video_embeds = torch.split(video_embeds, split_sizes)
|
538
540
|
return torch.cat(video_embeds)
|
539
541
|
|
542
|
+
def _update_hf_config(self):
|
543
|
+
"""update hf config to ensure vision attention num_attention_heads is divisible by tp_size"""
|
544
|
+
tp_size = get_attention_tp_size()
|
545
|
+
num_heads = self.config.vision_config.num_heads
|
546
|
+
head_dim = self.config.vision_config.hidden_size // num_heads
|
547
|
+
num_dummy_heads = 0
|
548
|
+
|
549
|
+
if num_heads % tp_size != 0:
|
550
|
+
num_dummy_heads = (
|
551
|
+
(num_heads + tp_size - 1) // tp_size
|
552
|
+
) * tp_size - num_heads
|
553
|
+
|
554
|
+
setattr(self.config.vision_config, "head_dim", head_dim)
|
555
|
+
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
|
556
|
+
|
557
|
+
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
558
|
+
"""pad attn qkv weights for dummy heads"""
|
559
|
+
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
560
|
+
if num_dummy_heads == 0:
|
561
|
+
return loaded_weight
|
562
|
+
head_dim = self.config.vision_config.head_dim
|
563
|
+
|
564
|
+
if "attn.qkv_proj" in name:
|
565
|
+
wq, wk, wv = loaded_weight.chunk(3, dim=0)
|
566
|
+
if name.endswith(".weight"):
|
567
|
+
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
|
568
|
+
elif name.endswith(".bias"):
|
569
|
+
dummy_shape = [num_dummy_heads, head_dim]
|
570
|
+
else:
|
571
|
+
raise RuntimeError(f"Unsupported weight with name={name}")
|
572
|
+
pad_func = lambda x: torch.cat(
|
573
|
+
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
|
574
|
+
).flatten(0, 1)
|
575
|
+
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
|
576
|
+
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
577
|
+
elif "attn.proj.weight" in name:
|
578
|
+
padded_weight = loaded_weight.new_zeros(
|
579
|
+
loaded_weight.shape[0], head_dim * num_dummy_heads
|
580
|
+
)
|
581
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
582
|
+
elif "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
583
|
+
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
584
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
585
|
+
return loaded_weight
|
586
|
+
|
540
587
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
541
588
|
stacked_params_mapping = [
|
542
589
|
# (param_name, shard_name, shard_id)
|
@@ -583,6 +630,10 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
|
583
630
|
raise
|
584
631
|
|
585
632
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
633
|
+
if "visual" in name:
|
634
|
+
loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
|
635
|
+
self.config, name, loaded_weight
|
636
|
+
)
|
586
637
|
weight_loader(param, loaded_weight)
|
587
638
|
|
588
639
|
|
sglang/srt/models/glm4v_moe.py
CHANGED
@@ -8,19 +8,12 @@ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
|
|
8
8
|
|
9
9
|
from sglang.srt.distributed import (
|
10
10
|
get_moe_expert_parallel_world_size,
|
11
|
-
get_tensor_model_parallel_rank,
|
12
11
|
get_tensor_model_parallel_world_size,
|
13
|
-
parallel_state,
|
14
|
-
tensor_model_parallel_all_reduce,
|
15
12
|
)
|
16
13
|
from sglang.srt.hf_transformers_utils import get_processor
|
17
|
-
from sglang.srt.layers.
|
18
|
-
get_attention_tp_rank,
|
19
|
-
get_attention_tp_size,
|
20
|
-
get_local_attention_dp_size,
|
21
|
-
)
|
14
|
+
from sglang.srt.layers.attention import vision_utils
|
22
15
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
23
|
-
from sglang.srt.layers.moe.
|
16
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
24
17
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
25
18
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
26
19
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
@@ -48,8 +41,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
|
48
41
|
|
49
42
|
config.moe_layer_freq = 1
|
50
43
|
self.config = config
|
44
|
+
vision_utils.update_vit_attn_dummy_heads_config(self.config)
|
51
45
|
self.tp_size = get_tensor_model_parallel_world_size()
|
52
|
-
self.dp_size = get_local_attention_dp_size()
|
53
46
|
self.quant_config = quant_config
|
54
47
|
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
|
55
48
|
self.num_fused_shared_experts = (
|
@@ -232,7 +225,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
|
232
225
|
|
233
226
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
234
227
|
# (param_name, weight_name, expert_id, shard_id)
|
235
|
-
expert_params_mapping =
|
228
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
236
229
|
ckpt_gate_proj_name="gate_proj",
|
237
230
|
ckpt_down_proj_name="down_proj",
|
238
231
|
ckpt_up_proj_name="up_proj",
|
@@ -394,6 +387,10 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
|
394
387
|
weight_loader = getattr(
|
395
388
|
param, "weight_loader", default_weight_loader
|
396
389
|
)
|
390
|
+
if "visual" in name:
|
391
|
+
loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
|
392
|
+
self.config, name, loaded_weight
|
393
|
+
)
|
397
394
|
weight_loader(param, loaded_weight)
|
398
395
|
|
399
396
|
|
sglang/srt/models/gpt_oss.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
"""Inference-only GptOss model compatible with HuggingFace weights."""
|
17
17
|
|
18
18
|
import logging
|
19
|
+
import math
|
19
20
|
from collections.abc import Iterable
|
20
21
|
from functools import partial
|
21
22
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
@@ -40,7 +41,7 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
|
40
41
|
from sglang.srt.layers.dp_attention import (
|
41
42
|
get_attention_tp_rank,
|
42
43
|
get_attention_tp_size,
|
43
|
-
|
44
|
+
is_dp_attention_enabled,
|
44
45
|
)
|
45
46
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
47
|
from sglang.srt.layers.linear import (
|
@@ -49,9 +50,10 @@ from sglang.srt.layers.linear import (
|
|
49
50
|
RowParallelLinear,
|
50
51
|
)
|
51
52
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
53
|
+
from sglang.srt.layers.moe import get_moe_a2a_backend
|
52
54
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
55
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
53
56
|
from sglang.srt.layers.moe.topk import TopK
|
54
|
-
from sglang.srt.layers.moe.utils import DeepEPMode
|
55
57
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
56
58
|
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
|
57
59
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -109,16 +111,13 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
109
111
|
self.tp_size = get_tensor_model_parallel_world_size()
|
110
112
|
self.layer_id = layer_id
|
111
113
|
self.activation = config.hidden_act
|
112
|
-
self.
|
113
|
-
self.
|
114
|
+
self.gemm1_alpha = getattr(config, "hidden_act_alpha", 1.702)
|
115
|
+
self.gemm1_clamp_limit = config.swiglu_limit
|
114
116
|
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
top_k=config.num_experts_per_tok,
|
120
|
-
renormalize=True,
|
121
|
-
)
|
117
|
+
self.topk = TopK(
|
118
|
+
top_k=config.num_experts_per_tok,
|
119
|
+
renormalize=True,
|
120
|
+
)
|
122
121
|
|
123
122
|
self.top_k = config.num_experts_per_tok
|
124
123
|
experts_type = get_moe_impl_class()
|
@@ -128,11 +127,9 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
128
127
|
quant_config.get_name() if quant_config is not None else None
|
129
128
|
)
|
130
129
|
extra_kwargs = {
|
131
|
-
"enable_flashinfer_cutlass_moe": global_server_args_dict[
|
132
|
-
"enable_flashinfer_cutlass_moe"
|
133
|
-
],
|
134
130
|
# for moe gate_up_proj and down_proj and their bias loading
|
135
|
-
"use_weight_loader_fused": quant_config_name
|
131
|
+
"use_weight_loader_fused": quant_config_name
|
132
|
+
!= "mxfp4"
|
136
133
|
}
|
137
134
|
self.experts = experts_type(
|
138
135
|
num_experts=config.num_local_experts
|
@@ -143,15 +140,10 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
143
140
|
intermediate_size=config.intermediate_size,
|
144
141
|
quant_config=quant_config,
|
145
142
|
activation=self.activation,
|
146
|
-
|
147
|
-
|
143
|
+
gemm1_alpha=self.gemm1_alpha,
|
144
|
+
gemm1_clamp_limit=self.gemm1_clamp_limit,
|
148
145
|
with_bias=True,
|
149
146
|
prefix=add_prefix("experts", prefix),
|
150
|
-
**(
|
151
|
-
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
152
|
-
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
153
|
-
else {}
|
154
|
-
),
|
155
147
|
**extra_kwargs,
|
156
148
|
)
|
157
149
|
|
@@ -170,7 +162,7 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
170
162
|
forward_batch: Optional[ForwardBatch] = None,
|
171
163
|
should_allreduce_fusion: bool = False,
|
172
164
|
) -> torch.Tensor:
|
173
|
-
if not
|
165
|
+
if not get_moe_a2a_backend().is_deepep():
|
174
166
|
return self.forward_normal(hidden_states, should_allreduce_fusion)
|
175
167
|
else:
|
176
168
|
raise Exception("forward_deepep branch not implemented yet")
|
@@ -188,17 +180,10 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
188
180
|
should_allreduce_fusion: bool = False,
|
189
181
|
) -> torch.Tensor:
|
190
182
|
num_tokens, hidden_dim = hidden_states.shape
|
191
|
-
hidden_states = hidden_states.view(-1, hidden_dim)
|
192
183
|
|
193
|
-
# router_logits: (num_tokens, n_experts)
|
194
184
|
router_logits, _ = self.router(hidden_states)
|
195
|
-
|
196
|
-
|
197
|
-
if self.topk is not None:
|
198
|
-
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
199
|
-
else:
|
200
|
-
kwargs["topk_output"] = (self.top_k, router_logits)
|
201
|
-
final_hidden_states = self.experts(**kwargs)
|
185
|
+
topk_output = self.topk(hidden_states, router_logits)
|
186
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
202
187
|
|
203
188
|
if self.tp_size > 1 and not should_allreduce_fusion:
|
204
189
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
@@ -293,8 +278,12 @@ class GptOssAttention(nn.Module):
|
|
293
278
|
prefix=add_prefix("qkv_proj", prefix),
|
294
279
|
)
|
295
280
|
|
281
|
+
# Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
|
282
|
+
# others can use bfloat16
|
283
|
+
attn_backend = global_server_args_dict.get("attention_backend")
|
284
|
+
sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16
|
296
285
|
self.sinks = nn.Parameter(
|
297
|
-
torch.empty(self.num_heads, dtype=
|
286
|
+
torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False
|
298
287
|
)
|
299
288
|
|
300
289
|
self.o_proj = RowParallelLinear(
|
@@ -431,7 +420,6 @@ class GptOssDecoderLayer(nn.Module):
|
|
431
420
|
|
432
421
|
self.attn_tp_size = get_attention_tp_size()
|
433
422
|
self.attn_tp_rank = get_attention_tp_rank()
|
434
|
-
self.local_dp_size = get_local_attention_dp_size()
|
435
423
|
|
436
424
|
# GptOss all layers are sparse and have no nextn now
|
437
425
|
self.is_layer_sparse = True
|
@@ -466,44 +454,11 @@ class GptOssDecoderLayer(nn.Module):
|
|
466
454
|
layer_scatter_modes=self.layer_scatter_modes,
|
467
455
|
input_layernorm=self.input_layernorm,
|
468
456
|
post_attention_layernorm=self.post_attention_layernorm,
|
457
|
+
is_last_layer=(
|
458
|
+
self.is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
|
459
|
+
),
|
469
460
|
)
|
470
461
|
|
471
|
-
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
|
472
|
-
|
473
|
-
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
474
|
-
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
|
475
|
-
|
476
|
-
batch_size = (
|
477
|
-
forward_batch.input_ids.shape[0]
|
478
|
-
if hasattr(forward_batch, "input_ids")
|
479
|
-
else 0
|
480
|
-
)
|
481
|
-
|
482
|
-
if batch_size > 128:
|
483
|
-
return False
|
484
|
-
|
485
|
-
return self._fuse_allreduce_lookup_table.get(batch_size, False)
|
486
|
-
|
487
|
-
def _build_fuse_allreduce_lookup_table(self):
|
488
|
-
static_conditions_met = (
|
489
|
-
self.layer_id != self.config.num_hidden_layers - 1
|
490
|
-
and get_tensor_model_parallel_world_size() > 1
|
491
|
-
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
492
|
-
and _is_sm100_supported
|
493
|
-
and _is_flashinfer_available
|
494
|
-
)
|
495
|
-
|
496
|
-
if not static_conditions_met:
|
497
|
-
return {}
|
498
|
-
|
499
|
-
lookup_table = {}
|
500
|
-
for batch_size in range(129): # 0 to 128
|
501
|
-
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
|
502
|
-
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
|
503
|
-
lookup_table[batch_size] = should_fuse
|
504
|
-
|
505
|
-
return lookup_table
|
506
|
-
|
507
462
|
def forward(
|
508
463
|
self,
|
509
464
|
positions: torch.Tensor,
|
@@ -527,8 +482,9 @@ class GptOssDecoderLayer(nn.Module):
|
|
527
482
|
)
|
528
483
|
|
529
484
|
should_allreduce_fusion = (
|
530
|
-
self.
|
531
|
-
|
485
|
+
self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
|
486
|
+
forward_batch
|
487
|
+
)
|
532
488
|
)
|
533
489
|
|
534
490
|
hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
|
@@ -561,7 +517,7 @@ class GptOssModel(nn.Module):
|
|
561
517
|
self.embed_tokens = VocabParallelEmbedding(
|
562
518
|
config.vocab_size,
|
563
519
|
config.hidden_size,
|
564
|
-
enable_tp=not
|
520
|
+
enable_tp=not is_dp_attention_enabled(),
|
565
521
|
prefix=add_prefix("embed_tokens", prefix),
|
566
522
|
)
|
567
523
|
else:
|
@@ -833,18 +789,27 @@ class GptOssForCausalLM(nn.Module):
|
|
833
789
|
moe_ep_size = get_moe_expert_parallel_world_size()
|
834
790
|
|
835
791
|
intermediate_size = self.config.intermediate_size
|
792
|
+
assert (
|
793
|
+
intermediate_size % mxfp4_block == 0
|
794
|
+
), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
|
836
795
|
intermediate_size_block = intermediate_size // mxfp4_block
|
837
|
-
|
796
|
+
|
797
|
+
per_rank_intermediate_size_block = math.ceil(
|
798
|
+
intermediate_size_block / moe_tp_size
|
799
|
+
)
|
800
|
+
|
838
801
|
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
|
839
802
|
|
840
803
|
# Calculate common slicing bounds for current rank
|
841
804
|
assert self.config.num_local_experts % moe_ep_size == 0
|
842
805
|
moe_num_global_experts = self.config.num_local_experts
|
843
806
|
moe_num_local_experts = self.config.num_local_experts // moe_ep_size
|
807
|
+
|
844
808
|
moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size
|
845
809
|
moe_tp_rank_end = min(
|
846
810
|
(moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size
|
847
811
|
)
|
812
|
+
|
848
813
|
moe_ep_rank_start = moe_ep_rank * moe_num_local_experts
|
849
814
|
moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts
|
850
815
|
|
@@ -1055,7 +1020,7 @@ class GptOssForCausalLM(nn.Module):
|
|
1055
1020
|
("qkv_proj", "k_proj", "k"),
|
1056
1021
|
("qkv_proj", "v_proj", "v"),
|
1057
1022
|
]
|
1058
|
-
expert_params_mapping =
|
1023
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping_fused(
|
1059
1024
|
ckpt_gate_up_proj_name="gate_up_proj",
|
1060
1025
|
ckpt_down_proj_name="down_proj",
|
1061
1026
|
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
@@ -1136,7 +1101,7 @@ class GptOssForCausalLM(nn.Module):
|
|
1136
1101
|
if name in params_dict.keys():
|
1137
1102
|
param = params_dict[name]
|
1138
1103
|
if "sinks" in name:
|
1139
|
-
start =
|
1104
|
+
start = get_attention_tp_rank() * param.numel()
|
1140
1105
|
param.data.copy_(
|
1141
1106
|
loaded_weight[start : start + param.numel()]
|
1142
1107
|
)
|