optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. optimum/rbln/__init__.py +156 -36
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +85 -75
  31. optimum/rbln/transformers/__init__.py +79 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +96 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  73. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  74. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  75. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  76. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  77. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  78. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  79. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  80. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  81. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  82. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  83. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  84. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  85. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  86. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  87. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  88. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  89. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  90. optimum/rbln/utils/submodule.py +26 -43
  91. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
  92. optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
  93. optimum/rbln/modeling_config.py +0 -310
  94. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  95. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -13,48 +13,21 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, Any, Callable
17
17
 
18
18
  import torch
19
- from transformers import (
20
- AutoModelForTextEncoding,
21
- PretrainedConfig,
22
- T5EncoderModel,
23
- T5ForConditionalGeneration,
24
- )
25
- from transformers.modeling_outputs import BaseModelOutput
26
-
27
- from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
28
- from ....modeling import RBLNModel
29
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
30
- from ....utils.logging import get_logger
31
- from ....utils.runtime_utils import RBLNPytorchRuntime
19
+ from transformers import AutoModelForTextEncoding, T5EncoderModel, T5ForConditionalGeneration
20
+
21
+ from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
32
22
  from ...models.seq2seq import RBLNModelForSeq2SeqLM
23
+ from .configuration_t5 import RBLNT5EncoderModelConfig, RBLNT5ForConditionalGenerationConfig
33
24
  from .t5_architecture import T5Wrapper
34
25
 
35
26
 
36
- logger = get_logger()
37
-
38
27
  if TYPE_CHECKING:
39
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
40
-
41
-
42
- class RBLNRuntimeModel(RBLNPytorchRuntime):
43
- def forward(
44
- self,
45
- input_ids: torch.LongTensor,
46
- attention_mask: torch.FloatTensor,
47
- head_mask: torch.FloatTensor,
48
- inputs_embeds: torch.FloatTensor,
49
- **kwargs,
50
- ):
51
- return super().forward(
52
- input_ids,
53
- attention_mask,
54
- head_mask,
55
- inputs_embeds,
56
- **kwargs,
57
- )
28
+ from transformers import PreTrainedModel
29
+
30
+ from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
58
31
 
59
32
 
60
33
  class T5EncoderWrapper(torch.nn.Module):
@@ -67,136 +40,35 @@ class T5EncoderWrapper(torch.nn.Module):
67
40
  return self.model(*args, **kwargs, return_dict=False)
68
41
 
69
42
 
70
- class RBLNT5EncoderModel(RBLNModel):
43
+ class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
71
44
  auto_model_class = AutoModelForTextEncoding
72
45
  rbln_model_input_names = ["input_ids", "attention_mask"]
73
46
 
74
- def __post_init__(self, **kwargs):
75
- self.model = RBLNRuntimeModel(runtime=self.model[0])
76
-
77
47
  @classmethod
78
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
48
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
79
49
  return T5EncoderWrapper(model)
80
50
 
81
51
  @classmethod
82
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
83
- batch_size = rbln_config.get("batch_size", 1)
84
- max_sequence_length = rbln_config.get("max_sequence_length", 256)
85
- model_input_names = ["input_ids"]
86
-
87
- rbln_config.update(
88
- {
89
- "batch_size": batch_size,
90
- "max_seq_len": max_sequence_length,
91
- "model_input_names": model_input_names,
92
- }
93
- )
94
-
95
- return rbln_config
96
-
97
- @classmethod
98
- def _get_rbln_config(
52
+ def update_rbln_config_using_pipe(
99
53
  cls,
100
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
101
- model_config: Optional["PretrainedConfig"] = None,
102
- rbln_kwargs: Dict[str, Any] = {},
103
- ) -> RBLNConfig:
104
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
105
- rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
106
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
107
-
108
- max_position_embeddings = getattr(model_config, "n_positions", None)
109
-
110
- if rbln_max_seq_len is None:
111
- rbln_max_seq_len = max_position_embeddings
112
- if rbln_max_seq_len is None:
113
- for tokenizer in preprocessors:
114
- if hasattr(tokenizer, "model_max_length"):
115
- rbln_max_seq_len = tokenizer.model_max_length
116
- break
117
- if rbln_max_seq_len is None:
118
- raise ValueError("`rbln_max_seq_len` should be specified!")
119
-
120
- if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
121
- raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
122
-
123
- signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
124
-
125
- if rbln_model_input_names is None:
126
- for tokenizer in preprocessors:
127
- if hasattr(tokenizer, "model_input_names"):
128
- rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
129
-
130
- invalid_params = set(rbln_model_input_names) - set(signature_params)
131
- if invalid_params:
132
- raise ValueError(f"Invalid model input names: {invalid_params}")
133
- break
134
- if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
135
- rbln_model_input_names = cls.rbln_model_input_names
136
- elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
137
- raise ValueError(
138
- "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
139
- f"and be sure to make the order of the inputs same as T5EncoderModel forward() arguments like ({list(signature_params)})"
140
- )
141
- else:
142
- invalid_params = set(rbln_model_input_names) - set(signature_params)
143
- if invalid_params:
144
- raise ValueError(f"Invalid model input names: {invalid_params}")
145
- rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
146
-
147
- if rbln_batch_size is None:
148
- rbln_batch_size = 1
149
-
150
- input_info = [
151
- (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
152
- for model_input_name in rbln_model_input_names
153
- ]
154
-
155
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
156
-
157
- rbln_config = RBLNConfig(
158
- rbln_cls=cls.__name__,
159
- compile_cfgs=[rbln_compile_config],
160
- rbln_kwargs=rbln_kwargs,
161
- )
162
-
163
- rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
54
+ pipe: "RBLNDiffusionMixin",
55
+ rbln_config: "RBLNDiffusionMixinConfig",
56
+ submodule_name: str,
57
+ ) -> "RBLNDiffusionMixinConfig":
58
+ submodule_config = getattr(rbln_config, submodule_name)
59
+ submodule_config.max_seq_len = rbln_config.max_seq_len or 256
60
+ submodule_config.model_input_names = ["input_ids"]
164
61
  return rbln_config
165
62
 
166
- def forward(
167
- self,
168
- input_ids: Optional[torch.LongTensor] = None,
169
- attention_mask: Optional[torch.FloatTensor] = None,
170
- head_mask: Optional[torch.FloatTensor] = None,
171
- inputs_embeds: Optional[torch.FloatTensor] = None,
172
- output_attentions: Optional[bool] = None,
173
- output_hidden_states: Optional[bool] = None,
174
- return_dict: Optional[bool] = None,
175
- ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
176
- encoder_outputs = self.model(
177
- input_ids=input_ids,
178
- attention_mask=attention_mask,
179
- inputs_embeds=inputs_embeds,
180
- head_mask=head_mask,
181
- output_attentions=output_attentions,
182
- output_hidden_states=output_hidden_states,
183
- return_dict=return_dict,
184
- )
185
- if not return_dict:
186
- return (encoder_outputs,)
187
- else:
188
- return BaseModelOutput(last_hidden_state=encoder_outputs)
189
-
190
63
 
191
64
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
192
65
  support_causal_attn = False
193
66
 
194
67
  @classmethod
195
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
196
- enc_max_seq_len = rbln_config.model_cfg["enc_max_seq_len"]
197
- dec_max_seq_len = rbln_config.model_cfg["dec_max_seq_len"]
198
-
199
- return T5Wrapper(model, enc_max_seq_len=enc_max_seq_len, dec_max_seq_len=dec_max_seq_len)
68
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5ForConditionalGenerationConfig):
69
+ return T5Wrapper(
70
+ model, enc_max_seq_len=rbln_config.enc_max_seq_len, dec_max_seq_len=rbln_config.dec_max_seq_len
71
+ )
200
72
 
201
73
  def __getattr__(self, __name: str) -> Any:
202
74
  def redirect(func):
@@ -22,4 +22,5 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  from ....ops import paged_add_softmax_attn_decode, rbln_cache_update
25
+ from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
25
26
  from .modeling_time_series_transformers import RBLNTimeSeriesTransformerForPrediction
@@ -0,0 +1,34 @@
1
+ from typing import Optional
2
+
3
+ from ....configuration_utils import RBLNModelConfig
4
+
5
+
6
+ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
7
+ def __init__(
8
+ self,
9
+ batch_size: Optional[int] = None,
10
+ enc_max_seq_len: Optional[int] = None,
11
+ dec_max_seq_len: Optional[int] = None,
12
+ num_parallel_samples: Optional[int] = None,
13
+ **kwargs,
14
+ ):
15
+ """
16
+ Args:
17
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
18
+ enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
19
+ dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
20
+ num_parallel_samples (Optional[int]): Number of samples to generate in parallel during prediction.
21
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
22
+
23
+ Raises:
24
+ ValueError: If batch_size is not a positive integer.
25
+ """
26
+ super().__init__(**kwargs)
27
+
28
+ self.batch_size = batch_size or 1
29
+ if not isinstance(self.batch_size, int) or self.batch_size <= 0:
30
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
31
+
32
+ self.enc_max_seq_len = enc_max_seq_len
33
+ self.dec_max_seq_len = dec_max_seq_len
34
+ self.num_parallel_samples = num_parallel_samples
@@ -25,7 +25,7 @@ import inspect
25
25
  import logging
26
26
  from dataclasses import dataclass
27
27
  from pathlib import Path
28
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
28
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
29
29
 
30
30
  import rebel
31
31
  import torch
@@ -38,9 +38,10 @@ from transformers import (
38
38
  from transformers.modeling_outputs import ModelOutput, SampleTSPredictionOutput, Seq2SeqTSModelOutput
39
39
  from transformers.modeling_utils import no_init_weights
40
40
 
41
+ from ....configuration_utils import RBLNCompileConfig
41
42
  from ....modeling import RBLNModel
42
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
43
43
  from ....utils.runtime_utils import RBLNPytorchRuntime
44
+ from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
44
45
  from .time_series_transformers_architecture import TimeSeriesTransformersWrapper
45
46
 
46
47
 
@@ -124,9 +125,9 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
124
125
 
125
126
  def __post_init__(self, **kwargs):
126
127
  super().__post_init__(**kwargs)
127
- self.batch_size = self.rbln_config.model_cfg["batch_size"]
128
- self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
129
- self.num_parallel_samples = self.rbln_config.model_cfg["num_parallel_samples"]
128
+ self.batch_size = self.rbln_config.batch_size
129
+ self.dec_max_seq_len = self.rbln_config.dec_max_seq_len
130
+ self.num_parallel_samples = self.rbln_config.num_parallel_samples
130
131
 
131
132
  with no_init_weights():
132
133
  self._origin_model = TimeSeriesTransformerForPrediction._from_config(self.config)
@@ -156,12 +157,14 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
156
157
  return redirect(val)
157
158
 
158
159
  @classmethod
159
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
160
- return TimeSeriesTransformersWrapper(model, rbln_config.model_cfg["num_parallel_samples"])
160
+ def wrap_model_if_needed(
161
+ self, model: "PreTrainedModel", rbln_config: RBLNTimeSeriesTransformerForPredictionConfig
162
+ ):
163
+ return TimeSeriesTransformersWrapper(model, rbln_config.num_parallel_samples)
161
164
 
162
165
  @classmethod
163
166
  @torch.inference_mode()
164
- def get_compiled_model(cls, model, rbln_config: RBLNConfig):
167
+ def get_compiled_model(cls, model, rbln_config: RBLNTimeSeriesTransformerForPredictionConfig):
165
168
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
166
169
 
167
170
  enc_compile_config = rbln_config.compile_cfgs[0]
@@ -206,7 +209,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
206
209
  model: "PreTrainedModel",
207
210
  save_dir_path: Path,
208
211
  subfolder: str,
209
- rbln_config: RBLNConfig,
212
+ rbln_config: RBLNTimeSeriesTransformerForPredictionConfig,
210
213
  ):
211
214
  """
212
215
  If you are unavoidably running on a CPU rather than an RBLN device,
@@ -217,31 +220,28 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
217
220
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
218
221
 
219
222
  @classmethod
220
- def _get_rbln_config(
223
+ def _update_rbln_config(
221
224
  cls,
222
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
223
- model_config: "PretrainedConfig",
224
- rbln_kwargs: Dict[str, Any] = {},
225
- ) -> RBLNConfig:
226
- rbln_batch_size = rbln_kwargs.get("batch_size", 1)
227
- rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
228
- rbln_num_parallel_samples = rbln_kwargs.get("num_parallel_samples", None)
229
-
230
- if not isinstance(rbln_batch_size, int):
231
- raise TypeError(f"Expected rbln_batch_size to be an int, but got {type(rbln_batch_size)}")
232
-
233
- rbln_num_parallel_samples = (
234
- model_config.num_parallel_samples if rbln_num_parallel_samples is None else rbln_num_parallel_samples
235
- )
236
- if rbln_dec_max_seq_len is None:
225
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
226
+ model: Optional["PreTrainedModel"] = None,
227
+ model_config: Optional["PretrainedConfig"] = None,
228
+ rbln_config: Optional[RBLNTimeSeriesTransformerForPredictionConfig] = None,
229
+ ) -> RBLNTimeSeriesTransformerForPredictionConfig:
230
+ rbln_config.num_parallel_samples = rbln_config.num_parallel_samples or model_config.num_parallel_samples
231
+
232
+ if rbln_config.dec_max_seq_len is None:
237
233
  predict_length = model_config.prediction_length
238
- rbln_dec_max_seq_len = (
234
+ rbln_config.dec_max_seq_len = (
239
235
  predict_length if predict_length % 64 == 0 else predict_length + (64 - predict_length % 64)
240
236
  )
241
237
 
242
238
  # model input info
243
239
  enc_input_info = [
244
- ("inputs_embeds", [rbln_batch_size, model_config.context_length, model_config.feature_size], "float32"),
240
+ (
241
+ "inputs_embeds",
242
+ [rbln_config.batch_size, model_config.context_length, model_config.feature_size],
243
+ "float32",
244
+ ),
245
245
  ]
246
246
  enc_input_info.extend(
247
247
  [
@@ -249,7 +249,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
249
249
  "cross_key_value_states",
250
250
  [
251
251
  model_config.decoder_layers * 2,
252
- rbln_batch_size,
252
+ rbln_config.batch_size,
253
253
  model_config.decoder_attention_heads,
254
254
  model_config.context_length,
255
255
  model_config.d_model // model_config.decoder_attention_heads,
@@ -260,8 +260,12 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
260
260
  )
261
261
 
262
262
  dec_input_info = [
263
- ("inputs_embeds", [rbln_batch_size * rbln_num_parallel_samples, 1, model_config.feature_size], "float32"),
264
- ("attention_mask", [1, rbln_dec_max_seq_len], "float32"),
263
+ (
264
+ "inputs_embeds",
265
+ [rbln_config.batch_size * rbln_config.num_parallel_samples, 1, model_config.feature_size],
266
+ "float32",
267
+ ),
268
+ ("attention_mask", [1, rbln_config.dec_max_seq_len], "float32"),
265
269
  ("cache_position", [], "int32"),
266
270
  ("block_tables", [1, 1], "int16"),
267
271
  ]
@@ -271,7 +275,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
271
275
  "cross_key_value_states",
272
276
  [
273
277
  model_config.decoder_layers * 2, # 4
274
- rbln_batch_size, # 64
278
+ rbln_config.batch_size, # 64
275
279
  model_config.decoder_attention_heads, # 2
276
280
  model_config.context_length, # 24
277
281
  model_config.d_model // model_config.decoder_attention_heads, # 13
@@ -286,8 +290,10 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
286
290
  f"self_key_value_states_{i}",
287
291
  [
288
292
  1,
289
- model_config.decoder_attention_heads * rbln_num_parallel_samples * rbln_batch_size,
290
- rbln_dec_max_seq_len,
293
+ model_config.decoder_attention_heads
294
+ * rbln_config.num_parallel_samples
295
+ * rbln_config.batch_size,
296
+ rbln_config.dec_max_seq_len,
291
297
  model_config.d_model // model_config.encoder_attention_heads,
292
298
  ],
293
299
  "float32",
@@ -298,38 +304,30 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
298
304
  enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
299
305
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
300
306
 
301
- rbln_config = RBLNConfig(
302
- rbln_cls=cls.__name__,
303
- compile_cfgs=[enc_compile_config, dec_compile_config],
304
- rbln_kwargs=rbln_kwargs,
305
- )
306
-
307
- rbln_config.model_cfg.update(
308
- {
309
- "batch_size": rbln_batch_size,
310
- "num_parallel_samples": rbln_num_parallel_samples,
311
- "dec_max_seq_len": rbln_dec_max_seq_len,
312
- }
313
- )
314
-
307
+ rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
315
308
  return rbln_config
316
309
 
317
310
  @classmethod
318
311
  def _create_runtimes(
319
312
  cls,
320
313
  compiled_models: List[rebel.RBLNCompiledModel],
321
- rbln_device_map: Dict[str, int],
322
- activate_profiler: Optional[bool] = None,
314
+ rbln_config: RBLNTimeSeriesTransformerForPredictionConfig,
323
315
  ) -> List[rebel.Runtime]:
324
- if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
316
+ if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
325
317
  cls._raise_missing_compiled_file_error(["encoder", "decoder"])
326
318
 
327
319
  return [
328
- compiled_models[0].create_runtime(
329
- tensor_type="pt", device=rbln_device_map["encoder"], activate_profiler=activate_profiler
320
+ rebel.Runtime(
321
+ compiled_models[0],
322
+ tensor_type="pt",
323
+ device=rbln_config.device_map["encoder"],
324
+ activate_profiler=rbln_config.activate_profiler,
330
325
  ),
331
- compiled_models[1].create_runtime(
332
- tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
326
+ rebel.Runtime(
327
+ compiled_models[1],
328
+ tensor_type="pt",
329
+ device=rbln_config.device_map["decoder"],
330
+ activate_profiler=rbln_config.activate_profiler,
333
331
  ),
334
332
  ]
335
333
 
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_wav2vec import RBLNWav2Vec2ForCTCConfig
15
16
  from .modeling_wav2vec2 import RBLNWav2Vec2ForCTC
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNModelForMaskedLMConfig
16
+
17
+
18
+ class RBLNWav2Vec2ForCTCConfig(RBLNModelForMaskedLMConfig):
19
+ rbln_model_input_names = ["input_values"]
@@ -12,26 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, Union
16
15
 
17
16
  import torch
18
- from transformers import AutoModelForMaskedLM, PretrainedConfig, Wav2Vec2ForCTC
17
+ from transformers import AutoModelForMaskedLM, Wav2Vec2ForCTC
19
18
  from transformers.modeling_outputs import CausalLMOutput
20
19
 
21
- from ....modeling import RBLNModel
22
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
23
- from ....utils.logging import get_logger
24
-
25
-
26
- logger = get_logger(__name__)
27
-
28
- if TYPE_CHECKING:
29
- from transformers import (
30
- AutoFeatureExtractor,
31
- AutoProcessor,
32
- AutoTokenizer,
33
- PretrainedConfig,
34
- )
20
+ from ...modeling_generic import RBLNModelForMaskedLM
21
+ from .configuration_wav2vec import RBLNWav2Vec2ForCTCConfig
35
22
 
36
23
 
37
24
  class _Wav2Vec2(torch.nn.Module):
@@ -44,11 +31,11 @@ class _Wav2Vec2(torch.nn.Module):
44
31
  return self.model.lm_head(output[0])
45
32
 
46
33
 
47
- class RBLNWav2Vec2ForCTC(RBLNModel):
34
+ class RBLNWav2Vec2ForCTC(RBLNModelForMaskedLM):
48
35
  """
49
36
  Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
50
37
 
51
- This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the
38
+ This model inherits from [`RBLNModelForMaskedLM`]. Check the superclass documentation for the generic methods the
52
39
  library implements for all its model.
53
40
 
54
41
  It implements the methods to convert a pre-trained Wav2Vec2 model into a RBLN Wav2Vec2 model by:
@@ -58,60 +45,10 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
58
45
 
59
46
  main_input_name = "input_values"
60
47
  auto_model_class = AutoModelForMaskedLM
48
+ rbln_dtype = "float32"
49
+ output_class = CausalLMOutput
50
+ output_key = "logits"
61
51
 
62
52
  @classmethod
63
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
53
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNWav2Vec2ForCTCConfig) -> torch.nn.Module:
64
54
  return _Wav2Vec2(model).eval()
65
-
66
- @classmethod
67
- def _get_rbln_config(
68
- cls,
69
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
70
- model_config: "PretrainedConfig",
71
- rbln_kwargs: Dict[str, Any] = {},
72
- ) -> RBLNConfig:
73
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
74
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
75
-
76
- if rbln_max_seq_len is None:
77
- for tokenizer in preprocessors:
78
- if hasattr(tokenizer, "model_max_length"):
79
- rbln_max_seq_len = tokenizer.model_max_length
80
- break
81
- if rbln_max_seq_len is None:
82
- raise ValueError("`rbln_max_seq_len` should be specified!")
83
-
84
- if rbln_batch_size is None:
85
- rbln_batch_size = 1
86
-
87
- input_info = [
88
- (
89
- "input_values",
90
- [
91
- rbln_batch_size,
92
- rbln_max_seq_len,
93
- ],
94
- "float32",
95
- ),
96
- ]
97
-
98
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
99
-
100
- rbln_config = RBLNConfig(
101
- rbln_cls=cls.__name__,
102
- compile_cfgs=[rbln_compile_config],
103
- rbln_kwargs=rbln_kwargs,
104
- )
105
-
106
- rbln_config.model_cfg.update(
107
- {
108
- "max_seq_len": rbln_max_seq_len,
109
- "batch_size": rbln_batch_size,
110
- }
111
- )
112
-
113
- return rbln_config
114
-
115
- def forward(self, input_values: "torch.Tensor", **kwargs):
116
- outputs = super().forward(input_values, **kwargs)
117
- return CausalLMOutput(logits=outputs)
@@ -13,4 +13,5 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....ops import paged_add_softmax_attn_decode
16
+ from .configuration_whisper import RBLNWhisperForConditionalGenerationConfig
16
17
  from .modeling_whisper import RBLNWhisperForConditionalGeneration
@@ -0,0 +1,64 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import rebel
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+
20
+
21
+ logger = get_logger()
22
+
23
+
24
+ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
25
+ def __init__(
26
+ self,
27
+ batch_size: int = None,
28
+ token_timestamps: bool = None,
29
+ use_attention_mask: bool = None,
30
+ enc_max_seq_len: int = None,
31
+ dec_max_seq_len: int = None,
32
+ **kwargs,
33
+ ):
34
+ """
35
+ Args:
36
+ batch_size (int, optional): The batch size for inference. Defaults to 1.
37
+ token_timestamps (bool, optional): Whether to output token timestamps during generation. Defaults to False.
38
+ use_attention_mask (bool, optional): Whether to use attention masks during inference. This is automatically
39
+ set to True for RBLN-CA02 devices.
40
+ enc_max_seq_len (int, optional): Maximum sequence length for the encoder.
41
+ dec_max_seq_len (int, optional): Maximum sequence length for the decoder.
42
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
43
+
44
+ Raises:
45
+ ValueError: If batch_size is not a positive integer.
46
+ """
47
+ super().__init__(**kwargs)
48
+
49
+ self.batch_size = batch_size or 1
50
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
51
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
52
+
53
+ self.token_timestamps = token_timestamps or False
54
+ self.enc_max_seq_len = enc_max_seq_len
55
+ self.dec_max_seq_len = dec_max_seq_len
56
+
57
+ self.use_attention_mask = use_attention_mask
58
+ npu = self.npu or rebel.get_npu_name()
59
+ if npu == "RBLN-CA02":
60
+ if self.use_attention_mask is False:
61
+ logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
62
+ self.use_attention_mask = True
63
+ else:
64
+ self.use_attention_mask = self.use_attention_mask or False