optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +2 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +45 -33
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
- optimum/rbln/diffusers/modeling_diffusers.py +16 -26
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
- optimum/rbln/diffusers/models/controlnet.py +13 -7
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
- optimum/rbln/modeling.py +33 -35
- optimum/rbln/modeling_base.py +45 -107
- optimum/rbln/transformers/__init__.py +39 -47
- optimum/rbln/transformers/configuration_generic.py +16 -13
- optimum/rbln/transformers/modeling_generic.py +18 -19
- optimum/rbln/transformers/modeling_rope_utils.py +1 -1
- optimum/rbln/transformers/models/__init__.py +46 -4
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +30 -12
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
- optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
- optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +231 -175
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +51 -5
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +87 -236
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
- optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +33 -4
- optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +51 -5
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
- optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
- optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +46 -25
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -2
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
- optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
- optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
- optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
- optimum/rbln/utils/model_utils.py +20 -0
- optimum/rbln/utils/submodule.py +6 -8
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/RECORD +127 -114
- /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
- /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,19 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from .configuration_vit import RBLNViTForImageClassificationConfig
|
16
|
+
from .modeling_vit import RBLNViTForImageClassification
|
17
|
+
|
18
|
+
|
19
|
+
__all__ = ["RBLNViTForImageClassificationConfig", "RBLNViTForImageClassification"]
|
@@ -0,0 +1,19 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ...configuration_generic import RBLNModelForImageClassificationConfig
|
16
|
+
|
17
|
+
|
18
|
+
class RBLNViTForImageClassificationConfig(RBLNModelForImageClassificationConfig):
|
19
|
+
""
|
@@ -0,0 +1,19 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ...modeling_generic import RBLNModelForImageClassification
|
16
|
+
|
17
|
+
|
18
|
+
class RBLNViTForImageClassification(RBLNModelForImageClassification):
|
19
|
+
""
|
@@ -12,5 +12,5 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from .
|
15
|
+
from .configuration_wav2vec2 import RBLNWav2Vec2ForCTCConfig
|
16
16
|
from .modeling_wav2vec2 import RBLNWav2Vec2ForCTC
|
@@ -17,7 +17,7 @@ import torch
|
|
17
17
|
from transformers import AutoModelForMaskedLM, Wav2Vec2ForCTC
|
18
18
|
|
19
19
|
from ...modeling_generic import RBLNModelForMaskedLM
|
20
|
-
from .
|
20
|
+
from .configuration_wav2vec2 import RBLNWav2Vec2ForCTCConfig
|
21
21
|
|
22
22
|
|
23
23
|
class _Wav2Vec2(torch.nn.Module):
|
@@ -12,6 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from typing import Any, Dict
|
16
|
+
|
15
17
|
import rebel
|
16
18
|
|
17
19
|
from ....configuration_utils import RBLNModelConfig
|
@@ -29,7 +31,7 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
|
|
29
31
|
use_attention_mask: bool = None,
|
30
32
|
enc_max_seq_len: int = None,
|
31
33
|
dec_max_seq_len: int = None,
|
32
|
-
**kwargs,
|
34
|
+
**kwargs: Dict[str, Any],
|
33
35
|
):
|
34
36
|
"""
|
35
37
|
Args:
|
@@ -104,13 +104,44 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
104
104
|
|
105
105
|
class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin):
|
106
106
|
"""
|
107
|
-
|
108
|
-
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
107
|
+
Whisper model for speech recognition and transcription optimized for RBLN NPU.
|
109
108
|
|
110
|
-
|
111
|
-
|
109
|
+
This model inherits from [`RBLNModel`]. It implements the methods to convert and run
|
110
|
+
pre-trained transformers based Whisper model on RBLN devices by:
|
112
111
|
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
113
112
|
- compiling the resulting graph using the RBLN compiler.
|
113
|
+
|
114
|
+
Example (Short form):
|
115
|
+
```python
|
116
|
+
import torch
|
117
|
+
from transformers import AutoProcessor
|
118
|
+
from datasets import load_dataset
|
119
|
+
from optimum.rbln import RBLNWhisperForConditionalGeneration
|
120
|
+
|
121
|
+
# Load processor and dataset
|
122
|
+
model_id = "openai/whisper-tiny"
|
123
|
+
processor = AutoProcessor.from_pretrained(model_id)
|
124
|
+
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
125
|
+
|
126
|
+
# Prepare input features
|
127
|
+
input_features = processor(
|
128
|
+
ds[0]["audio"]["array"],
|
129
|
+
sampling_rate=ds[0]["audio"]["sampling_rate"],
|
130
|
+
return_tensors="pt"
|
131
|
+
).input_features
|
132
|
+
|
133
|
+
# Load and compile model (or load pre-compiled model)
|
134
|
+
model = RBLNWhisperForConditionalGeneration.from_pretrained(
|
135
|
+
model_id=model_id,
|
136
|
+
export=True,
|
137
|
+
rbln_batch_size=1
|
138
|
+
)
|
139
|
+
|
140
|
+
# Generate transcription
|
141
|
+
outputs = model.generate(input_features=input_features, return_timestamps=True)
|
142
|
+
transcription = processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
143
|
+
print(f"Transcription: {transcription}")
|
144
|
+
```
|
114
145
|
"""
|
115
146
|
|
116
147
|
auto_model_class = AutoModelForSpeechSeq2Seq
|
@@ -153,11 +184,6 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
153
184
|
return self.decoder
|
154
185
|
|
155
186
|
def __getattr__(self, __name: str) -> Any:
|
156
|
-
"""This is the key method to implement RBLN-Whisper.
|
157
|
-
Returns:
|
158
|
-
Any: Whisper's corresponding method
|
159
|
-
"""
|
160
|
-
|
161
187
|
def redirect(func):
|
162
188
|
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
163
189
|
|
@@ -331,12 +357,6 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
331
357
|
attention_mask: Optional[torch.Tensor] = None, # need for support transformers>=4.45.0
|
332
358
|
**kwargs,
|
333
359
|
):
|
334
|
-
"""
|
335
|
-
whisper don't use attention_mask,
|
336
|
-
attention_mask (`torch.Tensor`)`, *optional*):
|
337
|
-
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
|
338
|
-
but it is not used. By default the silence in the input log mel spectrogram are ignored.
|
339
|
-
"""
|
340
360
|
return {
|
341
361
|
"input_ids": input_ids,
|
342
362
|
"cache_position": cache_position,
|
@@ -12,5 +12,19 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from .configuration_xlm_roberta import
|
16
|
-
|
15
|
+
from .configuration_xlm_roberta import (
|
16
|
+
RBLNXLMRobertaForSequenceClassificationConfig,
|
17
|
+
RBLNXLMRobertaModelConfig,
|
18
|
+
)
|
19
|
+
from .modeling_xlm_roberta import (
|
20
|
+
RBLNXLMRobertaForSequenceClassification,
|
21
|
+
RBLNXLMRobertaModel,
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
"RBLNXLMRobertaModelConfig",
|
27
|
+
"RBLNXLMRobertaForSequenceClassificationConfig",
|
28
|
+
"RBLNXLMRobertaModel",
|
29
|
+
"RBLNXLMRobertaForSequenceClassification",
|
30
|
+
]
|
@@ -12,8 +12,21 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from ...configuration_generic import
|
15
|
+
from ...configuration_generic import (
|
16
|
+
RBLNModelForSequenceClassificationConfig,
|
17
|
+
RBLNTransformerEncoderForFeatureExtractionConfig,
|
18
|
+
)
|
16
19
|
|
17
20
|
|
18
21
|
class RBLNXLMRobertaModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
|
19
|
-
|
22
|
+
"""
|
23
|
+
Configuration class for XLM-RoBERTa model.
|
24
|
+
Inherits from RBLNTransformerEncoderForFeatureExtractionConfig with no additional parameters.
|
25
|
+
"""
|
26
|
+
|
27
|
+
|
28
|
+
class RBLNXLMRobertaForSequenceClassificationConfig(RBLNModelForSequenceClassificationConfig):
|
29
|
+
"""
|
30
|
+
Configuration class for XLM-RoBERTa sequence classification model.
|
31
|
+
Inherits from RBLNModelForSequenceClassificationConfig with no additional parameters.
|
32
|
+
"""
|
@@ -12,9 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
16
|
-
from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
|
15
|
+
from ...modeling_generic import RBLNModelForSequenceClassification, RBLNTransformerEncoderForFeatureExtraction
|
17
16
|
|
18
17
|
|
19
18
|
class RBLNXLMRobertaModel(RBLNTransformerEncoderForFeatureExtraction):
|
20
|
-
|
19
|
+
"""
|
20
|
+
XLM-RoBERTa base model optimized for RBLN NPU.
|
21
|
+
"""
|
22
|
+
|
23
|
+
|
24
|
+
class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification):
|
25
|
+
"""
|
26
|
+
XLM-RoBERTa model for sequence classification tasks optimized for RBLN NPU.
|
27
|
+
"""
|
28
|
+
|
29
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
@@ -12,10 +12,20 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import importlib
|
16
|
+
from typing import TYPE_CHECKING, Type
|
17
|
+
|
18
|
+
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
from ..modeling import RBLNModel
|
21
|
+
|
15
22
|
# Prefix used for RBLN model class names
|
16
23
|
RBLN_PREFIX = "RBLN"
|
17
24
|
|
18
25
|
|
26
|
+
MODEL_MAPPING = {}
|
27
|
+
|
28
|
+
|
19
29
|
def convert_hf_to_rbln_model_name(hf_model_name: str):
|
20
30
|
"""
|
21
31
|
Convert HuggingFace model name to RBLN model name.
|
@@ -41,3 +51,13 @@ def convert_rbln_to_hf_model_name(rbln_model_name: str):
|
|
41
51
|
"""
|
42
52
|
|
43
53
|
return rbln_model_name.removeprefix(RBLN_PREFIX)
|
54
|
+
|
55
|
+
|
56
|
+
def get_rbln_model_cls(cls_name: str) -> Type["RBLNModel"]:
|
57
|
+
cls = getattr(importlib.import_module("optimum.rbln"), cls_name, None)
|
58
|
+
if cls is None:
|
59
|
+
if cls_name in MODEL_MAPPING:
|
60
|
+
cls = MODEL_MAPPING[cls_name]
|
61
|
+
else:
|
62
|
+
raise AttributeError(f"RBLNModel for {cls_name} not found.")
|
63
|
+
return cls
|
optimum/rbln/utils/submodule.py
CHANGED
@@ -12,19 +12,19 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import importlib
|
16
15
|
from pathlib import Path
|
17
16
|
from typing import TYPE_CHECKING, Any, Dict, List, Type
|
18
17
|
|
19
18
|
from transformers import PretrainedConfig
|
20
19
|
|
21
20
|
from ..configuration_utils import RBLNModelConfig
|
21
|
+
from ..utils.model_utils import get_rbln_model_cls
|
22
22
|
|
23
23
|
|
24
24
|
if TYPE_CHECKING:
|
25
25
|
from transformers import PreTrainedModel
|
26
26
|
|
27
|
-
from ..
|
27
|
+
from ..modeling import RBLNModel
|
28
28
|
|
29
29
|
|
30
30
|
class SubModulesMixin:
|
@@ -37,7 +37,7 @@ class SubModulesMixin:
|
|
37
37
|
|
38
38
|
_rbln_submodules: List[Dict[str, Any]] = []
|
39
39
|
|
40
|
-
def __init__(self, *, rbln_submodules: List["
|
40
|
+
def __init__(self, *, rbln_submodules: List["RBLNModel"] = [], **kwargs) -> None:
|
41
41
|
for submodule_meta, submodule in zip(self._rbln_submodules, rbln_submodules):
|
42
42
|
setattr(self, submodule_meta["name"], submodule)
|
43
43
|
|
@@ -48,7 +48,7 @@ class SubModulesMixin:
|
|
48
48
|
@classmethod
|
49
49
|
def _export_submodules_from_model(
|
50
50
|
cls, model: "PreTrainedModel", model_save_dir: str, rbln_config: RBLNModelConfig, **kwargs
|
51
|
-
) -> List["
|
51
|
+
) -> List["RBLNModel"]:
|
52
52
|
rbln_submodules = []
|
53
53
|
submodule_prefix = getattr(cls, "_rbln_submodule_prefix", None)
|
54
54
|
|
@@ -61,7 +61,7 @@ class SubModulesMixin:
|
|
61
61
|
torch_submodule: PreTrainedModel = getattr(model, submodule_name)
|
62
62
|
|
63
63
|
cls_name = torch_submodule.__class__.__name__
|
64
|
-
submodule_cls: Type["
|
64
|
+
submodule_cls: Type["RBLNModel"] = get_rbln_model_cls(f"RBLN{cls_name}")
|
65
65
|
submodule_rbln_config = getattr(rbln_config, submodule_name) or {}
|
66
66
|
|
67
67
|
if isinstance(submodule_rbln_config, dict):
|
@@ -95,9 +95,7 @@ class SubModulesMixin:
|
|
95
95
|
submodule_rbln_config = getattr(rbln_config, submodule_name)
|
96
96
|
|
97
97
|
# RBLNModelConfig -> RBLNModel
|
98
|
-
submodule_cls
|
99
|
-
importlib.import_module("optimum.rbln"), submodule_rbln_config.rbln_model_cls_name
|
100
|
-
)
|
98
|
+
submodule_cls = get_rbln_model_cls(submodule_rbln_config.rbln_model_cls_name)
|
101
99
|
|
102
100
|
json_file_path = Path(model_save_dir) / submodule_name / "config.json"
|
103
101
|
config = PretrainedConfig.from_json_file(json_file_path)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: optimum-rbln
|
3
|
-
Version: 0.8.
|
3
|
+
Version: 0.8.1a1
|
4
4
|
Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
5
5
|
Project-URL: Homepage, https://rebellions.ai
|
6
6
|
Project-URL: Documentation, https://docs.rbln.ai
|