sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +8 -0
  3. sglang/srt/configs/model_config.py +6 -0
  4. sglang/srt/configs/step3_vl.py +172 -0
  5. sglang/srt/conversation.py +23 -0
  6. sglang/srt/disaggregation/decode.py +2 -8
  7. sglang/srt/disaggregation/prefill.py +2 -6
  8. sglang/srt/distributed/parallel_state.py +86 -1
  9. sglang/srt/entrypoints/engine.py +14 -18
  10. sglang/srt/entrypoints/http_server.py +23 -3
  11. sglang/srt/entrypoints/openai/protocol.py +3 -1
  12. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  13. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  14. sglang/srt/eplb/expert_distribution.py +5 -0
  15. sglang/srt/eplb/expert_location.py +17 -6
  16. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  17. sglang/srt/eplb/expert_location_updater.py +2 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/step3_detector.py +436 -0
  20. sglang/srt/hf_transformers_utils.py +2 -0
  21. sglang/srt/jinja_template_utils.py +4 -1
  22. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  23. sglang/srt/layers/moe/ep_moe/layer.py +98 -603
  24. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
  29. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  30. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  31. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  32. sglang/srt/layers/moe/topk.py +6 -2
  33. sglang/srt/layers/quantization/fp8.py +0 -18
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -0
  35. sglang/srt/layers/quantization/unquant.py +0 -8
  36. sglang/srt/layers/quantization/w4afp8.py +1 -0
  37. sglang/srt/managers/cache_controller.py +143 -45
  38. sglang/srt/managers/data_parallel_controller.py +6 -0
  39. sglang/srt/managers/io_struct.py +12 -2
  40. sglang/srt/managers/scheduler.py +116 -669
  41. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  42. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  43. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  44. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  45. sglang/srt/managers/template_manager.py +62 -19
  46. sglang/srt/managers/tokenizer_manager.py +166 -83
  47. sglang/srt/managers/tp_worker.py +9 -0
  48. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  49. sglang/srt/mem_cache/hicache_storage.py +45 -11
  50. sglang/srt/mem_cache/hiradix_cache.py +15 -4
  51. sglang/srt/mem_cache/memory_pool_host.py +73 -1
  52. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  53. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  54. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
  55. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  56. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  57. sglang/srt/model_executor/model_runner.py +20 -13
  58. sglang/srt/models/arcee.py +532 -0
  59. sglang/srt/models/deepseek_v2.py +15 -56
  60. sglang/srt/models/glm4_moe.py +3 -1
  61. sglang/srt/models/granitemoe.py +3 -0
  62. sglang/srt/models/grok.py +3 -0
  63. sglang/srt/models/hunyuan.py +1 -0
  64. sglang/srt/models/llama4.py +3 -0
  65. sglang/srt/models/mixtral.py +3 -0
  66. sglang/srt/models/olmoe.py +3 -0
  67. sglang/srt/models/phimoe.py +1 -0
  68. sglang/srt/models/qwen3_moe.py +12 -69
  69. sglang/srt/models/step3_vl.py +994 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/poll_based_barrier.py +31 -0
  73. sglang/srt/reasoning_parser.py +2 -1
  74. sglang/srt/server_args.py +18 -13
  75. sglang/srt/speculative/eagle_worker.py +2 -0
  76. sglang/srt/two_batch_overlap.py +8 -3
  77. sglang/test/test_utils.py +53 -0
  78. sglang/utils.py +0 -11
  79. sglang/version.py +1 -1
  80. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
  81. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
  82. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,994 @@
1
+ import logging
2
+ import math
3
+ from collections.abc import Iterable
4
+ from math import sqrt
5
+ from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import LayerNorm
10
+ from torch.nn import functional as F
11
+ from transformers import PretrainedConfig
12
+ from transformers.activations import ACT2FN
13
+
14
+ from sglang.srt.configs.step3_vl import (
15
+ Step3TextConfig,
16
+ Step3VisionEncoderConfig,
17
+ Step3VLConfig,
18
+ )
19
+ from sglang.srt.distributed import (
20
+ get_tensor_model_parallel_rank,
21
+ get_tensor_model_parallel_world_size,
22
+ tensor_model_parallel_all_reduce,
23
+ )
24
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
25
+ from sglang.srt.layers.activation import SiluAndMul
26
+ from sglang.srt.layers.attention.vision import VisionAttention
27
+ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
28
+ from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
29
+ from sglang.srt.layers.layernorm import RMSNorm
30
+ from sglang.srt.layers.linear import (
31
+ ColumnParallelLinear,
32
+ MergedColumnParallelLinear,
33
+ ReplicatedLinear,
34
+ RowParallelLinear,
35
+ )
36
+ from sglang.srt.layers.logits_processor import LogitsProcessor
37
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
38
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
39
+ from sglang.srt.layers.moe.topk import TopK
40
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
+ from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.rotary_embedding import get_rope
43
+ from sglang.srt.layers.vocab_parallel_embedding import (
44
+ ParallelLMHead,
45
+ VocabParallelEmbedding,
46
+ )
47
+ from sglang.srt.managers.mm_utils import (
48
+ MultiModalityDataPaddingPatternMultimodalTokens,
49
+ general_mm_embed_routine,
50
+ )
51
+ from sglang.srt.managers.schedule_batch import (
52
+ Modality,
53
+ MultimodalDataItem,
54
+ MultimodalInputs,
55
+ global_server_args_dict,
56
+ )
57
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
58
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
59
+ from sglang.srt.utils import add_prefix, log_info_on_rank0, make_layers
60
+
61
+ logger = logging.getLogger(__name__)
62
+
63
+
64
+ """
65
+ Text Model
66
+ """
67
+
68
+
69
+ class Step3TextMLP(nn.Module):
70
+ def __init__(
71
+ self,
72
+ hidden_size: int,
73
+ intermediate_size: int,
74
+ hidden_act: str,
75
+ quant_config: Optional[QuantizationConfig] = None,
76
+ prefix: str = "",
77
+ ) -> None:
78
+ super().__init__()
79
+ self.gate_up_proj = MergedColumnParallelLinear(
80
+ hidden_size,
81
+ [intermediate_size] * 2,
82
+ bias=False,
83
+ quant_config=quant_config,
84
+ prefix=add_prefix("gate_up_proj", prefix),
85
+ )
86
+ self.down_proj = RowParallelLinear(
87
+ intermediate_size,
88
+ hidden_size,
89
+ bias=False,
90
+ quant_config=quant_config,
91
+ prefix=add_prefix("down_proj", prefix),
92
+ )
93
+ if hidden_act != "silu":
94
+ raise ValueError(
95
+ f"Unsupported activation: {hidden_act}. "
96
+ "Only silu is supported for now."
97
+ )
98
+ self.act_fn = SiluAndMul()
99
+
100
+ def forward(self, x):
101
+ gate_up, _ = self.gate_up_proj(x)
102
+ x = self.act_fn(gate_up)
103
+ x, _ = self.down_proj(x)
104
+ return x
105
+
106
+
107
+ class Step3TextMoEMLP(nn.Module):
108
+ # Native
109
+ def __init__(
110
+ self,
111
+ layer_id: int,
112
+ config: Step3TextConfig,
113
+ quant_config: Optional[QuantizationConfig] = None,
114
+ prefix: str = "",
115
+ ):
116
+ super().__init__()
117
+ self.tp_size = get_tensor_model_parallel_world_size()
118
+ self.layer_id = layer_id
119
+ if self.tp_size > config.moe_num_experts:
120
+ raise ValueError(
121
+ f"Tensor parallel size {self.tp_size} is greater than "
122
+ f"the number of experts {config.moe_num_experts}."
123
+ )
124
+
125
+ self.topk = TopK(
126
+ top_k=config.moe_top_k,
127
+ renormalize=config.norm_expert_weight,
128
+ use_grouped_topk=False,
129
+ )
130
+
131
+ self.experts = get_moe_impl_class()(
132
+ num_experts=config.moe_num_experts,
133
+ top_k=config.moe_top_k,
134
+ hidden_size=config.hidden_size,
135
+ intermediate_size=config.moe_intermediate_size,
136
+ layer_id=layer_id,
137
+ quant_config=quant_config,
138
+ prefix=add_prefix("experts", prefix),
139
+ )
140
+
141
+ self.gate = ReplicatedLinear(
142
+ config.hidden_size,
143
+ output_size=config.moe_num_experts,
144
+ bias=False,
145
+ quant_config=None,
146
+ prefix=add_prefix("gate", prefix),
147
+ )
148
+
149
+ if global_server_args_dict["enable_deepep_moe"]:
150
+ raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
151
+
152
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
153
+ num_tokens, hidden_dim = hidden_states.shape
154
+ hidden_states = hidden_states.view(-1, hidden_dim)
155
+
156
+ router_logits, _ = self.gate(hidden_states)
157
+ topk_output = self.topk(hidden_states, router_logits)
158
+ final_hidden_states = self.experts(
159
+ hidden_states=hidden_states, topk_output=topk_output
160
+ )
161
+
162
+ if self.tp_size > 1:
163
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
164
+ return final_hidden_states.view(num_tokens, hidden_dim)
165
+
166
+
167
+ class Step3TextAttention(nn.Module):
168
+ def __init__(
169
+ self,
170
+ hidden_size: int,
171
+ num_heads: int,
172
+ num_kv_heads: int,
173
+ head_dim: int,
174
+ share_q_dim: int,
175
+ layer_id: int = 0,
176
+ rope_theta: float = 10000,
177
+ rope_scaling: Optional[Dict[str, Any]] = None,
178
+ max_position_embeddings: int = 8192,
179
+ quant_config: Optional[QuantizationConfig] = None,
180
+ rms_norm_eps=None,
181
+ prefix: str = "",
182
+ ) -> None:
183
+ super().__init__()
184
+ self.hidden_size = hidden_size
185
+
186
+ attn_tp_rank = get_attention_tp_rank()
187
+ attn_tp_size = get_attention_tp_size()
188
+
189
+ self.all_tp_rank = get_tensor_model_parallel_rank()
190
+ self.total_num_heads = num_heads
191
+ self.attn_tp_rank = attn_tp_rank
192
+ self.layer_id = layer_id
193
+ assert self.total_num_heads % attn_tp_size == 0
194
+ self.num_heads = self.total_num_heads // attn_tp_size
195
+ self.total_num_kv_heads = num_kv_heads
196
+ if self.total_num_kv_heads >= attn_tp_size:
197
+ # Number of KV heads is greater than TP size, so we partition
198
+ # the KV heads across multiple tensor parallel GPUs.
199
+ assert self.total_num_kv_heads % attn_tp_size == 0
200
+ else:
201
+ # Number of KV heads is less than TP size, so we replicate
202
+ # the KV heads across multiple tensor parallel GPUs.
203
+ assert attn_tp_size % self.total_num_kv_heads == 0
204
+ self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
205
+ self.head_dim = head_dim
206
+ self.q_size = share_q_dim if share_q_dim else head_dim
207
+ self.kv_size = self.num_kv_heads * self.head_dim
208
+
209
+ self.scaling = self.head_dim**-0.5
210
+ self.rope_theta = rope_theta
211
+ self.max_position_embeddings = max_position_embeddings
212
+
213
+ self.qkv_proj = MergedColumnParallelLinear(
214
+ hidden_size,
215
+ [self.q_size, self.kv_size, self.kv_size],
216
+ bias=False,
217
+ quant_config=quant_config,
218
+ tp_rank=0, # In fact, we need a MergedReplicatedLinear
219
+ tp_size=1,
220
+ prefix=add_prefix("qkv_proj", prefix),
221
+ )
222
+
223
+ self.o_proj = RowParallelLinear(
224
+ self.total_num_heads * self.head_dim,
225
+ hidden_size,
226
+ bias=False,
227
+ quant_config=quant_config,
228
+ tp_rank=attn_tp_rank,
229
+ tp_size=attn_tp_size,
230
+ reduce_results=False,
231
+ prefix=add_prefix("o_proj", prefix),
232
+ )
233
+
234
+ self.inter_norm = RMSNorm(self.q_size, eps=rms_norm_eps)
235
+
236
+ self.wq = ColumnParallelLinear(
237
+ self.q_size,
238
+ self.head_dim * self.total_num_heads,
239
+ bias=False,
240
+ quant_config=quant_config,
241
+ tp_rank=attn_tp_rank,
242
+ tp_size=attn_tp_size,
243
+ prefix=add_prefix("wq", prefix),
244
+ )
245
+ self.rotary_emb = get_rope(
246
+ self.head_dim,
247
+ rotary_dim=self.head_dim,
248
+ max_position=max_position_embeddings,
249
+ base=rope_theta,
250
+ rope_scaling=rope_scaling,
251
+ )
252
+ self.attn = RadixAttention(
253
+ self.num_heads,
254
+ self.head_dim,
255
+ self.scaling,
256
+ num_kv_heads=self.num_kv_heads,
257
+ layer_id=layer_id,
258
+ quant_config=quant_config,
259
+ prefix=add_prefix("attn", prefix),
260
+ )
261
+
262
+ def forward(
263
+ self,
264
+ positions: torch.Tensor,
265
+ hidden_states: torch.Tensor,
266
+ forward_batch: ForwardBatch,
267
+ ) -> torch.Tensor:
268
+ qkv, _ = self.qkv_proj(hidden_states)
269
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
270
+ q = self.inter_norm(q.contiguous())
271
+ q, _ = self.wq(q)
272
+ q, k = self.rotary_emb(positions, q, k)
273
+ attn_output = self.attn(q, k, v, forward_batch)
274
+ output, _ = self.o_proj(attn_output)
275
+ return output
276
+
277
+
278
+ class Step3TextDecoderLayer(nn.Module):
279
+ def __init__(
280
+ self,
281
+ config: Step3TextConfig,
282
+ layer_id: int,
283
+ quant_config: Optional[QuantizationConfig] = None,
284
+ prefix: str = "",
285
+ ) -> None:
286
+ super().__init__()
287
+ self.hidden_size = config.hidden_size
288
+ rope_theta = getattr(config, "rope_theta", 10000)
289
+ rope_scaling = getattr(config, "rope_scaling", None)
290
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
291
+ head_dim = getattr(
292
+ config, "head_dim", config.hidden_size // config.num_attention_heads
293
+ )
294
+ # TODO: support shared experts fusion
295
+ # self.n_shared_experts = 1
296
+ # self.num_fused_shared_experts = (
297
+ # 0
298
+ # if global_server_args_dict["disable_shared_experts_fusion"]
299
+ # else self.n_shared_experts
300
+ # )
301
+ self.num_fused_shared_experts = 0
302
+ rms_norm_eps = config.rms_norm_eps
303
+ self.self_attn = Step3TextAttention(
304
+ hidden_size=self.hidden_size,
305
+ num_heads=config.num_attention_heads,
306
+ num_kv_heads=1,
307
+ head_dim=head_dim,
308
+ share_q_dim=config.share_q_dim,
309
+ layer_id=layer_id,
310
+ rope_theta=rope_theta,
311
+ rope_scaling=rope_scaling,
312
+ max_position_embeddings=max_position_embeddings,
313
+ rms_norm_eps=rms_norm_eps,
314
+ quant_config=quant_config,
315
+ prefix=add_prefix("self_attn", prefix),
316
+ )
317
+
318
+ moe_layers_enum = getattr(config, "moe_layers_enum", None)
319
+ if moe_layers_enum is not None:
320
+ moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
321
+ else:
322
+ # Default to 1dense.
323
+ moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
324
+
325
+ self.use_moe = False
326
+
327
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
328
+ self.post_attention_layernorm = RMSNorm(
329
+ config.hidden_size, eps=config.rms_norm_eps
330
+ )
331
+
332
+ self.layer_id = layer_id
333
+ self.is_layer_sparse = True if layer_id in moe_layers_idx else False
334
+ self.is_previous_layer_sparse = (
335
+ True if layer_id - 1 in moe_layers_idx else False
336
+ )
337
+
338
+ self.layer_scatter_modes = LayerScatterModes.init_new(
339
+ layer_id=layer_id,
340
+ num_layers=config.num_hidden_layers,
341
+ is_layer_sparse=self.is_layer_sparse,
342
+ is_previous_layer_sparse=self.is_previous_layer_sparse,
343
+ )
344
+
345
+ if not self.is_layer_sparse:
346
+ self.mlp = Step3TextMLP(
347
+ hidden_size=config.hidden_size,
348
+ intermediate_size=config.intermediate_size,
349
+ hidden_act="silu",
350
+ quant_config=quant_config,
351
+ prefix=add_prefix("mlp", prefix),
352
+ )
353
+ else:
354
+ self.use_moe = True
355
+ if self.num_fused_shared_experts == 0:
356
+ self.moe = Step3TextMoEMLP(
357
+ layer_id=layer_id,
358
+ config=config,
359
+ quant_config=quant_config,
360
+ prefix=add_prefix("mlp", prefix),
361
+ )
362
+ self.share_expert = Step3TextMLP(
363
+ hidden_size=config.hidden_size,
364
+ intermediate_size=config.share_expert_dim,
365
+ hidden_act="silu",
366
+ quant_config=quant_config,
367
+ prefix=add_prefix("share_expert", prefix),
368
+ )
369
+ else:
370
+ self.moe = Step3TextMoEMLP(
371
+ layer_id=layer_id,
372
+ config=config,
373
+ quant_config=quant_config,
374
+ prefix=add_prefix("mlp", prefix),
375
+ )
376
+
377
+ self.layer_communicator = LayerCommunicator(
378
+ layer_scatter_modes=self.layer_scatter_modes,
379
+ input_layernorm=self.input_layernorm,
380
+ post_attention_layernorm=self.post_attention_layernorm,
381
+ )
382
+
383
+ def moe_mlp_forward(self, hidden_states):
384
+ if not self.num_fused_shared_experts:
385
+ h = hidden_states.clone()
386
+ hidden_states = self.moe(hidden_states)
387
+ hidden_states += self.share_expert(h)
388
+ else:
389
+ hidden_states = self.moe(hidden_states)
390
+ return hidden_states
391
+
392
+ def forward(
393
+ self,
394
+ positions: torch.Tensor,
395
+ hidden_states: torch.Tensor,
396
+ forward_batch: ForwardBatch,
397
+ residual: Optional[torch.Tensor],
398
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
399
+
400
+ hidden_states, residual = self.layer_communicator.prepare_attn(
401
+ hidden_states, residual, forward_batch
402
+ )
403
+
404
+ if hidden_states.shape[0] != 0:
405
+ hidden_states = self.self_attn(
406
+ positions=positions,
407
+ hidden_states=hidden_states,
408
+ forward_batch=forward_batch,
409
+ )
410
+
411
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
412
+ hidden_states, residual, forward_batch
413
+ )
414
+ if self.use_moe:
415
+ hidden_states = self.moe_mlp_forward(hidden_states)
416
+ else:
417
+ hidden_states = self.mlp(hidden_states)
418
+
419
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
420
+ hidden_states, residual, forward_batch
421
+ )
422
+
423
+ return hidden_states, residual
424
+
425
+
426
+ class Step3TextModel(nn.Module):
427
+ def __init__(
428
+ self,
429
+ config: PretrainedConfig,
430
+ quant_config: Optional[QuantizationConfig] = None,
431
+ prefix: str = "",
432
+ ) -> None:
433
+ super().__init__()
434
+ self.padding_idx = config.pad_token_id
435
+ self.vocab_size = config.vocab_size
436
+
437
+ self.embed_tokens = VocabParallelEmbedding(
438
+ config.vocab_size,
439
+ config.hidden_size,
440
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
441
+ prefix=add_prefix("embed_tokens", prefix),
442
+ )
443
+
444
+ self.layers = make_layers(
445
+ config.num_hidden_layers,
446
+ lambda idx, prefix: Step3TextDecoderLayer(
447
+ layer_id=idx,
448
+ config=config,
449
+ quant_config=quant_config,
450
+ prefix=prefix,
451
+ ),
452
+ prefix=add_prefix("layers", prefix),
453
+ )
454
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
455
+
456
+ def get_input_embeddings(self):
457
+ return self.embed_tokens
458
+
459
+ def forward(
460
+ self,
461
+ input_ids: torch.Tensor,
462
+ positions: torch.Tensor,
463
+ forward_batch: ForwardBatch,
464
+ input_embeds: torch.Tensor = None,
465
+ ) -> torch.Tensor:
466
+ if input_embeds is None:
467
+ hidden_states = self.embed_tokens(input_ids)
468
+ else:
469
+ hidden_states = input_embeds
470
+
471
+ residual = None
472
+ for i in range(len(self.layers)):
473
+ layer = self.layers[i]
474
+ hidden_states, residual = layer(
475
+ positions, hidden_states, forward_batch, residual
476
+ )
477
+
478
+ if hidden_states.shape[0] != 0:
479
+ if residual is None:
480
+ hidden_states = self.norm(hidden_states)
481
+ else:
482
+ hidden_states, _ = self.norm(hidden_states, residual)
483
+ return hidden_states
484
+
485
+
486
+ """
487
+ Vision Model
488
+ """
489
+
490
+
491
+ def get_abs_pos(abs_pos, tgt_size):
492
+ dim = abs_pos.size(-1)
493
+ abs_pos_new = abs_pos.squeeze(0)
494
+ cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
495
+
496
+ src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
497
+ tgt_size = int(math.sqrt(tgt_size))
498
+ dtype = abs_pos.dtype
499
+
500
+ if src_size != tgt_size:
501
+ old_pos_embed = (
502
+ old_pos_embed.view(1, src_size, src_size, dim)
503
+ .permute(0, 3, 1, 2)
504
+ .contiguous()
505
+ )
506
+ old_pos_embed = old_pos_embed.to(torch.float32)
507
+ new_pos_embed = F.interpolate(
508
+ old_pos_embed,
509
+ size=(tgt_size, tgt_size),
510
+ mode="bicubic",
511
+ antialias=True,
512
+ align_corners=False,
513
+ ).to(dtype)
514
+ new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
515
+ new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
516
+ vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
517
+ vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
518
+ return vision_pos_embed
519
+ else:
520
+ return abs_pos
521
+
522
+
523
+ class Step3VisionMLP(nn.Module):
524
+ def __init__(
525
+ self,
526
+ dim: int,
527
+ intermediate_size: int,
528
+ bias: bool = True,
529
+ hidden_act="quick_gelu",
530
+ quant_config: Optional[QuantizationConfig] = None,
531
+ prefix: str = "",
532
+ ) -> None:
533
+ super().__init__()
534
+ self.fc1 = ColumnParallelLinear(
535
+ dim,
536
+ intermediate_size,
537
+ bias=bias,
538
+ quant_config=quant_config,
539
+ prefix=add_prefix("gate_proj", prefix),
540
+ )
541
+ self.act = ACT2FN[hidden_act] # quick_gelu
542
+ self.fc2 = RowParallelLinear(
543
+ intermediate_size,
544
+ dim,
545
+ bias=bias,
546
+ quant_config=quant_config,
547
+ prefix=add_prefix("down_proj", prefix),
548
+ )
549
+
550
+ def forward(self, hidden_states) -> torch.Tensor:
551
+ hidden_states, _ = self.fc1(hidden_states)
552
+ hidden_states = self.act(hidden_states)
553
+ hidden_states, _ = self.fc2(hidden_states)
554
+ return hidden_states
555
+
556
+
557
+ class Step3VisionAttention(nn.Module):
558
+ def __init__(
559
+ self,
560
+ dim: int,
561
+ num_heads: int = 16,
562
+ qkv_backend="fa3",
563
+ quant_config=None,
564
+ prefix: str = "",
565
+ ) -> None:
566
+
567
+ super().__init__()
568
+ self.num_heads = num_heads
569
+ self.head_dim = dim // num_heads
570
+ self.out_proj = RowParallelLinear(
571
+ dim,
572
+ dim,
573
+ bias=True,
574
+ quant_config=quant_config,
575
+ prefix=add_prefix("out_proj", prefix),
576
+ )
577
+ self.scale = self.head_dim**-0.5
578
+
579
+ self.attn = VisionAttention(
580
+ embed_dim=dim,
581
+ num_heads=num_heads,
582
+ projection_size=dim,
583
+ use_qkv_parallel=True,
584
+ rotary_embed="normal",
585
+ proj_bias=True,
586
+ qkv_backend=qkv_backend,
587
+ quant_config=quant_config,
588
+ prefix=add_prefix("attn", prefix),
589
+ )
590
+
591
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
592
+ attn_output = self.attn(hidden_states)
593
+ return attn_output
594
+
595
+
596
+ class Step3VisionEmbeddings(nn.Module):
597
+
598
+ def __init__(self, config: Step3VisionEncoderConfig):
599
+ super().__init__()
600
+ self.config = config
601
+ self.embed_dim = config.hidden_size
602
+ self.image_size = config.image_size
603
+ self.patch_size = config.patch_size
604
+
605
+ self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim))
606
+
607
+ self.patch_embedding = nn.Conv2d(
608
+ in_channels=config.num_channels,
609
+ out_channels=self.embed_dim,
610
+ kernel_size=self.patch_size,
611
+ stride=self.patch_size,
612
+ bias=True,
613
+ )
614
+
615
+ self.num_patches = (self.image_size // self.patch_size) ** 2
616
+ self.pad_tp_size = 4 # hard code for padding
617
+ # To load the pretrained weights, we still use P+1 as the seqlen
618
+ self.position_embedding = torch.nn.Embedding(
619
+ self.num_patches + 1, self.embed_dim
620
+ )
621
+ self.register_buffer(
622
+ "position_ids",
623
+ torch.arange(self.num_patches + 1).expand((1, -1)),
624
+ persistent=False,
625
+ )
626
+
627
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
628
+ batch_size = pixel_values.shape[0]
629
+ patch_embeds = self.patch_embedding(
630
+ pixel_values
631
+ ) # shape = [*, width, grid, grid]
632
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
633
+
634
+ # pad
635
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
636
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
637
+ embeddings = embeddings + get_abs_pos(
638
+ self.position_embedding(self.position_ids), patch_embeds.size(1)
639
+ )
640
+ embeddings = torch.cat(
641
+ [
642
+ embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1, 1),
643
+ embeddings,
644
+ ],
645
+ dim=1,
646
+ )
647
+ return embeddings
648
+
649
+
650
+ class Step3VisionEncoderLayer(nn.Module):
651
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
652
+ super().__init__()
653
+ self.embed_dim = config.hidden_size
654
+ self.layer_norm1 = LayerNorm(self.embed_dim, eps=1e-6)
655
+ self.layer_norm2 = LayerNorm(self.embed_dim, eps=1e-6)
656
+
657
+ self.self_attn = Step3VisionAttention(
658
+ self.embed_dim, num_heads=config.num_attention_heads
659
+ )
660
+ self.mlp = Step3VisionMLP(
661
+ dim=self.embed_dim,
662
+ intermediate_size=config.intermediate_size,
663
+ hidden_act=config.hidden_act,
664
+ )
665
+
666
+ def forward(self, hidden_states) -> torch.Tensor:
667
+ hidden_states = hidden_states + self.layer_norm1(self.self_attn(hidden_states))
668
+ hidden_states = hidden_states + self.layer_norm2(self.mlp(hidden_states))
669
+ return hidden_states
670
+
671
+
672
+ class Step3VisionTransformer(nn.Module):
673
+ def __init__(self, config: Step3VisionEncoderConfig):
674
+ super().__init__()
675
+ self.config = config
676
+ self.image_size = config.image_size
677
+ self.embeddings = Step3VisionEmbeddings(config)
678
+ self.transformer = Step3VisionEncoder(config)
679
+
680
+ @property
681
+ def dtype(self) -> torch.dtype:
682
+ return self.embeddings.patch_embedding.weight.dtype
683
+
684
+ def forward(
685
+ self,
686
+ pixel_values: torch.Tensor,
687
+ ):
688
+ hidden_states = self.embeddings(pixel_values)
689
+ hidden_states = self.transformer(inputs_embeds=hidden_states)
690
+ return hidden_states
691
+
692
+
693
+ class Step3VisionEncoder(nn.Module):
694
+ """
695
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
696
+ [`Step3VisionEncoderLayer`].
697
+
698
+ Args:
699
+ config: StepVisionEncoderConfig
700
+ """
701
+
702
+ def __init__(self, config: Step3VisionEncoderConfig):
703
+ super().__init__()
704
+ self.config = config
705
+ self.layers = nn.ModuleList(
706
+ [Step3VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
707
+ )
708
+
709
+ def forward(
710
+ self,
711
+ inputs_embeds,
712
+ ) -> torch.Tensor:
713
+
714
+ hidden_states = inputs_embeds
715
+ for encoder_layer in self.layers:
716
+ hidden_states = encoder_layer(
717
+ hidden_states,
718
+ )
719
+
720
+ return hidden_states
721
+
722
+
723
+ class Step3VLForConditionalGeneration(nn.Module):
724
+
725
+ def __init__(
726
+ self,
727
+ config: Step3VLConfig,
728
+ quant_config: Optional[QuantizationConfig] = None,
729
+ prefix: str = "",
730
+ ) -> None:
731
+ super().__init__()
732
+ self.config = config
733
+ self.quant_config = quant_config
734
+ self.model = Step3TextModel(
735
+ config.text_config, quant_config, prefix=add_prefix("model", prefix)
736
+ )
737
+
738
+ self.vision_model = Step3VisionTransformer(config.vision_config)
739
+
740
+ self.vit_downsampler = nn.Conv2d(
741
+ config.vision_config.hidden_size,
742
+ config.vision_config.output_hidden_size,
743
+ kernel_size=2,
744
+ stride=config.understand_projector_stride,
745
+ )
746
+ self.vit_downsampler2 = nn.Conv2d(
747
+ config.vision_config.output_hidden_size,
748
+ config.vision_config.output_hidden_size * 2,
749
+ kernel_size=3,
750
+ stride=2,
751
+ padding=1,
752
+ )
753
+ self.vit_large_projector = nn.Linear(
754
+ config.vision_config.output_hidden_size * 2,
755
+ config.hidden_size,
756
+ bias=config.projector_bias,
757
+ )
758
+
759
+ # TODO: support shared experts fusion
760
+ # self.n_shared_experts = 1
761
+ # self.num_fused_shared_experts = (
762
+ # 0
763
+ # if global_server_args_dict["disable_shared_experts_fusion"]
764
+ # else self.n_shared_experts
765
+ # )
766
+ self.num_fused_shared_experts = 0
767
+ self.config.tie_word_embeddings = False
768
+ if getattr(self.config, "tie_word_embeddings", False):
769
+ self.lm_head = self.model.embed_tokens
770
+ else:
771
+ self.lm_head = ParallelLMHead(
772
+ config.text_config.vocab_size,
773
+ config.text_config.hidden_size,
774
+ quant_config=quant_config,
775
+ prefix=add_prefix("lm_head", prefix),
776
+ )
777
+ self.logits_processor = LogitsProcessor(config.text_config)
778
+
779
+ def _get_vision_model_output(self, input_tensor: torch.Tensor) -> torch.Tensor:
780
+ return self.vision_model(input_tensor)[:, 4:]
781
+
782
+ def _flatten_embeddings(self, embeddings) -> torch.Tensor:
783
+
784
+ if isinstance(embeddings, torch.Tensor):
785
+ # Flatten all but the last dimension.
786
+ return embeddings.flatten(0, -2)
787
+
788
+ return torch.cat(tuple(self._flatten_embeddings(t) for t in embeddings))
789
+
790
+ def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
791
+ B, P = image_features.shape[:2]
792
+ HW = int(sqrt(P))
793
+ image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
794
+ image_features = self.vit_downsampler(image_features)
795
+ image_features = self.vit_downsampler2(image_features)
796
+ n_dim = image_features.size(1)
797
+ image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1)
798
+ image_features = self.vit_large_projector(image_features)
799
+ return image_features
800
+
801
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
802
+ assert len(items) == 1 # We only have images.
803
+
804
+ item = items[0]
805
+ pixel_values = item.feature.type(self.vision_model.dtype)
806
+ num_patches = item.model_specific_data.get("num_patches")
807
+ patch_pixel_values = item.model_specific_data.get("patch_pixel_values", None)
808
+ if patch_pixel_values is not None:
809
+ patch_pixel_values = patch_pixel_values.type(self.vision_model.dtype)
810
+
811
+ if patch_pixel_values is not None:
812
+ patch_pixel_values = patch_pixel_values.to("cuda")
813
+
814
+ image_features = self._get_vision_model_output(pixel_values)
815
+ patch_image_features = (
816
+ self._get_vision_model_output(patch_pixel_values)
817
+ if patch_pixel_values is not None
818
+ else None
819
+ )
820
+
821
+ image_features = self._process_image_features(image_features)
822
+ patch_image_features = (
823
+ self._process_image_features(patch_image_features)
824
+ if patch_image_features is not None
825
+ else None
826
+ )
827
+
828
+ merged_image_features = []
829
+ cur_patch_idx = 0
830
+ for i, num_patch in enumerate(num_patches):
831
+ cur_feature = []
832
+ if num_patch > 0:
833
+ patch_slice = patch_image_features[
834
+ cur_patch_idx : cur_patch_idx + num_patch
835
+ ]
836
+ cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
837
+ cur_feature.append(image_features[i].view(-1, image_features.shape[-1]))
838
+ cur_patch_idx += num_patch
839
+ merged_image_features.append(
840
+ torch.cat(cur_feature) if len(cur_feature) > 1 else cur_feature[0]
841
+ )
842
+ return self._flatten_embeddings(merged_image_features)
843
+
844
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
845
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
846
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
847
+
848
+ @torch.no_grad()
849
+ def forward(
850
+ self,
851
+ input_ids: torch.Tensor,
852
+ positions: torch.Tensor,
853
+ forward_batch: ForwardBatch,
854
+ input_embeds: torch.Tensor = None,
855
+ ) -> torch.Tensor:
856
+ hidden_states = general_mm_embed_routine(
857
+ input_ids=input_ids,
858
+ forward_batch=forward_batch,
859
+ language_model=self.model,
860
+ data_embedding_funcs={
861
+ Modality.IMAGE: self.get_image_feature,
862
+ },
863
+ positions=positions,
864
+ )
865
+
866
+ return self.logits_processor(
867
+ input_ids, hidden_states, self.lm_head, forward_batch
868
+ )
869
+
870
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
871
+ # TODO:
872
+ stacked_params_mapping = [
873
+ # (param_name, shard_name, shard_id)
874
+ (".qkv_proj", ".q_proj", 0),
875
+ (".qkv_proj", ".k_proj", 1),
876
+ (".qkv_proj", ".v_proj", 2),
877
+ (".gate_up_proj", ".gate_proj", 0),
878
+ (".gate_up_proj", ".up_proj", 1),
879
+ ]
880
+
881
+ if self.num_fused_shared_experts > 0:
882
+ assert self.num_fused_shared_experts == 1
883
+ log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
884
+
885
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
886
+ ckpt_gate_proj_name="gate_proj",
887
+ ckpt_down_proj_name="down_proj",
888
+ ckpt_up_proj_name="up_proj",
889
+ num_experts=self.config.text_config.moe_num_experts
890
+ + self.num_fused_shared_experts,
891
+ )
892
+
893
+ params_dict = dict(self.named_parameters())
894
+ loaded_params = set()
895
+
896
+ def match_expert_and_shard_ids(name_path: str, weight_path: str) -> bool:
897
+ name_parts = name_path.split(".")
898
+ weight_parts = weight_path.split(".")
899
+ shard_id_matches = name_parts[4] == weight_parts[2]
900
+ return shard_id_matches
901
+
902
+ for name, loaded_weight in weights:
903
+ if "vision_model" in name:
904
+ # 1.It’s not great, but let’s leave it like this for now
905
+ name = name.replace("self_attn", "self_attn.attn")
906
+ # 2.
907
+ name = name.replace("out_proj", "proj")
908
+
909
+ # TODO: support vision model
910
+ if self.num_fused_shared_experts > 0 and "share" in name:
911
+ # assert False
912
+ name = name.replace("share_expert", "moe")
913
+ for mapping in expert_params_mapping:
914
+ param_name, weight_name, expert_id, shard_id = mapping
915
+ if (
916
+ expert_id != self.config.text_config.moe_num_experts
917
+ or not match_expert_and_shard_ids(name, weight_name)
918
+ ):
919
+ continue
920
+
921
+ part_name = weight_name.split(".")[-2]
922
+ fake_weight_name = name.replace(part_name, weight_name[:-1])
923
+ actual_param_name = name.replace(part_name + ".", param_name)
924
+ param = params_dict[actual_param_name]
925
+ weight_loader = param.weight_loader
926
+ weight_loader(
927
+ param,
928
+ loaded_weight,
929
+ name,
930
+ shard_id=shard_id,
931
+ expert_id=expert_id,
932
+ )
933
+ break
934
+ continue
935
+
936
+ for param_name, weight_name, shard_id in stacked_params_mapping:
937
+ if weight_name not in name:
938
+ continue
939
+ if "gate." not in name and "moe" in name:
940
+ continue
941
+ name = name.replace(weight_name, param_name)
942
+ param = params_dict[name]
943
+ weight_loader = param.weight_loader
944
+ weight_loader(param, loaded_weight, shard_id)
945
+ loaded_params.add(name)
946
+ break
947
+ else:
948
+ if "moe" not in name:
949
+ param = params_dict[name]
950
+ weight_loader = getattr(
951
+ param, "weight_loader", default_weight_loader
952
+ )
953
+ weight_loader(param, loaded_weight)
954
+ loaded_params.add(name)
955
+ else:
956
+ if "gate." in name:
957
+ name = name.replace(weight_name, param_name)
958
+ param = params_dict[name]
959
+ weight_loader = param.weight_loader
960
+ weight_loader(param, loaded_weight)
961
+ loaded_params.add(name)
962
+ continue
963
+
964
+ for mapping in expert_params_mapping:
965
+ param_name, weight_name, expert_id, shard_id = mapping
966
+ if expert_id == self.config.text_config.moe_num_experts:
967
+ continue
968
+ if not match_expert_and_shard_ids(name, weight_name):
969
+ continue
970
+ part_name = weight_name.split(".")[-2]
971
+ fake_weight_name = name.replace(part_name, weight_name[:-1])
972
+ actual_param_name = name.replace(part_name + ".", param_name)
973
+ param = params_dict[actual_param_name]
974
+ weight_loader = param.weight_loader
975
+ weight_loader(
976
+ param,
977
+ loaded_weight[expert_id],
978
+ name,
979
+ shard_id=shard_id,
980
+ expert_id=expert_id,
981
+ )
982
+ loaded_params.add(actual_param_name)
983
+ # Don't break here, because this 'loaded_weight' includes all the weights for this layer
984
+
985
+ @classmethod
986
+ def get_model_config_for_expert_location(cls, config: Step3VLConfig):
987
+ return ModelConfigForExpertLocation(
988
+ num_layers=config.text_config.num_hidden_layers,
989
+ num_logical_experts=config.text_config.moe_num_experts,
990
+ num_groups=None,
991
+ )
992
+
993
+
994
+ EntryClass = Step3VLForConditionalGeneration