sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/bench_serving.py +56 -12
  2. sglang/launch_server.py +2 -0
  3. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
  4. sglang/srt/compilation/backend.py +1 -1
  5. sglang/srt/configs/model_config.py +5 -5
  6. sglang/srt/distributed/parallel_state.py +0 -7
  7. sglang/srt/entrypoints/engine.py +18 -15
  8. sglang/srt/entrypoints/grpc_server.py +0 -1
  9. sglang/srt/entrypoints/http_server.py +75 -94
  10. sglang/srt/environ.py +16 -2
  11. sglang/srt/eplb/expert_distribution.py +30 -0
  12. sglang/srt/function_call/function_call_parser.py +2 -0
  13. sglang/srt/function_call/minimax_m2.py +367 -0
  14. sglang/srt/layers/activation.py +6 -0
  15. sglang/srt/layers/attention/flashattention_backend.py +12 -2
  16. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  17. sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
  18. sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
  19. sglang/srt/layers/attention/utils.py +78 -0
  20. sglang/srt/layers/communicator.py +1 -0
  21. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  22. sglang/srt/layers/layernorm.py +19 -4
  23. sglang/srt/layers/logits_processor.py +5 -0
  24. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  25. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  26. sglang/srt/layers/moe/ep_moe/layer.py +79 -272
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  29. sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
  30. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  31. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  32. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  33. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  34. sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
  35. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  36. sglang/srt/layers/moe/topk.py +4 -4
  37. sglang/srt/layers/moe/utils.py +3 -4
  38. sglang/srt/layers/quantization/__init__.py +3 -5
  39. sglang/srt/layers/quantization/awq.py +0 -3
  40. sglang/srt/layers/quantization/base_config.py +7 -0
  41. sglang/srt/layers/quantization/fp8.py +68 -63
  42. sglang/srt/layers/quantization/gguf.py +566 -0
  43. sglang/srt/layers/quantization/mxfp4.py +30 -38
  44. sglang/srt/layers/quantization/unquant.py +23 -45
  45. sglang/srt/layers/quantization/w4afp8.py +38 -2
  46. sglang/srt/layers/radix_attention.py +5 -2
  47. sglang/srt/layers/rotary_embedding.py +13 -1
  48. sglang/srt/layers/sampler.py +12 -1
  49. sglang/srt/managers/io_struct.py +3 -0
  50. sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
  51. sglang/srt/managers/scheduler.py +21 -15
  52. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  53. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  54. sglang/srt/managers/tokenizer_manager.py +11 -19
  55. sglang/srt/mem_cache/hicache_storage.py +7 -1
  56. sglang/srt/mem_cache/memory_pool.py +82 -0
  57. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  58. sglang/srt/model_executor/forward_batch_info.py +44 -3
  59. sglang/srt/model_executor/model_runner.py +1 -149
  60. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  61. sglang/srt/models/deepseek_v2.py +147 -44
  62. sglang/srt/models/glm4_moe.py +322 -354
  63. sglang/srt/models/glm4_moe_nextn.py +4 -14
  64. sglang/srt/models/glm4v_moe.py +29 -196
  65. sglang/srt/models/minimax_m2.py +922 -0
  66. sglang/srt/models/nvila.py +355 -0
  67. sglang/srt/models/nvila_lite.py +184 -0
  68. sglang/srt/models/qwen2.py +22 -1
  69. sglang/srt/models/qwen3.py +34 -4
  70. sglang/srt/models/qwen3_moe.py +2 -4
  71. sglang/srt/multimodal/processors/base_processor.py +1 -0
  72. sglang/srt/multimodal/processors/glm4v.py +1 -1
  73. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  74. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  75. sglang/srt/parser/reasoning_parser.py +28 -1
  76. sglang/srt/server_args.py +365 -186
  77. sglang/srt/single_batch_overlap.py +2 -7
  78. sglang/srt/utils/common.py +87 -42
  79. sglang/srt/utils/hf_transformers_utils.py +7 -3
  80. sglang/test/test_deterministic.py +235 -12
  81. sglang/test/test_deterministic_utils.py +2 -1
  82. sglang/version.py +1 -1
  83. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
  84. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
  85. sglang/srt/models/vila.py +0 -306
  86. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  87. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  88. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,355 @@
1
+ import itertools
2
+ import math
3
+ from collections.abc import Iterable
4
+ from typing import Any
5
+
6
+ import einops
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch import Tensor
11
+ from transformers.configuration_utils import PretrainedConfig
12
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
13
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
14
+ from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
15
+
16
+ import sglang.srt.managers.mm_utils as mm_utils
17
+ import sglang.srt.model_loader.weight_utils as weight_utils
18
+ import sglang.srt.utils as utils
19
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
20
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
21
+ from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
22
+ from sglang.srt.managers.schedule_batch import (
23
+ Modality,
24
+ MultimodalDataItem,
25
+ MultimodalInputs,
26
+ )
27
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
29
+
30
+ MM_HIDDEN_SIZE = 3456
31
+
32
+
33
+ class NVILAConfig(PretrainedConfig):
34
+ model_type = "nvila"
35
+ sub_configs = {
36
+ "text_config": Qwen2Config,
37
+ "vision_config": SiglipVisionConfig,
38
+ }
39
+ _auto_class = "AutoConfig"
40
+
41
+ def __init__(
42
+ self,
43
+ *,
44
+ text_config: dict[str, Any] | None = None,
45
+ vision_config: dict[str, Any] | None = None,
46
+ image_token_id: int | None = None,
47
+ video_token_id: int | None = None,
48
+ **kwargs,
49
+ ):
50
+ self.text_config = (
51
+ Qwen2Config(**text_config) if text_config is not None else Qwen2Config()
52
+ )
53
+ self.vision_config = (
54
+ SiglipVisionConfig(**vision_config)
55
+ if vision_config is not None
56
+ else SiglipVisionConfig()
57
+ )
58
+
59
+ self.image_token_id = image_token_id if image_token_id is not None else -1
60
+ self.video_token_id = video_token_id if video_token_id is not None else -1
61
+
62
+ super().__init__(**kwargs)
63
+
64
+
65
+ class NVILAMultiModalProjectorDownsampleBlock(nn.Module):
66
+ def forward(self, x: Tensor) -> Tensor:
67
+ batch_size, sequence_length, hidden_size = x.shape
68
+
69
+ feat_size = math.isqrt(sequence_length)
70
+
71
+ features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
72
+
73
+ pad_after = feat_size % 2
74
+ if pad_after > 0:
75
+ features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
76
+ feat_size = feat_size + pad_after
77
+
78
+ features = features.reshape(
79
+ batch_size, feat_size // 2, 2, feat_size // 2, 2, hidden_size
80
+ )
81
+ features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
82
+ features = features.reshape(batch_size, -1, 4 * hidden_size)
83
+
84
+ return features
85
+
86
+
87
+ class NVILAMultiModalProjector(nn.Module):
88
+ def __init__(self, config: NVILAConfig):
89
+ super().__init__()
90
+
91
+ self.layers = nn.Sequential(
92
+ NVILAMultiModalProjectorDownsampleBlock(),
93
+ nn.LayerNorm(MM_HIDDEN_SIZE * 4),
94
+ nn.Linear(MM_HIDDEN_SIZE * 4, config.text_config.hidden_size),
95
+ nn.GELU(),
96
+ nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size),
97
+ )
98
+
99
+ def forward(self, x: Tensor) -> Tensor:
100
+ return self.layers(x)
101
+
102
+
103
+ class NVILAForConditionalGeneration(nn.Module):
104
+ def __init__(
105
+ self,
106
+ config: NVILAConfig,
107
+ quant_config: QuantizationConfig | None = None,
108
+ prefix: str = "",
109
+ ) -> None:
110
+ super().__init__()
111
+
112
+ self.config = config
113
+
114
+ self.vision_tower = SiglipVisionModel(config.vision_config)
115
+ self.mm_projector = NVILAMultiModalProjector(config)
116
+ self.llm = Qwen2ForCausalLM(
117
+ config=config.text_config,
118
+ quant_config=quant_config,
119
+ prefix=utils.add_prefix("llm", prefix),
120
+ )
121
+
122
+ def forward(
123
+ self,
124
+ input_ids: Tensor,
125
+ positions: Tensor,
126
+ forward_batch: ForwardBatch,
127
+ get_embedding: bool = False,
128
+ ) -> LogitsProcessorOutput:
129
+ output = mm_utils.general_mm_embed_routine(
130
+ input_ids=input_ids,
131
+ forward_batch=forward_batch,
132
+ language_model=self.llm,
133
+ data_embedding_funcs={
134
+ Modality.IMAGE: self.get_image_feature,
135
+ Modality.VIDEO: self.get_image_feature,
136
+ },
137
+ get_embedding=get_embedding,
138
+ positions=positions,
139
+ )
140
+
141
+ assert isinstance(output, LogitsProcessorOutput)
142
+
143
+ return output
144
+
145
+ def get_image_feature(self, mm_input: list[MultimodalDataItem]) -> Tensor:
146
+ block_sizes = (
147
+ list(
148
+ itertools.chain.from_iterable(
149
+ x.block_sizes for x in mm_input if hasattr(x, "block_sizes")
150
+ )
151
+ )
152
+ or None
153
+ )
154
+ pixel_values = torch.cat([torch.tensor(x.feature) for x in mm_input], dim=0)
155
+
156
+ vision_tower_output: BaseModelOutputWithPooling = self.vision_tower(
157
+ pixel_values.to(
158
+ device=self.vision_tower.device, dtype=self.vision_tower.dtype
159
+ ),
160
+ output_hidden_states=True,
161
+ )
162
+ assert vision_tower_output.hidden_states is not None
163
+
164
+ vision_features: Tensor = vision_tower_output.hidden_states[-2]
165
+
166
+ vision_features_list, block_sizes = merge_features_for_dynamic_s2(
167
+ vision_features,
168
+ block_sizes=(
169
+ block_sizes
170
+ if block_sizes is not None
171
+ else [None] * vision_features.shape[0]
172
+ ),
173
+ resize_output_to_scale_idx=-1,
174
+ scales=[448, 896, 1344],
175
+ )
176
+
177
+ vision_features_list = [
178
+ split_chessboard(x, block_size[0], block_size[1])
179
+ for x, block_size in zip(vision_features_list, block_sizes)
180
+ ]
181
+
182
+ vision_features = torch.cat(
183
+ [einops.rearrange(x, "b c h w -> b (h w) c") for x in vision_features_list]
184
+ )
185
+
186
+ vision_features = self.mm_projector(vision_features)
187
+
188
+ vision_features_list = list(
189
+ vision_features.split(
190
+ [block_size[0] * block_size[1] for block_size in block_sizes], dim=0
191
+ )
192
+ )
193
+ vision_features_list = [
194
+ merge_chessboard(x, block_size[0], block_size[1])
195
+ for x, block_size in zip(vision_features_list, block_sizes)
196
+ ]
197
+
198
+ vision_features = torch.stack(
199
+ [einops.rearrange(x, "1 c h w -> (h w) c") for x in vision_features_list]
200
+ )
201
+
202
+ vision_features = einops.rearrange(vision_features, "n p d -> (n p) d")
203
+
204
+ return vision_features
205
+
206
+ def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> None:
207
+ params_dict = dict(self.named_parameters())
208
+
209
+ for name, loaded_weight in weights:
210
+ if name.startswith("llm."):
211
+ self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
212
+ else:
213
+ param = params_dict[name]
214
+ weight_loader = getattr(
215
+ param, "weight_loader", weight_utils.default_weight_loader
216
+ )
217
+ weight_loader(param, loaded_weight)
218
+
219
+ def pad_input_ids(
220
+ self, input_ids: list[int], mm_inputs: MultimodalInputs
221
+ ) -> list[int]:
222
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
223
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
224
+
225
+
226
+ def merge_chessboard(x, num_split_h, num_split_w):
227
+ """
228
+ x: b * n * c or b * h * w * c
229
+ out: b * c * h * w
230
+ Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square.
231
+ """
232
+ B = x.shape[0]
233
+ if x.dim() == 3:
234
+ N = x.shape[1]
235
+ x = einops.rearrange(
236
+ x, "b (h w) c -> b c h w", h=math.isqrt(N), w=math.isqrt(N)
237
+ )
238
+
239
+ assert B % (num_split_h * num_split_w) == 0
240
+ b = B // (num_split_h * num_split_w)
241
+
242
+ x_merge = torch.cat(
243
+ [
244
+ torch.cat(
245
+ [
246
+ x[(i * num_split_w + j) * b : (i * num_split_w + j + 1) * b]
247
+ for j in range(num_split_w)
248
+ ],
249
+ dim=-1,
250
+ )
251
+ for i in range(num_split_h)
252
+ ],
253
+ dim=-2,
254
+ )
255
+
256
+ return x_merge
257
+
258
+
259
+ def merge_features_for_dynamic_s2(
260
+ image_features, block_sizes, *, scales, resize_output_to_scale_idx
261
+ ):
262
+ image_features_each_image = []
263
+ new_block_sizes = []
264
+ block_cnt = 0
265
+ for block_size_each_image in block_sizes:
266
+ if block_size_each_image is None:
267
+ cur_features = image_features[block_cnt : block_cnt + 1]
268
+ cur_features = einops.rearrange(
269
+ cur_features,
270
+ "1 (h w) c -> 1 c h w",
271
+ h=math.isqrt(cur_features.shape[1]),
272
+ )
273
+ cur_features = cur_features.repeat(1, len(scales), 1, 1)
274
+ image_features_each_image.append(cur_features)
275
+ new_block_sizes.append((1, 1))
276
+ block_cnt += 1
277
+ else:
278
+ cur_features_each_scale = []
279
+ for scale in scales[:-1]:
280
+ num_blocks_this_scale = (scale // scales[0]) ** 2
281
+ cur_features_each_scale.append(
282
+ merge_chessboard(
283
+ image_features[block_cnt : block_cnt + num_blocks_this_scale],
284
+ num_split_h=scale // scales[0],
285
+ num_split_w=scale // scales[0],
286
+ )
287
+ ) # 1 * C * H * W
288
+ block_cnt += num_blocks_this_scale
289
+ num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
290
+ cur_features_each_scale.append(
291
+ merge_chessboard(
292
+ image_features[block_cnt : block_cnt + num_blocks_last_scale],
293
+ num_split_h=block_size_each_image[0],
294
+ num_split_w=block_size_each_image[1],
295
+ )
296
+ ) # 1 * C * H * W
297
+ block_cnt += num_blocks_last_scale
298
+
299
+ # resize and concat features from different scales
300
+ output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
301
+ cur_features = torch.cat(
302
+ [
303
+ F.interpolate(
304
+ cur_features_each_scale[i].to(torch.float32),
305
+ size=output_size,
306
+ mode="area",
307
+ ).to(cur_features_each_scale[i].dtype)
308
+ for i in range(len(cur_features_each_scale))
309
+ ],
310
+ dim=1,
311
+ )
312
+
313
+ image_features_each_image.append(cur_features)
314
+
315
+ if (
316
+ resize_output_to_scale_idx == len(scales) - 1
317
+ or resize_output_to_scale_idx == -1
318
+ ):
319
+ new_block_sizes.append(block_size_each_image)
320
+ else:
321
+ new_block_sizes.append(
322
+ (
323
+ scales[resize_output_to_scale_idx] // scales[0],
324
+ scales[resize_output_to_scale_idx] // scales[0],
325
+ )
326
+ )
327
+
328
+ assert block_cnt == len(
329
+ image_features
330
+ ), f"The number of blocks ({block_cnt}) does not match length of image_features ({len(image_features)})!"
331
+
332
+ return image_features_each_image, new_block_sizes
333
+
334
+
335
+ def split_chessboard(x, num_split_h, num_split_w):
336
+ """
337
+ x: b * c * h * w
338
+ out: b * c * h * w
339
+ Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension
340
+ """
341
+ B, C, H, W = x.shape
342
+ assert H % num_split_h == 0 and W % num_split_w == 0
343
+ h, w = H // num_split_h, W // num_split_w
344
+ x_split = torch.cat(
345
+ [
346
+ x[:, :, i * h : (i + 1) * h, j * w : (j + 1) * w]
347
+ for i in range(num_split_h)
348
+ for j in range(num_split_w)
349
+ ],
350
+ dim=0,
351
+ )
352
+ return x_split
353
+
354
+
355
+ EntryClass = [NVILAForConditionalGeneration]
@@ -0,0 +1,184 @@
1
+ import math
2
+ from collections.abc import Iterable
3
+ from typing import Any
4
+
5
+ import einops
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+ from transformers.configuration_utils import PretrainedConfig
11
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
12
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
13
+ from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
14
+
15
+ import sglang.srt.managers.mm_utils as mm_utils
16
+ import sglang.srt.model_loader.weight_utils as weight_utils
17
+ import sglang.srt.utils as utils
18
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
19
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
20
+ from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
21
+ from sglang.srt.managers.schedule_batch import (
22
+ Modality,
23
+ MultimodalDataItem,
24
+ MultimodalInputs,
25
+ )
26
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
28
+
29
+ MM_HIDDEN_SIZE = 1152
30
+
31
+
32
+ class NVILALiteConfig(PretrainedConfig):
33
+ model_type = "nvila_lite"
34
+ sub_configs = {
35
+ "text_config": Qwen2Config,
36
+ "vision_config": SiglipVisionConfig,
37
+ }
38
+ _auto_class = "AutoConfig"
39
+
40
+ def __init__(
41
+ self,
42
+ *,
43
+ text_config: dict[str, Any] | None = None,
44
+ vision_config: dict[str, Any] | None = None,
45
+ image_token_id: int | None = None,
46
+ video_token_id: int | None = None,
47
+ **kwargs,
48
+ ):
49
+ self.text_config = (
50
+ Qwen2Config(**text_config) if text_config is not None else Qwen2Config()
51
+ )
52
+ self.vision_config = (
53
+ SiglipVisionConfig(**vision_config)
54
+ if vision_config is not None
55
+ else SiglipVisionConfig()
56
+ )
57
+
58
+ self.image_token_id = image_token_id if image_token_id is not None else -1
59
+ self.video_token_id = video_token_id if video_token_id is not None else -1
60
+
61
+ super().__init__(**kwargs)
62
+
63
+
64
+ class NVILALiteMultiModalProjectorDownsampleBlock(nn.Module):
65
+ def forward(self, x: Tensor) -> Tensor:
66
+ batch_size, sequence_length, hidden_size = x.shape
67
+
68
+ feat_size = math.isqrt(sequence_length)
69
+
70
+ features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
71
+
72
+ pad_after = (3 - feat_size % 3) % 3
73
+ if pad_after > 0:
74
+ features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
75
+ feat_size = feat_size + pad_after
76
+
77
+ features = features.reshape(
78
+ batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size
79
+ )
80
+ features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
81
+ features = features.reshape(batch_size, -1, 9 * hidden_size)
82
+
83
+ return features
84
+
85
+
86
+ class NVILALiteMultiModalProjector(nn.Module):
87
+ def __init__(self, config: NVILALiteConfig):
88
+ super().__init__()
89
+
90
+ self.layers = nn.Sequential(
91
+ NVILALiteMultiModalProjectorDownsampleBlock(),
92
+ nn.LayerNorm(MM_HIDDEN_SIZE * 9),
93
+ nn.Linear(MM_HIDDEN_SIZE * 9, MM_HIDDEN_SIZE * 3),
94
+ nn.GELU(),
95
+ nn.LayerNorm(MM_HIDDEN_SIZE * 3),
96
+ nn.Linear(MM_HIDDEN_SIZE * 3, config.text_config.hidden_size),
97
+ nn.GELU(),
98
+ nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size),
99
+ )
100
+
101
+ def forward(self, x: Tensor) -> Tensor:
102
+ return self.layers(x)
103
+
104
+
105
+ class NVILALiteForConditionalGeneration(nn.Module):
106
+ def __init__(
107
+ self,
108
+ config: NVILALiteConfig,
109
+ quant_config: QuantizationConfig | None = None,
110
+ prefix: str = "",
111
+ ) -> None:
112
+ super().__init__()
113
+
114
+ self.config = config
115
+
116
+ self.vision_tower = SiglipVisionModel(config.vision_config)
117
+ self.mm_projector = NVILALiteMultiModalProjector(config)
118
+ self.llm = Qwen2ForCausalLM(
119
+ config=config.text_config,
120
+ quant_config=quant_config,
121
+ prefix=utils.add_prefix("llm", prefix),
122
+ )
123
+
124
+ def forward(
125
+ self,
126
+ input_ids: Tensor,
127
+ positions: Tensor,
128
+ forward_batch: ForwardBatch,
129
+ get_embedding: bool = False,
130
+ ) -> LogitsProcessorOutput:
131
+ output = mm_utils.general_mm_embed_routine(
132
+ input_ids=input_ids,
133
+ forward_batch=forward_batch,
134
+ language_model=self.llm,
135
+ data_embedding_funcs={
136
+ Modality.IMAGE: self.get_image_feature,
137
+ Modality.VIDEO: self.get_image_feature,
138
+ },
139
+ get_embedding=get_embedding,
140
+ positions=positions,
141
+ )
142
+
143
+ assert isinstance(output, LogitsProcessorOutput)
144
+
145
+ return output
146
+
147
+ def get_image_feature(self, mm_input: list[MultimodalDataItem]) -> Tensor:
148
+ pixel_values = torch.cat([torch.tensor(x.feature) for x in mm_input], dim=0)
149
+
150
+ vision_tower_output: BaseModelOutputWithPooling = self.vision_tower(
151
+ pixel_values,
152
+ output_hidden_states=True,
153
+ )
154
+ assert vision_tower_output.hidden_states is not None
155
+
156
+ vision_features = vision_tower_output.hidden_states[-2]
157
+
158
+ vision_features = self.mm_projector(vision_features)
159
+
160
+ vision_features = einops.rearrange(vision_features, "n p d -> (n p) d")
161
+
162
+ return vision_features
163
+
164
+ def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> None:
165
+ params_dict = dict(self.named_parameters())
166
+
167
+ for name, loaded_weight in weights:
168
+ if name.startswith("llm."):
169
+ self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
170
+ else:
171
+ param = params_dict[name]
172
+ weight_loader = getattr(
173
+ param, "weight_loader", weight_utils.default_weight_loader
174
+ )
175
+ weight_loader(param, loaded_weight)
176
+
177
+ def pad_input_ids(
178
+ self, input_ids: list[int], mm_inputs: MultimodalInputs
179
+ ) -> list[int]:
180
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
181
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
182
+
183
+
184
+ EntryClass = [NVILALiteForConditionalGeneration]
@@ -49,6 +49,7 @@ from sglang.srt.model_loader.weight_utils import (
49
49
  default_weight_loader,
50
50
  kv_cache_scales_loader,
51
51
  )
52
+ from sglang.srt.server_args import get_global_server_args
52
53
  from sglang.srt.utils import add_prefix, make_layers
53
54
 
54
55
  Qwen2Config = None
@@ -89,6 +90,9 @@ class Qwen2MLP(nn.Module):
89
90
  self.act_fn = SiluAndMul()
90
91
 
91
92
  def forward(self, x):
93
+ if get_global_server_args().rl_on_policy_target == "fsdp":
94
+ x = x.bfloat16()
95
+
92
96
  gate_up, _ = self.gate_up_proj(x)
93
97
  x = self.act_fn(gate_up)
94
98
  x, _ = self.down_proj(x)
@@ -275,6 +279,11 @@ class Qwen2Model(nn.Module):
275
279
  quant_config=quant_config,
276
280
  enable_tp=not is_dp_attention_enabled(),
277
281
  prefix=add_prefix("embed_tokens", prefix),
282
+ params_dtype=(
283
+ torch.float32
284
+ if get_global_server_args().rl_on_policy_target == "fsdp"
285
+ else None
286
+ ),
278
287
  )
279
288
  else:
280
289
  self.embed_tokens = PPMissingLayer()
@@ -295,7 +304,19 @@ class Qwen2Model(nn.Module):
295
304
  prefix=add_prefix("layers", prefix),
296
305
  )
297
306
  if self.pp_group.is_last_rank:
298
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
307
+ norm_kwargs = (
308
+ dict(
309
+ weight_dtype=torch.float32,
310
+ cast_x_before_out_mul=True,
311
+ override_orig_dtype=torch.float32,
312
+ fp32_residual=True,
313
+ )
314
+ if get_global_server_args().rl_on_policy_target == "fsdp"
315
+ else {}
316
+ )
317
+ self.norm = RMSNorm(
318
+ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
319
+ )
299
320
  else:
300
321
  self.norm = PPMissingLayer(return_tuple=True)
301
322
 
@@ -29,6 +29,7 @@ from sglang.srt.model_loader.weight_utils import (
29
29
  )
30
30
  from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
31
31
  from sglang.srt.models.qwen2 import Qwen2Model
32
+ from sglang.srt.server_args import get_global_server_args
32
33
  from sglang.srt.utils import (
33
34
  add_prefix,
34
35
  get_cmo_stream,
@@ -88,8 +89,16 @@ class Qwen3Attention(nn.Module):
88
89
  self.max_position_embeddings = max_position_embeddings
89
90
  self.tp_rank = get_tensor_model_parallel_rank()
90
91
 
91
- self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
92
- self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
92
+ norm_kwargs = (
93
+ dict(
94
+ weight_dtype=torch.float32,
95
+ cast_x_before_out_mul=True,
96
+ )
97
+ if get_global_server_args().rl_on_policy_target == "fsdp"
98
+ else {}
99
+ )
100
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
101
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
93
102
 
94
103
  self.qkv_proj = QKVParallelLinear(
95
104
  hidden_size,
@@ -158,10 +167,18 @@ class Qwen3Attention(nn.Module):
158
167
  hidden_states: torch.Tensor,
159
168
  forward_batch: ForwardBatch,
160
169
  ) -> torch.Tensor:
170
+ if get_global_server_args().rl_on_policy_target == "fsdp":
171
+ hidden_states = hidden_states.bfloat16()
172
+
161
173
  qkv, _ = self.qkv_proj(hidden_states)
162
174
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
163
175
  q, k = self._apply_qk_norm(q, k)
164
176
  q, k = self.rotary_emb(positions, q, k)
177
+
178
+ if get_global_server_args().rl_on_policy_target == "fsdp":
179
+ q = q.to(torch.bfloat16)
180
+ k = k.to(torch.bfloat16)
181
+
165
182
  attn_output = self.attn(q, k, v, forward_batch)
166
183
  output, _ = self.o_proj(attn_output)
167
184
  return output
@@ -204,9 +221,22 @@ class Qwen3DecoderLayer(nn.Module):
204
221
  quant_config=quant_config,
205
222
  prefix=add_prefix("mlp", prefix),
206
223
  )
207
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
224
+
225
+ norm_kwargs = (
226
+ dict(
227
+ weight_dtype=torch.float32,
228
+ cast_x_before_out_mul=True,
229
+ override_orig_dtype=torch.float32,
230
+ fp32_residual=True,
231
+ )
232
+ if get_global_server_args().rl_on_policy_target == "fsdp"
233
+ else {}
234
+ )
235
+ self.input_layernorm = RMSNorm(
236
+ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
237
+ )
208
238
  self.post_attention_layernorm = RMSNorm(
209
- config.hidden_size, eps=config.rms_norm_eps
239
+ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
210
240
  )
211
241
 
212
242
  self.layer_scatter_modes = LayerScatterModes.init_new(
@@ -241,16 +241,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
241
241
  )
242
242
 
243
243
  def op_experts(self, state):
244
- state.hidden_states_experts_output = self.experts.run_moe_core(
244
+ state.combine_input = self.experts.run_moe_core(
245
245
  dispatch_output=state.dispatch_output,
246
246
  )
247
247
 
248
248
  def op_combine_a(self, state):
249
249
  if self.ep_size > 1:
250
250
  self.experts.dispatcher.combine_a(
251
- hidden_states=state.pop("hidden_states_experts_output"),
252
- topk_ids=state.dispatch_output.topk_ids,
253
- topk_weights=state.dispatch_output.topk_weights,
251
+ combine_input=state.pop("combine_input"),
254
252
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
255
253
  )
256
254
  state.pop("dispatch_output")
@@ -185,6 +185,7 @@ class BaseMultimodalProcessor(ABC):
185
185
  "aspect_ratio_mask": Modality.IMAGE,
186
186
  "num_patches": Modality.IMAGE,
187
187
  "patch_pixel_values": Modality.IMAGE,
188
+ "block_sizes": Modality.IMAGE,
188
189
  # Audio-related attributes
189
190
  "audio_features": Modality.AUDIO,
190
191
  "audio_feature_lens": Modality.AUDIO,
@@ -17,7 +17,7 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
17
17
  def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
18
18
  super().__init__(hf_config, server_args, _processor, *args, **kwargs)
19
19
 
20
- # GLM-4.1V and GLM-4.5V specific tokens
20
+ # GLM-V specific tokens
21
21
  self.IMAGE_TOKEN = "<|image|>"
22
22
  self.VIDEO_TOKEN = "<|video|>"
23
23
  self.IMAGE_START_TOKEN = "<|begin_of_image|>"