optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import rebel
|
|
18
|
+
import torch # noqa: I001
|
|
19
|
+
from diffusers import AutoencoderKLTemporalDecoder
|
|
20
|
+
from diffusers.models.autoencoders.vae import DecoderOutput
|
|
21
|
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
|
22
|
+
from transformers import PretrainedConfig
|
|
23
|
+
|
|
24
|
+
from ....configuration_utils import RBLNCompileConfig
|
|
25
|
+
from ....modeling import RBLNModel
|
|
26
|
+
from ....utils.logging import get_logger
|
|
27
|
+
from ...configurations import RBLNAutoencoderKLTemporalDecoderConfig
|
|
28
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
|
29
|
+
from .vae import (
|
|
30
|
+
DiagonalGaussianDistribution,
|
|
31
|
+
RBLNRuntimeVAEDecoder,
|
|
32
|
+
RBLNRuntimeVAEEncoder,
|
|
33
|
+
_VAEEncoder,
|
|
34
|
+
_VAETemporalDecoder,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
|
40
|
+
|
|
41
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
42
|
+
|
|
43
|
+
logger = get_logger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class RBLNAutoencoderKLTemporalDecoder(RBLNModel):
|
|
47
|
+
auto_model_class = AutoencoderKLTemporalDecoder
|
|
48
|
+
hf_library_name = "diffusers"
|
|
49
|
+
_rbln_config_class = RBLNAutoencoderKLTemporalDecoderConfig
|
|
50
|
+
|
|
51
|
+
def __post_init__(self, **kwargs):
|
|
52
|
+
super().__post_init__(**kwargs)
|
|
53
|
+
|
|
54
|
+
if self.rbln_config.uses_encoder:
|
|
55
|
+
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
|
|
56
|
+
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[-1], main_input_name="z")
|
|
57
|
+
self.image_size = self.rbln_config.image_size
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def _wrap_model_if_needed(
|
|
61
|
+
cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
|
|
62
|
+
) -> torch.nn.Module:
|
|
63
|
+
decoder_model = _VAETemporalDecoder(model)
|
|
64
|
+
decoder_model.num_frames = rbln_config.decode_chunk_size
|
|
65
|
+
decoder_model.eval()
|
|
66
|
+
|
|
67
|
+
if rbln_config.uses_encoder:
|
|
68
|
+
encoder_model = _VAEEncoder(model)
|
|
69
|
+
encoder_model.eval()
|
|
70
|
+
return encoder_model, decoder_model
|
|
71
|
+
else:
|
|
72
|
+
return decoder_model
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def get_compiled_model(
|
|
76
|
+
cls, model, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
|
|
77
|
+
) -> Dict[str, rebel.RBLNCompiledModel]:
|
|
78
|
+
compiled_models = {}
|
|
79
|
+
if rbln_config.uses_encoder:
|
|
80
|
+
encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
81
|
+
enc_compiled_model = cls.compile(
|
|
82
|
+
encoder_model,
|
|
83
|
+
rbln_compile_config=rbln_config.compile_cfgs[0],
|
|
84
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
85
|
+
device=rbln_config.device_map["encoder"],
|
|
86
|
+
)
|
|
87
|
+
compiled_models["encoder"] = enc_compiled_model
|
|
88
|
+
else:
|
|
89
|
+
decoder_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
90
|
+
dec_compiled_model = cls.compile(
|
|
91
|
+
decoder_model,
|
|
92
|
+
rbln_compile_config=rbln_config.compile_cfgs[-1],
|
|
93
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
94
|
+
device=rbln_config.device_map["decoder"],
|
|
95
|
+
)
|
|
96
|
+
compiled_models["decoder"] = dec_compiled_model
|
|
97
|
+
|
|
98
|
+
return compiled_models
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def get_vae_sample_size(
|
|
102
|
+
cls,
|
|
103
|
+
pipe: "RBLNDiffusionMixin",
|
|
104
|
+
rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
|
|
105
|
+
return_vae_scale_factor: bool = False,
|
|
106
|
+
) -> Tuple[int, int]:
|
|
107
|
+
sample_size = rbln_config.sample_size
|
|
108
|
+
if hasattr(pipe, "vae_scale_factor"):
|
|
109
|
+
vae_scale_factor = pipe.vae_scale_factor
|
|
110
|
+
else:
|
|
111
|
+
if hasattr(pipe.vae.config, "block_out_channels"):
|
|
112
|
+
vae_scale_factor = 2 ** (len(pipe.vae.config.block_out_channels) - 1)
|
|
113
|
+
else:
|
|
114
|
+
vae_scale_factor = 8 # vae image processor default value 8 (int)
|
|
115
|
+
|
|
116
|
+
if sample_size is None:
|
|
117
|
+
sample_size = pipe.unet.config.sample_size
|
|
118
|
+
if isinstance(sample_size, int):
|
|
119
|
+
sample_size = (sample_size, sample_size)
|
|
120
|
+
sample_size = (sample_size[0] * vae_scale_factor, sample_size[1] * vae_scale_factor)
|
|
121
|
+
|
|
122
|
+
if return_vae_scale_factor:
|
|
123
|
+
return sample_size, vae_scale_factor
|
|
124
|
+
else:
|
|
125
|
+
return sample_size
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def update_rbln_config_using_pipe(
|
|
129
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
130
|
+
) -> "RBLNDiffusionMixinConfig":
|
|
131
|
+
rbln_config.vae.sample_size, rbln_config.vae.vae_scale_factor = cls.get_vae_sample_size(
|
|
132
|
+
pipe, rbln_config.vae, return_vae_scale_factor=True
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
if rbln_config.vae.num_frames is None:
|
|
136
|
+
if hasattr(pipe.unet.config, "num_frames"):
|
|
137
|
+
rbln_config.vae.num_frames = pipe.unet.config.num_frames
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError("num_frames should be specified in unet config.json")
|
|
140
|
+
|
|
141
|
+
if rbln_config.vae.decode_chunk_size is None:
|
|
142
|
+
rbln_config.vae.decode_chunk_size = rbln_config.vae.num_frames
|
|
143
|
+
|
|
144
|
+
def chunk_frame(num_frames, decode_chunk_size):
|
|
145
|
+
# get closest divisor to num_frames
|
|
146
|
+
divisors = [i for i in range(1, num_frames) if num_frames % i == 0]
|
|
147
|
+
closest = min(divisors, key=lambda x: abs(x - decode_chunk_size))
|
|
148
|
+
if decode_chunk_size != closest:
|
|
149
|
+
logger.warning(
|
|
150
|
+
f"To ensure successful model compilation and prevent device OOM, {decode_chunk_size} is set to {closest}."
|
|
151
|
+
)
|
|
152
|
+
return closest
|
|
153
|
+
|
|
154
|
+
decode_chunk_size = chunk_frame(rbln_config.vae.num_frames, rbln_config.vae.decode_chunk_size)
|
|
155
|
+
rbln_config.vae.decode_chunk_size = decode_chunk_size
|
|
156
|
+
return rbln_config
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
def _update_rbln_config(
|
|
160
|
+
cls,
|
|
161
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
162
|
+
model: "PreTrainedModel",
|
|
163
|
+
model_config: "PretrainedConfig",
|
|
164
|
+
rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
|
|
165
|
+
) -> RBLNAutoencoderKLTemporalDecoderConfig:
|
|
166
|
+
if rbln_config.sample_size is None:
|
|
167
|
+
rbln_config.sample_size = model_config.sample_size
|
|
168
|
+
|
|
169
|
+
if rbln_config.vae_scale_factor is None:
|
|
170
|
+
if hasattr(model_config, "block_out_channels"):
|
|
171
|
+
rbln_config.vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
|
|
172
|
+
else:
|
|
173
|
+
# vae image processor default value 8 (int)
|
|
174
|
+
rbln_config.vae_scale_factor = 8
|
|
175
|
+
|
|
176
|
+
compile_cfgs = []
|
|
177
|
+
if rbln_config.uses_encoder:
|
|
178
|
+
vae_enc_input_info = [
|
|
179
|
+
(
|
|
180
|
+
"x",
|
|
181
|
+
[
|
|
182
|
+
rbln_config.batch_size,
|
|
183
|
+
model_config.in_channels,
|
|
184
|
+
rbln_config.sample_size[0],
|
|
185
|
+
rbln_config.sample_size[1],
|
|
186
|
+
],
|
|
187
|
+
"float32",
|
|
188
|
+
)
|
|
189
|
+
]
|
|
190
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
|
|
191
|
+
|
|
192
|
+
decode_batch_size = rbln_config.batch_size * rbln_config.decode_chunk_size
|
|
193
|
+
vae_dec_input_info = [
|
|
194
|
+
(
|
|
195
|
+
"z",
|
|
196
|
+
[
|
|
197
|
+
decode_batch_size,
|
|
198
|
+
model_config.latent_channels,
|
|
199
|
+
rbln_config.latent_sample_size[0],
|
|
200
|
+
rbln_config.latent_sample_size[1],
|
|
201
|
+
],
|
|
202
|
+
"float32",
|
|
203
|
+
)
|
|
204
|
+
]
|
|
205
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
|
|
206
|
+
|
|
207
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
208
|
+
return rbln_config
|
|
209
|
+
|
|
210
|
+
@classmethod
|
|
211
|
+
def _create_runtimes(
|
|
212
|
+
cls,
|
|
213
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
|
214
|
+
rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
|
|
215
|
+
) -> List[rebel.Runtime]:
|
|
216
|
+
if len(compiled_models) == 1:
|
|
217
|
+
# decoder
|
|
218
|
+
expected_models = ["decoder"]
|
|
219
|
+
else:
|
|
220
|
+
expected_models = ["encoder", "decoder"]
|
|
221
|
+
|
|
222
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_models):
|
|
223
|
+
cls._raise_missing_compiled_file_error(expected_models)
|
|
224
|
+
|
|
225
|
+
device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
|
|
226
|
+
return [
|
|
227
|
+
rebel.Runtime(
|
|
228
|
+
compiled_model,
|
|
229
|
+
tensor_type="pt",
|
|
230
|
+
device=device_val,
|
|
231
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
232
|
+
timeout=rbln_config.timeout,
|
|
233
|
+
)
|
|
234
|
+
for compiled_model, device_val in zip(compiled_models, device_vals)
|
|
235
|
+
]
|
|
236
|
+
|
|
237
|
+
def encode(
|
|
238
|
+
self, x: torch.FloatTensor, return_dict: bool = True
|
|
239
|
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
|
240
|
+
"""
|
|
241
|
+
Encode an input image into a latent representation.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
x: The input image to encode.
|
|
245
|
+
return_dict:
|
|
246
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
The latent representation or AutoencoderKLOutput if return_dict=True
|
|
250
|
+
"""
|
|
251
|
+
posterior = self.encoder.encode(x)
|
|
252
|
+
|
|
253
|
+
if not return_dict:
|
|
254
|
+
return (posterior,)
|
|
255
|
+
|
|
256
|
+
return AutoencoderKLOutput(latent_dist=posterior)
|
|
257
|
+
|
|
258
|
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
|
|
259
|
+
"""
|
|
260
|
+
Decode a latent representation into a video.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
z: The latent representation to decode.
|
|
264
|
+
return_dict:
|
|
265
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
The decoded video or DecoderOutput if return_dict=True
|
|
269
|
+
"""
|
|
270
|
+
decoded = self.decoder.decode(z)
|
|
271
|
+
|
|
272
|
+
if not return_dict:
|
|
273
|
+
return (decoded,)
|
|
274
|
+
|
|
275
|
+
return DecoderOutput(sample=decoded)
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, List, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, IdentityDistribution
|
|
19
|
+
|
|
20
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from diffusers import AutoencoderKL, AutoencoderKLCosmos, AutoencoderKLTemporalDecoder, VQModel
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
|
28
|
+
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
|
29
|
+
moments = self.forward(x.contiguous())
|
|
30
|
+
posterior = DiagonalGaussianDistribution(moments)
|
|
31
|
+
return posterior
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
|
|
35
|
+
def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
|
36
|
+
return self.forward(z)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RBLNRuntimeCosmosVAEEncoder(RBLNPytorchRuntime):
|
|
40
|
+
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
|
41
|
+
if self.use_slicing and x.shape[0] > 1:
|
|
42
|
+
encoded_slices = [self.forward(x_slice) for x_slice in x.split(1)]
|
|
43
|
+
h = torch.cat(encoded_slices)
|
|
44
|
+
else:
|
|
45
|
+
h = self.forward(x)
|
|
46
|
+
posterior = IdentityDistribution(h)
|
|
47
|
+
return posterior
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class RBLNRuntimeCosmosVAEDecoder(RBLNPytorchRuntime):
|
|
51
|
+
def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
|
52
|
+
if self.use_slicing and z.shape[0] > 1:
|
|
53
|
+
decoded_slices = [self.forward(z_slice) for z_slice in z.split(1)]
|
|
54
|
+
decoded = torch.cat(decoded_slices)
|
|
55
|
+
else:
|
|
56
|
+
decoded = self.forward(z)
|
|
57
|
+
return decoded
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class _VAEDecoder(torch.nn.Module):
|
|
61
|
+
def __init__(self, vae: "AutoencoderKL"):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.vae = vae
|
|
64
|
+
|
|
65
|
+
def forward(self, z):
|
|
66
|
+
vae_out = self.vae.decode(z, return_dict=False)
|
|
67
|
+
return vae_out
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class _VAETemporalDecoder(torch.nn.Module):
|
|
71
|
+
def __init__(self, vae: "AutoencoderKLTemporalDecoder"):
|
|
72
|
+
super().__init__()
|
|
73
|
+
self.vae = vae
|
|
74
|
+
self.num_frames = None
|
|
75
|
+
|
|
76
|
+
def forward(self, z):
|
|
77
|
+
vae_out = self.vae.decode(z, num_frames=self.num_frames, return_dict=False)
|
|
78
|
+
return vae_out
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class _VAEEncoder(torch.nn.Module):
|
|
82
|
+
def __init__(self, vae: Union["AutoencoderKL", "AutoencoderKLTemporalDecoder"]):
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.vae = vae
|
|
85
|
+
|
|
86
|
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True):
|
|
87
|
+
if hasattr(self, "use_tiling") and hasattr(self, "use_slicing"):
|
|
88
|
+
if self.use_tiling and (
|
|
89
|
+
x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size
|
|
90
|
+
):
|
|
91
|
+
return self.tiled_encode(x, return_dict=return_dict)
|
|
92
|
+
|
|
93
|
+
if self.use_slicing and x.shape[0] > 1:
|
|
94
|
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
|
95
|
+
h = torch.cat(encoded_slices)
|
|
96
|
+
else:
|
|
97
|
+
h = self.encoder(x)
|
|
98
|
+
if self.quant_conv is not None:
|
|
99
|
+
h = self.quant_conv(h)
|
|
100
|
+
|
|
101
|
+
else:
|
|
102
|
+
h = self.encoder(x)
|
|
103
|
+
if self.quant_conv is not None:
|
|
104
|
+
h = self.quant_conv(h)
|
|
105
|
+
return h
|
|
106
|
+
|
|
107
|
+
def forward(self, x):
|
|
108
|
+
vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
|
|
109
|
+
return vae_out
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class _VAECosmosEncoder(torch.nn.Module):
|
|
113
|
+
def __init__(self, vae: "AutoencoderKLCosmos"):
|
|
114
|
+
super().__init__()
|
|
115
|
+
self.vae = vae
|
|
116
|
+
|
|
117
|
+
def forward(self, x):
|
|
118
|
+
vae_out = self.vae._encode(x)
|
|
119
|
+
return vae_out
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class _VAECosmosDecoder(torch.nn.Module):
|
|
123
|
+
def __init__(self, vae: "AutoencoderKLCosmos"):
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.vae = vae
|
|
126
|
+
|
|
127
|
+
def forward(self, z):
|
|
128
|
+
vae_out = self.vae._decode(z, return_dict=False)
|
|
129
|
+
return vae_out
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class RBLNRuntimeVQEncoder(RBLNPytorchRuntime):
|
|
133
|
+
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
|
134
|
+
h = self.forward(x.contiguous())
|
|
135
|
+
return h
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class RBLNRuntimeVQDecoder(RBLNPytorchRuntime):
|
|
139
|
+
def decode(self, h: torch.Tensor, force_not_quantize: bool = False, shape=None, **kwargs) -> List[torch.Tensor]:
|
|
140
|
+
if not (force_not_quantize and not self.lookup_from_codebook):
|
|
141
|
+
raise ValueError(
|
|
142
|
+
"Currently, the `decode` method of the class `RBLNVQModel` is executed successfully only if `force_not_quantize` is True and `config.lookup_from_codebook` is False"
|
|
143
|
+
)
|
|
144
|
+
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
|
|
145
|
+
dec = self.forward(h.contiguous())
|
|
146
|
+
return dec, commit_loss
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class _VQEncoder(torch.nn.Module):
|
|
150
|
+
def __init__(self, vq_model: "VQModel"):
|
|
151
|
+
super().__init__()
|
|
152
|
+
self.vq_model = vq_model
|
|
153
|
+
|
|
154
|
+
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
|
155
|
+
h = self.vq_model.encoder(x)
|
|
156
|
+
h = self.vq_model.quant_conv(h)
|
|
157
|
+
return h
|
|
158
|
+
|
|
159
|
+
def forward(self, x: torch.Tensor):
|
|
160
|
+
vq_out = self.encode(x)
|
|
161
|
+
return vq_out
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class _VQDecoder(torch.nn.Module):
|
|
165
|
+
def __init__(self, vq_model: "VQModel"):
|
|
166
|
+
super().__init__()
|
|
167
|
+
self.vq_model = vq_model
|
|
168
|
+
|
|
169
|
+
def decode(self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None):
|
|
170
|
+
quant = h
|
|
171
|
+
quant2 = self.vq_model.post_quant_conv(quant)
|
|
172
|
+
quant = quant if self.vq_model.config.norm_type == "spatial" else None
|
|
173
|
+
dec = self.vq_model.decoder(quant2, quant)
|
|
174
|
+
return dec
|
|
175
|
+
|
|
176
|
+
def forward(self, h: torch.Tensor):
|
|
177
|
+
vq_out = self.decode(h)
|
|
178
|
+
return vq_out
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, Any, List, Union
|
|
16
|
+
|
|
17
|
+
import rebel
|
|
18
|
+
import torch
|
|
19
|
+
from diffusers import VQModel
|
|
20
|
+
from diffusers.models.autoencoders.vae import DecoderOutput
|
|
21
|
+
from diffusers.models.autoencoders.vq_model import VQEncoderOutput
|
|
22
|
+
|
|
23
|
+
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
24
|
+
from ....modeling import RBLNModel
|
|
25
|
+
from ....utils.logging import get_logger
|
|
26
|
+
from ...configurations.models.configuration_vq_model import RBLNVQModelConfig
|
|
27
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
28
|
+
from .vae import RBLNRuntimeVQDecoder, RBLNRuntimeVQEncoder, _VQDecoder, _VQEncoder
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
|
33
|
+
|
|
34
|
+
logger = get_logger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class RBLNVQModel(RBLNModel):
|
|
38
|
+
"""
|
|
39
|
+
RBLN implementation of VQModel for diffusion models.
|
|
40
|
+
|
|
41
|
+
This model is used to accelerate VQModel models from diffusers library on RBLN NPUs.
|
|
42
|
+
It can be configured to include both encoder and decoder, or just the decoder part for latent-to-image
|
|
43
|
+
conversion.
|
|
44
|
+
|
|
45
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
46
|
+
the library implements for all its models.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
auto_model_class = VQModel
|
|
50
|
+
config_name = "config.json"
|
|
51
|
+
hf_library_name = "diffusers"
|
|
52
|
+
|
|
53
|
+
def __post_init__(self, **kwargs):
|
|
54
|
+
super().__post_init__(**kwargs)
|
|
55
|
+
|
|
56
|
+
if self.rbln_config.uses_encoder:
|
|
57
|
+
self.encoder = RBLNRuntimeVQEncoder(runtime=self.model[0], main_input_name="x")
|
|
58
|
+
else:
|
|
59
|
+
self.encoder = None
|
|
60
|
+
|
|
61
|
+
self.decoder = RBLNRuntimeVQDecoder(runtime=self.model[-1], main_input_name="z")
|
|
62
|
+
self.decoder.lookup_from_codebook = self.config.lookup_from_codebook
|
|
63
|
+
self.image_size = self.rbln_config.image_size
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def get_compiled_model(cls, model, rbln_config: RBLNModelConfig):
|
|
67
|
+
if rbln_config.uses_encoder:
|
|
68
|
+
expected_models = ["encoder", "decoder"]
|
|
69
|
+
else:
|
|
70
|
+
expected_models = ["decoder"]
|
|
71
|
+
|
|
72
|
+
compiled_models = {}
|
|
73
|
+
for i, model_name in enumerate(expected_models):
|
|
74
|
+
if model_name == "encoder":
|
|
75
|
+
wrapped_model = _VQEncoder(model)
|
|
76
|
+
else:
|
|
77
|
+
wrapped_model = _VQDecoder(model)
|
|
78
|
+
|
|
79
|
+
wrapped_model.eval()
|
|
80
|
+
|
|
81
|
+
compiled_models[model_name] = cls.compile(
|
|
82
|
+
wrapped_model,
|
|
83
|
+
rbln_compile_config=rbln_config.compile_cfgs[i],
|
|
84
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
85
|
+
device=rbln_config.device_map[model_name],
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return compiled_models
|
|
89
|
+
|
|
90
|
+
@classmethod
|
|
91
|
+
def update_rbln_config_using_pipe(
|
|
92
|
+
cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
93
|
+
) -> "RBLNDiffusionMixinConfig":
|
|
94
|
+
return rbln_config
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def _update_rbln_config(
|
|
98
|
+
cls,
|
|
99
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
100
|
+
model: "PreTrainedModel",
|
|
101
|
+
model_config: "PretrainedConfig",
|
|
102
|
+
rbln_config: RBLNVQModelConfig,
|
|
103
|
+
) -> RBLNVQModelConfig:
|
|
104
|
+
if hasattr(model_config, "block_out_channels"):
|
|
105
|
+
rbln_config.vqmodel_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
|
|
106
|
+
else:
|
|
107
|
+
# image processor default value 8 (int)
|
|
108
|
+
rbln_config.vqmodel_scale_factor = 8
|
|
109
|
+
|
|
110
|
+
compile_cfgs = []
|
|
111
|
+
if rbln_config.uses_encoder:
|
|
112
|
+
enc_input_info = [
|
|
113
|
+
(
|
|
114
|
+
"x",
|
|
115
|
+
[
|
|
116
|
+
rbln_config.batch_size,
|
|
117
|
+
model_config.in_channels,
|
|
118
|
+
rbln_config.sample_size[0],
|
|
119
|
+
rbln_config.sample_size[1],
|
|
120
|
+
],
|
|
121
|
+
"float32",
|
|
122
|
+
)
|
|
123
|
+
]
|
|
124
|
+
enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
|
125
|
+
compile_cfgs.append(enc_rbln_compile_config)
|
|
126
|
+
|
|
127
|
+
dec_input_info = [
|
|
128
|
+
(
|
|
129
|
+
"h",
|
|
130
|
+
[
|
|
131
|
+
rbln_config.batch_size,
|
|
132
|
+
model_config.latent_channels,
|
|
133
|
+
rbln_config.latent_sample_size[0],
|
|
134
|
+
rbln_config.latent_sample_size[1],
|
|
135
|
+
],
|
|
136
|
+
"float32",
|
|
137
|
+
)
|
|
138
|
+
]
|
|
139
|
+
dec_rbln_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
|
140
|
+
compile_cfgs.append(dec_rbln_compile_config)
|
|
141
|
+
|
|
142
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
143
|
+
return rbln_config
|
|
144
|
+
|
|
145
|
+
@classmethod
|
|
146
|
+
def _create_runtimes(
|
|
147
|
+
cls,
|
|
148
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
|
149
|
+
rbln_config: RBLNVQModelConfig,
|
|
150
|
+
) -> List[rebel.Runtime]:
|
|
151
|
+
if len(compiled_models) == 1:
|
|
152
|
+
# decoder
|
|
153
|
+
expected_models = ["decoder"]
|
|
154
|
+
else:
|
|
155
|
+
# encoder, decoder
|
|
156
|
+
expected_models = ["encoder", "decoder"]
|
|
157
|
+
|
|
158
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_models):
|
|
159
|
+
cls._raise_missing_compiled_file_error(expected_models)
|
|
160
|
+
|
|
161
|
+
device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
|
|
162
|
+
return [
|
|
163
|
+
rebel.Runtime(
|
|
164
|
+
compiled_model,
|
|
165
|
+
tensor_type="pt",
|
|
166
|
+
device=device_val,
|
|
167
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
168
|
+
timeout=rbln_config.timeout,
|
|
169
|
+
)
|
|
170
|
+
for compiled_model, device_val in zip(compiled_models, device_vals)
|
|
171
|
+
]
|
|
172
|
+
|
|
173
|
+
def encode(
|
|
174
|
+
self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Any
|
|
175
|
+
) -> Union[torch.FloatTensor, VQEncoderOutput]:
|
|
176
|
+
"""
|
|
177
|
+
Encode an input image into a quantized latent representation.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
x: The input image to encode.
|
|
181
|
+
return_dict:
|
|
182
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
183
|
+
kwargs: Additional arguments to pass to the encoder/quantizer.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
The quantized latent representation or a specific output object.
|
|
187
|
+
"""
|
|
188
|
+
posterior = self.encoder.encode(x)
|
|
189
|
+
if not return_dict:
|
|
190
|
+
return (posterior,)
|
|
191
|
+
return VQEncoderOutput(latents=posterior)
|
|
192
|
+
|
|
193
|
+
def decode(
|
|
194
|
+
self, h: torch.FloatTensor, return_dict: bool = True, **kwargs: Any
|
|
195
|
+
) -> Union[torch.FloatTensor, DecoderOutput]:
|
|
196
|
+
"""
|
|
197
|
+
Decode a quantized latent representation back into an image.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
h: The quantized latent representation to decode.
|
|
201
|
+
return_dict:
|
|
202
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
203
|
+
kwargs: Additional arguments to pass to the decoder.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
The decoded image or a DecoderOutput object.
|
|
207
|
+
"""
|
|
208
|
+
dec, commit_loss = self.decoder.decode(h, **kwargs)
|
|
209
|
+
if not return_dict:
|
|
210
|
+
return (dec, commit_loss)
|
|
211
|
+
return DecoderOutput(sample=dec, commit_loss=commit_loss)
|