optimum-rbln 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. optimum/rbln/__init__.py +115 -0
  2. optimum/rbln/__version__.py +1 -0
  3. optimum/rbln/diffusers/__init__.py +64 -0
  4. optimum/rbln/diffusers/models/__init__.py +26 -0
  5. optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
  6. optimum/rbln/diffusers/models/controlnet.py +180 -0
  7. optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
  8. optimum/rbln/diffusers/pipelines/__init__.py +30 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
  10. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
  18. optimum/rbln/modeling.py +0 -0
  19. optimum/rbln/modeling_alias.py +49 -0
  20. optimum/rbln/modeling_base.py +645 -0
  21. optimum/rbln/modeling_config.py +169 -0
  22. optimum/rbln/modeling_seq2seq.py +469 -0
  23. optimum/rbln/transformers/__init__.py +59 -0
  24. optimum/rbln/transformers/generation/__init__.py +24 -0
  25. optimum/rbln/transformers/generation/streamers.py +122 -0
  26. optimum/rbln/transformers/models/__init__.py +28 -0
  27. optimum/rbln/transformers/models/bart/__init__.py +24 -0
  28. optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
  29. optimum/rbln/transformers/models/clip/__init__.py +24 -0
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
  31. optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
  32. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
  33. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
  34. optimum/rbln/transformers/models/llama/__init__.py +24 -0
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
  37. optimum/rbln/transformers/models/t5/__init__.py +24 -0
  38. optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
  39. optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
  40. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
  41. optimum/rbln/transformers/models/whisper/__init__.py +24 -0
  42. optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
  43. optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
  44. optimum/rbln/utils/__init__.py +25 -0
  45. optimum/rbln/utils/import_utils.py +28 -0
  46. optimum/rbln/utils/runtime_utils.py +71 -0
  47. optimum/rbln/utils/save_utils.py +92 -0
  48. optimum_rbln-0.1.0.dist-info/METADATA +144 -0
  49. optimum_rbln-0.1.0.dist-info/RECORD +51 -0
  50. optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
  51. optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,115 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from typing import TYPE_CHECKING
25
+
26
+ from transformers.utils import _LazyModule
27
+
28
+
29
+ _import_structure = {
30
+ "modeling_alias": [
31
+ "RBLNASTForAudioClassification",
32
+ "RBLNBertForQuestionAnswering",
33
+ "RBLNResNetForImageClassification",
34
+ "RBLNT5ForConditionalGeneration",
35
+ "RBLNBartForConditionalGeneration",
36
+ ],
37
+ "modeling_base": [
38
+ "RBLNBaseModel",
39
+ "RBLNModel",
40
+ "RBLNModelForQuestionAnswering",
41
+ "RBLNModelForAudioClassification",
42
+ "RBLNModelForImageClassification",
43
+ ],
44
+ "modeling_seq2seq": [
45
+ "RBLNModelForSeq2SeqLM",
46
+ ],
47
+ "transformers": [
48
+ "BatchTextIteratorStreamer",
49
+ "RBLNCLIPTextModel",
50
+ "RBLNCLIPTextModelWithProjection",
51
+ "RBLNGPT2LMHeadModel",
52
+ "RBLNWav2Vec2ForCTC",
53
+ "RBLNLlamaForCausalLM",
54
+ "RBLNWhisperForConditionalGeneration",
55
+ ],
56
+ "diffusers": [
57
+ "RBLNStableDiffusionPipeline",
58
+ "RBLNStableDiffusionXLPipeline",
59
+ "RBLNAutoencoderKL",
60
+ "RBLNUNet2DConditionModel",
61
+ "RBLNControlNetModel",
62
+ "RBLNStableDiffusionImg2ImgPipeline",
63
+ "RBLNStableDiffusionControlNetImg2ImgPipeline",
64
+ "RBLNMultiControlNetModel",
65
+ "RBLNStableDiffusionXLImg2ImgPipeline",
66
+ ],
67
+ "modeling_config": ["RBLNRuntimeConfig", "RBLNConfig"],
68
+ }
69
+
70
+ if TYPE_CHECKING:
71
+ from .diffusers import (
72
+ RBLNAutoencoderKL,
73
+ RBLNControlNetModel,
74
+ RBLNMultiControlNetModel,
75
+ RBLNStableDiffusionControlNetImg2ImgPipeline,
76
+ RBLNStableDiffusionImg2ImgPipeline,
77
+ RBLNStableDiffusionPipeline,
78
+ RBLNStableDiffusionXLImg2ImgPipeline,
79
+ RBLNStableDiffusionXLPipeline,
80
+ RBLNUNet2DConditionModel,
81
+ )
82
+ from .modeling_alias import (
83
+ RBLNASTForAudioClassification,
84
+ RBLNBartForConditionalGeneration,
85
+ RBLNBertForQuestionAnswering,
86
+ RBLNResNetForImageClassification,
87
+ RBLNT5ForConditionalGeneration,
88
+ )
89
+ from .modeling_base import (
90
+ RBLNBaseModel,
91
+ RBLNModel,
92
+ RBLNModelForAudioClassification,
93
+ RBLNModelForImageClassification,
94
+ RBLNModelForQuestionAnswering,
95
+ )
96
+ from .modeling_config import RBLNConfig, RBLNRuntimeConfig
97
+ from .modeling_seq2seq import RBLNModelForSeq2SeqLM
98
+ from .transformers import (
99
+ BatchTextIteratorStreamer,
100
+ RBLNCLIPTextModel,
101
+ RBLNCLIPTextModelWithProjection,
102
+ RBLNGPT2LMHeadModel,
103
+ RBLNLlamaForCausalLM,
104
+ RBLNWav2Vec2ForCTC,
105
+ RBLNWhisperForConditionalGeneration,
106
+ )
107
+ else:
108
+ import sys
109
+
110
+ sys.modules[__name__] = _LazyModule(
111
+ __name__,
112
+ globals()["__file__"],
113
+ _import_structure,
114
+ module_spec=__spec__,
115
+ )
@@ -0,0 +1 @@
1
+ __version__ = '0.1.0'
@@ -0,0 +1,64 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from typing import TYPE_CHECKING
25
+
26
+ from diffusers.pipelines.pipeline_utils import ALL_IMPORTABLE_CLASSES, LOADABLE_CLASSES
27
+ from transformers.utils import _LazyModule
28
+
29
+
30
+ LOADABLE_CLASSES["optimum.rbln"] = {"RBLNBaseModel": ["save_pretrained", "from_pretrained"]}
31
+ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES["optimum.rbln"])
32
+
33
+
34
+ _import_structure = {
35
+ "pipelines": [
36
+ "RBLNStableDiffusionPipeline",
37
+ "RBLNStableDiffusionXLPipeline",
38
+ "RBLNStableDiffusionImg2ImgPipeline",
39
+ "RBLNStableDiffusionControlNetImg2ImgPipeline",
40
+ "RBLNMultiControlNetModel",
41
+ "RBLNStableDiffusionXLImg2ImgPipeline",
42
+ ],
43
+ "models": ["RBLNAutoencoderKL", "RBLNUNet2DConditionModel", "RBLNControlNetModel"],
44
+ }
45
+
46
+ if TYPE_CHECKING:
47
+ from .models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
48
+ from .pipelines import (
49
+ RBLNMultiControlNetModel,
50
+ RBLNStableDiffusionControlNetImg2ImgPipeline,
51
+ RBLNStableDiffusionImg2ImgPipeline,
52
+ RBLNStableDiffusionPipeline,
53
+ RBLNStableDiffusionXLImg2ImgPipeline,
54
+ RBLNStableDiffusionXLPipeline,
55
+ )
56
+ else:
57
+ import sys
58
+
59
+ sys.modules[__name__] = _LazyModule(
60
+ __name__,
61
+ globals()["__file__"],
62
+ _import_structure,
63
+ module_spec=__spec__,
64
+ )
@@ -0,0 +1,26 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .autoencoder_kl import RBLNAutoencoderKL
25
+ from .controlnet import RBLNControlNetModel
26
+ from .unet_2d_condition import RBLNUNet2DConditionModel
@@ -0,0 +1,313 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import logging
25
+ from pathlib import Path
26
+ from tempfile import TemporaryDirectory
27
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
28
+
29
+ import rebel
30
+ import torch
31
+ from diffusers import AutoencoderKL
32
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
33
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
34
+ from optimum.exporters import TasksManager
35
+ from transformers import AutoConfig, AutoModel, PretrainedConfig
36
+
37
+ from ...modeling_base import RBLNModel
38
+ from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
39
+ from ...utils.runtime_utils import RBLNPytorchRuntime
40
+ from ...utils.save_utils import maybe_save_preprocessors
41
+
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+ if TYPE_CHECKING:
46
+ import torch
47
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
48
+
49
+
50
+ class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
51
+ def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
52
+ moments = self.forward(x.contiguous())
53
+ posterior = DiagonalGaussianDistribution(moments)
54
+ return AutoencoderKLOutput(latent_dist=posterior)
55
+
56
+
57
+ class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
58
+ def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
59
+ return (self.forward(z),)
60
+
61
+
62
+ class RBLNAutoencoderKL(RBLNModel):
63
+ model_type = "rbln_model"
64
+ config_name = "config.json"
65
+ auto_model_class = AutoModel # feature extraction
66
+
67
+ def __post_init__(self, **kwargs):
68
+ self.dtype = torch.float32
69
+
70
+ self.rbln_use_encode = self.rbln_config.meta["rbln_use_encode"]
71
+
72
+ if self.rbln_use_encode:
73
+ self.encoder = RBLNRuntimeVAEEncoder(runtime=self.runtimes[0], main_input_name="x")
74
+ self.decoder = RBLNRuntimeVAEDecoder(runtime=self.runtimes[1], main_input_name="z")
75
+ else:
76
+ self.decoder = RBLNRuntimeVAEDecoder(runtime=self.runtimes[0], main_input_name="z")
77
+
78
+ @classmethod
79
+ @torch.no_grad()
80
+ def _export(
81
+ cls,
82
+ model_id: str,
83
+ config: "PretrainedConfig",
84
+ use_auth_token: Optional[Union[bool, str]] = None,
85
+ revision: Optional[str] = None,
86
+ force_download: bool = False,
87
+ cache_dir: Optional[str] = None,
88
+ subfolder: str = "",
89
+ local_files_only: bool = False,
90
+ trust_remote_code: bool = False,
91
+ **kwargs,
92
+ ) -> "RBLNAutoencoderKL":
93
+ task = kwargs.pop("task", None)
94
+ if task is None:
95
+ task = TasksManager.infer_task_from_model(cls.auto_model_class)
96
+
97
+ save_dir = TemporaryDirectory()
98
+ save_dir_path = Path(save_dir.name)
99
+
100
+ rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
101
+
102
+ model: AutoencoderKL = TasksManager.get_model_from_task(
103
+ task=None,
104
+ model_name_or_path=model_id,
105
+ subfolder=subfolder,
106
+ revision=revision,
107
+ framework="pt",
108
+ cache_dir=cache_dir,
109
+ use_auth_token=use_auth_token,
110
+ local_files_only=local_files_only,
111
+ force_download=force_download,
112
+ trust_remote_code=trust_remote_code,
113
+ **kwargs,
114
+ )
115
+
116
+ if config is None:
117
+ config = model.config
118
+
119
+ if not isinstance(config, PretrainedConfig): # diffusers config
120
+ config = PretrainedConfig(**config)
121
+
122
+ config.save_pretrained(save_dir_path)
123
+ preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
124
+
125
+ # Get compilation arguments
126
+ if rbln_config_kwargs.get("rbln_config", None) is None:
127
+ rbln_config = cls.get_rbln_config(
128
+ preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
129
+ )
130
+
131
+ def compile_img2img():
132
+ encoder_model = _VAEEncoder(model)
133
+ decoder_model = _VAEDecoder(model)
134
+ encoder_model.eval()
135
+ decoder_model.eval()
136
+
137
+ enc_compiled_model = cls.compile(encoder_model, rbln_runtime_config=rbln_config["encoder"][0])
138
+ dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["decoder"][0])
139
+
140
+ enc_compiled_model.save(save_dir_path / f"{rbln_config['encoder'][0].compiled_model_name}.rbln")
141
+ dec_compiled_model.save(save_dir_path / f"{rbln_config['decoder'][0].compiled_model_name}.rbln")
142
+
143
+ def compile_text2img():
144
+ decoder_model = _VAEDecoder(model)
145
+ decoder_model.eval()
146
+
147
+ dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["compiled_model"][0])
148
+
149
+ dec_compiled_model.save(save_dir_path / f"{rbln_config['compiled_model'][0].compiled_model_name}.rbln")
150
+
151
+ if rbln_config_kwargs.get("rbln_use_encode"):
152
+ compile_img2img()
153
+ else:
154
+ compile_text2img()
155
+
156
+ rbln_config.save(save_dir_path)
157
+
158
+ return cls._from_pretrained(
159
+ model_id=save_dir_path,
160
+ config=config,
161
+ model_save_dir=save_dir,
162
+ **rbln_constructor_kwargs,
163
+ **kwargs,
164
+ )
165
+
166
+ @classmethod
167
+ def from_pretrained(cls, *args, **kwargs):
168
+ def get_model_from_task(
169
+ task: str,
170
+ model_name_or_path: Union[str, Path],
171
+ **kwargs,
172
+ ):
173
+ return AutoencoderKL.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
174
+
175
+ tasktmp = TasksManager.get_model_from_task
176
+ configtmp = AutoConfig.from_pretrained
177
+ modeltmp = AutoModel.from_pretrained
178
+ TasksManager.get_model_from_task = get_model_from_task
179
+
180
+ if kwargs.get("export", None):
181
+ # This is an ad-hoc to workaround save null values of the config.
182
+ # if export, pure optimum(not optimum-rbln) loads config using AutoConfig
183
+ # and diffusers model do not support loading by AutoConfig.
184
+ AutoConfig.from_pretrained = lambda *args, **kwargs: None
185
+ else:
186
+ AutoConfig.from_pretrained = AutoencoderKL.load_config
187
+
188
+ AutoModel.from_pretrained = AutoencoderKL.from_pretrained
189
+ rt = super().from_pretrained(*args, **kwargs)
190
+ AutoConfig.from_pretrained = configtmp
191
+ AutoModel.from_pretrained = modeltmp
192
+ TasksManager.get_model_from_task = tasktmp
193
+ return rt
194
+
195
+ @classmethod
196
+ def _get_rbln_config(
197
+ cls,
198
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
199
+ model_config: "PretrainedConfig",
200
+ rbln_unet_sample_size: Optional[int] = None,
201
+ rbln_img_width: Optional[int] = None,
202
+ rbln_img_height: Optional[int] = None,
203
+ rbln_batch_size: Optional[int] = None,
204
+ rbln_use_encode: Optional[bool] = None,
205
+ rbln_vae_scale_factor: Optional[int] = None,
206
+ ) -> RBLNConfig:
207
+ meta = {}
208
+ if rbln_batch_size is None:
209
+ rbln_batch_size = 1
210
+
211
+ meta["rbln_use_encode"] = rbln_use_encode
212
+ meta["rbln_batch_size"] = rbln_batch_size
213
+
214
+ if rbln_use_encode:
215
+ meta["rbln_img_width"] = rbln_img_width
216
+ meta["rbln_img_height"] = rbln_img_height
217
+
218
+ vae_enc_input_info = [
219
+ ("x", [rbln_batch_size, model_config.in_channels, rbln_img_width, rbln_img_height], "float32")
220
+ ]
221
+ vae_dec_input_info = [
222
+ (
223
+ "z",
224
+ [
225
+ rbln_batch_size,
226
+ model_config.latent_channels,
227
+ rbln_img_width // rbln_vae_scale_factor,
228
+ rbln_img_height // rbln_vae_scale_factor,
229
+ ],
230
+ "float32",
231
+ )
232
+ ]
233
+
234
+ enc_rbln_runtime_config = RBLNRuntimeConfig(compiled_model_name="encoder", input_info=vae_enc_input_info)
235
+ dec_rbln_runtime_config = RBLNRuntimeConfig(compiled_model_name="decoder", input_info=vae_dec_input_info)
236
+
237
+ rbln_config = RBLNConfig.from_rbln_runtime_configs(
238
+ [enc_rbln_runtime_config, dec_rbln_runtime_config],
239
+ _rbln_meta=meta,
240
+ )
241
+ return rbln_config
242
+
243
+ if rbln_unet_sample_size is None:
244
+ rbln_unet_sample_size = 64
245
+
246
+ meta["rbln_unet_sample_size"] = rbln_unet_sample_size
247
+ vae_config = RBLNRuntimeConfig(
248
+ input_info=[
249
+ (
250
+ "z",
251
+ [
252
+ rbln_batch_size,
253
+ model_config.latent_channels,
254
+ rbln_unet_sample_size,
255
+ rbln_unet_sample_size,
256
+ ],
257
+ "float32",
258
+ )
259
+ ],
260
+ )
261
+ rbln_config = RBLNConfig.from_rbln_runtime_configs([vae_config], _rbln_meta=meta)
262
+ return rbln_config
263
+
264
+ def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
265
+ if len(self.compiled_models) == 1:
266
+ device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
267
+ return [self.compiled_models[0].create_runtime(tensor_type="pt", device=device_val)]
268
+
269
+ device_vals = [rbln_device_map["encoder"], rbln_device_map["decoder"]]
270
+ return [
271
+ compiled_model.create_runtime(tensor_type="pt", device=device_val)
272
+ for compiled_model, device_val in zip(self.compiled_models, device_vals)
273
+ ]
274
+
275
+ def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
276
+ posterior = self.encoder.encode(x)
277
+ return AutoencoderKLOutput(latent_dist=posterior)
278
+
279
+ def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
280
+ return self.decoder.decode(z)
281
+
282
+
283
+ class _VAEDecoder(torch.nn.Module):
284
+ def __init__(self, vae: "AutoencoderKL"):
285
+ super().__init__()
286
+ self.vae = vae
287
+
288
+ def forward(self, z):
289
+ vae_out = self.vae.decode(z, return_dict=False)
290
+ return vae_out
291
+
292
+
293
+ class _VAEEncoder(torch.nn.Module):
294
+ def __init__(self, vae: "AutoencoderKL"):
295
+ super().__init__()
296
+ self.vae = vae
297
+
298
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True):
299
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
300
+ return self.tiled_encode(x, return_dict=return_dict)
301
+
302
+ if self.use_slicing and x.shape[0] > 1:
303
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
304
+ h = torch.cat(encoded_slices)
305
+ else:
306
+ h = self.encoder(x)
307
+
308
+ moments = self.quant_conv(h)
309
+ return moments
310
+
311
+ def forward(self, x):
312
+ vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
313
+ return vae_out