sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,358 @@
1
+ from typing import Iterable, List, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+ from torch import nn
7
+
8
+ from sglang.srt.configs.deepseekvl2 import (
9
+ DeepseekVL2Config,
10
+ DeepseekVL2MlpProjectorConfig,
11
+ )
12
+ from sglang.srt.layers.linear import ReplicatedLinear
13
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
14
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
15
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
16
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
17
+ from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
18
+
19
+
20
+ class DeepseekVL2MlpProjector(nn.Module):
21
+ def __init__(
22
+ self,
23
+ config: DeepseekVL2MlpProjectorConfig,
24
+ quant_config: Optional[QuantizationConfig] = None,
25
+ ):
26
+
27
+ super().__init__()
28
+
29
+ self.config = config
30
+
31
+ if config.projector_type == "identity":
32
+ modules = nn.Identity()
33
+
34
+ elif config.projector_type == "linear":
35
+ self.layers = nn.ModuleList(
36
+ [
37
+ ReplicatedLinear(
38
+ config.input_dim,
39
+ config.n_embed,
40
+ quant_config=quant_config,
41
+ )
42
+ ]
43
+ )
44
+
45
+ elif config.projector_type == "mlp_gelu":
46
+ mlp_depth = config.depth
47
+ self.layers = nn.ModuleList(
48
+ [
49
+ ReplicatedLinear(
50
+ config.input_dim,
51
+ config.n_embed,
52
+ quant_config=quant_config,
53
+ )
54
+ ]
55
+ )
56
+ for _ in range(1, mlp_depth):
57
+ self.layers.append(nn.GELU())
58
+ self.layers.append(
59
+ ReplicatedLinear(
60
+ config.n_embed,
61
+ config.n_embed,
62
+ quant_config=quant_config,
63
+ )
64
+ )
65
+
66
+ elif config.projector_type == "downsample_mlp_gelu":
67
+ mlp_depth = config.depth
68
+ mlp_ratio = config.mlp_ratio
69
+ self.layers = nn.ModuleList(
70
+ [
71
+ ReplicatedLinear(
72
+ config.input_dim
73
+ * config.downsample_ratio
74
+ * config.downsample_ratio,
75
+ config.n_embed * mlp_ratio,
76
+ quant_config=quant_config,
77
+ )
78
+ ]
79
+ )
80
+ for _ in range(1, mlp_depth - 1):
81
+ self.layers.append(nn.GELU())
82
+ self.layers.append(
83
+ ReplicatedLinear(
84
+ config.n_embed * mlp_ratio,
85
+ config.n_embed * mlp_ratio,
86
+ quant_config=quant_config,
87
+ )
88
+ )
89
+ self.layers.append(nn.GELU())
90
+ self.layers.append(
91
+ ReplicatedLinear(
92
+ config.n_embed * mlp_ratio,
93
+ config.n_embed,
94
+ quant_config=quant_config,
95
+ )
96
+ )
97
+
98
+ else:
99
+ raise ValueError(f"Unknown projector type: {config.projector_type}")
100
+
101
+ if config.token_pooling:
102
+ self.token_pooling_layer = ReplicatedLinear(
103
+ config.input_dim * 4, config.input_dim, quant_config=quant_config
104
+ )
105
+
106
+ def forward(self, x):
107
+ if self.config.token_pooling:
108
+ batch_size, wxh, channels = x.shape
109
+ w = h = int(wxh**0.5)
110
+ x = x.view(batch_size, w, h, channels)
111
+ x = x.permute(0, 3, 1, 2)
112
+
113
+ patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
114
+ batch_size, channels, h_patches, w_patches, _, _ = patches.size()
115
+ patches = patches.contiguous().view(
116
+ batch_size, channels, h_patches * w_patches, -1
117
+ )
118
+ patches = patches.permute(0, 2, 1, 3).contiguous()
119
+ patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
120
+
121
+ x = self.token_pooling_layer(patches)[0]
122
+
123
+ elif self.config.projector_type == "downsample_mlp_gelu":
124
+ bs, hw, input_dim = x.shape
125
+ h = w = int((hw) ** 0.5)
126
+
127
+ """compute padding"""
128
+ if h % self.config.downsample_ratio:
129
+ pad = self.config.downsample_ratio - h % self.config.downsample_ratio
130
+ else:
131
+ pad = 0
132
+ x = x.reshape(bs, h, w, input_dim)
133
+ if pad > 0:
134
+ x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
135
+
136
+ """4 to 1 concat"""
137
+ x = x.permute(0, 3, 1, 2) # B, C, H, W
138
+ x = F.unfold(
139
+ x,
140
+ kernel_size=self.config.downsample_ratio,
141
+ stride=self.config.downsample_ratio,
142
+ padding=0,
143
+ ) # B, C*4, HW // 4
144
+ x = x.permute(0, 2, 1)
145
+
146
+ for layer in self.layers:
147
+ x = layer(x)
148
+ if isinstance(x, tuple):
149
+ x = x[0]
150
+ return x
151
+
152
+
153
+ # todo
154
+ class DeepseekVL2ForCausalLM(nn.Module):
155
+
156
+ def __init__(
157
+ self,
158
+ config: DeepseekVL2Config,
159
+ quant_config: Optional[QuantizationConfig] = None,
160
+ ):
161
+ super().__init__()
162
+
163
+ # ----------- vision encoder ------------
164
+ vision_config = config.vision_config
165
+ self.vision = self._init_vision_module(vision_config, quant_config)
166
+
167
+ # ----------- vl projector ------------
168
+ projector_config = config.projector_config
169
+ self.projector = DeepseekVL2MlpProjector(projector_config, quant_config)
170
+
171
+ self.tile_tag = config.tile_tag
172
+ self.global_view_pos = config.global_view_pos
173
+
174
+ embed_std = 1 / torch.sqrt(
175
+ torch.tensor(projector_config.n_embed, dtype=torch.float32)
176
+ )
177
+ if self.tile_tag == "2D":
178
+ self.image_newline = nn.Parameter(
179
+ torch.randn(projector_config.n_embed) * embed_std
180
+ )
181
+ self.view_seperator = nn.Parameter(
182
+ torch.randn(projector_config.n_embed) * embed_std
183
+ )
184
+ else:
185
+ raise ValueError(f"tile tag should be 2D, but got {self.tile_tag}")
186
+
187
+ # ----------- language model ------------
188
+ language_config = config.language_config
189
+ self.language_model = DeepseekV2ForCausalLM(language_config)
190
+
191
+ def _init_vision_module(
192
+ self, vision_config, quant_config: Optional[QuantizationConfig]
193
+ ) -> nn.Module:
194
+ # TODO: refactor vision model through timm wrapper from transformers
195
+ try:
196
+ import timm
197
+ except ImportError:
198
+ raise ImportError("Please install timm") from ImportError
199
+
200
+ model = timm.create_model(
201
+ "vit_so400m_patch14_siglip_384.webli",
202
+ pretrained=False,
203
+ num_classes=0,
204
+ dynamic_img_size=True,
205
+ dynamic_img_pad=True,
206
+ )
207
+
208
+ model = model.to(dtype=torch.get_default_dtype())
209
+ return model
210
+
211
+ def forward(
212
+ self,
213
+ input_ids: torch.Tensor,
214
+ positions: torch.Tensor,
215
+ forward_batch: ForwardBatch,
216
+ **kwargs: object,
217
+ ):
218
+ input_embeds = self.language_model.model.embed_tokens(input_ids)
219
+ if (
220
+ forward_batch.forward_mode.is_extend()
221
+ and forward_batch.contains_image_inputs()
222
+ ):
223
+ extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
224
+ extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
225
+ for idx, image in enumerate(forward_batch.mm_inputs):
226
+ if image is None:
227
+ continue
228
+ start_idx = extend_start_loc_cpu[idx]
229
+ end_idx = start_idx + extend_seq_lens_cpu[idx]
230
+ images_emb_mask = image.images_emb_mask.to(device="cuda")
231
+ image_features = self.get_image_feature(image)
232
+ input_embeds[start_idx:end_idx] = input_embeds[
233
+ start_idx:end_idx
234
+ ].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
235
+
236
+ outputs = self.language_model.forward(
237
+ input_ids=input_ids,
238
+ positions=positions,
239
+ forward_batch=forward_batch,
240
+ input_embeds=input_embeds,
241
+ )
242
+
243
+ return outputs
244
+
245
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
246
+ stacked_params_mapping = [
247
+ # (param_name, shard_name, shard_id)
248
+ ("qkv_proj", "q_proj", "q"),
249
+ ("qkv_proj", "k_proj", "k"),
250
+ ("qkv_proj", "v_proj", "v"),
251
+ ("gate_up_proj", "up_proj", 1),
252
+ ("gate_up_proj", "gate_proj", 0),
253
+ ]
254
+ params_dict = dict(self.named_parameters())
255
+ weights = list(weights)
256
+ for name, loaded_weight in weights:
257
+ if "language" in name:
258
+ name = name.replace("language.", "")
259
+ self.language_model.load_weights([(name, loaded_weight)])
260
+ else:
261
+ param = params_dict[name]
262
+ weights_loader = getattr(param, "weight_loader", default_weight_loader)
263
+ weights_loader(param, loaded_weight)
264
+
265
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
266
+ return input_ids
267
+
268
+ def get_image_feature(self, image_input: MultimodalInputs):
269
+ pixel_values = image_input.pixel_values.type(
270
+ next(self.vision.parameters()).dtype
271
+ ).to(device=next(self.vision.parameters()).device)
272
+ image_feature = self.vision.forward_features(pixel_values)
273
+ images_embeds = self.projector(image_feature)
274
+ _, hw, n_dim = images_embeds.shape
275
+ h = w = int(hw**0.5)
276
+ tile_index = 0
277
+ images_in_this_batch = []
278
+ images_spatial_crop = image_input.image_spatial_crop
279
+ for jdx in range(images_spatial_crop.shape[1]):
280
+ num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
281
+ if num_width_tiles == 0 or num_height_tiles == 0:
282
+ break
283
+ num_tiles_in_image = num_width_tiles * num_height_tiles
284
+
285
+ # [hw, D]
286
+ global_features = images_embeds[tile_index]
287
+
288
+ # [num_height_tiles * num_width_tiles, hw, D]
289
+ local_features = images_embeds[
290
+ tile_index + 1 : tile_index + 1 + num_tiles_in_image
291
+ ]
292
+ tile_index += num_tiles_in_image + 1
293
+
294
+ # format global and local features
295
+ # ----------------- global view add newline -----------------
296
+ # [hw, D] -> [h, w, D]
297
+ global_features = global_features.view(h, w, n_dim)
298
+
299
+ # [D] -> [h, 1, D]
300
+ new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
301
+
302
+ # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
303
+ global_features = torch.cat([global_features, new_lines_in_global], dim=1)
304
+
305
+ # [h, w + 1, D] -> [h * (w + 1), D]
306
+ global_features = global_features.view(-1, n_dim)
307
+
308
+ # ----------------- local view add newline -----------------
309
+ # [num_height_tiles * num_width_tiles, h * w, D] ->
310
+ # [num_height_tiles * h, num_width_tiles * w, D]
311
+ local_features = rearrange(
312
+ local_features,
313
+ "(th tw) (h w) d -> (th h) (tw w) d",
314
+ th=num_height_tiles,
315
+ tw=num_width_tiles,
316
+ h=h,
317
+ w=w,
318
+ )
319
+
320
+ # [D] -> [num_height_tiles * h, 1, D]
321
+ new_lines_in_local = repeat(
322
+ self.image_newline,
323
+ "d -> (th h) 1 d",
324
+ th=num_height_tiles,
325
+ h=h,
326
+ )
327
+
328
+ # [num_height_tiles * h, num_width_tiles * w + 1, D]
329
+ local_features = torch.cat([local_features, new_lines_in_local], dim=1)
330
+
331
+ # [num_height_tiles * h, num_width_tiles * w + 1, D]
332
+ # --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
333
+ local_features = local_features.view(-1, n_dim)
334
+
335
+ # merge global and local tiles
336
+ if self.global_view_pos == "head":
337
+ global_local_features = torch.cat(
338
+ [
339
+ global_features,
340
+ self.view_seperator[None, :],
341
+ local_features,
342
+ ]
343
+ )
344
+ else:
345
+ global_local_features = torch.cat(
346
+ [
347
+ local_features,
348
+ self.view_seperator[None, :],
349
+ global_features,
350
+ ]
351
+ )
352
+
353
+ images_in_this_batch.append(global_local_features)
354
+
355
+ return torch.cat(images_in_this_batch, dim=0)
356
+
357
+
358
+ EntryClass = DeepseekVL2ForCausalLM