optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,331 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from typing import Optional, Tuple, Union
25
+
26
+ import torch
27
+ from torch import nn
28
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
29
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
30
+ from transformers.utils import logging
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class TimeSeriesTransformersWrapper:
37
+ def __init__(self, model, num_parallel_samples):
38
+ self.encoder = TimeSeriesTransformersEncoderWrapper(model)
39
+ self.decoder = TimeSeriesTransformersDecoderWrapper(model, num_parallel_samples)
40
+
41
+
42
+ class TimeSeriesTransformersEncoderWrapper(torch.nn.Module):
43
+ def __init__(self, model):
44
+ super().__init__()
45
+ self.config = model.config
46
+ self.encoder = model.get_encoder()
47
+ self.num_heads = self.config.decoder_attention_heads
48
+ self.d_kv = self.config.d_model // self.num_heads
49
+ self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
50
+
51
+ def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
52
+ return (
53
+ nn.ModuleList(layer.encoder_attn.k_proj for layer in decoder_layers),
54
+ nn.ModuleList(layer.encoder_attn.v_proj for layer in decoder_layers),
55
+ )
56
+
57
+ def forward(
58
+ self,
59
+ inputs_embeds: torch.Tensor,
60
+ cross_key_values: torch.Tensor, # n_layers, batch_size, num_heads, context_length, d_kv
61
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
62
+ # 1. get encoder last_hidden_states
63
+ encoder_outputs = self.encoder(inputs_embeds=inputs_embeds, attention_mask=None, return_dict=False)
64
+ last_hidden_states = encoder_outputs[0]
65
+
66
+ # 2. pre-compute cross_attention's past_key_value which used in decoder phase.
67
+ cross_kv = []
68
+ batch_size = inputs_embeds.shape[0]
69
+ for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
70
+ past_k = k_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
71
+ past_v = v_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
72
+
73
+ cross_kv.append(past_k)
74
+ cross_kv.append(past_v)
75
+
76
+ cross_kv = torch.stack(cross_kv, dim=0)
77
+
78
+ # 3. update cross_attention's past_key_value to the device-dram for optimization.
79
+ bidx = torch.tensor(0, dtype=torch.int16)
80
+ axis = torch.tensor(1, dtype=torch.int16)
81
+ enc_output = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
82
+
83
+ return enc_output
84
+
85
+
86
+ class TimeSeriesTransformersDecoderWrapper(torch.nn.Module):
87
+ def __init__(self, model, num_parallel_samples):
88
+ super().__init__()
89
+ self.config = model.config
90
+ self.num_layers = self.config.decoder_layers
91
+ self.decoder = self.convert_to_rbln_tst_decoder(model, num_parallel_samples)
92
+ self.parameter_projection = model.parameter_projection
93
+
94
+ def convert_to_rbln_tst_decoder(self, model: nn.Module, num_parallel_samples: int):
95
+ new_layers = []
96
+ for layer in model.get_decoder().layers:
97
+ self_attn = TimeSeriesTransformersSelfAttention(layer.self_attn, num_parallel_samples)
98
+ cross_attn = TimeSeriesTransformersCrossAttention(layer.encoder_attn, num_parallel_samples)
99
+ new_layers.append(TimeSeriesTransformersDecoderLayer(layer, self_attn, cross_attn))
100
+
101
+ decoder_model = TimeSeriesTransformersDecoder(model.get_decoder(), new_layers)
102
+
103
+ return decoder_model
104
+
105
+ def forward(
106
+ self,
107
+ inputs_embeds: torch.Tensor,
108
+ decoder_attention_mask: torch.Tensor,
109
+ cache_position: torch.Tensor,
110
+ block_tables: torch.Tensor,
111
+ cross_kv_cache: torch.Tensor, # batch_size, num_heads, context_length, d_kv
112
+ *self_kv_cache: torch.Tensor, # batch_size * num_parallel_samples, num_heads, prediction_length, d_kv
113
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
114
+ # prepare past_key_values
115
+ self_past_key_values = ()
116
+ cross_past_key_values = ()
117
+ for i in range(0, self.num_layers * 2, 2):
118
+ self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
119
+ cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
120
+
121
+ # Decode
122
+ last_hidden_states = self.decoder(
123
+ inputs_embeds=inputs_embeds,
124
+ attention_mask=decoder_attention_mask,
125
+ cache_position=cache_position,
126
+ block_tables=block_tables,
127
+ self_past_key_values=self_past_key_values,
128
+ cross_past_key_values=cross_past_key_values,
129
+ )
130
+
131
+ params = self.parameter_projection(last_hidden_states[:, -1:])
132
+
133
+ outputs = ()
134
+ outputs += (params,)
135
+ outputs += (last_hidden_states,)
136
+
137
+ return outputs
138
+
139
+
140
+ class TimeSeriesTransformersDecoder(nn.Module):
141
+ def __init__(self, model, layers, **kwargs):
142
+ super().__init__()
143
+ self._original_mod = model
144
+ self.config = model.config
145
+ self.layers = nn.ModuleList(layers)
146
+ self.value_embedding = model.value_embedding
147
+ self.embed_positions = model.embed_positions
148
+ self.layernorm_embedding = model.layernorm_embedding
149
+
150
+ def forward(
151
+ self,
152
+ inputs_embeds: torch.Tensor = None,
153
+ attention_mask: Optional[torch.Tensor] = None,
154
+ self_past_key_values: Optional[torch.Tensor] = None,
155
+ cross_past_key_values: Optional[torch.Tensor] = None,
156
+ cache_position: Optional[torch.Tensor] = None,
157
+ block_tables: torch.Tensor = None,
158
+ ):
159
+ input_shape = inputs_embeds.size()[:-1]
160
+
161
+ # prepare casual_attn_mask
162
+ attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
163
+
164
+ hidden_states = self.value_embedding(inputs_embeds)
165
+ embed_pos = self.embed_positions.weight[cache_position + self.config.context_length]
166
+ hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
167
+
168
+ # iterate decoder_layer
169
+ for self_past_key_value, cross_past_key_value, decoder_layer in zip(
170
+ self_past_key_values, cross_past_key_values, self.layers
171
+ ):
172
+ hidden_states = decoder_layer(
173
+ hidden_states,
174
+ attention_mask=attention_mask,
175
+ self_past_key_value=self_past_key_value,
176
+ cross_past_key_value=cross_past_key_value,
177
+ cache_position=cache_position,
178
+ block_tables=block_tables,
179
+ )
180
+
181
+ return hidden_states
182
+
183
+
184
+ class TimeSeriesTransformersDecoderLayer(nn.Module):
185
+ def __init__(self, decoder_layer, self_attn, cross_attn):
186
+ super().__init__()
187
+ self._original_mod = decoder_layer
188
+ self.self_attn = self_attn
189
+ self.encoder_attn = cross_attn
190
+ self.embed_dim = decoder_layer.embed_dim
191
+ self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
192
+ self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
193
+ self.final_layer_norm = decoder_layer.final_layer_norm
194
+ self.activation_fn = decoder_layer.activation_fn
195
+ self.fc1 = decoder_layer.fc1
196
+ self.fc2 = decoder_layer.fc2
197
+
198
+ def forward(
199
+ self,
200
+ hidden_states: torch.Tensor,
201
+ attention_mask: Optional[torch.Tensor] = None,
202
+ self_past_key_value: Optional[Tuple[torch.Tensor]] = None,
203
+ cross_past_key_value: Optional[Tuple[torch.Tensor]] = None,
204
+ cache_position: Optional[torch.Tensor] = None,
205
+ block_tables: torch.Tensor = None,
206
+ ) -> torch.Tensor:
207
+ # Self Attention Block
208
+ residual = hidden_states
209
+ hidden_states = self.self_attn(
210
+ hidden_states=hidden_states,
211
+ past_key_value=self_past_key_value,
212
+ attention_mask=attention_mask,
213
+ cache_position=cache_position,
214
+ block_tables=block_tables,
215
+ )
216
+ hidden_states = residual + hidden_states
217
+ hidden_states = self.self_attn_layer_norm(hidden_states)
218
+
219
+ # Cross-Attention Block
220
+ residual = hidden_states
221
+ hidden_states = self.encoder_attn(
222
+ hidden_states=hidden_states,
223
+ past_key_value=cross_past_key_value,
224
+ # attention_mask=encoder_attention_mask,
225
+ )
226
+ hidden_states = residual + hidden_states
227
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
228
+
229
+ # Fully Connected Block
230
+ residual = hidden_states
231
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
232
+ hidden_states = self.fc2(hidden_states)
233
+ hidden_states = residual + hidden_states
234
+ hidden_states = self.final_layer_norm(hidden_states)
235
+
236
+ return hidden_states
237
+
238
+
239
+ class TimeSeriesTransformersAttention(nn.Module):
240
+ def __init__(self, attn, num_parallel_samples):
241
+ super().__init__()
242
+ self._original_mod = attn
243
+ self.q_proj = attn.q_proj
244
+ self.k_proj = attn.k_proj
245
+ self.v_proj = attn.v_proj
246
+ self.out_proj = attn.out_proj
247
+ self.num_heads = attn.num_heads
248
+ self.embed_dim = attn.embed_dim
249
+ self.head_dim = attn.head_dim
250
+ self.scaling = attn.scaling
251
+ self.num_parallel_samples = num_parallel_samples
252
+
253
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
254
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
255
+
256
+
257
+ class TimeSeriesTransformersSelfAttention(TimeSeriesTransformersAttention):
258
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
259
+ return tensor.view(1, seq_len, 1, bsz * self.num_heads, self.head_dim).transpose(1, 3)
260
+
261
+ def forward(
262
+ self,
263
+ hidden_states: torch.Tensor,
264
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
265
+ attention_mask: Optional[torch.Tensor] = None,
266
+ cache_position: Optional[torch.Tensor] = None,
267
+ block_tables: Optional[torch.Tensor] = None,
268
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
269
+ bsz, tgt_len, _ = hidden_states.size()
270
+ query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
271
+ query_states = query_states * self.scaling
272
+
273
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
274
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
275
+
276
+ block_size = past_key_value[0].shape[-2]
277
+ attn_output = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode(
278
+ q=query_states,
279
+ k=key_states,
280
+ v=value_states,
281
+ mask=attention_mask.unsqueeze(2),
282
+ kcache=past_key_value[0].view(1, bsz * self.num_heads, 1, -1, self.head_dim),
283
+ vcache=past_key_value[1].view(1, bsz * self.num_heads, 1, -1, self.head_dim),
284
+ seq=cache_position.expand(bsz, 1),
285
+ scale=torch.tensor(1.0, dtype=torch.float32), # scale
286
+ block_table=block_tables,
287
+ block_size=block_size,
288
+ )
289
+
290
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
291
+ attn_output = attn_output.transpose(1, 2)
292
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
293
+ attn_output = self.out_proj(attn_output)
294
+
295
+ return attn_output
296
+
297
+
298
+ class TimeSeriesTransformersCrossAttention(TimeSeriesTransformersSelfAttention):
299
+ def forward(
300
+ self,
301
+ hidden_states: torch.Tensor,
302
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
303
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
304
+ batch_size, query_len, _ = hidden_states.size()
305
+ query_states = (
306
+ self.q_proj(hidden_states)
307
+ .view(
308
+ batch_size // self.num_parallel_samples,
309
+ self.num_parallel_samples,
310
+ query_len,
311
+ self.num_heads,
312
+ self.head_dim,
313
+ )
314
+ .transpose(2, 3)
315
+ )
316
+ query_states = query_states * self.scaling
317
+
318
+ key_states = past_key_value[0].unsqueeze(1)
319
+ value_states = past_key_value[1].unsqueeze(1)
320
+
321
+ attn_weights = torch.matmul(query_states, key_states.transpose(3, 4))
322
+ attn_weights = attn_weights
323
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
324
+
325
+ attn_output = torch.matmul(attn_weights, value_states)
326
+ attn_output = attn_output.view(batch_size, self.num_heads, query_len, self.head_dim)
327
+ attn_output = attn_output.transpose(1, 2)
328
+ attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
329
+ attn_output = self.out_proj(attn_output)
330
+
331
+ return attn_output
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_wav2vec import RBLNWav2Vec2ForCTCConfig
15
16
  from .modeling_wav2vec2 import RBLNWav2Vec2ForCTC
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNModelForMaskedLMConfig
16
+
17
+
18
+ class RBLNWav2Vec2ForCTCConfig(RBLNModelForMaskedLMConfig):
19
+ rbln_model_input_names = ["input_values"]
@@ -12,26 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, Union
16
15
 
17
16
  import torch
18
- from transformers import AutoModelForMaskedLM, PretrainedConfig, Wav2Vec2ForCTC
17
+ from transformers import AutoModelForMaskedLM, Wav2Vec2ForCTC
19
18
  from transformers.modeling_outputs import CausalLMOutput
20
19
 
21
- from ....modeling import RBLNModel
22
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
23
- from ....utils.logging import get_logger
24
-
25
-
26
- logger = get_logger(__name__)
27
-
28
- if TYPE_CHECKING:
29
- from transformers import (
30
- AutoFeatureExtractor,
31
- AutoProcessor,
32
- AutoTokenizer,
33
- PretrainedConfig,
34
- )
20
+ from ...modeling_generic import RBLNModelForMaskedLM
21
+ from .configuration_wav2vec import RBLNWav2Vec2ForCTCConfig
35
22
 
36
23
 
37
24
  class _Wav2Vec2(torch.nn.Module):
@@ -44,11 +31,11 @@ class _Wav2Vec2(torch.nn.Module):
44
31
  return self.model.lm_head(output[0])
45
32
 
46
33
 
47
- class RBLNWav2Vec2ForCTC(RBLNModel):
34
+ class RBLNWav2Vec2ForCTC(RBLNModelForMaskedLM):
48
35
  """
49
36
  Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
50
37
 
51
- This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the
38
+ This model inherits from [`RBLNModelForMaskedLM`]. Check the superclass documentation for the generic methods the
52
39
  library implements for all its model.
53
40
 
54
41
  It implements the methods to convert a pre-trained Wav2Vec2 model into a RBLN Wav2Vec2 model by:
@@ -58,60 +45,10 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
58
45
 
59
46
  main_input_name = "input_values"
60
47
  auto_model_class = AutoModelForMaskedLM
48
+ rbln_dtype = "float32"
49
+ output_class = CausalLMOutput
50
+ output_key = "logits"
61
51
 
62
52
  @classmethod
63
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
53
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNWav2Vec2ForCTCConfig) -> torch.nn.Module:
64
54
  return _Wav2Vec2(model).eval()
65
-
66
- @classmethod
67
- def _get_rbln_config(
68
- cls,
69
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
70
- model_config: "PretrainedConfig",
71
- rbln_kwargs: Dict[str, Any] = {},
72
- ) -> RBLNConfig:
73
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
74
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
75
-
76
- if rbln_max_seq_len is None:
77
- for tokenizer in preprocessors:
78
- if hasattr(tokenizer, "model_max_length"):
79
- rbln_max_seq_len = tokenizer.model_max_length
80
- break
81
- if rbln_max_seq_len is None:
82
- raise ValueError("`rbln_max_seq_len` should be specified!")
83
-
84
- if rbln_batch_size is None:
85
- rbln_batch_size = 1
86
-
87
- input_info = [
88
- (
89
- "input_values",
90
- [
91
- rbln_batch_size,
92
- rbln_max_seq_len,
93
- ],
94
- "float32",
95
- ),
96
- ]
97
-
98
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
99
-
100
- rbln_config = RBLNConfig(
101
- rbln_cls=cls.__name__,
102
- compile_cfgs=[rbln_compile_config],
103
- rbln_kwargs=rbln_kwargs,
104
- )
105
-
106
- rbln_config.model_cfg.update(
107
- {
108
- "max_seq_len": rbln_max_seq_len,
109
- "batch_size": rbln_batch_size,
110
- }
111
- )
112
-
113
- return rbln_config
114
-
115
- def forward(self, input_values: "torch.Tensor", **kwargs):
116
- outputs = super().forward(input_values, **kwargs)
117
- return CausalLMOutput(logits=outputs)
@@ -12,4 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from ....ops import paged_add_softmax_attn_decode
16
+ from .configuration_whisper import RBLNWhisperForConditionalGenerationConfig
15
17
  from .modeling_whisper import RBLNWhisperForConditionalGeneration
@@ -0,0 +1,64 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import rebel
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+
20
+
21
+ logger = get_logger()
22
+
23
+
24
+ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
25
+ def __init__(
26
+ self,
27
+ batch_size: int = None,
28
+ token_timestamps: bool = None,
29
+ use_attention_mask: bool = None,
30
+ enc_max_seq_len: int = None,
31
+ dec_max_seq_len: int = None,
32
+ **kwargs,
33
+ ):
34
+ """
35
+ Args:
36
+ batch_size (int, optional): The batch size for inference. Defaults to 1.
37
+ token_timestamps (bool, optional): Whether to output token timestamps during generation. Defaults to False.
38
+ use_attention_mask (bool, optional): Whether to use attention masks during inference. This is automatically
39
+ set to True for RBLN-CA02 devices.
40
+ enc_max_seq_len (int, optional): Maximum sequence length for the encoder.
41
+ dec_max_seq_len (int, optional): Maximum sequence length for the decoder.
42
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
43
+
44
+ Raises:
45
+ ValueError: If batch_size is not a positive integer.
46
+ """
47
+ super().__init__(**kwargs)
48
+
49
+ self.batch_size = batch_size or 1
50
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
51
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
52
+
53
+ self.token_timestamps = token_timestamps or False
54
+ self.enc_max_seq_len = enc_max_seq_len
55
+ self.dec_max_seq_len = dec_max_seq_len
56
+
57
+ self.use_attention_mask = use_attention_mask
58
+ npu = self.npu or rebel.get_npu_name()
59
+ if npu == "RBLN-CA02":
60
+ if self.use_attention_mask is False:
61
+ logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
62
+ self.use_attention_mask = True
63
+ else:
64
+ self.use_attention_mask = self.use_attention_mask or False