sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,358 @@
1
+ from typing import Iterable, List, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+ from torch import nn
7
+
8
+ from sglang.srt.configs.deepseekvl2 import (
9
+ DeepseekVL2Config,
10
+ DeepseekVL2MlpProjectorConfig,
11
+ )
12
+ from sglang.srt.layers.linear import ReplicatedLinear
13
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
14
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
15
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
16
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
17
+ from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
18
+
19
+
20
+ class DeepseekVL2MlpProjector(nn.Module):
21
+ def __init__(
22
+ self,
23
+ config: DeepseekVL2MlpProjectorConfig,
24
+ quant_config: Optional[QuantizationConfig] = None,
25
+ ):
26
+
27
+ super().__init__()
28
+
29
+ self.config = config
30
+
31
+ if config.projector_type == "identity":
32
+ modules = nn.Identity()
33
+
34
+ elif config.projector_type == "linear":
35
+ self.layers = nn.ModuleList(
36
+ [
37
+ ReplicatedLinear(
38
+ config.input_dim,
39
+ config.n_embed,
40
+ quant_config=quant_config,
41
+ )
42
+ ]
43
+ )
44
+
45
+ elif config.projector_type == "mlp_gelu":
46
+ mlp_depth = config.depth
47
+ self.layers = nn.ModuleList(
48
+ [
49
+ ReplicatedLinear(
50
+ config.input_dim,
51
+ config.n_embed,
52
+ quant_config=quant_config,
53
+ )
54
+ ]
55
+ )
56
+ for _ in range(1, mlp_depth):
57
+ self.layers.append(nn.GELU())
58
+ self.layers.append(
59
+ ReplicatedLinear(
60
+ config.n_embed,
61
+ config.n_embed,
62
+ quant_config=quant_config,
63
+ )
64
+ )
65
+
66
+ elif config.projector_type == "downsample_mlp_gelu":
67
+ mlp_depth = config.depth
68
+ mlp_ratio = config.mlp_ratio
69
+ self.layers = nn.ModuleList(
70
+ [
71
+ ReplicatedLinear(
72
+ config.input_dim
73
+ * config.downsample_ratio
74
+ * config.downsample_ratio,
75
+ config.n_embed * mlp_ratio,
76
+ quant_config=quant_config,
77
+ )
78
+ ]
79
+ )
80
+ for _ in range(1, mlp_depth - 1):
81
+ self.layers.append(nn.GELU())
82
+ self.layers.append(
83
+ ReplicatedLinear(
84
+ config.n_embed * mlp_ratio,
85
+ config.n_embed * mlp_ratio,
86
+ quant_config=quant_config,
87
+ )
88
+ )
89
+ self.layers.append(nn.GELU())
90
+ self.layers.append(
91
+ ReplicatedLinear(
92
+ config.n_embed * mlp_ratio,
93
+ config.n_embed,
94
+ quant_config=quant_config,
95
+ )
96
+ )
97
+
98
+ else:
99
+ raise ValueError(f"Unknown projector type: {config.projector_type}")
100
+
101
+ if config.token_pooling:
102
+ self.token_pooling_layer = ReplicatedLinear(
103
+ config.input_dim * 4, config.input_dim, quant_config=quant_config
104
+ )
105
+
106
+ def forward(self, x):
107
+ if self.config.token_pooling:
108
+ batch_size, wxh, channels = x.shape
109
+ w = h = int(wxh**0.5)
110
+ x = x.view(batch_size, w, h, channels)
111
+ x = x.permute(0, 3, 1, 2)
112
+
113
+ patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
114
+ batch_size, channels, h_patches, w_patches, _, _ = patches.size()
115
+ patches = patches.contiguous().view(
116
+ batch_size, channels, h_patches * w_patches, -1
117
+ )
118
+ patches = patches.permute(0, 2, 1, 3).contiguous()
119
+ patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
120
+
121
+ x = self.token_pooling_layer(patches)[0]
122
+
123
+ elif self.config.projector_type == "downsample_mlp_gelu":
124
+ bs, hw, input_dim = x.shape
125
+ h = w = int((hw) ** 0.5)
126
+
127
+ """compute padding"""
128
+ if h % self.config.downsample_ratio:
129
+ pad = self.config.downsample_ratio - h % self.config.downsample_ratio
130
+ else:
131
+ pad = 0
132
+ x = x.reshape(bs, h, w, input_dim)
133
+ if pad > 0:
134
+ x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
135
+
136
+ """4 to 1 concat"""
137
+ x = x.permute(0, 3, 1, 2) # B, C, H, W
138
+ x = F.unfold(
139
+ x,
140
+ kernel_size=self.config.downsample_ratio,
141
+ stride=self.config.downsample_ratio,
142
+ padding=0,
143
+ ) # B, C*4, HW // 4
144
+ x = x.permute(0, 2, 1)
145
+
146
+ for layer in self.layers:
147
+ x = layer(x)
148
+ if isinstance(x, tuple):
149
+ x = x[0]
150
+ return x
151
+
152
+
153
+ # todo
154
+ class DeepseekVL2ForCausalLM(nn.Module):
155
+
156
+ def __init__(
157
+ self,
158
+ config: DeepseekVL2Config,
159
+ quant_config: Optional[QuantizationConfig] = None,
160
+ ):
161
+ super().__init__()
162
+
163
+ # ----------- vision encoder ------------
164
+ vision_config = config.vision_config
165
+ self.vision = self._init_vision_module(vision_config, quant_config)
166
+
167
+ # ----------- vl projector ------------
168
+ projector_config = config.projector_config
169
+ self.projector = DeepseekVL2MlpProjector(projector_config, quant_config)
170
+
171
+ self.tile_tag = config.tile_tag
172
+ self.global_view_pos = config.global_view_pos
173
+
174
+ embed_std = 1 / torch.sqrt(
175
+ torch.tensor(projector_config.n_embed, dtype=torch.float32)
176
+ )
177
+ if self.tile_tag == "2D":
178
+ self.image_newline = nn.Parameter(
179
+ torch.randn(projector_config.n_embed) * embed_std
180
+ )
181
+ self.view_seperator = nn.Parameter(
182
+ torch.randn(projector_config.n_embed) * embed_std
183
+ )
184
+ else:
185
+ raise ValueError(f"tile tag should be 2D, but got {self.tile_tag}")
186
+
187
+ # ----------- language model ------------
188
+ language_config = config.language_config
189
+ self.language_model = DeepseekV2ForCausalLM(language_config)
190
+
191
+ def _init_vision_module(
192
+ self, vision_config, quant_config: Optional[QuantizationConfig]
193
+ ) -> nn.Module:
194
+ # TODO: refactor vision model through timm wrapper from transformers
195
+ try:
196
+ import timm
197
+ except ImportError:
198
+ raise ImportError("Please install timm") from ImportError
199
+
200
+ model = timm.create_model(
201
+ "vit_so400m_patch14_siglip_384.webli",
202
+ pretrained=False,
203
+ num_classes=0,
204
+ dynamic_img_size=True,
205
+ dynamic_img_pad=True,
206
+ )
207
+
208
+ model = model.to(dtype=torch.get_default_dtype())
209
+ return model
210
+
211
+ def forward(
212
+ self,
213
+ input_ids: torch.Tensor,
214
+ positions: torch.Tensor,
215
+ forward_batch: ForwardBatch,
216
+ **kwargs: object,
217
+ ):
218
+ input_embeds = self.language_model.model.embed_tokens(input_ids)
219
+ if (
220
+ forward_batch.forward_mode.is_extend()
221
+ and forward_batch.contains_image_inputs()
222
+ ):
223
+ extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
224
+ extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
225
+ for idx, image in enumerate(forward_batch.mm_inputs):
226
+ if image is None:
227
+ continue
228
+ start_idx = extend_start_loc_cpu[idx]
229
+ end_idx = start_idx + extend_seq_lens_cpu[idx]
230
+ images_emb_mask = image.images_emb_mask.to(device="cuda")
231
+ image_features = self.get_image_feature(image)
232
+ input_embeds[start_idx:end_idx] = input_embeds[
233
+ start_idx:end_idx
234
+ ].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
235
+
236
+ outputs = self.language_model.forward(
237
+ input_ids=input_ids,
238
+ positions=positions,
239
+ forward_batch=forward_batch,
240
+ input_embeds=input_embeds,
241
+ )
242
+
243
+ return outputs
244
+
245
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
246
+ stacked_params_mapping = [
247
+ # (param_name, shard_name, shard_id)
248
+ ("qkv_proj", "q_proj", "q"),
249
+ ("qkv_proj", "k_proj", "k"),
250
+ ("qkv_proj", "v_proj", "v"),
251
+ ("gate_up_proj", "up_proj", 1),
252
+ ("gate_up_proj", "gate_proj", 0),
253
+ ]
254
+ params_dict = dict(self.named_parameters())
255
+ weights = list(weights)
256
+ for name, loaded_weight in weights:
257
+ if "language" in name:
258
+ name = name.replace("language.", "")
259
+ self.language_model.load_weights([(name, loaded_weight)])
260
+ else:
261
+ param = params_dict[name]
262
+ weights_loader = getattr(param, "weight_loader", default_weight_loader)
263
+ weights_loader(param, loaded_weight)
264
+
265
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
266
+ return input_ids
267
+
268
+ def get_image_feature(self, image_input: MultimodalInputs):
269
+ pixel_values = image_input.pixel_values.type(
270
+ next(self.vision.parameters()).dtype
271
+ ).to(device=next(self.vision.parameters()).device)
272
+ image_feature = self.vision.forward_features(pixel_values)
273
+ images_embeds = self.projector(image_feature)
274
+ _, hw, n_dim = images_embeds.shape
275
+ h = w = int(hw**0.5)
276
+ tile_index = 0
277
+ images_in_this_batch = []
278
+ images_spatial_crop = image_input.image_spatial_crop
279
+ for jdx in range(images_spatial_crop.shape[1]):
280
+ num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
281
+ if num_width_tiles == 0 or num_height_tiles == 0:
282
+ break
283
+ num_tiles_in_image = num_width_tiles * num_height_tiles
284
+
285
+ # [hw, D]
286
+ global_features = images_embeds[tile_index]
287
+
288
+ # [num_height_tiles * num_width_tiles, hw, D]
289
+ local_features = images_embeds[
290
+ tile_index + 1 : tile_index + 1 + num_tiles_in_image
291
+ ]
292
+ tile_index += num_tiles_in_image + 1
293
+
294
+ # format global and local features
295
+ # ----------------- global view add newline -----------------
296
+ # [hw, D] -> [h, w, D]
297
+ global_features = global_features.view(h, w, n_dim)
298
+
299
+ # [D] -> [h, 1, D]
300
+ new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
301
+
302
+ # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
303
+ global_features = torch.cat([global_features, new_lines_in_global], dim=1)
304
+
305
+ # [h, w + 1, D] -> [h * (w + 1), D]
306
+ global_features = global_features.view(-1, n_dim)
307
+
308
+ # ----------------- local view add newline -----------------
309
+ # [num_height_tiles * num_width_tiles, h * w, D] ->
310
+ # [num_height_tiles * h, num_width_tiles * w, D]
311
+ local_features = rearrange(
312
+ local_features,
313
+ "(th tw) (h w) d -> (th h) (tw w) d",
314
+ th=num_height_tiles,
315
+ tw=num_width_tiles,
316
+ h=h,
317
+ w=w,
318
+ )
319
+
320
+ # [D] -> [num_height_tiles * h, 1, D]
321
+ new_lines_in_local = repeat(
322
+ self.image_newline,
323
+ "d -> (th h) 1 d",
324
+ th=num_height_tiles,
325
+ h=h,
326
+ )
327
+
328
+ # [num_height_tiles * h, num_width_tiles * w + 1, D]
329
+ local_features = torch.cat([local_features, new_lines_in_local], dim=1)
330
+
331
+ # [num_height_tiles * h, num_width_tiles * w + 1, D]
332
+ # --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
333
+ local_features = local_features.view(-1, n_dim)
334
+
335
+ # merge global and local tiles
336
+ if self.global_view_pos == "head":
337
+ global_local_features = torch.cat(
338
+ [
339
+ global_features,
340
+ self.view_seperator[None, :],
341
+ local_features,
342
+ ]
343
+ )
344
+ else:
345
+ global_local_features = torch.cat(
346
+ [
347
+ local_features,
348
+ self.view_seperator[None, :],
349
+ global_features,
350
+ ]
351
+ )
352
+
353
+ images_in_this_batch.append(global_local_features)
354
+
355
+ return torch.cat(images_in_this_batch, dim=0)
356
+
357
+
358
+ EntryClass = DeepseekVL2ForCausalLM