optimum-rbln 0.8.4a8__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (64) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +63 -32
  5. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +30 -14
  6. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +11 -8
  7. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +23 -13
  8. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +10 -6
  9. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +14 -10
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +14 -7
  11. optimum/rbln/diffusers/modeling_diffusers.py +5 -7
  12. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +9 -11
  13. optimum/rbln/modeling.py +50 -0
  14. optimum/rbln/modeling_base.py +1 -2
  15. optimum/rbln/transformers/__init__.py +8 -0
  16. optimum/rbln/transformers/modeling_generic.py +37 -1
  17. optimum/rbln/transformers/models/__init__.py +9 -0
  18. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +35 -3
  19. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +86 -23
  20. optimum/rbln/transformers/models/clip/modeling_clip.py +4 -0
  21. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  22. optimum/rbln/transformers/models/colpali/configuration_colpali.py +34 -18
  23. optimum/rbln/transformers/models/colpali/modeling_colpali.py +73 -80
  24. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  25. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  26. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  27. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  28. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  29. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
  30. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
  32. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +50 -2
  33. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  34. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +65 -3
  35. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +11 -3
  36. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  37. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
  38. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +67 -44
  39. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  40. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +27 -3
  41. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +24 -19
  42. optimum/rbln/transformers/models/llava/configuration_llava.py +16 -2
  43. optimum/rbln/transformers/models/llava/modeling_llava.py +108 -50
  44. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +11 -13
  45. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -343
  46. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  47. optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
  48. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +6 -11
  49. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +9 -8
  50. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +24 -0
  51. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +24 -0
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  54. optimum/rbln/transformers/models/siglip/modeling_siglip.py +3 -14
  55. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  56. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -1
  57. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  58. optimum/rbln/utils/runtime_utils.py +25 -15
  59. optimum/rbln/utils/submodule.py +21 -5
  60. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/METADATA +7 -6
  61. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/RECORD +64 -55
  62. optimum_rbln-0.9.2.dist-info/entry_points.txt +2 -0
  63. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/WHEEL +0 -0
  64. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/licenses/LICENSE +0 -0
@@ -26,6 +26,7 @@ from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModal
26
26
 
27
27
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
28
28
  from ....modeling import RBLNModel
29
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
29
30
  from .colpali_architecture import RBLNColPaliForRetrievalWrapper
30
31
 
31
32
 
@@ -33,93 +34,64 @@ if TYPE_CHECKING:
33
34
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
34
35
 
35
36
 
36
- class LoopVisionTower:
37
- def __init__(self, vision_tower: RBLNModel) -> None:
38
- self.vision_tower = vision_tower
37
+ class LoopVisionTower(LoopProcessor):
38
+ def __init__(self, vision_tower: "RBLNModel"):
39
+ super().__init__(model=vision_tower.model[0])
39
40
 
40
- def forward(self, pixel_values, **kwargs):
41
- batch_size = pixel_values.shape[0]
42
- outputs = []
43
- for i in range(batch_size):
44
- outputs.append(self.vision_tower(pixel_values[i : i + 1]))
41
+ def _get_batch_size(self, pixel_values, **kwargs):
42
+ return pixel_values.shape[0]
45
43
 
46
- last_hidden_states = [output.last_hidden_state for output in outputs]
47
- last_hidden_states = torch.cat(last_hidden_states, dim=0)
44
+ def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
45
+ pixel_values_item = pixel_values[index : index + 1]
46
+ out_buffer = kwargs["out"][index : index + 1]
47
+ return ([pixel_values_item], {"out": out_buffer})
48
48
 
49
+ def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
49
50
  return BaseModelOutputWithPooling(
50
- last_hidden_state=last_hidden_states,
51
+ last_hidden_state=kwargs["out"],
51
52
  )
52
53
 
53
- def __call__(self, *args: Any, **kwds: Any) -> Any:
54
- return self.forward(*args, **kwds)
55
-
56
- def __repr__(self) -> str:
57
- return repr(self.vision_tower)
58
-
59
54
 
60
- class LoopLanguageModel:
61
- def __init__(self, language_model: RBLNModel, rbln_config: RBLNModelConfig) -> None:
62
- self.language_model = language_model
55
+ class LoopLanguageModel(LoopProcessor):
56
+ def __init__(self, language_model: RBLNModel, rbln_config: RBLNModelConfig):
57
+ super().__init__(model=language_model)
63
58
  self.rbln_config = rbln_config
64
59
 
65
- def prepare_inputs(self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor):
60
+ def _get_batch_size(self, inputs_embeds, **kwargs):
61
+ return inputs_embeds.shape[0]
62
+
63
+ def _prepare_inputs_before_loop(self, *, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, **kwargs):
66
64
  input_len = inputs_embeds.shape[1]
67
65
  idx = bisect.bisect_left(self.rbln_config.max_seq_lens, input_len)
68
66
  if idx == len(self.rbln_config.max_seq_lens):
69
67
  raise ValueError(
70
68
  f"Required seq_len({input_len}) is larger than available max_seq_lens({self.rbln_config.max_seq_lens})."
71
69
  )
72
- else:
73
- max_seq_len = self.rbln_config.max_seq_lens[idx]
74
-
75
- inputs_embed = torch.nn.functional.pad(inputs_embeds, (0, 0, 0, max_seq_len - input_len))
76
- attn_mask = torch.nn.functional.pad(attention_mask, (0, max_seq_len - input_len)).to(torch.float32)
77
- position_ids = torch.arange(max_seq_len, dtype=torch.int32).view(1, -1)
78
-
79
- return inputs_embed, attn_mask, position_ids
80
-
81
- def forward(self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, **kwargs):
82
- padded_inputs_embed, padded_attn_mask, padded_position_ids = self.prepare_inputs(inputs_embeds, attention_mask)
83
- input_batch_size = inputs_embeds.shape[0]
84
- input_seq_len = inputs_embeds.shape[1]
85
-
86
- all_embeddings = []
87
- all_hidden_states = []
88
- for i in range(input_batch_size):
89
- outputs = self.language_model(
90
- inputs_embeds=padded_inputs_embed[i : i + 1],
91
- attention_mask=padded_attn_mask[i : i + 1],
92
- position_ids=padded_position_ids,
93
- )
94
-
95
- if self.rbln_config.output_hidden_states:
96
- embedding = outputs[0]
97
- hidden_states = outputs[1:]
98
- else:
99
- embedding = outputs
100
- hidden_states = None
70
+ max_seq_len = self.rbln_config.max_seq_lens[idx]
71
+ padded_inputs_embed = torch.nn.functional.pad(inputs_embeds, (0, 0, 0, max_seq_len - input_len))
72
+ padded_attn_mask = torch.nn.functional.pad(attention_mask, (0, max_seq_len - input_len)).to(torch.float32)
73
+ padded_position_ids = torch.arange(max_seq_len, dtype=torch.int32).view(1, -1)
74
+
75
+ return {
76
+ "padded_inputs_embed": padded_inputs_embed,
77
+ "padded_attn_mask": padded_attn_mask,
78
+ "padded_position_ids": padded_position_ids,
79
+ }
101
80
 
102
- all_embeddings.append(embedding)
103
- all_hidden_states.append(hidden_states)
81
+ def _prepare_inputs_for_iteration(self, index: int, common_inputs, *args, **kwargs):
82
+ item_kwargs = {
83
+ "inputs_embeds": common_inputs["padded_inputs_embed"][index : index + 1],
84
+ "attention_mask": common_inputs["padded_attn_mask"][index : index + 1],
85
+ "position_ids": common_inputs["padded_position_ids"],
86
+ "out": [tensor[index : index + 1] for tensor in kwargs["out"]],
87
+ }
88
+ return ([], item_kwargs)
104
89
 
105
- embeddings = torch.cat(all_embeddings, dim=0)[:, :input_seq_len]
90
+ def _process_outputs(self, outputs: list, **kwargs):
106
91
  if self.rbln_config.output_hidden_states:
107
- hidden_states = [
108
- torch.cat(
109
- [batch_hidden_states[layer_idx][:, :input_seq_len] for batch_hidden_states in all_hidden_states],
110
- dim=0,
111
- )
112
- for layer_idx in range(len(all_hidden_states[0]))
113
- ]
114
- return embeddings, tuple(hidden_states)
92
+ return kwargs["out"][0], tuple(kwargs["out"][1:])
115
93
  else:
116
- return embeddings
117
-
118
- def __call__(self, *args: Any, **kwds: Any) -> Any:
119
- return self.forward(*args, **kwds)
120
-
121
- def __repr__(self) -> str:
122
- return repr(self.language_model)
94
+ return kwargs["out"]
123
95
 
124
96
 
125
97
  class RBLNColPaliForRetrieval(RBLNModel):
@@ -212,7 +184,7 @@ class RBLNColPaliForRetrieval(RBLNModel):
212
184
  @classmethod
213
185
  def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
214
186
  return RBLNColPaliForRetrievalWrapper(
215
- causal_lm=model.vlm.language_model,
187
+ causal_lm=model.vlm,
216
188
  embedding_proj_layer=model.embedding_proj_layer,
217
189
  max_seq_len=max(rbln_config.max_seq_lens),
218
190
  output_hidden_states=rbln_config.output_hidden_states,
@@ -252,9 +224,9 @@ class RBLNColPaliForRetrieval(RBLNModel):
252
224
  input_infos = []
253
225
  for max_seq_len in rbln_config.max_seq_lens:
254
226
  input_info = [
255
- ("inputs_embeds", [1, max_seq_len, hidden_size], "float32"),
256
- ("attention_mask", [1, max_seq_len], "float32"),
257
- ("position_ids", [1, max_seq_len], "int32"),
227
+ ("inputs_embeds", [rbln_config.vision_tower.batch_size, max_seq_len, hidden_size], "float32"),
228
+ ("attention_mask", [rbln_config.vision_tower.batch_size, max_seq_len], "float32"),
229
+ ("position_ids", [rbln_config.vision_tower.batch_size, max_seq_len], "int32"),
258
230
  ]
259
231
  input_infos.append(input_info)
260
232
 
@@ -298,7 +270,7 @@ class RBLNColPaliForRetrieval(RBLNModel):
298
270
  """
299
271
  if not hasattr(model, "vision_tower"):
300
272
  model.vision_tower = model.vlm.vision_tower
301
- del model.vlm.vision_tower
273
+ del model.vlm.model.vision_tower
302
274
  model = super().from_model(model, config, rbln_config, model_save_dir, subfolder, **kwargs)
303
275
  return model
304
276
 
@@ -306,8 +278,7 @@ class RBLNColPaliForRetrieval(RBLNModel):
306
278
  def get_pytorch_model(cls, *args, **kwargs):
307
279
  model = super().get_pytorch_model(*args, **kwargs)
308
280
  model.vision_tower = model.vlm.vision_tower
309
- del model.vlm.vision_tower
310
-
281
+ del model.vlm.model.vision_tower
311
282
  return model
312
283
 
313
284
  def get_image_features(self, pixel_values: torch.Tensor):
@@ -318,8 +289,14 @@ class RBLNColPaliForRetrieval(RBLNModel):
318
289
  # Returns:
319
290
  # image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
320
291
 
321
- vision_outputs = self.vision_tower(pixel_values).last_hidden_state
322
- image_features = self.multi_modal_projector(vision_outputs)
292
+ vision_output_size = [
293
+ pixel_values.shape[0],
294
+ self.config.vlm_config.vision_config.num_image_tokens,
295
+ self.config.vlm_config.vision_config.hidden_size,
296
+ ]
297
+ vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu")
298
+ self.vision_tower(pixel_values, out=vision_output)
299
+ image_features = self.multi_modal_projector(vision_output)
323
300
  image_features = image_features / (self.config.text_config.hidden_size**0.5)
324
301
  return image_features
325
302
 
@@ -385,11 +362,27 @@ class RBLNColPaliForRetrieval(RBLNModel):
385
362
  input_ids=input_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values
386
363
  )
387
364
 
365
+ outputs = []
366
+ language_model_out_size = [inputs_embeds.shape[0], self.rbln_config.max_seq_lens[0], self.config.embedding_dim]
367
+ language_model_hidden_states_size = [
368
+ inputs_embeds.shape[0],
369
+ self.rbln_config.max_seq_lens[0],
370
+ self.rbln_config.max_seq_lens[0],
371
+ ]
372
+ outputs.append(torch.empty(size=language_model_out_size, dtype=torch.float32, device="cpu"))
373
+ if self.rbln_config.output_hidden_states:
374
+ for i in range(self.config.vlm_config.text_config.num_hidden_layers + 1):
375
+ outputs.append(torch.empty(size=language_model_hidden_states_size, dtype=torch.float32, device="cpu"))
376
+
388
377
  # Embedding_proj_layer is fused on the bottom of the language model.
389
- outputs = self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
378
+ self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, out=outputs)
390
379
 
391
- embeddings = outputs if not self.rbln_config.output_hidden_states else outputs[0]
392
- hidden_states = None if not self.rbln_config.output_hidden_states else outputs[1]
380
+ embeddings = outputs[0][:, : inputs_embeds.shape[1]]
381
+ hidden_states = (
382
+ None
383
+ if not self.rbln_config.output_hidden_states
384
+ else [tensor[0][:, : inputs_embeds.shape[1]] for tensor in outputs[1:]]
385
+ )
393
386
 
394
387
  # L2 normalization
395
388
  embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
@@ -0,0 +1,2 @@
1
+ from .configuration_colqwen2 import RBLNColQwen2ForRetrievalConfig
2
+ from .modeling_colqwen2 import RBLNColQwen2ForRetrieval
@@ -0,0 +1,233 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from transformers import PreTrainedModel
20
+
21
+ from optimum.rbln.transformers.models.decoderonly.decoderonly_architecture import (
22
+ DecoderOnlyLayer,
23
+ DecoderOnlyModel,
24
+ DecoderOnlyWrapper,
25
+ )
26
+
27
+ from .configuration_colqwen2 import (
28
+ RBLNColQwen2ForRetrievalConfig,
29
+ )
30
+
31
+
32
+ def slice_and_unsqueeze_cos_sin(cos, sin, position_ids):
33
+ """Slice cos[cache_position], sin[cache_position] vector for the query."""
34
+ cos = cos[position_ids[0]][None, None, None, :, :]
35
+ sin = sin[position_ids[0]][None, None, None, :, :]
36
+
37
+ return cos, sin
38
+
39
+
40
+ class ColQwen2LanguageModelWrapper(DecoderOnlyWrapper):
41
+ def __init__(
42
+ self, model: PreTrainedModel, rbln_config: "RBLNColQwen2ForRetrievalConfig", use_rotary_emb: bool = True
43
+ ):
44
+ model.config = (
45
+ model.config.vlm_config.text_config if hasattr(model.config, "vlm_config") else model.config.text_config
46
+ )
47
+ super().__init__(model, rbln_config, use_rotary_emb)
48
+
49
+ def get_decoder_layers(self, model: PreTrainedModel):
50
+ return model.language_model.layers
51
+
52
+ def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
53
+ new_layers = []
54
+ for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
55
+ is_sliding = layer_idx in self.rbln_config.sliding_window_layers
56
+ new_self_attn = self.get_rbln_attn_class()(
57
+ self.get_attn_layer(layer),
58
+ self.rbln_config,
59
+ is_sliding=is_sliding,
60
+ )
61
+ new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
62
+ new_layers.append(new_layer)
63
+
64
+ new_model = self.get_rbln_model_class()(
65
+ model.language_model,
66
+ new_layers,
67
+ self.rbln_config,
68
+ use_learned_pos_emb=self.__class__._use_learned_pos_emb,
69
+ )
70
+
71
+ # text_projection layer from model
72
+ self.embedding_proj_layer = (
73
+ model.embedding_proj_layer if hasattr(model, "embedding_proj_layer") else model.custom_text_proj
74
+ )
75
+ return new_model
76
+
77
+ def get_rbln_model_class(self):
78
+ return RBLNColQwen2LanguageModel
79
+
80
+ def prepare_forward_args(self, *args):
81
+ args = list(args)
82
+ input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
83
+ inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
84
+ cache_position = args.pop(0)
85
+ global_block_tables = args.pop(0)
86
+ local_block_tables = None
87
+ position_embeds = args.pop(0)
88
+ position_ids = None
89
+ attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
90
+ past_key_values = args
91
+
92
+ if len(past_key_values) != 2 * self.num_hidden_layers:
93
+ raise ValueError(
94
+ f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
95
+ )
96
+
97
+ _past_key_values = []
98
+ for i in range(self.config.num_hidden_layers):
99
+ key_states = past_key_values[i * 2]
100
+ value_states = past_key_values[i * 2 + 1]
101
+ past_key_value = [key_states, value_states]
102
+ _past_key_values.append(past_key_value)
103
+ past_key_values = _past_key_values
104
+
105
+ return (
106
+ input_ids,
107
+ inputs_embeds,
108
+ cache_position,
109
+ global_block_tables,
110
+ local_block_tables,
111
+ attention_mask,
112
+ position_ids,
113
+ past_key_values,
114
+ position_embeds,
115
+ )
116
+
117
+ def forward(self, *args):
118
+ (
119
+ input_ids,
120
+ inputs_embeds,
121
+ cache_position,
122
+ global_block_tables,
123
+ local_block_tables,
124
+ attention_mask,
125
+ position_ids,
126
+ past_key_values,
127
+ rotary_emb,
128
+ ) = self.prepare_forward_args(*args)
129
+
130
+ last_hidden_states = self.model(
131
+ input_ids=input_ids,
132
+ inputs_embeds=inputs_embeds,
133
+ attention_mask=attention_mask,
134
+ cache_position=cache_position,
135
+ position_ids=position_ids,
136
+ past_key_values=past_key_values,
137
+ rotary_emb=rotary_emb,
138
+ global_block_tables=global_block_tables,
139
+ local_block_tables=local_block_tables,
140
+ )
141
+
142
+ proj = self.embedding_proj_layer(last_hidden_states[0])
143
+ all_hidden_states = last_hidden_states[1] if self.rbln_config.output_hidden_states else None
144
+
145
+ if self.rbln_config.output_hidden_states:
146
+ return proj, all_hidden_states
147
+ else:
148
+ return proj
149
+
150
+
151
+ class RBLNColQwen2LanguageModel(DecoderOnlyModel):
152
+ def __init__(
153
+ self,
154
+ model,
155
+ layers: List["DecoderOnlyLayer"],
156
+ rbln_config: "RBLNColQwen2ForRetrievalConfig",
157
+ use_learned_pos_emb=None,
158
+ ):
159
+ super().__init__(model, layers, rbln_config, use_learned_pos_emb)
160
+
161
+ self.output_hidden_states = rbln_config.output_hidden_states
162
+
163
+ def forward(
164
+ self,
165
+ input_ids: torch.Tensor = None,
166
+ inputs_embeds: Optional[torch.Tensor] = None,
167
+ attention_mask: torch.Tensor = None,
168
+ cache_position: torch.Tensor = None,
169
+ position_ids: torch.Tensor = None,
170
+ query_position: torch.Tensor = None,
171
+ past_key_values: Tuple[Tuple[torch.Tensor]] = None,
172
+ rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
173
+ global_block_tables: Optional[torch.Tensor] = None,
174
+ local_block_tables: Optional[torch.Tensor] = None,
175
+ lora_int_id: Optional[torch.Tensor] = None,
176
+ ):
177
+ # retrieve input_ids and inputs_embeds
178
+ if (input_ids is None) ^ (inputs_embeds is not None):
179
+ raise ValueError(
180
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
181
+ )
182
+
183
+ # embed positions
184
+ if inputs_embeds is None:
185
+ inputs_embeds = self.get_embedding()(input_ids)
186
+
187
+ hidden_states = inputs_embeds * self.hidden_multiplier
188
+
189
+ # get cos,sin vector if needed
190
+ position_ids = position_ids if position_ids is not None else cache_position
191
+ if rotary_emb is not None:
192
+ if isinstance(rotary_emb, torch.Tensor):
193
+ cos = rotary_emb[0]
194
+ sin = rotary_emb[1]
195
+ else:
196
+ cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
197
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
198
+
199
+ # Get sequence positions for flash attention
200
+ if self.attn_impl == "flash_attn":
201
+ seq_positions = cache_position[:, 0]
202
+ seq_positions = self.convert_sequence_positions_for_flash_attn(
203
+ seq_positions=seq_positions, max_seq_len=self.max_seq_len
204
+ )
205
+ else:
206
+ seq_positions = cache_position[:, :1]
207
+
208
+ # Get local cache positions for sliding window layers
209
+ if len(self.sliding_window_layers) > 0:
210
+ sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
211
+
212
+ all_hidden_states = () if self.output_hidden_states else None
213
+ for layer_idx, layer in enumerate(self.layers):
214
+ if self.output_hidden_states:
215
+ all_hidden_states += (hidden_states,)
216
+
217
+ is_sliding = True if layer_idx in self.sliding_window_layers else False
218
+ hidden_states = layer(
219
+ hidden_states=hidden_states,
220
+ attention_mask=attention_mask,
221
+ seq_positions=sliding_cache_pos if is_sliding else seq_positions,
222
+ past_key_values=past_key_values,
223
+ cos=cos,
224
+ sin=sin,
225
+ block_tables=local_block_tables if is_sliding else global_block_tables,
226
+ lora_int_id=lora_int_id,
227
+ )
228
+
229
+ hidden_states = self.get_last_layernorm()(hidden_states)
230
+ if self.output_hidden_states:
231
+ all_hidden_states += (hidden_states,)
232
+
233
+ return hidden_states, all_hidden_states
@@ -0,0 +1,74 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from optimum.rbln.configuration_utils import RBLNModelConfig
18
+
19
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig
20
+
21
+
22
+ class RBLNColQwen2ForRetrievalConfig(RBLNDecoderOnlyModelConfig):
23
+ """
24
+ Configuration class for RBLN ColQwen2 models for document retrieval.
25
+
26
+ This class extends RBLNModelConfig with specific configurations for ColQwen2 models,
27
+ including vision tower settings and multi-sequence length support.
28
+
29
+ Example usage:
30
+ ```python
31
+ from optimum.rbln import RBLNColQwen2ForRetrievalConfig, RBLNColQwen2ForRetrievalConfig
32
+
33
+ # Create a configuration object
34
+ config = RBLNColQwen2ForRetrievalConfig(
35
+ visual={
36
+ "max_seq_lens": 6400,
37
+ "device": 0,
38
+ },
39
+ max_seq_len=32_768,
40
+ tensor_parallel_size=4,
41
+ device=[0, 1, 2, 3],
42
+ output_hidden_states=False,
43
+ )
44
+
45
+ # Use the configuration with from_pretrained
46
+ model = RBLNColQwen2ForRetrieval.from_pretrained(
47
+ "vidore/colqwen2-v1.0-hf",
48
+ export=True,
49
+ rbln_config=config
50
+ )
51
+ ```
52
+ """
53
+
54
+ submodules = ["visual"]
55
+
56
+ def __init__(
57
+ self,
58
+ visual: Optional[RBLNModelConfig] = None,
59
+ batch_size: Optional[int] = None,
60
+ use_inputs_embeds: bool = True,
61
+ output_hidden_states: Optional[bool] = False,
62
+ **kwargs,
63
+ ):
64
+ super().__init__(use_inputs_embeds=use_inputs_embeds, **kwargs)
65
+ if not self.use_inputs_embeds:
66
+ raise ValueError(
67
+ "RBLNColQwen2ForRetrievalConfig does not allow `use_inputs_embeds` to be set to False, "
68
+ "as RBLNColQwen2ForRetrieval accepts only `inputs_embeds` as input."
69
+ )
70
+ if batch_size is not None and batch_size != 1:
71
+ raise ValueError("batch_size is not supported for RBLNColQwen2ForRetrievalConfig")
72
+
73
+ self.visual = visual
74
+ self.output_hidden_states = output_hidden_states