optimum-rbln 0.8.4a8__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (64) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +63 -32
  5. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +30 -14
  6. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +11 -8
  7. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +23 -13
  8. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +10 -6
  9. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +14 -10
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +14 -7
  11. optimum/rbln/diffusers/modeling_diffusers.py +5 -7
  12. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +9 -11
  13. optimum/rbln/modeling.py +50 -0
  14. optimum/rbln/modeling_base.py +1 -2
  15. optimum/rbln/transformers/__init__.py +8 -0
  16. optimum/rbln/transformers/modeling_generic.py +37 -1
  17. optimum/rbln/transformers/models/__init__.py +9 -0
  18. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +35 -3
  19. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +86 -23
  20. optimum/rbln/transformers/models/clip/modeling_clip.py +4 -0
  21. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  22. optimum/rbln/transformers/models/colpali/configuration_colpali.py +34 -18
  23. optimum/rbln/transformers/models/colpali/modeling_colpali.py +73 -80
  24. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  25. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  26. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  27. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  28. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  29. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
  30. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
  32. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +50 -2
  33. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  34. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +65 -3
  35. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +11 -3
  36. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  37. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
  38. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +67 -44
  39. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  40. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +27 -3
  41. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +24 -19
  42. optimum/rbln/transformers/models/llava/configuration_llava.py +16 -2
  43. optimum/rbln/transformers/models/llava/modeling_llava.py +108 -50
  44. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +11 -13
  45. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -343
  46. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  47. optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
  48. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +6 -11
  49. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +9 -8
  50. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +24 -0
  51. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +24 -0
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  54. optimum/rbln/transformers/models/siglip/modeling_siglip.py +3 -14
  55. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  56. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -1
  57. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  58. optimum/rbln/utils/runtime_utils.py +25 -15
  59. optimum/rbln/utils/submodule.py +21 -5
  60. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/METADATA +7 -6
  61. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/RECORD +64 -55
  62. optimum_rbln-0.9.2.dist-info/entry_points.txt +2 -0
  63. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/WHEEL +0 -0
  64. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,446 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Optional, Union
16
+
17
+ import torch
18
+ from transformers import (
19
+ PretrainedConfig,
20
+ PreTrainedModel,
21
+ )
22
+ from transformers.modeling_utils import no_init_weights
23
+ from transformers.models.colqwen2.modeling_colqwen2 import ColQwen2ForRetrievalOutput
24
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
25
+ Qwen2_5_VLModel,
26
+ Qwen2_5_VLRotaryEmbedding,
27
+ )
28
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
29
+ Qwen2VLModel,
30
+ Qwen2VLRotaryEmbedding,
31
+ )
32
+
33
+ from optimum.rbln.transformers.models.decoderonly.modeling_decoderonly import (
34
+ RBLNDecoderOnlyModel,
35
+ )
36
+
37
+ from .configuration_colqwen2 import (
38
+ RBLNColQwen2ForRetrievalConfig,
39
+ )
40
+
41
+
42
+ if TYPE_CHECKING:
43
+ from transformers import (
44
+ AutoFeatureExtractor,
45
+ AutoProcessor,
46
+ AutoTokenizer,
47
+ PretrainedConfig,
48
+ )
49
+
50
+ from .colqwen2_architecture import ColQwen2LanguageModelWrapper
51
+
52
+
53
+ class RBLNColQwen2ForRetrieval(RBLNDecoderOnlyModel):
54
+ """
55
+ The ColQwen Model transformer for document retrieval using vision-language models.
56
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
57
+
58
+ A class to convert and run pre-trained transformers based `ColQwen2ForRetrieval` model on RBLN devices.
59
+ It implements the methods to convert a pre-trained transformers `ColQwen2ForRetrieval` model into a RBLN transformer model by:
60
+
61
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
62
+ - compiling the resulting graph using the RBLN compiler.
63
+
64
+ **Configuration:**
65
+ This model uses [`RBLNColQwen2ForRetrievalConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
66
+ the `rbln_config` parameter should be an instance of [`RBLNColQwen2ForRetrievalConfig`] or a dictionary conforming to its structure.
67
+
68
+ See the [`RBLNColQwen2ForRetrievalConfig`] class for all available configuration options.
69
+
70
+ Examples:
71
+ ```python
72
+ from optimum.rbln import RBLNColQwen2ForRetrieval
73
+
74
+ # Using a config dictionary
75
+ rbln_config = {
76
+ "visual": {
77
+ "max_seq_lens": 6400,
78
+ },
79
+ "max_seq_len": 32_768,
80
+ "tensor_parallel_size": 4,
81
+ "device": [0, 1, 2, 3],
82
+ "output_hidden_states": False,
83
+ }
84
+ model = RBLNColQwen2ForRetrieval.from_pretrained(
85
+ "vidore/colqwen2-v1.0-hf",
86
+ export=True,
87
+ rbln_config=rbln_config
88
+ )
89
+
90
+ # Using a RBLNColQwen2ForRetrievalConfig instance (recommended for type checking)
91
+ from optimum.rbln import RBLNColQwen2ForRetrievalConfig
92
+
93
+ config = RBLNColQwen2ForRetrievalConfig(
94
+ visual={
95
+ "max_seq_lens": 6400,
96
+ "device": 0,
97
+ },
98
+ max_seq_len=32_768,
99
+ tensor_parallel_size=4,
100
+ device=[0, 1, 2, 3],
101
+ output_hidden_states=False,
102
+ )
103
+ model = RBLNColQwen2ForRetrieval.from_pretrained(
104
+ "vidore/colqwen2-v1.0-hf",
105
+ export=True,
106
+ rbln_config=config
107
+ )
108
+ ```
109
+ """
110
+
111
+ main_input_name = "inputs_embeds"
112
+ auto_model_class = None
113
+ _rbln_submodules = [
114
+ {"name": "visual"},
115
+ ]
116
+ _decoder_wrapper_cls = ColQwen2LanguageModelWrapper
117
+ _use_rotary_emb = False
118
+
119
+ def __post_init__(self, **kwargs):
120
+ self.config = self.config.vlm_config if hasattr(self.config, "vlm_config") else self.config
121
+
122
+ artifacts = torch.load(
123
+ self.model_save_dir / self.subfolder / "torch_artifacts.pth",
124
+ weights_only=False,
125
+ )
126
+ self.embed_tokens = self._create_embedding_layer()
127
+ self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
128
+ self.visual = self.rbln_submodules[0]
129
+ self.prefill_runtime = self.model[0]
130
+ self.mrope_section = self.config.text_config.rope_scaling["mrope_section"]
131
+ self.is_colqwen2_5 = "qwen2_5_vl" in self.config.model_type
132
+
133
+ if self.is_colqwen2_5:
134
+ self.rotary_emb = Qwen2_5_VLRotaryEmbedding(self.config.text_config)
135
+ else:
136
+ self.rotary_emb = Qwen2VLRotaryEmbedding(self.config.text_config)
137
+ self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
138
+
139
+ @classmethod
140
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
141
+ if hasattr(model, "vlm"):
142
+ model.visual = model.vlm.visual
143
+ model.language_model = model.vlm.language_model
144
+
145
+ # FIXME: temporary fix for ColQwen2ForRetrieval dtype issue
146
+ return model.to(torch.float32)
147
+
148
+ def _create_embedding_layer(self):
149
+ with no_init_weights():
150
+ embed_tokens = torch.nn.Embedding(
151
+ self.config.text_config.vocab_size,
152
+ self.config.text_config.hidden_size,
153
+ self.config.text_config.pad_token_id,
154
+ )
155
+ return embed_tokens
156
+
157
+ @classmethod
158
+ def get_input_info(
159
+ cls,
160
+ batch_size: int,
161
+ query_length: int,
162
+ rbln_config: RBLNColQwen2ForRetrievalConfig,
163
+ model_config: PretrainedConfig,
164
+ ):
165
+ text_config = model_config.text_config
166
+ input_info = super().get_input_info(
167
+ batch_size,
168
+ query_length,
169
+ rbln_config=rbln_config,
170
+ model_config=text_config,
171
+ )
172
+
173
+ pos_idx = 3
174
+ input_info.insert(
175
+ pos_idx,
176
+ (
177
+ "position_emb",
178
+ [
179
+ 2,
180
+ batch_size,
181
+ 1,
182
+ query_length,
183
+ text_config.hidden_size // text_config.num_attention_heads,
184
+ ],
185
+ rbln_config.torch_dtype,
186
+ ),
187
+ )
188
+
189
+ # remove query postion from input_info
190
+ if "query_position" in input_info:
191
+ query_position = input_info.pop(4)
192
+ assert query_position[0] == "query_position", print(query_position[0], "is deleted.")
193
+ return input_info
194
+
195
+ @classmethod
196
+ def _update_rbln_config(
197
+ cls,
198
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
199
+ model: Optional["PreTrainedModel"] = None,
200
+ model_config: Optional["PretrainedConfig"] = None,
201
+ rbln_config: Optional[RBLNColQwen2ForRetrievalConfig] = None,
202
+ ) -> RBLNColQwen2ForRetrievalConfig:
203
+ model_config = model_config.vlm_config if hasattr(model_config, "vlm_config") else model_config
204
+ if rbln_config.output_hidden_states is None:
205
+ rbln_config.output_hidden_states = getattr(model_config.text_config, "output_hidden_states", False)
206
+
207
+ return super()._update_rbln_config(
208
+ preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
209
+ )
210
+
211
+ def _get_position_embeddings(self, hidden_states, position_ids):
212
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
213
+ mrope_section = self.mrope_section * 2
214
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
215
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
216
+ return torch.stack([cos, sin])
217
+
218
+ def get_rope_index(self, *args, **kwargs):
219
+ if self.is_colqwen2_5:
220
+ return Qwen2_5_VLModel.get_rope_index(self, *args, **kwargs)
221
+ else:
222
+ return Qwen2VLModel.get_rope_index(self, *args, **kwargs)
223
+
224
+ def _preprocess_visual(
225
+ self,
226
+ input_ids: torch.LongTensor = None,
227
+ attention_mask: torch.Tensor = None,
228
+ pixel_values: torch.Tensor = None,
229
+ pixel_values_videos: torch.FloatTensor = None,
230
+ image_grid_thw: torch.LongTensor = None,
231
+ video_grid_thw: torch.LongTensor = None,
232
+ second_per_grid_ts: torch.Tensor = None,
233
+ ):
234
+ batch_size = input_ids.shape[0]
235
+ inputs_embeds = self.embed_tokens(input_ids)
236
+
237
+ if pixel_values is not None:
238
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
239
+ n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
240
+ n_image_features = image_embeds.shape[0]
241
+ if n_image_tokens != n_image_features:
242
+ raise ValueError(
243
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
244
+ )
245
+
246
+ mask = input_ids == self.config.image_token_id
247
+ mask_unsqueezed = mask.unsqueeze(-1)
248
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
249
+
250
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
251
+ inputs_embeds = inputs_embeds.masked_scatter(mask_expanded, image_embeds)
252
+
253
+ if pixel_values_videos is not None:
254
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
255
+ n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
256
+ n_video_features = video_embeds.shape[0]
257
+ if n_video_tokens != n_video_features:
258
+ raise ValueError(
259
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
260
+ )
261
+
262
+ mask = input_ids == self.config.video_token_id
263
+ mask_unsqueezed = mask.unsqueeze(-1)
264
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
265
+ inputs_embeds = inputs_embeds.masked_scatter(mask_expanded, video_embeds)
266
+
267
+ max_inputs_len = input_ids.shape[1]
268
+ head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads
269
+ all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim)
270
+ all_rope_deltas = []
271
+
272
+ image_token_id = self.config.image_token_id
273
+ video_token_id = self.config.video_token_id
274
+ vision_start_token_id = self.config.vision_start_token_id
275
+ image_idx, video_idx = 0, 0
276
+
277
+ for b_idx in range(batch_size):
278
+ input_id = input_ids[b_idx : b_idx + 1][:, attention_mask[b_idx].bool()]
279
+ vision_start_indices = torch.argwhere(input_id == vision_start_token_id).squeeze(1)
280
+ vision_tokens = input_id[0][vision_start_indices + 1]
281
+ image_nums = (vision_tokens == image_token_id).sum()
282
+ video_nums = (vision_tokens == video_token_id).sum()
283
+ args = [
284
+ input_id,
285
+ image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
286
+ video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
287
+ ]
288
+ if self.config.model_type == "qwen2_5_vl":
289
+ args.append(
290
+ second_per_grid_ts[video_idx : video_idx + video_nums] if second_per_grid_ts is not None else None
291
+ )
292
+ position_ids, rope_deltas = self.get_rope_index(*args)
293
+ image_idx += image_nums
294
+ video_idx += video_nums
295
+
296
+ position_embed = self._get_position_embeddings(inputs_embeds, position_ids)
297
+ mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
298
+ all_position_embeds[:, b_idx : b_idx + 1].index_copy_(dim=-2, index=mask_indices, source=position_embed)
299
+ all_rope_deltas.append(rope_deltas)
300
+
301
+ rope_deltas = torch.stack(all_rope_deltas)
302
+
303
+ return inputs_embeds, all_position_embeds, rope_deltas
304
+
305
+ def _preprocess_chunked_prefill(self, inputs_embeds, attention_mask, position_embed):
306
+ # valid sequence length of inputs_embeds
307
+ query_length = inputs_embeds.shape[1] if attention_mask is None else torch.sum(attention_mask.view(-1)).item()
308
+
309
+ # extract valid inputs
310
+ inputs_embeds = inputs_embeds[:, attention_mask.bool()] if attention_mask is not None else inputs_embeds
311
+ position_embed = (
312
+ position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
313
+ )
314
+
315
+ # add padding for chunked prefill
316
+ padding_size = (
317
+ self.rbln_config.prefill_chunk_size - (query_length % self.rbln_config.prefill_chunk_size)
318
+ ) % self.rbln_config.prefill_chunk_size
319
+ padded_len = query_length + padding_size
320
+
321
+ inputs_embeds = torch.nn.functional.pad(inputs_embeds, (0, 0, 0, padding_size))
322
+ position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
323
+ cache_position = torch.arange(padded_len, dtype=torch.int32).unsqueeze(0)
324
+
325
+ return inputs_embeds, position_embed, cache_position, query_length
326
+
327
+ def _chunked_prefill_forward(
328
+ self,
329
+ inputs_embeds: torch.Tensor,
330
+ attention_mask: Optional[torch.Tensor] = None,
331
+ position_embed: Optional[torch.Tensor] = None,
332
+ output_hidden_states: Optional[bool] = False,
333
+ ):
334
+ padded_inputs_embeds, padded_position_embed, cache_position, query_length = self._preprocess_chunked_prefill(
335
+ inputs_embeds, attention_mask, position_embed
336
+ )
337
+
338
+ # Chunked prefill
339
+ projs = []
340
+ all_hidden_states = [] if output_hidden_states else None
341
+ for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
342
+ # Extract the current chunk of inputs and cache positions
343
+ input_chunk = padded_inputs_embeds[:, step : step + self.rbln_config.prefill_chunk_size]
344
+ cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
345
+ position_embed_chunk = padded_position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
346
+
347
+ # Forward pass for the current chunk
348
+ proj = self.prefill_runtime(
349
+ inputs_embeds=input_chunk,
350
+ cache_position=cache_pos_chunk,
351
+ block_tables=self.block_tables,
352
+ position_emb=position_embed_chunk,
353
+ )
354
+
355
+ if output_hidden_states:
356
+ projs.append(proj[0])
357
+ all_hidden_states.append(proj[1:])
358
+ else:
359
+ projs.append(proj)
360
+
361
+ projs = torch.concat(projs, dim=-2)[:, :query_length]
362
+ if output_hidden_states:
363
+ # Concatenate chunks for each layer
364
+ concatenated_hidden_states = [
365
+ torch.concat(hs_chunks, dim=-2)[:, :query_length] for hs_chunks in list(zip(*all_hidden_states))
366
+ ]
367
+ all_hidden_states = tuple(concatenated_hidden_states)
368
+
369
+ return self._postprocess_chunked_prefill(projs, attention_mask), all_hidden_states
370
+
371
+ def _postprocess_chunked_prefill(self, projs, attention_mask):
372
+ # index copy for attention mask
373
+ if attention_mask is not None:
374
+ embedding = torch.full(
375
+ (1, attention_mask.shape[-1], projs.shape[-1]),
376
+ fill_value=1e-10,
377
+ dtype=projs.dtype,
378
+ )
379
+ mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
380
+ embedding.index_copy_(dim=-2, index=mask_indices, source=projs)
381
+ else:
382
+ embedding = projs
383
+ return embedding
384
+
385
+ def forward(
386
+ self,
387
+ input_ids: Optional[torch.LongTensor] = None,
388
+ inputs_embeds: Optional[torch.FloatTensor] = None,
389
+ attention_mask: Optional[torch.Tensor] = None,
390
+ pixel_values: Optional[torch.Tensor] = None,
391
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
392
+ image_grid_thw: Optional[torch.LongTensor] = None,
393
+ video_grid_thw: Optional[torch.LongTensor] = None,
394
+ second_per_grid_ts: Optional[torch.Tensor] = None,
395
+ output_hidden_states: Optional[bool] = None,
396
+ **kwargs,
397
+ ) -> torch.Tensor:
398
+ output_hidden_states = (
399
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
400
+ )
401
+
402
+ if output_hidden_states != self.rbln_config.output_hidden_states:
403
+ raise ValueError(
404
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
405
+ f"Please compile again with the correct argument."
406
+ )
407
+
408
+ # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
409
+ if pixel_values is not None and image_grid_thw is not None:
410
+ offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (batch_size,)
411
+ pixel_values = torch.cat(
412
+ [pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)],
413
+ dim=0,
414
+ )
415
+ # visual preprocessing
416
+ inputs_embeds, position_embed, _ = self._preprocess_visual(
417
+ input_ids,
418
+ attention_mask,
419
+ pixel_values,
420
+ pixel_values_videos,
421
+ image_grid_thw,
422
+ video_grid_thw,
423
+ second_per_grid_ts,
424
+ )
425
+ batch_size = inputs_embeds.shape[0]
426
+
427
+ projs = []
428
+ for b_idx in range(batch_size):
429
+ proj = self._chunked_prefill_forward(
430
+ inputs_embeds[b_idx : b_idx + 1],
431
+ attention_mask[b_idx] if attention_mask is not None else None,
432
+ position_embed[:, b_idx : b_idx + 1],
433
+ output_hidden_states=output_hidden_states,
434
+ )
435
+ projs.append(proj[0])
436
+ all_hidden_states = proj[1] if output_hidden_states else ()
437
+
438
+ # postprocess
439
+ projs = torch.cat(projs, dim=0)
440
+ projs = projs / projs.norm(dim=-1, keepdim=True)
441
+ projs = projs * attention_mask.unsqueeze(-1)
442
+
443
+ return ColQwen2ForRetrievalOutput(
444
+ embeddings=projs,
445
+ hidden_states=all_hidden_states,
446
+ )
@@ -23,4 +23,5 @@ from ....ops import (
23
23
  paged_flash_causal_attn_prefill,
24
24
  )
25
25
  from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
26
+ from .configuration_lora import RBLNLoRAAdapterConfig, RBLNLoRAConfig
26
27
  from .modeling_decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
@@ -17,6 +17,7 @@ from typing import Any, Dict, List, Literal, Optional, Union, get_args
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....utils.logging import get_logger
19
19
  from ...utils.rbln_quantization import RBLNQuantizationConfig
20
+ from .configuration_lora import RBLNLoRAConfig
20
21
 
21
22
 
22
23
  logger = get_logger()
@@ -48,6 +49,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
48
49
  kvcache_partition_len: Optional[int] = None,
49
50
  kvcache_block_size: Optional[int] = None,
50
51
  quantization: Optional[Union[Dict[str, Any], RBLNQuantizationConfig]] = None,
52
+ lora_config: Optional[Union[Dict[str, Any], RBLNLoRAConfig]] = None,
51
53
  prefill_chunk_size: Optional[int] = None,
52
54
  kvcache_num_blocks: Optional[int] = None,
53
55
  decoder_batch_sizes: Optional[List[int]] = None,
@@ -80,6 +82,12 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
80
82
  kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
81
83
  in the PagedAttention KV cache. See the "KV Cache Block Size (`kvcache_block_size`)"
82
84
  section below for details.
85
+ quantization (Optional[Dict[str, Any]]): Configuration dictionary for applying model
86
+ quantization. Specifies format, etc.
87
+ lora_config (Optional[Union[Dict[str, Any], RBLNLoRAConfig]]): Configuration for LoRA
88
+ (Low-Rank Adaptation) settings when using (multi-)LoRA support. Can be provided as
89
+ a dictionary or an RBLNLoRAConfig instance. When provided, enables LoRA functionality
90
+ for the model compilation. Defaults to None (no LoRA).
83
91
  prefill_chunk_size (Optional[int]): The chunk size used during the prefill phase for
84
92
  processing input sequences. Defaults to 128. Must be a positive integer
85
93
  divisible by 64. Affects prefill performance and memory usage.
@@ -185,6 +193,26 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
185
193
  if self.quantization and isinstance(self.quantization, dict):
186
194
  self.quantization = RBLNQuantizationConfig(**self.quantization)
187
195
 
196
+ self.lora_config = lora_config
197
+ if self.lora_config and isinstance(self.lora_config, dict):
198
+ self.lora_config = RBLNLoRAConfig(**self.lora_config)
199
+
200
+ # Validate LoRA adapters if LoRA is enabled
201
+ if self.lora_config is not None:
202
+ validation_results = self.lora_config.validate_adapter_weights()
203
+ failed_adapters = [adapter_id for adapter_id, is_valid in validation_results.items() if not is_valid]
204
+
205
+ if failed_adapters:
206
+ raise ValueError(
207
+ f"Some LoRA adapters failed validation and may not be accessible at compile time: {failed_adapters}. "
208
+ "Please ensure all adapter weights are available and properly formatted."
209
+ )
210
+
211
+ logger.info(
212
+ f"LoRA configuration initialized with {self.lora_config.num_adapters} adapters: "
213
+ f"{self.lora_config.adapter_ids}. Max rank: {self.lora_config.max_lora_rank}"
214
+ )
215
+
188
216
  self.attn_impl = attn_impl
189
217
  self.kvcache_partition_len = kvcache_partition_len
190
218
  self.kvcache_block_size = kvcache_block_size
@@ -204,6 +232,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
204
232
  if self.logits_to_keep is not None and self.logits_to_keep > 1:
205
233
  raise NotImplementedError("`logits_to_keep` > 1 is currently not supported for RBLN models.")
206
234
 
235
+ self.decoder_batch_sizes = None
207
236
  if "decode" in self.phases:
208
237
  self.decoder_batch_sizes = decoder_batch_sizes
209
238
  if self.decoder_batch_sizes is None:
@@ -243,6 +272,11 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
243
272
  def use_multiple_decoder(self) -> bool:
244
273
  return isinstance(self.decoder_batch_sizes, list) and len(self.decoder_batch_sizes) > 1
245
274
 
275
+ @property
276
+ def use_lora(self):
277
+ """Check if LoRA is enabled for this configuration."""
278
+ return self.lora_config is not None
279
+
246
280
  @property
247
281
  def can_generate(self) -> bool:
248
282
  return "decode" in self.phases