sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/glm4_moe.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
|
|
16
16
|
|
|
17
17
|
import logging
|
|
18
|
-
from typing import Any, Dict, Iterable, Optional, Tuple
|
|
18
|
+
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
21
|
import torch.nn.functional as F
|
|
@@ -27,10 +27,16 @@ from sglang.srt.distributed import (
|
|
|
27
27
|
get_pp_group,
|
|
28
28
|
get_tensor_model_parallel_rank,
|
|
29
29
|
get_tensor_model_parallel_world_size,
|
|
30
|
+
parallel_state,
|
|
30
31
|
tensor_model_parallel_all_reduce,
|
|
31
32
|
)
|
|
33
|
+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
|
34
|
+
use_symmetric_memory,
|
|
35
|
+
)
|
|
36
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
|
37
|
+
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
|
38
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
|
32
39
|
from sglang.srt.layers.activation import SiluAndMul
|
|
33
|
-
from sglang.srt.layers.amx_utils import PackWeightMethod
|
|
34
40
|
from sglang.srt.layers.communicator import (
|
|
35
41
|
LayerCommunicator,
|
|
36
42
|
LayerScatterModes,
|
|
@@ -48,7 +54,10 @@ from sglang.srt.layers.linear import (
|
|
|
48
54
|
RowParallelLinear,
|
|
49
55
|
)
|
|
50
56
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
51
|
-
from sglang.srt.layers.moe import
|
|
57
|
+
from sglang.srt.layers.moe import (
|
|
58
|
+
get_moe_a2a_backend,
|
|
59
|
+
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
|
60
|
+
)
|
|
52
61
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
|
53
62
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|
54
63
|
from sglang.srt.layers.moe.topk import TopK
|
|
@@ -56,23 +65,17 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
|
56
65
|
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
|
57
66
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
58
67
|
from sglang.srt.layers.rotary_embedding import get_rope
|
|
68
|
+
from sglang.srt.layers.utils import PPMissingLayer
|
|
59
69
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
60
70
|
ParallelLMHead,
|
|
61
71
|
VocabParallelEmbedding,
|
|
62
72
|
)
|
|
63
73
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
|
64
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
74
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|
65
75
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
66
|
-
from sglang.srt.models.deepseek_v2 import (
|
|
67
|
-
DeepseekV2DecoderLayer,
|
|
68
|
-
DeepseekV2ForCausalLM,
|
|
69
|
-
DeepseekV2Model,
|
|
70
|
-
DeepseekV2MoE,
|
|
71
|
-
)
|
|
72
76
|
from sglang.srt.server_args import get_global_server_args
|
|
77
|
+
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
|
73
78
|
from sglang.srt.utils import (
|
|
74
|
-
BumpAllocator,
|
|
75
|
-
LazyValue,
|
|
76
79
|
add_prefix,
|
|
77
80
|
cpu_has_amx_support,
|
|
78
81
|
get_bool_env_var,
|
|
@@ -80,8 +83,7 @@ from sglang.srt.utils import (
|
|
|
80
83
|
is_cpu,
|
|
81
84
|
is_cuda,
|
|
82
85
|
is_hip,
|
|
83
|
-
|
|
84
|
-
use_intel_amx_backend,
|
|
86
|
+
make_layers,
|
|
85
87
|
)
|
|
86
88
|
|
|
87
89
|
_is_hip = is_hip()
|
|
@@ -92,11 +94,6 @@ _is_cpu_amx_available = cpu_has_amx_support()
|
|
|
92
94
|
_is_cpu = is_cpu()
|
|
93
95
|
_device_sm = get_device_sm()
|
|
94
96
|
|
|
95
|
-
if _is_cuda:
|
|
96
|
-
from sgl_kernel import dsv3_router_gemm
|
|
97
|
-
elif _is_cpu and _is_cpu_amx_available:
|
|
98
|
-
pass
|
|
99
|
-
|
|
100
97
|
logger = logging.getLogger(__name__)
|
|
101
98
|
|
|
102
99
|
|
|
@@ -136,8 +133,7 @@ class Glm4MoeMLP(nn.Module):
|
|
|
136
133
|
)
|
|
137
134
|
if hidden_act != "silu":
|
|
138
135
|
raise ValueError(
|
|
139
|
-
f"Unsupported activation: {hidden_act}. "
|
|
140
|
-
"Only silu is supported for now."
|
|
136
|
+
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
|
141
137
|
)
|
|
142
138
|
self.act_fn = SiluAndMul()
|
|
143
139
|
|
|
@@ -146,7 +142,6 @@ class Glm4MoeMLP(nn.Module):
|
|
|
146
142
|
x,
|
|
147
143
|
forward_batch=None,
|
|
148
144
|
should_allreduce_fusion=False,
|
|
149
|
-
gemm_output_zero_allocator: BumpAllocator = None,
|
|
150
145
|
):
|
|
151
146
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
|
152
147
|
return x
|
|
@@ -326,47 +321,21 @@ class Glm4MoeGate(nn.Module):
|
|
|
326
321
|
self,
|
|
327
322
|
config,
|
|
328
323
|
prefix: str = "",
|
|
329
|
-
is_nextn: bool = False,
|
|
330
324
|
):
|
|
331
325
|
super().__init__()
|
|
332
|
-
self.is_nextn = is_nextn
|
|
333
326
|
self.weight = nn.Parameter(
|
|
334
327
|
torch.empty((config.n_routed_experts, config.hidden_size))
|
|
335
328
|
)
|
|
336
329
|
self.e_score_correction_bias = nn.Parameter(
|
|
337
330
|
torch.empty((config.n_routed_experts), dtype=torch.float32)
|
|
338
331
|
)
|
|
339
|
-
if _is_cpu and _is_cpu_amx_available:
|
|
340
|
-
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
|
341
332
|
|
|
342
333
|
def forward(self, hidden_states):
|
|
343
|
-
|
|
344
|
-
return torch.ops.sgl_kernel.weight_packed_linear(
|
|
345
|
-
hidden_states,
|
|
346
|
-
self.weight,
|
|
347
|
-
None, # bias
|
|
348
|
-
True, # is_vnni
|
|
349
|
-
)
|
|
350
|
-
|
|
351
|
-
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
|
|
352
|
-
if (
|
|
353
|
-
_is_cuda
|
|
354
|
-
and not self.is_nextn
|
|
355
|
-
and hidden_states.shape[0] < 4
|
|
356
|
-
and hidden_states.shape[1] == 7168
|
|
357
|
-
and self.weight.shape[0] == 256
|
|
358
|
-
and _device_sm >= 90
|
|
359
|
-
):
|
|
360
|
-
logits = dsv3_router_gemm(hidden_states, self.weight).to(
|
|
361
|
-
hidden_states.dtype
|
|
362
|
-
)
|
|
363
|
-
else:
|
|
364
|
-
logits = F.linear(hidden_states, self.weight, None)
|
|
365
|
-
|
|
334
|
+
logits = F.linear(hidden_states, self.weight, None)
|
|
366
335
|
return logits
|
|
367
336
|
|
|
368
337
|
|
|
369
|
-
class Glm4MoeSparseMoeBlock(
|
|
338
|
+
class Glm4MoeSparseMoeBlock(nn.Module):
|
|
370
339
|
def __init__(
|
|
371
340
|
self,
|
|
372
341
|
config: PretrainedConfig,
|
|
@@ -374,18 +343,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
|
374
343
|
quant_config: Optional[QuantizationConfig] = None,
|
|
375
344
|
prefix: str = "",
|
|
376
345
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
|
377
|
-
is_nextn: bool = False,
|
|
378
346
|
):
|
|
379
347
|
nn.Module.__init__(self)
|
|
348
|
+
self.top_k = config.num_experts_per_tok
|
|
380
349
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
381
|
-
self.ep_size = get_moe_expert_parallel_world_size()
|
|
382
350
|
self.routed_scaling_factor = config.routed_scaling_factor
|
|
383
351
|
self.n_shared_experts = config.n_shared_experts
|
|
384
|
-
self.num_fused_shared_experts = (
|
|
385
|
-
0
|
|
386
|
-
if get_global_server_args().disable_shared_experts_fusion
|
|
387
|
-
else config.n_shared_experts
|
|
388
|
-
)
|
|
389
352
|
self.config = config
|
|
390
353
|
self.layer_id = layer_id
|
|
391
354
|
self.alt_stream = alt_stream
|
|
@@ -402,39 +365,31 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
|
402
365
|
"Only silu is supported for now."
|
|
403
366
|
)
|
|
404
367
|
|
|
405
|
-
self.gate = Glm4MoeGate(
|
|
406
|
-
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
|
407
|
-
)
|
|
368
|
+
self.gate = Glm4MoeGate(config=config, prefix=add_prefix("gate", prefix))
|
|
408
369
|
|
|
409
370
|
self.topk = TopK(
|
|
410
|
-
top_k=
|
|
371
|
+
top_k=self.top_k,
|
|
411
372
|
renormalize=config.norm_topk_prob,
|
|
412
373
|
use_grouped_topk=True,
|
|
413
374
|
num_expert_group=config.n_group,
|
|
414
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
|
415
375
|
topk_group=config.topk_group,
|
|
416
376
|
correction_bias=self.gate.e_score_correction_bias,
|
|
417
377
|
routed_scaling_factor=self.routed_scaling_factor,
|
|
418
378
|
)
|
|
419
379
|
|
|
420
380
|
self.experts = get_moe_impl_class(quant_config)(
|
|
421
|
-
num_experts=config.n_routed_experts
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
|
425
|
-
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
|
381
|
+
num_experts=config.n_routed_experts,
|
|
382
|
+
top_k=self.top_k,
|
|
383
|
+
layer_id=self.layer_id,
|
|
426
384
|
hidden_size=config.hidden_size,
|
|
427
385
|
intermediate_size=config.moe_intermediate_size,
|
|
428
|
-
layer_id=self.layer_id,
|
|
429
386
|
quant_config=quant_config,
|
|
430
387
|
routed_scaling_factor=self.routed_scaling_factor,
|
|
431
388
|
prefix=add_prefix("experts", prefix),
|
|
432
389
|
)
|
|
433
390
|
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
# self.shared_experts_weight_block_size = None
|
|
437
|
-
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
|
391
|
+
# shared expert
|
|
392
|
+
if config.n_shared_experts is not None:
|
|
438
393
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
|
439
394
|
self.shared_experts = Glm4MoeMLP(
|
|
440
395
|
hidden_size=config.hidden_size,
|
|
@@ -443,21 +398,14 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
|
443
398
|
quant_config=quant_config,
|
|
444
399
|
reduce_results=False,
|
|
445
400
|
prefix=add_prefix("shared_experts", prefix),
|
|
446
|
-
**(
|
|
401
|
+
**(
|
|
402
|
+
dict(tp_rank=0, tp_size=1)
|
|
403
|
+
if get_moe_a2a_backend().is_deepep()
|
|
404
|
+
or get_moe_a2a_backend().is_mooncake()
|
|
405
|
+
or should_use_flashinfer_cutlass_moe_fp4_allgather()
|
|
406
|
+
else {}
|
|
407
|
+
),
|
|
447
408
|
)
|
|
448
|
-
is_packed_weight = hasattr(
|
|
449
|
-
self.shared_experts.gate_up_proj.quant_method, "quant_config"
|
|
450
|
-
)
|
|
451
|
-
self.shared_experts_is_int8 = (
|
|
452
|
-
not is_packed_weight
|
|
453
|
-
and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
|
|
454
|
-
)
|
|
455
|
-
self.shared_experts_is_fp8 = (
|
|
456
|
-
not is_packed_weight
|
|
457
|
-
and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
|
|
458
|
-
)
|
|
459
|
-
|
|
460
|
-
self.top_k = config.num_experts_per_tok
|
|
461
409
|
|
|
462
410
|
if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
|
|
463
411
|
# TODO: we will support tp < ep in the future
|
|
@@ -479,12 +427,46 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
|
479
427
|
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
|
|
480
428
|
)
|
|
481
429
|
|
|
430
|
+
def get_moe_weights(self):
|
|
431
|
+
return [
|
|
432
|
+
x.data
|
|
433
|
+
for name, x in self.experts.named_parameters()
|
|
434
|
+
if name not in ["correction_bias"]
|
|
435
|
+
]
|
|
436
|
+
|
|
437
|
+
def forward(
|
|
438
|
+
self,
|
|
439
|
+
hidden_states: torch.Tensor,
|
|
440
|
+
forward_batch: Optional[ForwardBatch] = None,
|
|
441
|
+
should_allreduce_fusion: bool = False,
|
|
442
|
+
use_reduce_scatter: bool = False,
|
|
443
|
+
) -> torch.Tensor:
|
|
444
|
+
if not self._enable_a2a_moe:
|
|
445
|
+
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
|
446
|
+
if (
|
|
447
|
+
self.alt_stream is not None
|
|
448
|
+
and hidden_states.shape[0] > 0
|
|
449
|
+
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
|
450
|
+
):
|
|
451
|
+
return self.forward_normal_dual_stream(
|
|
452
|
+
hidden_states,
|
|
453
|
+
should_allreduce_fusion,
|
|
454
|
+
use_reduce_scatter,
|
|
455
|
+
)
|
|
456
|
+
else:
|
|
457
|
+
return self.forward_normal(
|
|
458
|
+
hidden_states,
|
|
459
|
+
should_allreduce_fusion,
|
|
460
|
+
use_reduce_scatter,
|
|
461
|
+
)
|
|
462
|
+
else:
|
|
463
|
+
return self.forward_deepep(hidden_states, forward_batch)
|
|
464
|
+
|
|
482
465
|
def forward_normal_dual_stream(
|
|
483
466
|
self,
|
|
484
467
|
hidden_states: torch.Tensor,
|
|
485
468
|
should_allreduce_fusion: bool = False,
|
|
486
469
|
use_reduce_scatter: bool = False,
|
|
487
|
-
gemm_output_zero_allocator: BumpAllocator = None,
|
|
488
470
|
) -> torch.Tensor:
|
|
489
471
|
|
|
490
472
|
current_stream = torch.cuda.current_stream()
|
|
@@ -498,28 +480,21 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
|
498
480
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
|
499
481
|
if not _is_cuda:
|
|
500
482
|
final_hidden_states *= self.routed_scaling_factor
|
|
483
|
+
|
|
501
484
|
current_stream.wait_stream(self.alt_stream)
|
|
485
|
+
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
|
486
|
+
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
|
502
487
|
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
final_hidden_states
|
|
513
|
-
else:
|
|
514
|
-
final_hidden_states += shared_output
|
|
515
|
-
if (
|
|
516
|
-
self.tp_size > 1
|
|
517
|
-
and not should_allreduce_fusion
|
|
518
|
-
and not use_reduce_scatter
|
|
519
|
-
):
|
|
520
|
-
final_hidden_states = tensor_model_parallel_all_reduce(
|
|
521
|
-
final_hidden_states
|
|
522
|
-
)
|
|
488
|
+
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
|
489
|
+
final_hidden_states = final_hidden_states_out
|
|
490
|
+
sm.tag(final_hidden_states)
|
|
491
|
+
if (
|
|
492
|
+
self.tp_size > 1
|
|
493
|
+
and not should_allreduce_fusion
|
|
494
|
+
and not use_reduce_scatter
|
|
495
|
+
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
|
|
496
|
+
):
|
|
497
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
|
523
498
|
return final_hidden_states
|
|
524
499
|
|
|
525
500
|
def forward_normal(
|
|
@@ -527,39 +502,69 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
|
527
502
|
hidden_states: torch.Tensor,
|
|
528
503
|
should_allreduce_fusion: bool = False,
|
|
529
504
|
use_reduce_scatter: bool = False,
|
|
530
|
-
gemm_output_zero_allocator: BumpAllocator = None,
|
|
531
505
|
) -> torch.Tensor:
|
|
532
|
-
if
|
|
533
|
-
self.
|
|
534
|
-
|
|
535
|
-
|
|
506
|
+
if hidden_states.shape[0] > 0:
|
|
507
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
|
508
|
+
# router_logits: (num_tokens, n_experts)
|
|
509
|
+
router_logits = self.gate(hidden_states)
|
|
510
|
+
topk_output = self.topk(hidden_states, router_logits)
|
|
511
|
+
else:
|
|
512
|
+
shared_output = None
|
|
513
|
+
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
|
536
514
|
|
|
537
|
-
shared_output = self._forward_shared_experts(hidden_states)
|
|
538
|
-
# router_logits: (num_tokens, n_experts)
|
|
539
|
-
router_logits = self.gate(hidden_states)
|
|
540
|
-
topk_output = self.topk(hidden_states, router_logits)
|
|
541
515
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
|
542
516
|
if not _is_cuda and not _use_aiter:
|
|
543
517
|
# fused in biased_grouped_topk so we can skip here
|
|
544
518
|
final_hidden_states *= self.routed_scaling_factor
|
|
545
|
-
if
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
519
|
+
if shared_output is not None:
|
|
520
|
+
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
|
521
|
+
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
|
522
|
+
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
|
523
|
+
final_hidden_states = final_hidden_states_out
|
|
524
|
+
sm.tag(final_hidden_states)
|
|
525
|
+
if (
|
|
526
|
+
self.tp_size > 1
|
|
527
|
+
and not should_allreduce_fusion
|
|
528
|
+
and not use_reduce_scatter
|
|
529
|
+
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
|
|
530
|
+
):
|
|
531
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
|
532
|
+
return final_hidden_states
|
|
533
|
+
|
|
534
|
+
def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
|
|
535
|
+
shared_output = None
|
|
536
|
+
if hidden_states.shape[0] > 0:
|
|
537
|
+
# router_logits: (num_tokens, n_experts)
|
|
538
|
+
router_logits, _ = self.gate(hidden_states)
|
|
539
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
|
540
|
+
topk_output = self.topk(
|
|
541
|
+
hidden_states,
|
|
542
|
+
router_logits,
|
|
543
|
+
num_token_non_padded=forward_batch.num_token_non_padded,
|
|
544
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
|
545
|
+
layer_id=self.layer_id,
|
|
546
|
+
),
|
|
547
|
+
)
|
|
552
548
|
else:
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
549
|
+
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
|
550
|
+
final_hidden_states = self.experts(
|
|
551
|
+
hidden_states=hidden_states,
|
|
552
|
+
topk_output=topk_output,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
if shared_output is not None:
|
|
556
|
+
final_hidden_states.add_(shared_output)
|
|
557
|
+
|
|
559
558
|
return final_hidden_states
|
|
560
559
|
|
|
560
|
+
def _forward_shared_experts(self, hidden_states: torch.Tensor):
|
|
561
|
+
shared_output = None
|
|
562
|
+
if hidden_states.shape[0] > 0:
|
|
563
|
+
shared_output = self.shared_experts(hidden_states)
|
|
564
|
+
return shared_output
|
|
561
565
|
|
|
562
|
-
|
|
566
|
+
|
|
567
|
+
class Glm4MoeDecoderLayer(nn.Module):
|
|
563
568
|
def __init__(
|
|
564
569
|
self,
|
|
565
570
|
config: PretrainedConfig,
|
|
@@ -582,6 +587,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
|
|
582
587
|
rms_norm_eps = config.rms_norm_eps
|
|
583
588
|
attention_bias = config.attention_bias
|
|
584
589
|
self.layer_id = layer_id
|
|
590
|
+
|
|
585
591
|
self.self_attn = Glm4MoeAttention(
|
|
586
592
|
hidden_size=self.hidden_size,
|
|
587
593
|
num_heads=config.num_attention_heads,
|
|
@@ -597,15 +603,15 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
|
|
597
603
|
quant_config=quant_config,
|
|
598
604
|
prefix=add_prefix("self_attn", prefix),
|
|
599
605
|
use_qk_norm=config.use_qk_norm,
|
|
606
|
+
alt_stream=alt_stream,
|
|
600
607
|
)
|
|
601
608
|
|
|
602
609
|
self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
|
|
603
610
|
is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)
|
|
604
611
|
|
|
605
|
-
num_layers = 1 if is_nextn else config.num_hidden_layers
|
|
606
612
|
self.layer_scatter_modes = LayerScatterModes.init_new(
|
|
607
613
|
layer_id=layer_id,
|
|
608
|
-
num_layers=
|
|
614
|
+
num_layers=1 if is_nextn else config.num_hidden_layers,
|
|
609
615
|
is_layer_sparse=self.is_layer_sparse,
|
|
610
616
|
is_previous_layer_sparse=is_previous_layer_sparse,
|
|
611
617
|
)
|
|
@@ -616,6 +622,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
|
|
616
622
|
quant_config=quant_config,
|
|
617
623
|
prefix=add_prefix("mlp", prefix),
|
|
618
624
|
layer_id=self.layer_id,
|
|
625
|
+
alt_stream=alt_stream,
|
|
619
626
|
)
|
|
620
627
|
else:
|
|
621
628
|
if enable_moe_dense_fully_dp():
|
|
@@ -641,7 +648,16 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
|
|
641
648
|
layer_scatter_modes=self.layer_scatter_modes,
|
|
642
649
|
input_layernorm=self.input_layernorm,
|
|
643
650
|
post_attention_layernorm=self.post_attention_layernorm,
|
|
644
|
-
allow_reduce_scatter=
|
|
651
|
+
allow_reduce_scatter=True,
|
|
652
|
+
is_last_layer=(
|
|
653
|
+
is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
|
|
654
|
+
),
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
|
658
|
+
return is_nextn or (
|
|
659
|
+
self.config.n_routed_experts is not None
|
|
660
|
+
and layer_id >= self.config.first_k_dense_replace
|
|
645
661
|
)
|
|
646
662
|
|
|
647
663
|
def forward(
|
|
@@ -650,8 +666,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
|
|
650
666
|
hidden_states: torch.Tensor,
|
|
651
667
|
forward_batch: ForwardBatch,
|
|
652
668
|
residual: Optional[torch.Tensor],
|
|
653
|
-
zero_allocator: BumpAllocator,
|
|
654
|
-
gemm_output_zero_allocator: BumpAllocator = None,
|
|
655
669
|
) -> torch.Tensor:
|
|
656
670
|
hidden_states, residual = self.layer_communicator.prepare_attn(
|
|
657
671
|
hidden_states, residual, forward_batch
|
|
@@ -676,44 +690,119 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
|
|
676
690
|
return hidden_states, residual
|
|
677
691
|
|
|
678
692
|
|
|
679
|
-
class Glm4MoeModel(
|
|
693
|
+
class Glm4MoeModel(nn.Module):
|
|
680
694
|
def __init__(
|
|
681
695
|
self,
|
|
682
696
|
config: PretrainedConfig,
|
|
683
697
|
quant_config: Optional[QuantizationConfig] = None,
|
|
684
698
|
prefix: str = "",
|
|
685
|
-
)
|
|
686
|
-
|
|
687
|
-
self.
|
|
699
|
+
):
|
|
700
|
+
super().__init__()
|
|
701
|
+
self.pp_group = get_pp_group()
|
|
702
|
+
self.config = config
|
|
688
703
|
self.vocab_size = config.vocab_size
|
|
689
|
-
self.
|
|
704
|
+
self.embed_dim = config.hidden_size
|
|
705
|
+
if self.pp_group.is_first_rank:
|
|
706
|
+
self.embed_tokens = VocabParallelEmbedding(
|
|
707
|
+
config.vocab_size,
|
|
708
|
+
config.hidden_size,
|
|
709
|
+
enable_tp=not is_dp_attention_enabled(),
|
|
710
|
+
)
|
|
711
|
+
else:
|
|
712
|
+
self.embed_tokens = PPMissingLayer()
|
|
690
713
|
|
|
691
|
-
self.embed_tokens = VocabParallelEmbedding(
|
|
692
|
-
config.vocab_size,
|
|
693
|
-
config.hidden_size,
|
|
694
|
-
enable_tp=not is_dp_attention_enabled(),
|
|
695
|
-
)
|
|
696
714
|
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
|
697
|
-
self.layers =
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
715
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
|
716
|
+
config.num_hidden_layers,
|
|
717
|
+
lambda idx, prefix: Glm4MoeDecoderLayer(
|
|
718
|
+
layer_id=idx,
|
|
719
|
+
config=config,
|
|
720
|
+
quant_config=quant_config,
|
|
721
|
+
prefix=prefix,
|
|
722
|
+
alt_stream=self.alt_stream,
|
|
723
|
+
),
|
|
724
|
+
pp_rank=self.pp_group.rank_in_group,
|
|
725
|
+
pp_size=self.pp_group.world_size,
|
|
726
|
+
prefix=add_prefix("layers", prefix),
|
|
708
727
|
)
|
|
709
|
-
self.pp_group
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
728
|
+
if self.pp_group.is_last_rank:
|
|
729
|
+
self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
|
|
730
|
+
else:
|
|
731
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
|
713
732
|
|
|
733
|
+
def get_input_embeddings(self) -> torch.Tensor:
|
|
734
|
+
return self.embed_tokens
|
|
714
735
|
|
|
715
|
-
|
|
736
|
+
def forward(
|
|
737
|
+
self,
|
|
738
|
+
input_ids: torch.Tensor,
|
|
739
|
+
positions: torch.Tensor,
|
|
740
|
+
forward_batch: ForwardBatch,
|
|
741
|
+
input_embeds: torch.Tensor = None,
|
|
742
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|
743
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
|
744
|
+
if self.pp_group.is_first_rank:
|
|
745
|
+
if input_embeds is None:
|
|
746
|
+
hidden_states = self.embed_tokens(input_ids)
|
|
747
|
+
else:
|
|
748
|
+
hidden_states = input_embeds
|
|
749
|
+
residual = None
|
|
750
|
+
else:
|
|
751
|
+
assert pp_proxy_tensors is not None
|
|
752
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
|
753
|
+
residual = pp_proxy_tensors["residual"]
|
|
716
754
|
|
|
755
|
+
normal_start_layer = self.start_layer
|
|
756
|
+
normal_end_layer = self.end_layer
|
|
757
|
+
if forward_batch.can_run_tbo:
|
|
758
|
+
if (
|
|
759
|
+
self.first_k_dense_replace > normal_start_layer
|
|
760
|
+
and self.first_k_dense_replace < normal_end_layer
|
|
761
|
+
):
|
|
762
|
+
normal_end_layer = self.first_k_dense_replace
|
|
763
|
+
elif self.first_k_dense_replace < normal_start_layer:
|
|
764
|
+
normal_end_layer = normal_start_layer = 0
|
|
765
|
+
|
|
766
|
+
for i in range(normal_start_layer, normal_end_layer):
|
|
767
|
+
with get_global_expert_distribution_recorder().with_current_layer(i):
|
|
768
|
+
layer = self.layers[i]
|
|
769
|
+
hidden_states, residual = layer(
|
|
770
|
+
positions,
|
|
771
|
+
hidden_states,
|
|
772
|
+
forward_batch,
|
|
773
|
+
residual,
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
if normal_end_layer != self.end_layer:
|
|
777
|
+
hidden_states, residual = model_forward_maybe_tbo(
|
|
778
|
+
layers=self.layers[normal_end_layer : self.end_layer],
|
|
779
|
+
enable_tbo=True,
|
|
780
|
+
positions=positions,
|
|
781
|
+
forward_batch=forward_batch,
|
|
782
|
+
hidden_states=hidden_states,
|
|
783
|
+
residual=residual,
|
|
784
|
+
input_data_scatter_mode=self.layers[
|
|
785
|
+
normal_end_layer - 1
|
|
786
|
+
].layer_scatter_modes.layer_output_mode,
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
if not self.pp_group.is_last_rank:
|
|
790
|
+
return PPProxyTensors(
|
|
791
|
+
{
|
|
792
|
+
"hidden_states": hidden_states,
|
|
793
|
+
"residual": residual,
|
|
794
|
+
}
|
|
795
|
+
)
|
|
796
|
+
else:
|
|
797
|
+
if not forward_batch.forward_mode.is_idle():
|
|
798
|
+
if residual is None:
|
|
799
|
+
hidden_states = self.norm(hidden_states)
|
|
800
|
+
else:
|
|
801
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
|
802
|
+
return hidden_states
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
class Glm4MoeForCausalLM(nn.Module):
|
|
717
806
|
def __init__(
|
|
718
807
|
self,
|
|
719
808
|
config: PretrainedConfig,
|
|
@@ -721,12 +810,10 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|
|
721
810
|
prefix: str = "",
|
|
722
811
|
) -> None:
|
|
723
812
|
nn.Module.__init__(self)
|
|
724
|
-
config.moe_layer_freq = 1
|
|
725
813
|
self.config = config
|
|
726
814
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
727
815
|
self.quant_config = quant_config
|
|
728
816
|
self.pp_group = get_pp_group()
|
|
729
|
-
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
|
|
730
817
|
self.model = Glm4MoeModel(
|
|
731
818
|
config, quant_config, prefix=add_prefix("model", prefix)
|
|
732
819
|
)
|
|
@@ -739,49 +826,41 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|
|
739
826
|
)
|
|
740
827
|
self.logits_processor = LogitsProcessor(config)
|
|
741
828
|
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
layer_id: layer.mlp.get_moe_weights()
|
|
745
|
-
for layer_id, layer in enumerate(self.model.layers)
|
|
746
|
-
if isinstance(layer.mlp, DeepseekV2MoE)
|
|
747
|
-
}
|
|
748
|
-
)
|
|
829
|
+
# For EAGLE3 support
|
|
830
|
+
self.capture_aux_hidden_states = False
|
|
749
831
|
|
|
750
|
-
def
|
|
751
|
-
self
|
|
752
|
-
):
|
|
753
|
-
self.num_fused_shared_experts = 0
|
|
754
|
-
if get_global_server_args().disable_shared_experts_fusion:
|
|
755
|
-
return
|
|
832
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
|
833
|
+
return self.model.embed_tokens
|
|
756
834
|
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
self.
|
|
772
|
-
|
|
773
|
-
logger,
|
|
774
|
-
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
|
835
|
+
@torch.no_grad()
|
|
836
|
+
def forward(
|
|
837
|
+
self,
|
|
838
|
+
input_ids: torch.Tensor,
|
|
839
|
+
positions: torch.Tensor,
|
|
840
|
+
forward_batch: ForwardBatch,
|
|
841
|
+
input_embeds: torch.Tensor = None,
|
|
842
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|
843
|
+
) -> torch.Tensor:
|
|
844
|
+
hidden_states = self.model(
|
|
845
|
+
input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
if self.pp_group.is_last_rank:
|
|
849
|
+
return self.logits_processor(
|
|
850
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
|
775
851
|
)
|
|
776
|
-
|
|
852
|
+
else:
|
|
853
|
+
return hidden_states
|
|
777
854
|
|
|
778
|
-
|
|
855
|
+
@property
|
|
856
|
+
def start_layer(self):
|
|
857
|
+
return self.model.start_layer
|
|
779
858
|
|
|
780
|
-
|
|
781
|
-
|
|
859
|
+
@property
|
|
860
|
+
def end_layer(self):
|
|
861
|
+
return self.model.end_layer
|
|
782
862
|
|
|
783
863
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
|
784
|
-
|
|
785
864
|
if is_nextn:
|
|
786
865
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
|
787
866
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
|
@@ -803,117 +882,14 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|
|
803
882
|
("gate_up_proj", "gate_proj", 0),
|
|
804
883
|
("gate_up_proj", "up_proj", 1),
|
|
805
884
|
]
|
|
806
|
-
if self.num_fused_shared_experts > 0:
|
|
807
|
-
assert self.num_fused_shared_experts == 1
|
|
808
|
-
weights_list = list(weights)
|
|
809
|
-
weights_dict = dict(weights_list)
|
|
810
|
-
if self.quant_config is not None:
|
|
811
|
-
if self.quant_config.get_name() == "w8a8_int8":
|
|
812
|
-
suffix_list = [
|
|
813
|
-
"down_proj.weight",
|
|
814
|
-
"down_proj.weight_scale",
|
|
815
|
-
"gate_proj.weight",
|
|
816
|
-
"gate_proj.weight_scale",
|
|
817
|
-
"up_proj.weight",
|
|
818
|
-
"up_proj.weight_scale",
|
|
819
|
-
]
|
|
820
|
-
elif (
|
|
821
|
-
self.quant_config.get_name() == "fp8"
|
|
822
|
-
or self.quant_config.get_name() == "blockwise_int8"
|
|
823
|
-
or self.quant_config.get_name() == "compressed_tensors"
|
|
824
|
-
):
|
|
825
|
-
suffix_list = [
|
|
826
|
-
"down_proj.weight",
|
|
827
|
-
"down_proj.weight_scale",
|
|
828
|
-
"gate_proj.weight",
|
|
829
|
-
"gate_proj.weight_scale",
|
|
830
|
-
"up_proj.weight",
|
|
831
|
-
"up_proj.weight_scale",
|
|
832
|
-
]
|
|
833
|
-
elif self.quant_config.get_name() == "awq":
|
|
834
|
-
suffix_list = [
|
|
835
|
-
"down_proj.qweight",
|
|
836
|
-
"down_proj.qzeros",
|
|
837
|
-
"down_proj.scales",
|
|
838
|
-
"gate_proj.qweight",
|
|
839
|
-
"gate_proj.qzeros",
|
|
840
|
-
"gate_proj.scales",
|
|
841
|
-
"up_proj.qweight",
|
|
842
|
-
"up_proj.qzeros",
|
|
843
|
-
"up_proj.scales",
|
|
844
|
-
]
|
|
845
|
-
elif self.quant_config.get_name() == "modelopt_fp4":
|
|
846
|
-
suffix_list = [
|
|
847
|
-
"down_proj.weight",
|
|
848
|
-
"down_proj.weight_scale",
|
|
849
|
-
"down_proj.weight_scale_2",
|
|
850
|
-
"down_proj.input_scale",
|
|
851
|
-
"gate_proj.weight",
|
|
852
|
-
"gate_proj.weight_scale",
|
|
853
|
-
"gate_proj.weight_scale_2",
|
|
854
|
-
"gate_proj.input_scale",
|
|
855
|
-
"up_proj.weight",
|
|
856
|
-
"up_proj.weight_scale",
|
|
857
|
-
"up_proj.weight_scale_2",
|
|
858
|
-
"up_proj.input_scale",
|
|
859
|
-
]
|
|
860
|
-
else:
|
|
861
|
-
raise ValueError(
|
|
862
|
-
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
|
863
|
-
)
|
|
864
|
-
else:
|
|
865
|
-
suffix_list = [
|
|
866
|
-
"down_proj.weight",
|
|
867
|
-
"gate_proj.weight",
|
|
868
|
-
"up_proj.weight",
|
|
869
|
-
]
|
|
870
|
-
names_to_remove = []
|
|
871
|
-
|
|
872
|
-
moe_layers = (
|
|
873
|
-
range(
|
|
874
|
-
self.config.first_k_dense_replace,
|
|
875
|
-
self.config.num_hidden_layers,
|
|
876
|
-
self.config.moe_layer_freq,
|
|
877
|
-
)
|
|
878
|
-
if not is_nextn
|
|
879
|
-
else [nextn_layer_id]
|
|
880
|
-
)
|
|
881
|
-
|
|
882
|
-
for moe_layer in moe_layers:
|
|
883
|
-
for suffix in suffix_list:
|
|
884
|
-
shared_expert_weight_name = (
|
|
885
|
-
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
|
886
|
-
)
|
|
887
|
-
# online fp8 quantization does not load weight_scale
|
|
888
|
-
if shared_expert_weight_name not in weights_dict:
|
|
889
|
-
continue
|
|
890
|
-
weights_list.append(
|
|
891
|
-
(
|
|
892
|
-
f"model.layers.{moe_layer}."
|
|
893
|
-
f"mlp.experts."
|
|
894
|
-
f"{self.config.n_routed_experts + 0}"
|
|
895
|
-
f".{suffix}",
|
|
896
|
-
weights_dict[shared_expert_weight_name],
|
|
897
|
-
)
|
|
898
|
-
)
|
|
899
|
-
names_to_remove += [shared_expert_weight_name]
|
|
900
|
-
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
|
901
885
|
|
|
902
|
-
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
903
|
-
# (param_name, weight_name, expert_id, shard_id)
|
|
904
886
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
905
887
|
ckpt_gate_proj_name="gate_proj",
|
|
906
888
|
ckpt_down_proj_name="down_proj",
|
|
907
889
|
ckpt_up_proj_name="up_proj",
|
|
908
|
-
num_experts=self.config.n_routed_experts
|
|
890
|
+
num_experts=self.config.n_routed_experts,
|
|
909
891
|
)
|
|
910
892
|
|
|
911
|
-
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
|
912
|
-
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
|
913
|
-
self.config.q_lora_rank is not None
|
|
914
|
-
)
|
|
915
|
-
cached_a_proj = {} if fuse_qkv_a_proj else None
|
|
916
|
-
|
|
917
893
|
if is_nextn:
|
|
918
894
|
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
|
|
919
895
|
nextn_spec_weight_names = [
|
|
@@ -969,22 +945,36 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|
|
969
945
|
# name will be updated to mlp.experts[0].gate_up_proj, which
|
|
970
946
|
# will then be updated below in expert_params_mapping
|
|
971
947
|
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
|
972
|
-
if
|
|
948
|
+
if "mlp.experts" in name:
|
|
973
949
|
continue
|
|
974
950
|
name = name.replace(weight_name, param_name)
|
|
975
951
|
# Skip loading extra bias for GPTQ models.
|
|
976
952
|
if name.endswith(".bias") and name not in params_dict:
|
|
977
953
|
continue
|
|
954
|
+
if name not in params_dict:
|
|
955
|
+
continue
|
|
956
|
+
|
|
978
957
|
param = params_dict[name]
|
|
979
958
|
weight_loader = param.weight_loader
|
|
980
959
|
weight_loader(param, loaded_weight, shard_id)
|
|
981
960
|
break
|
|
982
961
|
else:
|
|
962
|
+
# Track if this is an expert weight to enable early skipping
|
|
963
|
+
is_expert_weight = False
|
|
964
|
+
|
|
983
965
|
for mapping in expert_params_mapping:
|
|
984
966
|
param_name, weight_name, expert_id, shard_id = mapping
|
|
985
967
|
if weight_name not in name:
|
|
986
968
|
continue
|
|
969
|
+
|
|
970
|
+
# Mark as expert weight regardless of whether we can process it
|
|
971
|
+
is_expert_weight = True
|
|
972
|
+
|
|
987
973
|
name = name.replace(weight_name, param_name)
|
|
974
|
+
if name not in params_dict:
|
|
975
|
+
# Expert weight not on this rank, will be skipped below
|
|
976
|
+
continue
|
|
977
|
+
|
|
988
978
|
param = params_dict[name]
|
|
989
979
|
weight_loader = param.weight_loader
|
|
990
980
|
weight_loader(
|
|
@@ -996,65 +986,43 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|
|
996
986
|
)
|
|
997
987
|
break
|
|
998
988
|
else:
|
|
989
|
+
if is_expert_weight:
|
|
990
|
+
# This is an expert weight but not mapped to this rank, skip all remaining processing
|
|
991
|
+
continue
|
|
992
|
+
|
|
999
993
|
# Skip loading extra bias for GPTQ models.
|
|
1000
994
|
if name.endswith(".bias") and name not in params_dict:
|
|
1001
995
|
continue
|
|
1002
|
-
if
|
|
1003
|
-
|
|
1004
|
-
):
|
|
1005
|
-
cached_a_proj[name] = loaded_weight
|
|
1006
|
-
q_a_proj_name = (
|
|
1007
|
-
name
|
|
1008
|
-
if "q_a_proj" in name
|
|
1009
|
-
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
|
1010
|
-
)
|
|
1011
|
-
kv_a_proj_name = (
|
|
1012
|
-
name
|
|
1013
|
-
if "kv_a_proj_with_mqa" in name
|
|
1014
|
-
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
|
1015
|
-
)
|
|
996
|
+
if name not in params_dict:
|
|
997
|
+
continue
|
|
1016
998
|
|
|
1017
|
-
|
|
1018
|
-
if (
|
|
1019
|
-
q_a_proj_name in cached_a_proj
|
|
1020
|
-
and kv_a_proj_name in cached_a_proj
|
|
1021
|
-
):
|
|
1022
|
-
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
|
1023
|
-
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
|
1024
|
-
fused_weight = torch.cat(
|
|
1025
|
-
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
|
1026
|
-
)
|
|
1027
|
-
param_name = (
|
|
1028
|
-
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
|
1029
|
-
if "q_a_proj" in name
|
|
1030
|
-
else name.replace(
|
|
1031
|
-
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
|
|
1032
|
-
)
|
|
1033
|
-
)
|
|
1034
|
-
param = params_dict[param_name]
|
|
1035
|
-
|
|
1036
|
-
weight_loader = getattr(
|
|
1037
|
-
param, "weight_loader", default_weight_loader
|
|
1038
|
-
)
|
|
1039
|
-
weight_loader(param, fused_weight)
|
|
1040
|
-
cached_a_proj.pop(q_a_proj_name)
|
|
1041
|
-
cached_a_proj.pop(kv_a_proj_name)
|
|
1042
|
-
else:
|
|
1043
|
-
if (
|
|
1044
|
-
"k_scale" in name or "v_scale" in name
|
|
1045
|
-
) and name not in params_dict:
|
|
1046
|
-
# modelopt attn kv scale is named differently
|
|
1047
|
-
if any(scale in name for scale in ["k_scale", "v_scale"]):
|
|
1048
|
-
name = name.replace("_proj", "attn_mqa")
|
|
1049
|
-
else:
|
|
1050
|
-
logger.warning(
|
|
1051
|
-
f"Unknown scale found in checkpoint: {name}"
|
|
1052
|
-
)
|
|
999
|
+
if name in params_dict.keys():
|
|
1053
1000
|
param = params_dict[name]
|
|
1054
1001
|
weight_loader = getattr(
|
|
1055
1002
|
param, "weight_loader", default_weight_loader
|
|
1056
1003
|
)
|
|
1057
1004
|
weight_loader(param, loaded_weight)
|
|
1005
|
+
else:
|
|
1006
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
|
1007
|
+
|
|
1008
|
+
def get_embed_and_head(self):
|
|
1009
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
|
1010
|
+
|
|
1011
|
+
def set_embed_and_head(self, embed, head):
|
|
1012
|
+
del self.model.embed_tokens.weight
|
|
1013
|
+
del self.lm_head.weight
|
|
1014
|
+
self.model.embed_tokens.weight = embed
|
|
1015
|
+
self.lm_head.weight = head
|
|
1016
|
+
torch.cuda.empty_cache()
|
|
1017
|
+
torch.cuda.synchronize()
|
|
1018
|
+
|
|
1019
|
+
@classmethod
|
|
1020
|
+
def get_model_config_for_expert_location(cls, config):
|
|
1021
|
+
return ModelConfigForExpertLocation(
|
|
1022
|
+
num_layers=config.num_hidden_layers,
|
|
1023
|
+
num_logical_experts=config.n_routed_experts,
|
|
1024
|
+
num_groups=config.n_group,
|
|
1025
|
+
)
|
|
1058
1026
|
|
|
1059
1027
|
|
|
1060
1028
|
EntryClass = [Glm4MoeForCausalLM]
|