optimum-rbln 0.8.2a7__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (105) hide show
  1. optimum/rbln/__init__.py +36 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/configuration_utils.py +20 -4
  4. optimum/rbln/diffusers/__init__.py +7 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  22. optimum/rbln/diffusers/pipelines/auto_pipeline.py +237 -0
  23. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  24. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  27. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  28. optimum/rbln/modeling.py +3 -2
  29. optimum/rbln/modeling_base.py +29 -4
  30. optimum/rbln/ops/attn.py +158 -0
  31. optimum/rbln/ops/flash_attn.py +166 -0
  32. optimum/rbln/transformers/__init__.py +28 -0
  33. optimum/rbln/transformers/configuration_generic.py +6 -4
  34. optimum/rbln/transformers/modeling_generic.py +13 -8
  35. optimum/rbln/transformers/modeling_outputs.py +37 -0
  36. optimum/rbln/transformers/models/__init__.py +35 -16
  37. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +14 -0
  39. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  40. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  41. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  43. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  44. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +7 -6
  45. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  46. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  47. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  48. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  49. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  50. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -93
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  52. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  53. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +297 -987
  54. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  55. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  56. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  57. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  58. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  59. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +14 -3
  60. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  61. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +64 -258
  62. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  63. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  64. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +86 -0
  65. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +507 -0
  66. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  67. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  68. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  69. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  70. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  71. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  72. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  74. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  75. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  76. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  77. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  78. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  79. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  80. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  81. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  82. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  83. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  85. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  86. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  87. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  88. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  89. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  90. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  91. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  92. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  94. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  95. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  96. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  97. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  98. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  99. optimum/rbln/transformers/utils/rbln_quantization.py +365 -65
  100. optimum/rbln/utils/runtime_utils.py +3 -3
  101. optimum/rbln/utils/submodule.py +10 -4
  102. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/METADATA +1 -1
  103. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/RECORD +105 -89
  104. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/WHEEL +0 -0
  105. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/licenses/LICENSE +0 -0
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
38
38
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
39
39
  mandatory_members = ["main_input_name"]
40
40
 
41
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
41
+ def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
42
42
  output = super().forward(*args, **kwargs)
43
43
  return BaseModelOutput(last_hidden_state=output)
44
44
 
@@ -181,6 +181,21 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
181
181
 
182
182
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
183
183
 
184
+ @classmethod
185
+ def _update_paged_attention_config(cls, model_config: PretrainedConfig, rbln_config: RBLNModelForSeq2SeqLMConfig):
186
+ rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
187
+ rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
188
+
189
+ if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
190
+ raise NotImplementedError(
191
+ f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
192
+ )
193
+
194
+ if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
195
+ raise NotImplementedError(
196
+ f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
197
+ )
198
+
184
199
  @classmethod
185
200
  def _update_rbln_config(
186
201
  cls,
@@ -238,6 +253,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
238
253
  if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
239
254
  raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
240
255
 
256
+ if rbln_config.support_paged_attention:
257
+ cls._update_paged_attention_config(model_config, rbln_config)
258
+
241
259
  # model input info
242
260
  enc_input_info = [
243
261
  ("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
@@ -310,6 +328,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
310
328
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
311
329
 
312
330
  rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
331
+
313
332
  return rbln_config
314
333
 
315
334
  @classmethod
@@ -12,9 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_siglip import (
16
- RBLNSiglipVisionModelConfig,
17
- )
18
- from .modeling_siglip import (
19
- RBLNSiglipVisionModel,
20
- )
15
+ from .configuration_siglip import RBLNSiglipVisionModelConfig
16
+ from .modeling_siglip import RBLNSiglipVisionModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
15
+ from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from transformers import SiglipVisionConfig, SiglipVisionModel
@@ -126,7 +126,7 @@ class RBLNSiglipVisionModel(RBLNModel):
126
126
  output_attentions: bool = None,
127
127
  output_hidden_states: bool = None,
128
128
  interpolate_pos_encoding: bool = False,
129
- **kwargs: Dict[str, Any],
129
+ **kwargs: Any,
130
130
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
131
131
  if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
132
132
  logger.warning(
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_swin import RBLNSwinBackboneConfig
16
+ from .modeling_swin import RBLNSwinBackbone
@@ -0,0 +1,42 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at:
4
+
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ from typing import Any, Optional, Tuple, Union
14
+
15
+ from ...configuration_generic import RBLNModelForImageClassificationConfig
16
+
17
+
18
+ class RBLNSwinBackboneConfig(RBLNModelForImageClassificationConfig):
19
+ def __init__(
20
+ self,
21
+ image_size: Optional[Union[int, Tuple[int, int]]] = None,
22
+ batch_size: Optional[int] = None,
23
+ output_hidden_states: Optional[bool] = None,
24
+ output_attentions: Optional[bool] = None,
25
+ **kwargs: Any,
26
+ ):
27
+ """
28
+ Args:
29
+ batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
30
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
31
+
32
+ Raises:
33
+ ValueError: If batch_size is not a positive integer.
34
+ """
35
+ super().__init__(**kwargs)
36
+ self.batch_size = batch_size or 1
37
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
38
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
39
+
40
+ self.image_size = image_size
41
+ self.output_hidden_states = output_hidden_states
42
+ self.output_attentions = output_attentions
@@ -0,0 +1,341 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import types
16
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from transformers import SwinConfig
21
+ from transformers.models.swin.modeling_swin import BackboneOutput
22
+
23
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
24
+ from ....modeling import RBLNModel
25
+ from ....utils.logging import get_logger
26
+ from .configuration_swin import RBLNSwinBackboneConfig
27
+
28
+
29
+ logger = get_logger(__name__)
30
+
31
+ if TYPE_CHECKING:
32
+ from transformers import (
33
+ AutoFeatureExtractor,
34
+ AutoProcessor,
35
+ AutoTokenizer,
36
+ PreTrainedModel,
37
+ SwinBackbone,
38
+ SwinEncoder,
39
+ )
40
+
41
+
42
+ def window_partition(input_feature, window_size):
43
+ """
44
+ Partitions the given input into windows.
45
+ """
46
+ batch_size, height, width, num_channels = input_feature.shape
47
+ input_feature = input_feature.view(
48
+ batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
49
+ )
50
+ windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
51
+ return windows
52
+
53
+
54
+ def get_attn_mask(self, height, width, dtype, device):
55
+ if self.shift_size > 0:
56
+ # calculate attention mask for SW-MSA
57
+ img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
58
+ height_slices = (
59
+ slice(0, -self.window_size),
60
+ slice(-self.window_size, -self.shift_size),
61
+ slice(-self.shift_size, None),
62
+ )
63
+ width_slices = (
64
+ slice(0, -self.window_size),
65
+ slice(-self.window_size, -self.shift_size),
66
+ slice(-self.shift_size, None),
67
+ )
68
+ count = torch.zeros(1)
69
+ for height_slice in height_slices:
70
+ for width_slice in width_slices:
71
+ img_mask[:, height_slice, width_slice, :] = count
72
+ count += 1
73
+
74
+ mask_windows = window_partition(img_mask, self.window_size)
75
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
76
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
77
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
78
+ else:
79
+ attn_mask = None
80
+ return attn_mask
81
+
82
+
83
+ class _SwinEncoder(torch.nn.Module):
84
+ def __init__(self, model: "SwinEncoder"):
85
+ super().__init__()
86
+ self.layers = model.layers
87
+
88
+ def forward(
89
+ self,
90
+ hidden_states: torch.Tensor,
91
+ input_dimensions: Tuple[int, int],
92
+ head_mask: Optional[torch.FloatTensor] = None,
93
+ output_attentions: Optional[bool] = False,
94
+ output_hidden_states: Optional[bool] = False,
95
+ output_hidden_states_before_downsampling: Optional[bool] = False,
96
+ always_partition: Optional[bool] = False,
97
+ return_dict: Optional[bool] = True,
98
+ ):
99
+ all_hidden_states = () if output_hidden_states else None
100
+ all_reshaped_hidden_states = () if output_hidden_states else None
101
+ all_self_attentions = () if output_attentions else None
102
+
103
+ if output_hidden_states:
104
+ batch_size, _, hidden_size = hidden_states.shape
105
+ # rearrange b (h w) c -> b c h w
106
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
107
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
108
+ all_hidden_states += (hidden_states,)
109
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
110
+
111
+ for i, layer_module in enumerate(self.layers):
112
+ layer_head_mask = head_mask[i] if head_mask is not None else None
113
+
114
+ layer_outputs = layer_module(
115
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
116
+ )
117
+
118
+ hidden_states = layer_outputs[0]
119
+ hidden_states_before_downsampling = layer_outputs[1]
120
+ output_dimensions = layer_outputs[2]
121
+
122
+ input_dimensions = (output_dimensions[-2], output_dimensions[-1])
123
+
124
+ if output_hidden_states and output_hidden_states_before_downsampling:
125
+ batch_size, _, hidden_size = hidden_states_before_downsampling.shape
126
+ # rearrange b (h w) c -> b c h w
127
+ # here we use the original (not downsampled) height and width
128
+ reshaped_hidden_state = hidden_states_before_downsampling.view(
129
+ batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
130
+ )
131
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
132
+ all_hidden_states += (hidden_states_before_downsampling,)
133
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
134
+ elif output_hidden_states and not output_hidden_states_before_downsampling:
135
+ batch_size, _, hidden_size = hidden_states.shape
136
+ # rearrange b (h w) c -> b c h w
137
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
138
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
139
+ all_hidden_states += (hidden_states,)
140
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
141
+
142
+ if output_attentions:
143
+ all_self_attentions += layer_outputs[3:]
144
+
145
+ return tuple(
146
+ v
147
+ for v in [hidden_states, all_hidden_states, all_self_attentions, all_reshaped_hidden_states]
148
+ if v is not None
149
+ )
150
+
151
+
152
+ class _SwinBackbone(torch.nn.Module):
153
+ def __init__(self, model: "SwinBackbone", output_hidden_states: bool, output_attentions: bool):
154
+ super().__init__()
155
+ self.model = model
156
+ self.embeddings = model.embeddings
157
+ self.encoder = model.encoder
158
+ self.stage_names = model.stage_names
159
+ self.out_features = model.out_features
160
+ self.hidden_states_norms = model.hidden_states_norms
161
+ self.output_hidden_states = output_hidden_states
162
+ self.output_attentions = output_attentions
163
+
164
+ def forward(
165
+ self,
166
+ pixel_values: torch.Tensor,
167
+ ):
168
+ embedding_output, input_dimensions = self.embeddings(pixel_values)
169
+ outputs = _SwinEncoder(self.encoder)(
170
+ embedding_output,
171
+ input_dimensions,
172
+ head_mask=None,
173
+ output_attentions=self.output_attentions,
174
+ output_hidden_states=True,
175
+ output_hidden_states_before_downsampling=True,
176
+ always_partition=True,
177
+ return_dict=False,
178
+ )
179
+
180
+ hidden_states = outputs[-1]
181
+
182
+ feature_maps = ()
183
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
184
+ if stage in self.out_features:
185
+ batch_size, num_channels, height, width = hidden_state.shape
186
+ hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
187
+ hidden_state = hidden_state.view(batch_size, height * width, num_channels)
188
+ hidden_state = self.hidden_states_norms[stage](hidden_state)
189
+ hidden_state = hidden_state.view(batch_size, height, width, num_channels)
190
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
191
+ feature_maps += (hidden_state,)
192
+
193
+ output = (feature_maps,)
194
+
195
+ if self.output_hidden_states:
196
+ output += (outputs[1],)
197
+
198
+ if self.output_attentions:
199
+ output += (outputs[2],)
200
+
201
+ return output
202
+
203
+
204
+ class RBLNSwinBackbone(RBLNModel):
205
+ @classmethod
206
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSwinBackboneConfig) -> torch.nn.Module:
207
+ for layer in model.encoder.layers:
208
+ for block in layer.blocks:
209
+ block.get_attn_mask = types.MethodType(get_attn_mask, block)
210
+
211
+ wrapper_cfg = {
212
+ "output_hidden_states": rbln_config.output_hidden_states,
213
+ "output_attentions": rbln_config.output_attentions,
214
+ }
215
+ return _SwinBackbone(model, **wrapper_cfg).eval()
216
+
217
+ @classmethod
218
+ def _update_submodule_config(
219
+ cls,
220
+ model: "PreTrainedModel",
221
+ rbln_config: RBLNModelConfig,
222
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
223
+ ):
224
+ for processor in preprocessors:
225
+ if rbln_config.image_size is None and hasattr(processor, "image_processor"):
226
+ if "height" in processor.image_processor.size and "width" in processor.image_processor.size:
227
+ rbln_config.image_size = (
228
+ processor.image_processor.size["height"],
229
+ processor.image_processor.size["width"],
230
+ )
231
+ elif (
232
+ "longest_edge" in processor.image_processor.size
233
+ and "shortest_edge" in processor.image_processor.size
234
+ ):
235
+ rbln_config.image_size = processor.image_processor.size["longest_edge"]
236
+ elif "shortest_edge" in processor.image_processor.size:
237
+ rbln_config.image_size = processor.image_processor.size["shortest_edge"]
238
+ break
239
+
240
+ return rbln_config
241
+
242
+ @classmethod
243
+ def _update_rbln_config(
244
+ cls,
245
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
246
+ model: Optional["PreTrainedModel"] = None,
247
+ model_config: "SwinConfig" = None,
248
+ rbln_config: Optional[RBLNSwinBackboneConfig] = None,
249
+ ) -> RBLNSwinBackboneConfig:
250
+ if rbln_config.image_size is None:
251
+ for processor in preprocessors:
252
+ if hasattr(processor, "size"):
253
+ if all(required_key in processor.size.keys() for required_key in ["height", "width"]):
254
+ rbln_config.image_size = (processor.size["height"], processor.size["width"])
255
+ break
256
+
257
+ input_info = [
258
+ (
259
+ "pixel_values",
260
+ [
261
+ rbln_config.batch_size,
262
+ 3,
263
+ rbln_config.image_height,
264
+ rbln_config.image_width,
265
+ ],
266
+ "float32",
267
+ ),
268
+ ]
269
+
270
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
271
+ return rbln_config
272
+
273
+ def forward(
274
+ self,
275
+ pixel_values: Optional[torch.FloatTensor] = None,
276
+ return_dict: bool = True,
277
+ output_attentions: bool = None,
278
+ output_hidden_states: bool = None,
279
+ **kwargs,
280
+ ) -> Union[Tuple, BackboneOutput]:
281
+ if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
282
+ logger.warning(
283
+ f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
284
+ )
285
+
286
+ output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
287
+ output_hidden_states = (
288
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
289
+ )
290
+
291
+ if output_attentions != self.rbln_config.output_attentions:
292
+ raise ValueError(
293
+ f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
294
+ f"Please compile again with the correct argument."
295
+ )
296
+
297
+ if output_hidden_states != self.rbln_config.output_hidden_states:
298
+ raise ValueError(
299
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
300
+ f"Please compile again with the correct argument."
301
+ )
302
+
303
+ _, _, original_h, original_w = pixel_values.shape
304
+ if original_h > self.rbln_config.image_height or original_w > self.rbln_config.image_width:
305
+ raise ValueError(
306
+ f"Input image size ({original_h}x{original_w}) exceeds the configured maximum size"
307
+ f" ({self.rbln_config.image_height}x{self.rbln_config.image_width})."
308
+ )
309
+
310
+ pad_h = self.rbln_config.image_height - original_h
311
+ pad_w = self.rbln_config.image_width - original_w
312
+ padded_pixel_values = F.pad(pixel_values, (0, pad_w, 0, pad_h))
313
+
314
+ output = self.model[0](padded_pixel_values)
315
+
316
+ feature_maps = ()
317
+ for i in range(len(self.config.out_features)):
318
+ feature_maps += (output.pop(0),)
319
+
320
+ if self.rbln_config.output_hidden_states:
321
+ hidden_states = ()
322
+ for i in range(len(self.config.stage_names)):
323
+ hidden_states += (output.pop(0),)
324
+ else:
325
+ hidden_states = None
326
+
327
+ if self.rbln_config.output_attentions:
328
+ attentions = ()
329
+ for i in range(len(self.config.depths)):
330
+ attentions += (output.pop(0),)
331
+ else:
332
+ attentions = None
333
+
334
+ if not return_dict:
335
+ return tuple(item for item in (feature_maps, hidden_states, attentions) if item is not None)
336
+ else:
337
+ return BackboneOutput(
338
+ feature_maps=feature_maps,
339
+ hidden_states=hidden_states,
340
+ attentions=attentions,
341
+ )
@@ -32,3 +32,5 @@ class RBLNT5ForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
32
32
  This configuration class stores the configuration parameters specific to
33
33
  RBLN-optimized T5 models for conditional text generation tasks.
34
34
  """
35
+
36
+ support_paged_attention = False
@@ -126,7 +126,14 @@ class T5Decoder(Seq2SeqDecoder):
126
126
  b_size = attention_mask.shape[0]
127
127
  batch_decoder_position_bias = []
128
128
  for i in range(b_size):
129
- batch_position_bias = self._dec_position_bias[:, :, cache_position[i][0]].unsqueeze(2)
129
+ if torch.compiler.is_exporting():
130
+ cache_pos = cache_position[i][0].item()
131
+ torch._check_is_size(cache_pos)
132
+ torch._check(cache_pos >= 0)
133
+ torch._check(cache_pos < self._dec_position_bias.shape[2])
134
+ else:
135
+ cache_pos = cache_position[i][0]
136
+ batch_position_bias = torch.select(self._dec_position_bias, dim=2, index=cache_pos).unsqueeze(2)
130
137
  batch_decoder_position_bias.append(batch_position_bias)
131
138
  position_bias = torch.cat(batch_decoder_position_bias, dim=0)
132
139
 
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  from ....configuration_utils import RBLNModelConfig
4
4
 
@@ -17,7 +17,7 @@ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
17
17
  enc_max_seq_len: Optional[int] = None,
18
18
  dec_max_seq_len: Optional[int] = None,
19
19
  num_parallel_samples: Optional[int] = None,
20
- **kwargs: Dict[str, Any],
20
+ **kwargs: Any,
21
21
  ):
22
22
  """
23
23
  Args:
@@ -23,24 +23,20 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from dataclasses import dataclass
27
26
  from pathlib import Path
28
- from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
27
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
29
28
 
30
29
  import rebel
31
30
  import torch
32
31
  from rebel.compile_context import CompileContext
33
- from transformers import (
34
- PretrainedConfig,
35
- TimeSeriesTransformerForPrediction,
36
- TimeSeriesTransformerModel,
37
- )
38
- from transformers.modeling_outputs import ModelOutput, SampleTSPredictionOutput, Seq2SeqTSModelOutput
32
+ from transformers import PretrainedConfig, TimeSeriesTransformerForPrediction, TimeSeriesTransformerModel
33
+ from transformers.modeling_outputs import SampleTSPredictionOutput, Seq2SeqTSModelOutput
39
34
  from transformers.modeling_utils import no_init_weights
40
35
 
41
36
  from ....configuration_utils import RBLNCompileConfig
42
37
  from ....modeling import RBLNModel
43
38
  from ....utils.runtime_utils import RBLNPytorchRuntime
39
+ from ...modeling_outputs import RBLNSeq2SeqTSDecoderOutput
44
40
  from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
45
41
  from .time_series_transformers_architecture import TimeSeriesTransformersWrapper
46
42
 
@@ -113,12 +109,6 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
113
109
  )
114
110
 
115
111
 
116
- @dataclass
117
- class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
118
- last_hidden_states: torch.FloatTensor = None
119
- params: Tuple[torch.FloatTensor] = None
120
-
121
-
122
112
  class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
123
113
  """
124
114
  The Time Series Transformer Model with a distribution head on top for time-series forecasting. e.g., for datasets like M4, NN5, or other time series forecasting benchmarks.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict
15
+ from typing import Any
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....utils.logging import get_logger
@@ -36,7 +36,9 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
36
36
  use_attention_mask: bool = None,
37
37
  enc_max_seq_len: int = None,
38
38
  dec_max_seq_len: int = None,
39
- **kwargs: Dict[str, Any],
39
+ kvcache_num_blocks: int = None,
40
+ kvcache_block_size: int = None,
41
+ **kwargs: Any,
40
42
  ):
41
43
  """
42
44
  Args:
@@ -45,6 +47,10 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
45
47
  use_attention_mask (bool, optional): Whether to use attention masks during inference. This is automatically
46
48
  enc_max_seq_len (int, optional): Maximum sequence length for the encoder.
47
49
  dec_max_seq_len (int, optional): Maximum sequence length for the decoder.
50
+ kvcache_num_blocks (int, optional): The total number of blocks to allocate for the
51
+ PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
52
+ kvcache_block_size (int, optional): Sets the size (in number of tokens) of each block
53
+ in the PagedAttention KV cache for the SelfAttention. Defaults to dec_max_seq_len.
48
54
  **kwargs: Additional arguments passed to the parent RBLNModelConfig.
49
55
 
50
56
  Raises:
@@ -62,3 +68,5 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
62
68
 
63
69
  self.use_attention_mask = use_attention_mask
64
70
  self.use_attention_mask = self.use_attention_mask or False
71
+ self.kvcache_num_blocks = kvcache_num_blocks
72
+ self.kvcache_block_size = kvcache_block_size
@@ -46,7 +46,7 @@ if TYPE_CHECKING:
46
46
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
47
47
  mandatory_members = ["main_input_name"]
48
48
 
49
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
49
+ def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
50
50
  output = super().forward(*args, **kwargs)
51
51
  return BaseModelOutput(last_hidden_state=output)
52
52
 
@@ -253,6 +253,23 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
253
253
 
254
254
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
255
255
 
256
+ @classmethod
257
+ def _update_paged_attention_config(
258
+ cls, model_config: "PretrainedConfig", rbln_config: RBLNWhisperForConditionalGenerationConfig
259
+ ):
260
+ rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
261
+ rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
262
+
263
+ if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
264
+ raise NotImplementedError(
265
+ f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
266
+ )
267
+
268
+ if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
269
+ raise NotImplementedError(
270
+ f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
271
+ )
272
+
256
273
  @classmethod
257
274
  def _update_rbln_config(
258
275
  cls,
@@ -270,6 +287,8 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
270
287
  if rbln_config.dec_max_seq_len is None:
271
288
  rbln_config.dec_max_seq_len = model_config.max_length
272
289
 
290
+ cls._update_paged_attention_config(model_config, rbln_config)
291
+
273
292
  enc_input_info = [
274
293
  ("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
275
294
  ("block_tables", [1], "int16"),
@@ -12,14 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_xlm_roberta import (
16
- RBLNXLMRobertaForSequenceClassificationConfig,
17
- RBLNXLMRobertaModelConfig,
18
- )
19
- from .modeling_xlm_roberta import (
20
- RBLNXLMRobertaForSequenceClassification,
21
- RBLNXLMRobertaModel,
22
- )
15
+ from .configuration_xlm_roberta import RBLNXLMRobertaForSequenceClassificationConfig, RBLNXLMRobertaModelConfig
16
+ from .modeling_xlm_roberta import RBLNXLMRobertaForSequenceClassification, RBLNXLMRobertaModel
23
17
 
24
18
 
25
19
  __all__ = [