optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. optimum/rbln/__init__.py +164 -36
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +107 -78
  31. optimum/rbln/transformers/__init__.py +87 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +108 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +115 -84
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +282 -216
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  73. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  74. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  75. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  76. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  77. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  78. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  79. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  80. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  81. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  82. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  83. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  84. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  85. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  86. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  87. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  88. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  89. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  90. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  91. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  92. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  93. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  94. optimum/rbln/utils/runtime_utils.py +33 -2
  95. optimum/rbln/utils/submodule.py +26 -43
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/METADATA +1 -1
  97. optimum_rbln-0.7.4a6.dist-info/RECORD +166 -0
  98. optimum/rbln/modeling_config.py +0 -310
  99. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  100. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/WHEEL +0 -0
  101. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,608 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from pathlib import Path
17
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from transformers import (
21
+ AutoModelForVision2Seq,
22
+ PretrainedConfig,
23
+ PreTrainedModel,
24
+ Qwen2_5_VLForConditionalGeneration,
25
+ )
26
+ from transformers.modeling_utils import no_init_weights
27
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
28
+ Qwen2_5_VisionPatchEmbed,
29
+ Qwen2_5_VisionRotaryEmbedding,
30
+ Qwen2_5_VisionTransformerPretrainedModel,
31
+ Qwen2_5_VLRotaryEmbedding,
32
+ )
33
+
34
+ from ....configuration_utils import RBLNCompileConfig
35
+ from ....modeling import RBLNModel
36
+ from ....utils.logging import get_logger
37
+ from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput
38
+ from .configuration_qwen2_5_vl import (
39
+ RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
40
+ )
41
+ from .qwen2_5_vl_architecture import Qwen2_5_VisionTransformerWrapper, Qwen2_5_VL_LanguageModelWrapper
42
+
43
+
44
+ logger = get_logger(__name__)
45
+
46
+ if TYPE_CHECKING:
47
+ from transformers import (
48
+ AutoFeatureExtractor,
49
+ AutoProcessor,
50
+ AutoTokenizer,
51
+ PretrainedConfig,
52
+ )
53
+
54
+
55
+ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
56
+ auto_model_class = None
57
+
58
+ def __post_init__(self, **kwargs):
59
+ self.transformer = self.model[0]
60
+ self.max_seq_lens = torch.tensor(sorted(self.rbln_config.max_seq_lens, reverse=False))
61
+ config = self.config
62
+ self.window_size = config.window_size
63
+ self.patch_size = config.spatial_patch_size
64
+ self.spatial_merge_size = config.spatial_merge_size
65
+ self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size
66
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding((config.hidden_size // config.num_heads) // 2)
67
+ with no_init_weights():
68
+ self.patch_embed = Qwen2_5_VisionPatchEmbed(
69
+ patch_size=config.patch_size,
70
+ temporal_patch_size=config.temporal_patch_size,
71
+ in_channels=config.in_channels,
72
+ embed_dim=config.hidden_size,
73
+ )
74
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
75
+ self.patch_embed.load_state_dict(artifacts["patch_embed"])
76
+
77
+ @classmethod
78
+ def save_torch_artifacts(
79
+ cls,
80
+ model: "Qwen2_5_VLForConditionalGeneration",
81
+ save_dir_path: Path,
82
+ subfolder: str,
83
+ rbln_config: RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
84
+ ):
85
+ save_dict = {}
86
+ save_dict["patch_embed"] = model.patch_embed.state_dict()
87
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
88
+
89
+ @classmethod
90
+ def wrap_model_if_needed(
91
+ cls, model: "PreTrainedModel", rbln_config: RBLNQwen2_5_VisionTransformerPretrainedModelConfig
92
+ ):
93
+ return Qwen2_5_VisionTransformerWrapper(model).eval()
94
+
95
+ def __getattr__(self, __name: str) -> Any:
96
+ def redirect(func):
97
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
98
+
99
+ val = getattr(Qwen2_5_VisionTransformerPretrainedModel, __name)
100
+
101
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
102
+ return redirect(val)
103
+ return val
104
+
105
+ @classmethod
106
+ def _update_rbln_config(
107
+ cls,
108
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
109
+ model: Optional["PreTrainedModel"] = None,
110
+ model_config: "PretrainedConfig" = None,
111
+ rbln_config: Optional[RBLNQwen2_5_VisionTransformerPretrainedModelConfig] = None,
112
+ ) -> RBLNQwen2_5_VisionTransformerPretrainedModelConfig:
113
+ window_size = getattr(model_config, "window_size")
114
+ patch_size = getattr(model_config, "patch_size")
115
+ hidden_size = getattr(model_config, "hidden_size")
116
+ num_heads = getattr(model_config, "num_heads")
117
+ head_dim = hidden_size // num_heads
118
+ window_seq_len = (window_size // patch_size) ** 2
119
+
120
+ input_infos = []
121
+ for max_seq_len in rbln_config.max_seq_lens:
122
+ if max_seq_len % window_seq_len > 0:
123
+ raise ValueError(
124
+ f"max_seq_len ({max_seq_len}) must be a multiple of window_seq_len ({window_seq_len})."
125
+ )
126
+
127
+ input_info = [
128
+ ("hidden_states", [max_seq_len, hidden_size], "float32"),
129
+ ("full_attn_masks", [1, 1, max_seq_len, max_seq_len], "float32"),
130
+ (
131
+ "window_attn_masks",
132
+ [max_seq_len // window_seq_len, 1, window_seq_len, window_seq_len],
133
+ "float32",
134
+ ),
135
+ (
136
+ "cos",
137
+ [1, 1, max_seq_len, head_dim],
138
+ "float32",
139
+ ),
140
+ (
141
+ "sin",
142
+ [1, 1, max_seq_len, head_dim],
143
+ "float32",
144
+ ),
145
+ ]
146
+ input_infos.append(input_info)
147
+
148
+ rbln_compile_config = RBLNCompileConfig(input_info=input_infos)
149
+ rbln_config.set_compile_cfgs([rbln_compile_config])
150
+
151
+ return rbln_config
152
+
153
+ @staticmethod
154
+ def _pad_for_window_attn_layers(
155
+ window_indice: List[int],
156
+ hidden_states: torch.Tensor,
157
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
158
+ window_seq_len: int,
159
+ max_seq_len: int,
160
+ ):
161
+ # Padding for Window Attention
162
+ padded_hidden_state = []
163
+ padded_cos = []
164
+ padded_sin = []
165
+ window_valid_lengths = []
166
+ for i in range(len(window_indice) - 1):
167
+ start, end = window_indice[i], window_indice[i + 1]
168
+ segment = hidden_states[start:end]
169
+ cos_segment = position_embeddings[0][start:end]
170
+ sin_segment = position_embeddings[1][start:end]
171
+ segment_len = end - start
172
+
173
+ if segment_len < window_seq_len:
174
+ padding_size = window_seq_len - segment_len
175
+ padding = torch.zeros(
176
+ padding_size,
177
+ segment.shape[-1],
178
+ dtype=segment.dtype,
179
+ )
180
+ padding_pos = torch.zeros(
181
+ padding_size,
182
+ cos_segment.shape[-1],
183
+ dtype=cos_segment.dtype,
184
+ )
185
+ padded_segment = torch.cat([segment, padding], dim=0)
186
+ padded_cos_segment = torch.cat([cos_segment, padding_pos], dim=0)
187
+ padded_sin_segment = torch.cat([sin_segment, padding_pos], dim=0)
188
+ else:
189
+ padded_segment = segment
190
+ padded_cos_segment = cos_segment
191
+ padded_sin_segment = sin_segment
192
+ padded_hidden_state.append(padded_segment)
193
+ window_valid_lengths.append(segment_len)
194
+ padded_cos.append(padded_cos_segment)
195
+ padded_sin.append(padded_sin_segment)
196
+ hidden_state_padded = torch.cat(padded_hidden_state)
197
+ cos_padded = torch.cat(padded_cos, dim=0)
198
+ sin_padded = torch.cat(padded_sin, dim=0)
199
+
200
+ window_attn_masks = torch.ones(
201
+ max_seq_len // window_seq_len,
202
+ 1,
203
+ window_seq_len,
204
+ window_seq_len,
205
+ dtype=torch.float32,
206
+ )
207
+ for i, valid_len in enumerate(window_valid_lengths):
208
+ if valid_len < window_seq_len:
209
+ window_attn_masks[i, :, valid_len:, :] = 0
210
+ window_attn_masks[i, :, :, valid_len:] = 0
211
+
212
+ return hidden_state_padded, cos_padded, sin_padded, window_attn_masks, window_valid_lengths
213
+
214
+ @staticmethod
215
+ def _pad_for_full_attn_layers(
216
+ hidden_state_padded, cos_padded, sin_padded, max_seq_len, window_valid_lengths, window_seq_len
217
+ ):
218
+ if hidden_state_padded.shape[0] < max_seq_len:
219
+ full_padding_size = max_seq_len - hidden_state_padded.shape[0]
220
+ full_padding_hidden = torch.zeros(
221
+ full_padding_size,
222
+ hidden_state_padded.shape[-1],
223
+ dtype=hidden_state_padded.dtype,
224
+ )
225
+ hidden_state_full_padded = torch.cat([hidden_state_padded, full_padding_hidden], dim=0) # [5120, 1280]
226
+ full_padding_pos = torch.zeros(
227
+ full_padding_size,
228
+ cos_padded.shape[-1],
229
+ dtype=cos_padded.dtype,
230
+ )
231
+ cos_full_padded = torch.cat([cos_padded, full_padding_pos], dim=0)
232
+ sin_full_padded = torch.cat([sin_padded, full_padding_pos], dim=0)
233
+ window_valid_lengths.extend([0] * (max_seq_len // window_seq_len - len(window_valid_lengths)))
234
+ else:
235
+ hidden_state_full_padded = hidden_state_padded
236
+ cos_full_padded = cos_padded
237
+ sin_full_padded = sin_padded
238
+
239
+ full_attn_masks = torch.ones(
240
+ 1,
241
+ 1,
242
+ max_seq_len,
243
+ max_seq_len,
244
+ dtype=torch.float32,
245
+ )
246
+ for i, valid_len in enumerate(window_valid_lengths):
247
+ start = i * window_seq_len
248
+ end = start + window_seq_len
249
+ full_attn_masks[:, :, start + valid_len : end, :] = 0
250
+ full_attn_masks[:, :, :, start + valid_len : end] = 0
251
+
252
+ return hidden_state_full_padded, cos_full_padded, sin_full_padded, full_attn_masks
253
+
254
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
255
+ hidden_states = self.patch_embed(hidden_states)
256
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
257
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
258
+ cu_window_seqlens = torch.tensor(
259
+ cu_window_seqlens,
260
+ dtype=torch.int32,
261
+ )
262
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
263
+
264
+ seq_len, _ = hidden_states.size()
265
+ hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
266
+ hidden_states = hidden_states[window_index, :, :]
267
+ hidden_states = hidden_states.reshape(seq_len, -1)
268
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
269
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
270
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
271
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
272
+ position_embeddings = (emb.cos(), emb.sin())
273
+
274
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
275
+ dim=0,
276
+ dtype=torch.int32,
277
+ )
278
+ cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)
279
+
280
+ num_images = len(cu_seqlens) - 1
281
+ cu_window_seqlens = cu_window_seqlens.tolist()
282
+ window_seq_len = (self.window_size // self.patch_size) ** 2
283
+
284
+ output_hidden_states = []
285
+
286
+ # Process each image in the sequence
287
+ for i in range(num_images):
288
+ image_s, image_e = cu_seqlens[i], cu_seqlens[i + 1]
289
+ window_indice = cu_window_seqlens[cu_window_seqlens.index(image_s) : cu_window_seqlens.index(image_e) + 1]
290
+
291
+ # Select the nearest higher max_seq_len from the available compiled models.
292
+ window_padded_len = len(window_indice) * window_seq_len
293
+ try:
294
+ ws_index = torch.searchsorted(self.max_seq_lens, window_padded_len).item()
295
+ max_seq_len = self.max_seq_lens[ws_index]
296
+ except Exception:
297
+ raise ValueError(
298
+ f"Required seq_len({window_padded_len}) is larger than available max_seq_lens({self.max_seq_lens.tolist()})."
299
+ )
300
+
301
+ # Padding for Window Attention Layers
302
+ hidden_state_padded, cos_padded, sin_padded, window_attn_masks, window_valid_lengths = (
303
+ self._pad_for_window_attn_layers(
304
+ window_indice, hidden_states, position_embeddings, window_seq_len, max_seq_len
305
+ )
306
+ )
307
+
308
+ # Padding for Full Attention Layers
309
+ hidden_state_full_padded, cos_full_padded, sin_full_padded, full_attn_masks = (
310
+ self._pad_for_full_attn_layers(
311
+ hidden_state_padded, cos_padded, sin_padded, max_seq_len, window_valid_lengths, window_seq_len
312
+ )
313
+ )
314
+
315
+ # RBLN run with the compiled model
316
+ output = self.transformer(
317
+ hidden_state_full_padded,
318
+ full_attn_masks,
319
+ window_attn_masks,
320
+ cos_full_padded[None, None, :, :],
321
+ sin_full_padded[None, None, :, :],
322
+ )
323
+
324
+ # Depadding
325
+ depadded_output = []
326
+ for i, valid_len in enumerate(window_valid_lengths):
327
+ start = i * (window_seq_len // self.spatial_merge_unit)
328
+ end = start + (valid_len // self.spatial_merge_unit)
329
+ depadded_output.append(output[start:end])
330
+ output = torch.cat(depadded_output, dim=0)
331
+
332
+ output_hidden_states.append(output)
333
+ hidden_states = torch.cat(output_hidden_states)
334
+ reverse_indices = torch.argsort(window_index)
335
+ hidden_states = hidden_states[reverse_indices, :]
336
+
337
+ return hidden_states
338
+
339
+
340
+ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
341
+ auto_model_class = AutoModelForVision2Seq
342
+ _rbln_submodules = [
343
+ {"name": "visual"},
344
+ ]
345
+ _decoder_wrapper_cls = Qwen2_5_VL_LanguageModelWrapper
346
+ _use_rotary_emb = False
347
+
348
+ def __post_init__(self, **kwargs):
349
+ super().__post_init__(**kwargs)
350
+ self.visual = self.rbln_submodules[0]
351
+ self.mrope_section = self.config.rope_scaling["mrope_section"]
352
+ self.rotary_emb = Qwen2_5_VLRotaryEmbedding(self.config)
353
+ self.rope_deltas = torch.zeros(self.rbln_config.batch_size)
354
+
355
+ def can_generate(self):
356
+ return True
357
+
358
+ @classmethod
359
+ def update_kwargs(cls, kwargs):
360
+ kwargs.update(
361
+ {
362
+ "_attn_implementation": "eager",
363
+ }
364
+ )
365
+ return super().update_kwargs(kwargs)
366
+
367
+ @classmethod
368
+ def get_input_info(
369
+ cls,
370
+ batch_size: int,
371
+ query_length: int,
372
+ use_inputs_embeds: bool,
373
+ use_attention_mask: bool,
374
+ max_seq_len: int,
375
+ kvcache_block_size: int,
376
+ kvcache_num_blocks: int,
377
+ num_key_value_heads: int,
378
+ num_hidden_layers: int,
379
+ hidden_size: int,
380
+ head_dim: int,
381
+ ):
382
+ input_info = super().get_input_info(
383
+ batch_size,
384
+ query_length,
385
+ use_inputs_embeds,
386
+ use_attention_mask,
387
+ max_seq_len,
388
+ kvcache_block_size,
389
+ kvcache_num_blocks,
390
+ num_key_value_heads,
391
+ num_hidden_layers,
392
+ hidden_size,
393
+ head_dim,
394
+ )
395
+ pos_idx = 5 if query_length > 1 else 4
396
+ pos_idx = pos_idx if use_attention_mask else pos_idx - 1
397
+ input_info.insert(pos_idx, ("position_emb", [2, batch_size, 1, query_length, head_dim], "float32"))
398
+
399
+ return input_info
400
+
401
+ def prepare_inputs_for_generation(
402
+ self,
403
+ input_ids: torch.LongTensor,
404
+ generate_idx: Optional[torch.Tensor] = None,
405
+ attention_mask: Optional[torch.LongTensor] = None,
406
+ inputs_embeds: Optional[torch.Tensor] = None,
407
+ pixel_values=None,
408
+ pixel_values_videos=None,
409
+ image_grid_thw=None,
410
+ video_grid_thw=None,
411
+ second_per_grid_ts=None,
412
+ **kwargs,
413
+ ):
414
+ model_inputs = {}
415
+ is_prefill_phase = generate_idx is None
416
+
417
+ if is_prefill_phase:
418
+ generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
419
+ cache_position = None
420
+ model_inputs.update({"input_ids": input_ids})
421
+ else:
422
+ if inputs_embeds is not None:
423
+ raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
424
+
425
+ input_ids = input_ids[:, -1:]
426
+ cache_position = generate_idx
427
+ generate_idx = generate_idx + 1
428
+ model_inputs.update({"input_ids": input_ids})
429
+
430
+ model_inputs.update(
431
+ {
432
+ "attention_mask": attention_mask,
433
+ "cache_position": cache_position,
434
+ "generate_idx": generate_idx,
435
+ "pixel_values": pixel_values,
436
+ "pixel_values_videos": pixel_values_videos,
437
+ "image_grid_thw": image_grid_thw,
438
+ "video_grid_thw": video_grid_thw,
439
+ "second_per_grid_ts": second_per_grid_ts,
440
+ }
441
+ )
442
+
443
+ return model_inputs
444
+
445
+ def _get_position_embeddings(self, hidden_states, position_ids):
446
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
447
+ mrope_section = self.mrope_section * 2
448
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
449
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
450
+ return torch.stack([cos, sin])
451
+
452
+ def _preprocess_prefill(
453
+ self,
454
+ input_ids: torch.LongTensor = None,
455
+ attention_mask: torch.Tensor = None,
456
+ pixel_values: torch.Tensor = None,
457
+ pixel_values_videos: torch.FloatTensor = None,
458
+ image_grid_thw: torch.LongTensor = None,
459
+ video_grid_thw: torch.LongTensor = None,
460
+ second_per_grid_ts: torch.Tensor = None,
461
+ ):
462
+ batch_size = input_ids.shape[0]
463
+ inputs_embeds = self.embed_tokens(input_ids)
464
+
465
+ if pixel_values is not None:
466
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
467
+ n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
468
+ n_image_features = image_embeds.shape[0]
469
+ if n_image_tokens != n_image_features:
470
+ raise ValueError(
471
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
472
+ )
473
+
474
+ mask = input_ids == self.config.image_token_id
475
+ mask_unsqueezed = mask.unsqueeze(-1)
476
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
477
+
478
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
479
+ inputs_embeds = inputs_embeds.masked_scatter(mask_expanded, image_embeds)
480
+
481
+ if pixel_values_videos is not None:
482
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
483
+ n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
484
+ n_video_features = video_embeds.shape[0]
485
+ if n_video_tokens != n_video_features:
486
+ raise ValueError(
487
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
488
+ )
489
+
490
+ mask = input_ids == self.config.video_token_id
491
+ mask_unsqueezed = mask.unsqueeze(-1)
492
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
493
+ inputs_embeds = inputs_embeds.masked_scatter(mask_expanded, video_embeds)
494
+
495
+ max_inputs_len = input_ids.shape[1]
496
+
497
+ head_dim = getattr(self.config, "head_dim", None) or self.config.hidden_size // self.config.num_attention_heads
498
+ all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim)
499
+ all_rope_deltas = []
500
+
501
+ image_token_id = self.config.image_token_id
502
+ video_token_id = self.config.video_token_id
503
+ vision_start_token_id = self.config.vision_start_token_id
504
+ image_idx, video_idx = 0, 0
505
+
506
+ for b_idx in range(batch_size):
507
+ input_id = input_ids[b_idx : b_idx + 1][:, attention_mask[b_idx].bool()]
508
+ vision_start_indices = torch.argwhere(input_id == vision_start_token_id).squeeze(1)
509
+ vision_tokens = input_id[0][vision_start_indices + 1]
510
+ image_nums = (vision_tokens == image_token_id).sum()
511
+ video_nums = (vision_tokens == video_token_id).sum()
512
+ position_ids, rope_deltas = self.get_rope_index(
513
+ input_id,
514
+ image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
515
+ video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
516
+ second_per_grid_ts[video_idx : video_idx + video_nums] if second_per_grid_ts is not None else None,
517
+ )
518
+ image_idx += image_nums
519
+ video_idx += video_nums
520
+
521
+ position_embed = self._get_position_embeddings(inputs_embeds, position_ids)
522
+ mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
523
+ all_position_embeds[:, b_idx : b_idx + 1].index_copy_(dim=-2, index=mask_indices, source=position_embed)
524
+ all_rope_deltas.append(rope_deltas)
525
+
526
+ rope_deltas = torch.stack(all_rope_deltas)
527
+
528
+ return inputs_embeds, all_position_embeds, rope_deltas
529
+
530
+ def _preprocess_decoder(
531
+ self,
532
+ input_ids: torch.LongTensor = None,
533
+ cache_position: torch.LongTensor = None,
534
+ ):
535
+ if self.rbln_config.batch_size != cache_position.shape[0]:
536
+ raise RuntimeError(
537
+ f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.rbln_config.batch_size}."
538
+ )
539
+
540
+ inputs_embeds = self.embed_tokens(input_ids)
541
+ position_embeds = []
542
+ for b_idx in range(self.rbln_config.batch_size):
543
+ delta = cache_position[b_idx] + self.rope_deltas[b_idx]
544
+ position_ids = torch.arange(1).view(1, -1)
545
+ position_ids = position_ids.add(delta)
546
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
547
+ position_embed = self._get_position_embeddings(torch.zeros(1, dtype=torch.float32), position_ids)
548
+ position_embeds.append(position_embed)
549
+
550
+ position_embeds = torch.cat(position_embeds, dim=1)
551
+
552
+ return inputs_embeds, position_embeds
553
+
554
+ def forward(
555
+ self,
556
+ input_ids: Optional[torch.LongTensor] = None,
557
+ inputs_embeds: Optional[torch.FloatTensor] = None,
558
+ attention_mask: Optional[torch.Tensor] = None,
559
+ pixel_values: Optional[torch.Tensor] = None,
560
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
561
+ image_grid_thw: Optional[torch.LongTensor] = None,
562
+ video_grid_thw: Optional[torch.LongTensor] = None,
563
+ cache_position: Optional[torch.LongTensor] = None,
564
+ second_per_grid_ts: Optional[torch.Tensor] = None,
565
+ generate_idx: torch.Tensor = None,
566
+ **kwargs,
567
+ ) -> RBLNDecoderOnlyOutput:
568
+ # Prefill
569
+ if cache_position is None:
570
+ inputs_embeds, position_embed, rope_deltas = self._preprocess_prefill(
571
+ input_ids,
572
+ attention_mask,
573
+ pixel_values,
574
+ pixel_values_videos,
575
+ image_grid_thw,
576
+ video_grid_thw,
577
+ second_per_grid_ts,
578
+ )
579
+
580
+ self.rope_deltas = rope_deltas
581
+ batch_size = inputs_embeds.shape[0]
582
+
583
+ logits = []
584
+ for b_idx in range(batch_size):
585
+ cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
586
+
587
+ logit = self.prefill_decoder(
588
+ inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
589
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
590
+ cache_position=cache_position,
591
+ batch_idx=b_idx,
592
+ position_embed=position_embed[:, b_idx : b_idx + 1],
593
+ )
594
+ logits.append(logit)
595
+ logits = torch.cat(logits, dim=0)
596
+ # Decoder
597
+ else:
598
+ inputs_embeds, position_embed = self._preprocess_decoder(input_ids, cache_position)
599
+ logits = self.decoder(
600
+ inputs_embeds=inputs_embeds,
601
+ cache_position=cache_position,
602
+ position_embed=position_embed,
603
+ )
604
+
605
+ return RBLNDecoderOnlyOutput(
606
+ logits=logits,
607
+ generate_idx=generate_idx,
608
+ )