optimum-rbln 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +115 -0
- optimum/rbln/__version__.py +1 -0
- optimum/rbln/diffusers/__init__.py +64 -0
- optimum/rbln/diffusers/models/__init__.py +26 -0
- optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
- optimum/rbln/diffusers/models/controlnet.py +180 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
- optimum/rbln/diffusers/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
- optimum/rbln/modeling.py +0 -0
- optimum/rbln/modeling_alias.py +49 -0
- optimum/rbln/modeling_base.py +645 -0
- optimum/rbln/modeling_config.py +169 -0
- optimum/rbln/modeling_seq2seq.py +469 -0
- optimum/rbln/transformers/__init__.py +59 -0
- optimum/rbln/transformers/generation/__init__.py +24 -0
- optimum/rbln/transformers/generation/streamers.py +122 -0
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/bart/__init__.py +24 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
- optimum/rbln/transformers/models/clip/__init__.py +24 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
- optimum/rbln/transformers/models/llama/__init__.py +24 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
- optimum/rbln/transformers/models/t5/__init__.py +24 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
- optimum/rbln/transformers/models/whisper/__init__.py +24 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
- optimum/rbln/utils/__init__.py +25 -0
- optimum/rbln/utils/import_utils.py +28 -0
- optimum/rbln/utils/runtime_utils.py +71 -0
- optimum/rbln/utils/save_utils.py +92 -0
- optimum_rbln-0.1.0.dist-info/METADATA +144 -0
- optimum_rbln-0.1.0.dist-info/RECORD +51 -0
- optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
- optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
optimum/rbln/__init__.py
ADDED
@@ -0,0 +1,115 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from typing import TYPE_CHECKING
|
25
|
+
|
26
|
+
from transformers.utils import _LazyModule
|
27
|
+
|
28
|
+
|
29
|
+
_import_structure = {
|
30
|
+
"modeling_alias": [
|
31
|
+
"RBLNASTForAudioClassification",
|
32
|
+
"RBLNBertForQuestionAnswering",
|
33
|
+
"RBLNResNetForImageClassification",
|
34
|
+
"RBLNT5ForConditionalGeneration",
|
35
|
+
"RBLNBartForConditionalGeneration",
|
36
|
+
],
|
37
|
+
"modeling_base": [
|
38
|
+
"RBLNBaseModel",
|
39
|
+
"RBLNModel",
|
40
|
+
"RBLNModelForQuestionAnswering",
|
41
|
+
"RBLNModelForAudioClassification",
|
42
|
+
"RBLNModelForImageClassification",
|
43
|
+
],
|
44
|
+
"modeling_seq2seq": [
|
45
|
+
"RBLNModelForSeq2SeqLM",
|
46
|
+
],
|
47
|
+
"transformers": [
|
48
|
+
"BatchTextIteratorStreamer",
|
49
|
+
"RBLNCLIPTextModel",
|
50
|
+
"RBLNCLIPTextModelWithProjection",
|
51
|
+
"RBLNGPT2LMHeadModel",
|
52
|
+
"RBLNWav2Vec2ForCTC",
|
53
|
+
"RBLNLlamaForCausalLM",
|
54
|
+
"RBLNWhisperForConditionalGeneration",
|
55
|
+
],
|
56
|
+
"diffusers": [
|
57
|
+
"RBLNStableDiffusionPipeline",
|
58
|
+
"RBLNStableDiffusionXLPipeline",
|
59
|
+
"RBLNAutoencoderKL",
|
60
|
+
"RBLNUNet2DConditionModel",
|
61
|
+
"RBLNControlNetModel",
|
62
|
+
"RBLNStableDiffusionImg2ImgPipeline",
|
63
|
+
"RBLNStableDiffusionControlNetImg2ImgPipeline",
|
64
|
+
"RBLNMultiControlNetModel",
|
65
|
+
"RBLNStableDiffusionXLImg2ImgPipeline",
|
66
|
+
],
|
67
|
+
"modeling_config": ["RBLNRuntimeConfig", "RBLNConfig"],
|
68
|
+
}
|
69
|
+
|
70
|
+
if TYPE_CHECKING:
|
71
|
+
from .diffusers import (
|
72
|
+
RBLNAutoencoderKL,
|
73
|
+
RBLNControlNetModel,
|
74
|
+
RBLNMultiControlNetModel,
|
75
|
+
RBLNStableDiffusionControlNetImg2ImgPipeline,
|
76
|
+
RBLNStableDiffusionImg2ImgPipeline,
|
77
|
+
RBLNStableDiffusionPipeline,
|
78
|
+
RBLNStableDiffusionXLImg2ImgPipeline,
|
79
|
+
RBLNStableDiffusionXLPipeline,
|
80
|
+
RBLNUNet2DConditionModel,
|
81
|
+
)
|
82
|
+
from .modeling_alias import (
|
83
|
+
RBLNASTForAudioClassification,
|
84
|
+
RBLNBartForConditionalGeneration,
|
85
|
+
RBLNBertForQuestionAnswering,
|
86
|
+
RBLNResNetForImageClassification,
|
87
|
+
RBLNT5ForConditionalGeneration,
|
88
|
+
)
|
89
|
+
from .modeling_base import (
|
90
|
+
RBLNBaseModel,
|
91
|
+
RBLNModel,
|
92
|
+
RBLNModelForAudioClassification,
|
93
|
+
RBLNModelForImageClassification,
|
94
|
+
RBLNModelForQuestionAnswering,
|
95
|
+
)
|
96
|
+
from .modeling_config import RBLNConfig, RBLNRuntimeConfig
|
97
|
+
from .modeling_seq2seq import RBLNModelForSeq2SeqLM
|
98
|
+
from .transformers import (
|
99
|
+
BatchTextIteratorStreamer,
|
100
|
+
RBLNCLIPTextModel,
|
101
|
+
RBLNCLIPTextModelWithProjection,
|
102
|
+
RBLNGPT2LMHeadModel,
|
103
|
+
RBLNLlamaForCausalLM,
|
104
|
+
RBLNWav2Vec2ForCTC,
|
105
|
+
RBLNWhisperForConditionalGeneration,
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
import sys
|
109
|
+
|
110
|
+
sys.modules[__name__] = _LazyModule(
|
111
|
+
__name__,
|
112
|
+
globals()["__file__"],
|
113
|
+
_import_structure,
|
114
|
+
module_spec=__spec__,
|
115
|
+
)
|
@@ -0,0 +1 @@
|
|
1
|
+
__version__ = '0.1.0'
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from typing import TYPE_CHECKING
|
25
|
+
|
26
|
+
from diffusers.pipelines.pipeline_utils import ALL_IMPORTABLE_CLASSES, LOADABLE_CLASSES
|
27
|
+
from transformers.utils import _LazyModule
|
28
|
+
|
29
|
+
|
30
|
+
LOADABLE_CLASSES["optimum.rbln"] = {"RBLNBaseModel": ["save_pretrained", "from_pretrained"]}
|
31
|
+
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES["optimum.rbln"])
|
32
|
+
|
33
|
+
|
34
|
+
_import_structure = {
|
35
|
+
"pipelines": [
|
36
|
+
"RBLNStableDiffusionPipeline",
|
37
|
+
"RBLNStableDiffusionXLPipeline",
|
38
|
+
"RBLNStableDiffusionImg2ImgPipeline",
|
39
|
+
"RBLNStableDiffusionControlNetImg2ImgPipeline",
|
40
|
+
"RBLNMultiControlNetModel",
|
41
|
+
"RBLNStableDiffusionXLImg2ImgPipeline",
|
42
|
+
],
|
43
|
+
"models": ["RBLNAutoencoderKL", "RBLNUNet2DConditionModel", "RBLNControlNetModel"],
|
44
|
+
}
|
45
|
+
|
46
|
+
if TYPE_CHECKING:
|
47
|
+
from .models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
|
48
|
+
from .pipelines import (
|
49
|
+
RBLNMultiControlNetModel,
|
50
|
+
RBLNStableDiffusionControlNetImg2ImgPipeline,
|
51
|
+
RBLNStableDiffusionImg2ImgPipeline,
|
52
|
+
RBLNStableDiffusionPipeline,
|
53
|
+
RBLNStableDiffusionXLImg2ImgPipeline,
|
54
|
+
RBLNStableDiffusionXLPipeline,
|
55
|
+
)
|
56
|
+
else:
|
57
|
+
import sys
|
58
|
+
|
59
|
+
sys.modules[__name__] = _LazyModule(
|
60
|
+
__name__,
|
61
|
+
globals()["__file__"],
|
62
|
+
_import_structure,
|
63
|
+
module_spec=__spec__,
|
64
|
+
)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .autoencoder_kl import RBLNAutoencoderKL
|
25
|
+
from .controlnet import RBLNControlNetModel
|
26
|
+
from .unet_2d_condition import RBLNUNet2DConditionModel
|
@@ -0,0 +1,313 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import logging
|
25
|
+
from pathlib import Path
|
26
|
+
from tempfile import TemporaryDirectory
|
27
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
28
|
+
|
29
|
+
import rebel
|
30
|
+
import torch
|
31
|
+
from diffusers import AutoencoderKL
|
32
|
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
33
|
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
34
|
+
from optimum.exporters import TasksManager
|
35
|
+
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
36
|
+
|
37
|
+
from ...modeling_base import RBLNModel
|
38
|
+
from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
39
|
+
from ...utils.runtime_utils import RBLNPytorchRuntime
|
40
|
+
from ...utils.save_utils import maybe_save_preprocessors
|
41
|
+
|
42
|
+
|
43
|
+
logger = logging.getLogger(__name__)
|
44
|
+
|
45
|
+
if TYPE_CHECKING:
|
46
|
+
import torch
|
47
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
48
|
+
|
49
|
+
|
50
|
+
class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
51
|
+
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
52
|
+
moments = self.forward(x.contiguous())
|
53
|
+
posterior = DiagonalGaussianDistribution(moments)
|
54
|
+
return AutoencoderKLOutput(latent_dist=posterior)
|
55
|
+
|
56
|
+
|
57
|
+
class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
|
58
|
+
def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
59
|
+
return (self.forward(z),)
|
60
|
+
|
61
|
+
|
62
|
+
class RBLNAutoencoderKL(RBLNModel):
|
63
|
+
model_type = "rbln_model"
|
64
|
+
config_name = "config.json"
|
65
|
+
auto_model_class = AutoModel # feature extraction
|
66
|
+
|
67
|
+
def __post_init__(self, **kwargs):
|
68
|
+
self.dtype = torch.float32
|
69
|
+
|
70
|
+
self.rbln_use_encode = self.rbln_config.meta["rbln_use_encode"]
|
71
|
+
|
72
|
+
if self.rbln_use_encode:
|
73
|
+
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.runtimes[0], main_input_name="x")
|
74
|
+
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.runtimes[1], main_input_name="z")
|
75
|
+
else:
|
76
|
+
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.runtimes[0], main_input_name="z")
|
77
|
+
|
78
|
+
@classmethod
|
79
|
+
@torch.no_grad()
|
80
|
+
def _export(
|
81
|
+
cls,
|
82
|
+
model_id: str,
|
83
|
+
config: "PretrainedConfig",
|
84
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
85
|
+
revision: Optional[str] = None,
|
86
|
+
force_download: bool = False,
|
87
|
+
cache_dir: Optional[str] = None,
|
88
|
+
subfolder: str = "",
|
89
|
+
local_files_only: bool = False,
|
90
|
+
trust_remote_code: bool = False,
|
91
|
+
**kwargs,
|
92
|
+
) -> "RBLNAutoencoderKL":
|
93
|
+
task = kwargs.pop("task", None)
|
94
|
+
if task is None:
|
95
|
+
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
96
|
+
|
97
|
+
save_dir = TemporaryDirectory()
|
98
|
+
save_dir_path = Path(save_dir.name)
|
99
|
+
|
100
|
+
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
101
|
+
|
102
|
+
model: AutoencoderKL = TasksManager.get_model_from_task(
|
103
|
+
task=None,
|
104
|
+
model_name_or_path=model_id,
|
105
|
+
subfolder=subfolder,
|
106
|
+
revision=revision,
|
107
|
+
framework="pt",
|
108
|
+
cache_dir=cache_dir,
|
109
|
+
use_auth_token=use_auth_token,
|
110
|
+
local_files_only=local_files_only,
|
111
|
+
force_download=force_download,
|
112
|
+
trust_remote_code=trust_remote_code,
|
113
|
+
**kwargs,
|
114
|
+
)
|
115
|
+
|
116
|
+
if config is None:
|
117
|
+
config = model.config
|
118
|
+
|
119
|
+
if not isinstance(config, PretrainedConfig): # diffusers config
|
120
|
+
config = PretrainedConfig(**config)
|
121
|
+
|
122
|
+
config.save_pretrained(save_dir_path)
|
123
|
+
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
124
|
+
|
125
|
+
# Get compilation arguments
|
126
|
+
if rbln_config_kwargs.get("rbln_config", None) is None:
|
127
|
+
rbln_config = cls.get_rbln_config(
|
128
|
+
preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
|
129
|
+
)
|
130
|
+
|
131
|
+
def compile_img2img():
|
132
|
+
encoder_model = _VAEEncoder(model)
|
133
|
+
decoder_model = _VAEDecoder(model)
|
134
|
+
encoder_model.eval()
|
135
|
+
decoder_model.eval()
|
136
|
+
|
137
|
+
enc_compiled_model = cls.compile(encoder_model, rbln_runtime_config=rbln_config["encoder"][0])
|
138
|
+
dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["decoder"][0])
|
139
|
+
|
140
|
+
enc_compiled_model.save(save_dir_path / f"{rbln_config['encoder'][0].compiled_model_name}.rbln")
|
141
|
+
dec_compiled_model.save(save_dir_path / f"{rbln_config['decoder'][0].compiled_model_name}.rbln")
|
142
|
+
|
143
|
+
def compile_text2img():
|
144
|
+
decoder_model = _VAEDecoder(model)
|
145
|
+
decoder_model.eval()
|
146
|
+
|
147
|
+
dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["compiled_model"][0])
|
148
|
+
|
149
|
+
dec_compiled_model.save(save_dir_path / f"{rbln_config['compiled_model'][0].compiled_model_name}.rbln")
|
150
|
+
|
151
|
+
if rbln_config_kwargs.get("rbln_use_encode"):
|
152
|
+
compile_img2img()
|
153
|
+
else:
|
154
|
+
compile_text2img()
|
155
|
+
|
156
|
+
rbln_config.save(save_dir_path)
|
157
|
+
|
158
|
+
return cls._from_pretrained(
|
159
|
+
model_id=save_dir_path,
|
160
|
+
config=config,
|
161
|
+
model_save_dir=save_dir,
|
162
|
+
**rbln_constructor_kwargs,
|
163
|
+
**kwargs,
|
164
|
+
)
|
165
|
+
|
166
|
+
@classmethod
|
167
|
+
def from_pretrained(cls, *args, **kwargs):
|
168
|
+
def get_model_from_task(
|
169
|
+
task: str,
|
170
|
+
model_name_or_path: Union[str, Path],
|
171
|
+
**kwargs,
|
172
|
+
):
|
173
|
+
return AutoencoderKL.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
|
174
|
+
|
175
|
+
tasktmp = TasksManager.get_model_from_task
|
176
|
+
configtmp = AutoConfig.from_pretrained
|
177
|
+
modeltmp = AutoModel.from_pretrained
|
178
|
+
TasksManager.get_model_from_task = get_model_from_task
|
179
|
+
|
180
|
+
if kwargs.get("export", None):
|
181
|
+
# This is an ad-hoc to workaround save null values of the config.
|
182
|
+
# if export, pure optimum(not optimum-rbln) loads config using AutoConfig
|
183
|
+
# and diffusers model do not support loading by AutoConfig.
|
184
|
+
AutoConfig.from_pretrained = lambda *args, **kwargs: None
|
185
|
+
else:
|
186
|
+
AutoConfig.from_pretrained = AutoencoderKL.load_config
|
187
|
+
|
188
|
+
AutoModel.from_pretrained = AutoencoderKL.from_pretrained
|
189
|
+
rt = super().from_pretrained(*args, **kwargs)
|
190
|
+
AutoConfig.from_pretrained = configtmp
|
191
|
+
AutoModel.from_pretrained = modeltmp
|
192
|
+
TasksManager.get_model_from_task = tasktmp
|
193
|
+
return rt
|
194
|
+
|
195
|
+
@classmethod
|
196
|
+
def _get_rbln_config(
|
197
|
+
cls,
|
198
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
199
|
+
model_config: "PretrainedConfig",
|
200
|
+
rbln_unet_sample_size: Optional[int] = None,
|
201
|
+
rbln_img_width: Optional[int] = None,
|
202
|
+
rbln_img_height: Optional[int] = None,
|
203
|
+
rbln_batch_size: Optional[int] = None,
|
204
|
+
rbln_use_encode: Optional[bool] = None,
|
205
|
+
rbln_vae_scale_factor: Optional[int] = None,
|
206
|
+
) -> RBLNConfig:
|
207
|
+
meta = {}
|
208
|
+
if rbln_batch_size is None:
|
209
|
+
rbln_batch_size = 1
|
210
|
+
|
211
|
+
meta["rbln_use_encode"] = rbln_use_encode
|
212
|
+
meta["rbln_batch_size"] = rbln_batch_size
|
213
|
+
|
214
|
+
if rbln_use_encode:
|
215
|
+
meta["rbln_img_width"] = rbln_img_width
|
216
|
+
meta["rbln_img_height"] = rbln_img_height
|
217
|
+
|
218
|
+
vae_enc_input_info = [
|
219
|
+
("x", [rbln_batch_size, model_config.in_channels, rbln_img_width, rbln_img_height], "float32")
|
220
|
+
]
|
221
|
+
vae_dec_input_info = [
|
222
|
+
(
|
223
|
+
"z",
|
224
|
+
[
|
225
|
+
rbln_batch_size,
|
226
|
+
model_config.latent_channels,
|
227
|
+
rbln_img_width // rbln_vae_scale_factor,
|
228
|
+
rbln_img_height // rbln_vae_scale_factor,
|
229
|
+
],
|
230
|
+
"float32",
|
231
|
+
)
|
232
|
+
]
|
233
|
+
|
234
|
+
enc_rbln_runtime_config = RBLNRuntimeConfig(compiled_model_name="encoder", input_info=vae_enc_input_info)
|
235
|
+
dec_rbln_runtime_config = RBLNRuntimeConfig(compiled_model_name="decoder", input_info=vae_dec_input_info)
|
236
|
+
|
237
|
+
rbln_config = RBLNConfig.from_rbln_runtime_configs(
|
238
|
+
[enc_rbln_runtime_config, dec_rbln_runtime_config],
|
239
|
+
_rbln_meta=meta,
|
240
|
+
)
|
241
|
+
return rbln_config
|
242
|
+
|
243
|
+
if rbln_unet_sample_size is None:
|
244
|
+
rbln_unet_sample_size = 64
|
245
|
+
|
246
|
+
meta["rbln_unet_sample_size"] = rbln_unet_sample_size
|
247
|
+
vae_config = RBLNRuntimeConfig(
|
248
|
+
input_info=[
|
249
|
+
(
|
250
|
+
"z",
|
251
|
+
[
|
252
|
+
rbln_batch_size,
|
253
|
+
model_config.latent_channels,
|
254
|
+
rbln_unet_sample_size,
|
255
|
+
rbln_unet_sample_size,
|
256
|
+
],
|
257
|
+
"float32",
|
258
|
+
)
|
259
|
+
],
|
260
|
+
)
|
261
|
+
rbln_config = RBLNConfig.from_rbln_runtime_configs([vae_config], _rbln_meta=meta)
|
262
|
+
return rbln_config
|
263
|
+
|
264
|
+
def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
|
265
|
+
if len(self.compiled_models) == 1:
|
266
|
+
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
267
|
+
return [self.compiled_models[0].create_runtime(tensor_type="pt", device=device_val)]
|
268
|
+
|
269
|
+
device_vals = [rbln_device_map["encoder"], rbln_device_map["decoder"]]
|
270
|
+
return [
|
271
|
+
compiled_model.create_runtime(tensor_type="pt", device=device_val)
|
272
|
+
for compiled_model, device_val in zip(self.compiled_models, device_vals)
|
273
|
+
]
|
274
|
+
|
275
|
+
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
276
|
+
posterior = self.encoder.encode(x)
|
277
|
+
return AutoencoderKLOutput(latent_dist=posterior)
|
278
|
+
|
279
|
+
def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
280
|
+
return self.decoder.decode(z)
|
281
|
+
|
282
|
+
|
283
|
+
class _VAEDecoder(torch.nn.Module):
|
284
|
+
def __init__(self, vae: "AutoencoderKL"):
|
285
|
+
super().__init__()
|
286
|
+
self.vae = vae
|
287
|
+
|
288
|
+
def forward(self, z):
|
289
|
+
vae_out = self.vae.decode(z, return_dict=False)
|
290
|
+
return vae_out
|
291
|
+
|
292
|
+
|
293
|
+
class _VAEEncoder(torch.nn.Module):
|
294
|
+
def __init__(self, vae: "AutoencoderKL"):
|
295
|
+
super().__init__()
|
296
|
+
self.vae = vae
|
297
|
+
|
298
|
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True):
|
299
|
+
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
300
|
+
return self.tiled_encode(x, return_dict=return_dict)
|
301
|
+
|
302
|
+
if self.use_slicing and x.shape[0] > 1:
|
303
|
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
304
|
+
h = torch.cat(encoded_slices)
|
305
|
+
else:
|
306
|
+
h = self.encoder(x)
|
307
|
+
|
308
|
+
moments = self.quant_conv(h)
|
309
|
+
return moments
|
310
|
+
|
311
|
+
def forward(self, x):
|
312
|
+
vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
|
313
|
+
return vae_out
|