optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. optimum/rbln/__init__.py +2 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +45 -33
  4. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  5. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
  11. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
  15. optimum/rbln/diffusers/modeling_diffusers.py +16 -26
  16. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
  17. optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
  18. optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
  19. optimum/rbln/diffusers/models/controlnet.py +13 -7
  20. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  21. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
  23. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  24. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  25. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  26. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  28. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  29. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  30. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  31. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  33. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  34. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  35. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  36. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  38. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  42. optimum/rbln/modeling.py +33 -35
  43. optimum/rbln/modeling_base.py +45 -107
  44. optimum/rbln/transformers/__init__.py +39 -47
  45. optimum/rbln/transformers/configuration_generic.py +16 -13
  46. optimum/rbln/transformers/modeling_generic.py +18 -19
  47. optimum/rbln/transformers/modeling_rope_utils.py +1 -1
  48. optimum/rbln/transformers/models/__init__.py +46 -4
  49. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  52. optimum/rbln/transformers/models/auto/auto_factory.py +30 -12
  53. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
  54. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  55. optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
  56. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  57. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  58. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +231 -175
  59. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  60. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
  61. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
  62. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  63. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  64. optimum/rbln/transformers/models/exaone/modeling_exaone.py +51 -5
  65. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  66. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  67. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  68. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  69. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  70. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +87 -236
  71. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  72. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  73. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  74. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  75. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  76. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  77. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  78. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +33 -4
  79. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  80. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  81. optimum/rbln/transformers/models/midm/modeling_midm.py +51 -5
  82. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  83. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  84. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  85. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  86. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  87. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  88. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  89. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  90. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  91. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  92. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
  93. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +46 -25
  94. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -2
  95. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  96. optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
  97. optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
  98. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  99. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
  100. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  102. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
  104. optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
  105. optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
  106. optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
  107. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  108. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
  109. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
  110. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  111. optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
  112. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  114. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  115. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
  116. optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
  117. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  118. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  119. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  120. optimum/rbln/utils/model_utils.py +20 -0
  121. optimum/rbln/utils/submodule.py +6 -8
  122. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/METADATA +1 -1
  123. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/RECORD +127 -114
  124. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  125. /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
  126. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/WHEEL +0 -0
  127. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_vit import RBLNViTForImageClassificationConfig
16
+ from .modeling_vit import RBLNViTForImageClassification
17
+
18
+
19
+ __all__ = ["RBLNViTForImageClassificationConfig", "RBLNViTForImageClassification"]
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNModelForImageClassificationConfig
16
+
17
+
18
+ class RBLNViTForImageClassificationConfig(RBLNModelForImageClassificationConfig):
19
+ ""
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...modeling_generic import RBLNModelForImageClassification
16
+
17
+
18
+ class RBLNViTForImageClassification(RBLNModelForImageClassification):
19
+ ""
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_wav2vec import RBLNWav2Vec2ForCTCConfig
15
+ from .configuration_wav2vec2 import RBLNWav2Vec2ForCTCConfig
16
16
  from .modeling_wav2vec2 import RBLNWav2Vec2ForCTC
@@ -17,7 +17,7 @@ import torch
17
17
  from transformers import AutoModelForMaskedLM, Wav2Vec2ForCTC
18
18
 
19
19
  from ...modeling_generic import RBLNModelForMaskedLM
20
- from .configuration_wav2vec import RBLNWav2Vec2ForCTCConfig
20
+ from .configuration_wav2vec2 import RBLNWav2Vec2ForCTCConfig
21
21
 
22
22
 
23
23
  class _Wav2Vec2(torch.nn.Module):
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Any, Dict
16
+
15
17
  import rebel
16
18
 
17
19
  from ....configuration_utils import RBLNModelConfig
@@ -29,7 +31,7 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
29
31
  use_attention_mask: bool = None,
30
32
  enc_max_seq_len: int = None,
31
33
  dec_max_seq_len: int = None,
32
- **kwargs,
34
+ **kwargs: Dict[str, Any],
33
35
  ):
34
36
  """
35
37
  Args:
@@ -104,13 +104,44 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
104
104
 
105
105
  class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin):
106
106
  """
107
- The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
108
- This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
107
+ Whisper model for speech recognition and transcription optimized for RBLN NPU.
109
108
 
110
- A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
111
- It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
109
+ This model inherits from [`RBLNModel`]. It implements the methods to convert and run
110
+ pre-trained transformers based Whisper model on RBLN devices by:
112
111
  - transferring the checkpoint weights of the original into an optimized RBLN graph,
113
112
  - compiling the resulting graph using the RBLN compiler.
113
+
114
+ Example (Short form):
115
+ ```python
116
+ import torch
117
+ from transformers import AutoProcessor
118
+ from datasets import load_dataset
119
+ from optimum.rbln import RBLNWhisperForConditionalGeneration
120
+
121
+ # Load processor and dataset
122
+ model_id = "openai/whisper-tiny"
123
+ processor = AutoProcessor.from_pretrained(model_id)
124
+ ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
125
+
126
+ # Prepare input features
127
+ input_features = processor(
128
+ ds[0]["audio"]["array"],
129
+ sampling_rate=ds[0]["audio"]["sampling_rate"],
130
+ return_tensors="pt"
131
+ ).input_features
132
+
133
+ # Load and compile model (or load pre-compiled model)
134
+ model = RBLNWhisperForConditionalGeneration.from_pretrained(
135
+ model_id=model_id,
136
+ export=True,
137
+ rbln_batch_size=1
138
+ )
139
+
140
+ # Generate transcription
141
+ outputs = model.generate(input_features=input_features, return_timestamps=True)
142
+ transcription = processor.batch_decode(outputs, skip_special_tokens=True)[0]
143
+ print(f"Transcription: {transcription}")
144
+ ```
114
145
  """
115
146
 
116
147
  auto_model_class = AutoModelForSpeechSeq2Seq
@@ -153,11 +184,6 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
153
184
  return self.decoder
154
185
 
155
186
  def __getattr__(self, __name: str) -> Any:
156
- """This is the key method to implement RBLN-Whisper.
157
- Returns:
158
- Any: Whisper's corresponding method
159
- """
160
-
161
187
  def redirect(func):
162
188
  return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
163
189
 
@@ -331,12 +357,6 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
331
357
  attention_mask: Optional[torch.Tensor] = None, # need for support transformers>=4.45.0
332
358
  **kwargs,
333
359
  ):
334
- """
335
- whisper don't use attention_mask,
336
- attention_mask (`torch.Tensor`)`, *optional*):
337
- Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
338
- but it is not used. By default the silence in the input log mel spectrogram are ignored.
339
- """
340
360
  return {
341
361
  "input_ids": input_ids,
342
362
  "cache_position": cache_position,
@@ -12,5 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_xlm_roberta import RBLNXLMRobertaModelConfig
16
- from .modeling_xlm_roberta import RBLNXLMRobertaModel
15
+ from .configuration_xlm_roberta import (
16
+ RBLNXLMRobertaForSequenceClassificationConfig,
17
+ RBLNXLMRobertaModelConfig,
18
+ )
19
+ from .modeling_xlm_roberta import (
20
+ RBLNXLMRobertaForSequenceClassification,
21
+ RBLNXLMRobertaModel,
22
+ )
23
+
24
+
25
+ __all__ = [
26
+ "RBLNXLMRobertaModelConfig",
27
+ "RBLNXLMRobertaForSequenceClassificationConfig",
28
+ "RBLNXLMRobertaModel",
29
+ "RBLNXLMRobertaForSequenceClassification",
30
+ ]
@@ -12,8 +12,21 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
15
+ from ...configuration_generic import (
16
+ RBLNModelForSequenceClassificationConfig,
17
+ RBLNTransformerEncoderForFeatureExtractionConfig,
18
+ )
16
19
 
17
20
 
18
21
  class RBLNXLMRobertaModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
19
- pass
22
+ """
23
+ Configuration class for XLM-RoBERTa model.
24
+ Inherits from RBLNTransformerEncoderForFeatureExtractionConfig with no additional parameters.
25
+ """
26
+
27
+
28
+ class RBLNXLMRobertaForSequenceClassificationConfig(RBLNModelForSequenceClassificationConfig):
29
+ """
30
+ Configuration class for XLM-RoBERTa sequence classification model.
31
+ Inherits from RBLNModelForSequenceClassificationConfig with no additional parameters.
32
+ """
@@ -12,9 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
-
16
- from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
15
+ from ...modeling_generic import RBLNModelForSequenceClassification, RBLNTransformerEncoderForFeatureExtraction
17
16
 
18
17
 
19
18
  class RBLNXLMRobertaModel(RBLNTransformerEncoderForFeatureExtraction):
20
- pass
19
+ """
20
+ XLM-RoBERTa base model optimized for RBLN NPU.
21
+ """
22
+
23
+
24
+ class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification):
25
+ """
26
+ XLM-RoBERTa model for sequence classification tasks optimized for RBLN NPU.
27
+ """
28
+
29
+ rbln_model_input_names = ["input_ids", "attention_mask"]
@@ -12,10 +12,20 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import importlib
16
+ from typing import TYPE_CHECKING, Type
17
+
18
+
19
+ if TYPE_CHECKING:
20
+ from ..modeling import RBLNModel
21
+
15
22
  # Prefix used for RBLN model class names
16
23
  RBLN_PREFIX = "RBLN"
17
24
 
18
25
 
26
+ MODEL_MAPPING = {}
27
+
28
+
19
29
  def convert_hf_to_rbln_model_name(hf_model_name: str):
20
30
  """
21
31
  Convert HuggingFace model name to RBLN model name.
@@ -41,3 +51,13 @@ def convert_rbln_to_hf_model_name(rbln_model_name: str):
41
51
  """
42
52
 
43
53
  return rbln_model_name.removeprefix(RBLN_PREFIX)
54
+
55
+
56
+ def get_rbln_model_cls(cls_name: str) -> Type["RBLNModel"]:
57
+ cls = getattr(importlib.import_module("optimum.rbln"), cls_name, None)
58
+ if cls is None:
59
+ if cls_name in MODEL_MAPPING:
60
+ cls = MODEL_MAPPING[cls_name]
61
+ else:
62
+ raise AttributeError(f"RBLNModel for {cls_name} not found.")
63
+ return cls
@@ -12,19 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import importlib
16
15
  from pathlib import Path
17
16
  from typing import TYPE_CHECKING, Any, Dict, List, Type
18
17
 
19
18
  from transformers import PretrainedConfig
20
19
 
21
20
  from ..configuration_utils import RBLNModelConfig
21
+ from ..utils.model_utils import get_rbln_model_cls
22
22
 
23
23
 
24
24
  if TYPE_CHECKING:
25
25
  from transformers import PreTrainedModel
26
26
 
27
- from ..modeling_base import RBLNBaseModel
27
+ from ..modeling import RBLNModel
28
28
 
29
29
 
30
30
  class SubModulesMixin:
@@ -37,7 +37,7 @@ class SubModulesMixin:
37
37
 
38
38
  _rbln_submodules: List[Dict[str, Any]] = []
39
39
 
40
- def __init__(self, *, rbln_submodules: List["RBLNBaseModel"] = [], **kwargs) -> None:
40
+ def __init__(self, *, rbln_submodules: List["RBLNModel"] = [], **kwargs) -> None:
41
41
  for submodule_meta, submodule in zip(self._rbln_submodules, rbln_submodules):
42
42
  setattr(self, submodule_meta["name"], submodule)
43
43
 
@@ -48,7 +48,7 @@ class SubModulesMixin:
48
48
  @classmethod
49
49
  def _export_submodules_from_model(
50
50
  cls, model: "PreTrainedModel", model_save_dir: str, rbln_config: RBLNModelConfig, **kwargs
51
- ) -> List["RBLNBaseModel"]:
51
+ ) -> List["RBLNModel"]:
52
52
  rbln_submodules = []
53
53
  submodule_prefix = getattr(cls, "_rbln_submodule_prefix", None)
54
54
 
@@ -61,7 +61,7 @@ class SubModulesMixin:
61
61
  torch_submodule: PreTrainedModel = getattr(model, submodule_name)
62
62
 
63
63
  cls_name = torch_submodule.__class__.__name__
64
- submodule_cls: Type["RBLNBaseModel"] = getattr(importlib.import_module("optimum.rbln"), f"RBLN{cls_name}")
64
+ submodule_cls: Type["RBLNModel"] = get_rbln_model_cls(f"RBLN{cls_name}")
65
65
  submodule_rbln_config = getattr(rbln_config, submodule_name) or {}
66
66
 
67
67
  if isinstance(submodule_rbln_config, dict):
@@ -95,9 +95,7 @@ class SubModulesMixin:
95
95
  submodule_rbln_config = getattr(rbln_config, submodule_name)
96
96
 
97
97
  # RBLNModelConfig -> RBLNModel
98
- submodule_cls: "RBLNBaseModel" = getattr(
99
- importlib.import_module("optimum.rbln"), submodule_rbln_config.rbln_model_cls_name
100
- )
98
+ submodule_cls = get_rbln_model_cls(submodule_rbln_config.rbln_model_cls_name)
101
99
 
102
100
  json_file_path = Path(model_save_dir) / submodule_name / "config.json"
103
101
  config = PretrainedConfig.from_json_file(json_file_path)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.8.0.post2
3
+ Version: 0.8.1a1
4
4
  Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai