sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py CHANGED
@@ -16,7 +16,6 @@
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
17
17
  """Inference-only Grok1 model."""
18
18
  import functools
19
- import json
20
19
  import logging
21
20
  import math
22
21
  import os
@@ -35,9 +34,16 @@ from sglang.srt.distributed import (
35
34
  tensor_model_parallel_all_gather,
36
35
  tensor_model_parallel_all_reduce,
37
36
  )
38
- from sglang.srt.layers.elementwise import fused_dual_residual_rmsnorm, fused_rmsnorm
37
+ from sglang.srt.layers.activation import GeluAndMul
38
+ from sglang.srt.layers.elementwise import (
39
+ experts_combine_triton,
40
+ fused_dual_residual_rmsnorm,
41
+ fused_rmsnorm,
42
+ gelu_and_mul_triton,
43
+ )
39
44
  from sglang.srt.layers.layernorm import RMSNorm
40
45
  from sglang.srt.layers.linear import (
46
+ MergedColumnParallelLinear,
41
47
  QKVParallelLinear,
42
48
  ReplicatedLinear,
43
49
  RowParallelLinear,
@@ -49,7 +55,12 @@ from sglang.srt.layers.moe.router import fused_moe_router_shim
49
55
  from sglang.srt.layers.moe.topk import TopK
50
56
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
51
57
  from sglang.srt.layers.radix_attention import RadixAttention
52
- from sglang.srt.layers.rotary_embedding import get_rope
58
+ from sglang.srt.layers.rotary_embedding import (
59
+ RotaryEmbedding,
60
+ _yarn_find_correction_range,
61
+ _yarn_get_mscale,
62
+ get_rope,
63
+ )
53
64
  from sglang.srt.layers.vocab_parallel_embedding import (
54
65
  ParallelLMHead,
55
66
  VocabParallelEmbedding,
@@ -58,13 +69,60 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
58
69
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
59
70
  from sglang.srt.model_loader.loader import DefaultModelLoader
60
71
  from sglang.srt.model_loader.weight_utils import default_weight_loader
61
- from sglang.srt.utils import dump_to_file
72
+ from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file
62
73
 
63
74
  logger = logging.getLogger(__name__)
64
75
 
65
76
 
77
+ # Dump tensors for debugging
66
78
  debug_tensor_dump_output_folder = None
79
+ debug_tensor_dump_prefill_only = False
80
+ # Skip all the other tensor dumps, only dump the target logits
81
+ debug_tensor_dump_only_target_logprobs = False
67
82
  debug_tensor_dump_inject = False
83
+ debug_tensor_dump_layers = None
84
+ debug_tensor_dump_test = False
85
+
86
+
87
+ class Grok1MLP(nn.Module):
88
+ def __init__(
89
+ self,
90
+ hidden_size: int,
91
+ intermediate_size: int,
92
+ layer_id: int,
93
+ quant_config: Optional[QuantizationConfig] = None,
94
+ prefix: str = "",
95
+ reduce_results=True,
96
+ use_presharded_weights: bool = False,
97
+ split_gate_up: bool = False,
98
+ ) -> None:
99
+ super().__init__()
100
+
101
+ self.gate_up_proj = MergedColumnParallelLinear(
102
+ hidden_size,
103
+ [intermediate_size] * 2,
104
+ bias=False,
105
+ quant_config=quant_config,
106
+ prefix=add_prefix("gate_up_proj", prefix),
107
+ use_presharded_weights=use_presharded_weights,
108
+ )
109
+ self.down_proj = RowParallelLinear(
110
+ intermediate_size,
111
+ hidden_size,
112
+ bias=False,
113
+ quant_config=quant_config,
114
+ prefix=add_prefix("down_proj", prefix),
115
+ reduce_results=reduce_results,
116
+ use_presharded_weights=use_presharded_weights,
117
+ )
118
+ self.act_fn = GeluAndMul(approximate="tanh")
119
+ self.layer_id = layer_id
120
+
121
+ def forward(self, x):
122
+ gate_up, _ = self.gate_up_proj(x)
123
+ x, _ = gelu_and_mul_triton(gate_up)
124
+ x, _ = self.down_proj(x)
125
+ return x
68
126
 
69
127
 
70
128
  class Grok1MoE(nn.Module):
@@ -87,10 +145,11 @@ class Grok1MoE(nn.Module):
87
145
  params_dtype: Optional[torch.dtype] = None,
88
146
  quant_config: Optional[QuantizationConfig] = None,
89
147
  tp_size: Optional[int] = None,
90
- reduce_results=True,
148
+ reduce_results: bool = True,
91
149
  use_presharded_weights: bool = False,
92
150
  inplace: bool = True,
93
151
  no_combine: bool = False,
152
+ prefix: str = "",
94
153
  ):
95
154
  super().__init__()
96
155
  self.hidden_size = hidden_size
@@ -135,7 +194,6 @@ class Grok1MoE(nn.Module):
135
194
  intermediate_size=intermediate_size,
136
195
  params_dtype=params_dtype,
137
196
  quant_config=quant_config,
138
- tp_size=tp_size,
139
197
  activation="gelu",
140
198
  **kwargs,
141
199
  )
@@ -146,6 +204,135 @@ class Grok1MoE(nn.Module):
146
204
  return self.experts(hidden_states, topk_output)
147
205
 
148
206
 
207
+ def _yarn_linear_ramp_mask(
208
+ low: float, high: float, dim: int, dtype: torch.dtype
209
+ ) -> torch.Tensor:
210
+ if low == high:
211
+ low -= 0.001 # Prevent singularity
212
+
213
+ linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
214
+ ramp_func = torch.clamp(linear_func, 0, 1)
215
+ return ramp_func
216
+
217
+
218
+ def get_rope_scaling(config):
219
+ rope_type = getattr(config, "rope_type", None)
220
+ if rope_type:
221
+ original_max_position_embeddings = getattr(
222
+ config, "original_max_position_embeddings", None
223
+ )
224
+ scaling_factor = getattr(config, "scaling_factor", None)
225
+ extrapolation_factor = getattr(config, "extrapolation_factor", 1.0)
226
+ attn_factor = getattr(config, "attn_factor", 1.0)
227
+ beta_fast = getattr(config, "beta_fast", 32)
228
+ beta_slow = getattr(config, "beta_slow", 1)
229
+ rope_scaling = {
230
+ "extra_method": rope_type,
231
+ "max_position_embeddings": original_max_position_embeddings,
232
+ "scaling_factor": scaling_factor,
233
+ "extrapolation_factor": extrapolation_factor,
234
+ "attn_factor": attn_factor,
235
+ "beta_fast": beta_fast,
236
+ "beta_slow": beta_slow,
237
+ "dtype": torch.float,
238
+ }
239
+ return rope_scaling
240
+ else:
241
+ return None
242
+
243
+
244
+ class ScalingRotaryEmbedding(RotaryEmbedding):
245
+ """Scale the RotaryEmbedding in a way similar to YaRN method. https://arxiv.org/pdf/2309.00071."""
246
+
247
+ def __init__(
248
+ self,
249
+ head_size: int,
250
+ rotary_dim: int,
251
+ max_position_embeddings: int,
252
+ base: int,
253
+ is_neox_style: bool,
254
+ scaling_factor: float,
255
+ dtype: torch.dtype,
256
+ *,
257
+ extra_method: str = "yarn_log",
258
+ extrapolation_factor: float = 1,
259
+ attn_factor: float = 1,
260
+ beta_fast: int = 32,
261
+ beta_slow: int = 1,
262
+ ) -> None:
263
+ self.scaling_factor = scaling_factor
264
+ self.extra_method = extra_method
265
+ self.extrapolation_factor = extrapolation_factor
266
+ self.attn_factor = attn_factor
267
+ self.beta_fast = beta_fast
268
+ self.beta_slow = beta_slow
269
+ # Get n-d magnitude scaling corrected for interpolation
270
+ self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
271
+ super().__init__(
272
+ head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
273
+ )
274
+
275
+ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
276
+ pos_freqs = self.base ** (
277
+ torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
278
+ )
279
+ inv_freq_extrapolation = 1.0 / pos_freqs
280
+ inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
281
+
282
+ low, high = _yarn_find_correction_range(
283
+ self.beta_fast,
284
+ self.beta_slow,
285
+ self.rotary_dim,
286
+ self.base,
287
+ self.max_position_embeddings,
288
+ )
289
+ # Get n-d rotational scaling corrected for extrapolation
290
+ inv_freq_mask = (
291
+ 1
292
+ - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
293
+ ) * self.extrapolation_factor
294
+ if self.extra_method in ["original"]:
295
+ inv_freq = inv_freq_extrapolation
296
+ elif self.extra_method in ["yarn", "yarn_linear"]:
297
+ inv_freq = (
298
+ inv_freq_interpolation * (1 - inv_freq_mask)
299
+ + inv_freq_extrapolation * inv_freq_mask
300
+ )
301
+ elif self.extra_method == "yarn_log":
302
+ inv_freq = torch.exp(
303
+ torch.log(inv_freq_extrapolation) * inv_freq_mask
304
+ + torch.log(inv_freq_interpolation) * (1.0 - inv_freq_mask)
305
+ )
306
+ elif self.extra_method == "theta_scale":
307
+ exponents = torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
308
+ theta_scale_exponent = self.base ** (
309
+ math.log(
310
+ self.max_position_embeddings * self.scaling_factor / (2 * math.pi)
311
+ )
312
+ / math.log(self.max_position_embeddings / (2 * math.pi))
313
+ )
314
+ inv_freq = torch.tensor(
315
+ 1.0 / (theta_scale_exponent ** (exponents / self.rotary_dim)),
316
+ dtype=torch.float32,
317
+ )
318
+ else:
319
+ raise ValueError(f"Unknown extrapolation method: {self.extra_method}")
320
+ return inv_freq
321
+
322
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
323
+ inv_freq = self._compute_inv_freq(self.scaling_factor)
324
+ t = torch.arange(
325
+ self.max_position_embeddings * self.scaling_factor, dtype=torch.float32
326
+ )
327
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
328
+ # cos = freqs.cos() * self.mscale
329
+ # sin = freqs.sin() * self.mscale
330
+ cos = freqs.cos()
331
+ sin = freqs.sin()
332
+ cache = torch.cat((cos, sin), dim=-1)
333
+ return cache
334
+
335
+
149
336
  class Grok1Attention(nn.Module):
150
337
  def __init__(
151
338
  self,
@@ -158,7 +345,9 @@ class Grok1Attention(nn.Module):
158
345
  rope_theta: float = 10000,
159
346
  quant_config: Optional[QuantizationConfig] = None,
160
347
  reduce_results: bool = True,
348
+ alt_stream: Optional[torch.cuda.Stream] = None,
161
349
  load_presharded_attn: bool = False,
350
+ prefix: str = "",
162
351
  ) -> None:
163
352
  super().__init__()
164
353
  self.config = config
@@ -184,7 +373,9 @@ class Grok1Attention(nn.Module):
184
373
  self.kv_size = self.num_kv_heads * self.head_dim
185
374
  self.scaling = self.head_dim**-0.5
186
375
  self.rope_theta = rope_theta
376
+ rope_scaling = get_rope_scaling(config)
187
377
  self.load_presharded_attn = load_presharded_attn
378
+ self.alt_stream = alt_stream or torch.cuda.Stream()
188
379
 
189
380
  self.qkv_proj = QKVParallelLinear(
190
381
  hidden_size,
@@ -196,6 +387,7 @@ class Grok1Attention(nn.Module):
196
387
  tp_rank=attn_tp_rank,
197
388
  tp_size=attn_tp_size,
198
389
  load_presharded_attn=self.load_presharded_attn,
390
+ prefix=add_prefix("qkv_proj", prefix),
199
391
  )
200
392
  self.o_proj = RowParallelLinear(
201
393
  self.total_num_heads * self.head_dim,
@@ -206,6 +398,7 @@ class Grok1Attention(nn.Module):
206
398
  tp_rank=attn_tp_rank,
207
399
  tp_size=attn_tp_size,
208
400
  use_presharded_weights=self.load_presharded_attn,
401
+ prefix=add_prefix("o_proj", prefix),
209
402
  )
210
403
  self.rotary_emb = get_rope(
211
404
  self.head_dim,
@@ -215,7 +408,37 @@ class Grok1Attention(nn.Module):
215
408
  is_neox_style=True,
216
409
  )
217
410
 
411
+ self.rope_rotate_half_dims = getattr(config, "rope_rotate_half_dims", False)
412
+
413
+ if rope_scaling is not None:
414
+ self.rotary_emb = ScalingRotaryEmbedding(
415
+ self.head_dim,
416
+ rotary_dim=(
417
+ self.head_dim
418
+ if not self.rope_rotate_half_dims
419
+ else self.head_dim // 2
420
+ ),
421
+ base=int(self.rope_theta),
422
+ is_neox_style=True,
423
+ **rope_scaling,
424
+ )
425
+ pos_encoding_mode = "NONE"
426
+ else:
427
+ self.rotary_emb = get_rope(
428
+ self.head_dim,
429
+ rotary_dim=(
430
+ self.head_dim
431
+ if not self.rope_rotate_half_dims
432
+ else self.head_dim // 2
433
+ ),
434
+ max_position=max_position,
435
+ base=int(self.rope_theta),
436
+ is_neox_style=True,
437
+ )
438
+ pos_encoding_mode = "NONE"
439
+
218
440
  logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)
441
+ logit_capping_method = getattr(config, "attn_logit_softcapping_method", "tanh")
219
442
 
220
443
  self.attn = RadixAttention(
221
444
  self.num_heads,
@@ -225,7 +448,11 @@ class Grok1Attention(nn.Module):
225
448
  layer_id=layer_id,
226
449
  logit_cap=logit_cap,
227
450
  quant_config=quant_config,
451
+ pos_encoding_mode=pos_encoding_mode,
452
+ logit_capping_method=logit_capping_method,
453
+ prefix=add_prefix("attn", prefix),
228
454
  )
455
+ self.attn.xai_temperature_len = getattr(self.config, "attn_temperature_len", -1)
229
456
 
230
457
  def forward(
231
458
  self,
@@ -257,6 +484,8 @@ class Grok1Attention(nn.Module):
257
484
  )
258
485
 
259
486
  qkv, _ = self.qkv_proj(hidden_states)
487
+ dispose_tensor(hidden_states)
488
+
260
489
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
261
490
  q, k = self.rotary_emb(positions, q, k)
262
491
 
@@ -289,6 +518,7 @@ class Grok1Attention(nn.Module):
289
518
  )
290
519
 
291
520
  attn_output = self.attn(q, k, v, forward_batch)
521
+ del q, k, v, qkv
292
522
 
293
523
  if debug_tensor_dump_output_folder:
294
524
  dump_to_file(
@@ -313,49 +543,89 @@ class Grok1DecoderLayer(nn.Module):
313
543
  load_presharded_moe: bool = False,
314
544
  load_presharded_attn: bool = False,
315
545
  load_presharded_mlp: bool = False,
546
+ alt_stream: Optional[torch.cuda.Stream] = None,
547
+ skip_moe: bool = False,
548
+ prefix: str = "",
316
549
  ) -> None:
317
550
  super().__init__()
318
551
  self.num_experts = config.num_local_experts
319
552
  self.hidden_size = config.hidden_size
553
+ self.residual_moe = getattr(config, "residual_moe", False)
320
554
  self.layer_id = layer_id
555
+ self.alt_stream = alt_stream or torch.cuda.Stream()
321
556
 
322
557
  rope_theta = getattr(config, "rope_theta", 10000)
323
558
  self.self_attn = Grok1Attention(
324
559
  config=config,
325
560
  hidden_size=self.hidden_size,
326
561
  num_heads=config.num_attention_heads,
327
- max_position=config.max_position_embeddings,
562
+ max_position=(
563
+ config.context_len
564
+ if hasattr(config, "context_len")
565
+ else config.max_position_embeddings
566
+ ),
328
567
  num_kv_heads=config.num_key_value_heads,
329
568
  layer_id=layer_id,
330
569
  rope_theta=rope_theta,
331
570
  quant_config=quant_config,
332
571
  reduce_results=False,
572
+ alt_stream=self.alt_stream,
333
573
  load_presharded_attn=load_presharded_attn,
574
+ prefix=add_prefix("attn", prefix),
334
575
  )
335
- self.block_sparse_moe = Grok1MoE(
336
- config=config,
337
- layer_id=layer_id,
338
- num_experts=config.num_local_experts,
339
- top_k=config.num_experts_per_tok,
340
- hidden_size=config.hidden_size,
341
- intermediate_size=getattr(
342
- config,
343
- "moe_intermediate_size",
344
- getattr(config, "intermediate_size", None),
345
- ),
346
- quant_config=quant_config,
347
- reduce_results=True,
348
- use_presharded_weights=load_presharded_moe,
349
- inplace=True,
350
- no_combine=False, # just a suggestion to not combine topk
351
- )
576
+
577
+ split_gate_up = not getattr(config, "merge_gate_up", True)
578
+ if self.num_experts > 0:
579
+ self.block_sparse_moe = Grok1MoE(
580
+ config=config,
581
+ layer_id=layer_id,
582
+ num_experts=config.num_local_experts,
583
+ top_k=config.num_experts_per_tok,
584
+ hidden_size=config.hidden_size,
585
+ intermediate_size=getattr(
586
+ config,
587
+ "moe_intermediate_size",
588
+ getattr(config, "intermediate_size", None),
589
+ ),
590
+ quant_config=quant_config,
591
+ reduce_results=not self.residual_moe,
592
+ use_presharded_weights=load_presharded_moe,
593
+ inplace=False, # not self.residual_moe,
594
+ no_combine=False, # self.residual_moe, # just a suggestion to not combine topk
595
+ prefix=add_prefix("block_sparse_moe", prefix),
596
+ )
597
+ if self.residual_moe:
598
+ self.mlp = Grok1MLP(
599
+ hidden_size=config.hidden_size,
600
+ intermediate_size=config.intermediate_size,
601
+ quant_config=quant_config,
602
+ reduce_results=False,
603
+ use_presharded_weights=load_presharded_mlp,
604
+ layer_id=layer_id,
605
+ split_gate_up=split_gate_up,
606
+ )
607
+ else:
608
+ raise NotImplementedError()
352
609
 
353
610
  self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
354
611
  self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
355
612
  self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
356
613
  self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
357
614
 
358
- self.ffn = self.block_sparse_moe
615
+ if self.num_experts > 0:
616
+ if self.residual_moe:
617
+ # NOTE: self.block_sparse_moe modifies the input in-place,
618
+ # so we have to call it later. Be aware of any possible related errors.
619
+ if get_tensor_model_parallel_world_size() > 1:
620
+ self.ffn = lambda x: tensor_model_parallel_all_reduce(
621
+ self.moe_with_rmoe(x)
622
+ )
623
+ else:
624
+ self.ffn = self.moe_with_rmoe
625
+ else:
626
+ self.ffn = self.block_sparse_moe
627
+ else:
628
+ raise NotImplementedError()
359
629
 
360
630
  def forward(
361
631
  self,
@@ -365,6 +635,10 @@ class Grok1DecoderLayer(nn.Module):
365
635
  residual: Optional[torch.Tensor] = None,
366
636
  deferred_norm: Optional[RMSNorm] = None,
367
637
  ) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]:
638
+
639
+ hidden_states_original = hidden_states
640
+ residual_original = residual
641
+
368
642
  # Self Attention
369
643
  if deferred_norm is not None:
370
644
  assert residual is not None
@@ -387,6 +661,14 @@ class Grok1DecoderLayer(nn.Module):
387
661
  hidden_states,
388
662
  )
389
663
 
664
+ if residual_original is not None:
665
+ dispose_tensor(residual_original)
666
+
667
+ dispose_flag = False
668
+ if residual is not hidden_states_original:
669
+ dispose_flag = True
670
+ dispose_tensor(hidden_states_original)
671
+
390
672
  hidden_states = self.self_attn(
391
673
  positions=positions,
392
674
  hidden_states=hidden_states,
@@ -404,10 +686,23 @@ class Grok1DecoderLayer(nn.Module):
404
686
  self.post_attn_norm.variance_epsilon,
405
687
  )
406
688
 
689
+ if not dispose_flag:
690
+ dispose_tensor(hidden_states_original)
691
+
407
692
  # Fully Connected
408
693
  hidden_states = self.ffn(hidden_states)
409
694
  return hidden_states, residual, self.post_moe_norm # defer layernorm
410
695
 
696
+ def moe_with_rmoe(self, x):
697
+ current_stream = torch.cuda.current_stream()
698
+ self.alt_stream.wait_stream(current_stream)
699
+ mlp_result = self.mlp(x)
700
+ with torch.cuda.stream(self.alt_stream):
701
+ # moe should not be inplace because of stream race condition
702
+ moe_result = self.block_sparse_moe(x)
703
+ current_stream.wait_stream(self.alt_stream)
704
+ return (mlp_result + moe_result) / 1.4142135623730951
705
+
411
706
 
412
707
  class Grok1Model(nn.Module):
413
708
  def __init__(
@@ -418,6 +713,8 @@ class Grok1Model(nn.Module):
418
713
  load_presharded_embedding: bool = False,
419
714
  load_presharded_attn: bool = False,
420
715
  load_presharded_mlp: bool = False,
716
+ replicate_embedding: bool = False,
717
+ prefix: str = "",
421
718
  ) -> None:
422
719
  super().__init__()
423
720
  self.config = config
@@ -428,7 +725,11 @@ class Grok1Model(nn.Module):
428
725
  config.vocab_size,
429
726
  config.hidden_size,
430
727
  use_presharded_weights=load_presharded_embedding,
728
+ enable_tp=not replicate_embedding,
729
+ prefix=add_prefix("embed_tokens", prefix),
431
730
  )
731
+
732
+ self.alt_stream = torch.cuda.Stream()
432
733
  self.layers = nn.ModuleList(
433
734
  [
434
735
  Grok1DecoderLayer(
@@ -438,6 +739,7 @@ class Grok1Model(nn.Module):
438
739
  load_presharded_moe=load_presharded_moe,
439
740
  load_presharded_attn=load_presharded_attn,
440
741
  load_presharded_mlp=load_presharded_mlp,
742
+ alt_stream=self.alt_stream,
441
743
  )
442
744
  for i in range(config.num_hidden_layers)
443
745
  ]
@@ -507,6 +809,7 @@ class Grok1ForCausalLM(nn.Module):
507
809
  self,
508
810
  config: PretrainedConfig,
509
811
  quant_config: Optional[QuantizationConfig] = None,
812
+ prefix: str = "",
510
813
  ) -> None:
511
814
  super().__init__()
512
815
  self.config = config
@@ -515,7 +818,8 @@ class Grok1ForCausalLM(nn.Module):
515
818
  # Get presharded weights.
516
819
  self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False)
517
820
  self.load_presharded_moe = (
518
- self.config.num_local_experts > 0
821
+ getattr(config, "load_presharded_moe", True)
822
+ and self.config.num_local_experts > 0
519
823
  and get_tensor_model_parallel_world_size() > 1
520
824
  )
521
825
  self.load_presharded_attn = getattr(config, "load_presharded_attn", False)
@@ -530,6 +834,11 @@ class Grok1ForCausalLM(nn.Module):
530
834
  or self.load_presharded_embedding
531
835
  )
532
836
 
837
+ default_replicate_lm_head = False
838
+ self.replicate_lm_head = getattr(
839
+ config, "replicate_lm_head", default_replicate_lm_head
840
+ )
841
+
533
842
  if self.is_weights_presharded:
534
843
  setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
535
844
 
@@ -537,6 +846,7 @@ class Grok1ForCausalLM(nn.Module):
537
846
  self.replicate_lm_head = getattr(
538
847
  config, "replicate_lm_head", default_replicate_lm_head
539
848
  )
849
+ self.replicate_embedding = getattr(config, "replicate_embedding", False)
540
850
 
541
851
  self.model = Grok1Model(
542
852
  config,
@@ -545,6 +855,8 @@ class Grok1ForCausalLM(nn.Module):
545
855
  load_presharded_embedding=self.load_presharded_embedding,
546
856
  load_presharded_attn=self.load_presharded_attn,
547
857
  load_presharded_mlp=self.load_presharded_mlp,
858
+ replicate_embedding=self.replicate_embedding,
859
+ prefix=add_prefix("model", prefix),
548
860
  )
549
861
 
550
862
  lm_head_params_dtype = None
@@ -554,6 +866,7 @@ class Grok1ForCausalLM(nn.Module):
554
866
  config.vocab_size,
555
867
  bias=False,
556
868
  params_dtype=lm_head_params_dtype,
869
+ prefix=add_prefix("lm_head", prefix),
557
870
  )
558
871
  self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
559
872
  else:
@@ -562,6 +875,7 @@ class Grok1ForCausalLM(nn.Module):
562
875
  config.hidden_size,
563
876
  use_presharded_weights=self.load_presharded_embedding,
564
877
  params_dtype=lm_head_params_dtype,
878
+ prefix=add_prefix("lm_head", prefix),
565
879
  )
566
880
  self.logits_processor = LogitsProcessor(config)
567
881
 
@@ -578,6 +892,7 @@ class Grok1ForCausalLM(nn.Module):
578
892
  f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, "
579
893
  f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B"
580
894
  )
895
+ self.loaded_param_names = set()
581
896
 
582
897
  def forward(
583
898
  self,
@@ -597,11 +912,13 @@ class Grok1ForCausalLM(nn.Module):
597
912
  def load_weights(
598
913
  self,
599
914
  weights: Iterable[Tuple[str, torch.Tensor]],
600
- num_experts: Optional[int] = None,
601
915
  ignore_parent_name: bool = False,
916
+ check_hit_names: bool = True,
917
+ model_config: PretrainedConfig | None = None,
602
918
  ) -> dict[str, torch.Tensor]:
603
- if num_experts is None:
604
- num_experts = self.config.num_local_experts
919
+ if model_config is None:
920
+ model_config = self.config
921
+
605
922
  stacked_params_mapping = []
606
923
  stacked_params_mapping += [
607
924
  # (param_name, shard_name, shard_id)
@@ -617,6 +934,7 @@ class Grok1ForCausalLM(nn.Module):
617
934
 
618
935
  # Params for weights, fp8 weight scales, fp8 activation scales
619
936
  # (param_name, weight_name, expert_id, shard_id)
937
+ num_experts = model_config.num_local_experts
620
938
  expert_params_mapping = FusedMoE.make_expert_params_mapping(
621
939
  ckpt_gate_proj_name="w1",
622
940
  ckpt_down_proj_name="w2",
@@ -631,23 +949,26 @@ class Grok1ForCausalLM(nn.Module):
631
949
  def load_weight_wrapper(
632
950
  name: str, loaded_weight: torch.Tensor, *args, **kwargs
633
951
  ):
634
- if ignore_parent_name:
635
- name = name.split(".")[-1]
636
-
637
- if name not in params_dict:
638
- return
639
-
640
952
  # Fuse constant multipliers into the weights
641
953
  if "lm_head" in name:
642
954
  loaded_weight = (
643
955
  loaded_weight.to(torch.float32)
644
- * self.config.output_multiplier_scale
956
+ * model_config.output_multiplier_scale
645
957
  )
646
958
 
959
+ original_name = name
960
+ if ignore_parent_name:
961
+ name = name.split(".")[-1]
962
+
963
+ if name not in params_dict:
964
+ logger.info(f"Skipping {name=} in load_weights_wrapper")
965
+ return
966
+
647
967
  param = params_dict[name]
648
968
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
649
969
  weight_loader(param, loaded_weight, *args, **kwargs)
650
970
  hit_names.add(name)
971
+ self.loaded_param_names.add(original_name)
651
972
 
652
973
  for name, loaded_weight in weights:
653
974
  if "rotary_emb.inv_freq" in name:
@@ -686,19 +1007,22 @@ class Grok1ForCausalLM(nn.Module):
686
1007
 
687
1008
  load_weight_wrapper(name=name, loaded_weight=loaded_weight)
688
1009
 
689
- if len(hit_names) > 5:
690
- missing = all_names - hit_names
691
- missing_exclude_scales = {x for x in missing if "scale" not in x}
692
- logger.info(
693
- f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
694
- )
695
- if len(missing_exclude_scales) > 0:
696
- raise ValueError(
697
- f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
1010
+ if check_hit_names:
1011
+ if len(hit_names) > 5:
1012
+ missing = all_names - hit_names
1013
+ missing_exclude_scales = {x for x in missing if "scale" not in x}
1014
+ logger.info(
1015
+ f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
698
1016
  )
1017
+ if len(missing_exclude_scales) > 0:
1018
+ raise ValueError(
1019
+ f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
1020
+ )
699
1021
 
700
- elif len(hit_names) == 0:
701
- raise ValueError("load_weights failed because it did not hit any names.")
1022
+ elif len(hit_names) == 0:
1023
+ raise ValueError(
1024
+ f"load_weights failed because it did not hit any names. {all_names=} {hit_names=}"
1025
+ )
702
1026
 
703
1027
  return hit_names
704
1028
 
@@ -709,7 +1033,11 @@ class Grok1ForCausalLM(nn.Module):
709
1033
  "moe_intermediate_size",
710
1034
  getattr(cfg, "intermediate_size", None),
711
1035
  )
712
- num_experts = cfg.num_local_experts
1036
+ residual_moe = getattr(cfg, "residual_moe", False)
1037
+ if cfg.num_local_experts > 0:
1038
+ num_experts = cfg.num_local_experts + (1 if residual_moe else 0)
1039
+ else:
1040
+ num_experts = 1
713
1041
 
714
1042
  wq = (
715
1043
  cfg.num_hidden_layers