optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. optimum/rbln/__init__.py +96 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +153 -42
  5. optimum/rbln/diffusers/__init__.py +7 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  12. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  13. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  20. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  21. optimum/rbln/diffusers/models/__init__.py +3 -13
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  23. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
  24. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  25. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
  26. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
  27. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
  28. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
  29. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  30. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  31. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  32. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  34. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  36. optimum/rbln/modeling.py +71 -19
  37. optimum/rbln/modeling_base.py +99 -21
  38. optimum/rbln/ops/attn.py +158 -0
  39. optimum/rbln/ops/flash_attn.py +166 -0
  40. optimum/rbln/ops/kv_cache_update.py +5 -0
  41. optimum/rbln/ops/linear.py +7 -0
  42. optimum/rbln/transformers/__init__.py +92 -0
  43. optimum/rbln/transformers/configuration_generic.py +9 -7
  44. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  45. optimum/rbln/transformers/modeling_generic.py +51 -9
  46. optimum/rbln/transformers/modeling_outputs.py +37 -0
  47. optimum/rbln/transformers/models/__init__.py +91 -30
  48. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  49. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  50. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  51. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  52. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  53. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  54. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  55. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
  57. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  58. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
  59. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  60. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  61. optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
  62. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  63. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  64. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  65. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  66. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  67. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
  68. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  69. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
  71. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
  72. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  73. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
  74. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  75. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  76. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  77. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  78. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  79. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  80. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  81. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  82. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  83. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  84. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  85. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
  86. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  87. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  88. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  89. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  90. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  91. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  92. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  93. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  94. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  95. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
  96. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  97. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  98. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  99. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  100. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  101. optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
  102. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  103. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
  104. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  105. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  106. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  107. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  108. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  109. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  110. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  111. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  112. optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
  113. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  114. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  115. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  116. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  117. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  118. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  119. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  120. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  121. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  122. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  123. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  124. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
  125. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  126. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  127. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  128. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  129. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  130. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
  131. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  132. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  133. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  134. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
  135. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  136. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  137. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  138. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  139. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
  140. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
  141. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  142. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  143. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  144. optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
  145. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  146. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  147. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  148. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  149. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  150. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  151. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  152. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  153. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
  154. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  155. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  156. optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
  157. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  158. optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
  159. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  160. optimum/rbln/utils/depreacate_utils.py +16 -0
  161. optimum/rbln/utils/runtime_utils.py +28 -18
  162. optimum/rbln/utils/submodule.py +31 -9
  163. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
  164. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
  165. optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
  166. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
  167. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,341 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import types
16
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from transformers import SwinConfig
21
+ from transformers.models.swin.modeling_swin import BackboneOutput
22
+
23
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
24
+ from ....modeling import RBLNModel
25
+ from ....utils.logging import get_logger
26
+ from .configuration_swin import RBLNSwinBackboneConfig
27
+
28
+
29
+ logger = get_logger(__name__)
30
+
31
+ if TYPE_CHECKING:
32
+ from transformers import (
33
+ AutoFeatureExtractor,
34
+ AutoProcessor,
35
+ AutoTokenizer,
36
+ PreTrainedModel,
37
+ SwinBackbone,
38
+ SwinEncoder,
39
+ )
40
+
41
+
42
+ def window_partition(input_feature, window_size):
43
+ """
44
+ Partitions the given input into windows.
45
+ """
46
+ batch_size, height, width, num_channels = input_feature.shape
47
+ input_feature = input_feature.view(
48
+ batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
49
+ )
50
+ windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
51
+ return windows
52
+
53
+
54
+ def get_attn_mask(self, height, width, dtype, device):
55
+ if self.shift_size > 0:
56
+ # calculate attention mask for SW-MSA
57
+ img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
58
+ height_slices = (
59
+ slice(0, -self.window_size),
60
+ slice(-self.window_size, -self.shift_size),
61
+ slice(-self.shift_size, None),
62
+ )
63
+ width_slices = (
64
+ slice(0, -self.window_size),
65
+ slice(-self.window_size, -self.shift_size),
66
+ slice(-self.shift_size, None),
67
+ )
68
+ count = torch.zeros(1)
69
+ for height_slice in height_slices:
70
+ for width_slice in width_slices:
71
+ img_mask[:, height_slice, width_slice, :] = count
72
+ count += 1
73
+
74
+ mask_windows = window_partition(img_mask, self.window_size)
75
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
76
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
77
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
78
+ else:
79
+ attn_mask = None
80
+ return attn_mask
81
+
82
+
83
+ class _SwinEncoder(torch.nn.Module):
84
+ def __init__(self, model: "SwinEncoder"):
85
+ super().__init__()
86
+ self.layers = model.layers
87
+
88
+ def forward(
89
+ self,
90
+ hidden_states: torch.Tensor,
91
+ input_dimensions: Tuple[int, int],
92
+ head_mask: Optional[torch.FloatTensor] = None,
93
+ output_attentions: Optional[bool] = False,
94
+ output_hidden_states: Optional[bool] = False,
95
+ output_hidden_states_before_downsampling: Optional[bool] = False,
96
+ always_partition: Optional[bool] = False,
97
+ return_dict: Optional[bool] = True,
98
+ ):
99
+ all_hidden_states = () if output_hidden_states else None
100
+ all_reshaped_hidden_states = () if output_hidden_states else None
101
+ all_self_attentions = () if output_attentions else None
102
+
103
+ if output_hidden_states:
104
+ batch_size, _, hidden_size = hidden_states.shape
105
+ # rearrange b (h w) c -> b c h w
106
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
107
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
108
+ all_hidden_states += (hidden_states,)
109
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
110
+
111
+ for i, layer_module in enumerate(self.layers):
112
+ layer_head_mask = head_mask[i] if head_mask is not None else None
113
+
114
+ layer_outputs = layer_module(
115
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
116
+ )
117
+
118
+ hidden_states = layer_outputs[0]
119
+ hidden_states_before_downsampling = layer_outputs[1]
120
+ output_dimensions = layer_outputs[2]
121
+
122
+ input_dimensions = (output_dimensions[-2], output_dimensions[-1])
123
+
124
+ if output_hidden_states and output_hidden_states_before_downsampling:
125
+ batch_size, _, hidden_size = hidden_states_before_downsampling.shape
126
+ # rearrange b (h w) c -> b c h w
127
+ # here we use the original (not downsampled) height and width
128
+ reshaped_hidden_state = hidden_states_before_downsampling.view(
129
+ batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
130
+ )
131
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
132
+ all_hidden_states += (hidden_states_before_downsampling,)
133
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
134
+ elif output_hidden_states and not output_hidden_states_before_downsampling:
135
+ batch_size, _, hidden_size = hidden_states.shape
136
+ # rearrange b (h w) c -> b c h w
137
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
138
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
139
+ all_hidden_states += (hidden_states,)
140
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
141
+
142
+ if output_attentions:
143
+ all_self_attentions += layer_outputs[3:]
144
+
145
+ return tuple(
146
+ v
147
+ for v in [hidden_states, all_hidden_states, all_self_attentions, all_reshaped_hidden_states]
148
+ if v is not None
149
+ )
150
+
151
+
152
+ class _SwinBackbone(torch.nn.Module):
153
+ def __init__(self, model: "SwinBackbone", output_hidden_states: bool, output_attentions: bool):
154
+ super().__init__()
155
+ self.model = model
156
+ self.embeddings = model.embeddings
157
+ self.encoder = model.encoder
158
+ self.stage_names = model.stage_names
159
+ self.out_features = model.out_features
160
+ self.hidden_states_norms = model.hidden_states_norms
161
+ self.output_hidden_states = output_hidden_states
162
+ self.output_attentions = output_attentions
163
+
164
+ def forward(
165
+ self,
166
+ pixel_values: torch.Tensor,
167
+ ):
168
+ embedding_output, input_dimensions = self.embeddings(pixel_values)
169
+ outputs = _SwinEncoder(self.encoder)(
170
+ embedding_output,
171
+ input_dimensions,
172
+ head_mask=None,
173
+ output_attentions=self.output_attentions,
174
+ output_hidden_states=True,
175
+ output_hidden_states_before_downsampling=True,
176
+ always_partition=True,
177
+ return_dict=False,
178
+ )
179
+
180
+ hidden_states = outputs[-1]
181
+
182
+ feature_maps = ()
183
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
184
+ if stage in self.out_features:
185
+ batch_size, num_channels, height, width = hidden_state.shape
186
+ hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
187
+ hidden_state = hidden_state.view(batch_size, height * width, num_channels)
188
+ hidden_state = self.hidden_states_norms[stage](hidden_state)
189
+ hidden_state = hidden_state.view(batch_size, height, width, num_channels)
190
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
191
+ feature_maps += (hidden_state,)
192
+
193
+ output = (feature_maps,)
194
+
195
+ if self.output_hidden_states:
196
+ output += (outputs[1],)
197
+
198
+ if self.output_attentions:
199
+ output += (outputs[2],)
200
+
201
+ return output
202
+
203
+
204
+ class RBLNSwinBackbone(RBLNModel):
205
+ @classmethod
206
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSwinBackboneConfig) -> torch.nn.Module:
207
+ for layer in model.encoder.layers:
208
+ for block in layer.blocks:
209
+ block.get_attn_mask = types.MethodType(get_attn_mask, block)
210
+
211
+ wrapper_cfg = {
212
+ "output_hidden_states": rbln_config.output_hidden_states,
213
+ "output_attentions": rbln_config.output_attentions,
214
+ }
215
+ return _SwinBackbone(model, **wrapper_cfg).eval()
216
+
217
+ @classmethod
218
+ def _update_submodule_config(
219
+ cls,
220
+ model: "PreTrainedModel",
221
+ rbln_config: RBLNModelConfig,
222
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
223
+ ):
224
+ for processor in preprocessors:
225
+ if rbln_config.image_size is None and hasattr(processor, "image_processor"):
226
+ if "height" in processor.image_processor.size and "width" in processor.image_processor.size:
227
+ rbln_config.image_size = (
228
+ processor.image_processor.size["height"],
229
+ processor.image_processor.size["width"],
230
+ )
231
+ elif (
232
+ "longest_edge" in processor.image_processor.size
233
+ and "shortest_edge" in processor.image_processor.size
234
+ ):
235
+ rbln_config.image_size = processor.image_processor.size["longest_edge"]
236
+ elif "shortest_edge" in processor.image_processor.size:
237
+ rbln_config.image_size = processor.image_processor.size["shortest_edge"]
238
+ break
239
+
240
+ return rbln_config
241
+
242
+ @classmethod
243
+ def _update_rbln_config(
244
+ cls,
245
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
246
+ model: Optional["PreTrainedModel"] = None,
247
+ model_config: "SwinConfig" = None,
248
+ rbln_config: Optional[RBLNSwinBackboneConfig] = None,
249
+ ) -> RBLNSwinBackboneConfig:
250
+ if rbln_config.image_size is None:
251
+ for processor in preprocessors:
252
+ if hasattr(processor, "size"):
253
+ if all(required_key in processor.size.keys() for required_key in ["height", "width"]):
254
+ rbln_config.image_size = (processor.size["height"], processor.size["width"])
255
+ break
256
+
257
+ input_info = [
258
+ (
259
+ "pixel_values",
260
+ [
261
+ rbln_config.batch_size,
262
+ 3,
263
+ rbln_config.image_height,
264
+ rbln_config.image_width,
265
+ ],
266
+ "float32",
267
+ ),
268
+ ]
269
+
270
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
271
+ return rbln_config
272
+
273
+ def forward(
274
+ self,
275
+ pixel_values: Optional[torch.FloatTensor] = None,
276
+ return_dict: bool = True,
277
+ output_attentions: bool = None,
278
+ output_hidden_states: bool = None,
279
+ **kwargs,
280
+ ) -> Union[Tuple, BackboneOutput]:
281
+ if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
282
+ logger.warning(
283
+ f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
284
+ )
285
+
286
+ output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
287
+ output_hidden_states = (
288
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
289
+ )
290
+
291
+ if output_attentions != self.rbln_config.output_attentions:
292
+ raise ValueError(
293
+ f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
294
+ f"Please compile again with the correct argument."
295
+ )
296
+
297
+ if output_hidden_states != self.rbln_config.output_hidden_states:
298
+ raise ValueError(
299
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
300
+ f"Please compile again with the correct argument."
301
+ )
302
+
303
+ _, _, original_h, original_w = pixel_values.shape
304
+ if original_h > self.rbln_config.image_height or original_w > self.rbln_config.image_width:
305
+ raise ValueError(
306
+ f"Input image size ({original_h}x{original_w}) exceeds the configured maximum size"
307
+ f" ({self.rbln_config.image_height}x{self.rbln_config.image_width})."
308
+ )
309
+
310
+ pad_h = self.rbln_config.image_height - original_h
311
+ pad_w = self.rbln_config.image_width - original_w
312
+ padded_pixel_values = F.pad(pixel_values, (0, pad_w, 0, pad_h))
313
+
314
+ output = self.model[0](padded_pixel_values)
315
+
316
+ feature_maps = ()
317
+ for i in range(len(self.config.out_features)):
318
+ feature_maps += (output.pop(0),)
319
+
320
+ if self.rbln_config.output_hidden_states:
321
+ hidden_states = ()
322
+ for i in range(len(self.config.stage_names)):
323
+ hidden_states += (output.pop(0),)
324
+ else:
325
+ hidden_states = None
326
+
327
+ if self.rbln_config.output_attentions:
328
+ attentions = ()
329
+ for i in range(len(self.config.depths)):
330
+ attentions += (output.pop(0),)
331
+ else:
332
+ attentions = None
333
+
334
+ if not return_dict:
335
+ return tuple(item for item in (feature_maps, hidden_states, attentions) if item is not None)
336
+ else:
337
+ return BackboneOutput(
338
+ feature_maps=feature_maps,
339
+ hidden_states=hidden_states,
340
+ attentions=attentions,
341
+ )
@@ -32,3 +32,5 @@ class RBLNT5ForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
32
32
  This configuration class stores the configuration parameters specific to
33
33
  RBLN-optimized T5 models for conditional text generation tasks.
34
34
  """
35
+
36
+ support_paged_attention = False
@@ -126,7 +126,14 @@ class T5Decoder(Seq2SeqDecoder):
126
126
  b_size = attention_mask.shape[0]
127
127
  batch_decoder_position_bias = []
128
128
  for i in range(b_size):
129
- batch_position_bias = self._dec_position_bias[:, :, cache_position[i][0]].unsqueeze(2)
129
+ if torch.compiler.is_exporting():
130
+ cache_pos = cache_position[i][0].item()
131
+ torch._check_is_size(cache_pos)
132
+ torch._check(cache_pos >= 0)
133
+ torch._check(cache_pos < self._dec_position_bias.shape[2])
134
+ else:
135
+ cache_pos = cache_position[i][0]
136
+ batch_position_bias = torch.select(self._dec_position_bias, dim=2, index=cache_pos).unsqueeze(2)
130
137
  batch_decoder_position_bias.append(batch_position_bias)
131
138
  position_bias = torch.cat(batch_decoder_position_bias, dim=0)
132
139
 
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  from ....configuration_utils import RBLNModelConfig
4
4
 
@@ -17,7 +17,7 @@ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
17
17
  enc_max_seq_len: Optional[int] = None,
18
18
  dec_max_seq_len: Optional[int] = None,
19
19
  num_parallel_samples: Optional[int] = None,
20
- **kwargs: Dict[str, Any],
20
+ **kwargs: Any,
21
21
  ):
22
22
  """
23
23
  Args:
@@ -25,7 +25,7 @@ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
25
25
  enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
26
26
  dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
27
27
  num_parallel_samples (Optional[int]): Number of samples to generate in parallel during prediction.
28
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
28
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
29
29
 
30
30
  Raises:
31
31
  ValueError: If batch_size is not a positive integer.
@@ -23,24 +23,20 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from dataclasses import dataclass
27
26
  from pathlib import Path
28
- from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
27
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
29
28
 
30
29
  import rebel
31
30
  import torch
32
31
  from rebel.compile_context import CompileContext
33
- from transformers import (
34
- PretrainedConfig,
35
- TimeSeriesTransformerForPrediction,
36
- TimeSeriesTransformerModel,
37
- )
38
- from transformers.modeling_outputs import ModelOutput, SampleTSPredictionOutput, Seq2SeqTSModelOutput
32
+ from transformers import PretrainedConfig, TimeSeriesTransformerForPrediction, TimeSeriesTransformerModel
33
+ from transformers.modeling_outputs import SampleTSPredictionOutput, Seq2SeqTSModelOutput
39
34
  from transformers.modeling_utils import no_init_weights
40
35
 
41
36
  from ....configuration_utils import RBLNCompileConfig
42
37
  from ....modeling import RBLNModel
43
38
  from ....utils.runtime_utils import RBLNPytorchRuntime
39
+ from ...modeling_outputs import RBLNSeq2SeqTSDecoderOutput
44
40
  from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
45
41
  from .time_series_transformers_architecture import TimeSeriesTransformersWrapper
46
42
 
@@ -113,12 +109,6 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
113
109
  )
114
110
 
115
111
 
116
- @dataclass
117
- class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
118
- last_hidden_states: torch.FloatTensor = None
119
- params: Tuple[torch.FloatTensor] = None
120
-
121
-
122
112
  class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
123
113
  """
124
114
  The Time Series Transformer Model with a distribution head on top for time-series forecasting. e.g., for datasets like M4, NN5, or other time series forecasting benchmarks.
@@ -162,7 +162,13 @@ class TimeSeriesTransformersDecoder(nn.Module):
162
162
  attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
163
163
 
164
164
  hidden_states = self.value_embedding(inputs_embeds)
165
- embed_pos = self.embed_positions.weight[cache_position + self.config.context_length]
165
+ embed_idx = cache_position + self.config.context_length
166
+ if torch.compiler.is_exporting():
167
+ embed_idx = embed_idx.item()
168
+ torch._check_is_size(embed_idx)
169
+ torch._check(embed_idx >= 0)
170
+ torch._check(embed_idx < len(self.embed_positions.weight))
171
+ embed_pos = self.embed_positions.weight[embed_idx]
166
172
  hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
167
173
 
168
174
  # iterate decoder_layer
@@ -38,6 +38,7 @@ class RBLNWav2Vec2ForCTC(RBLNModelForMaskedLM):
38
38
  library implements for all its model.
39
39
 
40
40
  It implements the methods to convert a pre-trained Wav2Vec2 model into a RBLN Wav2Vec2 model by:
41
+
41
42
  - transferring the checkpoint weights of the original into an optimized RBLN graph,
42
43
  - compiling the resulting graph using the RBLN compiler.
43
44
  """
@@ -12,9 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict
16
-
17
- import rebel
15
+ from typing import Any
18
16
 
19
17
  from ....configuration_utils import RBLNModelConfig
20
18
  from ....utils.logging import get_logger
@@ -38,17 +36,22 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
38
36
  use_attention_mask: bool = None,
39
37
  enc_max_seq_len: int = None,
40
38
  dec_max_seq_len: int = None,
41
- **kwargs: Dict[str, Any],
39
+ kvcache_num_blocks: int = None,
40
+ kvcache_block_size: int = None,
41
+ **kwargs: Any,
42
42
  ):
43
43
  """
44
44
  Args:
45
45
  batch_size (int, optional): The batch size for inference. Defaults to 1.
46
46
  token_timestamps (bool, optional): Whether to output token timestamps during generation. Defaults to False.
47
47
  use_attention_mask (bool, optional): Whether to use attention masks during inference. This is automatically
48
- set to True for RBLN-CA02 devices.
49
48
  enc_max_seq_len (int, optional): Maximum sequence length for the encoder.
50
49
  dec_max_seq_len (int, optional): Maximum sequence length for the decoder.
51
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
50
+ kvcache_num_blocks (int, optional): The total number of blocks to allocate for the
51
+ PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
52
+ kvcache_block_size (int, optional): Sets the size (in number of tokens) of each block
53
+ in the PagedAttention KV cache for the SelfAttention. Defaults to dec_max_seq_len.
54
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
52
55
 
53
56
  Raises:
54
57
  ValueError: If batch_size is not a positive integer.
@@ -64,10 +67,6 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
64
67
  self.dec_max_seq_len = dec_max_seq_len
65
68
 
66
69
  self.use_attention_mask = use_attention_mask
67
- npu = self.npu or rebel.get_npu_name()
68
- if npu == "RBLN-CA02":
69
- if self.use_attention_mask is False:
70
- logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
71
- self.use_attention_mask = True
72
- else:
73
- self.use_attention_mask = self.use_attention_mask or False
70
+ self.use_attention_mask = self.use_attention_mask or False
71
+ self.kvcache_num_blocks = kvcache_num_blocks
72
+ self.kvcache_block_size = kvcache_block_size
@@ -39,14 +39,31 @@ from transformers.models.whisper.generation_whisper import WhisperGenerationMixi
39
39
 
40
40
 
41
41
  class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
42
- """
43
- This class is based on transformers version 4.44.2.
44
- It uses the same generate() method, so it's crucial to maintain the inheritance order.
45
- Ensure WhisperGenerationMixin is listed before GenerationMixin.
46
- """
42
+ def generate(self, *args, generation_config=None, **kwargs):
43
+ num_beams = kwargs.get(
44
+ "num_beams",
45
+ generation_config.num_beams
46
+ if hasattr(generation_config, "num_beams") and generation_config.num_beams is not None
47
+ else 1,
48
+ )
49
+ if num_beams > 1:
50
+ raise ValueError(
51
+ f"Beam search is not supported in RBLNWhisperGenerationMixin. "
52
+ f"Received num_beams={num_beams}, but only num_beams=1 is allowed. "
53
+ f"Please set num_beams=1 for greedy search or adjust your configuration."
54
+ )
55
+
56
+ return super().generate(*args, **kwargs)
47
57
 
48
58
  def _postprocess_outputs(
49
- self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, *args, **kwargs
59
+ self,
60
+ seek_outputs,
61
+ decoder_input_ids,
62
+ return_token_timestamps,
63
+ generation_config,
64
+ is_shortform,
65
+ seek,
66
+ batch_idx_map,
50
67
  ):
51
68
  # remove all previously passed decoder input ids
52
69
  # should happen only if it is the first generated segment
@@ -64,6 +81,11 @@ class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
64
81
 
65
82
  if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
66
83
  num_frames = getattr(generation_config, "num_frames", None)
84
+
85
+ if num_frames is not None:
86
+ num_frames = num_frames - seek
87
+ num_frames = num_frames[batch_idx_map]
88
+
67
89
  if version.parse(transformers.__version__) >= version.parse("4.46.0"):
68
90
  seek_outputs["token_timestamps"] = self._extract_token_timestamps(
69
91
  seek_outputs,
@@ -46,7 +46,7 @@ if TYPE_CHECKING:
46
46
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
47
47
  mandatory_members = ["main_input_name"]
48
48
 
49
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
49
+ def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
50
50
  output = super().forward(*args, **kwargs)
51
51
  return BaseModelOutput(last_hidden_state=output)
52
52
 
@@ -73,6 +73,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
73
73
  decoder_input_ids: torch.Tensor = None,
74
74
  decoder_attention_mask: torch.Tensor = None,
75
75
  cache_position: torch.Tensor = None,
76
+ block_tables: torch.Tensor = None,
76
77
  ):
77
78
  inputs_bsz = decoder_input_ids.shape[0]
78
79
  padded_bsz = self.batch_size - inputs_bsz
@@ -89,11 +90,14 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
89
90
  )
90
91
  decoder_attention_mask[b_idx, : decoding_step + 1] = 1
91
92
 
93
+ if block_tables is None:
94
+ block_tables = self.default_block_tables
95
+
92
96
  outputs = super().forward(
93
97
  decoder_input_ids,
94
98
  decoder_attention_mask if self.use_attention_mask else None,
95
99
  cache_position,
96
- block_tables=self.default_block_tables,
100
+ block_tables=block_tables,
97
101
  )
98
102
 
99
103
  if isinstance(outputs, torch.Tensor):
@@ -108,6 +112,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
108
112
 
109
113
  This model inherits from [`RBLNModel`]. It implements the methods to convert and run
110
114
  pre-trained transformers based Whisper model on RBLN devices by:
115
+
111
116
  - transferring the checkpoint weights of the original into an optimized RBLN graph,
112
117
  - compiling the resulting graph using the RBLN compiler.
113
118
 
@@ -145,7 +150,8 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
145
150
  """
146
151
 
147
152
  auto_model_class = AutoModelForSpeechSeq2Seq
148
- main_input_name = "input_ids"
153
+ main_input_name = "input_features"
154
+ _is_stateful = False
149
155
 
150
156
  def __post_init__(self, **kwargs):
151
157
  super().__post_init__(**kwargs)
@@ -249,6 +255,23 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
249
255
 
250
256
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
251
257
 
258
+ @classmethod
259
+ def _update_paged_attention_config(
260
+ cls, model_config: "PretrainedConfig", rbln_config: RBLNWhisperForConditionalGenerationConfig
261
+ ):
262
+ rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
263
+ rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
264
+
265
+ if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
266
+ raise NotImplementedError(
267
+ f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
268
+ )
269
+
270
+ if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
271
+ raise NotImplementedError(
272
+ f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
273
+ )
274
+
252
275
  @classmethod
253
276
  def _update_rbln_config(
254
277
  cls,
@@ -266,6 +289,8 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
266
289
  if rbln_config.dec_max_seq_len is None:
267
290
  rbln_config.dec_max_seq_len = model_config.max_length
268
291
 
292
+ cls._update_paged_attention_config(model_config, rbln_config)
293
+
269
294
  enc_input_info = [
270
295
  ("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
271
296
  ("block_tables", [1], "int16"),
@@ -12,14 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_xlm_roberta import (
16
- RBLNXLMRobertaForSequenceClassificationConfig,
17
- RBLNXLMRobertaModelConfig,
18
- )
19
- from .modeling_xlm_roberta import (
20
- RBLNXLMRobertaForSequenceClassification,
21
- RBLNXLMRobertaModel,
22
- )
15
+ from .configuration_xlm_roberta import RBLNXLMRobertaForSequenceClassificationConfig, RBLNXLMRobertaModelConfig
16
+ from .modeling_xlm_roberta import RBLNXLMRobertaForSequenceClassification, RBLNXLMRobertaModel
23
17
 
24
18
 
25
19
  __all__ = [