sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,355 @@
1
+ import itertools
2
+ import math
3
+ from collections.abc import Iterable
4
+ from typing import Any
5
+
6
+ import einops
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch import Tensor
11
+ from transformers.configuration_utils import PretrainedConfig
12
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
13
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
14
+ from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
15
+
16
+ import sglang.srt.managers.mm_utils as mm_utils
17
+ import sglang.srt.model_loader.weight_utils as weight_utils
18
+ import sglang.srt.utils as utils
19
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
20
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
21
+ from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
22
+ from sglang.srt.managers.schedule_batch import (
23
+ Modality,
24
+ MultimodalDataItem,
25
+ MultimodalInputs,
26
+ )
27
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
29
+
30
+ MM_HIDDEN_SIZE = 3456
31
+
32
+
33
+ class NVILAConfig(PretrainedConfig):
34
+ model_type = "nvila"
35
+ sub_configs = {
36
+ "text_config": Qwen2Config,
37
+ "vision_config": SiglipVisionConfig,
38
+ }
39
+ _auto_class = "AutoConfig"
40
+
41
+ def __init__(
42
+ self,
43
+ *,
44
+ text_config: dict[str, Any] | None = None,
45
+ vision_config: dict[str, Any] | None = None,
46
+ image_token_id: int | None = None,
47
+ video_token_id: int | None = None,
48
+ **kwargs,
49
+ ):
50
+ self.text_config = (
51
+ Qwen2Config(**text_config) if text_config is not None else Qwen2Config()
52
+ )
53
+ self.vision_config = (
54
+ SiglipVisionConfig(**vision_config)
55
+ if vision_config is not None
56
+ else SiglipVisionConfig()
57
+ )
58
+
59
+ self.image_token_id = image_token_id if image_token_id is not None else -1
60
+ self.video_token_id = video_token_id if video_token_id is not None else -1
61
+
62
+ super().__init__(**kwargs)
63
+
64
+
65
+ class NVILAMultiModalProjectorDownsampleBlock(nn.Module):
66
+ def forward(self, x: Tensor) -> Tensor:
67
+ batch_size, sequence_length, hidden_size = x.shape
68
+
69
+ feat_size = math.isqrt(sequence_length)
70
+
71
+ features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
72
+
73
+ pad_after = feat_size % 2
74
+ if pad_after > 0:
75
+ features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
76
+ feat_size = feat_size + pad_after
77
+
78
+ features = features.reshape(
79
+ batch_size, feat_size // 2, 2, feat_size // 2, 2, hidden_size
80
+ )
81
+ features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
82
+ features = features.reshape(batch_size, -1, 4 * hidden_size)
83
+
84
+ return features
85
+
86
+
87
+ class NVILAMultiModalProjector(nn.Module):
88
+ def __init__(self, config: NVILAConfig):
89
+ super().__init__()
90
+
91
+ self.layers = nn.Sequential(
92
+ NVILAMultiModalProjectorDownsampleBlock(),
93
+ nn.LayerNorm(MM_HIDDEN_SIZE * 4),
94
+ nn.Linear(MM_HIDDEN_SIZE * 4, config.text_config.hidden_size),
95
+ nn.GELU(),
96
+ nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size),
97
+ )
98
+
99
+ def forward(self, x: Tensor) -> Tensor:
100
+ return self.layers(x)
101
+
102
+
103
+ class NVILAForConditionalGeneration(nn.Module):
104
+ def __init__(
105
+ self,
106
+ config: NVILAConfig,
107
+ quant_config: QuantizationConfig | None = None,
108
+ prefix: str = "",
109
+ ) -> None:
110
+ super().__init__()
111
+
112
+ self.config = config
113
+
114
+ self.vision_tower = SiglipVisionModel(config.vision_config)
115
+ self.mm_projector = NVILAMultiModalProjector(config)
116
+ self.llm = Qwen2ForCausalLM(
117
+ config=config.text_config,
118
+ quant_config=quant_config,
119
+ prefix=utils.add_prefix("llm", prefix),
120
+ )
121
+
122
+ def forward(
123
+ self,
124
+ input_ids: Tensor,
125
+ positions: Tensor,
126
+ forward_batch: ForwardBatch,
127
+ get_embedding: bool = False,
128
+ ) -> LogitsProcessorOutput:
129
+ output = mm_utils.general_mm_embed_routine(
130
+ input_ids=input_ids,
131
+ forward_batch=forward_batch,
132
+ language_model=self.llm,
133
+ data_embedding_funcs={
134
+ Modality.IMAGE: self.get_image_feature,
135
+ Modality.VIDEO: self.get_image_feature,
136
+ },
137
+ get_embedding=get_embedding,
138
+ positions=positions,
139
+ )
140
+
141
+ assert isinstance(output, LogitsProcessorOutput)
142
+
143
+ return output
144
+
145
+ def get_image_feature(self, mm_input: list[MultimodalDataItem]) -> Tensor:
146
+ block_sizes = (
147
+ list(
148
+ itertools.chain.from_iterable(
149
+ x.block_sizes for x in mm_input if hasattr(x, "block_sizes")
150
+ )
151
+ )
152
+ or None
153
+ )
154
+ pixel_values = torch.cat([torch.tensor(x.feature) for x in mm_input], dim=0)
155
+
156
+ vision_tower_output: BaseModelOutputWithPooling = self.vision_tower(
157
+ pixel_values.to(
158
+ device=self.vision_tower.device, dtype=self.vision_tower.dtype
159
+ ),
160
+ output_hidden_states=True,
161
+ )
162
+ assert vision_tower_output.hidden_states is not None
163
+
164
+ vision_features: Tensor = vision_tower_output.hidden_states[-2]
165
+
166
+ vision_features_list, block_sizes = merge_features_for_dynamic_s2(
167
+ vision_features,
168
+ block_sizes=(
169
+ block_sizes
170
+ if block_sizes is not None
171
+ else [None] * vision_features.shape[0]
172
+ ),
173
+ resize_output_to_scale_idx=-1,
174
+ scales=[448, 896, 1344],
175
+ )
176
+
177
+ vision_features_list = [
178
+ split_chessboard(x, block_size[0], block_size[1])
179
+ for x, block_size in zip(vision_features_list, block_sizes)
180
+ ]
181
+
182
+ vision_features = torch.cat(
183
+ [einops.rearrange(x, "b c h w -> b (h w) c") for x in vision_features_list]
184
+ )
185
+
186
+ vision_features = self.mm_projector(vision_features)
187
+
188
+ vision_features_list = list(
189
+ vision_features.split(
190
+ [block_size[0] * block_size[1] for block_size in block_sizes], dim=0
191
+ )
192
+ )
193
+ vision_features_list = [
194
+ merge_chessboard(x, block_size[0], block_size[1])
195
+ for x, block_size in zip(vision_features_list, block_sizes)
196
+ ]
197
+
198
+ vision_features = torch.stack(
199
+ [einops.rearrange(x, "1 c h w -> (h w) c") for x in vision_features_list]
200
+ )
201
+
202
+ vision_features = einops.rearrange(vision_features, "n p d -> (n p) d")
203
+
204
+ return vision_features
205
+
206
+ def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> None:
207
+ params_dict = dict(self.named_parameters())
208
+
209
+ for name, loaded_weight in weights:
210
+ if name.startswith("llm."):
211
+ self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
212
+ else:
213
+ param = params_dict[name]
214
+ weight_loader = getattr(
215
+ param, "weight_loader", weight_utils.default_weight_loader
216
+ )
217
+ weight_loader(param, loaded_weight)
218
+
219
+ def pad_input_ids(
220
+ self, input_ids: list[int], mm_inputs: MultimodalInputs
221
+ ) -> list[int]:
222
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
223
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
224
+
225
+
226
+ def merge_chessboard(x, num_split_h, num_split_w):
227
+ """
228
+ x: b * n * c or b * h * w * c
229
+ out: b * c * h * w
230
+ Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square.
231
+ """
232
+ B = x.shape[0]
233
+ if x.dim() == 3:
234
+ N = x.shape[1]
235
+ x = einops.rearrange(
236
+ x, "b (h w) c -> b c h w", h=math.isqrt(N), w=math.isqrt(N)
237
+ )
238
+
239
+ assert B % (num_split_h * num_split_w) == 0
240
+ b = B // (num_split_h * num_split_w)
241
+
242
+ x_merge = torch.cat(
243
+ [
244
+ torch.cat(
245
+ [
246
+ x[(i * num_split_w + j) * b : (i * num_split_w + j + 1) * b]
247
+ for j in range(num_split_w)
248
+ ],
249
+ dim=-1,
250
+ )
251
+ for i in range(num_split_h)
252
+ ],
253
+ dim=-2,
254
+ )
255
+
256
+ return x_merge
257
+
258
+
259
+ def merge_features_for_dynamic_s2(
260
+ image_features, block_sizes, *, scales, resize_output_to_scale_idx
261
+ ):
262
+ image_features_each_image = []
263
+ new_block_sizes = []
264
+ block_cnt = 0
265
+ for block_size_each_image in block_sizes:
266
+ if block_size_each_image is None:
267
+ cur_features = image_features[block_cnt : block_cnt + 1]
268
+ cur_features = einops.rearrange(
269
+ cur_features,
270
+ "1 (h w) c -> 1 c h w",
271
+ h=math.isqrt(cur_features.shape[1]),
272
+ )
273
+ cur_features = cur_features.repeat(1, len(scales), 1, 1)
274
+ image_features_each_image.append(cur_features)
275
+ new_block_sizes.append((1, 1))
276
+ block_cnt += 1
277
+ else:
278
+ cur_features_each_scale = []
279
+ for scale in scales[:-1]:
280
+ num_blocks_this_scale = (scale // scales[0]) ** 2
281
+ cur_features_each_scale.append(
282
+ merge_chessboard(
283
+ image_features[block_cnt : block_cnt + num_blocks_this_scale],
284
+ num_split_h=scale // scales[0],
285
+ num_split_w=scale // scales[0],
286
+ )
287
+ ) # 1 * C * H * W
288
+ block_cnt += num_blocks_this_scale
289
+ num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
290
+ cur_features_each_scale.append(
291
+ merge_chessboard(
292
+ image_features[block_cnt : block_cnt + num_blocks_last_scale],
293
+ num_split_h=block_size_each_image[0],
294
+ num_split_w=block_size_each_image[1],
295
+ )
296
+ ) # 1 * C * H * W
297
+ block_cnt += num_blocks_last_scale
298
+
299
+ # resize and concat features from different scales
300
+ output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
301
+ cur_features = torch.cat(
302
+ [
303
+ F.interpolate(
304
+ cur_features_each_scale[i].to(torch.float32),
305
+ size=output_size,
306
+ mode="area",
307
+ ).to(cur_features_each_scale[i].dtype)
308
+ for i in range(len(cur_features_each_scale))
309
+ ],
310
+ dim=1,
311
+ )
312
+
313
+ image_features_each_image.append(cur_features)
314
+
315
+ if (
316
+ resize_output_to_scale_idx == len(scales) - 1
317
+ or resize_output_to_scale_idx == -1
318
+ ):
319
+ new_block_sizes.append(block_size_each_image)
320
+ else:
321
+ new_block_sizes.append(
322
+ (
323
+ scales[resize_output_to_scale_idx] // scales[0],
324
+ scales[resize_output_to_scale_idx] // scales[0],
325
+ )
326
+ )
327
+
328
+ assert block_cnt == len(
329
+ image_features
330
+ ), f"The number of blocks ({block_cnt}) does not match length of image_features ({len(image_features)})!"
331
+
332
+ return image_features_each_image, new_block_sizes
333
+
334
+
335
+ def split_chessboard(x, num_split_h, num_split_w):
336
+ """
337
+ x: b * c * h * w
338
+ out: b * c * h * w
339
+ Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension
340
+ """
341
+ B, C, H, W = x.shape
342
+ assert H % num_split_h == 0 and W % num_split_w == 0
343
+ h, w = H // num_split_h, W // num_split_w
344
+ x_split = torch.cat(
345
+ [
346
+ x[:, :, i * h : (i + 1) * h, j * w : (j + 1) * w]
347
+ for i in range(num_split_h)
348
+ for j in range(num_split_w)
349
+ ],
350
+ dim=0,
351
+ )
352
+ return x_split
353
+
354
+
355
+ EntryClass = [NVILAForConditionalGeneration]
@@ -0,0 +1,184 @@
1
+ import math
2
+ from collections.abc import Iterable
3
+ from typing import Any
4
+
5
+ import einops
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+ from transformers.configuration_utils import PretrainedConfig
11
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
12
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
13
+ from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
14
+
15
+ import sglang.srt.managers.mm_utils as mm_utils
16
+ import sglang.srt.model_loader.weight_utils as weight_utils
17
+ import sglang.srt.utils as utils
18
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
19
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
20
+ from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
21
+ from sglang.srt.managers.schedule_batch import (
22
+ Modality,
23
+ MultimodalDataItem,
24
+ MultimodalInputs,
25
+ )
26
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
28
+
29
+ MM_HIDDEN_SIZE = 1152
30
+
31
+
32
+ class NVILALiteConfig(PretrainedConfig):
33
+ model_type = "nvila_lite"
34
+ sub_configs = {
35
+ "text_config": Qwen2Config,
36
+ "vision_config": SiglipVisionConfig,
37
+ }
38
+ _auto_class = "AutoConfig"
39
+
40
+ def __init__(
41
+ self,
42
+ *,
43
+ text_config: dict[str, Any] | None = None,
44
+ vision_config: dict[str, Any] | None = None,
45
+ image_token_id: int | None = None,
46
+ video_token_id: int | None = None,
47
+ **kwargs,
48
+ ):
49
+ self.text_config = (
50
+ Qwen2Config(**text_config) if text_config is not None else Qwen2Config()
51
+ )
52
+ self.vision_config = (
53
+ SiglipVisionConfig(**vision_config)
54
+ if vision_config is not None
55
+ else SiglipVisionConfig()
56
+ )
57
+
58
+ self.image_token_id = image_token_id if image_token_id is not None else -1
59
+ self.video_token_id = video_token_id if video_token_id is not None else -1
60
+
61
+ super().__init__(**kwargs)
62
+
63
+
64
+ class NVILALiteMultiModalProjectorDownsampleBlock(nn.Module):
65
+ def forward(self, x: Tensor) -> Tensor:
66
+ batch_size, sequence_length, hidden_size = x.shape
67
+
68
+ feat_size = math.isqrt(sequence_length)
69
+
70
+ features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
71
+
72
+ pad_after = (3 - feat_size % 3) % 3
73
+ if pad_after > 0:
74
+ features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
75
+ feat_size = feat_size + pad_after
76
+
77
+ features = features.reshape(
78
+ batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size
79
+ )
80
+ features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
81
+ features = features.reshape(batch_size, -1, 9 * hidden_size)
82
+
83
+ return features
84
+
85
+
86
+ class NVILALiteMultiModalProjector(nn.Module):
87
+ def __init__(self, config: NVILALiteConfig):
88
+ super().__init__()
89
+
90
+ self.layers = nn.Sequential(
91
+ NVILALiteMultiModalProjectorDownsampleBlock(),
92
+ nn.LayerNorm(MM_HIDDEN_SIZE * 9),
93
+ nn.Linear(MM_HIDDEN_SIZE * 9, MM_HIDDEN_SIZE * 3),
94
+ nn.GELU(),
95
+ nn.LayerNorm(MM_HIDDEN_SIZE * 3),
96
+ nn.Linear(MM_HIDDEN_SIZE * 3, config.text_config.hidden_size),
97
+ nn.GELU(),
98
+ nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size),
99
+ )
100
+
101
+ def forward(self, x: Tensor) -> Tensor:
102
+ return self.layers(x)
103
+
104
+
105
+ class NVILALiteForConditionalGeneration(nn.Module):
106
+ def __init__(
107
+ self,
108
+ config: NVILALiteConfig,
109
+ quant_config: QuantizationConfig | None = None,
110
+ prefix: str = "",
111
+ ) -> None:
112
+ super().__init__()
113
+
114
+ self.config = config
115
+
116
+ self.vision_tower = SiglipVisionModel(config.vision_config)
117
+ self.mm_projector = NVILALiteMultiModalProjector(config)
118
+ self.llm = Qwen2ForCausalLM(
119
+ config=config.text_config,
120
+ quant_config=quant_config,
121
+ prefix=utils.add_prefix("llm", prefix),
122
+ )
123
+
124
+ def forward(
125
+ self,
126
+ input_ids: Tensor,
127
+ positions: Tensor,
128
+ forward_batch: ForwardBatch,
129
+ get_embedding: bool = False,
130
+ ) -> LogitsProcessorOutput:
131
+ output = mm_utils.general_mm_embed_routine(
132
+ input_ids=input_ids,
133
+ forward_batch=forward_batch,
134
+ language_model=self.llm,
135
+ data_embedding_funcs={
136
+ Modality.IMAGE: self.get_image_feature,
137
+ Modality.VIDEO: self.get_image_feature,
138
+ },
139
+ get_embedding=get_embedding,
140
+ positions=positions,
141
+ )
142
+
143
+ assert isinstance(output, LogitsProcessorOutput)
144
+
145
+ return output
146
+
147
+ def get_image_feature(self, mm_input: list[MultimodalDataItem]) -> Tensor:
148
+ pixel_values = torch.cat([torch.tensor(x.feature) for x in mm_input], dim=0)
149
+
150
+ vision_tower_output: BaseModelOutputWithPooling = self.vision_tower(
151
+ pixel_values,
152
+ output_hidden_states=True,
153
+ )
154
+ assert vision_tower_output.hidden_states is not None
155
+
156
+ vision_features = vision_tower_output.hidden_states[-2]
157
+
158
+ vision_features = self.mm_projector(vision_features)
159
+
160
+ vision_features = einops.rearrange(vision_features, "n p d -> (n p) d")
161
+
162
+ return vision_features
163
+
164
+ def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> None:
165
+ params_dict = dict(self.named_parameters())
166
+
167
+ for name, loaded_weight in weights:
168
+ if name.startswith("llm."):
169
+ self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
170
+ else:
171
+ param = params_dict[name]
172
+ weight_loader = getattr(
173
+ param, "weight_loader", weight_utils.default_weight_loader
174
+ )
175
+ weight_loader(param, loaded_weight)
176
+
177
+ def pad_input_ids(
178
+ self, input_ids: list[int], mm_inputs: MultimodalInputs
179
+ ) -> list[int]:
180
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
181
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
182
+
183
+
184
+ EntryClass = [NVILALiteForConditionalGeneration]
@@ -49,6 +49,7 @@ from sglang.srt.model_loader.weight_utils import (
49
49
  default_weight_loader,
50
50
  kv_cache_scales_loader,
51
51
  )
52
+ from sglang.srt.server_args import get_global_server_args
52
53
  from sglang.srt.utils import add_prefix, make_layers
53
54
 
54
55
  Qwen2Config = None
@@ -89,6 +90,9 @@ class Qwen2MLP(nn.Module):
89
90
  self.act_fn = SiluAndMul()
90
91
 
91
92
  def forward(self, x):
93
+ if get_global_server_args().rl_on_policy_target == "fsdp":
94
+ x = x.bfloat16()
95
+
92
96
  gate_up, _ = self.gate_up_proj(x)
93
97
  x = self.act_fn(gate_up)
94
98
  x, _ = self.down_proj(x)
@@ -275,6 +279,11 @@ class Qwen2Model(nn.Module):
275
279
  quant_config=quant_config,
276
280
  enable_tp=not is_dp_attention_enabled(),
277
281
  prefix=add_prefix("embed_tokens", prefix),
282
+ params_dtype=(
283
+ torch.float32
284
+ if get_global_server_args().rl_on_policy_target == "fsdp"
285
+ else None
286
+ ),
278
287
  )
279
288
  else:
280
289
  self.embed_tokens = PPMissingLayer()
@@ -295,7 +304,19 @@ class Qwen2Model(nn.Module):
295
304
  prefix=add_prefix("layers", prefix),
296
305
  )
297
306
  if self.pp_group.is_last_rank:
298
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
307
+ norm_kwargs = (
308
+ dict(
309
+ weight_dtype=torch.float32,
310
+ cast_x_before_out_mul=True,
311
+ override_orig_dtype=torch.float32,
312
+ fp32_residual=True,
313
+ )
314
+ if get_global_server_args().rl_on_policy_target == "fsdp"
315
+ else {}
316
+ )
317
+ self.norm = RMSNorm(
318
+ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
319
+ )
299
320
  else:
300
321
  self.norm = PPMissingLayer(return_tuple=True)
301
322
 
@@ -441,7 +462,7 @@ class Qwen2ForCausalLM(nn.Module):
441
462
  self.pp_group.send(
442
463
  self.model.embed_tokens.weight, dst=self.pp_group.last_rank
443
464
  )
444
- else:
465
+ elif self.pp_group.is_last_rank:
445
466
  emb_token_weight = self.pp_group.recv(
446
467
  size=(config.vocab_size, config.hidden_size),
447
468
  dtype=next(self.model.parameters()).dtype,
@@ -473,10 +473,16 @@ class Qwen2MoeDecoderLayer(nn.Module):
473
473
  hidden_states: torch.Tensor,
474
474
  forward_batch: ForwardBatch,
475
475
  residual: Optional[torch.Tensor],
476
+ captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
476
477
  ) -> Tuple[torch.Tensor, torch.Tensor]:
477
478
 
478
- hidden_states, residual = self.layer_communicator.prepare_attn(
479
- hidden_states, residual, forward_batch
479
+ hidden_states, residual = (
480
+ self.layer_communicator.prepare_attn_and_capture_last_layer_outputs(
481
+ hidden_states,
482
+ residual,
483
+ forward_batch,
484
+ captured_last_layer_outputs=captured_last_layer_outputs,
485
+ )
480
486
  )
481
487
 
482
488
  if hidden_states.shape[0] != 0:
@@ -553,6 +559,11 @@ class Qwen2MoeModel(nn.Module):
553
559
  # For EAGLE3 support
554
560
  self.layers_to_capture = []
555
561
 
562
+ def set_eagle3_layers_to_capture(self, layers_to_capture: List[int]):
563
+ self.layers_to_capture = layers_to_capture
564
+ for layer_id in self.layers_to_capture:
565
+ setattr(self.layers[layer_id], "_is_layer_to_capture", True)
566
+
556
567
  def forward(
557
568
  self,
558
569
  input_ids: torch.Tensor,
@@ -585,12 +596,6 @@ class Qwen2MoeModel(nn.Module):
585
596
  )
586
597
  else:
587
598
  for i in range(self.start_layer, self.end_layer):
588
- if i in self.layers_to_capture:
589
- aux_hidden_states.append(
590
- hidden_states + residual
591
- if residual is not None
592
- else hidden_states
593
- )
594
599
  ctx = (
595
600
  nullcontext()
596
601
  if get_global_server_args().enable_piecewise_cuda_graph
@@ -599,7 +604,15 @@ class Qwen2MoeModel(nn.Module):
599
604
  with ctx:
600
605
  layer = self.layers[i]
601
606
  hidden_states, residual = layer(
602
- positions, hidden_states, forward_batch, residual
607
+ positions,
608
+ hidden_states,
609
+ forward_batch,
610
+ residual,
611
+ captured_last_layer_outputs=(
612
+ aux_hidden_states
613
+ if getattr(layer, "_is_layer_to_capture", False)
614
+ else None
615
+ ),
603
616
  )
604
617
  if not self.pp_group.is_last_rank:
605
618
  return PPProxyTensors(
@@ -830,13 +843,15 @@ class Qwen2MoeForCausalLM(nn.Module):
830
843
  self.capture_aux_hidden_states = True
831
844
  if layer_ids is None:
832
845
  num_layers = self.config.num_hidden_layers
833
- self.model.layers_to_capture = [
834
- 2,
835
- num_layers // 2,
836
- num_layers - 3,
837
- ] # Specific layers for EAGLE3 support
846
+ self.model.set_eagle3_layers_to_capture(
847
+ [
848
+ 2,
849
+ num_layers // 2,
850
+ num_layers - 3,
851
+ ]
852
+ ) # Specific layers for EAGLE3 support
838
853
  else:
839
- self.model.layers_to_capture = [val + 1 for val in layer_ids]
854
+ self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids])
840
855
 
841
856
 
842
857
  EntryClass = Qwen2MoeForCausalLM