optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. optimum/rbln/__init__.py +24 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +45 -33
  4. optimum/rbln/diffusers/__init__.py +21 -1
  5. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  13. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  14. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  15. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
  22. optimum/rbln/diffusers/modeling_diffusers.py +72 -65
  23. optimum/rbln/diffusers/models/__init__.py +4 -0
  24. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  25. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
  26. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  27. optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
  28. optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
  29. optimum/rbln/diffusers/models/controlnet.py +14 -8
  30. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  31. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  32. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  33. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  34. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
  35. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  36. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  37. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  38. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  39. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  42. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  43. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  45. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  46. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  47. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  49. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  50. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  51. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  52. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  53. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  54. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  55. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  56. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  57. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  58. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  59. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  60. optimum/rbln/modeling.py +71 -37
  61. optimum/rbln/modeling_base.py +63 -109
  62. optimum/rbln/transformers/__init__.py +41 -47
  63. optimum/rbln/transformers/configuration_generic.py +16 -13
  64. optimum/rbln/transformers/modeling_generic.py +21 -22
  65. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  66. optimum/rbln/transformers/models/__init__.py +54 -4
  67. optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
  68. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  69. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  70. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  71. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  72. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  73. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  74. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  75. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  76. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
  77. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
  78. optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
  79. optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
  80. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
  82. optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
  83. optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
  84. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  85. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  86. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
  87. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  88. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  89. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
  90. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  91. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  92. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  93. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  94. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  95. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  96. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  97. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  98. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  99. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  100. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
  101. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  102. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  103. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  104. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  105. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  106. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  107. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
  108. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  109. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  110. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  111. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  112. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  113. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  114. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  115. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  116. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  117. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  118. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  119. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  120. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  121. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  122. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
  123. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
  124. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
  125. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  126. optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
  127. optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
  128. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  129. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
  130. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
  131. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  132. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  133. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
  134. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  135. optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
  136. optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
  137. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  138. optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
  139. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  140. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  141. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
  142. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
  143. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  144. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  145. optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
  146. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  147. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
  148. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  149. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
  150. optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
  151. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  152. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  153. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  154. optimum/rbln/utils/model_utils.py +20 -0
  155. optimum/rbln/utils/runtime_utils.py +49 -1
  156. optimum/rbln/utils/submodule.py +6 -8
  157. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
  158. optimum_rbln-0.8.1.dist-info/RECORD +211 -0
  159. optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
  160. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  161. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
  162. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,383 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import bisect
16
+ from pathlib import Path
17
+ from typing import TYPE_CHECKING, Any, Optional, Union
18
+
19
+ import torch
20
+ from transformers import (
21
+ PretrainedConfig,
22
+ PreTrainedModel,
23
+ )
24
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
25
+ from transformers.modeling_utils import no_init_weights
26
+ from transformers.models.colpali.modeling_colpali import ColPaliForRetrievalOutput
27
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModalProjector
28
+
29
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
30
+ from ....modeling import RBLNModel
31
+ from .colpali_architecture import RBLNColPaliForRetrievalWrapper
32
+
33
+
34
+ if TYPE_CHECKING:
35
+ from transformers import (
36
+ AutoFeatureExtractor,
37
+ AutoProcessor,
38
+ AutoTokenizer,
39
+ PretrainedConfig,
40
+ )
41
+
42
+
43
+ class LoopVisionTower:
44
+ def __init__(self, vision_tower: RBLNModel) -> None:
45
+ self.vision_tower = vision_tower
46
+
47
+ def forward(self, pixel_values, **kwargs):
48
+ batch_size = pixel_values.shape[0]
49
+ outputs = []
50
+ for i in range(batch_size):
51
+ outputs.append(self.vision_tower(pixel_values[i : i + 1]))
52
+
53
+ last_hidden_states = [output.last_hidden_state for output in outputs]
54
+ last_hidden_states = torch.cat(last_hidden_states, dim=0)
55
+
56
+ return BaseModelOutputWithPooling(
57
+ last_hidden_state=last_hidden_states,
58
+ )
59
+
60
+ def __call__(self, *args: Any, **kwds: Any) -> Any:
61
+ return self.forward(*args, **kwds)
62
+
63
+ def __repr__(self) -> str:
64
+ return repr(self.vision_tower)
65
+
66
+
67
+ class LoopLanguageModel:
68
+ def __init__(self, language_model: RBLNModel, rbln_config: RBLNModelConfig) -> None:
69
+ self.language_model = language_model
70
+ self.rbln_config = rbln_config
71
+
72
+ def prepare_inputs(self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor):
73
+ input_len = inputs_embeds.shape[1]
74
+ idx = bisect.bisect_left(self.rbln_config.max_seq_lens, input_len)
75
+ if idx == len(self.rbln_config.max_seq_lens):
76
+ raise ValueError(
77
+ f"Required seq_len({input_len}) is larger than available max_seq_lens({self.rbln_config.max_seq_lens})."
78
+ )
79
+ else:
80
+ max_seq_len = self.rbln_config.max_seq_lens[idx]
81
+
82
+ inputs_embed = torch.nn.functional.pad(inputs_embeds, (0, 0, 0, max_seq_len - input_len))
83
+ attn_mask = torch.nn.functional.pad(attention_mask, (0, max_seq_len - input_len)).to(torch.float32)
84
+ position_ids = torch.arange(max_seq_len, dtype=torch.int32).view(1, -1)
85
+
86
+ return inputs_embed, attn_mask, position_ids
87
+
88
+ def forward(self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, **kwargs):
89
+ padded_inputs_embed, padded_attn_mask, padded_position_ids = self.prepare_inputs(inputs_embeds, attention_mask)
90
+ input_batch_size = inputs_embeds.shape[0]
91
+ input_seq_len = inputs_embeds.shape[1]
92
+
93
+ all_embeddings = []
94
+ all_hidden_states = []
95
+ for i in range(input_batch_size):
96
+ outputs = self.language_model(
97
+ inputs_embeds=padded_inputs_embed[i : i + 1],
98
+ attention_mask=padded_attn_mask[i : i + 1],
99
+ position_ids=padded_position_ids,
100
+ )
101
+
102
+ if self.rbln_config.output_hidden_states:
103
+ embedding = outputs[0]
104
+ hidden_states = outputs[1:]
105
+ else:
106
+ embedding = outputs
107
+ hidden_states = None
108
+
109
+ all_embeddings.append(embedding)
110
+ all_hidden_states.append(hidden_states)
111
+
112
+ embeddings = torch.cat(all_embeddings, dim=0)[:, :input_seq_len]
113
+ if self.rbln_config.output_hidden_states:
114
+ hidden_states = [
115
+ torch.cat(
116
+ [batch_hidden_states[layer_idx][:, :input_seq_len] for batch_hidden_states in all_hidden_states],
117
+ dim=0,
118
+ )
119
+ for layer_idx in range(len(all_hidden_states[0]))
120
+ ]
121
+ return embeddings, tuple(hidden_states)
122
+ else:
123
+ return embeddings
124
+
125
+ def __call__(self, *args: Any, **kwds: Any) -> Any:
126
+ return self.forward(*args, **kwds)
127
+
128
+ def __repr__(self) -> str:
129
+ return repr(self.language_model)
130
+
131
+
132
+ class RBLNColPaliForRetrieval(RBLNModel):
133
+ """
134
+ The ColPali Model transformer for document retrieval using vision-language models.
135
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
136
+
137
+ A class to convert and run pre-trained transformers based ColPaliForRetrieval model on RBLN devices.
138
+ It implements the methods to convert a pre-trained transformers ColPaliForRetrieval model into a RBLN transformer model by:
139
+
140
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
141
+ - compiling the resulting graph using the RBLN compiler.
142
+
143
+ **Configuration:**
144
+ This model uses [`RBLNColPaliForRetrievalConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
145
+ the `rbln_config` parameter should be an instance of [`RBLNColPaliForRetrievalConfig`] or a dictionary conforming to its structure.
146
+
147
+ See the [`RBLNColPaliForRetrievalConfig`] class for all available configuration options.
148
+
149
+ Examples:
150
+ ```python
151
+ from optimum.rbln import RBLNColPaliForRetrieval
152
+
153
+ # Simple usage using rbln_* arguments
154
+ # `max_seq_lens` is automatically inferred from the model config
155
+ model = RBLNColPaliForRetrieval.from_pretrained(
156
+ "vidore/colpali-v1.3-hf",
157
+ export=True,
158
+ rbln_max_seq_lens=1152,
159
+ )
160
+
161
+ # Using a config dictionary
162
+ rbln_config = {
163
+ "max_seq_lens": 1152,
164
+ "output_hidden_states": False,
165
+ }
166
+ model = RBLNColPaliForRetrieval.from_pretrained(
167
+ "vidore/colpali-v1.3-hf",
168
+ export=True,
169
+ rbln_config=rbln_config
170
+ )
171
+
172
+ # Using a RBLNColPaliForRetrievalConfig instance (recommended for type checking)
173
+ from optimum.rbln import RBLNColPaliForRetrievalConfig
174
+
175
+ config = RBLNColPaliForRetrievalConfig(
176
+ max_seq_lens=1152,
177
+ output_hidden_states=False,
178
+ tensor_parallel_size=4
179
+ )
180
+ model = RBLNColPaliForRetrieval.from_pretrained(
181
+ "vidore/colpali-v1.3-hf",
182
+ export=True,
183
+ rbln_config=config
184
+ )
185
+ ```
186
+ """
187
+
188
+ auto_model_class = None
189
+ _rbln_submodules = [
190
+ {"name": "vision_tower"},
191
+ ]
192
+
193
+ def __post_init__(self, **kwargs):
194
+ self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
195
+ self.language_model = LoopLanguageModel(self.model[0], self.rbln_config)
196
+
197
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
198
+ self.embed_tokens = self._create_embedding_layer()
199
+ self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
200
+ self.multi_modal_projector = self._create_multi_modal_projector()
201
+ self.multi_modal_projector.load_state_dict(artifacts["multi_modal_projector"])
202
+
203
+ return super().__post_init__(**kwargs)
204
+
205
+ def _create_embedding_layer(self):
206
+ with no_init_weights():
207
+ embed_tokens = torch.nn.Embedding(
208
+ self.config.text_config.vocab_size,
209
+ self.config.text_config.hidden_size,
210
+ self.config.text_config.pad_token_id,
211
+ )
212
+ return embed_tokens
213
+
214
+ def _create_multi_modal_projector(self):
215
+ with no_init_weights():
216
+ multi_modal_projector = PaliGemmaMultiModalProjector(self.config.vlm_config)
217
+ return multi_modal_projector
218
+
219
+ @classmethod
220
+ def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
221
+ return RBLNColPaliForRetrievalWrapper(
222
+ causal_lm=model.vlm.language_model,
223
+ embedding_proj_layer=model.embedding_proj_layer,
224
+ max_seq_len=max(rbln_config.max_seq_lens),
225
+ output_hidden_states=rbln_config.output_hidden_states,
226
+ )
227
+
228
+ @classmethod
229
+ def save_torch_artifacts(
230
+ cls,
231
+ model: "PreTrainedModel",
232
+ save_dir_path: Path,
233
+ subfolder: str,
234
+ rbln_config: RBLNModelConfig,
235
+ ):
236
+ save_dict = {}
237
+ save_dict["embed_tokens"] = model.vlm.get_input_embeddings().state_dict()
238
+ save_dict["multi_modal_projector"] = model.vlm.multi_modal_projector.state_dict()
239
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
240
+
241
+ @classmethod
242
+ def _update_rbln_config(
243
+ cls,
244
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
245
+ model: Optional["PreTrainedModel"] = None,
246
+ model_config: Optional["PretrainedConfig"] = None,
247
+ rbln_config: Optional[RBLNModelConfig] = None,
248
+ ) -> RBLNModelConfig:
249
+ hidden_size = model_config.vlm_config.text_config.hidden_size
250
+ if rbln_config.max_seq_lens is None:
251
+ rbln_config.max_seq_lens = [model_config.vlm_config.text_config.max_position_embeddings]
252
+ if isinstance(rbln_config.max_seq_lens, int):
253
+ rbln_config.max_seq_lens = [rbln_config.max_seq_lens]
254
+ rbln_config.max_seq_lens = sorted(set(rbln_config.max_seq_lens))
255
+
256
+ if rbln_config.output_hidden_states is None:
257
+ rbln_config.output_hidden_states = model_config.vlm_config.text_config.output_hidden_states
258
+
259
+ input_infos = []
260
+ for max_seq_len in rbln_config.max_seq_lens:
261
+ input_info = [
262
+ ("inputs_embeds", [1, max_seq_len, hidden_size], "float32"),
263
+ ("attention_mask", [1, max_seq_len], "float32"),
264
+ ("position_ids", [1, max_seq_len], "int32"),
265
+ ]
266
+ input_infos.append(input_info)
267
+
268
+ rbln_compile_config = RBLNCompileConfig(input_info=input_infos)
269
+ rbln_config.set_compile_cfgs([rbln_compile_config])
270
+
271
+ return rbln_config
272
+
273
+ @classmethod
274
+ def from_model(cls, model: "PreTrainedModel", *args, **kwargs):
275
+ if not hasattr(model, "vision_tower"):
276
+ model.vision_tower = model.vlm.vision_tower
277
+ del model.vlm.vision_tower
278
+ model = super().from_model(model, *args, **kwargs)
279
+ return model
280
+
281
+ @classmethod
282
+ def get_pytorch_model(cls, *args, **kwargs):
283
+ model = super().get_pytorch_model(*args, **kwargs)
284
+ model.vision_tower = model.vlm.vision_tower
285
+ del model.vlm.vision_tower
286
+
287
+ return model
288
+
289
+ def get_image_features(self, pixel_values: torch.Tensor):
290
+ # Projects the last hidden state from the vision model into language model space.
291
+ # Args:
292
+ # pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
293
+ # The tensors corresponding to the input images.
294
+ # Returns:
295
+ # image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
296
+
297
+ vision_outputs = self.vision_tower(pixel_values).last_hidden_state
298
+ image_features = self.multi_modal_projector(vision_outputs)
299
+ image_features = image_features / (self.config.text_config.hidden_size**0.5)
300
+ return image_features
301
+
302
+ def _preprocess_inputs(
303
+ self,
304
+ input_ids: Optional[torch.LongTensor] = None,
305
+ inputs_embeds: Optional[torch.FloatTensor] = None,
306
+ pixel_values: Optional[torch.FloatTensor] = None,
307
+ **kwargs,
308
+ ):
309
+ if (input_ids is None) ^ (inputs_embeds is not None):
310
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
311
+
312
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
313
+ if input_ids is not None and self.config.vlm_config.image_token_index >= self.config.text_config.vocab_size:
314
+ special_image_mask = input_ids == self.config.vlm_config.image_token_index
315
+ llm_input_ids = input_ids.clone()
316
+ llm_input_ids[special_image_mask] = 0
317
+ else:
318
+ llm_input_ids = input_ids
319
+
320
+ if inputs_embeds is None:
321
+ inputs_embeds = self.embed_tokens(llm_input_ids)
322
+
323
+ # Merge text and images
324
+ image_features = None
325
+ if pixel_values is not None:
326
+ image_features = self.get_image_features(pixel_values)
327
+ special_image_mask = (input_ids == self.config.vlm_config.image_token_index).unsqueeze(-1)
328
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
329
+
330
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
331
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
332
+
333
+ return inputs_embeds, image_features
334
+
335
+ def forward(
336
+ self,
337
+ input_ids: Optional[torch.LongTensor] = None,
338
+ inputs_embeds: Optional[torch.FloatTensor] = None,
339
+ pixel_values: Optional[torch.FloatTensor] = None,
340
+ attention_mask: Optional[torch.Tensor] = None,
341
+ output_attentions: Optional[bool] = None,
342
+ output_hidden_states: Optional[bool] = None,
343
+ return_dict: Optional[bool] = None,
344
+ **kwargs,
345
+ ) -> ColPaliForRetrievalOutput:
346
+ if pixel_values is not None:
347
+ pixel_values = pixel_values.to(dtype=self.dtype)
348
+
349
+ if output_attentions:
350
+ raise ValueError("output_attentions is not supported for RBLNColPaliForRetrieval")
351
+
352
+ if output_hidden_states is not None and output_hidden_states != self.rbln_config.output_hidden_states:
353
+ raise ValueError(
354
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
355
+ f"Please compile again with the correct argument."
356
+ )
357
+
358
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
359
+
360
+ inputs_embeds, image_features = self._preprocess_inputs(
361
+ input_ids=input_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values
362
+ )
363
+
364
+ # Embedding_proj_layer is fused on the bottom of the language model.
365
+ outputs = self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
366
+
367
+ embeddings = outputs if not self.rbln_config.output_hidden_states else outputs[0]
368
+ hidden_states = None if not self.rbln_config.output_hidden_states else outputs[1]
369
+
370
+ # L2 normalization
371
+ embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
372
+
373
+ if attention_mask is not None:
374
+ embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
375
+
376
+ if not return_dict:
377
+ return (embeddings, hidden_states, image_features)
378
+ else:
379
+ return ColPaliForRetrievalOutput(
380
+ embeddings=embeddings,
381
+ hidden_states=hidden_states,
382
+ image_hidden_states=image_features,
383
+ )
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, List, Optional, Union
15
+ from typing import Any, Dict, List, Literal, Optional, Union
16
16
 
17
17
  import rebel
18
18
 
@@ -23,8 +23,18 @@ from ...utils.rbln_quantization import RBLNQuantizationConfig
23
23
 
24
24
  logger = get_logger()
25
25
 
26
+ CacheImplType = Literal["static", "sliding_window", "hybrid"]
27
+
26
28
 
27
29
  class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
30
+ """
31
+ Configuration class for RBLN decoder-only models for Causal Language Modeling.
32
+
33
+ This class extends RBLNModelConfig with parameters specific to decoder-only transformer
34
+ architectures optimized for RBLN devices. It controls aspects like attention implementation,
35
+ KV cache management, and batching for inference.
36
+ """
37
+
28
38
  def __init__(
29
39
  self,
30
40
  batch_size: Optional[int] = None,
@@ -39,36 +49,119 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
39
49
  prefill_chunk_size: Optional[int] = None,
40
50
  kvcache_num_blocks: Optional[int] = None,
41
51
  decoder_batch_sizes: Optional[List[int]] = None,
52
+ cache_impl: Optional[CacheImplType] = None,
53
+ sliding_window: Optional[int] = None,
54
+ sliding_window_layers: Optional[List[int]] = None,
42
55
  **kwargs,
43
56
  ):
44
57
  """
45
58
  Args:
46
59
  batch_size (Optional[int]): The batch size for inference. Defaults to 1.
47
60
  max_seq_len (Optional[int]): The maximum sequence length supported by the model.
48
- use_inputs_embeds (Optional[bool]): Whether to use input embeddings directly. Defaults to False.
49
- use_attention_mask (Optional[bool]): Whether to use attention masks. This is automatically set to True
50
- for RBLN-CA02 devices.
61
+ If not provided, it attempts to infer from the model's configuration
62
+ (`max_position_embeddings` or `n_positions`). Must be specified if not available
63
+ in the model config.
64
+ use_inputs_embeds (Optional[bool]): Whether to use input embeddings (`inputs_embeds`)
65
+ directly instead of `input_ids`. Defaults to False. Requires the model to be
66
+ compiled with this option enabled.
67
+ use_attention_mask (Optional[bool]): Whether the model requires attention masks during
68
+ inference. This is typically determined based on the target device and model
69
+ architecture. Defaults are often set automatically based on the model and RBLN NPU.
51
70
  use_position_ids (Optional[bool]): Whether to use position IDs. Defaults to False.
52
- attn_impl (Optional[str]): The attention implementation to use.
53
- kvcache_partition_len (Optional[int]): The length of each KV cache partition.
54
- kvcache_block_size (Optional[int]): The block size for KV cache.
55
- quantization (Optional[Dict[str, Any]]): Configuration for model quantization.
56
- prefill_chunk_size (Optional[int]): The chunk size for prefilling the KV cache. Defaults to 128,
57
- and must be a positive integer divisible by 64.
58
- kvcache_num_blocks (Optional[int]): The number of blocks in the KV cache.
71
+ attn_impl (Optional[str]): Specifies the attention implementation to use.
72
+ See the "Attention Implementation (`attn_impl`)" section below for details.
73
+ kvcache_partition_len (Optional[int]): Defines the partition length for the KV cache
74
+ when using "flash_attn". See the "KV Cache Partition Length (`kvcache_partition_len`)"
75
+ section below for details.
76
+ kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
77
+ in the PagedAttention KV cache. See the "KV Cache Block Size (`kvcache_block_size`)"
78
+ section below for details.
79
+ quantization (Optional[Dict[str, Any]]): Configuration dictionary for applying model
80
+ quantization. Specifies format, etc.
81
+ prefill_chunk_size (Optional[int]): The chunk size used during the prefill phase for
82
+ processing input sequences. Defaults to 128. Must be a positive integer
83
+ divisible by 64. Affects prefill performance and memory usage.
84
+ kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
85
+ PagedAttention KV cache. See the "KV Cache Number of Blocks (`kvcache_num_blocks`)"
86
+ section below for details.
59
87
  decoder_batch_sizes (Optional[List[int]]): A list of batch sizes for which separate decoder models will be compiled.
60
88
  This allows the model to handle varying batch sizes efficiently during generation. If not specified,
61
89
  defaults to a list containing only the model's main batch size. When specifying multiple batch sizes:
62
90
  1) All values must be less than or equal to the main batch size.
63
91
  2) The list will be sorted in descending order (larger batch sizes first).
64
92
  3) If using multiple decoders, at least one batch size should match the main batch size.
65
-
93
+ cache_impl (Optional[CacheImplType]): Specifies the KV cache implementation strategy. Defaults to "static".
94
+ - "static": Uses a fixed-size global KV cache for all layers, suitable for standard attention patterns.
95
+ - "sliding_window": Implements a sliding window KV cache, where each layer maintains a local cache of recent tokens.
96
+ - "hybrid": Combines both static and sliding window approaches, allowing different layers to use different cache strategies.
97
+ The choice affects memory usage and attention patterns. When using "sliding_window" or "hybrid",
98
+ you must specify the `sliding_window` size and optionally `sliding_window_layers` for hybrid mode.
99
+ sliding_window (Optional[int]): The size of the sliding window. Defaults to None.
100
+ sliding_window_layers (Optional[List[int]]): The layers to use for the sliding window used in the hybrid model. Defaults to None.
66
101
  **kwargs: Additional arguments passed to the parent RBLNModelConfig.
67
102
 
68
103
  Raises:
69
- ValueError: If batch_size is not a positive integer or if prefill_chunk_size is not
70
- a positive integer divisible by 64.
104
+ ValueError: If `batch_size` is not a positive integer.
105
+ ValueError: If `prefill_chunk_size` is not a positive integer divisible by 64.
106
+ ValueError: If `max_seq_len` cannot be determined and is required.
107
+ ValueError: If attention parameter constraints are violated (e.g., `max_seq_len` vs
108
+ `kvcache_partition_len` for flash attention).
109
+
110
+
111
+ Attention Implementation:
112
+ `attn_impl` determines the underlying attention mechanism used by the model.
113
+
114
+ - **`"eager"`** (Default if `kvcache_partition_len` is not set): Uses the standard PyTorch
115
+ attention implementation. Suitable for sequences up to a certain limit (e.g., 32,768 tokens).
116
+ - **`"flash_attn"`**: Utilizes an optimized Flash Attention implementation, beneficial for
117
+ longer sequences and potentially faster execution. Requires `max_seq_len` to be at least
118
+ 8,192. If `kvcache_partition_len` is specified, `attn_impl` automatically defaults
119
+ to `"flash_attn"`. When using `"flash_attn"`, `kvcache_block_size` must equal
120
+ `kvcache_partition_len`.
121
+
122
+ The choice impacts performance and memory usage, especially for long sequences.
123
+ Constraints related to `max_seq_len` and `kvcache_partition_len` apply when using
124
+ `"flash_attn"`.
125
+
126
+
127
+ KV Cache Partition Length:
128
+ `kvcache_partition_len` is relevant **only** when `attn_impl` is `"flash_attn"`.
129
+
130
+ - It defines the length (number of tokens) of each partition within the Key-Value (KV) cache.
131
+ - Must be between 4,096 and 32,768 (inclusive).
132
+ - When using `"flash_attn"`, `max_seq_len` must be a multiple of `kvcache_partition_len`
133
+ and at least twice its value (`max_seq_len >= 2 * kvcache_partition_len`).
134
+ - If `attn_impl` is `"flash_attn"` and `kvcache_partition_len` is `None`, it defaults to
135
+ 16,384.
136
+
137
+
138
+ KV Cache Number of Blocks:
139
+ `kvcache_num_blocks` controls the total number of memory blocks allocated for the PagedAttention KV cache.
140
+ Each block holds `kvcache_block_size` tokens of Key and Value states.
141
+
142
+ - **Automatic Estimation (Default)**: If `kvcache_num_blocks` is `None`, the system estimates
143
+ the maximum number of blocks that can fit into the available RBLN device memory. This
144
+ calculation considers the model size (kernel memory), required buffer memory, the number
145
+ of layers and heads, `kvcache_block_size`, tensor parallelism, and available RBLN NPU DRAM.
146
+ This aims to maximize cache capacity for potentially better performance with long sequences
147
+ or larger batches without manual tuning.
148
+ - **Manual Setting**: You can explicitly set the number of blocks. This provides finer control
149
+ but requires careful consideration of memory limits. Setting it too high may lead to
150
+ compilation errors if it exceeds available memory. The system will issue warnings if your
151
+ setting exceeds the estimated maximum.
152
+ - **Performance Impact**: A larger number of blocks reduces the likelihood of cache eviction,
153
+ which is beneficial for tasks involving many long sequences or large batch sizes, enabling
154
+ higher throughput. However, allocating more blocks consumes more memory.
155
+ - **Minimum Requirement**: The system requires a minimum number of blocks to function,
156
+ calculated based on `max_seq_len`, `kvcache_block_size`, and `batch_size`. The number of
157
+ allocated blocks must be sufficient to hold at least one full sequence length per item
158
+ in the batch concurrently. The system will log warnings or raise errors if constraints
159
+ are violated (e.g., if `kvcache_num_blocks` is less than `batch_size` when using Flash Attention).
160
+
161
+ The optimal value depends on the specific model, task, hardware, and desired trade-off
162
+ between performance and memory usage. The automatic estimation provides a robust starting point.
71
163
  """
164
+
72
165
  super().__init__(**kwargs)
73
166
  self.batch_size = batch_size or 1
74
167
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
@@ -121,6 +214,10 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
121
214
  # Larger batch size should be at the beginning of the list.
122
215
  self.decoder_batch_sizes.sort(reverse=True)
123
216
 
217
+ self.cache_impl = cache_impl or "static"
218
+ self.sliding_window = sliding_window
219
+ self.sliding_window_layers = sliding_window_layers or []
220
+
124
221
  @property
125
222
  def use_multiple_decoder(self):
126
223
  return isinstance(self.decoder_batch_sizes, list) and len(self.decoder_batch_sizes) > 1