optimum-rbln 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. optimum/rbln/__init__.py +115 -0
  2. optimum/rbln/__version__.py +1 -0
  3. optimum/rbln/diffusers/__init__.py +64 -0
  4. optimum/rbln/diffusers/models/__init__.py +26 -0
  5. optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
  6. optimum/rbln/diffusers/models/controlnet.py +180 -0
  7. optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
  8. optimum/rbln/diffusers/pipelines/__init__.py +30 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
  10. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
  18. optimum/rbln/modeling.py +0 -0
  19. optimum/rbln/modeling_alias.py +49 -0
  20. optimum/rbln/modeling_base.py +645 -0
  21. optimum/rbln/modeling_config.py +169 -0
  22. optimum/rbln/modeling_seq2seq.py +469 -0
  23. optimum/rbln/transformers/__init__.py +59 -0
  24. optimum/rbln/transformers/generation/__init__.py +24 -0
  25. optimum/rbln/transformers/generation/streamers.py +122 -0
  26. optimum/rbln/transformers/models/__init__.py +28 -0
  27. optimum/rbln/transformers/models/bart/__init__.py +24 -0
  28. optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
  29. optimum/rbln/transformers/models/clip/__init__.py +24 -0
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
  31. optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
  32. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
  33. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
  34. optimum/rbln/transformers/models/llama/__init__.py +24 -0
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
  37. optimum/rbln/transformers/models/t5/__init__.py +24 -0
  38. optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
  39. optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
  40. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
  41. optimum/rbln/transformers/models/whisper/__init__.py +24 -0
  42. optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
  43. optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
  44. optimum/rbln/utils/__init__.py +25 -0
  45. optimum/rbln/utils/import_utils.py +28 -0
  46. optimum/rbln/utils/runtime_utils.py +71 -0
  47. optimum/rbln/utils/save_utils.py +92 -0
  48. optimum_rbln-0.1.0.dist-info/METADATA +144 -0
  49. optimum_rbln-0.1.0.dist-info/RECORD +51 -0
  50. optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
  51. optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,106 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+ """RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
24
+
25
+
26
+ from diffusers import StableDiffusionPipeline
27
+
28
+ from ....modeling_base import RBLNBaseModel
29
+ from ....transformers import RBLNCLIPTextModel
30
+ from ...models import RBLNAutoencoderKL, RBLNUNet2DConditionModel
31
+
32
+
33
+ class RBLNStableDiffusionPipeline(StableDiffusionPipeline):
34
+ @classmethod
35
+ def from_pretrained(cls, model_id, **kwargs):
36
+ """
37
+ Pipeline for text-to-image generation using Stable Diffusion.
38
+
39
+ This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods
40
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
41
+
42
+ It implements the methods to convert a pre-trained Stable Diffusion pipeline into a RBLNStableDiffusion pipeline by:
43
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
44
+ - compiling the resulting graph using the RBLN compiler.
45
+
46
+ Args:
47
+ model_id (`Union[str, Path]`):
48
+ Can be either:
49
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
50
+ - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
51
+ """
52
+ export = kwargs.pop("export", None)
53
+ model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
54
+ if export is None or export is False:
55
+ return model
56
+
57
+ rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
58
+
59
+ # compile model, create runtime
60
+ vae = RBLNAutoencoderKL.from_pretrained(
61
+ model_id=model_id,
62
+ subfolder="vae",
63
+ export=True,
64
+ rbln_unet_sample_size=model.unet.config.sample_size,
65
+ rbln_use_encode=False,
66
+ **rbln_config_kwargs,
67
+ **rbln_constructor_kwargs,
68
+ )
69
+ text_encoder = RBLNCLIPTextModel.from_pretrained(
70
+ model_id=model_id,
71
+ subfolder="text_encoder",
72
+ export=True,
73
+ **rbln_config_kwargs,
74
+ **rbln_constructor_kwargs,
75
+ )
76
+
77
+ batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
78
+ unet_batch_size = batch_size * 2
79
+ unet = RBLNUNet2DConditionModel.from_pretrained(
80
+ model_id=model_id,
81
+ subfolder="unet",
82
+ export=True,
83
+ rbln_max_seq_len=text_encoder.config.max_position_embeddings,
84
+ rbln_batch_size=unet_batch_size,
85
+ rbln_use_encode=False,
86
+ rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
87
+ **rbln_config_kwargs,
88
+ **rbln_constructor_kwargs,
89
+ )
90
+
91
+ # replace modules
92
+ model.vae = vae
93
+ model.text_encoder = text_encoder
94
+ model.unet = unet
95
+
96
+ # update config to be able to load from file.
97
+ update_dict = {
98
+ "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
99
+ "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
100
+ "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
101
+ }
102
+ model.register_to_config(**update_dict)
103
+
104
+ model.models = [vae.model[0], text_encoder.model[0], unet.model[0]]
105
+
106
+ return model
@@ -0,0 +1,116 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+ """RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
24
+
25
+ from pathlib import Path
26
+ from tempfile import TemporaryDirectory
27
+
28
+ from diffusers import StableDiffusionImg2ImgPipeline
29
+
30
+ from ....modeling_base import RBLNBaseModel
31
+ from ....transformers import RBLNCLIPTextModel
32
+ from ...models import RBLNAutoencoderKL, RBLNUNet2DConditionModel
33
+
34
+
35
+ class RBLNStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
36
+ @classmethod
37
+ def from_pretrained(cls, model_id, **kwargs):
38
+ """
39
+ Pipeline for image-to-image generation using Stable Diffusion.
40
+
41
+ This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods
42
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
43
+
44
+ It implements the methods to convert a pre-trained Stable Diffusion pipeline into a RBLNStableDiffusion pipeline by:
45
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
46
+ - compiling the resulting graph using the RBLN compiler.
47
+
48
+ Args:
49
+ model_id (`Union[str, Path]`):
50
+ Can be either:
51
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
52
+ - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
53
+ """
54
+ export = kwargs.pop("export", None)
55
+ model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
56
+ if export is None or export is False:
57
+ return model
58
+
59
+ rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
60
+
61
+ save_dir = TemporaryDirectory()
62
+ save_dir_path = Path(save_dir.name)
63
+ model.save_pretrained(save_directory=save_dir_path, **kwargs)
64
+
65
+ # compile model, create runtime
66
+ vae = RBLNAutoencoderKL.from_pretrained(
67
+ model_id=model_id,
68
+ subfolder="vae",
69
+ export=True,
70
+ rbln_unet_sample_size=model.unet.config.sample_size,
71
+ rbln_use_encode=True,
72
+ rbln_vae_scale_factor=model.vae_scale_factor,
73
+ **rbln_config_kwargs,
74
+ **rbln_constructor_kwargs,
75
+ )
76
+ text_encoder = RBLNCLIPTextModel.from_pretrained(
77
+ model_id=model_id,
78
+ subfolder="text_encoder",
79
+ export=True,
80
+ **rbln_config_kwargs,
81
+ **rbln_constructor_kwargs,
82
+ )
83
+
84
+ batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
85
+ unet_batch_size = batch_size * 2
86
+
87
+ unet = RBLNUNet2DConditionModel.from_pretrained(
88
+ model_id=model_id,
89
+ subfolder="unet",
90
+ export=True,
91
+ rbln_max_seq_len=text_encoder.config.max_position_embeddings,
92
+ rbln_batch_size=unet_batch_size,
93
+ rbln_use_encode=True,
94
+ rbln_vae_scale_factor=model.vae_scale_factor,
95
+ rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
96
+ **rbln_config_kwargs,
97
+ **rbln_constructor_kwargs,
98
+ )
99
+
100
+ # replace modules
101
+ model.vae = vae
102
+ model.text_encoder = text_encoder
103
+ model.unet = unet
104
+
105
+ # update config to be able to load from file.
106
+ update_dict = {
107
+ "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
108
+ "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
109
+ "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
110
+ }
111
+ model.register_to_config(**update_dict)
112
+
113
+ # vae encoder, vae decoder, text_encoder, unet
114
+ model.models = [vae.model[0], vae.model[1], text_encoder.model[0], unet.model[0]]
115
+
116
+ return model
@@ -0,0 +1,2 @@
1
+ from .pipeline_stable_diffusion_xl import RBLNStableDiffusionXLPipeline
2
+ from .pipeline_stable_diffusion_xl_img2img import RBLNStableDiffusionXLImg2ImgPipeline
@@ -0,0 +1,109 @@
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
16
+
17
+
18
+ from diffusers import StableDiffusionXLPipeline
19
+
20
+ from ....modeling_base import RBLNBaseModel
21
+ from ....transformers import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
22
+ from ...models import RBLNAutoencoderKL, RBLNUNet2DConditionModel
23
+
24
+
25
+ class RBLNStableDiffusionXLPipeline(StableDiffusionXLPipeline):
26
+ @classmethod
27
+ def from_pretrained(cls, model_id, **kwargs):
28
+ """
29
+ Pipeline for text-to-image generation using Stable Diffusion XL.
30
+
31
+ This model inherits from [`StableDiffusionXLPipeline`]. Check the superclass documentation for the generic methods the
32
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
33
+
34
+ It implements the methods to convert a pre-trained StableDiffusionXL pipeline into a RBLNStableDiffusionXL pipeline by:
35
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
36
+ - compiling the resulting graph using the RBLN compiler.
37
+
38
+ Args:
39
+ model_id (`Union[str, Path]`):
40
+ Can be either:
41
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
42
+ - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
43
+ """
44
+ export = kwargs.pop("export", None)
45
+ model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
46
+ if export is None or export is False:
47
+ return model
48
+
49
+ rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
50
+
51
+ do_classifier_free_guidance = (
52
+ rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
53
+ )
54
+
55
+ vae = RBLNAutoencoderKL.from_pretrained(
56
+ model_id=model_id,
57
+ subfolder="vae",
58
+ export=True,
59
+ rbln_unet_sample_size=model.unet.config.sample_size,
60
+ rbln_use_encode=False,
61
+ **rbln_config_kwargs,
62
+ **rbln_constructor_kwargs,
63
+ )
64
+ text_encoder = RBLNCLIPTextModel.from_pretrained(
65
+ model_id=model_id,
66
+ subfolder="text_encoder",
67
+ export=True,
68
+ **rbln_config_kwargs,
69
+ **rbln_constructor_kwargs,
70
+ )
71
+ text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
72
+ model_id=model_id,
73
+ subfolder="text_encoder_2",
74
+ export=True,
75
+ **rbln_config_kwargs,
76
+ **rbln_constructor_kwargs,
77
+ )
78
+
79
+ batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
80
+ unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
81
+
82
+ unet = RBLNUNet2DConditionModel.from_pretrained(
83
+ model_id=model_id,
84
+ subfolder="unet",
85
+ export=True,
86
+ rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
87
+ rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
88
+ rbln_batch_size=unet_batch_size,
89
+ rbln_use_encode=False,
90
+ rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
91
+ **rbln_config_kwargs,
92
+ **rbln_constructor_kwargs,
93
+ )
94
+
95
+ model.vae = vae
96
+ model.text_encoder = text_encoder
97
+ model.unet = unet
98
+ model.text_encoder_2 = text_encoder_2
99
+ update_dict = {
100
+ "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
101
+ "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
102
+ "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
103
+ "text_encoder_2": ("optimum.rbln", "RBLNCLIPTextModel"),
104
+ }
105
+ model.register_to_config(**update_dict)
106
+
107
+ model.models = [vae.model[0], unet.model[0], text_encoder.model[0], text_encoder_2.model[0]]
108
+
109
+ return model
@@ -0,0 +1,111 @@
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
16
+
17
+
18
+ from diffusers import StableDiffusionXLImg2ImgPipeline
19
+
20
+ from ....modeling_base import RBLNBaseModel
21
+ from ....transformers import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
22
+ from ...models import RBLNAutoencoderKL, RBLNUNet2DConditionModel
23
+
24
+
25
+ class RBLNStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
26
+ @classmethod
27
+ def from_pretrained(cls, model_id, **kwargs):
28
+ """
29
+ Pipeline for image-to-image generation using Stable Diffusion XL.
30
+
31
+ This model inherits from [`StableDiffusionXLPipeline`]. Check the superclass documentation for the generic methods the
32
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
33
+
34
+ It implements the methods to convert a pre-trained StableDiffusionXL pipeline into a RBLNStableDiffusionXL pipeline by:
35
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
36
+ - compiling the resulting graph using the RBLN compiler.
37
+
38
+ Args:
39
+ model_id (`Union[str, Path]`):
40
+ Can be either:
41
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
42
+ - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
43
+ """
44
+ export = kwargs.pop("export", None)
45
+ model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
46
+ if export is None or export is False:
47
+ return model
48
+
49
+ rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
50
+
51
+ do_classifier_free_guidance = (
52
+ rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
53
+ )
54
+
55
+ vae = RBLNAutoencoderKL.from_pretrained(
56
+ model_id=model_id,
57
+ subfolder="vae",
58
+ export=True,
59
+ rbln_unet_sample_size=model.unet.config.sample_size,
60
+ rbln_use_encode=True,
61
+ rbln_vae_scale_factor=model.vae_scale_factor,
62
+ **rbln_config_kwargs,
63
+ **rbln_constructor_kwargs,
64
+ )
65
+ text_encoder = RBLNCLIPTextModel.from_pretrained(
66
+ model_id=model_id,
67
+ subfolder="text_encoder",
68
+ export=True,
69
+ **rbln_config_kwargs,
70
+ **rbln_constructor_kwargs,
71
+ )
72
+ text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
73
+ model_id=model_id,
74
+ subfolder="text_encoder_2",
75
+ export=True,
76
+ **rbln_config_kwargs,
77
+ **rbln_constructor_kwargs,
78
+ )
79
+
80
+ batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
81
+ unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
82
+
83
+ unet = RBLNUNet2DConditionModel.from_pretrained(
84
+ model_id=model_id,
85
+ subfolder="unet",
86
+ export=True,
87
+ rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
88
+ rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
89
+ rbln_batch_size=unet_batch_size,
90
+ rbln_use_encode=True,
91
+ rbln_vae_scale_factor=model.vae_scale_factor,
92
+ rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
93
+ **rbln_config_kwargs,
94
+ **rbln_constructor_kwargs,
95
+ )
96
+
97
+ model.vae = vae
98
+ model.text_encoder = text_encoder
99
+ model.unet = unet
100
+ model.text_encoder_2 = text_encoder_2
101
+ update_dict = {
102
+ "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
103
+ "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
104
+ "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
105
+ "text_encoder_2": ("optimum.rbln", "RBLNCLIPTextModel"),
106
+ }
107
+ model.register_to_config(**update_dict)
108
+
109
+ model.models = [vae.model[0], vae.model[1], unet.model[0], text_encoder.model[0], text_encoder_2.model[0]]
110
+
111
+ return model
File without changes
@@ -0,0 +1,49 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_base import (
25
+ RBLNModelForAudioClassification,
26
+ RBLNModelForImageClassification,
27
+ RBLNModelForQuestionAnswering,
28
+ )
29
+ from .modeling_seq2seq import RBLNModelForSeq2SeqLM
30
+
31
+
32
+ class RBLNASTForAudioClassification(RBLNModelForAudioClassification):
33
+ pass
34
+
35
+
36
+ class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
37
+ pass
38
+
39
+
40
+ class RBLNResNetForImageClassification(RBLNModelForImageClassification):
41
+ pass
42
+
43
+
44
+ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
45
+ pass
46
+
47
+
48
+ class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
49
+ pass