optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. optimum/rbln/__init__.py +116 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +171 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +33 -18
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +50 -24
  52. optimum/rbln/modeling_base.py +116 -35
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +100 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +93 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  158. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  159. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  160. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  161. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  162. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  163. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  164. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
  165. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  166. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  167. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  168. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  169. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  170. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  171. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  172. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  173. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  174. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  175. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  176. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
  177. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  178. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  179. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  180. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  181. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  182. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  183. optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
  184. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  185. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  186. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  187. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  188. optimum/rbln/utils/deprecation.py +213 -0
  189. optimum/rbln/utils/hub.py +22 -50
  190. optimum/rbln/utils/runtime_utils.py +85 -17
  191. optimum/rbln/utils/submodule.py +31 -9
  192. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  193. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  194. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  195. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  196. optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
  197. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,354 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import types
16
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from transformers import SwinConfig
21
+ from transformers.models.swin.modeling_swin import BackboneOutput
22
+
23
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
24
+ from ....modeling import RBLNModel
25
+ from ....utils.logging import get_logger
26
+ from .configuration_swin import RBLNSwinBackboneConfig
27
+
28
+
29
+ logger = get_logger(__name__)
30
+
31
+ if TYPE_CHECKING:
32
+ from transformers import (
33
+ AutoFeatureExtractor,
34
+ AutoProcessor,
35
+ AutoTokenizer,
36
+ PreTrainedModel,
37
+ SwinBackbone,
38
+ SwinEncoder,
39
+ )
40
+
41
+
42
+ def window_partition(input_feature, window_size):
43
+ """
44
+ Partitions the given input into windows.
45
+ """
46
+ batch_size, height, width, num_channels = input_feature.shape
47
+ input_feature = input_feature.view(
48
+ batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
49
+ )
50
+ windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
51
+ return windows
52
+
53
+
54
+ def get_attn_mask(self, height, width, dtype, device):
55
+ if self.shift_size > 0:
56
+ # calculate attention mask for SW-MSA
57
+ img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
58
+ height_slices = (
59
+ slice(0, -self.window_size),
60
+ slice(-self.window_size, -self.shift_size),
61
+ slice(-self.shift_size, None),
62
+ )
63
+ width_slices = (
64
+ slice(0, -self.window_size),
65
+ slice(-self.window_size, -self.shift_size),
66
+ slice(-self.shift_size, None),
67
+ )
68
+ count = torch.zeros(1)
69
+ for height_slice in height_slices:
70
+ for width_slice in width_slices:
71
+ img_mask[:, height_slice, width_slice, :] = count
72
+ count += 1
73
+
74
+ mask_windows = window_partition(img_mask, self.window_size)
75
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
76
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
77
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
78
+ else:
79
+ attn_mask = None
80
+ return attn_mask
81
+
82
+
83
+ class _SwinEncoder(torch.nn.Module):
84
+ def __init__(self, model: "SwinEncoder"):
85
+ super().__init__()
86
+ self.layers = model.layers
87
+
88
+ def forward(
89
+ self,
90
+ hidden_states: torch.Tensor,
91
+ input_dimensions: Tuple[int, int],
92
+ head_mask: Optional[torch.FloatTensor] = None,
93
+ output_attentions: Optional[bool] = False,
94
+ output_hidden_states: Optional[bool] = False,
95
+ output_hidden_states_before_downsampling: Optional[bool] = False,
96
+ always_partition: Optional[bool] = False,
97
+ return_dict: Optional[bool] = True,
98
+ ):
99
+ all_hidden_states = () if output_hidden_states else None
100
+ all_reshaped_hidden_states = () if output_hidden_states else None
101
+ all_self_attentions = () if output_attentions else None
102
+
103
+ if output_hidden_states:
104
+ batch_size, _, hidden_size = hidden_states.shape
105
+ # rearrange b (h w) c -> b c h w
106
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
107
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
108
+ all_hidden_states += (hidden_states,)
109
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
110
+
111
+ for i, layer_module in enumerate(self.layers):
112
+ layer_head_mask = head_mask[i] if head_mask is not None else None
113
+
114
+ layer_outputs = layer_module(
115
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
116
+ )
117
+
118
+ hidden_states = layer_outputs[0]
119
+ hidden_states_before_downsampling = layer_outputs[1]
120
+ output_dimensions = layer_outputs[2]
121
+
122
+ input_dimensions = (output_dimensions[-2], output_dimensions[-1])
123
+
124
+ if output_hidden_states and output_hidden_states_before_downsampling:
125
+ batch_size, _, hidden_size = hidden_states_before_downsampling.shape
126
+ # rearrange b (h w) c -> b c h w
127
+ # here we use the original (not downsampled) height and width
128
+ reshaped_hidden_state = hidden_states_before_downsampling.view(
129
+ batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
130
+ )
131
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
132
+ all_hidden_states += (hidden_states_before_downsampling,)
133
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
134
+ elif output_hidden_states and not output_hidden_states_before_downsampling:
135
+ batch_size, _, hidden_size = hidden_states.shape
136
+ # rearrange b (h w) c -> b c h w
137
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
138
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
139
+ all_hidden_states += (hidden_states,)
140
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
141
+
142
+ if output_attentions:
143
+ all_self_attentions += layer_outputs[3:]
144
+
145
+ return tuple(
146
+ v
147
+ for v in [hidden_states, all_hidden_states, all_self_attentions, all_reshaped_hidden_states]
148
+ if v is not None
149
+ )
150
+
151
+
152
+ class _SwinBackbone(torch.nn.Module):
153
+ def __init__(self, model: "SwinBackbone", output_hidden_states: bool, output_attentions: bool):
154
+ super().__init__()
155
+ self.model = model
156
+ self.embeddings = model.embeddings
157
+ self.encoder = model.encoder
158
+ self.stage_names = model.stage_names
159
+ self.out_features = model.out_features
160
+ self.hidden_states_norms = model.hidden_states_norms
161
+ self.output_hidden_states = output_hidden_states
162
+ self.output_attentions = output_attentions
163
+
164
+ def forward(
165
+ self,
166
+ pixel_values: torch.Tensor,
167
+ ):
168
+ embedding_output, input_dimensions = self.embeddings(pixel_values)
169
+ outputs = _SwinEncoder(self.encoder)(
170
+ embedding_output,
171
+ input_dimensions,
172
+ head_mask=None,
173
+ output_attentions=self.output_attentions,
174
+ output_hidden_states=True,
175
+ output_hidden_states_before_downsampling=True,
176
+ always_partition=True,
177
+ return_dict=False,
178
+ )
179
+
180
+ hidden_states = outputs[-1]
181
+
182
+ feature_maps = ()
183
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
184
+ if stage in self.out_features:
185
+ batch_size, num_channels, height, width = hidden_state.shape
186
+ hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
187
+ hidden_state = hidden_state.view(batch_size, height * width, num_channels)
188
+ hidden_state = self.hidden_states_norms[stage](hidden_state)
189
+ hidden_state = hidden_state.view(batch_size, height, width, num_channels)
190
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
191
+ feature_maps += (hidden_state,)
192
+
193
+ output = (feature_maps,)
194
+
195
+ if self.output_hidden_states:
196
+ output += (outputs[1],)
197
+
198
+ if self.output_attentions:
199
+ output += (outputs[2],)
200
+
201
+ return output
202
+
203
+
204
+ class RBLNSwinBackbone(RBLNModel):
205
+ @classmethod
206
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSwinBackboneConfig) -> torch.nn.Module:
207
+ for layer in model.encoder.layers:
208
+ for block in layer.blocks:
209
+ block.get_attn_mask = types.MethodType(get_attn_mask, block)
210
+
211
+ wrapper_cfg = {
212
+ "output_hidden_states": rbln_config.output_hidden_states,
213
+ "output_attentions": rbln_config.output_attentions,
214
+ }
215
+ return _SwinBackbone(model, **wrapper_cfg).eval()
216
+
217
+ @classmethod
218
+ def _update_submodule_config(
219
+ cls,
220
+ model: "PreTrainedModel",
221
+ rbln_config: RBLNModelConfig,
222
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
223
+ ):
224
+ for processor in preprocessors:
225
+ if rbln_config.image_size is None and hasattr(processor, "image_processor"):
226
+ if "height" in processor.image_processor.size and "width" in processor.image_processor.size:
227
+ rbln_config.image_size = (
228
+ processor.image_processor.size["height"],
229
+ processor.image_processor.size["width"],
230
+ )
231
+ elif (
232
+ "longest_edge" in processor.image_processor.size
233
+ and "shortest_edge" in processor.image_processor.size
234
+ ):
235
+ rbln_config.image_size = processor.image_processor.size["longest_edge"]
236
+ elif "shortest_edge" in processor.image_processor.size:
237
+ rbln_config.image_size = processor.image_processor.size["shortest_edge"]
238
+ break
239
+
240
+ return rbln_config
241
+
242
+ @classmethod
243
+ def _update_rbln_config(
244
+ cls,
245
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
246
+ model: Optional["PreTrainedModel"] = None,
247
+ model_config: "SwinConfig" = None,
248
+ rbln_config: Optional[RBLNSwinBackboneConfig] = None,
249
+ ) -> RBLNSwinBackboneConfig:
250
+ if rbln_config.image_size is None:
251
+ for processor in preprocessors:
252
+ if hasattr(processor, "size"):
253
+ if all(required_key in processor.size.keys() for required_key in ["height", "width"]):
254
+ rbln_config.image_size = (processor.size["height"], processor.size["width"])
255
+ break
256
+
257
+ input_info = [
258
+ (
259
+ "pixel_values",
260
+ [
261
+ rbln_config.batch_size,
262
+ 3,
263
+ rbln_config.image_height,
264
+ rbln_config.image_width,
265
+ ],
266
+ "float32",
267
+ ),
268
+ ]
269
+
270
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
271
+ return rbln_config
272
+
273
+ def forward(
274
+ self,
275
+ pixel_values: Optional[torch.FloatTensor] = None,
276
+ return_dict: bool = True,
277
+ output_attentions: bool = None,
278
+ output_hidden_states: bool = None,
279
+ **kwargs,
280
+ ) -> Union[Tuple, BackboneOutput]:
281
+ """
282
+ Forward pass for the RBLN-optimized Swin backbone model.
283
+
284
+ Args:
285
+ pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size), optional): The tensors corresponding to the input images. Pixel values can be obtained using ViTImageProcessor. See ViTImageProcessor.call() for details (processor_class uses ViTImageProcessor for processing images).
286
+ return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
287
+ output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
288
+ output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
289
+
290
+ Returns:
291
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BackboneOutput object.
292
+ """
293
+
294
+ if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
295
+ logger.warning(
296
+ f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
297
+ )
298
+
299
+ output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
300
+ output_hidden_states = (
301
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
302
+ )
303
+
304
+ if output_attentions != self.rbln_config.output_attentions:
305
+ raise ValueError(
306
+ f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
307
+ f"Please compile again with the correct argument."
308
+ )
309
+
310
+ if output_hidden_states != self.rbln_config.output_hidden_states:
311
+ raise ValueError(
312
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
313
+ f"Please compile again with the correct argument."
314
+ )
315
+
316
+ _, _, original_h, original_w = pixel_values.shape
317
+ if original_h > self.rbln_config.image_height or original_w > self.rbln_config.image_width:
318
+ raise ValueError(
319
+ f"Input image size ({original_h}x{original_w}) exceeds the configured maximum size"
320
+ f" ({self.rbln_config.image_height}x{self.rbln_config.image_width})."
321
+ )
322
+
323
+ pad_h = self.rbln_config.image_height - original_h
324
+ pad_w = self.rbln_config.image_width - original_w
325
+ padded_pixel_values = F.pad(pixel_values, (0, pad_w, 0, pad_h))
326
+
327
+ output = self.model[0](padded_pixel_values)
328
+
329
+ feature_maps = ()
330
+ for i in range(len(self.config.out_features)):
331
+ feature_maps += (output.pop(0),)
332
+
333
+ if self.rbln_config.output_hidden_states:
334
+ hidden_states = ()
335
+ for i in range(len(self.config.stage_names)):
336
+ hidden_states += (output.pop(0),)
337
+ else:
338
+ hidden_states = None
339
+
340
+ if self.rbln_config.output_attentions:
341
+ attentions = ()
342
+ for i in range(len(self.config.depths)):
343
+ attentions += (output.pop(0),)
344
+ else:
345
+ attentions = None
346
+
347
+ if not return_dict:
348
+ return tuple(item for item in (feature_maps, hidden_states, attentions) if item is not None)
349
+ else:
350
+ return BackboneOutput(
351
+ feature_maps=feature_maps,
352
+ hidden_states=hidden_states,
353
+ attentions=attentions,
354
+ )
@@ -32,3 +32,5 @@ class RBLNT5ForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
32
32
  This configuration class stores the configuration parameters specific to
33
33
  RBLN-optimized T5 models for conditional text generation tasks.
34
34
  """
35
+
36
+ support_paged_attention = False
@@ -68,7 +68,7 @@ class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
68
68
  output_class = BaseModelOutputWithPastAndCrossAttentions
69
69
 
70
70
  @classmethod
71
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
71
+ def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
72
72
  return T5EncoderWrapper(model)
73
73
 
74
74
  @classmethod
@@ -113,7 +113,7 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
113
113
  support_causal_attn = False
114
114
 
115
115
  @classmethod
116
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5ForConditionalGenerationConfig):
116
+ def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5ForConditionalGenerationConfig):
117
117
  return T5Wrapper(
118
118
  model, enc_max_seq_len=rbln_config.enc_max_seq_len, dec_max_seq_len=rbln_config.dec_max_seq_len
119
119
  )
@@ -126,7 +126,14 @@ class T5Decoder(Seq2SeqDecoder):
126
126
  b_size = attention_mask.shape[0]
127
127
  batch_decoder_position_bias = []
128
128
  for i in range(b_size):
129
- batch_position_bias = self._dec_position_bias[:, :, cache_position[i][0]].unsqueeze(2)
129
+ if torch.compiler.is_exporting():
130
+ cache_pos = cache_position[i][0].item()
131
+ torch._check_is_size(cache_pos)
132
+ torch._check(cache_pos >= 0)
133
+ torch._check(cache_pos < self._dec_position_bias.shape[2])
134
+ else:
135
+ cache_pos = cache_position[i][0]
136
+ batch_position_bias = torch.select(self._dec_position_bias, dim=2, index=cache_pos).unsqueeze(2)
130
137
  batch_decoder_position_bias.append(batch_position_bias)
131
138
  position_bias = torch.cat(batch_decoder_position_bias, dim=0)
132
139
 
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  from ....configuration_utils import RBLNModelConfig
4
4
 
@@ -17,7 +17,7 @@ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
17
17
  enc_max_seq_len: Optional[int] = None,
18
18
  dec_max_seq_len: Optional[int] = None,
19
19
  num_parallel_samples: Optional[int] = None,
20
- **kwargs: Dict[str, Any],
20
+ **kwargs: Any,
21
21
  ):
22
22
  """
23
23
  Args:
@@ -25,7 +25,7 @@ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
25
25
  enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
26
26
  dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
27
27
  num_parallel_samples (Optional[int]): Number of samples to generate in parallel during prediction.
28
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
28
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
29
29
 
30
30
  Raises:
31
31
  ValueError: If batch_size is not a positive integer.
@@ -23,24 +23,20 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from dataclasses import dataclass
27
26
  from pathlib import Path
28
- from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
27
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
29
28
 
30
29
  import rebel
31
30
  import torch
32
31
  from rebel.compile_context import CompileContext
33
- from transformers import (
34
- PretrainedConfig,
35
- TimeSeriesTransformerForPrediction,
36
- TimeSeriesTransformerModel,
37
- )
38
- from transformers.modeling_outputs import ModelOutput, SampleTSPredictionOutput, Seq2SeqTSModelOutput
32
+ from transformers import PretrainedConfig, TimeSeriesTransformerForPrediction, TimeSeriesTransformerModel
33
+ from transformers.modeling_outputs import SampleTSPredictionOutput, Seq2SeqTSModelOutput
39
34
  from transformers.modeling_utils import no_init_weights
40
35
 
41
36
  from ....configuration_utils import RBLNCompileConfig
42
37
  from ....modeling import RBLNModel
43
38
  from ....utils.runtime_utils import RBLNPytorchRuntime
39
+ from ...modeling_outputs import RBLNSeq2SeqTSDecoderOutput
44
40
  from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
45
41
  from .time_series_transformers_architecture import TimeSeriesTransformersWrapper
46
42
 
@@ -113,12 +109,6 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
113
109
  )
114
110
 
115
111
 
116
- @dataclass
117
- class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
118
- last_hidden_states: torch.FloatTensor = None
119
- params: Tuple[torch.FloatTensor] = None
120
-
121
-
122
112
  class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
123
113
  """
124
114
  The Time Series Transformer Model with a distribution head on top for time-series forecasting. e.g., for datasets like M4, NN5, or other time series forecasting benchmarks.
@@ -163,7 +153,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
163
153
  return redirect(val)
164
154
 
165
155
  @classmethod
166
- def wrap_model_if_needed(
156
+ def _wrap_model_if_needed(
167
157
  self, model: "PreTrainedModel", rbln_config: RBLNTimeSeriesTransformerForPredictionConfig
168
158
  ):
169
159
  return TimeSeriesTransformersWrapper(model, rbln_config.num_parallel_samples)
@@ -171,7 +161,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
171
161
  @classmethod
172
162
  @torch.inference_mode()
173
163
  def get_compiled_model(cls, model, rbln_config: RBLNTimeSeriesTransformerForPredictionConfig):
174
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
164
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
175
165
 
176
166
  enc_compile_config = rbln_config.compile_cfgs[0]
177
167
  dec_compile_config = rbln_config.compile_cfgs[1]
@@ -331,12 +321,14 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
331
321
  tensor_type="pt",
332
322
  device=rbln_config.device_map["encoder"],
333
323
  activate_profiler=rbln_config.activate_profiler,
324
+ timeout=rbln_config.timeout,
334
325
  ),
335
326
  rebel.Runtime(
336
327
  compiled_models[1],
337
328
  tensor_type="pt",
338
329
  device=rbln_config.device_map["decoder"],
339
330
  activate_profiler=rbln_config.activate_profiler,
331
+ timeout=rbln_config.timeout,
340
332
  ),
341
333
  ]
342
334
 
@@ -361,6 +353,20 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
361
353
  static_real_features: Optional[torch.Tensor] = None,
362
354
  **kwargs,
363
355
  ) -> SampleTSPredictionOutput:
356
+ """
357
+ Generate pass for the RBLN-optimized Time Series Transformer model for time series forecasting.
358
+
359
+ Args:
360
+ past_values (torch.FloatTensor of shape (batch_size, sequence_length) or (batch_size, sequence_length, input_size)): Past values of the time series, that serve as context in order to predict the future.
361
+ past_time_features (torch.FloatTensor of shape (batch_size, sequence_length, num_features)): Required time features, which the model internally will add to past_values.
362
+ future_time_features (torch.FloatTensor of shape (batch_size, prediction_length, num_features)): Required time features for the prediction window, which the model internally will add to future_values.
363
+ past_observed_mask (torch.BoolTensor of shape (batch_size, sequence_length) or (batch_size, sequence_length, input_size), optional): Boolean mask to indicate which past_values were observed and which were missing.
364
+ static_categorical_features (torch.LongTensor of shape (batch_size, number of static categorical features), optional): Optional static categorical features for which the model will learn an embedding, which it will add to the values of the time series.
365
+ static_real_features (torch.FloatTensor of shape (batch_size, number of static real features), optional): Optional static real features which the model will add to the values of the time series.
366
+
367
+ Returns:
368
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SampleTSPredictionOutput object.
369
+ """
364
370
  self.validate_batch_size(**{k: v for k, v in locals().items() if isinstance(v, torch.Tensor)})
365
371
 
366
372
  outputs = self.encoder(
@@ -162,7 +162,13 @@ class TimeSeriesTransformersDecoder(nn.Module):
162
162
  attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
163
163
 
164
164
  hidden_states = self.value_embedding(inputs_embeds)
165
- embed_pos = self.embed_positions.weight[cache_position + self.config.context_length]
165
+ embed_idx = cache_position + self.config.context_length
166
+ if torch.compiler.is_exporting():
167
+ embed_idx = embed_idx.item()
168
+ torch._check_is_size(embed_idx)
169
+ torch._check(embed_idx >= 0)
170
+ torch._check(embed_idx < len(self.embed_positions.weight))
171
+ embed_pos = self.embed_positions.weight[embed_idx]
166
172
  hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
167
173
 
168
174
  # iterate decoder_layer
@@ -12,6 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Tuple, Union
16
+
17
+ import torch
18
+ from transformers.modeling_outputs import ImageClassifierOutput
19
+
15
20
  from ...modeling_generic import RBLNModelForImageClassification
16
21
 
17
22
 
@@ -23,3 +28,17 @@ class RBLNViTForImageClassification(RBLNModelForImageClassification):
23
28
  on RBLN devices, supporting image classification with transformer-based architectures
24
29
  that process images as sequences of patches.
25
30
  """
31
+
32
+ def forward(self, pixel_values: torch.Tensor, **kwargs) -> Union[ImageClassifierOutput, Tuple]:
33
+ """
34
+ Forward pass for the RBLN-optimized Vision Transformer model for image classification.
35
+
36
+ Args:
37
+ pixel_values (torch.FloatTensor of shape (batch_size, channels, height, width)):
38
+ The tensors corresponding to the input images.
39
+
40
+ Returns:
41
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns an ImageClassifierOutput object.
42
+
43
+ """
44
+ return super().forward(pixel_values, **kwargs)
@@ -12,10 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ...configuration_generic import RBLNModelForMaskedLMConfig
15
+ from typing import Any, Optional
16
16
 
17
+ from ....configuration_utils import RBLNModelConfig
17
18
 
18
- class RBLNWav2Vec2ForCTCConfig(RBLNModelForMaskedLMConfig):
19
+
20
+ class RBLNWav2Vec2ForCTCConfig(RBLNModelConfig):
19
21
  """
20
22
  Configuration class for RBLNWav2Vec2ForCTC.
21
23
 
@@ -23,4 +25,14 @@ class RBLNWav2Vec2ForCTCConfig(RBLNModelForMaskedLMConfig):
23
25
  RBLN-optimized Wav2Vec2 models for Connectionist Temporal Classification (CTC) tasks.
24
26
  """
25
27
 
26
- rbln_model_input_names = ["input_values"]
28
+ def __init__(
29
+ self,
30
+ max_seq_len: Optional[int] = None,
31
+ batch_size: Optional[int] = None,
32
+ **kwargs: Any,
33
+ ):
34
+ super().__init__(**kwargs)
35
+ self.max_seq_len = max_seq_len
36
+ self.batch_size = batch_size or 1
37
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
38
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")