sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/conversation.py +0 -112
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  7. sglang/srt/disaggregation/launch_lb.py +5 -20
  8. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  9. sglang/srt/disaggregation/prefill.py +1 -0
  10. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  11. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  12. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  13. sglang/srt/distributed/parallel_state.py +11 -0
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +35 -15
  16. sglang/srt/eplb/expert_distribution.py +4 -2
  17. sglang/srt/hf_transformers_utils.py +25 -10
  18. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  19. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  20. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  21. sglang/srt/layers/attention/utils.py +6 -1
  22. sglang/srt/layers/attention/vision.py +27 -10
  23. sglang/srt/layers/communicator.py +14 -4
  24. sglang/srt/layers/linear.py +7 -1
  25. sglang/srt/layers/logits_processor.py +9 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +29 -68
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
  29. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  30. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  31. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  32. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  33. sglang/srt/layers/moe/utils.py +43 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  35. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  36. sglang/srt/layers/quantization/fp8.py +57 -1
  37. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  38. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  39. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  40. sglang/srt/lora/lora_registry.py +7 -0
  41. sglang/srt/managers/cache_controller.py +43 -39
  42. sglang/srt/managers/data_parallel_controller.py +52 -2
  43. sglang/srt/managers/io_struct.py +6 -1
  44. sglang/srt/managers/schedule_batch.py +3 -2
  45. sglang/srt/managers/schedule_policy.py +3 -1
  46. sglang/srt/managers/scheduler.py +145 -6
  47. sglang/srt/managers/template_manager.py +25 -22
  48. sglang/srt/managers/tokenizer_manager.py +114 -62
  49. sglang/srt/managers/utils.py +45 -1
  50. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  51. sglang/srt/mem_cache/hicache_storage.py +13 -12
  52. sglang/srt/mem_cache/hiradix_cache.py +21 -4
  53. sglang/srt/mem_cache/memory_pool.py +15 -118
  54. sglang/srt/mem_cache/memory_pool_host.py +350 -33
  55. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  56. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
  57. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  58. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
  59. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
  60. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
  61. sglang/srt/model_executor/cuda_graph_runner.py +42 -4
  62. sglang/srt/model_executor/forward_batch_info.py +13 -3
  63. sglang/srt/model_executor/model_runner.py +13 -1
  64. sglang/srt/model_loader/weight_utils.py +2 -0
  65. sglang/srt/models/deepseek_v2.py +28 -23
  66. sglang/srt/models/glm4_moe.py +85 -22
  67. sglang/srt/models/grok.py +3 -3
  68. sglang/srt/models/llama4.py +13 -2
  69. sglang/srt/models/mixtral.py +3 -3
  70. sglang/srt/models/mllama4.py +428 -19
  71. sglang/srt/models/qwen2_moe.py +1 -4
  72. sglang/srt/models/qwen3_moe.py +7 -8
  73. sglang/srt/models/step3_vl.py +1 -4
  74. sglang/srt/multimodal/processors/base_processor.py +4 -3
  75. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  76. sglang/srt/operations_strategy.py +1 -1
  77. sglang/srt/server_args.py +115 -21
  78. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  79. sglang/srt/two_batch_overlap.py +6 -4
  80. sglang/srt/utils.py +4 -24
  81. sglang/srt/weight_sync/utils.py +1 -1
  82. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  83. sglang/test/runners.py +2 -2
  84. sglang/test/test_utils.py +3 -3
  85. sglang/version.py +1 -1
  86. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  87. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
  88. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  89. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  90. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  91. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  92. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,24 @@
1
1
  import json as json_lib
2
2
  import logging
3
+ import math
3
4
  import os
4
5
  from collections.abc import Iterable
5
6
  from typing import List, Optional, Set, Tuple
6
7
 
7
8
  import torch
8
9
  from torch import nn
9
- from transformers import Llama4Config
10
+ from transformers import Llama4Config, Llama4VisionConfig
10
11
  from transformers.models.llama4.modeling_llama4 import (
11
12
  Llama4MultiModalProjector,
12
- Llama4VisionModel,
13
+ vision_apply_rotary_emb,
13
14
  )
14
15
 
16
+ from sglang.srt.layers.attention.vision import VisionAttention
17
+ from sglang.srt.layers.linear import (
18
+ ColumnParallelLinear,
19
+ ReplicatedLinear,
20
+ RowParallelLinear,
21
+ )
15
22
  from sglang.srt.layers.logits_processor import LogitsProcessor
16
23
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
17
24
  from sglang.srt.layers.quantization import QuantizationConfig
@@ -26,10 +33,10 @@ from sglang.srt.managers.schedule_batch import (
26
33
  global_server_args_dict,
27
34
  )
28
35
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
29
- from sglang.srt.model_loader.weight_utils import default_weight_loader
30
- from sglang.srt.utils import add_prefix, is_cpu
36
+ from sglang.srt.utils import is_cpu
31
37
 
32
38
  _is_cpu = is_cpu()
39
+
33
40
  from sglang.srt.model_loader.weight_utils import (
34
41
  default_weight_loader,
35
42
  maybe_remap_kv_scale_name,
@@ -39,6 +46,376 @@ from sglang.srt.utils import add_prefix
39
46
  logger = logging.getLogger(__name__)
40
47
 
41
48
 
49
+ class Llama4VisionMLP(nn.Module):
50
+
51
+ def __init__(
52
+ self,
53
+ input_size: int,
54
+ intermediate_size: int,
55
+ output_size: int,
56
+ bias: bool,
57
+ output_activation: bool,
58
+ quant_config: Optional[QuantizationConfig] = None,
59
+ prefix: str = "",
60
+ use_data_parallel: bool = False,
61
+ ):
62
+ super().__init__()
63
+ cls_fc1 = ReplicatedLinear if use_data_parallel else ColumnParallelLinear
64
+ self.fc1 = cls_fc1(
65
+ input_size=input_size,
66
+ output_size=intermediate_size,
67
+ bias=bias,
68
+ quant_config=quant_config,
69
+ prefix=f"{prefix}.fc1",
70
+ )
71
+ cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear
72
+ self.fc2 = cls_fc2(
73
+ input_size=intermediate_size,
74
+ output_size=output_size,
75
+ bias=bias,
76
+ quant_config=quant_config,
77
+ prefix=f"{prefix}.fc2",
78
+ )
79
+ self.activation_fn = nn.GELU()
80
+ self.output_activation = output_activation
81
+
82
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
83
+ hidden_states, _ = self.fc1(hidden_states)
84
+ hidden_states = self.activation_fn(hidden_states)
85
+ hidden_states, _ = self.fc2(hidden_states)
86
+ if self.output_activation:
87
+ return self.activation_fn(hidden_states)
88
+ return hidden_states
89
+
90
+
91
+ def pixel_shuffle(input_tensor, shuffle_ratio):
92
+ # input_tensor: [batch_size, num_patches, channels]
93
+ batch_size, num_patches, channels = input_tensor.shape
94
+ patch_size = int(math.sqrt(num_patches))
95
+
96
+ input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
97
+ batch_size, height, width, channels = input_tensor.size()
98
+
99
+ reshaped_tensor = input_tensor.view(
100
+ batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
101
+ )
102
+ reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
103
+
104
+ reshaped_tensor = reshaped_tensor.view(
105
+ batch_size,
106
+ int(height * shuffle_ratio),
107
+ int(width * shuffle_ratio),
108
+ int(channels / (shuffle_ratio**2)),
109
+ )
110
+ reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
111
+
112
+ output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
113
+ return output_tensor
114
+
115
+
116
+ class Llama4VisionPixelShuffleMLP(nn.Module):
117
+
118
+ def __init__(
119
+ self,
120
+ config,
121
+ quant_config: Optional[QuantizationConfig] = None,
122
+ prefix: str = "",
123
+ use_data_parallel: bool = False,
124
+ ):
125
+ super().__init__()
126
+ self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
127
+ self.mlp = Llama4VisionMLP(
128
+ input_size=config.intermediate_size,
129
+ intermediate_size=config.projector_input_dim,
130
+ output_size=config.projector_output_dim,
131
+ bias=config.multi_modal_projector_bias,
132
+ output_activation=True,
133
+ quant_config=quant_config,
134
+ prefix=f"{prefix}.mlp",
135
+ use_data_parallel=use_data_parallel,
136
+ )
137
+
138
+ def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
139
+ encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
140
+ return self.mlp(encoded_patches)
141
+
142
+
143
+ def apply_position_embedding(q, k, freqs_ci, shape):
144
+ # [batch_size_times_num_tiles, num_channels]
145
+ input_shape = shape[:2]
146
+ # [batch_size_times_num_tiles, num_channels, num_heads, head_dim]
147
+ hidden_shape = (*input_shape, *q.shape[-2:])
148
+ q = q.view(hidden_shape)
149
+ k = k.view(hidden_shape)
150
+ q, k = vision_apply_rotary_emb(q, k, freqs_ci)
151
+ return q, k
152
+
153
+
154
+ class Llama4VisionEncoderLayer(nn.Module):
155
+
156
+ def __init__(
157
+ self,
158
+ config: Llama4VisionConfig,
159
+ quant_config: Optional[QuantizationConfig],
160
+ prefix: str = "",
161
+ use_data_parallel: bool = False,
162
+ ):
163
+ super().__init__()
164
+ self.hidden_size = config.hidden_size
165
+ self.num_attention_heads = config.num_attention_heads
166
+ self.intermediate_size = config.intermediate_size
167
+
168
+ self.self_attn = VisionAttention(
169
+ self.hidden_size,
170
+ self.num_attention_heads,
171
+ self.hidden_size,
172
+ use_qkv_parallel=True,
173
+ # vision_model is explicitly ignored in Maverick-17B-128E-Instruct-FP8
174
+ quant_config=None,
175
+ dropout=0.0,
176
+ qkv_backend="sdpa",
177
+ softmax_in_single_precision=False,
178
+ flatten_batch=False,
179
+ prefix=add_prefix("self_attn", prefix),
180
+ qkv_bias=True,
181
+ customized_position_embedding_applier=apply_position_embedding,
182
+ )
183
+ self.mlp = Llama4VisionMLP(
184
+ input_size=config.hidden_size,
185
+ intermediate_size=config.intermediate_size,
186
+ output_size=config.hidden_size,
187
+ bias=True,
188
+ output_activation=False,
189
+ quant_config=quant_config,
190
+ prefix=f"{prefix}.mlp",
191
+ use_data_parallel=use_data_parallel,
192
+ )
193
+
194
+ self.input_layernorm = nn.LayerNorm(config.hidden_size)
195
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
196
+
197
+ def forward(
198
+ self,
199
+ hidden_state: torch.Tensor,
200
+ freqs_ci: torch.Tensor,
201
+ ):
202
+ # Self Attention
203
+ residual = hidden_state
204
+ hidden_state = self.input_layernorm(hidden_state)
205
+ hidden_state = self.self_attn(hidden_state, position_embeddings=freqs_ci)
206
+ hidden_state = residual + hidden_state
207
+
208
+ # Feed forward
209
+ residual = hidden_state
210
+ hidden_state = self.post_attention_layernorm(hidden_state)
211
+ hidden_state = self.mlp(hidden_state)
212
+ hidden_state = residual + hidden_state
213
+
214
+ outputs = hidden_state
215
+ return outputs
216
+
217
+
218
+ class Llama4VisionEncoder(nn.Module):
219
+
220
+ def __init__(
221
+ self,
222
+ config: Llama4VisionConfig,
223
+ quant_config: Optional[QuantizationConfig],
224
+ prefix: str = "",
225
+ use_data_parallel: bool = False,
226
+ ):
227
+ super().__init__()
228
+ self.config = config
229
+ self.layers = nn.ModuleList(
230
+ [
231
+ Llama4VisionEncoderLayer(
232
+ config,
233
+ quant_config=quant_config,
234
+ prefix=f"{prefix}.layers.{layer_idx}",
235
+ use_data_parallel=use_data_parallel,
236
+ )
237
+ for layer_idx in range(config.num_hidden_layers)
238
+ ]
239
+ )
240
+
241
+ def forward(
242
+ self,
243
+ hidden_states: torch.Tensor,
244
+ freqs_ci: torch.Tensor, # TODO: move this to an attribute instead of keeping it around
245
+ ) -> torch.Tensor:
246
+ r"""
247
+ Args:
248
+ hidden_states (`torch.FloatTensor` of shape
249
+ `(batch_size, sequence_length, hidden_size)`):
250
+ Optionally, instead of passing `input_ids` you can choose to
251
+ directly pass an embedded representation. This is useful if you
252
+ want more control over how to convert `input_ids` indices into
253
+ associated vectors than the model's internal embedding
254
+ lookup matrix.
255
+ """
256
+
257
+ for encoder_layer in self.layers:
258
+ layer_outputs = encoder_layer(hidden_states, freqs_ci=freqs_ci)
259
+ hidden_states = layer_outputs
260
+
261
+ return hidden_states
262
+
263
+
264
+ class Llama4UnfoldConvolution(nn.Module):
265
+
266
+ def __init__(
267
+ self,
268
+ config: Llama4VisionConfig,
269
+ quant_config: Optional[QuantizationConfig] = None,
270
+ prefix: str = "",
271
+ use_data_parallel: bool = False,
272
+ ):
273
+ super().__init__()
274
+ kernel_size = config.patch_size
275
+ if isinstance(kernel_size, int):
276
+ kernel_size = (kernel_size, kernel_size)
277
+ self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
278
+ params = {
279
+ "input_size": config.num_channels * kernel_size[0] * kernel_size[1],
280
+ "output_size": config.hidden_size,
281
+ "bias": False,
282
+ "quant_config": quant_config,
283
+ "prefix": f"{prefix}.linear",
284
+ }
285
+ if use_data_parallel:
286
+ cls = ReplicatedLinear
287
+ else:
288
+ cls = ColumnParallelLinear
289
+ params["gather_output"] = True
290
+ self.linear = cls(**params)
291
+
292
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
293
+ hidden_states = self.unfold(hidden_states)
294
+ hidden_states = hidden_states.permute(0, 2, 1)
295
+ hidden_states, _ = self.linear(hidden_states)
296
+ return hidden_states
297
+
298
+
299
+ class Llama4VisionRotaryEmbedding(nn.Module):
300
+ def __init__(self, config):
301
+ super().__init__()
302
+ idx = config.image_size // config.patch_size
303
+ img_idx = torch.arange(idx**2, dtype=torch.int32).reshape(idx**2, 1)
304
+ img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
305
+ img_idx[-1, -1] = -2 # ID_CLS_TOKEN
306
+ frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x
307
+ frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y
308
+ freq_dim = config.hidden_size // config.num_attention_heads // 2
309
+ rope_freq = 1.0 / (
310
+ config.rope_theta
311
+ ** (torch.arange(0, freq_dim, 2)[: (freq_dim // 2)].float() / freq_dim)
312
+ )
313
+ freqs_x = (
314
+ (frequencies_x + 1)[..., None] * rope_freq[None, None, :]
315
+ ).repeat_interleave(2, dim=-1)
316
+ freqs_y = (
317
+ (frequencies_y + 1)[..., None] * rope_freq[None, None, :]
318
+ ).repeat_interleave(2, dim=-1)
319
+ freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
320
+ freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
321
+ freq_cis = torch.view_as_complex(
322
+ torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
323
+ )
324
+ self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
325
+
326
+ def forward(self, hidden_states):
327
+ return self.freqs_ci.to(hidden_states.device)
328
+
329
+
330
+ class Llama4VisionModel(nn.Module):
331
+
332
+ def __init__(
333
+ self,
334
+ config: Llama4VisionConfig,
335
+ quant_config: Optional[QuantizationConfig] = None,
336
+ prefix: str = "",
337
+ ):
338
+ super().__init__()
339
+ self.config = config
340
+ self.image_size = config.image_size
341
+ self.patch_size = config.patch_size
342
+ self.hidden_size = config.hidden_size
343
+ self.num_channels = config.num_channels
344
+
345
+ self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
346
+ self.scale = config.hidden_size**-0.5
347
+
348
+ self.patch_embedding = Llama4UnfoldConvolution(
349
+ config,
350
+ quant_config=quant_config,
351
+ prefix=f"{prefix}.patch_embedding",
352
+ )
353
+
354
+ self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
355
+ self.positional_embedding_vlm = nn.Parameter(
356
+ self.scale * torch.randn(self.num_patches, self.hidden_size)
357
+ )
358
+
359
+ self.rotary_embedding = Llama4VisionRotaryEmbedding(config)
360
+
361
+ # layer norms
362
+ self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5)
363
+ self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)
364
+
365
+ # encoders
366
+ self.model = Llama4VisionEncoder(
367
+ config,
368
+ quant_config=quant_config,
369
+ prefix=f"{prefix}.model",
370
+ )
371
+ self.vision_adapter = Llama4VisionPixelShuffleMLP(
372
+ config,
373
+ quant_config,
374
+ prefix=f"{prefix}.vision_adapter",
375
+ )
376
+
377
+ def forward(
378
+ self,
379
+ pixel_values: torch.Tensor,
380
+ ) -> torch.Tensor:
381
+ # Patch embedding
382
+ hidden_state = self.patch_embedding(pixel_values)
383
+ num_tiles, num_patches, hidden_dim = hidden_state.shape
384
+
385
+ # Add cls token
386
+ class_embedding = self.class_embedding.expand(
387
+ hidden_state.shape[0], 1, hidden_state.shape[-1]
388
+ )
389
+ hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
390
+ num_patches += 1
391
+
392
+ # Position embeddings
393
+ hidden_state = hidden_state.reshape(
394
+ num_tiles,
395
+ 1,
396
+ num_patches,
397
+ hidden_dim,
398
+ )
399
+ positional_embedding = self.positional_embedding_vlm.to(
400
+ dtype=hidden_state.dtype, device=hidden_state.device
401
+ )
402
+ hidden_state = hidden_state + positional_embedding
403
+ hidden_state = self.layernorm_pre(hidden_state)
404
+ hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)
405
+ freqs_ci = self.rotary_embedding(pixel_values)
406
+ # Apply encoder
407
+ hidden_state = self.model(hidden_state, freqs_ci=freqs_ci)
408
+ hidden_state = self.layernorm_post(hidden_state)
409
+
410
+ # Remove CLS token output
411
+ hidden_state = hidden_state[:, :-1, :]
412
+
413
+ # now, we use Llama4VisionPixelShuffle + mlp to project embeddings
414
+ hidden_state = self.vision_adapter(hidden_state)
415
+
416
+ return hidden_state
417
+
418
+
42
419
  class Llama4ForConditionalGeneration(nn.Module):
43
420
  packed_modules_mapping = {
44
421
  "qkv_proj": ["q_proj", "k_proj", "v_proj"],
@@ -60,7 +437,8 @@ class Llama4ForConditionalGeneration(nn.Module):
60
437
  if not self.has_vision_weights:
61
438
  logger.warning(
62
439
  "No vision weights found in checkpoint. Model will run in text-only mode. "
63
- "Multimodal capabilities (image processing) will be unavailable."
440
+ "Multimodal capabilities (vision understanding) will be unavailable. "
441
+ "Please not that this warning might be inaccurate if the weights haven't been fully downloaded"
64
442
  )
65
443
 
66
444
  self.has_vision = (
@@ -68,7 +446,12 @@ class Llama4ForConditionalGeneration(nn.Module):
68
446
  )
69
447
 
70
448
  if self.has_vision:
71
- self.vision_model = Llama4VisionModel(config.vision_config)
449
+ self.vision_model = Llama4VisionModel(
450
+ config.vision_config,
451
+ quant_config=quant_config,
452
+ prefix=add_prefix("vision_model", prefix),
453
+ )
454
+
72
455
  self.multi_modal_projector = Llama4MultiModalProjector(config)
73
456
  else:
74
457
  self.vision_model = None
@@ -112,7 +495,6 @@ class Llama4ForConditionalGeneration(nn.Module):
112
495
  filename="model.safetensors.index.json",
113
496
  cache_dir=None,
114
497
  )
115
-
116
498
  if index_file_path and os.path.exists(index_file_path):
117
499
  return self._check_vision_weights_in_index(index_file_path)
118
500
 
@@ -120,7 +502,7 @@ class Llama4ForConditionalGeneration(nn.Module):
120
502
  # If we can't access the cache, fall back to config-based detection
121
503
  pass
122
504
 
123
- # Fallback assume text-only
505
+ # Fallback, assume text-only
124
506
  return False
125
507
 
126
508
  def _check_vision_weights_in_index(self, index_file: str) -> bool:
@@ -131,7 +513,6 @@ class Llama4ForConditionalGeneration(nn.Module):
131
513
 
132
514
  vision_patterns = ["vision_model", "vision_tower", "multi_modal_projector"]
133
515
  weight_names = index_data.get("weight_map", {}).keys()
134
-
135
516
  return any(
136
517
  pattern in weight_name
137
518
  for weight_name in weight_names
@@ -150,17 +531,17 @@ class Llama4ForConditionalGeneration(nn.Module):
150
531
  # For text-only models, return None or raise an error
151
532
  if not self.has_vision or self.vision_model is None:
152
533
  raise ValueError("Vision model not available for text-only checkpoint")
153
-
154
534
  pixel_values = (
155
535
  torch.concat([item.feature for item in items])
156
536
  .to(next(self.vision_model.parameters()).device)
157
537
  .type(next(self.vision_model.parameters()).dtype)
158
538
  )
539
+ image_features = self.vision_model(pixel_values)
159
540
 
160
- image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
161
- image_features = image_outputs.last_hidden_state
162
541
  vision_flat = image_features.view(-1, image_features.size(-1))
542
+
163
543
  projected_vision_flat = self.multi_modal_projector(vision_flat)
544
+
164
545
  return projected_vision_flat
165
546
 
166
547
  def forward(
@@ -246,31 +627,47 @@ class Llama4ForConditionalGeneration(nn.Module):
246
627
  num_experts=num_experts,
247
628
  )
248
629
 
630
+ loaded_params = set()
631
+
249
632
  for name, loaded_weight in weights:
250
633
  if self._should_skip_weight(name):
251
634
  continue
252
635
 
253
636
  name = self._transform_weight_name(name)
254
637
 
255
- if "vision" not in name:
638
+ if "vision" in name:
639
+ name = name.replace(".self_attn.o_proj", ".self_attn.proj")
640
+ else:
256
641
  name, loaded_weight = self.permute_qk_weight_for_rotary(
257
642
  name, loaded_weight
258
643
  )
259
644
 
260
645
  if self._handle_scale_remapping(name, params_dict):
646
+ loaded_params.add(name)
261
647
  continue
262
648
 
263
649
  if self._handle_stacked_params(
264
- name, loaded_weight, stacked_params_mapping, params_dict
650
+ name, loaded_weight, stacked_params_mapping, params_dict, loaded_params
265
651
  ):
266
652
  continue
267
653
 
268
654
  if self._handle_expert_weights(
269
- name, loaded_weight, expert_params_mapping, params_dict, num_experts
655
+ name,
656
+ loaded_weight,
657
+ expert_params_mapping,
658
+ params_dict,
659
+ num_experts,
660
+ loaded_params,
270
661
  ):
271
662
  continue
272
663
 
664
+ loaded_params.add(name)
273
665
  self._handle_default_weight(name, loaded_weight, params_dict)
666
+ unloaded_params = params_dict.keys() - loaded_params
667
+ if unloaded_params:
668
+ logger.warning(
669
+ f"Some weights are not initialized from checkpoints {unloaded_params}"
670
+ )
274
671
 
275
672
  def _should_skip_weight(self, name: str) -> bool:
276
673
  """Check if we should skip loading this weight."""
@@ -301,11 +698,13 @@ class Llama4ForConditionalGeneration(nn.Module):
301
698
  loaded_weight: torch.Tensor,
302
699
  stacked_params_mapping: list,
303
700
  params_dict: dict,
701
+ loaded_params: set,
304
702
  ) -> bool:
305
703
  """Handle stacked parameter loading. Returns True if handled."""
306
704
  for param_name, weight_name, shard_id in stacked_params_mapping:
307
- if weight_name in name and "vision" not in name:
705
+ if weight_name in name:
308
706
  transformed_name = name.replace(weight_name, param_name)
707
+ loaded_params.add(transformed_name)
309
708
  param = params_dict[transformed_name]
310
709
  param.weight_loader(param, loaded_weight, shard_id)
311
710
  return True
@@ -318,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module):
318
717
  expert_params_mapping: list,
319
718
  params_dict: dict,
320
719
  num_experts: int,
720
+ loaded_params: set,
321
721
  ) -> bool:
322
722
  """Handle expert weight loading for MoE (Mixture of Experts) layers.
323
723
 
@@ -336,16 +736,16 @@ class Llama4ForConditionalGeneration(nn.Module):
336
736
 
337
737
  if "experts.gate_up_proj" not in name and "experts.down_proj" not in name:
338
738
  return self._handle_other_expert_params(
339
- name, loaded_weight, expert_params_mapping, params_dict
739
+ name, loaded_weight, expert_params_mapping, params_dict, loaded_params
340
740
  )
341
741
 
342
742
  if "scale" in name:
343
743
  return self._handle_expert_scale_params(
344
- name, loaded_weight, params_dict, num_experts
744
+ name, loaded_weight, params_dict, num_experts, loaded_params
345
745
  )
346
746
  else:
347
747
  return self._handle_expert_weight_params(
348
- name, loaded_weight, params_dict, num_experts
748
+ name, loaded_weight, params_dict, num_experts, loaded_params
349
749
  )
350
750
 
351
751
  def _handle_other_expert_params(
@@ -354,6 +754,7 @@ class Llama4ForConditionalGeneration(nn.Module):
354
754
  loaded_weight: torch.Tensor,
355
755
  expert_params_mapping: list,
356
756
  params_dict: dict,
757
+ loaded_params: set,
357
758
  ) -> bool:
358
759
  """Handle expert parameters that are not gate_up_proj or down_proj weights.
359
760
 
@@ -362,6 +763,7 @@ class Llama4ForConditionalGeneration(nn.Module):
362
763
  loaded_weight: The weight tensor to be loaded
363
764
  expert_params_mapping: List of tuples mapping checkpoint names to model parameters
364
765
  params_dict: Dictionary of model parameters
766
+ loaded_params: Set of loaded parameter names
365
767
 
366
768
  Returns:
367
769
  bool: True if parameter was found and handled, False otherwise
@@ -373,6 +775,7 @@ class Llama4ForConditionalGeneration(nn.Module):
373
775
  param.weight_loader(
374
776
  param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id
375
777
  )
778
+ loaded_params.add(transformed_name)
376
779
  return True
377
780
  return False
378
781
 
@@ -411,6 +814,7 @@ class Llama4ForConditionalGeneration(nn.Module):
411
814
  loaded_weight: torch.Tensor,
412
815
  params_dict: dict,
413
816
  num_experts: int,
817
+ loaded_params: set,
414
818
  ) -> bool:
415
819
  """Handle quantization scale parameters for expert weights.
416
820
 
@@ -419,6 +823,7 @@ class Llama4ForConditionalGeneration(nn.Module):
419
823
  loaded_weight: Scale tensor to be loaded
420
824
  params_dict: Dictionary of model parameters
421
825
  num_experts: Total number of experts for broadcast operations
826
+ loaded_params: Set of loaded parameter names
422
827
 
423
828
  Returns:
424
829
  bool: True (always handles scale parameters)
@@ -447,6 +852,7 @@ class Llama4ForConditionalGeneration(nn.Module):
447
852
  # Load the same scale for all experts
448
853
  for expert_id in range(num_experts):
449
854
  param.data[expert_id] = loaded_weight
855
+ loaded_params.add(transformed_name)
450
856
 
451
857
  return True
452
858
 
@@ -456,6 +862,7 @@ class Llama4ForConditionalGeneration(nn.Module):
456
862
  loaded_weight: torch.Tensor,
457
863
  params_dict: dict,
458
864
  num_experts: int,
865
+ loaded_params: set,
459
866
  ) -> bool:
460
867
  """Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
461
868
 
@@ -464,6 +871,7 @@ class Llama4ForConditionalGeneration(nn.Module):
464
871
  loaded_weight: Weight tensor(s) to be loaded
465
872
  params_dict: Dictionary of model parameters
466
873
  num_experts: Total number of experts for tensor distribution
874
+ loaded_params: Set of loaded parameter names
467
875
 
468
876
  Returns:
469
877
  bool: True (always handles weight parameters)
@@ -486,6 +894,7 @@ class Llama4ForConditionalGeneration(nn.Module):
486
894
 
487
895
  param = params_dict[param_name]
488
896
  weight_loader = param.weight_loader
897
+ loaded_params.add(param_name)
489
898
 
490
899
  # Handle the case where loaded_weight might be a single tensor for all experts
491
900
  if weight_chunk.dim() == 2:
@@ -148,7 +148,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
148
148
  **(
149
149
  dict(
150
150
  enable_flashinfer_cutlass_moe=True,
151
- enable_ep_moe=global_server_args_dict["enable_ep_moe"],
152
151
  )
153
152
  if global_server_args_dict["enable_flashinfer_cutlass_moe"]
154
153
  else {}
@@ -616,9 +615,7 @@ class Qwen2MoeForCausalLM(nn.Module):
616
615
  ("gate_up_proj", "up_proj", 1),
617
616
  ]
618
617
 
619
- MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
620
-
621
- expert_params_mapping = MoEImpl.make_expert_params_mapping(
618
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
622
619
  ckpt_gate_proj_name="gate_proj",
623
620
  ckpt_down_proj_name="down_proj",
624
621
  ckpt_up_proj_name="up_proj",