optimum-rbln 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +115 -0
- optimum/rbln/__version__.py +1 -0
- optimum/rbln/diffusers/__init__.py +64 -0
- optimum/rbln/diffusers/models/__init__.py +26 -0
- optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
- optimum/rbln/diffusers/models/controlnet.py +180 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
- optimum/rbln/diffusers/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
- optimum/rbln/modeling.py +0 -0
- optimum/rbln/modeling_alias.py +49 -0
- optimum/rbln/modeling_base.py +645 -0
- optimum/rbln/modeling_config.py +169 -0
- optimum/rbln/modeling_seq2seq.py +469 -0
- optimum/rbln/transformers/__init__.py +59 -0
- optimum/rbln/transformers/generation/__init__.py +24 -0
- optimum/rbln/transformers/generation/streamers.py +122 -0
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/bart/__init__.py +24 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
- optimum/rbln/transformers/models/clip/__init__.py +24 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
- optimum/rbln/transformers/models/llama/__init__.py +24 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
- optimum/rbln/transformers/models/t5/__init__.py +24 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
- optimum/rbln/transformers/models/whisper/__init__.py +24 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
- optimum/rbln/utils/__init__.py +25 -0
- optimum/rbln/utils/import_utils.py +28 -0
- optimum/rbln/utils/runtime_utils.py +71 -0
- optimum/rbln/utils/save_utils.py +92 -0
- optimum_rbln-0.1.0.dist-info/METADATA +144 -0
- optimum_rbln-0.1.0.dist-info/RECORD +51 -0
- optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
- optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,106 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
"""RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
|
24
|
+
|
25
|
+
|
26
|
+
from diffusers import StableDiffusionPipeline
|
27
|
+
|
28
|
+
from ....modeling_base import RBLNBaseModel
|
29
|
+
from ....transformers import RBLNCLIPTextModel
|
30
|
+
from ...models import RBLNAutoencoderKL, RBLNUNet2DConditionModel
|
31
|
+
|
32
|
+
|
33
|
+
class RBLNStableDiffusionPipeline(StableDiffusionPipeline):
|
34
|
+
@classmethod
|
35
|
+
def from_pretrained(cls, model_id, **kwargs):
|
36
|
+
"""
|
37
|
+
Pipeline for text-to-image generation using Stable Diffusion.
|
38
|
+
|
39
|
+
This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods
|
40
|
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
41
|
+
|
42
|
+
It implements the methods to convert a pre-trained Stable Diffusion pipeline into a RBLNStableDiffusion pipeline by:
|
43
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
44
|
+
- compiling the resulting graph using the RBLN compiler.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
model_id (`Union[str, Path]`):
|
48
|
+
Can be either:
|
49
|
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
50
|
+
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
51
|
+
"""
|
52
|
+
export = kwargs.pop("export", None)
|
53
|
+
model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
|
54
|
+
if export is None or export is False:
|
55
|
+
return model
|
56
|
+
|
57
|
+
rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
|
58
|
+
|
59
|
+
# compile model, create runtime
|
60
|
+
vae = RBLNAutoencoderKL.from_pretrained(
|
61
|
+
model_id=model_id,
|
62
|
+
subfolder="vae",
|
63
|
+
export=True,
|
64
|
+
rbln_unet_sample_size=model.unet.config.sample_size,
|
65
|
+
rbln_use_encode=False,
|
66
|
+
**rbln_config_kwargs,
|
67
|
+
**rbln_constructor_kwargs,
|
68
|
+
)
|
69
|
+
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
70
|
+
model_id=model_id,
|
71
|
+
subfolder="text_encoder",
|
72
|
+
export=True,
|
73
|
+
**rbln_config_kwargs,
|
74
|
+
**rbln_constructor_kwargs,
|
75
|
+
)
|
76
|
+
|
77
|
+
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
78
|
+
unet_batch_size = batch_size * 2
|
79
|
+
unet = RBLNUNet2DConditionModel.from_pretrained(
|
80
|
+
model_id=model_id,
|
81
|
+
subfolder="unet",
|
82
|
+
export=True,
|
83
|
+
rbln_max_seq_len=text_encoder.config.max_position_embeddings,
|
84
|
+
rbln_batch_size=unet_batch_size,
|
85
|
+
rbln_use_encode=False,
|
86
|
+
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
87
|
+
**rbln_config_kwargs,
|
88
|
+
**rbln_constructor_kwargs,
|
89
|
+
)
|
90
|
+
|
91
|
+
# replace modules
|
92
|
+
model.vae = vae
|
93
|
+
model.text_encoder = text_encoder
|
94
|
+
model.unet = unet
|
95
|
+
|
96
|
+
# update config to be able to load from file.
|
97
|
+
update_dict = {
|
98
|
+
"vae": ("optimum.rbln", "RBLNAutoencoderKL"),
|
99
|
+
"text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
|
100
|
+
"unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
|
101
|
+
}
|
102
|
+
model.register_to_config(**update_dict)
|
103
|
+
|
104
|
+
model.models = [vae.model[0], text_encoder.model[0], unet.model[0]]
|
105
|
+
|
106
|
+
return model
|
@@ -0,0 +1,116 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
"""RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
|
24
|
+
|
25
|
+
from pathlib import Path
|
26
|
+
from tempfile import TemporaryDirectory
|
27
|
+
|
28
|
+
from diffusers import StableDiffusionImg2ImgPipeline
|
29
|
+
|
30
|
+
from ....modeling_base import RBLNBaseModel
|
31
|
+
from ....transformers import RBLNCLIPTextModel
|
32
|
+
from ...models import RBLNAutoencoderKL, RBLNUNet2DConditionModel
|
33
|
+
|
34
|
+
|
35
|
+
class RBLNStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
36
|
+
@classmethod
|
37
|
+
def from_pretrained(cls, model_id, **kwargs):
|
38
|
+
"""
|
39
|
+
Pipeline for image-to-image generation using Stable Diffusion.
|
40
|
+
|
41
|
+
This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods
|
42
|
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
43
|
+
|
44
|
+
It implements the methods to convert a pre-trained Stable Diffusion pipeline into a RBLNStableDiffusion pipeline by:
|
45
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
46
|
+
- compiling the resulting graph using the RBLN compiler.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
model_id (`Union[str, Path]`):
|
50
|
+
Can be either:
|
51
|
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
52
|
+
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
53
|
+
"""
|
54
|
+
export = kwargs.pop("export", None)
|
55
|
+
model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
|
56
|
+
if export is None or export is False:
|
57
|
+
return model
|
58
|
+
|
59
|
+
rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
|
60
|
+
|
61
|
+
save_dir = TemporaryDirectory()
|
62
|
+
save_dir_path = Path(save_dir.name)
|
63
|
+
model.save_pretrained(save_directory=save_dir_path, **kwargs)
|
64
|
+
|
65
|
+
# compile model, create runtime
|
66
|
+
vae = RBLNAutoencoderKL.from_pretrained(
|
67
|
+
model_id=model_id,
|
68
|
+
subfolder="vae",
|
69
|
+
export=True,
|
70
|
+
rbln_unet_sample_size=model.unet.config.sample_size,
|
71
|
+
rbln_use_encode=True,
|
72
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
73
|
+
**rbln_config_kwargs,
|
74
|
+
**rbln_constructor_kwargs,
|
75
|
+
)
|
76
|
+
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
77
|
+
model_id=model_id,
|
78
|
+
subfolder="text_encoder",
|
79
|
+
export=True,
|
80
|
+
**rbln_config_kwargs,
|
81
|
+
**rbln_constructor_kwargs,
|
82
|
+
)
|
83
|
+
|
84
|
+
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
85
|
+
unet_batch_size = batch_size * 2
|
86
|
+
|
87
|
+
unet = RBLNUNet2DConditionModel.from_pretrained(
|
88
|
+
model_id=model_id,
|
89
|
+
subfolder="unet",
|
90
|
+
export=True,
|
91
|
+
rbln_max_seq_len=text_encoder.config.max_position_embeddings,
|
92
|
+
rbln_batch_size=unet_batch_size,
|
93
|
+
rbln_use_encode=True,
|
94
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
95
|
+
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
96
|
+
**rbln_config_kwargs,
|
97
|
+
**rbln_constructor_kwargs,
|
98
|
+
)
|
99
|
+
|
100
|
+
# replace modules
|
101
|
+
model.vae = vae
|
102
|
+
model.text_encoder = text_encoder
|
103
|
+
model.unet = unet
|
104
|
+
|
105
|
+
# update config to be able to load from file.
|
106
|
+
update_dict = {
|
107
|
+
"vae": ("optimum.rbln", "RBLNAutoencoderKL"),
|
108
|
+
"text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
|
109
|
+
"unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
|
110
|
+
}
|
111
|
+
model.register_to_config(**update_dict)
|
112
|
+
|
113
|
+
# vae encoder, vae decoder, text_encoder, unet
|
114
|
+
model.models = [vae.model[0], vae.model[1], text_encoder.model[0], unet.model[0]]
|
115
|
+
|
116
|
+
return model
|
@@ -0,0 +1,109 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
"""RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
|
16
|
+
|
17
|
+
|
18
|
+
from diffusers import StableDiffusionXLPipeline
|
19
|
+
|
20
|
+
from ....modeling_base import RBLNBaseModel
|
21
|
+
from ....transformers import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
|
22
|
+
from ...models import RBLNAutoencoderKL, RBLNUNet2DConditionModel
|
23
|
+
|
24
|
+
|
25
|
+
class RBLNStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
26
|
+
@classmethod
|
27
|
+
def from_pretrained(cls, model_id, **kwargs):
|
28
|
+
"""
|
29
|
+
Pipeline for text-to-image generation using Stable Diffusion XL.
|
30
|
+
|
31
|
+
This model inherits from [`StableDiffusionXLPipeline`]. Check the superclass documentation for the generic methods the
|
32
|
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
33
|
+
|
34
|
+
It implements the methods to convert a pre-trained StableDiffusionXL pipeline into a RBLNStableDiffusionXL pipeline by:
|
35
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
36
|
+
- compiling the resulting graph using the RBLN compiler.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
model_id (`Union[str, Path]`):
|
40
|
+
Can be either:
|
41
|
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
42
|
+
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
43
|
+
"""
|
44
|
+
export = kwargs.pop("export", None)
|
45
|
+
model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
|
46
|
+
if export is None or export is False:
|
47
|
+
return model
|
48
|
+
|
49
|
+
rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
|
50
|
+
|
51
|
+
do_classifier_free_guidance = (
|
52
|
+
rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
53
|
+
)
|
54
|
+
|
55
|
+
vae = RBLNAutoencoderKL.from_pretrained(
|
56
|
+
model_id=model_id,
|
57
|
+
subfolder="vae",
|
58
|
+
export=True,
|
59
|
+
rbln_unet_sample_size=model.unet.config.sample_size,
|
60
|
+
rbln_use_encode=False,
|
61
|
+
**rbln_config_kwargs,
|
62
|
+
**rbln_constructor_kwargs,
|
63
|
+
)
|
64
|
+
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
65
|
+
model_id=model_id,
|
66
|
+
subfolder="text_encoder",
|
67
|
+
export=True,
|
68
|
+
**rbln_config_kwargs,
|
69
|
+
**rbln_constructor_kwargs,
|
70
|
+
)
|
71
|
+
text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
|
72
|
+
model_id=model_id,
|
73
|
+
subfolder="text_encoder_2",
|
74
|
+
export=True,
|
75
|
+
**rbln_config_kwargs,
|
76
|
+
**rbln_constructor_kwargs,
|
77
|
+
)
|
78
|
+
|
79
|
+
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
80
|
+
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
81
|
+
|
82
|
+
unet = RBLNUNet2DConditionModel.from_pretrained(
|
83
|
+
model_id=model_id,
|
84
|
+
subfolder="unet",
|
85
|
+
export=True,
|
86
|
+
rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
|
87
|
+
rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
|
88
|
+
rbln_batch_size=unet_batch_size,
|
89
|
+
rbln_use_encode=False,
|
90
|
+
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
91
|
+
**rbln_config_kwargs,
|
92
|
+
**rbln_constructor_kwargs,
|
93
|
+
)
|
94
|
+
|
95
|
+
model.vae = vae
|
96
|
+
model.text_encoder = text_encoder
|
97
|
+
model.unet = unet
|
98
|
+
model.text_encoder_2 = text_encoder_2
|
99
|
+
update_dict = {
|
100
|
+
"vae": ("optimum.rbln", "RBLNAutoencoderKL"),
|
101
|
+
"text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
|
102
|
+
"unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
|
103
|
+
"text_encoder_2": ("optimum.rbln", "RBLNCLIPTextModel"),
|
104
|
+
}
|
105
|
+
model.register_to_config(**update_dict)
|
106
|
+
|
107
|
+
model.models = [vae.model[0], unet.model[0], text_encoder.model[0], text_encoder_2.model[0]]
|
108
|
+
|
109
|
+
return model
|
@@ -0,0 +1,111 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
"""RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
|
16
|
+
|
17
|
+
|
18
|
+
from diffusers import StableDiffusionXLImg2ImgPipeline
|
19
|
+
|
20
|
+
from ....modeling_base import RBLNBaseModel
|
21
|
+
from ....transformers import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
|
22
|
+
from ...models import RBLNAutoencoderKL, RBLNUNet2DConditionModel
|
23
|
+
|
24
|
+
|
25
|
+
class RBLNStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
|
26
|
+
@classmethod
|
27
|
+
def from_pretrained(cls, model_id, **kwargs):
|
28
|
+
"""
|
29
|
+
Pipeline for image-to-image generation using Stable Diffusion XL.
|
30
|
+
|
31
|
+
This model inherits from [`StableDiffusionXLPipeline`]. Check the superclass documentation for the generic methods the
|
32
|
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
33
|
+
|
34
|
+
It implements the methods to convert a pre-trained StableDiffusionXL pipeline into a RBLNStableDiffusionXL pipeline by:
|
35
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
36
|
+
- compiling the resulting graph using the RBLN compiler.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
model_id (`Union[str, Path]`):
|
40
|
+
Can be either:
|
41
|
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
42
|
+
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
43
|
+
"""
|
44
|
+
export = kwargs.pop("export", None)
|
45
|
+
model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
|
46
|
+
if export is None or export is False:
|
47
|
+
return model
|
48
|
+
|
49
|
+
rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
|
50
|
+
|
51
|
+
do_classifier_free_guidance = (
|
52
|
+
rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
53
|
+
)
|
54
|
+
|
55
|
+
vae = RBLNAutoencoderKL.from_pretrained(
|
56
|
+
model_id=model_id,
|
57
|
+
subfolder="vae",
|
58
|
+
export=True,
|
59
|
+
rbln_unet_sample_size=model.unet.config.sample_size,
|
60
|
+
rbln_use_encode=True,
|
61
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
62
|
+
**rbln_config_kwargs,
|
63
|
+
**rbln_constructor_kwargs,
|
64
|
+
)
|
65
|
+
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
66
|
+
model_id=model_id,
|
67
|
+
subfolder="text_encoder",
|
68
|
+
export=True,
|
69
|
+
**rbln_config_kwargs,
|
70
|
+
**rbln_constructor_kwargs,
|
71
|
+
)
|
72
|
+
text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
|
73
|
+
model_id=model_id,
|
74
|
+
subfolder="text_encoder_2",
|
75
|
+
export=True,
|
76
|
+
**rbln_config_kwargs,
|
77
|
+
**rbln_constructor_kwargs,
|
78
|
+
)
|
79
|
+
|
80
|
+
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
81
|
+
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
82
|
+
|
83
|
+
unet = RBLNUNet2DConditionModel.from_pretrained(
|
84
|
+
model_id=model_id,
|
85
|
+
subfolder="unet",
|
86
|
+
export=True,
|
87
|
+
rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
|
88
|
+
rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
|
89
|
+
rbln_batch_size=unet_batch_size,
|
90
|
+
rbln_use_encode=True,
|
91
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
92
|
+
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
93
|
+
**rbln_config_kwargs,
|
94
|
+
**rbln_constructor_kwargs,
|
95
|
+
)
|
96
|
+
|
97
|
+
model.vae = vae
|
98
|
+
model.text_encoder = text_encoder
|
99
|
+
model.unet = unet
|
100
|
+
model.text_encoder_2 = text_encoder_2
|
101
|
+
update_dict = {
|
102
|
+
"vae": ("optimum.rbln", "RBLNAutoencoderKL"),
|
103
|
+
"text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
|
104
|
+
"unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
|
105
|
+
"text_encoder_2": ("optimum.rbln", "RBLNCLIPTextModel"),
|
106
|
+
}
|
107
|
+
model.register_to_config(**update_dict)
|
108
|
+
|
109
|
+
model.models = [vae.model[0], vae.model[1], unet.model[0], text_encoder.model[0], text_encoder_2.model[0]]
|
110
|
+
|
111
|
+
return model
|
optimum/rbln/modeling.py
ADDED
File without changes
|
@@ -0,0 +1,49 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .modeling_base import (
|
25
|
+
RBLNModelForAudioClassification,
|
26
|
+
RBLNModelForImageClassification,
|
27
|
+
RBLNModelForQuestionAnswering,
|
28
|
+
)
|
29
|
+
from .modeling_seq2seq import RBLNModelForSeq2SeqLM
|
30
|
+
|
31
|
+
|
32
|
+
class RBLNASTForAudioClassification(RBLNModelForAudioClassification):
|
33
|
+
pass
|
34
|
+
|
35
|
+
|
36
|
+
class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
|
37
|
+
pass
|
38
|
+
|
39
|
+
|
40
|
+
class RBLNResNetForImageClassification(RBLNModelForImageClassification):
|
41
|
+
pass
|
42
|
+
|
43
|
+
|
44
|
+
class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
45
|
+
pass
|
46
|
+
|
47
|
+
|
48
|
+
class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
49
|
+
pass
|