optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. optimum/rbln/__init__.py +156 -36
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +85 -75
  31. optimum/rbln/transformers/__init__.py +79 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +96 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  73. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  74. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  75. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  76. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  77. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  78. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  79. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  80. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  81. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  82. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  83. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  84. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  85. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  86. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  87. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  88. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  89. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  90. optimum/rbln/utils/submodule.py +26 -43
  91. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
  92. optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
  93. optimum/rbln/modeling_config.py +0 -310
  94. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  95. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -21,155 +21,242 @@ different model architectures.
21
21
  """
22
22
 
23
23
  import inspect
24
- from typing import TYPE_CHECKING, Any, Dict, Optional, Union
24
+ from typing import TYPE_CHECKING, Optional, Union
25
25
 
26
26
  from transformers import (
27
+ AutoModel,
27
28
  AutoModelForAudioClassification,
29
+ AutoModelForDepthEstimation,
28
30
  AutoModelForImageClassification,
29
31
  AutoModelForMaskedLM,
30
32
  AutoModelForQuestionAnswering,
31
33
  AutoModelForSequenceClassification,
34
+ AutoModelForTextEncoding,
32
35
  PretrainedConfig,
33
36
  )
37
+ from transformers.modeling_outputs import (
38
+ BaseModelOutput,
39
+ DepthEstimatorOutput,
40
+ ImageClassifierOutput,
41
+ MaskedLMOutput,
42
+ QuestionAnsweringModelOutput,
43
+ SequenceClassifierOutput,
44
+ )
34
45
 
46
+ from ..configuration_utils import RBLNCompileConfig
35
47
  from ..modeling import RBLNModel
36
- from ..modeling_config import RBLNCompileConfig, RBLNConfig
37
48
  from ..utils.logging import get_logger
49
+ from .configuration_generic import (
50
+ RBLNModelForAudioClassificationConfig,
51
+ _RBLNImageModelConfig,
52
+ _RBLNTransformerEncoderConfig,
53
+ )
38
54
 
39
55
 
40
56
  if TYPE_CHECKING:
41
- from transformers import (
42
- AutoFeatureExtractor,
43
- AutoProcessor,
44
- AutoTokenizer,
45
- )
57
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
46
58
 
47
59
  logger = get_logger()
48
60
 
49
61
 
50
- class RBLNModelForQuestionAnswering(RBLNModel):
51
- auto_model_class = AutoModelForQuestionAnswering
62
+ class _RBLNTransformerEncoder(RBLNModel):
63
+ auto_model_class = AutoModel
52
64
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
65
+ rbln_dtype = "int64"
66
+ output_class = BaseModelOutput
67
+ output_key = "last_hidden_state"
53
68
 
54
69
  @classmethod
55
- def _get_rbln_config(
70
+ def _update_rbln_config(
56
71
  cls,
57
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
72
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
73
+ model: Optional["PreTrainedModel"] = None,
58
74
  model_config: Optional["PretrainedConfig"] = None,
59
- rbln_kwargs: Dict[str, Any] = {},
60
- ) -> RBLNConfig:
61
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
62
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
63
- rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
75
+ rbln_config: Optional[_RBLNTransformerEncoderConfig] = None,
76
+ ) -> _RBLNTransformerEncoderConfig:
77
+ return cls.update_rbln_config_for_transformers_encoder(
78
+ preprocessors=preprocessors,
79
+ model=model,
80
+ model_config=model_config,
81
+ rbln_config=rbln_config,
82
+ )
64
83
 
65
- if rbln_max_seq_len is None:
66
- for tokenizer in preprocessors:
67
- if hasattr(tokenizer, "model_max_length"):
68
- rbln_max_seq_len = tokenizer.model_max_length
69
- break
70
- if rbln_max_seq_len is None:
71
- raise ValueError("`rbln_max_seq_len` should be specified!")
84
+ @classmethod
85
+ def update_rbln_config_for_transformers_encoder(
86
+ cls,
87
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
88
+ model: Optional["PreTrainedModel"] = None,
89
+ model_config: Optional["PretrainedConfig"] = None,
90
+ rbln_config: Optional[_RBLNTransformerEncoderConfig] = None,
91
+ ) -> _RBLNTransformerEncoderConfig:
92
+ max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
93
+ model_config, "max_position_embeddings", None
94
+ )
95
+
96
+ if rbln_config.max_seq_len is None:
97
+ rbln_config.max_seq_len = max_position_embeddings
98
+ if rbln_config.max_seq_len is None:
99
+ for tokenizer in preprocessors:
100
+ if hasattr(tokenizer, "model_max_length"):
101
+ rbln_config.max_seq_len = tokenizer.model_max_length
102
+ break
103
+ if rbln_config.max_seq_len is None:
104
+ raise ValueError("`max_seq_len` should be specified!")
72
105
 
73
- if rbln_batch_size is None:
74
- rbln_batch_size = 1
106
+ if max_position_embeddings is not None and rbln_config.max_seq_len > max_position_embeddings:
107
+ raise ValueError("`max_seq_len` should be less or equal than max_position_embeddings!")
75
108
 
76
- signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
109
+ signature_params = inspect.signature(model.forward).parameters.keys()
77
110
 
78
- if rbln_model_input_names is None:
111
+ if rbln_config.model_input_names is None:
79
112
  for tokenizer in preprocessors:
80
113
  if hasattr(tokenizer, "model_input_names"):
81
- rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
114
+ rbln_config.model_input_names = [
115
+ name for name in signature_params if name in tokenizer.model_input_names
116
+ ]
82
117
 
83
- invalid_params = set(rbln_model_input_names) - set(signature_params)
118
+ invalid_params = set(rbln_config.model_input_names) - set(signature_params)
84
119
  if invalid_params:
85
120
  raise ValueError(f"Invalid model input names: {invalid_params}")
86
121
  break
87
- if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
88
- rbln_model_input_names = cls.rbln_model_input_names
89
- elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
90
- raise ValueError(
91
- "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
92
- f"and be sure to make the order of the inputs same as QuestionAnswering forward() arguments like ({list(signature_params)})"
93
- )
122
+ if rbln_config.model_input_names is None and cls.rbln_model_input_names is not None:
123
+ rbln_config.model_input_names = cls.rbln_model_input_names
124
+
94
125
  else:
95
- invalid_params = set(rbln_model_input_names) - set(signature_params)
126
+ invalid_params = set(rbln_config.model_input_names) - set(signature_params)
96
127
  if invalid_params:
97
128
  raise ValueError(f"Invalid model input names: {invalid_params}")
98
- rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
129
+ rbln_config.model_input_names = [
130
+ name for name in signature_params if name in rbln_config.model_input_names
131
+ ]
132
+
133
+ if rbln_config.model_input_names is None or len(rbln_config.model_input_names) == 0:
134
+ raise ValueError(
135
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`. "
136
+ "This is an internal error. Please report it to the developers."
137
+ )
99
138
 
100
139
  input_info = [
101
- (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
102
- for model_input_name in rbln_model_input_names
140
+ (model_input_name, [rbln_config.batch_size, rbln_config.max_seq_len], cls.rbln_dtype)
141
+ for model_input_name in rbln_config.model_input_names
103
142
  ]
104
143
 
105
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
106
- rbln_config = RBLNConfig(
107
- rbln_cls=cls.__name__,
108
- compile_cfgs=[rbln_compile_config],
109
- rbln_kwargs=rbln_kwargs,
110
- )
111
- rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
144
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
112
145
  return rbln_config
113
146
 
114
147
 
115
- class RBLNModelForImageClassification(RBLNModel):
116
- """
117
- This is a generic model class that will be instantiated as one of the model classes of the library (with a image classification head) when created with the from_pretrained() class method
118
- """
119
-
120
- auto_model_class = AutoModelForImageClassification
148
+ class _RBLNImageModel(RBLNModel):
149
+ auto_model_class = AutoModel
150
+ main_input_name = "pixel_values"
151
+ output_class = BaseModelOutput
152
+ output_key = "last_hidden_state"
121
153
 
122
154
  @classmethod
123
- def _get_rbln_config(
155
+ def _update_rbln_config(
124
156
  cls,
125
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
157
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
158
+ model: Optional["PreTrainedModel"] = None,
126
159
  model_config: Optional["PretrainedConfig"] = None,
127
- rbln_kwargs: Dict[str, Any] = {},
128
- ) -> RBLNConfig:
129
- rbln_image_size = rbln_kwargs.get("image_size", None)
130
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
160
+ rbln_config: Optional[_RBLNImageModelConfig] = None,
161
+ ) -> _RBLNImageModelConfig:
162
+ return cls.update_rbln_config_for_image_model(
163
+ preprocessors=preprocessors,
164
+ model=model,
165
+ model_config=model_config,
166
+ rbln_config=rbln_config,
167
+ )
131
168
 
132
- if rbln_image_size is None:
169
+ @classmethod
170
+ def update_rbln_config_for_image_model(
171
+ cls,
172
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
173
+ model: Optional["PreTrainedModel"] = None,
174
+ model_config: Optional["PretrainedConfig"] = None,
175
+ rbln_config: Optional[_RBLNImageModelConfig] = None,
176
+ ) -> _RBLNImageModelConfig:
177
+ if rbln_config.image_size is None:
133
178
  for processor in preprocessors:
134
179
  if hasattr(processor, "size"):
135
180
  if all(required_key in processor.size.keys() for required_key in ["height", "width"]):
136
- rbln_image_size = (processor.size["height"], processor.size["width"])
181
+ rbln_config.image_size = (processor.size["height"], processor.size["width"])
137
182
  elif "shortest_edge" in processor.size.keys():
138
- rbln_image_size = (processor.size["shortest_edge"], processor.size["shortest_edge"])
183
+ rbln_config.image_size = (processor.size["shortest_edge"], processor.size["shortest_edge"])
139
184
  elif "longest_edge" in processor.size.keys():
140
- rbln_image_size = (processor.size["longest_edge"], processor.size["longest_edge"])
185
+ rbln_config.image_size = (processor.size["longest_edge"], processor.size["longest_edge"])
141
186
  break
142
187
 
143
- if rbln_image_size is None:
144
- rbln_image_size = model_config.image_size
145
-
146
- if rbln_image_size is None:
147
- raise ValueError("`rbln_image_size` should be specified!")
188
+ if rbln_config.image_size is None:
189
+ rbln_config.image_size = model_config.image_size
148
190
 
149
- if rbln_batch_size is None:
150
- rbln_batch_size = 1
151
-
152
- if isinstance(rbln_image_size, int):
153
- rbln_image_height, rbln_image_width = rbln_image_size, rbln_image_size
154
- elif isinstance(rbln_image_size, (list, tuple)):
155
- rbln_image_height, rbln_image_width = rbln_image_size[0], rbln_image_size[1]
156
- elif isinstance(rbln_image_size, dict):
157
- rbln_image_height, rbln_image_width = rbln_image_size["height"], rbln_image_size["width"]
158
- else:
159
- raise ValueError(
160
- "`rbln_image_size` should be `int` (ex. 224), `tuple` (ex. 224, 224), `dict` (ex. {'height': 224, 'width': 224}) format"
161
- )
191
+ if rbln_config.image_size is None:
192
+ raise ValueError("`image_size` should be specified!")
162
193
 
163
194
  input_info = [
164
195
  (
165
- "pixel_values",
166
- [rbln_batch_size, 3, rbln_image_height, rbln_image_width],
196
+ cls.main_input_name,
197
+ [rbln_config.batch_size, 3, rbln_config.image_height, rbln_config.image_width],
167
198
  "float32",
168
199
  )
169
200
  ]
170
201
 
171
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
172
- return RBLNConfig(rbln_cls=cls.__name__, compile_cfgs=[rbln_compile_config], rbln_kwargs=rbln_kwargs)
202
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
203
+ return rbln_config
204
+
205
+
206
+ class RBLNModelForQuestionAnswering(_RBLNTransformerEncoder):
207
+ auto_model_class = AutoModelForQuestionAnswering
208
+ rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
209
+ output_class = QuestionAnsweringModelOutput
210
+
211
+ def _prepare_output(self, output, return_dict):
212
+ """
213
+ Prepare QuestionAnswering specific output format.
214
+ """
215
+ start_logits, end_logits = output
216
+
217
+ if not return_dict:
218
+ return (start_logits, end_logits)
219
+ else:
220
+ return QuestionAnsweringModelOutput(start_logits=start_logits, end_logits=end_logits)
221
+
222
+
223
+ class RBLNModelForSequenceClassification(_RBLNTransformerEncoder):
224
+ auto_model_class = AutoModelForSequenceClassification
225
+ rbln_model_input_names = ["input_ids", "attention_mask"]
226
+ output_class = SequenceClassifierOutput
227
+ output_key = "logits"
228
+
229
+
230
+ class RBLNModelForMaskedLM(_RBLNTransformerEncoder):
231
+ auto_model_class = AutoModelForMaskedLM
232
+ rbln_model_input_names = ["input_ids", "attention_mask"]
233
+ output_class = MaskedLMOutput
234
+ output_key = "logits"
235
+
236
+
237
+ class RBLNModelForTextEncoding(_RBLNTransformerEncoder):
238
+ auto_model_class = AutoModelForTextEncoding
239
+ rbln_model_input_names = ["input_ids", "attention_mask"]
240
+
241
+
242
+ class RBLNTransformerEncoderForFeatureExtraction(_RBLNTransformerEncoder):
243
+ # TODO: RBLNModel is also for feature extraction.
244
+ auto_model_class = AutoModel
245
+ rbln_model_input_names = ["input_ids", "attention_mask"]
246
+ output_class = BaseModelOutput
247
+ output_key = "last_hidden_state"
248
+
249
+
250
+ class RBLNModelForImageClassification(_RBLNImageModel):
251
+ auto_model_class = AutoModelForImageClassification
252
+ output_class = ImageClassifierOutput
253
+ output_key = "logits"
254
+
255
+
256
+ class RBLNModelForDepthEstimation(_RBLNImageModel):
257
+ auto_model_class = AutoModelForDepthEstimation
258
+ output_class = DepthEstimatorOutput
259
+ output_key = "predicted_depth"
173
260
 
174
261
 
175
262
  class RBLNModelForAudioClassification(RBLNModel):
@@ -186,219 +273,45 @@ class RBLNModelForAudioClassification(RBLNModel):
186
273
  """
187
274
 
188
275
  auto_model_class = AutoModelForAudioClassification
276
+ output_class = SequenceClassifierOutput
277
+ output_key = "logits"
189
278
 
190
279
  @classmethod
191
- def _get_rbln_config(
280
+ def _update_rbln_config(
192
281
  cls,
193
- preprocessors: "AutoFeatureExtractor",
194
- model_config: "PretrainedConfig",
195
- rbln_kwargs: Dict[str, Any] = {},
196
- ) -> RBLNConfig:
197
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
198
- rbln_max_length = rbln_kwargs.get("max_length", None)
199
- rbln_num_mel_bins = rbln_kwargs.get("num_mel_bins", None)
200
-
201
- if rbln_batch_size is None:
202
- rbln_batch_size = 1
203
-
204
- if rbln_num_mel_bins is None:
205
- rbln_num_mel_bins = getattr(model_config, "num_mel_bins", None)
206
- if rbln_num_mel_bins is None:
282
+ preprocessors: "AutoFeatureExtractor" = None,
283
+ model: Optional["PreTrainedModel"] = None,
284
+ model_config: "PretrainedConfig" = None,
285
+ rbln_config: Optional[RBLNModelForAudioClassificationConfig] = None,
286
+ ) -> RBLNModelForAudioClassificationConfig:
287
+ if rbln_config.num_mel_bins is None:
288
+ rbln_config.num_mel_bins = getattr(model_config, "num_mel_bins", None)
289
+ if rbln_config.num_mel_bins is None:
207
290
  for feature_extractor in preprocessors:
208
291
  if hasattr(feature_extractor, "num_mel_bins"):
209
- rbln_num_mel_bins = feature_extractor.num_mel_bins
292
+ rbln_config.num_mel_bins = feature_extractor.num_mel_bins
210
293
  break
211
294
 
212
- if rbln_num_mel_bins is None:
213
- raise ValueError("`rbln_num_mel_bins` should be specified!")
295
+ if rbln_config.num_mel_bins is None:
296
+ raise ValueError("`num_mel_bins` should be specified!")
214
297
 
215
- if rbln_max_length is None:
216
- rbln_max_length = getattr(model_config, "max_length", None)
298
+ if rbln_config.max_length is None:
299
+ rbln_config.max_length = getattr(model_config, "max_length", None)
217
300
  for feature_extractor in preprocessors:
218
301
  if hasattr(feature_extractor, "max_length"):
219
- rbln_max_length = feature_extractor.max_length
302
+ rbln_config.max_length = feature_extractor.max_length
220
303
  break
221
304
 
222
- if rbln_max_length is None:
223
- raise ValueError("`rbln_max_length` should be specified!")
305
+ if rbln_config.max_length is None:
306
+ raise ValueError("`max_length` should be specified!")
224
307
 
225
308
  input_info = [
226
309
  (
227
310
  "input_values",
228
- [rbln_batch_size, rbln_max_length, rbln_num_mel_bins],
311
+ [rbln_config.batch_size, rbln_config.max_length, rbln_config.num_mel_bins],
229
312
  "float32",
230
313
  ),
231
314
  ]
232
315
 
233
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
234
- rbln_config = RBLNConfig(
235
- rbln_cls=cls.__name__,
236
- compile_cfgs=[rbln_compile_config],
237
- rbln_kwargs=rbln_kwargs,
238
- )
239
- rbln_config.model_cfg.update(
240
- {
241
- "batch_size": rbln_batch_size,
242
- "max_length": rbln_max_length,
243
- "num_mel_bins": rbln_num_mel_bins,
244
- }
245
- )
246
- return rbln_config
247
-
248
-
249
- class RBLNModelForSequenceClassification(RBLNModel):
250
- """
251
- This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence classification head) when created with the from_pretrained() class method
252
- This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
253
-
254
- A class to convert and run pre-trained transformers based SequenceClassification models on RBLN devices.
255
- It implements the methods to convert a pre-trained transformers SequenceClassification model into a RBLN transformer model by:
256
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
257
- - compiling the resulting graph using the RBLN compiler.
258
-
259
- Currently, this model class supports the 'XLMRoberta' and 'Roberta' model from the transformers library. Future updates may include support for additional model types.
260
- """
261
-
262
- auto_model_class = AutoModelForSequenceClassification
263
-
264
- @classmethod
265
- def _get_rbln_config(
266
- cls,
267
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
268
- model_config: Optional["PretrainedConfig"] = None,
269
- rbln_kwargs: Dict[str, Any] = {},
270
- ) -> RBLNConfig:
271
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
272
- rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
273
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
274
-
275
- max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
276
- model_config, "max_position_embeddings", None
277
- )
278
-
279
- if rbln_max_seq_len is None:
280
- rbln_max_seq_len = max_position_embeddings
281
- if rbln_max_seq_len is None:
282
- for tokenizer in preprocessors:
283
- if hasattr(tokenizer, "model_max_length"):
284
- rbln_max_seq_len = tokenizer.model_max_length
285
- break
286
- if rbln_max_seq_len is None:
287
- raise ValueError("`rbln_max_seq_len` should be specified!")
288
-
289
- if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
290
- raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
291
-
292
- signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
293
-
294
- if rbln_model_input_names is None:
295
- for tokenizer in preprocessors:
296
- if hasattr(tokenizer, "model_input_names"):
297
- rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
298
-
299
- invalid_params = set(rbln_model_input_names) - set(signature_params)
300
- if invalid_params:
301
- raise ValueError(f"Invalid model input names: {invalid_params}")
302
- break
303
- if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
304
- rbln_model_input_names = cls.rbln_model_input_names
305
- elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
306
- raise ValueError(
307
- "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
308
- f"and be sure to make the order of the inputs same as SequenceClassification forward() arguments like ({list(signature_params)})"
309
- )
310
- else:
311
- invalid_params = set(rbln_model_input_names) - set(signature_params)
312
- if invalid_params:
313
- raise ValueError(f"Invalid model input names: {invalid_params}")
314
- rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
315
-
316
- if rbln_batch_size is None:
317
- rbln_batch_size = 1
318
-
319
- input_info = [
320
- (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
321
- for model_input_name in rbln_model_input_names
322
- ]
323
-
324
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
325
- rbln_config = RBLNConfig(
326
- rbln_cls=cls.__name__,
327
- compile_cfgs=[rbln_compile_config],
328
- rbln_kwargs=rbln_kwargs,
329
- )
330
- rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
331
- return rbln_config
332
-
333
-
334
- class RBLNModelForMaskedLM(RBLNModel):
335
- auto_model_class = AutoModelForMaskedLM
336
-
337
- @classmethod
338
- def _get_rbln_config(
339
- cls,
340
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
341
- model_config: Optional["PretrainedConfig"] = None,
342
- rbln_kwargs: Dict[str, Any] = {},
343
- ) -> RBLNConfig:
344
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
345
- rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
346
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
347
-
348
- max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
349
- model_config, "max_position_embeddings", None
350
- )
351
-
352
- if rbln_max_seq_len is None:
353
- rbln_max_seq_len = max_position_embeddings
354
- if rbln_max_seq_len is None:
355
- for tokenizer in preprocessors:
356
- if hasattr(tokenizer, "model_max_length"):
357
- rbln_max_seq_len = tokenizer.model_max_length
358
- break
359
- if rbln_max_seq_len is None:
360
- raise ValueError("`rbln_max_seq_len` should be specified!")
361
-
362
- if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
363
- raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
364
-
365
- signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
366
-
367
- if rbln_model_input_names is None:
368
- for tokenizer in preprocessors:
369
- if hasattr(tokenizer, "model_input_names"):
370
- rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
371
-
372
- invalid_params = set(rbln_model_input_names) - set(signature_params)
373
- if invalid_params:
374
- raise ValueError(f"Invalid model input names: {invalid_params}")
375
- break
376
- if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
377
- rbln_model_input_names = cls.rbln_model_input_names
378
- elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
379
- raise ValueError(
380
- "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
381
- f"and be sure to make the order of the inputs same as MaskedLM forward() arguments like ({list(signature_params)})"
382
- )
383
- else:
384
- invalid_params = set(rbln_model_input_names) - set(signature_params)
385
- if invalid_params:
386
- raise ValueError(f"Invalid model input names: {invalid_params}")
387
- rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
388
-
389
- if rbln_batch_size is None:
390
- rbln_batch_size = 1
391
-
392
- input_info = [
393
- (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
394
- for model_input_name in rbln_model_input_names
395
- ]
396
-
397
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
398
- rbln_config = RBLNConfig(
399
- rbln_cls=cls.__name__,
400
- compile_cfgs=[rbln_compile_config],
401
- rbln_kwargs=rbln_kwargs,
402
- )
403
- rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
316
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
404
317
  return rbln_config