optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -18,19 +18,14 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
18
18
  import rebel
19
19
  import torch
20
20
  from rebel.compile_context import CompileContext
21
- from transformers import (
22
- AutoModelForSpeechSeq2Seq,
23
- AutoProcessor,
24
- PretrainedConfig,
25
- WhisperForConditionalGeneration,
26
- WhisperModel,
27
- )
21
+ from transformers import AutoModelForSpeechSeq2Seq, WhisperForConditionalGeneration, WhisperModel
28
22
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
29
23
 
24
+ from ....configuration_utils import RBLNCompileConfig
30
25
  from ....modeling import RBLNModel
31
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
26
  from ....utils.logging import get_logger
33
27
  from ....utils.runtime_utils import RBLNPytorchRuntime
28
+ from .configuration_whisper import RBLNWhisperForConditionalGenerationConfig
34
29
  from .generation_whisper import RBLNWhisperGenerationMixin
35
30
  from .whisper_architecture import WhisperWrapper
36
31
 
@@ -38,29 +33,41 @@ from .whisper_architecture import WhisperWrapper
38
33
  logger = get_logger(__name__)
39
34
 
40
35
  if TYPE_CHECKING:
41
- from transformers import AutoFeatureExtractor, AutoProcessor, PretrainedConfig, PreTrainedModel
36
+ from transformers import (
37
+ AutoFeatureExtractor,
38
+ AutoProcessor,
39
+ AutoTokenizer,
40
+ GenerationConfig,
41
+ PretrainedConfig,
42
+ PreTrainedModel,
43
+ )
42
44
 
43
45
 
44
46
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
45
47
  mandatory_members = ["main_input_name"]
46
48
 
47
- def forward(self, input_features: torch.Tensor = None):
48
- # backward compatibility transformers==4.40.2
49
- # https://github.com/huggingface/transformers/blob/4fdf58afb72b0754da30037fc800b6044e7d9c99/src/transformers/pipelines/automatic_speech_recognition.py#L494
50
-
51
- n_pad_to_batch = self.batch_size - input_features.shape[0]
52
- if n_pad_to_batch > 0:
53
- input_features = torch.nn.functional.pad(input_features, (0, 0, 0, 0, 0, n_pad_to_batch))
54
-
55
- _ = super().forward(input_features=input_features)
56
-
57
- # dummy output for generation
58
- return BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
49
+ def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
50
+ output = super().forward(*args, **kwargs)
51
+ return BaseModelOutput(last_hidden_state=output)
59
52
 
60
53
 
61
54
  class RBLNRuntimeDecoder(RBLNPytorchRuntime):
62
55
  mandatory_members = ["main_input_name"]
63
56
 
57
+ def __init__(
58
+ self,
59
+ runtime: rebel.Runtime,
60
+ batch_size: int,
61
+ dec_max_seq_len: int,
62
+ use_attention_mask: Optional[bool] = None,
63
+ **kwargs: Any,
64
+ ) -> None:
65
+ super().__init__(runtime, **kwargs)
66
+ self.batch_size = batch_size
67
+ self.dec_max_seq_len = dec_max_seq_len
68
+ self.use_attention_mask = use_attention_mask
69
+ self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
70
+
64
71
  def forward(
65
72
  self,
66
73
  decoder_input_ids: torch.Tensor = None,
@@ -69,13 +76,24 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
69
76
  ):
70
77
  inputs_bsz = decoder_input_ids.shape[0]
71
78
  padded_bsz = self.batch_size - inputs_bsz
79
+
72
80
  if padded_bsz > 0:
73
81
  decoder_input_ids = torch.nn.functional.pad(decoder_input_ids, (0, 0, 0, padded_bsz))
74
82
 
83
+ if self.use_attention_mask:
84
+ for b_idx in range(self.batch_size):
85
+ decoding_step = cache_position[b_idx].item()
86
+ if not (0 <= decoding_step < self.dec_max_seq_len):
87
+ raise ValueError(
88
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
89
+ )
90
+ decoder_attention_mask[b_idx, : decoding_step + 1] = 1
91
+
75
92
  outputs = super().forward(
76
- decoder_input_ids=decoder_input_ids,
77
- decoder_attention_mask=decoder_attention_mask,
78
- cache_position=cache_position,
93
+ decoder_input_ids,
94
+ decoder_attention_mask if self.use_attention_mask else None,
95
+ cache_position,
96
+ block_tables=self.default_block_tables,
79
97
  )
80
98
 
81
99
  if isinstance(outputs, torch.Tensor):
@@ -101,15 +119,18 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
101
119
  def __post_init__(self, **kwargs):
102
120
  super().__post_init__(**kwargs)
103
121
 
104
- self.batch_size = self.rbln_config.model_cfg["batch_size"]
105
- self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
106
- self.rbln_token_timestamps = self.rbln_config.model_cfg["token_timestamps"]
122
+ self.batch_size = self.rbln_config.batch_size
123
+ self.dec_max_seq_len = self.rbln_config.dec_max_seq_len
124
+ self.rbln_token_timestamps = self.rbln_config.token_timestamps
125
+ self.use_attention_mask = self.rbln_config.use_attention_mask
107
126
 
108
- self.encoder = RBLNRuntimeEncoder(
109
- runtime=self.model[0], main_input_name="input_features", batch_size=self.batch_size
110
- )
127
+ self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
111
128
  self.decoder = RBLNRuntimeDecoder(
112
- runtime=self.model[1], main_input_name="input_ids", batch_size=self.batch_size
129
+ runtime=self.model[1],
130
+ main_input_name="input_ids",
131
+ batch_size=self.batch_size,
132
+ dec_max_seq_len=self.dec_max_seq_len,
133
+ use_attention_mask=self.use_attention_mask,
113
134
  )
114
135
 
115
136
  # skip encoder & first decoder when language detected
@@ -150,13 +171,16 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
150
171
  raise NotImplementedError
151
172
 
152
173
  @classmethod
153
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
154
- rbln_token_timestamps = rbln_config.model_cfg["token_timestamps"]
155
- return WhisperWrapper(model, rbln_token_timestamps)
174
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNWhisperForConditionalGenerationConfig):
175
+ return WhisperWrapper(
176
+ model,
177
+ use_attention_mask=rbln_config.use_attention_mask,
178
+ rbln_token_timestamps=rbln_config.token_timestamps,
179
+ )
156
180
 
157
181
  @classmethod
158
182
  @torch.inference_mode()
159
- def get_compiled_model(cls, model, rbln_config: RBLNConfig):
183
+ def get_compiled_model(cls, model, rbln_config: RBLNWhisperForConditionalGenerationConfig):
160
184
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
161
185
 
162
186
  enc_compile_config = rbln_config.compile_cfgs[0]
@@ -196,47 +220,42 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
196
220
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
197
221
 
198
222
  @classmethod
199
- def _get_rbln_config(
223
+ def _update_rbln_config(
200
224
  cls,
201
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor"],
202
- model_config: "PretrainedConfig",
203
- rbln_kwargs: Dict[str, Any] = {},
204
- ) -> RBLNConfig:
205
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
206
- rbln_token_timestamps = rbln_kwargs.get("token_timestamps", False)
207
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
208
-
225
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
226
+ model: Optional["PreTrainedModel"] = None,
227
+ model_config: Optional["PretrainedConfig"] = None,
228
+ rbln_config: Optional[RBLNWhisperForConditionalGenerationConfig] = None,
229
+ ) -> RBLNWhisperForConditionalGenerationConfig:
209
230
  expected_seq_len = model_config.max_source_positions * 2
210
231
  num_mel_bins = model_config.num_mel_bins
211
- enc_max_seq_len = model_config.max_source_positions
232
+ rbln_config.enc_max_seq_len = model_config.max_source_positions
212
233
 
213
234
  # 'whisper-large-v3-turbo' doesn't have 'max_length', but PretrainedConfig have default value for the key 'max_length'
214
- rbln_dec_max_seq_len = getattr(model_config, "max_target_positions", None)
215
- if rbln_dec_max_seq_len is None:
216
- rbln_dec_max_seq_len = model_config.max_length
217
-
218
- # model input info
219
- enc_input_info = [("input_features", [rbln_batch_size, num_mel_bins, expected_seq_len], "float32")]
220
- enc_input_info.extend(
221
- [
222
- (
223
- "cross_key_value_states",
224
- [
225
- model_config.decoder_layers * 2,
226
- rbln_batch_size,
227
- model_config.decoder_attention_heads,
228
- enc_max_seq_len,
229
- model_config.d_model // model_config.decoder_attention_heads,
230
- ],
231
- "float32",
232
- )
233
- ]
234
- )
235
+ rbln_config.dec_max_seq_len = getattr(model_config, "max_target_positions", None)
236
+ if rbln_config.dec_max_seq_len is None:
237
+ rbln_config.dec_max_seq_len = model_config.max_length
238
+
239
+ enc_input_info = [
240
+ ("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
241
+ ("block_tables", [1], "int16"),
242
+ (
243
+ "cross_key_value_states",
244
+ [
245
+ model_config.decoder_layers * 2,
246
+ rbln_config.batch_size,
247
+ model_config.decoder_attention_heads,
248
+ rbln_config.enc_max_seq_len,
249
+ model_config.d_model // model_config.decoder_attention_heads,
250
+ ],
251
+ "float32",
252
+ ),
253
+ ]
235
254
 
236
255
  dec_input_info = [
237
- ("decoder_input_ids", [rbln_batch_size, 1], "int64"),
238
- ("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "int64"),
239
- ("cache_position", [], "int32"),
256
+ ("decoder_input_ids", [rbln_config.batch_size, 1], "int64"),
257
+ ("cache_position", [rbln_config.batch_size, 1], "int32"),
258
+ ("block_tables", [rbln_config.batch_size, 1], "int16"),
240
259
  ]
241
260
  dec_input_info.extend(
242
261
  [
@@ -244,9 +263,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
244
263
  "cross_key_value_states",
245
264
  [
246
265
  model_config.decoder_layers * 2,
247
- rbln_batch_size,
266
+ rbln_config.batch_size,
248
267
  model_config.decoder_attention_heads,
249
- enc_max_seq_len,
268
+ rbln_config.enc_max_seq_len,
250
269
  model_config.d_model // model_config.decoder_attention_heads,
251
270
  ],
252
271
  "float32",
@@ -258,9 +277,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
258
277
  (
259
278
  f"self_key_value_states_{i}",
260
279
  [
261
- rbln_batch_size,
280
+ rbln_config.batch_size,
262
281
  model_config.decoder_attention_heads,
263
- rbln_dec_max_seq_len,
282
+ rbln_config.dec_max_seq_len,
264
283
  model_config.d_model // model_config.encoder_attention_heads,
265
284
  ],
266
285
  "float32",
@@ -269,22 +288,15 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
269
288
  ]
270
289
  )
271
290
 
291
+ if rbln_config.use_attention_mask:
292
+ dec_input_info.insert(
293
+ 1, ("decoder_attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
294
+ )
295
+
272
296
  enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
273
297
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
274
298
 
275
- rbln_config = RBLNConfig(
276
- rbln_cls=cls.__name__,
277
- compile_cfgs=[enc_compile_config, dec_compile_config],
278
- rbln_kwargs=rbln_kwargs,
279
- )
280
-
281
- rbln_config.model_cfg.update(
282
- {
283
- "batch_size": rbln_batch_size,
284
- "dec_max_seq_len": rbln_dec_max_seq_len,
285
- "token_timestamps": rbln_token_timestamps,
286
- }
287
- )
299
+ rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
288
300
 
289
301
  return rbln_config
290
302
 
@@ -292,18 +304,23 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
292
304
  def _create_runtimes(
293
305
  cls,
294
306
  compiled_models: List[rebel.RBLNCompiledModel],
295
- rbln_device_map: Dict[str, int],
296
- activate_profiler: Optional[bool] = None,
307
+ rbln_config: RBLNWhisperForConditionalGenerationConfig,
297
308
  ) -> List[rebel.Runtime]:
298
- if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
309
+ if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
299
310
  cls._raise_missing_compiled_file_error(["encoder", "decoder"])
300
311
 
301
312
  return [
302
- compiled_models[0].create_runtime(
303
- tensor_type="pt", device=rbln_device_map["encoder"], activate_profiler=activate_profiler
313
+ rebel.Runtime(
314
+ compiled_models[0],
315
+ tensor_type="pt",
316
+ device=rbln_config.device_map["encoder"],
317
+ activate_profiler=rbln_config.activate_profiler,
304
318
  ),
305
- compiled_models[1].create_runtime(
306
- tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
319
+ rebel.Runtime(
320
+ compiled_models[1],
321
+ tensor_type="pt",
322
+ device=rbln_config.device_map["decoder"],
323
+ activate_profiler=rbln_config.activate_profiler,
307
324
  ),
308
325
  ]
309
326
 
@@ -327,11 +344,25 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
327
344
 
328
345
  # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L512
329
346
  def _prepare_encoder_decoder_kwargs_for_generation(
330
- self, inputs_tensor: torch.Tensor, model_kwargs, *args, **kwargs
347
+ self,
348
+ inputs_tensor: torch.Tensor,
349
+ model_kwargs,
350
+ model_input_name: Optional[str] = None,
351
+ generation_config: Optional["GenerationConfig"] = None,
352
+ **kwargs,
331
353
  ) -> Dict[str, Any]:
354
+ batch_size = inputs_tensor.shape[0]
355
+ n_pad_to_batch = self.batch_size - batch_size
356
+ if n_pad_to_batch > 0:
357
+ inputs_tensor = torch.nn.functional.pad(inputs_tensor, (0, 0, 0, 0, 0, n_pad_to_batch))
358
+
332
359
  if not self.is_language_detected:
333
- model_kwargs["encoder_outputs"] = self.encoder(input_features=inputs_tensor)
334
- self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.int64)
360
+ for b in range(inputs_tensor.shape[0]):
361
+ block_tables = torch.tensor([b], dtype=torch.int16)
362
+ model_kwargs["encoder_outputs"] = self.encoder(
363
+ input_features=inputs_tensor[b].unsqueeze(0), block_tables=block_tables
364
+ )
365
+ self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
335
366
  else:
336
367
  model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
337
368
 
@@ -359,7 +390,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
359
390
  decoder_output = self.decoder(
360
391
  decoder_input_ids=input_ids[:, step : step + 1].contiguous(),
361
392
  decoder_attention_mask=self.decoder_attention_mask,
362
- cache_position=step.to(torch.int32),
393
+ cache_position=torch.full((self.batch_size, 1), step, dtype=torch.int32),
363
394
  )
364
395
  cross_attentions.append(decoder_output.cross_attentions)
365
396
  lm_logits = decoder_output.logits
@@ -374,15 +405,19 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
374
405
  # detect language pass
375
406
  # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/models/whisper/generation_whisper.py#L1442
376
407
  else:
408
+ # for language auto detection (generate with language=None)
377
409
  if encoder_outputs is None:
378
- self.encoder(input_features=input_features.contiguous())
379
- self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.int64)
410
+ for b in range(input_features.shape[0]):
411
+ block_tables = torch.tensor([b], dtype=torch.int16)
412
+ self.encoder(input_features=input_features[b].unsqueeze(0), block_tables=block_tables)
413
+
414
+ self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
380
415
  self.is_language_detected = True
381
416
  self.decoder_attention_mask[:, 0] = 1
382
417
  decoder_output = self.decoder(
383
418
  decoder_input_ids=decoder_input_ids.contiguous(),
384
419
  decoder_attention_mask=self.decoder_attention_mask,
385
- cache_position=torch.zeros([], dtype=torch.int32),
420
+ cache_position=torch.zeros([self.rbln_config.batch_size, 1], dtype=torch.int32),
386
421
  )
387
422
  lm_logits = decoder_output.logits
388
423
  self.language_cross = decoder_output.cross_attentions
@@ -16,27 +16,19 @@ from typing import Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from torch import nn
19
- from transformers.modeling_attn_mask_utils import (
20
- _prepare_4d_causal_attention_mask,
21
- )
22
- from transformers.modeling_outputs import (
23
- BaseModelOutput,
24
- Seq2SeqLMOutput,
25
- )
19
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
26
20
  from transformers.utils import logging
27
21
 
28
- from ....ops import register_rbln_custom_add_softmax_attention, register_rbln_custom_cache_update
29
-
30
22
 
31
23
  logger = logging.get_logger(__name__)
32
24
 
33
25
 
34
26
  class WhisperWrapper:
35
- def __init__(self, model, rbln_token_timestamps):
36
- register_rbln_custom_cache_update()
37
- register_rbln_custom_add_softmax_attention()
27
+ def __init__(self, model, use_attention_mask, rbln_token_timestamps):
38
28
  self.encoder = WhisperEncoderWrapper(model)
39
- self.decoder = WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps)
29
+ self.decoder = WhisperDecoderWrapper(
30
+ model, use_attention_mask=use_attention_mask, output_attentions=rbln_token_timestamps
31
+ )
40
32
 
41
33
 
42
34
  class WhisperEncoderWrapper(torch.nn.Module):
@@ -57,6 +49,7 @@ class WhisperEncoderWrapper(torch.nn.Module):
57
49
  def forward(
58
50
  self,
59
51
  input_features: Optional[torch.LongTensor],
52
+ b_idx: torch.Tensor,
60
53
  cross_key_values: torch.Tensor,
61
54
  ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
62
55
  # 1. get encoder last_hidden_states
@@ -76,21 +69,31 @@ class WhisperEncoderWrapper(torch.nn.Module):
76
69
  cross_kv = torch.stack(cross_kv, dim=0)
77
70
 
78
71
  # 3. update cross_attention's past_key_value to the device-dram for optimization.
79
- bidx = torch.tensor(0, dtype=torch.int16)
80
- axis = torch.tensor(1, dtype=torch.int16)
81
- enc_output = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
72
+ batch_axis = torch.tensor(1, dtype=torch.int16)
73
+ cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(
74
+ cross_key_values, cross_kv, b_idx[0], batch_axis
75
+ )
82
76
 
83
- return enc_output
77
+ return cross_key_values
84
78
 
85
79
 
86
80
  class WhisperDecoderWrapper(torch.nn.Module):
87
- def __init__(self, model, output_attentions: bool = False):
81
+ def __init__(self, model, use_attention_mask: bool = True, output_attentions: bool = False, **kwargs):
88
82
  super().__init__()
89
83
  self.config = model.config
90
- self.num_layers = self.config.decoder_layers
91
84
  self.proj_out = model.proj_out
92
- self.decoder = self.convert_to_rbln_conditional_generation(model)
85
+ self.use_attention_mask = use_attention_mask
93
86
  self.output_attentions = output_attentions
87
+ self.__post_init__(model, **kwargs)
88
+
89
+ def __post_init__(self, model: nn.Module, **kwargs):
90
+ """
91
+ Post-initialization to extract and configure encoder-related attributes.
92
+ It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
93
+ by subclasses to modify or add custom attributes as necessary.
94
+ """
95
+ self.num_layers = self.config.decoder_layers
96
+ self.decoder = self.convert_to_rbln_conditional_generation(model)
94
97
 
95
98
  def convert_to_rbln_conditional_generation(self, model: nn.Module):
96
99
  new_layers = []
@@ -105,12 +108,21 @@ class WhisperDecoderWrapper(torch.nn.Module):
105
108
 
106
109
  def forward(
107
110
  self,
108
- decoder_input_ids: torch.Tensor,
109
- decoder_attention_mask: torch.Tensor,
110
- cache_position: torch.Tensor,
111
- cross_kv_cache: torch.Tensor,
112
- *self_kv_cache: torch.Tensor,
111
+ *args,
113
112
  ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
113
+ if self.use_attention_mask:
114
+ (
115
+ decoder_input_ids,
116
+ decoder_attention_mask,
117
+ cache_position,
118
+ block_tables,
119
+ cross_kv_cache,
120
+ *self_kv_cache,
121
+ ) = args
122
+ else:
123
+ decoder_attention_mask = None
124
+ (decoder_input_ids, cache_position, block_tables, cross_kv_cache, *self_kv_cache) = args
125
+
114
126
  # prepare past_key_values
115
127
  self_past_key_values = ()
116
128
  cross_past_key_values = ()
@@ -125,6 +137,7 @@ class WhisperDecoderWrapper(torch.nn.Module):
125
137
  cache_position=cache_position,
126
138
  self_past_key_values=self_past_key_values,
127
139
  cross_past_key_values=cross_past_key_values,
140
+ block_tables=block_tables,
128
141
  )
129
142
 
130
143
  lm_logits = self.proj_out(sequence_output)
@@ -154,17 +167,25 @@ class WhisperDecoder(nn.Module):
154
167
  self_past_key_values: Optional[torch.Tensor] = None,
155
168
  cross_past_key_values: Optional[torch.Tensor] = None,
156
169
  cache_position: Optional[torch.Tensor] = None,
170
+ block_tables: Optional[torch.Tensor] = None,
157
171
  ):
158
172
  input_shape = input_ids.size()
159
173
  input_ids = input_ids.view(-1, input_shape[-1])
160
174
 
161
175
  # positional embeding
162
176
  inputs_embeds = self.embed_tokens(input_ids)
163
- positions = self.embed_positions(input_ids, position_ids=cache_position)
164
- hidden_states = inputs_embeds + positions
177
+ all_hiddens = []
178
+ for i in range(inputs_embeds.shape[0]):
179
+ position_id = cache_position[i]
180
+ position = self.embed_positions.weight[position_id]
181
+ batch_hidden = position + inputs_embeds[i]
182
+ all_hiddens.append(batch_hidden)
183
+
184
+ hidden_states = torch.cat(all_hiddens, dim=0).unsqueeze(1)
165
185
 
166
- # prepare casual_attn_mask
167
- attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
186
+ # prepare attn mask (normal attention - masked)
187
+ if attention_mask is not None:
188
+ attention_mask = attention_mask[:, None, None, :]
168
189
 
169
190
  cross_attentions = ()
170
191
  # iterate decoder_layer
@@ -177,6 +198,7 @@ class WhisperDecoder(nn.Module):
177
198
  self_past_key_value=self_past_key_value,
178
199
  cross_past_key_value=cross_past_key_value,
179
200
  cache_position=cache_position,
201
+ block_tables=block_tables,
180
202
  )
181
203
  cross_attentions += (cross_attn_weights,)
182
204
 
@@ -205,6 +227,7 @@ class WhisperDecoderLayer(nn.Module):
205
227
  self_past_key_value: Optional[Tuple[torch.Tensor]] = None,
206
228
  cross_past_key_value: Optional[Tuple[torch.Tensor]] = None,
207
229
  cache_position: Optional[torch.Tensor] = None,
230
+ block_tables: Optional[torch.Tensor] = None,
208
231
  ) -> torch.Tensor:
209
232
  # Self Attention Block
210
233
  residual = hidden_states
@@ -214,6 +237,7 @@ class WhisperDecoderLayer(nn.Module):
214
237
  past_key_value=self_past_key_value,
215
238
  attention_mask=attention_mask,
216
239
  cache_position=cache_position,
240
+ block_tables=block_tables,
217
241
  )
218
242
  hidden_states = residual + hidden_states
219
243
 
@@ -263,6 +287,7 @@ class WhisperSelfAttention(WhisperAttention):
263
287
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
264
288
  attention_mask: Optional[torch.Tensor] = None,
265
289
  cache_position: Optional[torch.Tensor] = None,
290
+ block_tables: Optional[torch.Tensor] = None,
266
291
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
267
292
  bsz, tgt_len, _ = hidden_states.size()
268
293
  query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
@@ -270,17 +295,25 @@ class WhisperSelfAttention(WhisperAttention):
270
295
 
271
296
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
272
297
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
273
-
274
- attn_output = torch.ops.rbln_custom_ops.add_softmax_attn_decode(
275
- query_states,
276
- key_states,
277
- value_states,
278
- attention_mask.unsqueeze(2),
279
- past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
280
- past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
281
- cache_position.expand(bsz, 1),
282
- torch.tensor(1.0, dtype=torch.float32), # scale
283
- )
298
+ block_size = past_key_value[0].shape[-2]
299
+
300
+ args = {
301
+ "q": query_states,
302
+ "k": key_states,
303
+ "v": value_states,
304
+ "kcache": past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
305
+ "vcache": past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
306
+ "seq": cache_position.expand(bsz, 1),
307
+ "scale": torch.tensor(1.0, dtype=torch.float32),
308
+ "block_table": block_tables,
309
+ "block_size": block_size,
310
+ }
311
+
312
+ if attention_mask is not None:
313
+ args["mask"] = attention_mask.unsqueeze(2)
314
+ attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(**args)
315
+ else:
316
+ attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(**args)
284
317
 
285
318
  attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
286
319
  attn_output = attn_output.transpose(1, 2)
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_xlm_roberta import RBLNXLMRobertaModelConfig
15
16
  from .modeling_xlm_roberta import RBLNXLMRobertaModel
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
16
+
17
+
18
+ class RBLNXLMRobertaModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
19
+ pass