optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +164 -36
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +772 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +63 -122
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +55 -70
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- optimum/rbln/modeling.py +58 -39
- optimum/rbln/modeling_base.py +107 -78
- optimum/rbln/transformers/__init__.py +87 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +108 -34
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +115 -84
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +282 -216
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/runtime_utils.py +33 -2
- optimum/rbln/utils/submodule.py +26 -43
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/METADATA +1 -1
- optimum_rbln-0.7.4a6.dist-info/RECORD +166 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/licenses/LICENSE +0 -0
@@ -12,73 +12,72 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import TYPE_CHECKING,
|
15
|
+
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
16
16
|
|
17
17
|
import rebel
|
18
|
-
import torch
|
18
|
+
import torch
|
19
19
|
from diffusers import AutoencoderKL
|
20
|
+
from diffusers.models.autoencoders.vae import DecoderOutput
|
20
21
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
21
22
|
from transformers import PretrainedConfig
|
22
23
|
|
24
|
+
from ....configuration_utils import RBLNCompileConfig
|
23
25
|
from ....modeling import RBLNModel
|
24
|
-
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
25
26
|
from ....utils.logging import get_logger
|
26
|
-
from ...
|
27
|
+
from ...configurations import RBLNAutoencoderKLConfig
|
27
28
|
from .vae import RBLNRuntimeVAEDecoder, RBLNRuntimeVAEEncoder, _VAEDecoder, _VAEEncoder
|
28
29
|
|
29
30
|
|
30
31
|
if TYPE_CHECKING:
|
31
32
|
import torch
|
32
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
33
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
34
|
+
|
35
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
33
36
|
|
34
37
|
logger = get_logger(__name__)
|
35
38
|
|
36
39
|
|
37
40
|
class RBLNAutoencoderKL(RBLNModel):
|
38
41
|
auto_model_class = AutoencoderKL
|
39
|
-
config_name = "config.json"
|
40
42
|
hf_library_name = "diffusers"
|
43
|
+
_rbln_config_class = RBLNAutoencoderKLConfig
|
41
44
|
|
42
45
|
def __post_init__(self, **kwargs):
|
43
46
|
super().__post_init__(**kwargs)
|
44
47
|
|
45
|
-
if self.rbln_config.
|
48
|
+
if self.rbln_config.uses_encoder:
|
46
49
|
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
|
47
|
-
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
|
48
50
|
else:
|
49
|
-
self.
|
51
|
+
self.encoder = None
|
50
52
|
|
51
|
-
self.
|
53
|
+
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[-1], main_input_name="z")
|
54
|
+
self.image_size = self.rbln_config.image_size
|
52
55
|
|
53
56
|
@classmethod
|
54
|
-
def get_compiled_model(cls, model, rbln_config:
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
decoder_model.eval()
|
60
|
-
|
61
|
-
enc_compiled_model = cls.compile(encoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
|
62
|
-
dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[1])
|
57
|
+
def get_compiled_model(cls, model, rbln_config: RBLNAutoencoderKLConfig) -> Dict[str, rebel.RBLNCompiledModel]:
|
58
|
+
if rbln_config.uses_encoder:
|
59
|
+
expected_models = ["encoder", "decoder"]
|
60
|
+
else:
|
61
|
+
expected_models = ["decoder"]
|
63
62
|
|
64
|
-
|
63
|
+
compiled_models = {}
|
64
|
+
for i, model_name in enumerate(expected_models):
|
65
|
+
if model_name == "encoder":
|
66
|
+
wrapped_model = _VAEEncoder(model)
|
67
|
+
else:
|
68
|
+
wrapped_model = _VAEDecoder(model)
|
65
69
|
|
66
|
-
|
67
|
-
decoder_model = _VAEDecoder(model)
|
68
|
-
decoder_model.eval()
|
70
|
+
wrapped_model.eval()
|
69
71
|
|
70
|
-
|
72
|
+
compiled_models[model_name] = cls.compile(wrapped_model, rbln_compile_config=rbln_config.compile_cfgs[i])
|
71
73
|
|
72
|
-
|
73
|
-
|
74
|
-
if rbln_config.model_cfg.get("img2img_pipeline") or rbln_config.model_cfg.get("inpaint_pipeline"):
|
75
|
-
return compile_img2img()
|
76
|
-
else:
|
77
|
-
return compile_text2img()
|
74
|
+
return compiled_models
|
78
75
|
|
79
76
|
@classmethod
|
80
|
-
def get_vae_sample_size(
|
81
|
-
|
77
|
+
def get_vae_sample_size(
|
78
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: RBLNAutoencoderKLConfig, return_vae_scale_factor: bool = False
|
79
|
+
) -> Tuple[int, int]:
|
80
|
+
sample_size = rbln_config.sample_size
|
82
81
|
noise_module = getattr(pipe, "unet", None) or getattr(pipe, "transformer", None)
|
83
82
|
vae_scale_factor = (
|
84
83
|
pipe.vae_scale_factor
|
@@ -91,139 +90,121 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
91
90
|
"Cannot find noise processing or predicting module attributes. ex. U-Net, Transformer, ..."
|
92
91
|
)
|
93
92
|
|
94
|
-
if
|
95
|
-
|
93
|
+
if sample_size is None:
|
94
|
+
sample_size = noise_module.config.sample_size
|
95
|
+
if isinstance(sample_size, int):
|
96
|
+
sample_size = (sample_size, sample_size)
|
97
|
+
sample_size = (sample_size[0] * vae_scale_factor, sample_size[1] * vae_scale_factor)
|
96
98
|
|
97
|
-
|
98
|
-
|
99
|
-
sample_size = noise_module.config.sample_size
|
100
|
-
elif rbln_config["inpaint_pipeline"]:
|
101
|
-
sample_size = noise_module.config.sample_size * vae_scale_factor
|
102
|
-
else:
|
103
|
-
# In case of text2img, sample size of vae decoder is determined by unet.
|
104
|
-
noise_module_sample_size = noise_module.config.sample_size
|
105
|
-
if isinstance(noise_module_sample_size, int):
|
106
|
-
sample_size = noise_module_sample_size * vae_scale_factor
|
107
|
-
else:
|
108
|
-
sample_size = (
|
109
|
-
noise_module_sample_size[0] * vae_scale_factor,
|
110
|
-
noise_module_sample_size[1] * vae_scale_factor,
|
111
|
-
)
|
99
|
+
if return_vae_scale_factor:
|
100
|
+
return sample_size, vae_scale_factor
|
112
101
|
else:
|
113
|
-
sample_size
|
114
|
-
|
115
|
-
return sample_size
|
102
|
+
return sample_size
|
116
103
|
|
117
104
|
@classmethod
|
118
|
-
def update_rbln_config_using_pipe(
|
119
|
-
|
105
|
+
def update_rbln_config_using_pipe(
|
106
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
107
|
+
) -> "RBLNDiffusionMixinConfig":
|
108
|
+
rbln_config.vae.sample_size, rbln_config.vae.vae_scale_factor = cls.get_vae_sample_size(
|
109
|
+
pipe, rbln_config.vae, return_vae_scale_factor=True
|
110
|
+
)
|
120
111
|
return rbln_config
|
121
112
|
|
122
113
|
@classmethod
|
123
|
-
def
|
114
|
+
def _update_rbln_config(
|
124
115
|
cls,
|
125
116
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
117
|
+
model: "PreTrainedModel",
|
126
118
|
model_config: "PretrainedConfig",
|
127
|
-
|
128
|
-
) ->
|
129
|
-
|
130
|
-
|
131
|
-
is_img2img = rbln_kwargs.get("img2img_pipeline")
|
132
|
-
is_inpaint = rbln_kwargs.get("inpaint_pipeline")
|
133
|
-
|
134
|
-
if rbln_batch_size is None:
|
135
|
-
rbln_batch_size = 1
|
136
|
-
|
137
|
-
if sample_size is None:
|
138
|
-
sample_size = model_config.sample_size
|
119
|
+
rbln_config: RBLNAutoencoderKLConfig,
|
120
|
+
) -> RBLNAutoencoderKLConfig:
|
121
|
+
if rbln_config.sample_size is None:
|
122
|
+
rbln_config.sample_size = model_config.sample_size
|
139
123
|
|
140
|
-
if isinstance(sample_size, int):
|
141
|
-
sample_size = (sample_size, sample_size)
|
124
|
+
if isinstance(rbln_config.sample_size, int):
|
125
|
+
rbln_config.sample_size = (rbln_config.sample_size, rbln_config.sample_size)
|
142
126
|
|
143
|
-
|
127
|
+
if rbln_config.in_channels is None:
|
128
|
+
rbln_config.in_channels = model_config.in_channels
|
144
129
|
|
145
|
-
if
|
146
|
-
|
147
|
-
else:
|
148
|
-
# vae image processor default value 8 (int)
|
149
|
-
vae_scale_factor = 8
|
130
|
+
if rbln_config.latent_channels is None:
|
131
|
+
rbln_config.latent_channels = model_config.latent_channels
|
150
132
|
|
151
|
-
|
152
|
-
|
133
|
+
if rbln_config.vae_scale_factor is None:
|
134
|
+
if hasattr(model_config, "block_out_channels"):
|
135
|
+
rbln_config.vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
|
136
|
+
else:
|
137
|
+
# vae image processor default value 8 (int)
|
138
|
+
rbln_config.vae_scale_factor = 8
|
153
139
|
|
154
|
-
|
140
|
+
compile_cfgs = []
|
141
|
+
if rbln_config.uses_encoder:
|
155
142
|
vae_enc_input_info = [
|
156
143
|
(
|
157
144
|
"x",
|
158
|
-
[
|
145
|
+
[
|
146
|
+
rbln_config.batch_size,
|
147
|
+
rbln_config.in_channels,
|
148
|
+
rbln_config.sample_size[0],
|
149
|
+
rbln_config.sample_size[1],
|
150
|
+
],
|
159
151
|
"float32",
|
160
152
|
)
|
161
153
|
]
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
rbln_config = RBLNConfig(
|
175
|
-
rbln_cls=cls.__name__,
|
176
|
-
compile_cfgs=compile_cfgs,
|
177
|
-
rbln_kwargs=rbln_kwargs,
|
154
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
|
155
|
+
|
156
|
+
vae_dec_input_info = [
|
157
|
+
(
|
158
|
+
"z",
|
159
|
+
[
|
160
|
+
rbln_config.batch_size,
|
161
|
+
rbln_config.latent_channels,
|
162
|
+
rbln_config.latent_sample_size[0],
|
163
|
+
rbln_config.latent_sample_size[1],
|
164
|
+
],
|
165
|
+
"float32",
|
178
166
|
)
|
179
|
-
|
167
|
+
]
|
168
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
|
180
169
|
|
181
|
-
|
182
|
-
input_info=[
|
183
|
-
(
|
184
|
-
"z",
|
185
|
-
[rbln_batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
|
186
|
-
"float32",
|
187
|
-
)
|
188
|
-
]
|
189
|
-
)
|
190
|
-
rbln_config = RBLNConfig(
|
191
|
-
rbln_cls=cls.__name__,
|
192
|
-
compile_cfgs=[vae_config],
|
193
|
-
rbln_kwargs=rbln_kwargs,
|
194
|
-
)
|
170
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
195
171
|
return rbln_config
|
196
172
|
|
197
173
|
@classmethod
|
198
174
|
def _create_runtimes(
|
199
175
|
cls,
|
200
176
|
compiled_models: List[rebel.RBLNCompiledModel],
|
201
|
-
|
202
|
-
activate_profiler: Optional[bool] = None,
|
177
|
+
rbln_config: RBLNAutoencoderKLConfig,
|
203
178
|
) -> List[rebel.Runtime]:
|
204
179
|
if len(compiled_models) == 1:
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
compiled_models[0].create_runtime(
|
211
|
-
tensor_type="pt", device=device_val, activate_profiler=activate_profiler
|
212
|
-
)
|
213
|
-
]
|
180
|
+
# decoder
|
181
|
+
expected_models = ["decoder"]
|
182
|
+
else:
|
183
|
+
# encoder, decoder
|
184
|
+
expected_models = ["encoder", "decoder"]
|
214
185
|
|
215
|
-
if any(model_name not in
|
216
|
-
cls._raise_missing_compiled_file_error(
|
186
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_models):
|
187
|
+
cls._raise_missing_compiled_file_error(expected_models)
|
217
188
|
|
218
|
-
device_vals = [
|
189
|
+
device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
|
219
190
|
return [
|
220
|
-
|
191
|
+
rebel.Runtime(
|
192
|
+
compiled_model,
|
193
|
+
tensor_type="pt",
|
194
|
+
device=device_val,
|
195
|
+
activate_profiler=rbln_config.activate_profiler,
|
196
|
+
)
|
221
197
|
for compiled_model, device_val in zip(compiled_models, device_vals)
|
222
198
|
]
|
223
199
|
|
224
|
-
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
200
|
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
|
225
201
|
posterior = self.encoder.encode(x)
|
202
|
+
if not return_dict:
|
203
|
+
return (posterior,)
|
226
204
|
return AutoencoderKLOutput(latent_dist=posterior)
|
227
205
|
|
228
|
-
def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
229
|
-
|
206
|
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
|
207
|
+
dec = self.decoder.decode(z)
|
208
|
+
if not return_dict:
|
209
|
+
return (dec,)
|
210
|
+
return DecoderOutput(sample=dec)
|
@@ -14,11 +14,9 @@
|
|
14
14
|
|
15
15
|
from typing import TYPE_CHECKING, List
|
16
16
|
|
17
|
-
import torch
|
17
|
+
import torch
|
18
18
|
from diffusers import AutoencoderKL, VQModel
|
19
19
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
20
|
-
from diffusers.models.autoencoders.vq_model import VQEncoderOutput
|
21
|
-
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
22
20
|
|
23
21
|
from ....utils.logging import get_logger
|
24
22
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
@@ -34,12 +32,12 @@ class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
|
34
32
|
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
35
33
|
moments = self.forward(x.contiguous())
|
36
34
|
posterior = DiagonalGaussianDistribution(moments)
|
37
|
-
return
|
35
|
+
return posterior
|
38
36
|
|
39
37
|
|
40
38
|
class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
|
41
39
|
def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
42
|
-
return
|
40
|
+
return self.forward(z)
|
43
41
|
|
44
42
|
|
45
43
|
class _VAEDecoder(torch.nn.Module):
|
@@ -78,7 +76,7 @@ class _VAEEncoder(torch.nn.Module):
|
|
78
76
|
class RBLNRuntimeVQEncoder(RBLNPytorchRuntime):
|
79
77
|
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
80
78
|
h = self.forward(x.contiguous())
|
81
|
-
return
|
79
|
+
return h
|
82
80
|
|
83
81
|
|
84
82
|
class RBLNRuntimeVQDecoder(RBLNPytorchRuntime):
|
@@ -12,24 +12,24 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import TYPE_CHECKING,
|
15
|
+
from typing import TYPE_CHECKING, List, Union
|
16
16
|
|
17
17
|
import rebel
|
18
18
|
import torch
|
19
19
|
from diffusers import VQModel
|
20
20
|
from diffusers.models.autoencoders.vae import DecoderOutput
|
21
21
|
from diffusers.models.autoencoders.vq_model import VQEncoderOutput
|
22
|
-
from transformers import PretrainedConfig
|
23
22
|
|
23
|
+
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
24
24
|
from ....modeling import RBLNModel
|
25
|
-
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
26
25
|
from ....utils.logging import get_logger
|
27
|
-
from ...
|
26
|
+
from ...configurations.models.configuration_vq_model import RBLNVQModelConfig
|
27
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
28
28
|
from .vae import RBLNRuntimeVQDecoder, RBLNRuntimeVQEncoder, _VQDecoder, _VQEncoder
|
29
29
|
|
30
30
|
|
31
31
|
if TYPE_CHECKING:
|
32
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
32
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
33
33
|
|
34
34
|
logger = get_logger(__name__)
|
35
35
|
|
@@ -42,126 +42,125 @@ class RBLNVQModel(RBLNModel):
|
|
42
42
|
def __post_init__(self, **kwargs):
|
43
43
|
super().__post_init__(**kwargs)
|
44
44
|
|
45
|
-
|
46
|
-
|
45
|
+
if self.rbln_config.uses_encoder:
|
46
|
+
self.encoder = RBLNRuntimeVQEncoder(runtime=self.model[0], main_input_name="x")
|
47
|
+
else:
|
48
|
+
self.encoder = None
|
49
|
+
|
50
|
+
self.decoder = RBLNRuntimeVQDecoder(runtime=self.model[-1], main_input_name="z")
|
47
51
|
self.decoder.lookup_from_codebook = self.config.lookup_from_codebook
|
48
|
-
|
49
|
-
width = self.rbln_config.model_cfg.get("img_width", 512)
|
50
|
-
self.image_size = [height, width]
|
52
|
+
self.image_size = self.rbln_config.image_size
|
51
53
|
|
52
54
|
@classmethod
|
53
|
-
def get_compiled_model(cls, model, rbln_config:
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
55
|
+
def get_compiled_model(cls, model, rbln_config: RBLNModelConfig):
|
56
|
+
if rbln_config.uses_encoder:
|
57
|
+
expected_models = ["encoder", "decoder"]
|
58
|
+
else:
|
59
|
+
expected_models = ["decoder"]
|
58
60
|
|
59
|
-
|
60
|
-
|
61
|
+
compiled_models = {}
|
62
|
+
for i, model_name in enumerate(expected_models):
|
63
|
+
if model_name == "encoder":
|
64
|
+
wrapped_model = _VQEncoder(model)
|
65
|
+
else:
|
66
|
+
wrapped_model = _VQDecoder(model)
|
61
67
|
|
62
|
-
|
68
|
+
wrapped_model.eval()
|
63
69
|
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
if batch_size is None:
|
68
|
-
batch_size = 1
|
69
|
-
img_height = rbln_config.get("img_height")
|
70
|
-
if img_height is None:
|
71
|
-
img_height = 512
|
72
|
-
img_width = rbln_config.get("img_width")
|
73
|
-
if img_width is None:
|
74
|
-
img_width = 512
|
75
|
-
|
76
|
-
rbln_config.update(
|
77
|
-
{
|
78
|
-
"batch_size": batch_size,
|
79
|
-
"img_height": img_height,
|
80
|
-
"img_width": img_width,
|
81
|
-
}
|
82
|
-
)
|
70
|
+
compiled_models[model_name] = cls.compile(wrapped_model, rbln_compile_config=rbln_config.compile_cfgs[i])
|
71
|
+
|
72
|
+
return compiled_models
|
83
73
|
|
74
|
+
@classmethod
|
75
|
+
def update_rbln_config_using_pipe(
|
76
|
+
cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
77
|
+
) -> "RBLNDiffusionMixinConfig":
|
84
78
|
return rbln_config
|
85
79
|
|
86
80
|
@classmethod
|
87
|
-
def
|
81
|
+
def _update_rbln_config(
|
88
82
|
cls,
|
89
83
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
84
|
+
model: "PreTrainedModel",
|
90
85
|
model_config: "PretrainedConfig",
|
91
|
-
|
92
|
-
) ->
|
93
|
-
batch_size = rbln_kwargs.get("batch_size")
|
94
|
-
if batch_size is None:
|
95
|
-
batch_size = 1
|
96
|
-
|
97
|
-
height = rbln_kwargs.get("img_height")
|
98
|
-
if height is None:
|
99
|
-
height = 512
|
100
|
-
|
101
|
-
width = rbln_kwargs.get("img_width")
|
102
|
-
if width is None:
|
103
|
-
width = 512
|
104
|
-
|
86
|
+
rbln_config: RBLNVQModelConfig,
|
87
|
+
) -> RBLNVQModelConfig:
|
105
88
|
if hasattr(model_config, "block_out_channels"):
|
106
|
-
|
89
|
+
rbln_config.vqmodel_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
|
107
90
|
else:
|
108
91
|
# image processor default value 8 (int)
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
92
|
+
rbln_config.vqmodel_scale_factor = 8
|
93
|
+
|
94
|
+
compile_cfgs = []
|
95
|
+
if rbln_config.uses_encoder:
|
96
|
+
enc_input_info = [
|
97
|
+
(
|
98
|
+
"x",
|
99
|
+
[
|
100
|
+
rbln_config.batch_size,
|
101
|
+
model_config.in_channels,
|
102
|
+
rbln_config.sample_size[0],
|
103
|
+
rbln_config.sample_size[1],
|
104
|
+
],
|
105
|
+
"float32",
|
106
|
+
)
|
107
|
+
]
|
108
|
+
enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
109
|
+
compile_cfgs.append(enc_rbln_compile_config)
|
113
110
|
|
114
|
-
enc_input_info = [
|
115
|
-
(
|
116
|
-
"x",
|
117
|
-
[batch_size, model_config.in_channels, enc_shape[0], enc_shape[1]],
|
118
|
-
"float32",
|
119
|
-
)
|
120
|
-
]
|
121
111
|
dec_input_info = [
|
122
112
|
(
|
123
113
|
"h",
|
124
|
-
[
|
114
|
+
[
|
115
|
+
rbln_config.batch_size,
|
116
|
+
model_config.latent_channels,
|
117
|
+
rbln_config.latent_sample_size[0],
|
118
|
+
rbln_config.latent_sample_size[1],
|
119
|
+
],
|
125
120
|
"float32",
|
126
121
|
)
|
127
122
|
]
|
128
|
-
|
129
|
-
enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
130
123
|
dec_rbln_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
124
|
+
compile_cfgs.append(dec_rbln_compile_config)
|
131
125
|
|
132
|
-
compile_cfgs
|
133
|
-
rbln_config = RBLNConfig(
|
134
|
-
rbln_cls=cls.__name__,
|
135
|
-
compile_cfgs=compile_cfgs,
|
136
|
-
rbln_kwargs=rbln_kwargs,
|
137
|
-
)
|
126
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
138
127
|
return rbln_config
|
139
128
|
|
140
129
|
@classmethod
|
141
130
|
def _create_runtimes(
|
142
131
|
cls,
|
143
132
|
compiled_models: List[rebel.RBLNCompiledModel],
|
144
|
-
|
145
|
-
activate_profiler: Optional[bool] = None,
|
133
|
+
rbln_config: RBLNVQModelConfig,
|
146
134
|
) -> List[rebel.Runtime]:
|
147
135
|
if len(compiled_models) == 1:
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
]
|
136
|
+
# decoder
|
137
|
+
expected_models = ["decoder"]
|
138
|
+
else:
|
139
|
+
# encoder, decoder
|
140
|
+
expected_models = ["encoder", "decoder"]
|
154
141
|
|
155
|
-
|
142
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_models):
|
143
|
+
cls._raise_missing_compiled_file_error(expected_models)
|
144
|
+
|
145
|
+
device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
|
156
146
|
return [
|
157
|
-
|
147
|
+
rebel.Runtime(
|
148
|
+
compiled_model,
|
149
|
+
tensor_type="pt",
|
150
|
+
device=device_val,
|
151
|
+
activate_profiler=rbln_config.activate_profiler,
|
152
|
+
)
|
158
153
|
for compiled_model, device_val in zip(compiled_models, device_vals)
|
159
154
|
]
|
160
155
|
|
161
|
-
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
156
|
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
|
162
157
|
posterior = self.encoder.encode(x)
|
158
|
+
if not return_dict:
|
159
|
+
return (posterior,)
|
163
160
|
return VQEncoderOutput(latents=posterior)
|
164
161
|
|
165
|
-
def decode(self, h: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
162
|
+
def decode(self, h: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
|
166
163
|
dec, commit_loss = self.decoder.decode(h, **kwargs)
|
164
|
+
if not return_dict:
|
165
|
+
return (dec, commit_loss)
|
167
166
|
return DecoderOutput(sample=dec, commit_loss=commit_loss)
|