sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +375 -51
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py
CHANGED
@@ -16,7 +16,6 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
17
17
|
"""Inference-only Grok1 model."""
|
18
18
|
import functools
|
19
|
-
import json
|
20
19
|
import logging
|
21
20
|
import math
|
22
21
|
import os
|
@@ -35,9 +34,16 @@ from sglang.srt.distributed import (
|
|
35
34
|
tensor_model_parallel_all_gather,
|
36
35
|
tensor_model_parallel_all_reduce,
|
37
36
|
)
|
38
|
-
from sglang.srt.layers.
|
37
|
+
from sglang.srt.layers.activation import GeluAndMul
|
38
|
+
from sglang.srt.layers.elementwise import (
|
39
|
+
experts_combine_triton,
|
40
|
+
fused_dual_residual_rmsnorm,
|
41
|
+
fused_rmsnorm,
|
42
|
+
gelu_and_mul_triton,
|
43
|
+
)
|
39
44
|
from sglang.srt.layers.layernorm import RMSNorm
|
40
45
|
from sglang.srt.layers.linear import (
|
46
|
+
MergedColumnParallelLinear,
|
41
47
|
QKVParallelLinear,
|
42
48
|
ReplicatedLinear,
|
43
49
|
RowParallelLinear,
|
@@ -49,7 +55,12 @@ from sglang.srt.layers.moe.router import fused_moe_router_shim
|
|
49
55
|
from sglang.srt.layers.moe.topk import TopK
|
50
56
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
51
57
|
from sglang.srt.layers.radix_attention import RadixAttention
|
52
|
-
from sglang.srt.layers.rotary_embedding import
|
58
|
+
from sglang.srt.layers.rotary_embedding import (
|
59
|
+
RotaryEmbedding,
|
60
|
+
_yarn_find_correction_range,
|
61
|
+
_yarn_get_mscale,
|
62
|
+
get_rope,
|
63
|
+
)
|
53
64
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
54
65
|
ParallelLMHead,
|
55
66
|
VocabParallelEmbedding,
|
@@ -58,13 +69,60 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
58
69
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
59
70
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
60
71
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
61
|
-
from sglang.srt.utils import dump_to_file
|
72
|
+
from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file
|
62
73
|
|
63
74
|
logger = logging.getLogger(__name__)
|
64
75
|
|
65
76
|
|
77
|
+
# Dump tensors for debugging
|
66
78
|
debug_tensor_dump_output_folder = None
|
79
|
+
debug_tensor_dump_prefill_only = False
|
80
|
+
# Skip all the other tensor dumps, only dump the target logits
|
81
|
+
debug_tensor_dump_only_target_logprobs = False
|
67
82
|
debug_tensor_dump_inject = False
|
83
|
+
debug_tensor_dump_layers = None
|
84
|
+
debug_tensor_dump_test = False
|
85
|
+
|
86
|
+
|
87
|
+
class Grok1MLP(nn.Module):
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
hidden_size: int,
|
91
|
+
intermediate_size: int,
|
92
|
+
layer_id: int,
|
93
|
+
quant_config: Optional[QuantizationConfig] = None,
|
94
|
+
prefix: str = "",
|
95
|
+
reduce_results=True,
|
96
|
+
use_presharded_weights: bool = False,
|
97
|
+
split_gate_up: bool = False,
|
98
|
+
) -> None:
|
99
|
+
super().__init__()
|
100
|
+
|
101
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
102
|
+
hidden_size,
|
103
|
+
[intermediate_size] * 2,
|
104
|
+
bias=False,
|
105
|
+
quant_config=quant_config,
|
106
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
107
|
+
use_presharded_weights=use_presharded_weights,
|
108
|
+
)
|
109
|
+
self.down_proj = RowParallelLinear(
|
110
|
+
intermediate_size,
|
111
|
+
hidden_size,
|
112
|
+
bias=False,
|
113
|
+
quant_config=quant_config,
|
114
|
+
prefix=add_prefix("down_proj", prefix),
|
115
|
+
reduce_results=reduce_results,
|
116
|
+
use_presharded_weights=use_presharded_weights,
|
117
|
+
)
|
118
|
+
self.act_fn = GeluAndMul(approximate="tanh")
|
119
|
+
self.layer_id = layer_id
|
120
|
+
|
121
|
+
def forward(self, x):
|
122
|
+
gate_up, _ = self.gate_up_proj(x)
|
123
|
+
x, _ = gelu_and_mul_triton(gate_up)
|
124
|
+
x, _ = self.down_proj(x)
|
125
|
+
return x
|
68
126
|
|
69
127
|
|
70
128
|
class Grok1MoE(nn.Module):
|
@@ -87,10 +145,11 @@ class Grok1MoE(nn.Module):
|
|
87
145
|
params_dtype: Optional[torch.dtype] = None,
|
88
146
|
quant_config: Optional[QuantizationConfig] = None,
|
89
147
|
tp_size: Optional[int] = None,
|
90
|
-
reduce_results=True,
|
148
|
+
reduce_results: bool = True,
|
91
149
|
use_presharded_weights: bool = False,
|
92
150
|
inplace: bool = True,
|
93
151
|
no_combine: bool = False,
|
152
|
+
prefix: str = "",
|
94
153
|
):
|
95
154
|
super().__init__()
|
96
155
|
self.hidden_size = hidden_size
|
@@ -135,7 +194,6 @@ class Grok1MoE(nn.Module):
|
|
135
194
|
intermediate_size=intermediate_size,
|
136
195
|
params_dtype=params_dtype,
|
137
196
|
quant_config=quant_config,
|
138
|
-
tp_size=tp_size,
|
139
197
|
activation="gelu",
|
140
198
|
**kwargs,
|
141
199
|
)
|
@@ -146,6 +204,135 @@ class Grok1MoE(nn.Module):
|
|
146
204
|
return self.experts(hidden_states, topk_output)
|
147
205
|
|
148
206
|
|
207
|
+
def _yarn_linear_ramp_mask(
|
208
|
+
low: float, high: float, dim: int, dtype: torch.dtype
|
209
|
+
) -> torch.Tensor:
|
210
|
+
if low == high:
|
211
|
+
low -= 0.001 # Prevent singularity
|
212
|
+
|
213
|
+
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
|
214
|
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
215
|
+
return ramp_func
|
216
|
+
|
217
|
+
|
218
|
+
def get_rope_scaling(config):
|
219
|
+
rope_type = getattr(config, "rope_type", None)
|
220
|
+
if rope_type:
|
221
|
+
original_max_position_embeddings = getattr(
|
222
|
+
config, "original_max_position_embeddings", None
|
223
|
+
)
|
224
|
+
scaling_factor = getattr(config, "scaling_factor", None)
|
225
|
+
extrapolation_factor = getattr(config, "extrapolation_factor", 1.0)
|
226
|
+
attn_factor = getattr(config, "attn_factor", 1.0)
|
227
|
+
beta_fast = getattr(config, "beta_fast", 32)
|
228
|
+
beta_slow = getattr(config, "beta_slow", 1)
|
229
|
+
rope_scaling = {
|
230
|
+
"extra_method": rope_type,
|
231
|
+
"max_position_embeddings": original_max_position_embeddings,
|
232
|
+
"scaling_factor": scaling_factor,
|
233
|
+
"extrapolation_factor": extrapolation_factor,
|
234
|
+
"attn_factor": attn_factor,
|
235
|
+
"beta_fast": beta_fast,
|
236
|
+
"beta_slow": beta_slow,
|
237
|
+
"dtype": torch.float,
|
238
|
+
}
|
239
|
+
return rope_scaling
|
240
|
+
else:
|
241
|
+
return None
|
242
|
+
|
243
|
+
|
244
|
+
class ScalingRotaryEmbedding(RotaryEmbedding):
|
245
|
+
"""Scale the RotaryEmbedding in a way similar to YaRN method. https://arxiv.org/pdf/2309.00071."""
|
246
|
+
|
247
|
+
def __init__(
|
248
|
+
self,
|
249
|
+
head_size: int,
|
250
|
+
rotary_dim: int,
|
251
|
+
max_position_embeddings: int,
|
252
|
+
base: int,
|
253
|
+
is_neox_style: bool,
|
254
|
+
scaling_factor: float,
|
255
|
+
dtype: torch.dtype,
|
256
|
+
*,
|
257
|
+
extra_method: str = "yarn_log",
|
258
|
+
extrapolation_factor: float = 1,
|
259
|
+
attn_factor: float = 1,
|
260
|
+
beta_fast: int = 32,
|
261
|
+
beta_slow: int = 1,
|
262
|
+
) -> None:
|
263
|
+
self.scaling_factor = scaling_factor
|
264
|
+
self.extra_method = extra_method
|
265
|
+
self.extrapolation_factor = extrapolation_factor
|
266
|
+
self.attn_factor = attn_factor
|
267
|
+
self.beta_fast = beta_fast
|
268
|
+
self.beta_slow = beta_slow
|
269
|
+
# Get n-d magnitude scaling corrected for interpolation
|
270
|
+
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
271
|
+
super().__init__(
|
272
|
+
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
273
|
+
)
|
274
|
+
|
275
|
+
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
276
|
+
pos_freqs = self.base ** (
|
277
|
+
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
|
278
|
+
)
|
279
|
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
280
|
+
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
281
|
+
|
282
|
+
low, high = _yarn_find_correction_range(
|
283
|
+
self.beta_fast,
|
284
|
+
self.beta_slow,
|
285
|
+
self.rotary_dim,
|
286
|
+
self.base,
|
287
|
+
self.max_position_embeddings,
|
288
|
+
)
|
289
|
+
# Get n-d rotational scaling corrected for extrapolation
|
290
|
+
inv_freq_mask = (
|
291
|
+
1
|
292
|
+
- _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
|
293
|
+
) * self.extrapolation_factor
|
294
|
+
if self.extra_method in ["original"]:
|
295
|
+
inv_freq = inv_freq_extrapolation
|
296
|
+
elif self.extra_method in ["yarn", "yarn_linear"]:
|
297
|
+
inv_freq = (
|
298
|
+
inv_freq_interpolation * (1 - inv_freq_mask)
|
299
|
+
+ inv_freq_extrapolation * inv_freq_mask
|
300
|
+
)
|
301
|
+
elif self.extra_method == "yarn_log":
|
302
|
+
inv_freq = torch.exp(
|
303
|
+
torch.log(inv_freq_extrapolation) * inv_freq_mask
|
304
|
+
+ torch.log(inv_freq_interpolation) * (1.0 - inv_freq_mask)
|
305
|
+
)
|
306
|
+
elif self.extra_method == "theta_scale":
|
307
|
+
exponents = torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
|
308
|
+
theta_scale_exponent = self.base ** (
|
309
|
+
math.log(
|
310
|
+
self.max_position_embeddings * self.scaling_factor / (2 * math.pi)
|
311
|
+
)
|
312
|
+
/ math.log(self.max_position_embeddings / (2 * math.pi))
|
313
|
+
)
|
314
|
+
inv_freq = torch.tensor(
|
315
|
+
1.0 / (theta_scale_exponent ** (exponents / self.rotary_dim)),
|
316
|
+
dtype=torch.float32,
|
317
|
+
)
|
318
|
+
else:
|
319
|
+
raise ValueError(f"Unknown extrapolation method: {self.extra_method}")
|
320
|
+
return inv_freq
|
321
|
+
|
322
|
+
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
323
|
+
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
324
|
+
t = torch.arange(
|
325
|
+
self.max_position_embeddings * self.scaling_factor, dtype=torch.float32
|
326
|
+
)
|
327
|
+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
328
|
+
# cos = freqs.cos() * self.mscale
|
329
|
+
# sin = freqs.sin() * self.mscale
|
330
|
+
cos = freqs.cos()
|
331
|
+
sin = freqs.sin()
|
332
|
+
cache = torch.cat((cos, sin), dim=-1)
|
333
|
+
return cache
|
334
|
+
|
335
|
+
|
149
336
|
class Grok1Attention(nn.Module):
|
150
337
|
def __init__(
|
151
338
|
self,
|
@@ -158,7 +345,9 @@ class Grok1Attention(nn.Module):
|
|
158
345
|
rope_theta: float = 10000,
|
159
346
|
quant_config: Optional[QuantizationConfig] = None,
|
160
347
|
reduce_results: bool = True,
|
348
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
161
349
|
load_presharded_attn: bool = False,
|
350
|
+
prefix: str = "",
|
162
351
|
) -> None:
|
163
352
|
super().__init__()
|
164
353
|
self.config = config
|
@@ -184,7 +373,9 @@ class Grok1Attention(nn.Module):
|
|
184
373
|
self.kv_size = self.num_kv_heads * self.head_dim
|
185
374
|
self.scaling = self.head_dim**-0.5
|
186
375
|
self.rope_theta = rope_theta
|
376
|
+
rope_scaling = get_rope_scaling(config)
|
187
377
|
self.load_presharded_attn = load_presharded_attn
|
378
|
+
self.alt_stream = alt_stream or torch.cuda.Stream()
|
188
379
|
|
189
380
|
self.qkv_proj = QKVParallelLinear(
|
190
381
|
hidden_size,
|
@@ -196,6 +387,7 @@ class Grok1Attention(nn.Module):
|
|
196
387
|
tp_rank=attn_tp_rank,
|
197
388
|
tp_size=attn_tp_size,
|
198
389
|
load_presharded_attn=self.load_presharded_attn,
|
390
|
+
prefix=add_prefix("qkv_proj", prefix),
|
199
391
|
)
|
200
392
|
self.o_proj = RowParallelLinear(
|
201
393
|
self.total_num_heads * self.head_dim,
|
@@ -206,6 +398,7 @@ class Grok1Attention(nn.Module):
|
|
206
398
|
tp_rank=attn_tp_rank,
|
207
399
|
tp_size=attn_tp_size,
|
208
400
|
use_presharded_weights=self.load_presharded_attn,
|
401
|
+
prefix=add_prefix("o_proj", prefix),
|
209
402
|
)
|
210
403
|
self.rotary_emb = get_rope(
|
211
404
|
self.head_dim,
|
@@ -215,7 +408,37 @@ class Grok1Attention(nn.Module):
|
|
215
408
|
is_neox_style=True,
|
216
409
|
)
|
217
410
|
|
411
|
+
self.rope_rotate_half_dims = getattr(config, "rope_rotate_half_dims", False)
|
412
|
+
|
413
|
+
if rope_scaling is not None:
|
414
|
+
self.rotary_emb = ScalingRotaryEmbedding(
|
415
|
+
self.head_dim,
|
416
|
+
rotary_dim=(
|
417
|
+
self.head_dim
|
418
|
+
if not self.rope_rotate_half_dims
|
419
|
+
else self.head_dim // 2
|
420
|
+
),
|
421
|
+
base=int(self.rope_theta),
|
422
|
+
is_neox_style=True,
|
423
|
+
**rope_scaling,
|
424
|
+
)
|
425
|
+
pos_encoding_mode = "NONE"
|
426
|
+
else:
|
427
|
+
self.rotary_emb = get_rope(
|
428
|
+
self.head_dim,
|
429
|
+
rotary_dim=(
|
430
|
+
self.head_dim
|
431
|
+
if not self.rope_rotate_half_dims
|
432
|
+
else self.head_dim // 2
|
433
|
+
),
|
434
|
+
max_position=max_position,
|
435
|
+
base=int(self.rope_theta),
|
436
|
+
is_neox_style=True,
|
437
|
+
)
|
438
|
+
pos_encoding_mode = "NONE"
|
439
|
+
|
218
440
|
logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)
|
441
|
+
logit_capping_method = getattr(config, "attn_logit_softcapping_method", "tanh")
|
219
442
|
|
220
443
|
self.attn = RadixAttention(
|
221
444
|
self.num_heads,
|
@@ -225,7 +448,11 @@ class Grok1Attention(nn.Module):
|
|
225
448
|
layer_id=layer_id,
|
226
449
|
logit_cap=logit_cap,
|
227
450
|
quant_config=quant_config,
|
451
|
+
pos_encoding_mode=pos_encoding_mode,
|
452
|
+
logit_capping_method=logit_capping_method,
|
453
|
+
prefix=add_prefix("attn", prefix),
|
228
454
|
)
|
455
|
+
self.attn.xai_temperature_len = getattr(self.config, "attn_temperature_len", -1)
|
229
456
|
|
230
457
|
def forward(
|
231
458
|
self,
|
@@ -257,6 +484,8 @@ class Grok1Attention(nn.Module):
|
|
257
484
|
)
|
258
485
|
|
259
486
|
qkv, _ = self.qkv_proj(hidden_states)
|
487
|
+
dispose_tensor(hidden_states)
|
488
|
+
|
260
489
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
261
490
|
q, k = self.rotary_emb(positions, q, k)
|
262
491
|
|
@@ -289,6 +518,7 @@ class Grok1Attention(nn.Module):
|
|
289
518
|
)
|
290
519
|
|
291
520
|
attn_output = self.attn(q, k, v, forward_batch)
|
521
|
+
del q, k, v, qkv
|
292
522
|
|
293
523
|
if debug_tensor_dump_output_folder:
|
294
524
|
dump_to_file(
|
@@ -313,49 +543,89 @@ class Grok1DecoderLayer(nn.Module):
|
|
313
543
|
load_presharded_moe: bool = False,
|
314
544
|
load_presharded_attn: bool = False,
|
315
545
|
load_presharded_mlp: bool = False,
|
546
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
547
|
+
skip_moe: bool = False,
|
548
|
+
prefix: str = "",
|
316
549
|
) -> None:
|
317
550
|
super().__init__()
|
318
551
|
self.num_experts = config.num_local_experts
|
319
552
|
self.hidden_size = config.hidden_size
|
553
|
+
self.residual_moe = getattr(config, "residual_moe", False)
|
320
554
|
self.layer_id = layer_id
|
555
|
+
self.alt_stream = alt_stream or torch.cuda.Stream()
|
321
556
|
|
322
557
|
rope_theta = getattr(config, "rope_theta", 10000)
|
323
558
|
self.self_attn = Grok1Attention(
|
324
559
|
config=config,
|
325
560
|
hidden_size=self.hidden_size,
|
326
561
|
num_heads=config.num_attention_heads,
|
327
|
-
max_position=
|
562
|
+
max_position=(
|
563
|
+
config.context_len
|
564
|
+
if hasattr(config, "context_len")
|
565
|
+
else config.max_position_embeddings
|
566
|
+
),
|
328
567
|
num_kv_heads=config.num_key_value_heads,
|
329
568
|
layer_id=layer_id,
|
330
569
|
rope_theta=rope_theta,
|
331
570
|
quant_config=quant_config,
|
332
571
|
reduce_results=False,
|
572
|
+
alt_stream=self.alt_stream,
|
333
573
|
load_presharded_attn=load_presharded_attn,
|
574
|
+
prefix=add_prefix("attn", prefix),
|
334
575
|
)
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
config,
|
343
|
-
|
344
|
-
getattr(
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
576
|
+
|
577
|
+
split_gate_up = not getattr(config, "merge_gate_up", True)
|
578
|
+
if self.num_experts > 0:
|
579
|
+
self.block_sparse_moe = Grok1MoE(
|
580
|
+
config=config,
|
581
|
+
layer_id=layer_id,
|
582
|
+
num_experts=config.num_local_experts,
|
583
|
+
top_k=config.num_experts_per_tok,
|
584
|
+
hidden_size=config.hidden_size,
|
585
|
+
intermediate_size=getattr(
|
586
|
+
config,
|
587
|
+
"moe_intermediate_size",
|
588
|
+
getattr(config, "intermediate_size", None),
|
589
|
+
),
|
590
|
+
quant_config=quant_config,
|
591
|
+
reduce_results=not self.residual_moe,
|
592
|
+
use_presharded_weights=load_presharded_moe,
|
593
|
+
inplace=False, # not self.residual_moe,
|
594
|
+
no_combine=False, # self.residual_moe, # just a suggestion to not combine topk
|
595
|
+
prefix=add_prefix("block_sparse_moe", prefix),
|
596
|
+
)
|
597
|
+
if self.residual_moe:
|
598
|
+
self.mlp = Grok1MLP(
|
599
|
+
hidden_size=config.hidden_size,
|
600
|
+
intermediate_size=config.intermediate_size,
|
601
|
+
quant_config=quant_config,
|
602
|
+
reduce_results=False,
|
603
|
+
use_presharded_weights=load_presharded_mlp,
|
604
|
+
layer_id=layer_id,
|
605
|
+
split_gate_up=split_gate_up,
|
606
|
+
)
|
607
|
+
else:
|
608
|
+
raise NotImplementedError()
|
352
609
|
|
353
610
|
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
354
611
|
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
355
612
|
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
356
613
|
self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
357
614
|
|
358
|
-
self.
|
615
|
+
if self.num_experts > 0:
|
616
|
+
if self.residual_moe:
|
617
|
+
# NOTE: self.block_sparse_moe modifies the input in-place,
|
618
|
+
# so we have to call it later. Be aware of any possible related errors.
|
619
|
+
if get_tensor_model_parallel_world_size() > 1:
|
620
|
+
self.ffn = lambda x: tensor_model_parallel_all_reduce(
|
621
|
+
self.moe_with_rmoe(x)
|
622
|
+
)
|
623
|
+
else:
|
624
|
+
self.ffn = self.moe_with_rmoe
|
625
|
+
else:
|
626
|
+
self.ffn = self.block_sparse_moe
|
627
|
+
else:
|
628
|
+
raise NotImplementedError()
|
359
629
|
|
360
630
|
def forward(
|
361
631
|
self,
|
@@ -365,6 +635,10 @@ class Grok1DecoderLayer(nn.Module):
|
|
365
635
|
residual: Optional[torch.Tensor] = None,
|
366
636
|
deferred_norm: Optional[RMSNorm] = None,
|
367
637
|
) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]:
|
638
|
+
|
639
|
+
hidden_states_original = hidden_states
|
640
|
+
residual_original = residual
|
641
|
+
|
368
642
|
# Self Attention
|
369
643
|
if deferred_norm is not None:
|
370
644
|
assert residual is not None
|
@@ -387,6 +661,14 @@ class Grok1DecoderLayer(nn.Module):
|
|
387
661
|
hidden_states,
|
388
662
|
)
|
389
663
|
|
664
|
+
if residual_original is not None:
|
665
|
+
dispose_tensor(residual_original)
|
666
|
+
|
667
|
+
dispose_flag = False
|
668
|
+
if residual is not hidden_states_original:
|
669
|
+
dispose_flag = True
|
670
|
+
dispose_tensor(hidden_states_original)
|
671
|
+
|
390
672
|
hidden_states = self.self_attn(
|
391
673
|
positions=positions,
|
392
674
|
hidden_states=hidden_states,
|
@@ -404,10 +686,23 @@ class Grok1DecoderLayer(nn.Module):
|
|
404
686
|
self.post_attn_norm.variance_epsilon,
|
405
687
|
)
|
406
688
|
|
689
|
+
if not dispose_flag:
|
690
|
+
dispose_tensor(hidden_states_original)
|
691
|
+
|
407
692
|
# Fully Connected
|
408
693
|
hidden_states = self.ffn(hidden_states)
|
409
694
|
return hidden_states, residual, self.post_moe_norm # defer layernorm
|
410
695
|
|
696
|
+
def moe_with_rmoe(self, x):
|
697
|
+
current_stream = torch.cuda.current_stream()
|
698
|
+
self.alt_stream.wait_stream(current_stream)
|
699
|
+
mlp_result = self.mlp(x)
|
700
|
+
with torch.cuda.stream(self.alt_stream):
|
701
|
+
# moe should not be inplace because of stream race condition
|
702
|
+
moe_result = self.block_sparse_moe(x)
|
703
|
+
current_stream.wait_stream(self.alt_stream)
|
704
|
+
return (mlp_result + moe_result) / 1.4142135623730951
|
705
|
+
|
411
706
|
|
412
707
|
class Grok1Model(nn.Module):
|
413
708
|
def __init__(
|
@@ -418,6 +713,8 @@ class Grok1Model(nn.Module):
|
|
418
713
|
load_presharded_embedding: bool = False,
|
419
714
|
load_presharded_attn: bool = False,
|
420
715
|
load_presharded_mlp: bool = False,
|
716
|
+
replicate_embedding: bool = False,
|
717
|
+
prefix: str = "",
|
421
718
|
) -> None:
|
422
719
|
super().__init__()
|
423
720
|
self.config = config
|
@@ -428,7 +725,11 @@ class Grok1Model(nn.Module):
|
|
428
725
|
config.vocab_size,
|
429
726
|
config.hidden_size,
|
430
727
|
use_presharded_weights=load_presharded_embedding,
|
728
|
+
enable_tp=not replicate_embedding,
|
729
|
+
prefix=add_prefix("embed_tokens", prefix),
|
431
730
|
)
|
731
|
+
|
732
|
+
self.alt_stream = torch.cuda.Stream()
|
432
733
|
self.layers = nn.ModuleList(
|
433
734
|
[
|
434
735
|
Grok1DecoderLayer(
|
@@ -438,6 +739,7 @@ class Grok1Model(nn.Module):
|
|
438
739
|
load_presharded_moe=load_presharded_moe,
|
439
740
|
load_presharded_attn=load_presharded_attn,
|
440
741
|
load_presharded_mlp=load_presharded_mlp,
|
742
|
+
alt_stream=self.alt_stream,
|
441
743
|
)
|
442
744
|
for i in range(config.num_hidden_layers)
|
443
745
|
]
|
@@ -507,6 +809,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
507
809
|
self,
|
508
810
|
config: PretrainedConfig,
|
509
811
|
quant_config: Optional[QuantizationConfig] = None,
|
812
|
+
prefix: str = "",
|
510
813
|
) -> None:
|
511
814
|
super().__init__()
|
512
815
|
self.config = config
|
@@ -515,7 +818,8 @@ class Grok1ForCausalLM(nn.Module):
|
|
515
818
|
# Get presharded weights.
|
516
819
|
self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False)
|
517
820
|
self.load_presharded_moe = (
|
518
|
-
|
821
|
+
getattr(config, "load_presharded_moe", True)
|
822
|
+
and self.config.num_local_experts > 0
|
519
823
|
and get_tensor_model_parallel_world_size() > 1
|
520
824
|
)
|
521
825
|
self.load_presharded_attn = getattr(config, "load_presharded_attn", False)
|
@@ -530,14 +834,16 @@ class Grok1ForCausalLM(nn.Module):
|
|
530
834
|
or self.load_presharded_embedding
|
531
835
|
)
|
532
836
|
|
533
|
-
if self.is_weights_presharded:
|
534
|
-
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
535
|
-
|
536
837
|
default_replicate_lm_head = False
|
537
838
|
self.replicate_lm_head = getattr(
|
538
839
|
config, "replicate_lm_head", default_replicate_lm_head
|
539
840
|
)
|
540
841
|
|
842
|
+
if self.is_weights_presharded:
|
843
|
+
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
844
|
+
|
845
|
+
self.replicate_embedding = getattr(config, "replicate_embedding", False)
|
846
|
+
|
541
847
|
self.model = Grok1Model(
|
542
848
|
config,
|
543
849
|
quant_config=quant_config,
|
@@ -545,6 +851,8 @@ class Grok1ForCausalLM(nn.Module):
|
|
545
851
|
load_presharded_embedding=self.load_presharded_embedding,
|
546
852
|
load_presharded_attn=self.load_presharded_attn,
|
547
853
|
load_presharded_mlp=self.load_presharded_mlp,
|
854
|
+
replicate_embedding=self.replicate_embedding,
|
855
|
+
prefix=add_prefix("model", prefix),
|
548
856
|
)
|
549
857
|
|
550
858
|
lm_head_params_dtype = None
|
@@ -554,6 +862,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
554
862
|
config.vocab_size,
|
555
863
|
bias=False,
|
556
864
|
params_dtype=lm_head_params_dtype,
|
865
|
+
prefix=add_prefix("lm_head", prefix),
|
557
866
|
)
|
558
867
|
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
559
868
|
else:
|
@@ -562,6 +871,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
562
871
|
config.hidden_size,
|
563
872
|
use_presharded_weights=self.load_presharded_embedding,
|
564
873
|
params_dtype=lm_head_params_dtype,
|
874
|
+
prefix=add_prefix("lm_head", prefix),
|
565
875
|
)
|
566
876
|
self.logits_processor = LogitsProcessor(config)
|
567
877
|
|
@@ -578,6 +888,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
578
888
|
f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, "
|
579
889
|
f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B"
|
580
890
|
)
|
891
|
+
self.loaded_param_names = set()
|
581
892
|
|
582
893
|
def forward(
|
583
894
|
self,
|
@@ -597,11 +908,13 @@ class Grok1ForCausalLM(nn.Module):
|
|
597
908
|
def load_weights(
|
598
909
|
self,
|
599
910
|
weights: Iterable[Tuple[str, torch.Tensor]],
|
600
|
-
num_experts: Optional[int] = None,
|
601
911
|
ignore_parent_name: bool = False,
|
912
|
+
check_hit_names: bool = True,
|
913
|
+
model_config: PretrainedConfig | None = None,
|
602
914
|
) -> dict[str, torch.Tensor]:
|
603
|
-
if
|
604
|
-
|
915
|
+
if model_config is None:
|
916
|
+
model_config = self.config
|
917
|
+
|
605
918
|
stacked_params_mapping = []
|
606
919
|
stacked_params_mapping += [
|
607
920
|
# (param_name, shard_name, shard_id)
|
@@ -617,6 +930,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
617
930
|
|
618
931
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
619
932
|
# (param_name, weight_name, expert_id, shard_id)
|
933
|
+
num_experts = model_config.num_local_experts
|
620
934
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
621
935
|
ckpt_gate_proj_name="w1",
|
622
936
|
ckpt_down_proj_name="w2",
|
@@ -631,23 +945,26 @@ class Grok1ForCausalLM(nn.Module):
|
|
631
945
|
def load_weight_wrapper(
|
632
946
|
name: str, loaded_weight: torch.Tensor, *args, **kwargs
|
633
947
|
):
|
634
|
-
if ignore_parent_name:
|
635
|
-
name = name.split(".")[-1]
|
636
|
-
|
637
|
-
if name not in params_dict:
|
638
|
-
return
|
639
|
-
|
640
948
|
# Fuse constant multipliers into the weights
|
641
949
|
if "lm_head" in name:
|
642
950
|
loaded_weight = (
|
643
951
|
loaded_weight.to(torch.float32)
|
644
|
-
*
|
952
|
+
* model_config.output_multiplier_scale
|
645
953
|
)
|
646
954
|
|
955
|
+
original_name = name
|
956
|
+
if ignore_parent_name:
|
957
|
+
name = name.split(".")[-1]
|
958
|
+
|
959
|
+
if name not in params_dict:
|
960
|
+
logger.info(f"Skipping {name=} in load_weights_wrapper")
|
961
|
+
return
|
962
|
+
|
647
963
|
param = params_dict[name]
|
648
964
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
649
965
|
weight_loader(param, loaded_weight, *args, **kwargs)
|
650
966
|
hit_names.add(name)
|
967
|
+
self.loaded_param_names.add(original_name)
|
651
968
|
|
652
969
|
for name, loaded_weight in weights:
|
653
970
|
if "rotary_emb.inv_freq" in name:
|
@@ -686,19 +1003,22 @@ class Grok1ForCausalLM(nn.Module):
|
|
686
1003
|
|
687
1004
|
load_weight_wrapper(name=name, loaded_weight=loaded_weight)
|
688
1005
|
|
689
|
-
if
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
if len(missing_exclude_scales) > 0:
|
696
|
-
raise ValueError(
|
697
|
-
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
|
1006
|
+
if check_hit_names:
|
1007
|
+
if len(hit_names) > 5:
|
1008
|
+
missing = all_names - hit_names
|
1009
|
+
missing_exclude_scales = {x for x in missing if "scale" not in x}
|
1010
|
+
logger.info(
|
1011
|
+
f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
|
698
1012
|
)
|
1013
|
+
if len(missing_exclude_scales) > 0:
|
1014
|
+
raise ValueError(
|
1015
|
+
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
|
1016
|
+
)
|
699
1017
|
|
700
|
-
|
701
|
-
|
1018
|
+
elif len(hit_names) == 0:
|
1019
|
+
raise ValueError(
|
1020
|
+
f"load_weights failed because it did not hit any names. {all_names=} {hit_names=}"
|
1021
|
+
)
|
702
1022
|
|
703
1023
|
return hit_names
|
704
1024
|
|
@@ -709,7 +1029,11 @@ class Grok1ForCausalLM(nn.Module):
|
|
709
1029
|
"moe_intermediate_size",
|
710
1030
|
getattr(cfg, "intermediate_size", None),
|
711
1031
|
)
|
712
|
-
|
1032
|
+
residual_moe = getattr(cfg, "residual_moe", False)
|
1033
|
+
if cfg.num_local_experts > 0:
|
1034
|
+
num_experts = cfg.num_local_experts + (1 if residual_moe else 0)
|
1035
|
+
else:
|
1036
|
+
num_experts = 1
|
713
1037
|
|
714
1038
|
wq = (
|
715
1039
|
cfg.num_hidden_layers
|