optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import rebel
|
|
18
|
+
import torch
|
|
19
|
+
from diffusers import AutoencoderKL
|
|
20
|
+
from diffusers.models.autoencoders.vae import DecoderOutput
|
|
21
|
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
|
22
|
+
from transformers import PretrainedConfig
|
|
23
|
+
|
|
24
|
+
from ....configuration_utils import RBLNCompileConfig
|
|
25
|
+
from ....modeling import RBLNModel
|
|
26
|
+
from ....utils.logging import get_logger
|
|
27
|
+
from ...configurations import RBLNAutoencoderKLConfig
|
|
28
|
+
from .vae import RBLNRuntimeVAEDecoder, RBLNRuntimeVAEEncoder, _VAEDecoder, _VAEEncoder
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
import torch
|
|
33
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
|
34
|
+
|
|
35
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
36
|
+
|
|
37
|
+
logger = get_logger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class RBLNAutoencoderKL(RBLNModel):
|
|
41
|
+
"""
|
|
42
|
+
RBLN implementation of AutoencoderKL (VAE) for diffusion models.
|
|
43
|
+
|
|
44
|
+
This model is used to accelerate AutoencoderKL (VAE) models from diffusers library on RBLN NPUs.
|
|
45
|
+
It can be configured to include both encoder and decoder, or just the decoder part for latent-to-image
|
|
46
|
+
conversion.
|
|
47
|
+
|
|
48
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
49
|
+
the library implements for all its models.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
auto_model_class = AutoencoderKL
|
|
53
|
+
hf_library_name = "diffusers"
|
|
54
|
+
_rbln_config_class = RBLNAutoencoderKLConfig
|
|
55
|
+
|
|
56
|
+
def __post_init__(self, **kwargs):
|
|
57
|
+
super().__post_init__(**kwargs)
|
|
58
|
+
|
|
59
|
+
if self.rbln_config.uses_encoder:
|
|
60
|
+
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
|
|
61
|
+
else:
|
|
62
|
+
self.encoder = None
|
|
63
|
+
|
|
64
|
+
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[-1], main_input_name="z")
|
|
65
|
+
self.image_size = self.rbln_config.image_size
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def get_compiled_model(cls, model, rbln_config: RBLNAutoencoderKLConfig) -> Dict[str, rebel.RBLNCompiledModel]:
|
|
69
|
+
if rbln_config.uses_encoder:
|
|
70
|
+
expected_models = ["encoder", "decoder"]
|
|
71
|
+
else:
|
|
72
|
+
expected_models = ["decoder"]
|
|
73
|
+
|
|
74
|
+
compiled_models = {}
|
|
75
|
+
for i, model_name in enumerate(expected_models):
|
|
76
|
+
if model_name == "encoder":
|
|
77
|
+
wrapped_model = _VAEEncoder(model)
|
|
78
|
+
else:
|
|
79
|
+
wrapped_model = _VAEDecoder(model)
|
|
80
|
+
|
|
81
|
+
wrapped_model.eval()
|
|
82
|
+
|
|
83
|
+
compiled_models[model_name] = cls.compile(
|
|
84
|
+
wrapped_model,
|
|
85
|
+
rbln_compile_config=rbln_config.compile_cfgs[i],
|
|
86
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
87
|
+
device=rbln_config.device_map[model_name],
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
return compiled_models
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def get_vae_sample_size(
|
|
94
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: RBLNAutoencoderKLConfig, return_vae_scale_factor: bool = False
|
|
95
|
+
) -> Tuple[int, int]:
|
|
96
|
+
sample_size = rbln_config.sample_size
|
|
97
|
+
noise_module = getattr(pipe, "unet", None) or getattr(pipe, "transformer", None)
|
|
98
|
+
vae_scale_factor = (
|
|
99
|
+
pipe.vae_scale_factor
|
|
100
|
+
if hasattr(pipe, "vae_scale_factor")
|
|
101
|
+
else 2 ** (len(pipe.vae.config.block_out_channels) - 1)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
if noise_module is None:
|
|
105
|
+
raise AttributeError(
|
|
106
|
+
"Cannot find noise processing or predicting module attributes. ex. U-Net, Transformer, ..."
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
if sample_size is None:
|
|
110
|
+
sample_size = noise_module.config.sample_size
|
|
111
|
+
if isinstance(sample_size, int):
|
|
112
|
+
sample_size = (sample_size, sample_size)
|
|
113
|
+
sample_size = (sample_size[0] * vae_scale_factor, sample_size[1] * vae_scale_factor)
|
|
114
|
+
|
|
115
|
+
if return_vae_scale_factor:
|
|
116
|
+
return sample_size, vae_scale_factor
|
|
117
|
+
else:
|
|
118
|
+
return sample_size
|
|
119
|
+
|
|
120
|
+
@classmethod
|
|
121
|
+
def update_rbln_config_using_pipe(
|
|
122
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
123
|
+
) -> "RBLNDiffusionMixinConfig":
|
|
124
|
+
rbln_config.vae.sample_size, rbln_config.vae.vae_scale_factor = cls.get_vae_sample_size(
|
|
125
|
+
pipe, rbln_config.vae, return_vae_scale_factor=True
|
|
126
|
+
)
|
|
127
|
+
return rbln_config
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def _update_rbln_config(
|
|
131
|
+
cls,
|
|
132
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
133
|
+
model: "PreTrainedModel",
|
|
134
|
+
model_config: "PretrainedConfig",
|
|
135
|
+
rbln_config: RBLNAutoencoderKLConfig,
|
|
136
|
+
) -> RBLNAutoencoderKLConfig:
|
|
137
|
+
if rbln_config.sample_size is None:
|
|
138
|
+
rbln_config.sample_size = model_config.sample_size
|
|
139
|
+
|
|
140
|
+
if isinstance(rbln_config.sample_size, int):
|
|
141
|
+
rbln_config.sample_size = (rbln_config.sample_size, rbln_config.sample_size)
|
|
142
|
+
|
|
143
|
+
if rbln_config.in_channels is None:
|
|
144
|
+
rbln_config.in_channels = model_config.in_channels
|
|
145
|
+
|
|
146
|
+
if rbln_config.latent_channels is None:
|
|
147
|
+
rbln_config.latent_channels = model_config.latent_channels
|
|
148
|
+
|
|
149
|
+
if rbln_config.vae_scale_factor is None:
|
|
150
|
+
if hasattr(model_config, "block_out_channels"):
|
|
151
|
+
rbln_config.vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
|
|
152
|
+
else:
|
|
153
|
+
# vae image processor default value 8 (int)
|
|
154
|
+
rbln_config.vae_scale_factor = 8
|
|
155
|
+
|
|
156
|
+
compile_cfgs = []
|
|
157
|
+
if rbln_config.uses_encoder:
|
|
158
|
+
vae_enc_input_info = [
|
|
159
|
+
(
|
|
160
|
+
"x",
|
|
161
|
+
[
|
|
162
|
+
rbln_config.batch_size,
|
|
163
|
+
rbln_config.in_channels,
|
|
164
|
+
rbln_config.sample_size[0],
|
|
165
|
+
rbln_config.sample_size[1],
|
|
166
|
+
],
|
|
167
|
+
"float32",
|
|
168
|
+
)
|
|
169
|
+
]
|
|
170
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
|
|
171
|
+
|
|
172
|
+
vae_dec_input_info = [
|
|
173
|
+
(
|
|
174
|
+
"z",
|
|
175
|
+
[
|
|
176
|
+
rbln_config.batch_size,
|
|
177
|
+
rbln_config.latent_channels,
|
|
178
|
+
rbln_config.latent_sample_size[0],
|
|
179
|
+
rbln_config.latent_sample_size[1],
|
|
180
|
+
],
|
|
181
|
+
"float32",
|
|
182
|
+
)
|
|
183
|
+
]
|
|
184
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
|
|
185
|
+
|
|
186
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
187
|
+
return rbln_config
|
|
188
|
+
|
|
189
|
+
@classmethod
|
|
190
|
+
def _create_runtimes(
|
|
191
|
+
cls,
|
|
192
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
|
193
|
+
rbln_config: RBLNAutoencoderKLConfig,
|
|
194
|
+
) -> List[rebel.Runtime]:
|
|
195
|
+
if len(compiled_models) == 1:
|
|
196
|
+
# decoder
|
|
197
|
+
expected_models = ["decoder"]
|
|
198
|
+
else:
|
|
199
|
+
# encoder, decoder
|
|
200
|
+
expected_models = ["encoder", "decoder"]
|
|
201
|
+
|
|
202
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_models):
|
|
203
|
+
cls._raise_missing_compiled_file_error(expected_models)
|
|
204
|
+
|
|
205
|
+
device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
|
|
206
|
+
return [
|
|
207
|
+
rebel.Runtime(
|
|
208
|
+
compiled_model,
|
|
209
|
+
tensor_type="pt",
|
|
210
|
+
device=device_val,
|
|
211
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
212
|
+
timeout=rbln_config.timeout,
|
|
213
|
+
)
|
|
214
|
+
for compiled_model, device_val in zip(compiled_models, device_vals)
|
|
215
|
+
]
|
|
216
|
+
|
|
217
|
+
def encode(
|
|
218
|
+
self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Dict[str, Any]
|
|
219
|
+
) -> Union[torch.FloatTensor, AutoencoderKLOutput]:
|
|
220
|
+
"""
|
|
221
|
+
Encode an input image into a latent representation.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
x: The input image to encode.
|
|
225
|
+
return_dict:
|
|
226
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
227
|
+
kwargs: Additional arguments to pass to the encoder.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
The latent representation or AutoencoderKLOutput if return_dict=True
|
|
231
|
+
"""
|
|
232
|
+
posterior = self.encoder.encode(x)
|
|
233
|
+
if not return_dict:
|
|
234
|
+
return (posterior,)
|
|
235
|
+
return AutoencoderKLOutput(latent_dist=posterior)
|
|
236
|
+
|
|
237
|
+
def decode(
|
|
238
|
+
self, z: torch.FloatTensor, return_dict: bool = True, **kwargs: Dict[str, Any]
|
|
239
|
+
) -> Union[torch.FloatTensor, DecoderOutput]:
|
|
240
|
+
"""
|
|
241
|
+
Decode a latent representation into an image.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
z: The latent representation to decode.
|
|
245
|
+
return_dict:
|
|
246
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
247
|
+
kwargs: Additional arguments to pass to the decoder.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
The decoded image or DecoderOutput if return_dict=True
|
|
251
|
+
"""
|
|
252
|
+
dec = self.decoder.decode(z)
|
|
253
|
+
if not return_dict:
|
|
254
|
+
return (dec,)
|
|
255
|
+
return DecoderOutput(sample=dec)
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
|
16
|
+
|
|
17
|
+
import rebel
|
|
18
|
+
import torch
|
|
19
|
+
from diffusers.models.autoencoders.autoencoder_kl_cosmos import AutoencoderKLCosmos, CosmosCausalConv3d
|
|
20
|
+
from diffusers.models.autoencoders.vae import DecoderOutput
|
|
21
|
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
|
22
|
+
from torch.nn import functional as F
|
|
23
|
+
from transformers import PretrainedConfig
|
|
24
|
+
|
|
25
|
+
from ....configuration_utils import RBLNCompileConfig
|
|
26
|
+
from ....modeling import RBLNModel
|
|
27
|
+
from ....utils.logging import get_logger
|
|
28
|
+
from ...configurations import RBLNAutoencoderKLCosmosConfig
|
|
29
|
+
from .vae import RBLNRuntimeCosmosVAEDecoder, RBLNRuntimeCosmosVAEEncoder, _VAECosmosDecoder, _VAECosmosEncoder
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
import torch
|
|
34
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
35
|
+
|
|
36
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
37
|
+
|
|
38
|
+
logger = get_logger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class RBLNAutoencoderKLCosmos(RBLNModel):
|
|
42
|
+
"""
|
|
43
|
+
RBLN implementation of AutoencoderKLCosmos for diffusion models.
|
|
44
|
+
|
|
45
|
+
This model is used to accelerate AutoencoderKLCosmos models from diffusers library on RBLN NPUs.
|
|
46
|
+
It can be configured to include both encoder and decoder, or just the decoder part for latent-to-video
|
|
47
|
+
conversion.
|
|
48
|
+
|
|
49
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
50
|
+
the library implements for all its models.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
auto_model_class = AutoencoderKLCosmos
|
|
54
|
+
hf_library_name = "diffusers"
|
|
55
|
+
_rbln_config_class = RBLNAutoencoderKLCosmosConfig
|
|
56
|
+
|
|
57
|
+
def __post_init__(self, **kwargs):
|
|
58
|
+
super().__post_init__(**kwargs)
|
|
59
|
+
|
|
60
|
+
if self.rbln_config.uses_encoder:
|
|
61
|
+
self.encoder = RBLNRuntimeCosmosVAEEncoder(
|
|
62
|
+
runtime=self.model[0], main_input_name="x", use_slicing=self.rbln_config.use_slicing
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
self.decoder = RBLNRuntimeCosmosVAEDecoder(
|
|
66
|
+
runtime=self.model[-1], main_input_name="z", use_slicing=self.rbln_config.use_slicing
|
|
67
|
+
)
|
|
68
|
+
self.image_size = self.rbln_config.image_size
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def _wrap_model_if_needed(
|
|
72
|
+
cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLCosmosConfig
|
|
73
|
+
) -> torch.nn.Module:
|
|
74
|
+
decoder_model = _VAECosmosDecoder(model)
|
|
75
|
+
decoder_model.eval()
|
|
76
|
+
|
|
77
|
+
if rbln_config.uses_encoder:
|
|
78
|
+
encoder_model = _VAECosmosEncoder(model)
|
|
79
|
+
encoder_model.eval()
|
|
80
|
+
return encoder_model, decoder_model
|
|
81
|
+
else:
|
|
82
|
+
return decoder_model
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def get_compiled_model(
|
|
86
|
+
cls, model, rbln_config: RBLNAutoencoderKLCosmosConfig
|
|
87
|
+
) -> Dict[str, rebel.RBLNCompiledModel]:
|
|
88
|
+
def replaced_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
89
|
+
if self.temporal_pad != 0:
|
|
90
|
+
hidden_states_prev = hidden_states[:, :, :1, ...].repeat(1, 1, self.temporal_pad, 1, 1)
|
|
91
|
+
hidden_states = torch.cat([hidden_states_prev, hidden_states], dim=2)
|
|
92
|
+
hidden_states = F.pad(hidden_states, (*self.spatial_pad, 0, 0), mode=self.pad_mode, value=0.0)
|
|
93
|
+
return super(CosmosCausalConv3d, self).forward(hidden_states)
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
original_forward = CosmosCausalConv3d.forward
|
|
97
|
+
CosmosCausalConv3d.forward = replaced_forward
|
|
98
|
+
|
|
99
|
+
compiled_models = {}
|
|
100
|
+
if rbln_config.uses_encoder:
|
|
101
|
+
encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
102
|
+
enc_compiled_model = cls.compile(
|
|
103
|
+
encoder_model,
|
|
104
|
+
rbln_compile_config=rbln_config.compile_cfgs[0],
|
|
105
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
106
|
+
device=rbln_config.device_map["encoder"],
|
|
107
|
+
)
|
|
108
|
+
compiled_models["encoder"] = enc_compiled_model
|
|
109
|
+
else:
|
|
110
|
+
decoder_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
111
|
+
dec_compiled_model = cls.compile(
|
|
112
|
+
decoder_model,
|
|
113
|
+
rbln_compile_config=rbln_config.compile_cfgs[-1],
|
|
114
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
115
|
+
device=rbln_config.device_map["decoder"],
|
|
116
|
+
)
|
|
117
|
+
compiled_models["decoder"] = dec_compiled_model
|
|
118
|
+
|
|
119
|
+
finally:
|
|
120
|
+
CosmosCausalConv3d.forward = original_forward
|
|
121
|
+
|
|
122
|
+
return compiled_models
|
|
123
|
+
|
|
124
|
+
@classmethod
|
|
125
|
+
def update_rbln_config_using_pipe(
|
|
126
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
127
|
+
) -> "RBLNDiffusionMixinConfig":
|
|
128
|
+
rbln_config.vae.num_channels_latents = pipe.transformer.config.out_channels
|
|
129
|
+
rbln_config.vae.vae_scale_factor_temporal = pipe.vae_scale_factor_temporal
|
|
130
|
+
rbln_config.vae.vae_scale_factor_spatial = pipe.vae_scale_factor_spatial
|
|
131
|
+
return rbln_config
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def _update_rbln_config(
|
|
135
|
+
cls,
|
|
136
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
137
|
+
model: "PreTrainedModel",
|
|
138
|
+
model_config: "PretrainedConfig",
|
|
139
|
+
rbln_config: RBLNAutoencoderKLCosmosConfig,
|
|
140
|
+
) -> RBLNAutoencoderKLCosmosConfig:
|
|
141
|
+
batch_size = 1 if rbln_config.use_slicing else rbln_config.batch_size
|
|
142
|
+
compile_cfgs = []
|
|
143
|
+
if rbln_config.uses_encoder:
|
|
144
|
+
vae_enc_input_info = [
|
|
145
|
+
(
|
|
146
|
+
"x",
|
|
147
|
+
[
|
|
148
|
+
batch_size,
|
|
149
|
+
model_config.in_channels,
|
|
150
|
+
rbln_config.num_frames,
|
|
151
|
+
rbln_config.height,
|
|
152
|
+
rbln_config.width,
|
|
153
|
+
],
|
|
154
|
+
"float32",
|
|
155
|
+
),
|
|
156
|
+
]
|
|
157
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
|
|
158
|
+
|
|
159
|
+
num_latent_frames = (rbln_config.num_frames - 1) // rbln_config.vae_scale_factor_temporal + 1
|
|
160
|
+
latent_height = rbln_config.height // rbln_config.vae_scale_factor_spatial
|
|
161
|
+
latent_width = rbln_config.width // rbln_config.vae_scale_factor_spatial
|
|
162
|
+
|
|
163
|
+
vae_dec_input_info = [
|
|
164
|
+
(
|
|
165
|
+
"z",
|
|
166
|
+
[
|
|
167
|
+
batch_size,
|
|
168
|
+
rbln_config.num_channels_latents,
|
|
169
|
+
num_latent_frames,
|
|
170
|
+
latent_height,
|
|
171
|
+
latent_width,
|
|
172
|
+
],
|
|
173
|
+
"float32",
|
|
174
|
+
),
|
|
175
|
+
]
|
|
176
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
|
|
177
|
+
|
|
178
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
179
|
+
return rbln_config
|
|
180
|
+
|
|
181
|
+
@classmethod
|
|
182
|
+
def _create_runtimes(
|
|
183
|
+
cls,
|
|
184
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
|
185
|
+
rbln_config: RBLNAutoencoderKLCosmosConfig,
|
|
186
|
+
) -> List[rebel.Runtime]:
|
|
187
|
+
if len(compiled_models) == 1:
|
|
188
|
+
# decoder
|
|
189
|
+
expected_models = ["decoder"]
|
|
190
|
+
else:
|
|
191
|
+
expected_models = ["encoder", "decoder"]
|
|
192
|
+
|
|
193
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_models):
|
|
194
|
+
cls._raise_missing_compiled_file_error(expected_models)
|
|
195
|
+
|
|
196
|
+
device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
|
|
197
|
+
return [
|
|
198
|
+
rebel.Runtime(
|
|
199
|
+
compiled_model,
|
|
200
|
+
tensor_type="pt",
|
|
201
|
+
device=device_val,
|
|
202
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
203
|
+
timeout=rbln_config.timeout,
|
|
204
|
+
)
|
|
205
|
+
for compiled_model, device_val in zip(compiled_models, device_vals)
|
|
206
|
+
]
|
|
207
|
+
|
|
208
|
+
def encode(
|
|
209
|
+
self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Dict[str, Any]
|
|
210
|
+
) -> Union[torch.FloatTensor, AutoencoderKLOutput]:
|
|
211
|
+
"""
|
|
212
|
+
Encode an input video into a latent representation.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
x: The input video to encode.
|
|
216
|
+
return_dict:
|
|
217
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
218
|
+
kwargs: Additional arguments to pass to the encoder.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
The latent representation or AutoencoderKLOutput if return_dict=True
|
|
222
|
+
"""
|
|
223
|
+
posterior = self.encoder.encode(x)
|
|
224
|
+
if not return_dict:
|
|
225
|
+
return (posterior,)
|
|
226
|
+
return AutoencoderKLOutput(latent_dist=posterior)
|
|
227
|
+
|
|
228
|
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[torch.FloatTensor, DecoderOutput]:
|
|
229
|
+
"""
|
|
230
|
+
Decode a latent representation into a video.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
z: The latent representation to decode.
|
|
234
|
+
return_dict:
|
|
235
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
The decoded video or DecoderOutput if return_dict=True
|
|
239
|
+
"""
|
|
240
|
+
decoded = self.decoder.decode(z)
|
|
241
|
+
|
|
242
|
+
if not return_dict:
|
|
243
|
+
return (decoded,)
|
|
244
|
+
|
|
245
|
+
return DecoderOutput(sample=decoded)
|