sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +8 -0
  3. sglang/srt/configs/model_config.py +6 -0
  4. sglang/srt/configs/step3_vl.py +172 -0
  5. sglang/srt/conversation.py +23 -0
  6. sglang/srt/disaggregation/decode.py +2 -8
  7. sglang/srt/disaggregation/prefill.py +2 -6
  8. sglang/srt/distributed/parallel_state.py +86 -1
  9. sglang/srt/entrypoints/engine.py +14 -18
  10. sglang/srt/entrypoints/http_server.py +23 -3
  11. sglang/srt/entrypoints/openai/protocol.py +3 -1
  12. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  13. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  14. sglang/srt/eplb/expert_distribution.py +5 -0
  15. sglang/srt/eplb/expert_location.py +17 -6
  16. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  17. sglang/srt/eplb/expert_location_updater.py +2 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/step3_detector.py +436 -0
  20. sglang/srt/hf_transformers_utils.py +2 -0
  21. sglang/srt/jinja_template_utils.py +4 -1
  22. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  23. sglang/srt/layers/moe/ep_moe/layer.py +98 -603
  24. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
  29. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  30. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  31. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  32. sglang/srt/layers/moe/topk.py +6 -2
  33. sglang/srt/layers/quantization/fp8.py +0 -18
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -0
  35. sglang/srt/layers/quantization/unquant.py +0 -8
  36. sglang/srt/layers/quantization/w4afp8.py +1 -0
  37. sglang/srt/managers/cache_controller.py +143 -45
  38. sglang/srt/managers/data_parallel_controller.py +6 -0
  39. sglang/srt/managers/io_struct.py +12 -2
  40. sglang/srt/managers/scheduler.py +116 -669
  41. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  42. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  43. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  44. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  45. sglang/srt/managers/template_manager.py +62 -19
  46. sglang/srt/managers/tokenizer_manager.py +166 -83
  47. sglang/srt/managers/tp_worker.py +9 -0
  48. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  49. sglang/srt/mem_cache/hicache_storage.py +45 -11
  50. sglang/srt/mem_cache/hiradix_cache.py +15 -4
  51. sglang/srt/mem_cache/memory_pool_host.py +73 -1
  52. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  53. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  54. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
  55. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  56. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  57. sglang/srt/model_executor/model_runner.py +20 -13
  58. sglang/srt/models/arcee.py +532 -0
  59. sglang/srt/models/deepseek_v2.py +15 -56
  60. sglang/srt/models/glm4_moe.py +3 -1
  61. sglang/srt/models/granitemoe.py +3 -0
  62. sglang/srt/models/grok.py +3 -0
  63. sglang/srt/models/hunyuan.py +1 -0
  64. sglang/srt/models/llama4.py +3 -0
  65. sglang/srt/models/mixtral.py +3 -0
  66. sglang/srt/models/olmoe.py +3 -0
  67. sglang/srt/models/phimoe.py +1 -0
  68. sglang/srt/models/qwen3_moe.py +12 -69
  69. sglang/srt/models/step3_vl.py +994 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/poll_based_barrier.py +31 -0
  73. sglang/srt/reasoning_parser.py +2 -1
  74. sglang/srt/server_args.py +18 -13
  75. sglang/srt/speculative/eagle_worker.py +2 -0
  76. sglang/srt/two_batch_overlap.py +8 -3
  77. sglang/test/test_utils.py +53 -0
  78. sglang/utils.py +0 -11
  79. sglang/version.py +1 -1
  80. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
  81. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
  82. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,532 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """Inference-only Arcee Foundational Model (AFM) compatible with HuggingFace weights."""
15
+
16
+ import logging
17
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from torch import nn
21
+ from transformers import LlamaConfig
22
+
23
+ from sglang.srt.distributed import (
24
+ get_pp_group,
25
+ get_tensor_model_parallel_rank,
26
+ get_tensor_model_parallel_world_size,
27
+ )
28
+ from sglang.srt.layers.activation import get_act_fn
29
+ from sglang.srt.layers.layernorm import RMSNorm
30
+ from sglang.srt.layers.linear import (
31
+ ColumnParallelLinear,
32
+ QKVParallelLinear,
33
+ RowParallelLinear,
34
+ )
35
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
36
+ from sglang.srt.layers.pooler import Pooler, PoolingType
37
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
+ from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.rotary_embedding import get_rope
40
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
41
+ from sglang.srt.layers.vocab_parallel_embedding import (
42
+ ParallelLMHead,
43
+ VocabParallelEmbedding,
44
+ )
45
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
46
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
47
+ from sglang.srt.model_loader.weight_utils import (
48
+ default_weight_loader,
49
+ kv_cache_scales_loader,
50
+ maybe_remap_kv_scale_name,
51
+ )
52
+ from sglang.srt.utils import add_prefix, make_layers
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ class ArceeMLP(nn.Module):
58
+ """
59
+ MLP block for the Arcee model, using a ReLU-squared activation function.
60
+ This differs from the Llama SwiGLU activation.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ hidden_size: int,
66
+ intermediate_size: int,
67
+ hidden_act: str,
68
+ quant_config: Optional[QuantizationConfig] = None,
69
+ prefix: str = "",
70
+ reduce_results: bool = True,
71
+ ) -> None:
72
+ super().__init__()
73
+ # Arcee uses a single up-projection, not a merged gate/up projection.
74
+ self.up_proj = ColumnParallelLinear(
75
+ hidden_size,
76
+ intermediate_size,
77
+ bias=False,
78
+ quant_config=quant_config,
79
+ prefix=add_prefix("up_proj", prefix),
80
+ )
81
+ self.down_proj = RowParallelLinear(
82
+ intermediate_size,
83
+ hidden_size,
84
+ bias=False,
85
+ quant_config=quant_config,
86
+ prefix=add_prefix("down_proj", prefix),
87
+ reduce_results=reduce_results,
88
+ )
89
+ if hidden_act != "relu2":
90
+ raise ValueError(
91
+ f"Unsupported activation: {hidden_act}. "
92
+ "Arcee model in SGLang only supports 'relu2'."
93
+ )
94
+ # The activation function is relu(x)^2
95
+ self.act_fn = get_act_fn("relu2")
96
+
97
+ def forward(self, x, forward_batch=None):
98
+ x, _ = self.up_proj(x)
99
+ x = self.act_fn(x)
100
+ x, _ = self.down_proj(x)
101
+ return x
102
+
103
+
104
+ class ArceeAttention(nn.Module):
105
+ def __init__(
106
+ self,
107
+ config: LlamaConfig,
108
+ hidden_size: int,
109
+ num_heads: int,
110
+ num_kv_heads: int,
111
+ layer_id: int = 0,
112
+ rope_theta: float = 10000,
113
+ rope_scaling: Optional[Dict[str, Any]] = None,
114
+ rope_is_neox_style: bool = True,
115
+ max_position_embeddings: int = 8192,
116
+ quant_config: Optional[QuantizationConfig] = None,
117
+ prefix: str = "",
118
+ bias: bool = False,
119
+ ) -> None:
120
+ super().__init__()
121
+ self.hidden_size = hidden_size
122
+ tp_size = get_tensor_model_parallel_world_size()
123
+ self.total_num_heads = num_heads
124
+ assert self.total_num_heads % tp_size == 0
125
+ self.num_heads = self.total_num_heads // tp_size
126
+ self.total_num_kv_heads = num_kv_heads
127
+ if self.total_num_kv_heads >= tp_size:
128
+ assert self.total_num_kv_heads % tp_size == 0
129
+ else:
130
+ assert tp_size % self.total_num_kv_heads == 0
131
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
132
+ self.head_dim = getattr(config, "head_dim", None)
133
+ if self.head_dim is None:
134
+ self.head_dim = self.hidden_size // self.total_num_heads
135
+ self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
136
+ self.rotary_dim = int(self.partial_rotary_factor * self.head_dim)
137
+ self.q_size = self.num_heads * self.head_dim
138
+ self.kv_size = self.num_kv_heads * self.head_dim
139
+ self.scaling = self.head_dim**-0.5
140
+ self.rope_theta = rope_theta
141
+ self.max_position_embeddings = max_position_embeddings
142
+
143
+ self.qkv_proj = QKVParallelLinear(
144
+ hidden_size,
145
+ self.head_dim,
146
+ self.total_num_heads,
147
+ self.total_num_kv_heads,
148
+ bias=bias,
149
+ quant_config=quant_config,
150
+ prefix=add_prefix("qkv_proj", prefix),
151
+ )
152
+ self.o_proj = RowParallelLinear(
153
+ self.total_num_heads * self.head_dim,
154
+ hidden_size,
155
+ bias=bias,
156
+ quant_config=quant_config,
157
+ prefix=add_prefix("o_proj", prefix),
158
+ )
159
+
160
+ self.rotary_emb = get_rope(
161
+ self.head_dim,
162
+ rotary_dim=self.rotary_dim,
163
+ max_position=max_position_embeddings,
164
+ base=rope_theta,
165
+ rope_scaling=rope_scaling,
166
+ is_neox_style=rope_is_neox_style,
167
+ )
168
+ self.attn = RadixAttention(
169
+ self.num_heads,
170
+ self.head_dim,
171
+ self.scaling,
172
+ num_kv_heads=self.num_kv_heads,
173
+ layer_id=layer_id,
174
+ quant_config=quant_config,
175
+ prefix=add_prefix("attn", prefix),
176
+ )
177
+
178
+ def forward(
179
+ self,
180
+ positions: torch.Tensor,
181
+ hidden_states: torch.Tensor,
182
+ forward_batch: ForwardBatch,
183
+ ) -> torch.Tensor:
184
+ qkv, _ = self.qkv_proj(hidden_states)
185
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
186
+ q, k = self.rotary_emb(positions, q, k)
187
+ attn_output = self.attn(q, k, v, forward_batch)
188
+ output, _ = self.o_proj(attn_output)
189
+ return output
190
+
191
+
192
+ class ArceeDecoderLayer(nn.Module):
193
+ def __init__(
194
+ self,
195
+ config: LlamaConfig,
196
+ layer_id: int = 0,
197
+ quant_config: Optional[QuantizationConfig] = None,
198
+ prefix: str = "",
199
+ ) -> None:
200
+ super().__init__()
201
+ self.hidden_size = config.hidden_size
202
+ rope_theta = getattr(config, "rope_theta", 10000)
203
+ rope_scaling = getattr(config, "rope_scaling", None)
204
+ if rope_scaling is not None and getattr(
205
+ config, "original_max_position_embeddings", None
206
+ ):
207
+ rope_scaling["original_max_position_embeddings"] = (
208
+ config.original_max_position_embeddings
209
+ )
210
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
211
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
212
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
213
+ config, "bias", False
214
+ )
215
+ self.self_attn = ArceeAttention(
216
+ config=config,
217
+ hidden_size=self.hidden_size,
218
+ num_heads=config.num_attention_heads,
219
+ num_kv_heads=config.num_key_value_heads,
220
+ layer_id=layer_id,
221
+ rope_theta=rope_theta,
222
+ rope_scaling=rope_scaling,
223
+ rope_is_neox_style=rope_is_neox_style,
224
+ max_position_embeddings=max_position_embeddings,
225
+ quant_config=quant_config,
226
+ prefix=add_prefix("self_attn", prefix),
227
+ bias=attention_bias,
228
+ )
229
+ self.mlp = ArceeMLP(
230
+ hidden_size=self.hidden_size,
231
+ intermediate_size=config.intermediate_size,
232
+ hidden_act=config.hidden_act,
233
+ quant_config=quant_config,
234
+ prefix=add_prefix("mlp", prefix),
235
+ )
236
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
237
+ self.post_attention_layernorm = RMSNorm(
238
+ config.hidden_size, eps=config.rms_norm_eps
239
+ )
240
+
241
+ def forward(
242
+ self,
243
+ positions: torch.Tensor,
244
+ hidden_states: torch.Tensor,
245
+ forward_batch: ForwardBatch,
246
+ residual: Optional[torch.Tensor],
247
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
248
+ # Self Attention
249
+ if residual is None:
250
+ residual = hidden_states
251
+ hidden_states = self.input_layernorm(hidden_states)
252
+ else:
253
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
254
+ hidden_states = self.self_attn(
255
+ positions=positions,
256
+ hidden_states=hidden_states,
257
+ forward_batch=forward_batch,
258
+ )
259
+
260
+ # Fully Connected
261
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
262
+ hidden_states = self.mlp(hidden_states)
263
+ return hidden_states, residual
264
+
265
+
266
+ class ArceeModel(nn.Module):
267
+ def __init__(
268
+ self,
269
+ config: LlamaConfig,
270
+ quant_config: Optional[QuantizationConfig] = None,
271
+ prefix: str = "",
272
+ ) -> None:
273
+ super().__init__()
274
+ self.config = config
275
+ self.padding_idx = config.pad_token_id
276
+ self.vocab_size = config.vocab_size
277
+ self.pp_group = get_pp_group()
278
+ if self.pp_group.is_first_rank:
279
+ self.embed_tokens = VocabParallelEmbedding(
280
+ config.vocab_size,
281
+ config.hidden_size,
282
+ quant_config=quant_config,
283
+ prefix=add_prefix("embed_tokens", prefix),
284
+ )
285
+ else:
286
+ self.embed_tokens = PPMissingLayer()
287
+
288
+ self.layers, self.start_layer, self.end_layer = make_layers(
289
+ config.num_hidden_layers,
290
+ lambda idx, prefix: ArceeDecoderLayer(
291
+ config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
292
+ ),
293
+ pp_rank=self.pp_group.rank_in_group,
294
+ pp_size=self.pp_group.world_size,
295
+ prefix="model.layers",
296
+ )
297
+
298
+ if self.pp_group.is_last_rank:
299
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
300
+ else:
301
+ self.norm = PPMissingLayer(return_tuple=True)
302
+ self.layers_to_capture = []
303
+
304
+ def forward(
305
+ self,
306
+ input_ids: torch.Tensor,
307
+ positions: torch.Tensor,
308
+ forward_batch: ForwardBatch,
309
+ input_embeds: torch.Tensor = None,
310
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
311
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
312
+ if self.pp_group.is_first_rank:
313
+ if input_embeds is None:
314
+ hidden_states = self.embed_tokens(input_ids)
315
+ else:
316
+ hidden_states = input_embeds
317
+ residual = None
318
+ else:
319
+ assert pp_proxy_tensors is not None
320
+ hidden_states = pp_proxy_tensors["hidden_states"]
321
+ residual = pp_proxy_tensors["residual"]
322
+
323
+ aux_hidden_states = []
324
+ for i in range(self.start_layer, self.end_layer):
325
+ if i in self.layers_to_capture:
326
+ aux_hidden_states.append(hidden_states + residual)
327
+ layer = self.layers[i]
328
+ hidden_states, residual = layer(
329
+ positions,
330
+ hidden_states,
331
+ forward_batch,
332
+ residual,
333
+ )
334
+
335
+ if not self.pp_group.is_last_rank:
336
+ return PPProxyTensors(
337
+ {
338
+ "hidden_states": hidden_states,
339
+ "residual": residual,
340
+ }
341
+ )
342
+ else:
343
+ hidden_states, _ = self.norm(hidden_states, residual)
344
+
345
+ if len(aux_hidden_states) == 0:
346
+ return hidden_states
347
+
348
+ return hidden_states, aux_hidden_states
349
+
350
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
351
+ tp_size = get_tensor_model_parallel_world_size()
352
+ tp_rank = get_tensor_model_parallel_rank()
353
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
354
+ quantization_param_path,
355
+ tp_rank,
356
+ tp_size,
357
+ self.config.num_hidden_layers,
358
+ self.config.__class__.model_type,
359
+ ):
360
+ if not isinstance(self.layers[layer_idx], nn.Identity):
361
+ layer_self_attn = self.layers[layer_idx].self_attn
362
+
363
+ if hasattr(layer_self_attn.attn, "k_scale"):
364
+ layer_self_attn.attn.k_scale = scaling_factor
365
+ layer_self_attn.attn.v_scale = scaling_factor
366
+ else:
367
+ raise RuntimeError(
368
+ "Self attention has no KV cache scaling factor attribute!"
369
+ )
370
+
371
+
372
+ class ArceeForCausalLM(nn.Module):
373
+ # BitandBytes specific attributes
374
+ default_bitsandbytes_target_modules = [
375
+ # Note: gate_proj is removed compared to Llama
376
+ ".down_proj.",
377
+ ".up_proj.",
378
+ ".q_proj.",
379
+ ".k_proj.",
380
+ ".v_proj.",
381
+ ".o_proj.",
382
+ ]
383
+ # in TP, these weights are partitioned along the column dimension (dim=-1)
384
+ column_parallel_weights_modules = [".down_proj.", ".o_proj."]
385
+ bitsandbytes_stacked_params_mapping = {
386
+ # shard_name, weight_name, index
387
+ # Note: gate_proj and up_proj are removed as they are not stacked in ArceeMLP
388
+ ".q_proj": (".qkv_proj", 0),
389
+ ".k_proj": (".qkv_proj", 1),
390
+ ".v_proj": (".qkv_proj", 2),
391
+ }
392
+
393
+ def __init__(
394
+ self,
395
+ config: LlamaConfig,
396
+ quant_config: Optional[QuantizationConfig] = None,
397
+ prefix: str = "",
398
+ ) -> None:
399
+ super().__init__()
400
+ self.pp_group = get_pp_group()
401
+ self.config = config
402
+ self.quant_config = quant_config
403
+ self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
404
+ # Arcee does not tie word embeddings
405
+ self.lm_head = ParallelLMHead(
406
+ config.vocab_size,
407
+ config.hidden_size,
408
+ quant_config=quant_config,
409
+ prefix=add_prefix("lm_head", prefix),
410
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
411
+ )
412
+ self.logits_processor = LogitsProcessor(config)
413
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
414
+ # Parameters that are stacked in a single tensor in this model
415
+ self.stacked_params_mapping = [
416
+ # (param_name, shard_name, shard_id)
417
+ (".qkv_proj", ".q_proj", "q"),
418
+ (".qkv_proj", ".k_proj", "k"),
419
+ (".qkv_proj", ".v_proj", "v"),
420
+ ]
421
+ self.capture_aux_hidden_states = False
422
+
423
+ def _init_model(
424
+ self,
425
+ config: LlamaConfig,
426
+ quant_config: Optional[QuantizationConfig] = None,
427
+ prefix: str = "",
428
+ ):
429
+ return ArceeModel(config, quant_config=quant_config, prefix=prefix)
430
+
431
+ @torch.no_grad()
432
+ def forward(
433
+ self,
434
+ input_ids: torch.Tensor,
435
+ positions: torch.Tensor,
436
+ forward_batch: ForwardBatch,
437
+ input_embeds: torch.Tensor = None,
438
+ get_embedding: bool = False,
439
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
440
+ ) -> LogitsProcessorOutput:
441
+ hidden_states = self.model(
442
+ input_ids,
443
+ positions,
444
+ forward_batch,
445
+ input_embeds,
446
+ pp_proxy_tensors=pp_proxy_tensors,
447
+ )
448
+
449
+ aux_hidden_states = None
450
+ if self.capture_aux_hidden_states:
451
+ hidden_states, aux_hidden_states = hidden_states
452
+
453
+ if self.pp_group.is_last_rank:
454
+ if not get_embedding:
455
+ return self.logits_processor(
456
+ input_ids,
457
+ hidden_states,
458
+ self.lm_head,
459
+ forward_batch,
460
+ aux_hidden_states,
461
+ )
462
+ else:
463
+ return self.pooler(hidden_states, forward_batch)
464
+ else:
465
+ return hidden_states
466
+
467
+ @property
468
+ def start_layer(self):
469
+ return self.model.start_layer
470
+
471
+ @property
472
+ def end_layer(self):
473
+ return self.model.end_layer
474
+
475
+ def get_input_embeddings(self) -> nn.Embedding:
476
+ return self.model.embed_tokens
477
+
478
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
479
+ params_dict = dict(self.named_parameters())
480
+
481
+ for name, loaded_weight in weights:
482
+ layer_id = get_layer_id(name)
483
+ if (
484
+ layer_id is not None
485
+ and hasattr(self.model, "start_layer")
486
+ and (
487
+ layer_id < self.model.start_layer
488
+ or layer_id >= self.model.end_layer
489
+ )
490
+ ):
491
+ continue
492
+ if "rotary_emb.inv_freq" in name or "projector" in name:
493
+ continue
494
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
495
+ continue
496
+
497
+ # Handle FP8 kv-scale remapping
498
+ if "scale" in name:
499
+ name = maybe_remap_kv_scale_name(name, params_dict)
500
+ if name is None:
501
+ continue
502
+
503
+ is_stacked = False
504
+ for param_name, weight_name, shard_id in self.stacked_params_mapping:
505
+ if weight_name not in name:
506
+ continue
507
+
508
+ name = name.replace(weight_name, param_name)
509
+ if name not in params_dict:
510
+ continue
511
+
512
+ param = params_dict[name]
513
+ weight_loader = param.weight_loader
514
+ weight_loader(param, loaded_weight, shard_id)
515
+ is_stacked = True
516
+ break
517
+
518
+ if not is_stacked:
519
+ if name in params_dict:
520
+ param = params_dict[name]
521
+ weight_loader = getattr(
522
+ param, "weight_loader", default_weight_loader
523
+ )
524
+ weight_loader(param, loaded_weight)
525
+ else:
526
+ logger.warning(f"Parameter {name} not found in model.")
527
+
528
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
529
+ self.model.load_kv_cache_scales(quantization_param_path)
530
+
531
+
532
+ EntryClass = [ArceeForCausalLM]
@@ -325,6 +325,7 @@ class DeepseekV2MoE(nn.Module):
325
325
  num_experts=config.n_routed_experts
326
326
  + self.num_fused_shared_experts
327
327
  + global_server_args_dict["ep_num_redundant_experts"],
328
+ num_fused_shared_experts=self.num_fused_shared_experts,
328
329
  top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
329
330
  hidden_size=config.hidden_size,
330
331
  intermediate_size=config.moe_intermediate_size,
@@ -594,41 +595,13 @@ class DeepseekV2MoE(nn.Module):
594
595
  topk_weights = torch.empty(
595
596
  (0, self.top_k), dtype=torch.float32, device=hidden_states.device
596
597
  )
597
- if self.ep_size > 1:
598
- # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
599
- (
600
- hidden_states,
601
- topk_idx,
602
- topk_weights,
603
- reorder_topk_ids,
604
- num_recv_tokens_per_expert,
605
- seg_indptr,
606
- masked_m,
607
- expected_m,
608
- ) = self.deepep_dispatcher.dispatch(
609
- hidden_states=hidden_states,
610
- topk_idx=topk_idx,
611
- topk_weights=topk_weights,
612
- forward_batch=forward_batch,
613
- )
598
+
614
599
  final_hidden_states = self.experts(
615
600
  hidden_states=hidden_states,
616
601
  topk_idx=topk_idx,
617
602
  topk_weights=topk_weights,
618
- reorder_topk_ids=reorder_topk_ids,
619
- seg_indptr=seg_indptr,
620
- masked_m=masked_m,
621
- expected_m=expected_m,
622
- num_recv_tokens_per_expert=num_recv_tokens_per_expert,
623
603
  forward_batch=forward_batch,
624
604
  )
625
- if self.ep_size > 1:
626
- final_hidden_states = self.deepep_dispatcher.combine(
627
- hidden_states=final_hidden_states,
628
- topk_idx=topk_idx,
629
- topk_weights=topk_weights,
630
- forward_batch=forward_batch,
631
- )
632
605
 
633
606
  if shared_output is not None:
634
607
  x = shared_output
@@ -689,8 +662,7 @@ class DeepseekV2MoE(nn.Module):
689
662
 
690
663
  def op_dispatch_a(self, state):
691
664
  if self.ep_size > 1:
692
- # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
693
- self.deepep_dispatcher.dispatch_a(
665
+ self.experts.deepep_dispatcher.dispatch_a(
694
666
  hidden_states=state.hidden_states_mlp_input,
695
667
  topk_idx=state.pop("topk_idx_local"),
696
668
  topk_weights=state.pop("topk_weights_local"),
@@ -703,46 +675,32 @@ class DeepseekV2MoE(nn.Module):
703
675
  with get_global_expert_distribution_recorder().with_current_layer(
704
676
  self.layer_id
705
677
  ):
706
- (
707
- state.hidden_states_experts_input,
708
- state.topk_idx_dispatched,
709
- state.topk_weights_dispatched,
710
- state.reorder_topk_ids,
711
- state.num_recv_tokens_per_expert,
712
- state.seg_indptr,
713
- state.masked_m,
714
- state.expected_m,
715
- ) = self.deepep_dispatcher.dispatch_b(
678
+ state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
716
679
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
717
680
  )
718
681
 
719
682
  def op_experts(self, state):
720
- state.hidden_states_experts_output = self.experts(
721
- hidden_states=state.pop("hidden_states_experts_input"),
722
- topk_idx=state.topk_idx_dispatched,
723
- topk_weights=state.topk_weights_dispatched,
724
- reorder_topk_ids=state.pop("reorder_topk_ids"),
725
- seg_indptr=state.pop("seg_indptr"),
726
- masked_m=state.pop("masked_m"),
727
- expected_m=state.pop("expected_m"),
728
- num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
729
- forward_batch=state.forward_batch,
683
+ state.hidden_states_experts_output = self.experts.moe_impl(
684
+ dispatch_output=state.dispatch_output,
730
685
  )
731
686
 
732
687
  def op_combine_a(self, state):
733
688
  if self.ep_size > 1:
734
- self.deepep_dispatcher.combine_a(
689
+ self.experts.deepep_dispatcher.combine_a(
735
690
  hidden_states=state.pop("hidden_states_experts_output"),
736
- topk_idx=state.pop("topk_idx_dispatched"),
737
- topk_weights=state.pop("topk_weights_dispatched"),
691
+ topk_idx=state.dispatch_output.topk_idx,
692
+ topk_weights=state.dispatch_output.topk_weights,
738
693
  forward_batch=state.forward_batch,
739
694
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
740
695
  )
696
+ state.pop("dispatch_output")
741
697
 
742
698
  def op_combine_b(self, state):
743
699
  if self.ep_size > 1:
744
- state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
745
- tbo_subbatch_index=state.get("tbo_subbatch_index"),
700
+ state.hidden_states_after_combine = (
701
+ self.experts.deepep_dispatcher.combine_b(
702
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
703
+ )
746
704
  )
747
705
 
748
706
  def op_output(self, state):
@@ -2155,6 +2113,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2155
2113
 
2156
2114
  if disable_reason is not None:
2157
2115
  global_server_args_dict["disable_shared_experts_fusion"] = True
2116
+ self.num_fused_shared_experts = 0
2158
2117
  log_info_on_rank0(
2159
2118
  logger,
2160
2119
  f"{disable_reason} Shared experts fusion optimization is disabled.",
@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
434
434
  num_experts=config.n_routed_experts
435
435
  + self.num_fused_shared_experts
436
436
  + global_server_args_dict["ep_num_redundant_experts"],
437
+ num_fused_shared_experts=self.num_fused_shared_experts,
437
438
  top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
438
439
  hidden_size=config.hidden_size,
439
440
  intermediate_size=config.moe_intermediate_size,
@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
740
741
  global_server_args_dict["enable_deepep_moe"]
741
742
  or global_server_args_dict["enable_ep_moe"]
742
743
  ):
743
- disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
744
+ disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
744
745
 
745
746
  if disable_reason is not None:
746
747
  global_server_args_dict["disable_shared_experts_fusion"] = True
748
+ self.num_fused_shared_experts = 0
747
749
  log_info_on_rank0(
748
750
  logger,
749
751
  f"{disable_reason} Shared experts fusion optimization is disabled.",
@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module):
43
43
  top_k: int,
44
44
  hidden_size: int,
45
45
  intermediate_size: int,
46
+ layer_id: int,
46
47
  params_dtype: Optional[torch.dtype] = None,
47
48
  quant_config: Optional[QuantizationConfig] = None,
48
49
  tp_size: Optional[int] = None,
@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module):
71
72
  top_k=top_k,
72
73
  hidden_size=hidden_size,
73
74
  intermediate_size=intermediate_size,
75
+ layer_id=layer_id,
74
76
  params_dtype=params_dtype,
75
77
  reduce_results=True,
76
78
  quant_config=quant_config,
@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module):
203
205
  top_k=config.num_experts_per_tok,
204
206
  hidden_size=config.hidden_size,
205
207
  intermediate_size=config.intermediate_size,
208
+ layer_id=layer_id,
206
209
  quant_config=quant_config,
207
210
  prefix=f"{prefix}.block_sparse_moe",
208
211
  )