optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +27 -13
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +85 -65
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
- optimum/rbln/diffusers/pipelines/__init__.py +60 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +176 -763
- optimum/rbln/modeling_diffusers.py +329 -0
- optimum/rbln/transformers/__init__.py +2 -2
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
- optimum/rbln/utils/decorator_utils.py +59 -0
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +5 -5
- optimum/rbln/utils/submodule.py +114 -0
- optimum/rbln/utils/timer_utils.py +2 -2
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- optimum_rbln-0.1.15.dist-info/RECORD +110 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum_rbln-0.1.12.dist-info/METADATA +0 -119
- optimum_rbln-0.1.12.dist-info/RECORD +0 -103
- optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -38,9 +38,9 @@ _import_structure = {
|
|
38
38
|
"RBLNXLMRobertaForSequenceClassification",
|
39
39
|
"RBLNRobertaForSequenceClassification",
|
40
40
|
"RBLNRobertaForMaskedLM",
|
41
|
-
"RBLNViTForImageClassification"
|
41
|
+
"RBLNViTForImageClassification",
|
42
42
|
],
|
43
|
-
"
|
43
|
+
"modeling": [
|
44
44
|
"RBLNBaseModel",
|
45
45
|
"RBLNModel",
|
46
46
|
"RBLNModelForQuestionAnswering",
|
@@ -50,7 +50,6 @@ _import_structure = {
|
|
50
50
|
"RBLNModelForMaskedLM",
|
51
51
|
],
|
52
52
|
"transformers": [
|
53
|
-
"BatchTextIteratorStreamer",
|
54
53
|
"RBLNAutoModel",
|
55
54
|
"RBLNAutoModelForAudioClassification",
|
56
55
|
"RBLNAutoModelForCausalLM",
|
@@ -76,6 +75,7 @@ _import_structure = {
|
|
76
75
|
"RBLNQwen2ForCausalLM",
|
77
76
|
"RBLNWav2Vec2ForCTC",
|
78
77
|
"RBLNLlamaForCausalLM",
|
78
|
+
"RBLNT5EncoderModel",
|
79
79
|
"RBLNT5ForConditionalGeneration",
|
80
80
|
"RBLNPhiForCausalLM",
|
81
81
|
"RBLNLlavaNextForConditionalGeneration",
|
@@ -91,14 +91,21 @@ _import_structure = {
|
|
91
91
|
"RBLNUNet2DConditionModel",
|
92
92
|
"RBLNControlNetModel",
|
93
93
|
"RBLNStableDiffusionImg2ImgPipeline",
|
94
|
+
"RBLNStableDiffusionInpaintPipeline",
|
94
95
|
"RBLNStableDiffusionControlNetImg2ImgPipeline",
|
95
96
|
"RBLNMultiControlNetModel",
|
96
97
|
"RBLNStableDiffusionXLImg2ImgPipeline",
|
98
|
+
"RBLNStableDiffusionXLInpaintPipeline",
|
97
99
|
"RBLNStableDiffusionControlNetPipeline",
|
98
100
|
"RBLNStableDiffusionXLControlNetPipeline",
|
99
101
|
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
102
|
+
"RBLNSD3Transformer2DModel",
|
103
|
+
"RBLNStableDiffusion3Img2ImgPipeline",
|
104
|
+
"RBLNStableDiffusion3InpaintPipeline",
|
105
|
+
"RBLNStableDiffusion3Pipeline",
|
100
106
|
],
|
101
107
|
"modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
|
108
|
+
"modeling_diffusers": ["RBLNDiffusionMixin"],
|
102
109
|
}
|
103
110
|
|
104
111
|
if TYPE_CHECKING:
|
@@ -106,16 +113,31 @@ if TYPE_CHECKING:
|
|
106
113
|
RBLNAutoencoderKL,
|
107
114
|
RBLNControlNetModel,
|
108
115
|
RBLNMultiControlNetModel,
|
116
|
+
RBLNSD3Transformer2DModel,
|
117
|
+
RBLNStableDiffusion3Img2ImgPipeline,
|
118
|
+
RBLNStableDiffusion3InpaintPipeline,
|
119
|
+
RBLNStableDiffusion3Pipeline,
|
109
120
|
RBLNStableDiffusionControlNetImg2ImgPipeline,
|
110
121
|
RBLNStableDiffusionControlNetPipeline,
|
111
122
|
RBLNStableDiffusionImg2ImgPipeline,
|
123
|
+
RBLNStableDiffusionInpaintPipeline,
|
112
124
|
RBLNStableDiffusionPipeline,
|
113
125
|
RBLNStableDiffusionXLControlNetImg2ImgPipeline,
|
114
126
|
RBLNStableDiffusionXLControlNetPipeline,
|
115
127
|
RBLNStableDiffusionXLImg2ImgPipeline,
|
128
|
+
RBLNStableDiffusionXLInpaintPipeline,
|
116
129
|
RBLNStableDiffusionXLPipeline,
|
117
130
|
RBLNUNet2DConditionModel,
|
118
131
|
)
|
132
|
+
from .modeling import (
|
133
|
+
RBLNBaseModel,
|
134
|
+
RBLNModel,
|
135
|
+
RBLNModelForAudioClassification,
|
136
|
+
RBLNModelForImageClassification,
|
137
|
+
RBLNModelForMaskedLM,
|
138
|
+
RBLNModelForQuestionAnswering,
|
139
|
+
RBLNModelForSequenceClassification,
|
140
|
+
)
|
119
141
|
from .modeling_alias import (
|
120
142
|
RBLNASTForAudioClassification,
|
121
143
|
RBLNBertForQuestionAnswering,
|
@@ -126,18 +148,9 @@ if TYPE_CHECKING:
|
|
126
148
|
RBLNViTForImageClassification,
|
127
149
|
RBLNXLMRobertaForSequenceClassification,
|
128
150
|
)
|
129
|
-
from .modeling_base import (
|
130
|
-
RBLNBaseModel,
|
131
|
-
RBLNModel,
|
132
|
-
RBLNModelForAudioClassification,
|
133
|
-
RBLNModelForImageClassification,
|
134
|
-
RBLNModelForMaskedLM,
|
135
|
-
RBLNModelForQuestionAnswering,
|
136
|
-
RBLNModelForSequenceClassification,
|
137
|
-
)
|
138
151
|
from .modeling_config import RBLNCompileConfig, RBLNConfig
|
152
|
+
from .modeling_diffusers import RBLNDiffusionMixin
|
139
153
|
from .transformers import (
|
140
|
-
BatchTextIteratorStreamer,
|
141
154
|
RBLNAutoModel,
|
142
155
|
RBLNAutoModelForAudioClassification,
|
143
156
|
RBLNAutoModelForCausalLM,
|
@@ -166,6 +179,7 @@ if TYPE_CHECKING:
|
|
166
179
|
RBLNMistralForCausalLM,
|
167
180
|
RBLNPhiForCausalLM,
|
168
181
|
RBLNQwen2ForCausalLM,
|
182
|
+
RBLNT5EncoderModel,
|
169
183
|
RBLNT5ForConditionalGeneration,
|
170
184
|
RBLNWav2Vec2ForCTC,
|
171
185
|
RBLNWhisperForConditionalGeneration,
|
optimum/rbln/__version__.py
CHANGED
@@ -1 +1,16 @@
|
|
1
|
-
|
1
|
+
# file generated by setuptools_scm
|
2
|
+
# don't change, don't track in version control
|
3
|
+
TYPE_CHECKING = False
|
4
|
+
if TYPE_CHECKING:
|
5
|
+
from typing import Tuple, Union
|
6
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
7
|
+
else:
|
8
|
+
VERSION_TUPLE = object
|
9
|
+
|
10
|
+
version: str
|
11
|
+
__version__: str
|
12
|
+
__version_tuple__: VERSION_TUPLE
|
13
|
+
version_tuple: VERSION_TUPLE
|
14
|
+
|
15
|
+
__version__ = version = '0.1.15'
|
16
|
+
__version_tuple__ = version_tuple = (0, 1, 15)
|
@@ -36,27 +36,47 @@ _import_structure = {
|
|
36
36
|
"RBLNStableDiffusionPipeline",
|
37
37
|
"RBLNStableDiffusionXLPipeline",
|
38
38
|
"RBLNStableDiffusionImg2ImgPipeline",
|
39
|
+
"RBLNStableDiffusionInpaintPipeline",
|
39
40
|
"RBLNStableDiffusionControlNetImg2ImgPipeline",
|
40
41
|
"RBLNMultiControlNetModel",
|
41
42
|
"RBLNStableDiffusionXLImg2ImgPipeline",
|
43
|
+
"RBLNStableDiffusionXLInpaintPipeline",
|
42
44
|
"RBLNStableDiffusionControlNetPipeline",
|
43
45
|
"RBLNStableDiffusionXLControlNetPipeline",
|
44
46
|
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
47
|
+
"RBLNStableDiffusion3Pipeline",
|
48
|
+
"RBLNStableDiffusion3Img2ImgPipeline",
|
49
|
+
"RBLNStableDiffusion3InpaintPipeline",
|
50
|
+
],
|
51
|
+
"models": [
|
52
|
+
"RBLNAutoencoderKL",
|
53
|
+
"RBLNUNet2DConditionModel",
|
54
|
+
"RBLNControlNetModel",
|
55
|
+
"RBLNSD3Transformer2DModel",
|
45
56
|
],
|
46
|
-
"models": ["RBLNAutoencoderKL", "RBLNUNet2DConditionModel", "RBLNControlNetModel"],
|
47
57
|
}
|
48
58
|
|
49
59
|
if TYPE_CHECKING:
|
50
|
-
from .models import
|
60
|
+
from .models import (
|
61
|
+
RBLNAutoencoderKL,
|
62
|
+
RBLNControlNetModel,
|
63
|
+
RBLNSD3Transformer2DModel,
|
64
|
+
RBLNUNet2DConditionModel,
|
65
|
+
)
|
51
66
|
from .pipelines import (
|
52
67
|
RBLNMultiControlNetModel,
|
68
|
+
RBLNStableDiffusion3Img2ImgPipeline,
|
69
|
+
RBLNStableDiffusion3InpaintPipeline,
|
70
|
+
RBLNStableDiffusion3Pipeline,
|
53
71
|
RBLNStableDiffusionControlNetImg2ImgPipeline,
|
54
72
|
RBLNStableDiffusionControlNetPipeline,
|
55
73
|
RBLNStableDiffusionImg2ImgPipeline,
|
74
|
+
RBLNStableDiffusionInpaintPipeline,
|
56
75
|
RBLNStableDiffusionPipeline,
|
57
76
|
RBLNStableDiffusionXLControlNetImg2ImgPipeline,
|
58
77
|
RBLNStableDiffusionXLControlNetPipeline,
|
59
78
|
RBLNStableDiffusionXLImg2ImgPipeline,
|
79
|
+
RBLNStableDiffusionXLInpaintPipeline,
|
60
80
|
RBLNStableDiffusionXLPipeline,
|
61
81
|
)
|
62
82
|
else:
|
@@ -20,7 +20,38 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
+
from typing import TYPE_CHECKING
|
23
24
|
|
24
|
-
from .
|
25
|
-
|
26
|
-
|
25
|
+
from transformers.utils import _LazyModule
|
26
|
+
|
27
|
+
|
28
|
+
_import_structure = {
|
29
|
+
"autoencoders": [
|
30
|
+
"RBLNAutoencoderKL",
|
31
|
+
],
|
32
|
+
"unets": [
|
33
|
+
"RBLNUNet2DConditionModel",
|
34
|
+
],
|
35
|
+
"controlnet": ["RBLNControlNetModel"],
|
36
|
+
"transformers": ["RBLNSD3Transformer2DModel"],
|
37
|
+
}
|
38
|
+
if TYPE_CHECKING:
|
39
|
+
from .autoencoders import (
|
40
|
+
RBLNAutoencoderKL,
|
41
|
+
)
|
42
|
+
from .controlnet import RBLNControlNetModel
|
43
|
+
from .transformers import (
|
44
|
+
RBLNSD3Transformer2DModel,
|
45
|
+
)
|
46
|
+
from .unets import (
|
47
|
+
RBLNUNet2DConditionModel,
|
48
|
+
)
|
49
|
+
else:
|
50
|
+
import sys
|
51
|
+
|
52
|
+
sys.modules[__name__] = _LazyModule(
|
53
|
+
__name__,
|
54
|
+
globals()["__file__"],
|
55
|
+
_import_structure,
|
56
|
+
module_spec=__spec__,
|
57
|
+
)
|
@@ -22,20 +22,18 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import logging
|
25
|
-
from
|
26
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
25
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
27
26
|
|
28
27
|
import rebel
|
29
28
|
import torch # noqa: I001
|
30
29
|
from diffusers import AutoencoderKL
|
31
|
-
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
32
30
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
33
|
-
from
|
34
|
-
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
31
|
+
from transformers import PretrainedConfig
|
35
32
|
|
36
|
-
from
|
37
|
-
from
|
38
|
-
from
|
33
|
+
from ....modeling import RBLNModel
|
34
|
+
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
35
|
+
from ....modeling_diffusers import RBLNDiffusionMixin
|
36
|
+
from .vae import RBLNRuntimeVAEDecoder, RBLNRuntimeVAEEncoder, _VAEDecoder, _VAEEncoder
|
39
37
|
|
40
38
|
|
41
39
|
if TYPE_CHECKING:
|
@@ -45,31 +43,22 @@ if TYPE_CHECKING:
|
|
45
43
|
logger = logging.getLogger(__name__)
|
46
44
|
|
47
45
|
|
48
|
-
class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
49
|
-
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
50
|
-
moments = self.forward(x.contiguous())
|
51
|
-
posterior = DiagonalGaussianDistribution(moments)
|
52
|
-
return AutoencoderKLOutput(latent_dist=posterior)
|
53
|
-
|
54
|
-
|
55
|
-
class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
|
56
|
-
def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
57
|
-
return (self.forward(z),)
|
58
|
-
|
59
|
-
|
60
46
|
class RBLNAutoencoderKL(RBLNModel):
|
47
|
+
auto_model_class = AutoencoderKL
|
61
48
|
config_name = "config.json"
|
49
|
+
hf_library_name = "diffusers"
|
62
50
|
|
63
51
|
def __post_init__(self, **kwargs):
|
64
52
|
super().__post_init__(**kwargs)
|
65
53
|
|
66
|
-
self.
|
67
|
-
if self.rbln_use_encode:
|
54
|
+
if self.rbln_config.model_cfg.get("img2img_pipeline") or self.rbln_config.model_cfg.get("inpaint_pipeline"):
|
68
55
|
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
|
69
56
|
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
|
70
57
|
else:
|
71
58
|
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[0], main_input_name="z")
|
72
59
|
|
60
|
+
self.image_size = self.rbln_config.model_cfg["sample_size"]
|
61
|
+
|
73
62
|
@classmethod
|
74
63
|
def get_compiled_model(cls, model, rbln_config: RBLNConfig):
|
75
64
|
def compile_img2img():
|
@@ -91,39 +80,40 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
91
80
|
|
92
81
|
return dec_compiled_model
|
93
82
|
|
94
|
-
if rbln_config.model_cfg.get("
|
83
|
+
if rbln_config.model_cfg.get("img2img_pipeline") or rbln_config.model_cfg.get("inpaint_pipeline"):
|
95
84
|
return compile_img2img()
|
96
85
|
else:
|
97
86
|
return compile_text2img()
|
98
87
|
|
99
88
|
@classmethod
|
100
|
-
def
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
89
|
+
def get_vae_sample_size(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
|
90
|
+
image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
|
91
|
+
if (image_size[0] is None) != (image_size[1] is None):
|
92
|
+
raise ValueError("Both image height and image width must be given or not given")
|
93
|
+
elif image_size[0] is None and image_size[1] is None:
|
94
|
+
if rbln_config["img2img_pipeline"]:
|
95
|
+
sample_size = pipe.vae.config.sample_size
|
96
|
+
elif rbln_config["inpaint_pipeline"]:
|
97
|
+
sample_size = pipe.unet.config.sample_size * pipe.vae_scale_factor
|
98
|
+
else:
|
99
|
+
# In case of text2img, sample size of vae decoder is determined by unet.
|
100
|
+
unet_sample_size = pipe.unet.config.sample_size
|
101
|
+
if isinstance(unet_sample_size, int):
|
102
|
+
sample_size = unet_sample_size * pipe.vae_scale_factor
|
103
|
+
else:
|
104
|
+
sample_size = (
|
105
|
+
unet_sample_size[0] * pipe.vae_scale_factor,
|
106
|
+
unet_sample_size[1] * pipe.vae_scale_factor,
|
107
|
+
)
|
118
108
|
else:
|
119
|
-
|
109
|
+
sample_size = (image_size[0], image_size[1])
|
120
110
|
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
return
|
111
|
+
return sample_size
|
112
|
+
|
113
|
+
@classmethod
|
114
|
+
def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
115
|
+
rbln_config.update({"sample_size": cls.get_vae_sample_size(pipe, rbln_config)})
|
116
|
+
return rbln_config
|
127
117
|
|
128
118
|
@classmethod
|
129
119
|
def _get_rbln_config(
|
@@ -132,34 +122,43 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
132
122
|
model_config: "PretrainedConfig",
|
133
123
|
rbln_kwargs: Dict[str, Any] = {},
|
134
124
|
) -> RBLNConfig:
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
rbln_use_encode = rbln_kwargs.get("use_encode", None)
|
140
|
-
rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
|
125
|
+
rbln_batch_size = rbln_kwargs.get("batch_size")
|
126
|
+
sample_size = rbln_kwargs.get("sample_size")
|
127
|
+
is_img2img = rbln_kwargs.get("img2img_pipeline")
|
128
|
+
is_inpaint = rbln_kwargs.get("inpaint_pipeline")
|
141
129
|
|
142
130
|
if rbln_batch_size is None:
|
143
131
|
rbln_batch_size = 1
|
144
132
|
|
145
|
-
|
133
|
+
if sample_size is None:
|
134
|
+
sample_size = model_config.sample_size
|
135
|
+
|
136
|
+
if isinstance(sample_size, int):
|
137
|
+
sample_size = (sample_size, sample_size)
|
138
|
+
|
139
|
+
rbln_kwargs["sample_size"] = sample_size
|
140
|
+
|
141
|
+
if hasattr(model_config, "block_out_channels"):
|
142
|
+
vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
|
143
|
+
else:
|
144
|
+
# vae image processor default value 8 (int)
|
145
|
+
vae_scale_factor = 8
|
146
146
|
|
147
|
-
|
148
|
-
|
149
|
-
model_cfg["img_height"] = rbln_img_height
|
147
|
+
dec_shape = (sample_size[0] // vae_scale_factor, sample_size[1] // vae_scale_factor)
|
148
|
+
enc_shape = (sample_size[0], sample_size[1])
|
150
149
|
|
150
|
+
if is_img2img or is_inpaint:
|
151
151
|
vae_enc_input_info = [
|
152
|
-
(
|
152
|
+
(
|
153
|
+
"x",
|
154
|
+
[rbln_batch_size, model_config.in_channels, enc_shape[0], enc_shape[1]],
|
155
|
+
"float32",
|
156
|
+
)
|
153
157
|
]
|
154
158
|
vae_dec_input_info = [
|
155
159
|
(
|
156
160
|
"z",
|
157
|
-
[
|
158
|
-
rbln_batch_size,
|
159
|
-
model_config.latent_channels,
|
160
|
-
rbln_img_height // rbln_vae_scale_factor,
|
161
|
-
rbln_img_width // rbln_vae_scale_factor,
|
162
|
-
],
|
161
|
+
[rbln_batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
|
163
162
|
"float32",
|
164
163
|
)
|
165
164
|
]
|
@@ -173,33 +172,22 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
173
172
|
compile_cfgs=compile_cfgs,
|
174
173
|
rbln_kwargs=rbln_kwargs,
|
175
174
|
)
|
176
|
-
rbln_config.model_cfg.update(model_cfg)
|
177
175
|
return rbln_config
|
178
176
|
|
179
|
-
if rbln_unet_sample_size is None:
|
180
|
-
rbln_unet_sample_size = 64
|
181
|
-
|
182
|
-
model_cfg["unet_sample_size"] = rbln_unet_sample_size
|
183
177
|
vae_config = RBLNCompileConfig(
|
184
178
|
input_info=[
|
185
179
|
(
|
186
180
|
"z",
|
187
|
-
[
|
188
|
-
rbln_batch_size,
|
189
|
-
model_config.latent_channels,
|
190
|
-
rbln_unet_sample_size,
|
191
|
-
rbln_unet_sample_size,
|
192
|
-
],
|
181
|
+
[rbln_batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
|
193
182
|
"float32",
|
194
183
|
)
|
195
|
-
]
|
184
|
+
]
|
196
185
|
)
|
197
186
|
rbln_config = RBLNConfig(
|
198
187
|
rbln_cls=cls.__name__,
|
199
188
|
compile_cfgs=[vae_config],
|
200
189
|
rbln_kwargs=rbln_kwargs,
|
201
190
|
)
|
202
|
-
rbln_config.model_cfg.update(model_cfg)
|
203
191
|
return rbln_config
|
204
192
|
|
205
193
|
@classmethod
|
@@ -222,36 +210,3 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
222
210
|
|
223
211
|
def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
224
212
|
return self.decoder.decode(z)
|
225
|
-
|
226
|
-
|
227
|
-
class _VAEDecoder(torch.nn.Module):
|
228
|
-
def __init__(self, vae: "AutoencoderKL"):
|
229
|
-
super().__init__()
|
230
|
-
self.vae = vae
|
231
|
-
|
232
|
-
def forward(self, z):
|
233
|
-
vae_out = self.vae.decode(z, return_dict=False)
|
234
|
-
return vae_out
|
235
|
-
|
236
|
-
|
237
|
-
class _VAEEncoder(torch.nn.Module):
|
238
|
-
def __init__(self, vae: "AutoencoderKL"):
|
239
|
-
super().__init__()
|
240
|
-
self.vae = vae
|
241
|
-
|
242
|
-
def encode(self, x: torch.FloatTensor, return_dict: bool = True):
|
243
|
-
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
244
|
-
return self.tiled_encode(x, return_dict=return_dict)
|
245
|
-
|
246
|
-
if self.use_slicing and x.shape[0] > 1:
|
247
|
-
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
248
|
-
h = torch.cat(encoded_slices)
|
249
|
-
else:
|
250
|
-
h = self.encoder(x)
|
251
|
-
|
252
|
-
moments = self.quant_conv(h)
|
253
|
-
return moments
|
254
|
-
|
255
|
-
def forward(self, x):
|
256
|
-
vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
|
257
|
-
return vae_out
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
|
25
|
+
import logging
|
26
|
+
from typing import TYPE_CHECKING
|
27
|
+
|
28
|
+
import torch # noqa: I001
|
29
|
+
from diffusers import AutoencoderKL
|
30
|
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
31
|
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
32
|
+
|
33
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime
|
34
|
+
|
35
|
+
|
36
|
+
if TYPE_CHECKING:
|
37
|
+
import torch
|
38
|
+
|
39
|
+
logger = logging.getLogger(__name__)
|
40
|
+
|
41
|
+
|
42
|
+
class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
43
|
+
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
44
|
+
moments = self.forward(x.contiguous())
|
45
|
+
posterior = DiagonalGaussianDistribution(moments)
|
46
|
+
return AutoencoderKLOutput(latent_dist=posterior)
|
47
|
+
|
48
|
+
|
49
|
+
class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
|
50
|
+
def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
51
|
+
return (self.forward(z),)
|
52
|
+
|
53
|
+
|
54
|
+
class _VAEDecoder(torch.nn.Module):
|
55
|
+
def __init__(self, vae: "AutoencoderKL"):
|
56
|
+
super().__init__()
|
57
|
+
self.vae = vae
|
58
|
+
|
59
|
+
def forward(self, z):
|
60
|
+
vae_out = self.vae.decode(z, return_dict=False)
|
61
|
+
return vae_out
|
62
|
+
|
63
|
+
|
64
|
+
class _VAEEncoder(torch.nn.Module):
|
65
|
+
def __init__(self, vae: "AutoencoderKL"):
|
66
|
+
super().__init__()
|
67
|
+
self.vae = vae
|
68
|
+
|
69
|
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True):
|
70
|
+
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
71
|
+
return self.tiled_encode(x, return_dict=return_dict)
|
72
|
+
|
73
|
+
if self.use_slicing and x.shape[0] > 1:
|
74
|
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
75
|
+
h = torch.cat(encoded_slices)
|
76
|
+
else:
|
77
|
+
h = self.encoder(x)
|
78
|
+
if self.quant_conv is not None:
|
79
|
+
h = self.quant_conv(h)
|
80
|
+
return h
|
81
|
+
|
82
|
+
def forward(self, x):
|
83
|
+
vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
|
84
|
+
return vae_out
|