sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py
CHANGED
@@ -16,7 +16,6 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
17
17
|
"""Inference-only Grok1 model."""
|
18
18
|
import functools
|
19
|
-
import json
|
20
19
|
import logging
|
21
20
|
import math
|
22
21
|
import os
|
@@ -35,9 +34,16 @@ from sglang.srt.distributed import (
|
|
35
34
|
tensor_model_parallel_all_gather,
|
36
35
|
tensor_model_parallel_all_reduce,
|
37
36
|
)
|
38
|
-
from sglang.srt.layers.
|
37
|
+
from sglang.srt.layers.activation import GeluAndMul
|
38
|
+
from sglang.srt.layers.elementwise import (
|
39
|
+
experts_combine_triton,
|
40
|
+
fused_dual_residual_rmsnorm,
|
41
|
+
fused_rmsnorm,
|
42
|
+
gelu_and_mul_triton,
|
43
|
+
)
|
39
44
|
from sglang.srt.layers.layernorm import RMSNorm
|
40
45
|
from sglang.srt.layers.linear import (
|
46
|
+
MergedColumnParallelLinear,
|
41
47
|
QKVParallelLinear,
|
42
48
|
ReplicatedLinear,
|
43
49
|
RowParallelLinear,
|
@@ -49,7 +55,12 @@ from sglang.srt.layers.moe.router import fused_moe_router_shim
|
|
49
55
|
from sglang.srt.layers.moe.topk import TopK
|
50
56
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
51
57
|
from sglang.srt.layers.radix_attention import RadixAttention
|
52
|
-
from sglang.srt.layers.rotary_embedding import
|
58
|
+
from sglang.srt.layers.rotary_embedding import (
|
59
|
+
RotaryEmbedding,
|
60
|
+
_yarn_find_correction_range,
|
61
|
+
_yarn_get_mscale,
|
62
|
+
get_rope,
|
63
|
+
)
|
53
64
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
54
65
|
ParallelLMHead,
|
55
66
|
VocabParallelEmbedding,
|
@@ -58,13 +69,60 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
58
69
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
59
70
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
60
71
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
61
|
-
from sglang.srt.utils import dump_to_file
|
72
|
+
from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file
|
62
73
|
|
63
74
|
logger = logging.getLogger(__name__)
|
64
75
|
|
65
76
|
|
77
|
+
# Dump tensors for debugging
|
66
78
|
debug_tensor_dump_output_folder = None
|
79
|
+
debug_tensor_dump_prefill_only = False
|
80
|
+
# Skip all the other tensor dumps, only dump the target logits
|
81
|
+
debug_tensor_dump_only_target_logprobs = False
|
67
82
|
debug_tensor_dump_inject = False
|
83
|
+
debug_tensor_dump_layers = None
|
84
|
+
debug_tensor_dump_test = False
|
85
|
+
|
86
|
+
|
87
|
+
class Grok1MLP(nn.Module):
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
hidden_size: int,
|
91
|
+
intermediate_size: int,
|
92
|
+
layer_id: int,
|
93
|
+
quant_config: Optional[QuantizationConfig] = None,
|
94
|
+
prefix: str = "",
|
95
|
+
reduce_results=True,
|
96
|
+
use_presharded_weights: bool = False,
|
97
|
+
split_gate_up: bool = False,
|
98
|
+
) -> None:
|
99
|
+
super().__init__()
|
100
|
+
|
101
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
102
|
+
hidden_size,
|
103
|
+
[intermediate_size] * 2,
|
104
|
+
bias=False,
|
105
|
+
quant_config=quant_config,
|
106
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
107
|
+
use_presharded_weights=use_presharded_weights,
|
108
|
+
)
|
109
|
+
self.down_proj = RowParallelLinear(
|
110
|
+
intermediate_size,
|
111
|
+
hidden_size,
|
112
|
+
bias=False,
|
113
|
+
quant_config=quant_config,
|
114
|
+
prefix=add_prefix("down_proj", prefix),
|
115
|
+
reduce_results=reduce_results,
|
116
|
+
use_presharded_weights=use_presharded_weights,
|
117
|
+
)
|
118
|
+
self.act_fn = GeluAndMul(approximate="tanh")
|
119
|
+
self.layer_id = layer_id
|
120
|
+
|
121
|
+
def forward(self, x):
|
122
|
+
gate_up, _ = self.gate_up_proj(x)
|
123
|
+
x, _ = gelu_and_mul_triton(gate_up)
|
124
|
+
x, _ = self.down_proj(x)
|
125
|
+
return x
|
68
126
|
|
69
127
|
|
70
128
|
class Grok1MoE(nn.Module):
|
@@ -87,10 +145,11 @@ class Grok1MoE(nn.Module):
|
|
87
145
|
params_dtype: Optional[torch.dtype] = None,
|
88
146
|
quant_config: Optional[QuantizationConfig] = None,
|
89
147
|
tp_size: Optional[int] = None,
|
90
|
-
reduce_results=True,
|
148
|
+
reduce_results: bool = True,
|
91
149
|
use_presharded_weights: bool = False,
|
92
150
|
inplace: bool = True,
|
93
151
|
no_combine: bool = False,
|
152
|
+
prefix: str = "",
|
94
153
|
):
|
95
154
|
super().__init__()
|
96
155
|
self.hidden_size = hidden_size
|
@@ -135,7 +194,6 @@ class Grok1MoE(nn.Module):
|
|
135
194
|
intermediate_size=intermediate_size,
|
136
195
|
params_dtype=params_dtype,
|
137
196
|
quant_config=quant_config,
|
138
|
-
tp_size=tp_size,
|
139
197
|
activation="gelu",
|
140
198
|
**kwargs,
|
141
199
|
)
|
@@ -146,6 +204,135 @@ class Grok1MoE(nn.Module):
|
|
146
204
|
return self.experts(hidden_states, topk_output)
|
147
205
|
|
148
206
|
|
207
|
+
def _yarn_linear_ramp_mask(
|
208
|
+
low: float, high: float, dim: int, dtype: torch.dtype
|
209
|
+
) -> torch.Tensor:
|
210
|
+
if low == high:
|
211
|
+
low -= 0.001 # Prevent singularity
|
212
|
+
|
213
|
+
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
|
214
|
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
215
|
+
return ramp_func
|
216
|
+
|
217
|
+
|
218
|
+
def get_rope_scaling(config):
|
219
|
+
rope_type = getattr(config, "rope_type", None)
|
220
|
+
if rope_type:
|
221
|
+
original_max_position_embeddings = getattr(
|
222
|
+
config, "original_max_position_embeddings", None
|
223
|
+
)
|
224
|
+
scaling_factor = getattr(config, "scaling_factor", None)
|
225
|
+
extrapolation_factor = getattr(config, "extrapolation_factor", 1.0)
|
226
|
+
attn_factor = getattr(config, "attn_factor", 1.0)
|
227
|
+
beta_fast = getattr(config, "beta_fast", 32)
|
228
|
+
beta_slow = getattr(config, "beta_slow", 1)
|
229
|
+
rope_scaling = {
|
230
|
+
"extra_method": rope_type,
|
231
|
+
"max_position_embeddings": original_max_position_embeddings,
|
232
|
+
"scaling_factor": scaling_factor,
|
233
|
+
"extrapolation_factor": extrapolation_factor,
|
234
|
+
"attn_factor": attn_factor,
|
235
|
+
"beta_fast": beta_fast,
|
236
|
+
"beta_slow": beta_slow,
|
237
|
+
"dtype": torch.float,
|
238
|
+
}
|
239
|
+
return rope_scaling
|
240
|
+
else:
|
241
|
+
return None
|
242
|
+
|
243
|
+
|
244
|
+
class ScalingRotaryEmbedding(RotaryEmbedding):
|
245
|
+
"""Scale the RotaryEmbedding in a way similar to YaRN method. https://arxiv.org/pdf/2309.00071."""
|
246
|
+
|
247
|
+
def __init__(
|
248
|
+
self,
|
249
|
+
head_size: int,
|
250
|
+
rotary_dim: int,
|
251
|
+
max_position_embeddings: int,
|
252
|
+
base: int,
|
253
|
+
is_neox_style: bool,
|
254
|
+
scaling_factor: float,
|
255
|
+
dtype: torch.dtype,
|
256
|
+
*,
|
257
|
+
extra_method: str = "yarn_log",
|
258
|
+
extrapolation_factor: float = 1,
|
259
|
+
attn_factor: float = 1,
|
260
|
+
beta_fast: int = 32,
|
261
|
+
beta_slow: int = 1,
|
262
|
+
) -> None:
|
263
|
+
self.scaling_factor = scaling_factor
|
264
|
+
self.extra_method = extra_method
|
265
|
+
self.extrapolation_factor = extrapolation_factor
|
266
|
+
self.attn_factor = attn_factor
|
267
|
+
self.beta_fast = beta_fast
|
268
|
+
self.beta_slow = beta_slow
|
269
|
+
# Get n-d magnitude scaling corrected for interpolation
|
270
|
+
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
271
|
+
super().__init__(
|
272
|
+
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
273
|
+
)
|
274
|
+
|
275
|
+
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
276
|
+
pos_freqs = self.base ** (
|
277
|
+
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
|
278
|
+
)
|
279
|
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
280
|
+
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
281
|
+
|
282
|
+
low, high = _yarn_find_correction_range(
|
283
|
+
self.beta_fast,
|
284
|
+
self.beta_slow,
|
285
|
+
self.rotary_dim,
|
286
|
+
self.base,
|
287
|
+
self.max_position_embeddings,
|
288
|
+
)
|
289
|
+
# Get n-d rotational scaling corrected for extrapolation
|
290
|
+
inv_freq_mask = (
|
291
|
+
1
|
292
|
+
- _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
|
293
|
+
) * self.extrapolation_factor
|
294
|
+
if self.extra_method in ["original"]:
|
295
|
+
inv_freq = inv_freq_extrapolation
|
296
|
+
elif self.extra_method in ["yarn", "yarn_linear"]:
|
297
|
+
inv_freq = (
|
298
|
+
inv_freq_interpolation * (1 - inv_freq_mask)
|
299
|
+
+ inv_freq_extrapolation * inv_freq_mask
|
300
|
+
)
|
301
|
+
elif self.extra_method == "yarn_log":
|
302
|
+
inv_freq = torch.exp(
|
303
|
+
torch.log(inv_freq_extrapolation) * inv_freq_mask
|
304
|
+
+ torch.log(inv_freq_interpolation) * (1.0 - inv_freq_mask)
|
305
|
+
)
|
306
|
+
elif self.extra_method == "theta_scale":
|
307
|
+
exponents = torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
|
308
|
+
theta_scale_exponent = self.base ** (
|
309
|
+
math.log(
|
310
|
+
self.max_position_embeddings * self.scaling_factor / (2 * math.pi)
|
311
|
+
)
|
312
|
+
/ math.log(self.max_position_embeddings / (2 * math.pi))
|
313
|
+
)
|
314
|
+
inv_freq = torch.tensor(
|
315
|
+
1.0 / (theta_scale_exponent ** (exponents / self.rotary_dim)),
|
316
|
+
dtype=torch.float32,
|
317
|
+
)
|
318
|
+
else:
|
319
|
+
raise ValueError(f"Unknown extrapolation method: {self.extra_method}")
|
320
|
+
return inv_freq
|
321
|
+
|
322
|
+
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
323
|
+
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
324
|
+
t = torch.arange(
|
325
|
+
self.max_position_embeddings * self.scaling_factor, dtype=torch.float32
|
326
|
+
)
|
327
|
+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
328
|
+
# cos = freqs.cos() * self.mscale
|
329
|
+
# sin = freqs.sin() * self.mscale
|
330
|
+
cos = freqs.cos()
|
331
|
+
sin = freqs.sin()
|
332
|
+
cache = torch.cat((cos, sin), dim=-1)
|
333
|
+
return cache
|
334
|
+
|
335
|
+
|
149
336
|
class Grok1Attention(nn.Module):
|
150
337
|
def __init__(
|
151
338
|
self,
|
@@ -158,7 +345,9 @@ class Grok1Attention(nn.Module):
|
|
158
345
|
rope_theta: float = 10000,
|
159
346
|
quant_config: Optional[QuantizationConfig] = None,
|
160
347
|
reduce_results: bool = True,
|
348
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
161
349
|
load_presharded_attn: bool = False,
|
350
|
+
prefix: str = "",
|
162
351
|
) -> None:
|
163
352
|
super().__init__()
|
164
353
|
self.config = config
|
@@ -184,7 +373,9 @@ class Grok1Attention(nn.Module):
|
|
184
373
|
self.kv_size = self.num_kv_heads * self.head_dim
|
185
374
|
self.scaling = self.head_dim**-0.5
|
186
375
|
self.rope_theta = rope_theta
|
376
|
+
rope_scaling = get_rope_scaling(config)
|
187
377
|
self.load_presharded_attn = load_presharded_attn
|
378
|
+
self.alt_stream = alt_stream or torch.cuda.Stream()
|
188
379
|
|
189
380
|
self.qkv_proj = QKVParallelLinear(
|
190
381
|
hidden_size,
|
@@ -196,6 +387,7 @@ class Grok1Attention(nn.Module):
|
|
196
387
|
tp_rank=attn_tp_rank,
|
197
388
|
tp_size=attn_tp_size,
|
198
389
|
load_presharded_attn=self.load_presharded_attn,
|
390
|
+
prefix=add_prefix("qkv_proj", prefix),
|
199
391
|
)
|
200
392
|
self.o_proj = RowParallelLinear(
|
201
393
|
self.total_num_heads * self.head_dim,
|
@@ -206,6 +398,7 @@ class Grok1Attention(nn.Module):
|
|
206
398
|
tp_rank=attn_tp_rank,
|
207
399
|
tp_size=attn_tp_size,
|
208
400
|
use_presharded_weights=self.load_presharded_attn,
|
401
|
+
prefix=add_prefix("o_proj", prefix),
|
209
402
|
)
|
210
403
|
self.rotary_emb = get_rope(
|
211
404
|
self.head_dim,
|
@@ -215,7 +408,37 @@ class Grok1Attention(nn.Module):
|
|
215
408
|
is_neox_style=True,
|
216
409
|
)
|
217
410
|
|
411
|
+
self.rope_rotate_half_dims = getattr(config, "rope_rotate_half_dims", False)
|
412
|
+
|
413
|
+
if rope_scaling is not None:
|
414
|
+
self.rotary_emb = ScalingRotaryEmbedding(
|
415
|
+
self.head_dim,
|
416
|
+
rotary_dim=(
|
417
|
+
self.head_dim
|
418
|
+
if not self.rope_rotate_half_dims
|
419
|
+
else self.head_dim // 2
|
420
|
+
),
|
421
|
+
base=int(self.rope_theta),
|
422
|
+
is_neox_style=True,
|
423
|
+
**rope_scaling,
|
424
|
+
)
|
425
|
+
pos_encoding_mode = "NONE"
|
426
|
+
else:
|
427
|
+
self.rotary_emb = get_rope(
|
428
|
+
self.head_dim,
|
429
|
+
rotary_dim=(
|
430
|
+
self.head_dim
|
431
|
+
if not self.rope_rotate_half_dims
|
432
|
+
else self.head_dim // 2
|
433
|
+
),
|
434
|
+
max_position=max_position,
|
435
|
+
base=int(self.rope_theta),
|
436
|
+
is_neox_style=True,
|
437
|
+
)
|
438
|
+
pos_encoding_mode = "NONE"
|
439
|
+
|
218
440
|
logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)
|
441
|
+
logit_capping_method = getattr(config, "attn_logit_softcapping_method", "tanh")
|
219
442
|
|
220
443
|
self.attn = RadixAttention(
|
221
444
|
self.num_heads,
|
@@ -225,7 +448,11 @@ class Grok1Attention(nn.Module):
|
|
225
448
|
layer_id=layer_id,
|
226
449
|
logit_cap=logit_cap,
|
227
450
|
quant_config=quant_config,
|
451
|
+
pos_encoding_mode=pos_encoding_mode,
|
452
|
+
logit_capping_method=logit_capping_method,
|
453
|
+
prefix=add_prefix("attn", prefix),
|
228
454
|
)
|
455
|
+
self.attn.xai_temperature_len = getattr(self.config, "attn_temperature_len", -1)
|
229
456
|
|
230
457
|
def forward(
|
231
458
|
self,
|
@@ -257,6 +484,8 @@ class Grok1Attention(nn.Module):
|
|
257
484
|
)
|
258
485
|
|
259
486
|
qkv, _ = self.qkv_proj(hidden_states)
|
487
|
+
dispose_tensor(hidden_states)
|
488
|
+
|
260
489
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
261
490
|
q, k = self.rotary_emb(positions, q, k)
|
262
491
|
|
@@ -289,6 +518,7 @@ class Grok1Attention(nn.Module):
|
|
289
518
|
)
|
290
519
|
|
291
520
|
attn_output = self.attn(q, k, v, forward_batch)
|
521
|
+
del q, k, v, qkv
|
292
522
|
|
293
523
|
if debug_tensor_dump_output_folder:
|
294
524
|
dump_to_file(
|
@@ -313,49 +543,89 @@ class Grok1DecoderLayer(nn.Module):
|
|
313
543
|
load_presharded_moe: bool = False,
|
314
544
|
load_presharded_attn: bool = False,
|
315
545
|
load_presharded_mlp: bool = False,
|
546
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
547
|
+
skip_moe: bool = False,
|
548
|
+
prefix: str = "",
|
316
549
|
) -> None:
|
317
550
|
super().__init__()
|
318
551
|
self.num_experts = config.num_local_experts
|
319
552
|
self.hidden_size = config.hidden_size
|
553
|
+
self.residual_moe = getattr(config, "residual_moe", False)
|
320
554
|
self.layer_id = layer_id
|
555
|
+
self.alt_stream = alt_stream or torch.cuda.Stream()
|
321
556
|
|
322
557
|
rope_theta = getattr(config, "rope_theta", 10000)
|
323
558
|
self.self_attn = Grok1Attention(
|
324
559
|
config=config,
|
325
560
|
hidden_size=self.hidden_size,
|
326
561
|
num_heads=config.num_attention_heads,
|
327
|
-
max_position=
|
562
|
+
max_position=(
|
563
|
+
config.context_len
|
564
|
+
if hasattr(config, "context_len")
|
565
|
+
else config.max_position_embeddings
|
566
|
+
),
|
328
567
|
num_kv_heads=config.num_key_value_heads,
|
329
568
|
layer_id=layer_id,
|
330
569
|
rope_theta=rope_theta,
|
331
570
|
quant_config=quant_config,
|
332
571
|
reduce_results=False,
|
572
|
+
alt_stream=self.alt_stream,
|
333
573
|
load_presharded_attn=load_presharded_attn,
|
574
|
+
prefix=add_prefix("attn", prefix),
|
334
575
|
)
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
config,
|
343
|
-
|
344
|
-
getattr(
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
576
|
+
|
577
|
+
split_gate_up = not getattr(config, "merge_gate_up", True)
|
578
|
+
if self.num_experts > 0:
|
579
|
+
self.block_sparse_moe = Grok1MoE(
|
580
|
+
config=config,
|
581
|
+
layer_id=layer_id,
|
582
|
+
num_experts=config.num_local_experts,
|
583
|
+
top_k=config.num_experts_per_tok,
|
584
|
+
hidden_size=config.hidden_size,
|
585
|
+
intermediate_size=getattr(
|
586
|
+
config,
|
587
|
+
"moe_intermediate_size",
|
588
|
+
getattr(config, "intermediate_size", None),
|
589
|
+
),
|
590
|
+
quant_config=quant_config,
|
591
|
+
reduce_results=not self.residual_moe,
|
592
|
+
use_presharded_weights=load_presharded_moe,
|
593
|
+
inplace=False, # not self.residual_moe,
|
594
|
+
no_combine=False, # self.residual_moe, # just a suggestion to not combine topk
|
595
|
+
prefix=add_prefix("block_sparse_moe", prefix),
|
596
|
+
)
|
597
|
+
if self.residual_moe:
|
598
|
+
self.mlp = Grok1MLP(
|
599
|
+
hidden_size=config.hidden_size,
|
600
|
+
intermediate_size=config.intermediate_size,
|
601
|
+
quant_config=quant_config,
|
602
|
+
reduce_results=False,
|
603
|
+
use_presharded_weights=load_presharded_mlp,
|
604
|
+
layer_id=layer_id,
|
605
|
+
split_gate_up=split_gate_up,
|
606
|
+
)
|
607
|
+
else:
|
608
|
+
raise NotImplementedError()
|
352
609
|
|
353
610
|
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
354
611
|
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
355
612
|
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
356
613
|
self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
357
614
|
|
358
|
-
self.
|
615
|
+
if self.num_experts > 0:
|
616
|
+
if self.residual_moe:
|
617
|
+
# NOTE: self.block_sparse_moe modifies the input in-place,
|
618
|
+
# so we have to call it later. Be aware of any possible related errors.
|
619
|
+
if get_tensor_model_parallel_world_size() > 1:
|
620
|
+
self.ffn = lambda x: tensor_model_parallel_all_reduce(
|
621
|
+
self.moe_with_rmoe(x)
|
622
|
+
)
|
623
|
+
else:
|
624
|
+
self.ffn = self.moe_with_rmoe
|
625
|
+
else:
|
626
|
+
self.ffn = self.block_sparse_moe
|
627
|
+
else:
|
628
|
+
raise NotImplementedError()
|
359
629
|
|
360
630
|
def forward(
|
361
631
|
self,
|
@@ -365,6 +635,10 @@ class Grok1DecoderLayer(nn.Module):
|
|
365
635
|
residual: Optional[torch.Tensor] = None,
|
366
636
|
deferred_norm: Optional[RMSNorm] = None,
|
367
637
|
) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]:
|
638
|
+
|
639
|
+
hidden_states_original = hidden_states
|
640
|
+
residual_original = residual
|
641
|
+
|
368
642
|
# Self Attention
|
369
643
|
if deferred_norm is not None:
|
370
644
|
assert residual is not None
|
@@ -387,6 +661,14 @@ class Grok1DecoderLayer(nn.Module):
|
|
387
661
|
hidden_states,
|
388
662
|
)
|
389
663
|
|
664
|
+
if residual_original is not None:
|
665
|
+
dispose_tensor(residual_original)
|
666
|
+
|
667
|
+
dispose_flag = False
|
668
|
+
if residual is not hidden_states_original:
|
669
|
+
dispose_flag = True
|
670
|
+
dispose_tensor(hidden_states_original)
|
671
|
+
|
390
672
|
hidden_states = self.self_attn(
|
391
673
|
positions=positions,
|
392
674
|
hidden_states=hidden_states,
|
@@ -404,10 +686,23 @@ class Grok1DecoderLayer(nn.Module):
|
|
404
686
|
self.post_attn_norm.variance_epsilon,
|
405
687
|
)
|
406
688
|
|
689
|
+
if not dispose_flag:
|
690
|
+
dispose_tensor(hidden_states_original)
|
691
|
+
|
407
692
|
# Fully Connected
|
408
693
|
hidden_states = self.ffn(hidden_states)
|
409
694
|
return hidden_states, residual, self.post_moe_norm # defer layernorm
|
410
695
|
|
696
|
+
def moe_with_rmoe(self, x):
|
697
|
+
current_stream = torch.cuda.current_stream()
|
698
|
+
self.alt_stream.wait_stream(current_stream)
|
699
|
+
mlp_result = self.mlp(x)
|
700
|
+
with torch.cuda.stream(self.alt_stream):
|
701
|
+
# moe should not be inplace because of stream race condition
|
702
|
+
moe_result = self.block_sparse_moe(x)
|
703
|
+
current_stream.wait_stream(self.alt_stream)
|
704
|
+
return (mlp_result + moe_result) / 1.4142135623730951
|
705
|
+
|
411
706
|
|
412
707
|
class Grok1Model(nn.Module):
|
413
708
|
def __init__(
|
@@ -418,6 +713,8 @@ class Grok1Model(nn.Module):
|
|
418
713
|
load_presharded_embedding: bool = False,
|
419
714
|
load_presharded_attn: bool = False,
|
420
715
|
load_presharded_mlp: bool = False,
|
716
|
+
replicate_embedding: bool = False,
|
717
|
+
prefix: str = "",
|
421
718
|
) -> None:
|
422
719
|
super().__init__()
|
423
720
|
self.config = config
|
@@ -428,7 +725,11 @@ class Grok1Model(nn.Module):
|
|
428
725
|
config.vocab_size,
|
429
726
|
config.hidden_size,
|
430
727
|
use_presharded_weights=load_presharded_embedding,
|
728
|
+
enable_tp=not replicate_embedding,
|
729
|
+
prefix=add_prefix("embed_tokens", prefix),
|
431
730
|
)
|
731
|
+
|
732
|
+
self.alt_stream = torch.cuda.Stream()
|
432
733
|
self.layers = nn.ModuleList(
|
433
734
|
[
|
434
735
|
Grok1DecoderLayer(
|
@@ -438,6 +739,7 @@ class Grok1Model(nn.Module):
|
|
438
739
|
load_presharded_moe=load_presharded_moe,
|
439
740
|
load_presharded_attn=load_presharded_attn,
|
440
741
|
load_presharded_mlp=load_presharded_mlp,
|
742
|
+
alt_stream=self.alt_stream,
|
441
743
|
)
|
442
744
|
for i in range(config.num_hidden_layers)
|
443
745
|
]
|
@@ -507,6 +809,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
507
809
|
self,
|
508
810
|
config: PretrainedConfig,
|
509
811
|
quant_config: Optional[QuantizationConfig] = None,
|
812
|
+
prefix: str = "",
|
510
813
|
) -> None:
|
511
814
|
super().__init__()
|
512
815
|
self.config = config
|
@@ -515,7 +818,8 @@ class Grok1ForCausalLM(nn.Module):
|
|
515
818
|
# Get presharded weights.
|
516
819
|
self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False)
|
517
820
|
self.load_presharded_moe = (
|
518
|
-
|
821
|
+
getattr(config, "load_presharded_moe", True)
|
822
|
+
and self.config.num_local_experts > 0
|
519
823
|
and get_tensor_model_parallel_world_size() > 1
|
520
824
|
)
|
521
825
|
self.load_presharded_attn = getattr(config, "load_presharded_attn", False)
|
@@ -530,6 +834,11 @@ class Grok1ForCausalLM(nn.Module):
|
|
530
834
|
or self.load_presharded_embedding
|
531
835
|
)
|
532
836
|
|
837
|
+
default_replicate_lm_head = False
|
838
|
+
self.replicate_lm_head = getattr(
|
839
|
+
config, "replicate_lm_head", default_replicate_lm_head
|
840
|
+
)
|
841
|
+
|
533
842
|
if self.is_weights_presharded:
|
534
843
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
535
844
|
|
@@ -537,6 +846,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
537
846
|
self.replicate_lm_head = getattr(
|
538
847
|
config, "replicate_lm_head", default_replicate_lm_head
|
539
848
|
)
|
849
|
+
self.replicate_embedding = getattr(config, "replicate_embedding", False)
|
540
850
|
|
541
851
|
self.model = Grok1Model(
|
542
852
|
config,
|
@@ -545,6 +855,8 @@ class Grok1ForCausalLM(nn.Module):
|
|
545
855
|
load_presharded_embedding=self.load_presharded_embedding,
|
546
856
|
load_presharded_attn=self.load_presharded_attn,
|
547
857
|
load_presharded_mlp=self.load_presharded_mlp,
|
858
|
+
replicate_embedding=self.replicate_embedding,
|
859
|
+
prefix=add_prefix("model", prefix),
|
548
860
|
)
|
549
861
|
|
550
862
|
lm_head_params_dtype = None
|
@@ -554,6 +866,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
554
866
|
config.vocab_size,
|
555
867
|
bias=False,
|
556
868
|
params_dtype=lm_head_params_dtype,
|
869
|
+
prefix=add_prefix("lm_head", prefix),
|
557
870
|
)
|
558
871
|
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
559
872
|
else:
|
@@ -562,6 +875,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
562
875
|
config.hidden_size,
|
563
876
|
use_presharded_weights=self.load_presharded_embedding,
|
564
877
|
params_dtype=lm_head_params_dtype,
|
878
|
+
prefix=add_prefix("lm_head", prefix),
|
565
879
|
)
|
566
880
|
self.logits_processor = LogitsProcessor(config)
|
567
881
|
|
@@ -578,6 +892,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
578
892
|
f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, "
|
579
893
|
f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B"
|
580
894
|
)
|
895
|
+
self.loaded_param_names = set()
|
581
896
|
|
582
897
|
def forward(
|
583
898
|
self,
|
@@ -597,11 +912,13 @@ class Grok1ForCausalLM(nn.Module):
|
|
597
912
|
def load_weights(
|
598
913
|
self,
|
599
914
|
weights: Iterable[Tuple[str, torch.Tensor]],
|
600
|
-
num_experts: Optional[int] = None,
|
601
915
|
ignore_parent_name: bool = False,
|
916
|
+
check_hit_names: bool = True,
|
917
|
+
model_config: PretrainedConfig | None = None,
|
602
918
|
) -> dict[str, torch.Tensor]:
|
603
|
-
if
|
604
|
-
|
919
|
+
if model_config is None:
|
920
|
+
model_config = self.config
|
921
|
+
|
605
922
|
stacked_params_mapping = []
|
606
923
|
stacked_params_mapping += [
|
607
924
|
# (param_name, shard_name, shard_id)
|
@@ -617,6 +934,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
617
934
|
|
618
935
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
619
936
|
# (param_name, weight_name, expert_id, shard_id)
|
937
|
+
num_experts = model_config.num_local_experts
|
620
938
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
621
939
|
ckpt_gate_proj_name="w1",
|
622
940
|
ckpt_down_proj_name="w2",
|
@@ -631,23 +949,26 @@ class Grok1ForCausalLM(nn.Module):
|
|
631
949
|
def load_weight_wrapper(
|
632
950
|
name: str, loaded_weight: torch.Tensor, *args, **kwargs
|
633
951
|
):
|
634
|
-
if ignore_parent_name:
|
635
|
-
name = name.split(".")[-1]
|
636
|
-
|
637
|
-
if name not in params_dict:
|
638
|
-
return
|
639
|
-
|
640
952
|
# Fuse constant multipliers into the weights
|
641
953
|
if "lm_head" in name:
|
642
954
|
loaded_weight = (
|
643
955
|
loaded_weight.to(torch.float32)
|
644
|
-
*
|
956
|
+
* model_config.output_multiplier_scale
|
645
957
|
)
|
646
958
|
|
959
|
+
original_name = name
|
960
|
+
if ignore_parent_name:
|
961
|
+
name = name.split(".")[-1]
|
962
|
+
|
963
|
+
if name not in params_dict:
|
964
|
+
logger.info(f"Skipping {name=} in load_weights_wrapper")
|
965
|
+
return
|
966
|
+
|
647
967
|
param = params_dict[name]
|
648
968
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
649
969
|
weight_loader(param, loaded_weight, *args, **kwargs)
|
650
970
|
hit_names.add(name)
|
971
|
+
self.loaded_param_names.add(original_name)
|
651
972
|
|
652
973
|
for name, loaded_weight in weights:
|
653
974
|
if "rotary_emb.inv_freq" in name:
|
@@ -686,19 +1007,22 @@ class Grok1ForCausalLM(nn.Module):
|
|
686
1007
|
|
687
1008
|
load_weight_wrapper(name=name, loaded_weight=loaded_weight)
|
688
1009
|
|
689
|
-
if
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
if len(missing_exclude_scales) > 0:
|
696
|
-
raise ValueError(
|
697
|
-
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
|
1010
|
+
if check_hit_names:
|
1011
|
+
if len(hit_names) > 5:
|
1012
|
+
missing = all_names - hit_names
|
1013
|
+
missing_exclude_scales = {x for x in missing if "scale" not in x}
|
1014
|
+
logger.info(
|
1015
|
+
f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
|
698
1016
|
)
|
1017
|
+
if len(missing_exclude_scales) > 0:
|
1018
|
+
raise ValueError(
|
1019
|
+
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
|
1020
|
+
)
|
699
1021
|
|
700
|
-
|
701
|
-
|
1022
|
+
elif len(hit_names) == 0:
|
1023
|
+
raise ValueError(
|
1024
|
+
f"load_weights failed because it did not hit any names. {all_names=} {hit_names=}"
|
1025
|
+
)
|
702
1026
|
|
703
1027
|
return hit_names
|
704
1028
|
|
@@ -709,7 +1033,11 @@ class Grok1ForCausalLM(nn.Module):
|
|
709
1033
|
"moe_intermediate_size",
|
710
1034
|
getattr(cfg, "intermediate_size", None),
|
711
1035
|
)
|
712
|
-
|
1036
|
+
residual_moe = getattr(cfg, "residual_moe", False)
|
1037
|
+
if cfg.num_local_experts > 0:
|
1038
|
+
num_experts = cfg.num_local_experts + (1 if residual_moe else 0)
|
1039
|
+
else:
|
1040
|
+
num_experts = 1
|
713
1041
|
|
714
1042
|
wq = (
|
715
1043
|
cfg.num_hidden_layers
|