optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. optimum/rbln/__init__.py +156 -36
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +85 -75
  31. optimum/rbln/transformers/__init__.py +79 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +96 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  73. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  74. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  75. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  76. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  77. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  78. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  79. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  80. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  81. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  82. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  83. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  84. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  85. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  86. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  87. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  88. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  89. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  90. optimum/rbln/utils/submodule.py +26 -43
  91. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
  92. optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
  93. optimum/rbln/modeling_config.py +0 -310
  94. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  95. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -18,19 +18,14 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
18
18
  import rebel
19
19
  import torch
20
20
  from rebel.compile_context import CompileContext
21
- from transformers import (
22
- AutoModelForSpeechSeq2Seq,
23
- AutoProcessor,
24
- PretrainedConfig,
25
- WhisperForConditionalGeneration,
26
- WhisperModel,
27
- )
21
+ from transformers import AutoModelForSpeechSeq2Seq, WhisperForConditionalGeneration, WhisperModel
28
22
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
29
23
 
24
+ from ....configuration_utils import RBLNCompileConfig
30
25
  from ....modeling import RBLNModel
31
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
26
  from ....utils.logging import get_logger
33
27
  from ....utils.runtime_utils import RBLNPytorchRuntime
28
+ from .configuration_whisper import RBLNWhisperForConditionalGenerationConfig
34
29
  from .generation_whisper import RBLNWhisperGenerationMixin
35
30
  from .whisper_architecture import WhisperWrapper
36
31
 
@@ -38,7 +33,14 @@ from .whisper_architecture import WhisperWrapper
38
33
  logger = get_logger(__name__)
39
34
 
40
35
  if TYPE_CHECKING:
41
- from transformers import AutoFeatureExtractor, AutoProcessor, GenerationConfig, PretrainedConfig, PreTrainedModel
36
+ from transformers import (
37
+ AutoFeatureExtractor,
38
+ AutoProcessor,
39
+ AutoTokenizer,
40
+ GenerationConfig,
41
+ PretrainedConfig,
42
+ PreTrainedModel,
43
+ )
42
44
 
43
45
 
44
46
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
@@ -117,10 +119,10 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
117
119
  def __post_init__(self, **kwargs):
118
120
  super().__post_init__(**kwargs)
119
121
 
120
- self.batch_size = self.rbln_config.model_cfg["batch_size"]
121
- self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
122
- self.rbln_token_timestamps = self.rbln_config.model_cfg["token_timestamps"]
123
- self.use_attention_mask = self.rbln_config.model_cfg.get("use_attention_mask", None)
122
+ self.batch_size = self.rbln_config.batch_size
123
+ self.dec_max_seq_len = self.rbln_config.dec_max_seq_len
124
+ self.rbln_token_timestamps = self.rbln_config.token_timestamps
125
+ self.use_attention_mask = self.rbln_config.use_attention_mask
124
126
 
125
127
  self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
126
128
  self.decoder = RBLNRuntimeDecoder(
@@ -169,16 +171,16 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
169
171
  raise NotImplementedError
170
172
 
171
173
  @classmethod
172
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
173
- rbln_token_timestamps = rbln_config.model_cfg["token_timestamps"]
174
- use_attention_mask = rbln_config.model_cfg.get("use_attention_mask", False)
174
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNWhisperForConditionalGenerationConfig):
175
175
  return WhisperWrapper(
176
- model, use_attention_mask=use_attention_mask, rbln_token_timestamps=rbln_token_timestamps
176
+ model,
177
+ use_attention_mask=rbln_config.use_attention_mask,
178
+ rbln_token_timestamps=rbln_config.token_timestamps,
177
179
  )
178
180
 
179
181
  @classmethod
180
182
  @torch.inference_mode()
181
- def get_compiled_model(cls, model, rbln_config: RBLNConfig):
183
+ def get_compiled_model(cls, model, rbln_config: RBLNWhisperForConditionalGenerationConfig):
182
184
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
183
185
 
184
186
  enc_compile_config = rbln_config.compile_cfgs[0]
@@ -218,32 +220,21 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
218
220
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
219
221
 
220
222
  @classmethod
221
- def _get_rbln_config(
223
+ def _update_rbln_config(
222
224
  cls,
223
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor"],
224
- model_config: "PretrainedConfig",
225
- rbln_kwargs: Dict[str, Any] = {},
226
- ) -> RBLNConfig:
227
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
228
- rbln_token_timestamps = rbln_kwargs.get("token_timestamps", False)
229
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
230
-
225
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
226
+ model: Optional["PreTrainedModel"] = None,
227
+ model_config: Optional["PretrainedConfig"] = None,
228
+ rbln_config: Optional[RBLNWhisperForConditionalGenerationConfig] = None,
229
+ ) -> RBLNWhisperForConditionalGenerationConfig:
231
230
  expected_seq_len = model_config.max_source_positions * 2
232
231
  num_mel_bins = model_config.num_mel_bins
233
- enc_max_seq_len = model_config.max_source_positions
232
+ rbln_config.enc_max_seq_len = model_config.max_source_positions
234
233
 
235
234
  # 'whisper-large-v3-turbo' doesn't have 'max_length', but PretrainedConfig have default value for the key 'max_length'
236
- rbln_dec_max_seq_len = getattr(model_config, "max_target_positions", None)
237
- if rbln_dec_max_seq_len is None:
238
- rbln_dec_max_seq_len = model_config.max_length
239
-
240
- # use_attention_mask conditions
241
- rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
242
- if rbln_use_attention_mask is None:
243
- rbln_use_attention_mask = False
244
- rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
245
- if rbln_npu == "RBLN-CA02":
246
- rbln_use_attention_mask = True
235
+ rbln_config.dec_max_seq_len = getattr(model_config, "max_target_positions", None)
236
+ if rbln_config.dec_max_seq_len is None:
237
+ rbln_config.dec_max_seq_len = model_config.max_length
247
238
 
248
239
  enc_input_info = [
249
240
  ("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
@@ -252,9 +243,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
252
243
  "cross_key_value_states",
253
244
  [
254
245
  model_config.decoder_layers * 2,
255
- rbln_batch_size,
246
+ rbln_config.batch_size,
256
247
  model_config.decoder_attention_heads,
257
- enc_max_seq_len,
248
+ rbln_config.enc_max_seq_len,
258
249
  model_config.d_model // model_config.decoder_attention_heads,
259
250
  ],
260
251
  "float32",
@@ -262,9 +253,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
262
253
  ]
263
254
 
264
255
  dec_input_info = [
265
- ("decoder_input_ids", [rbln_batch_size, 1], "int64"),
266
- ("cache_position", [rbln_batch_size, 1], "int32"),
267
- ("block_tables", [rbln_batch_size, 1], "int16"),
256
+ ("decoder_input_ids", [rbln_config.batch_size, 1], "int64"),
257
+ ("cache_position", [rbln_config.batch_size, 1], "int32"),
258
+ ("block_tables", [rbln_config.batch_size, 1], "int16"),
268
259
  ]
269
260
  dec_input_info.extend(
270
261
  [
@@ -272,9 +263,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
272
263
  "cross_key_value_states",
273
264
  [
274
265
  model_config.decoder_layers * 2,
275
- rbln_batch_size,
266
+ rbln_config.batch_size,
276
267
  model_config.decoder_attention_heads,
277
- enc_max_seq_len,
268
+ rbln_config.enc_max_seq_len,
278
269
  model_config.d_model // model_config.decoder_attention_heads,
279
270
  ],
280
271
  "float32",
@@ -286,9 +277,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
286
277
  (
287
278
  f"self_key_value_states_{i}",
288
279
  [
289
- rbln_batch_size,
280
+ rbln_config.batch_size,
290
281
  model_config.decoder_attention_heads,
291
- rbln_dec_max_seq_len,
282
+ rbln_config.dec_max_seq_len,
292
283
  model_config.d_model // model_config.encoder_attention_heads,
293
284
  ],
294
285
  "float32",
@@ -297,26 +288,15 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
297
288
  ]
298
289
  )
299
290
 
300
- if rbln_use_attention_mask:
301
- dec_input_info.insert(1, ("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
291
+ if rbln_config.use_attention_mask:
292
+ dec_input_info.insert(
293
+ 1, ("decoder_attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
294
+ )
302
295
 
303
296
  enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
304
297
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
305
298
 
306
- rbln_config = RBLNConfig(
307
- rbln_cls=cls.__name__,
308
- compile_cfgs=[enc_compile_config, dec_compile_config],
309
- rbln_kwargs=rbln_kwargs,
310
- )
311
-
312
- rbln_config.model_cfg.update(
313
- {
314
- "batch_size": rbln_batch_size,
315
- "dec_max_seq_len": rbln_dec_max_seq_len,
316
- "token_timestamps": rbln_token_timestamps,
317
- "use_attention_mask": rbln_use_attention_mask,
318
- }
319
- )
299
+ rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
320
300
 
321
301
  return rbln_config
322
302
 
@@ -324,18 +304,23 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
324
304
  def _create_runtimes(
325
305
  cls,
326
306
  compiled_models: List[rebel.RBLNCompiledModel],
327
- rbln_device_map: Dict[str, int],
328
- activate_profiler: Optional[bool] = None,
307
+ rbln_config: RBLNWhisperForConditionalGenerationConfig,
329
308
  ) -> List[rebel.Runtime]:
330
- if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
309
+ if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
331
310
  cls._raise_missing_compiled_file_error(["encoder", "decoder"])
332
311
 
333
312
  return [
334
- compiled_models[0].create_runtime(
335
- tensor_type="pt", device=rbln_device_map["encoder"], activate_profiler=activate_profiler
313
+ rebel.Runtime(
314
+ compiled_models[0],
315
+ tensor_type="pt",
316
+ device=rbln_config.device_map["encoder"],
317
+ activate_profiler=rbln_config.activate_profiler,
336
318
  ),
337
- compiled_models[1].create_runtime(
338
- tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
319
+ rebel.Runtime(
320
+ compiled_models[1],
321
+ tensor_type="pt",
322
+ device=rbln_config.device_map["decoder"],
323
+ activate_profiler=rbln_config.activate_profiler,
339
324
  ),
340
325
  ]
341
326
 
@@ -432,7 +417,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
432
417
  decoder_output = self.decoder(
433
418
  decoder_input_ids=decoder_input_ids.contiguous(),
434
419
  decoder_attention_mask=self.decoder_attention_mask,
435
- cache_position=torch.zeros([self.rbln_config.model_cfg["batch_size"], 1], dtype=torch.int32),
420
+ cache_position=torch.zeros([self.rbln_config.batch_size, 1], dtype=torch.int32),
436
421
  )
437
422
  lm_logits = decoder_output.logits
438
423
  self.language_cross = decoder_output.cross_attentions
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_xlm_roberta import RBLNXLMRobertaModelConfig
15
16
  from .modeling_xlm_roberta import RBLNXLMRobertaModel
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
16
+
17
+
18
+ class RBLNXLMRobertaModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
19
+ pass
@@ -12,89 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import inspect
16
- from typing import TYPE_CHECKING, Optional, Union
17
15
 
18
- from transformers import PretrainedConfig
16
+ from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
19
17
 
20
- from ....modeling import RBLNModel
21
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
22
- from ....utils.logging import get_logger
23
18
 
24
-
25
- logger = get_logger(__name__)
26
-
27
- if TYPE_CHECKING:
28
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
29
-
30
-
31
- class RBLNXLMRobertaModel(RBLNModel):
32
- @classmethod
33
- def _get_rbln_config(
34
- cls,
35
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
36
- model_config: Optional["PretrainedConfig"] = None,
37
- rbln_kwargs={},
38
- ) -> RBLNConfig:
39
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
40
- rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
41
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
42
-
43
- max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
44
- model_config, "max_position_embeddings", None
45
- )
46
-
47
- if rbln_max_seq_len is None:
48
- rbln_max_seq_len = max_position_embeddings
49
- if rbln_max_seq_len is None:
50
- for tokenizer in preprocessors:
51
- if hasattr(tokenizer, "model_max_length"):
52
- rbln_max_seq_len = tokenizer.model_max_length
53
- break
54
- if rbln_max_seq_len is None:
55
- raise ValueError("`rbln_max_seq_len` should be specified!")
56
-
57
- if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
58
- raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
59
-
60
- signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
61
-
62
- if rbln_model_input_names is None:
63
- for tokenizer in preprocessors:
64
- if hasattr(tokenizer, "model_input_names"):
65
- rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
66
-
67
- invalid_params = set(rbln_model_input_names) - set(signature_params)
68
- if invalid_params:
69
- raise ValueError(f"Invalid model input names: {invalid_params}")
70
- break
71
- if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
72
- rbln_model_input_names = cls.rbln_model_input_names
73
- elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
74
- raise ValueError(
75
- "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
76
- f"and be sure to make the order of the inputs same as XLMRobertaModel forward() arguments like ({list(signature_params)})"
77
- )
78
- else:
79
- invalid_params = set(rbln_model_input_names) - set(signature_params)
80
- if invalid_params:
81
- raise ValueError(f"Invalid model input names: {invalid_params}")
82
- rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
83
-
84
- if rbln_batch_size is None:
85
- rbln_batch_size = 1
86
-
87
- input_info = [
88
- (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
89
- for model_input_name in rbln_model_input_names
90
- ]
91
-
92
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
93
-
94
- rbln_config = RBLNConfig(
95
- rbln_cls=cls.__name__,
96
- compile_cfgs=[rbln_compile_config],
97
- rbln_kwargs=rbln_kwargs,
98
- )
99
- rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
100
- return rbln_config
19
+ class RBLNXLMRobertaModel(RBLNTransformerEncoderForFeatureExtraction):
20
+ pass
@@ -13,10 +13,9 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import importlib
16
- from pathlib import Path
17
- from typing import TYPE_CHECKING, Any, Dict, List
16
+ from typing import TYPE_CHECKING, Any, Dict, List, Type
18
17
 
19
- from ..modeling_config import RBLNConfig
18
+ from ..configuration_utils import RBLNModelConfig
20
19
 
21
20
 
22
21
  if TYPE_CHECKING:
@@ -35,37 +34,32 @@ class SubModulesMixin:
35
34
 
36
35
  _rbln_submodules: List[Dict[str, Any]] = []
37
36
 
38
- def __init__(
39
- self,
40
- *,
41
- rbln_submodules: List["RBLNBaseModel"] = [],
42
- **kwargs,
43
- ) -> None:
37
+ def __init__(self, *, rbln_submodules: List["RBLNBaseModel"] = [], **kwargs) -> None:
44
38
  for submodule_meta, submodule in zip(self._rbln_submodules, rbln_submodules):
45
39
  setattr(self, submodule_meta["name"], submodule)
46
40
 
47
41
  @classmethod
48
42
  def _export_submodules_from_model(
49
- cls,
50
- model: "PreTrainedModel",
51
- model_save_dir: str,
52
- rbln_kwargs: Dict[str, Any],
53
- **kwargs,
43
+ cls, model: "PreTrainedModel", model_save_dir: str, rbln_config: RBLNModelConfig, **kwargs
54
44
  ) -> List["RBLNBaseModel"]:
55
45
  rbln_submodules = []
56
46
  for submodule in cls._rbln_submodules:
57
47
  submodule_name = submodule["name"]
58
48
  torch_submodule: "PreTrainedModel" = getattr(model, submodule["name"])
59
49
  cls_name = torch_submodule.__class__.__name__
60
- submodule_cls: "RBLNBaseModel" = getattr(importlib.import_module("optimum.rbln"), f"RBLN{cls_name}")
50
+ submodule_cls: Type["RBLNBaseModel"] = getattr(importlib.import_module("optimum.rbln"), f"RBLN{cls_name}")
51
+ submodule_rbln_config = getattr(rbln_config, submodule_name) or {}
61
52
 
62
- if submodule_name in rbln_kwargs:
63
- kwargs["rbln_config"] = rbln_kwargs[submodule_name]
53
+ if isinstance(submodule_rbln_config, dict):
54
+ submodule_rbln_config_class = submodule_cls.get_rbln_config_class()
55
+ submodule_rbln_config = submodule_rbln_config_class(**submodule_rbln_config)
56
+ setattr(rbln_config, submodule_name, submodule_rbln_config)
64
57
 
65
58
  rbln_submodule = submodule_cls.from_model(
66
59
  model=torch_submodule,
67
60
  subfolder=submodule_name,
68
61
  model_save_dir=model_save_dir,
62
+ rbln_config=submodule_rbln_config,
69
63
  **kwargs,
70
64
  )
71
65
 
@@ -74,55 +68,44 @@ class SubModulesMixin:
74
68
  return rbln_submodules
75
69
 
76
70
  @classmethod
77
- def _load_submodules_from_compiled_models(
78
- cls,
79
- model_save_dir: str,
80
- rbln_kwargs: Dict[str, Any],
81
- **kwargs,
82
- ):
71
+ def _load_submodules_from_compiled_models(cls, model_save_dir: str, rbln_config: RBLNModelConfig, **kwargs):
83
72
  rbln_submodules = []
84
73
  for submodule in cls._rbln_submodules:
85
74
  submodule_name = submodule["name"]
86
75
 
87
- if submodule_name in rbln_kwargs:
88
- kwargs["rbln_config"] = rbln_kwargs[submodule_name]
89
-
90
76
  # Get cls name for call the constructor of the rbln class
91
- submodule_rbln_config = RBLNConfig.load(Path(model_save_dir) / submodule_name)
92
- submodule_cls_name = submodule_rbln_config.meta["cls"]
93
- submodule_cls: "RBLNBaseModel" = getattr(importlib.import_module("optimum.rbln"), submodule_cls_name)
77
+ submodule_rbln_config = getattr(rbln_config, submodule_name)
78
+
79
+ # RBLNModelConfig -> RBLNModel
80
+ submodule_cls: "RBLNBaseModel" = getattr(
81
+ importlib.import_module("optimum.rbln"), submodule_rbln_config.rbln_model_cls_name
82
+ )
94
83
 
95
84
  rbln_submodule = submodule_cls._from_pretrained(
96
85
  model_id=model_save_dir,
97
86
  config=None,
98
87
  subfolder=submodule_name,
88
+ rbln_config=submodule_rbln_config,
99
89
  **kwargs,
100
90
  )
91
+
92
+ # update submodule's rbln_config since it is updated in the from_pretrained method
93
+ setattr(rbln_config, submodule_name, rbln_submodule.rbln_config)
101
94
  rbln_submodules.append(rbln_submodule)
95
+
102
96
  return rbln_submodules
103
97
 
104
98
  @classmethod
105
- def _load_submodules(
106
- cls,
107
- model_save_dir,
108
- rbln_kwargs,
109
- model=None,
110
- **kwargs,
111
- ):
99
+ def _load_submodules(cls, model_save_dir, rbln_config: RBLNModelConfig, model=None, **kwargs):
112
100
  # Two ways :
113
101
  # 1. Compile from pytorch object
114
102
  # 2. Load from compiled file
115
103
  if model is not None:
116
104
  return cls._export_submodules_from_model(
117
- model=model,
118
- model_save_dir=model_save_dir,
119
- rbln_kwargs=rbln_kwargs,
120
- **kwargs,
105
+ model=model, model_save_dir=model_save_dir, rbln_config=rbln_config, **kwargs
121
106
  )
122
107
 
123
108
  else:
124
109
  return cls._load_submodules_from_compiled_models(
125
- model_save_dir=model_save_dir,
126
- rbln_kwargs=rbln_kwargs,
127
- **kwargs,
110
+ model_save_dir=model_save_dir, rbln_config=rbln_config, **kwargs
128
111
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.7.4a4
3
+ Version: 0.7.4a5
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai