optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. optimum/rbln/__init__.py +27 -13
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +85 -65
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
  13. optimum/rbln/diffusers/pipelines/__init__.py +60 -12
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  31. optimum/rbln/modeling.py +572 -0
  32. optimum/rbln/modeling_alias.py +1 -1
  33. optimum/rbln/modeling_base.py +176 -763
  34. optimum/rbln/modeling_diffusers.py +329 -0
  35. optimum/rbln/transformers/__init__.py +2 -2
  36. optimum/rbln/transformers/cache_utils.py +5 -9
  37. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  38. optimum/rbln/transformers/models/__init__.py +80 -31
  39. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  40. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
  44. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
  45. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
  46. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
  47. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  48. optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
  49. optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
  50. optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
  51. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  53. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
  54. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  55. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
  56. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
  57. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
  58. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  59. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  60. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
  61. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  62. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  63. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  64. optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
  65. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  66. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  67. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  68. optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
  69. optimum/rbln/utils/decorator_utils.py +59 -0
  70. optimum/rbln/utils/hub.py +131 -0
  71. optimum/rbln/utils/import_utils.py +21 -0
  72. optimum/rbln/utils/model_utils.py +53 -0
  73. optimum/rbln/utils/runtime_utils.py +5 -5
  74. optimum/rbln/utils/submodule.py +114 -0
  75. optimum/rbln/utils/timer_utils.py +2 -2
  76. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  77. optimum_rbln-0.1.15.dist-info/RECORD +110 -0
  78. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  79. optimum/rbln/transformers/generation/streamers.py +0 -139
  80. optimum/rbln/transformers/generation/utils.py +0 -397
  81. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  82. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  83. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  84. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  85. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  86. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  87. optimum_rbln-0.1.12.dist-info/METADATA +0 -119
  88. optimum_rbln-0.1.12.dist-info/RECORD +0 -103
  89. optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
  90. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -38,9 +38,9 @@ _import_structure = {
38
38
  "RBLNXLMRobertaForSequenceClassification",
39
39
  "RBLNRobertaForSequenceClassification",
40
40
  "RBLNRobertaForMaskedLM",
41
- "RBLNViTForImageClassification"
41
+ "RBLNViTForImageClassification",
42
42
  ],
43
- "modeling_base": [
43
+ "modeling": [
44
44
  "RBLNBaseModel",
45
45
  "RBLNModel",
46
46
  "RBLNModelForQuestionAnswering",
@@ -50,7 +50,6 @@ _import_structure = {
50
50
  "RBLNModelForMaskedLM",
51
51
  ],
52
52
  "transformers": [
53
- "BatchTextIteratorStreamer",
54
53
  "RBLNAutoModel",
55
54
  "RBLNAutoModelForAudioClassification",
56
55
  "RBLNAutoModelForCausalLM",
@@ -76,6 +75,7 @@ _import_structure = {
76
75
  "RBLNQwen2ForCausalLM",
77
76
  "RBLNWav2Vec2ForCTC",
78
77
  "RBLNLlamaForCausalLM",
78
+ "RBLNT5EncoderModel",
79
79
  "RBLNT5ForConditionalGeneration",
80
80
  "RBLNPhiForCausalLM",
81
81
  "RBLNLlavaNextForConditionalGeneration",
@@ -91,14 +91,21 @@ _import_structure = {
91
91
  "RBLNUNet2DConditionModel",
92
92
  "RBLNControlNetModel",
93
93
  "RBLNStableDiffusionImg2ImgPipeline",
94
+ "RBLNStableDiffusionInpaintPipeline",
94
95
  "RBLNStableDiffusionControlNetImg2ImgPipeline",
95
96
  "RBLNMultiControlNetModel",
96
97
  "RBLNStableDiffusionXLImg2ImgPipeline",
98
+ "RBLNStableDiffusionXLInpaintPipeline",
97
99
  "RBLNStableDiffusionControlNetPipeline",
98
100
  "RBLNStableDiffusionXLControlNetPipeline",
99
101
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
102
+ "RBLNSD3Transformer2DModel",
103
+ "RBLNStableDiffusion3Img2ImgPipeline",
104
+ "RBLNStableDiffusion3InpaintPipeline",
105
+ "RBLNStableDiffusion3Pipeline",
100
106
  ],
101
107
  "modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
108
+ "modeling_diffusers": ["RBLNDiffusionMixin"],
102
109
  }
103
110
 
104
111
  if TYPE_CHECKING:
@@ -106,16 +113,31 @@ if TYPE_CHECKING:
106
113
  RBLNAutoencoderKL,
107
114
  RBLNControlNetModel,
108
115
  RBLNMultiControlNetModel,
116
+ RBLNSD3Transformer2DModel,
117
+ RBLNStableDiffusion3Img2ImgPipeline,
118
+ RBLNStableDiffusion3InpaintPipeline,
119
+ RBLNStableDiffusion3Pipeline,
109
120
  RBLNStableDiffusionControlNetImg2ImgPipeline,
110
121
  RBLNStableDiffusionControlNetPipeline,
111
122
  RBLNStableDiffusionImg2ImgPipeline,
123
+ RBLNStableDiffusionInpaintPipeline,
112
124
  RBLNStableDiffusionPipeline,
113
125
  RBLNStableDiffusionXLControlNetImg2ImgPipeline,
114
126
  RBLNStableDiffusionXLControlNetPipeline,
115
127
  RBLNStableDiffusionXLImg2ImgPipeline,
128
+ RBLNStableDiffusionXLInpaintPipeline,
116
129
  RBLNStableDiffusionXLPipeline,
117
130
  RBLNUNet2DConditionModel,
118
131
  )
132
+ from .modeling import (
133
+ RBLNBaseModel,
134
+ RBLNModel,
135
+ RBLNModelForAudioClassification,
136
+ RBLNModelForImageClassification,
137
+ RBLNModelForMaskedLM,
138
+ RBLNModelForQuestionAnswering,
139
+ RBLNModelForSequenceClassification,
140
+ )
119
141
  from .modeling_alias import (
120
142
  RBLNASTForAudioClassification,
121
143
  RBLNBertForQuestionAnswering,
@@ -126,18 +148,9 @@ if TYPE_CHECKING:
126
148
  RBLNViTForImageClassification,
127
149
  RBLNXLMRobertaForSequenceClassification,
128
150
  )
129
- from .modeling_base import (
130
- RBLNBaseModel,
131
- RBLNModel,
132
- RBLNModelForAudioClassification,
133
- RBLNModelForImageClassification,
134
- RBLNModelForMaskedLM,
135
- RBLNModelForQuestionAnswering,
136
- RBLNModelForSequenceClassification,
137
- )
138
151
  from .modeling_config import RBLNCompileConfig, RBLNConfig
152
+ from .modeling_diffusers import RBLNDiffusionMixin
139
153
  from .transformers import (
140
- BatchTextIteratorStreamer,
141
154
  RBLNAutoModel,
142
155
  RBLNAutoModelForAudioClassification,
143
156
  RBLNAutoModelForCausalLM,
@@ -166,6 +179,7 @@ if TYPE_CHECKING:
166
179
  RBLNMistralForCausalLM,
167
180
  RBLNPhiForCausalLM,
168
181
  RBLNQwen2ForCausalLM,
182
+ RBLNT5EncoderModel,
169
183
  RBLNT5ForConditionalGeneration,
170
184
  RBLNWav2Vec2ForCTC,
171
185
  RBLNWhisperForConditionalGeneration,
@@ -1 +1,16 @@
1
- __version__ = '0.1.12'
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '0.1.15'
16
+ __version_tuple__ = version_tuple = (0, 1, 15)
@@ -36,27 +36,47 @@ _import_structure = {
36
36
  "RBLNStableDiffusionPipeline",
37
37
  "RBLNStableDiffusionXLPipeline",
38
38
  "RBLNStableDiffusionImg2ImgPipeline",
39
+ "RBLNStableDiffusionInpaintPipeline",
39
40
  "RBLNStableDiffusionControlNetImg2ImgPipeline",
40
41
  "RBLNMultiControlNetModel",
41
42
  "RBLNStableDiffusionXLImg2ImgPipeline",
43
+ "RBLNStableDiffusionXLInpaintPipeline",
42
44
  "RBLNStableDiffusionControlNetPipeline",
43
45
  "RBLNStableDiffusionXLControlNetPipeline",
44
46
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
47
+ "RBLNStableDiffusion3Pipeline",
48
+ "RBLNStableDiffusion3Img2ImgPipeline",
49
+ "RBLNStableDiffusion3InpaintPipeline",
50
+ ],
51
+ "models": [
52
+ "RBLNAutoencoderKL",
53
+ "RBLNUNet2DConditionModel",
54
+ "RBLNControlNetModel",
55
+ "RBLNSD3Transformer2DModel",
45
56
  ],
46
- "models": ["RBLNAutoencoderKL", "RBLNUNet2DConditionModel", "RBLNControlNetModel"],
47
57
  }
48
58
 
49
59
  if TYPE_CHECKING:
50
- from .models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
60
+ from .models import (
61
+ RBLNAutoencoderKL,
62
+ RBLNControlNetModel,
63
+ RBLNSD3Transformer2DModel,
64
+ RBLNUNet2DConditionModel,
65
+ )
51
66
  from .pipelines import (
52
67
  RBLNMultiControlNetModel,
68
+ RBLNStableDiffusion3Img2ImgPipeline,
69
+ RBLNStableDiffusion3InpaintPipeline,
70
+ RBLNStableDiffusion3Pipeline,
53
71
  RBLNStableDiffusionControlNetImg2ImgPipeline,
54
72
  RBLNStableDiffusionControlNetPipeline,
55
73
  RBLNStableDiffusionImg2ImgPipeline,
74
+ RBLNStableDiffusionInpaintPipeline,
56
75
  RBLNStableDiffusionPipeline,
57
76
  RBLNStableDiffusionXLControlNetImg2ImgPipeline,
58
77
  RBLNStableDiffusionXLControlNetPipeline,
59
78
  RBLNStableDiffusionXLImg2ImgPipeline,
79
+ RBLNStableDiffusionXLInpaintPipeline,
60
80
  RBLNStableDiffusionXLPipeline,
61
81
  )
62
82
  else:
@@ -20,7 +20,38 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+ from typing import TYPE_CHECKING
23
24
 
24
- from .autoencoder_kl import RBLNAutoencoderKL
25
- from .controlnet import RBLNControlNetModel
26
- from .unet_2d_condition import RBLNUNet2DConditionModel
25
+ from transformers.utils import _LazyModule
26
+
27
+
28
+ _import_structure = {
29
+ "autoencoders": [
30
+ "RBLNAutoencoderKL",
31
+ ],
32
+ "unets": [
33
+ "RBLNUNet2DConditionModel",
34
+ ],
35
+ "controlnet": ["RBLNControlNetModel"],
36
+ "transformers": ["RBLNSD3Transformer2DModel"],
37
+ }
38
+ if TYPE_CHECKING:
39
+ from .autoencoders import (
40
+ RBLNAutoencoderKL,
41
+ )
42
+ from .controlnet import RBLNControlNetModel
43
+ from .transformers import (
44
+ RBLNSD3Transformer2DModel,
45
+ )
46
+ from .unets import (
47
+ RBLNUNet2DConditionModel,
48
+ )
49
+ else:
50
+ import sys
51
+
52
+ sys.modules[__name__] = _LazyModule(
53
+ __name__,
54
+ globals()["__file__"],
55
+ _import_structure,
56
+ module_spec=__spec__,
57
+ )
@@ -21,5 +21,4 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .streamers import BatchTextIteratorStreamer
25
- from .utils import RBLNGenerationMixin
24
+ from .autoencoder_kl import RBLNAutoencoderKL
@@ -22,20 +22,18 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from pathlib import Path
26
- from typing import TYPE_CHECKING, Any, Dict, List, Union
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
27
26
 
28
27
  import rebel
29
28
  import torch # noqa: I001
30
29
  from diffusers import AutoencoderKL
31
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
32
30
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
33
- from optimum.exporters import TasksManager
34
- from transformers import AutoConfig, AutoModel, PretrainedConfig
31
+ from transformers import PretrainedConfig
35
32
 
36
- from ...modeling_base import RBLNModel
37
- from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
38
- from ...utils.runtime_utils import RBLNPytorchRuntime
33
+ from ....modeling import RBLNModel
34
+ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
35
+ from ....modeling_diffusers import RBLNDiffusionMixin
36
+ from .vae import RBLNRuntimeVAEDecoder, RBLNRuntimeVAEEncoder, _VAEDecoder, _VAEEncoder
39
37
 
40
38
 
41
39
  if TYPE_CHECKING:
@@ -45,31 +43,22 @@ if TYPE_CHECKING:
45
43
  logger = logging.getLogger(__name__)
46
44
 
47
45
 
48
- class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
49
- def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
50
- moments = self.forward(x.contiguous())
51
- posterior = DiagonalGaussianDistribution(moments)
52
- return AutoencoderKLOutput(latent_dist=posterior)
53
-
54
-
55
- class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
56
- def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
57
- return (self.forward(z),)
58
-
59
-
60
46
  class RBLNAutoencoderKL(RBLNModel):
47
+ auto_model_class = AutoencoderKL
61
48
  config_name = "config.json"
49
+ hf_library_name = "diffusers"
62
50
 
63
51
  def __post_init__(self, **kwargs):
64
52
  super().__post_init__(**kwargs)
65
53
 
66
- self.rbln_use_encode = self.rbln_config.model_cfg["use_encode"]
67
- if self.rbln_use_encode:
54
+ if self.rbln_config.model_cfg.get("img2img_pipeline") or self.rbln_config.model_cfg.get("inpaint_pipeline"):
68
55
  self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
69
56
  self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
70
57
  else:
71
58
  self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[0], main_input_name="z")
72
59
 
60
+ self.image_size = self.rbln_config.model_cfg["sample_size"]
61
+
73
62
  @classmethod
74
63
  def get_compiled_model(cls, model, rbln_config: RBLNConfig):
75
64
  def compile_img2img():
@@ -91,39 +80,40 @@ class RBLNAutoencoderKL(RBLNModel):
91
80
 
92
81
  return dec_compiled_model
93
82
 
94
- if rbln_config.model_cfg.get("use_encode", False):
83
+ if rbln_config.model_cfg.get("img2img_pipeline") or rbln_config.model_cfg.get("inpaint_pipeline"):
95
84
  return compile_img2img()
96
85
  else:
97
86
  return compile_text2img()
98
87
 
99
88
  @classmethod
100
- def from_pretrained(cls, *args, **kwargs):
101
- def get_model_from_task(
102
- task: str,
103
- model_name_or_path: Union[str, Path],
104
- **kwargs,
105
- ):
106
- return AutoencoderKL.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
107
-
108
- tasktmp = TasksManager.get_model_from_task
109
- configtmp = AutoConfig.from_pretrained
110
- modeltmp = AutoModel.from_pretrained
111
- TasksManager.get_model_from_task = get_model_from_task
112
-
113
- if kwargs.get("export", None):
114
- # This is an ad-hoc to workaround save null values of the config.
115
- # if export, pure optimum(not optimum-rbln) loads config using AutoConfig
116
- # and diffusers model do not support loading by AutoConfig.
117
- AutoConfig.from_pretrained = lambda *args, **kwargs: None
89
+ def get_vae_sample_size(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
90
+ image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
91
+ if (image_size[0] is None) != (image_size[1] is None):
92
+ raise ValueError("Both image height and image width must be given or not given")
93
+ elif image_size[0] is None and image_size[1] is None:
94
+ if rbln_config["img2img_pipeline"]:
95
+ sample_size = pipe.vae.config.sample_size
96
+ elif rbln_config["inpaint_pipeline"]:
97
+ sample_size = pipe.unet.config.sample_size * pipe.vae_scale_factor
98
+ else:
99
+ # In case of text2img, sample size of vae decoder is determined by unet.
100
+ unet_sample_size = pipe.unet.config.sample_size
101
+ if isinstance(unet_sample_size, int):
102
+ sample_size = unet_sample_size * pipe.vae_scale_factor
103
+ else:
104
+ sample_size = (
105
+ unet_sample_size[0] * pipe.vae_scale_factor,
106
+ unet_sample_size[1] * pipe.vae_scale_factor,
107
+ )
118
108
  else:
119
- AutoConfig.from_pretrained = AutoencoderKL.load_config
109
+ sample_size = (image_size[0], image_size[1])
120
110
 
121
- AutoModel.from_pretrained = AutoencoderKL.from_pretrained
122
- rt = super().from_pretrained(*args, **kwargs)
123
- AutoConfig.from_pretrained = configtmp
124
- AutoModel.from_pretrained = modeltmp
125
- TasksManager.get_model_from_task = tasktmp
126
- return rt
111
+ return sample_size
112
+
113
+ @classmethod
114
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
115
+ rbln_config.update({"sample_size": cls.get_vae_sample_size(pipe, rbln_config)})
116
+ return rbln_config
127
117
 
128
118
  @classmethod
129
119
  def _get_rbln_config(
@@ -132,34 +122,43 @@ class RBLNAutoencoderKL(RBLNModel):
132
122
  model_config: "PretrainedConfig",
133
123
  rbln_kwargs: Dict[str, Any] = {},
134
124
  ) -> RBLNConfig:
135
- rbln_unet_sample_size = rbln_kwargs.get("unet_sample_size", None)
136
- rbln_img_width = rbln_kwargs.get("img_width", None)
137
- rbln_img_height = rbln_kwargs.get("img_height", None)
138
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
139
- rbln_use_encode = rbln_kwargs.get("use_encode", None)
140
- rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
125
+ rbln_batch_size = rbln_kwargs.get("batch_size")
126
+ sample_size = rbln_kwargs.get("sample_size")
127
+ is_img2img = rbln_kwargs.get("img2img_pipeline")
128
+ is_inpaint = rbln_kwargs.get("inpaint_pipeline")
141
129
 
142
130
  if rbln_batch_size is None:
143
131
  rbln_batch_size = 1
144
132
 
145
- model_cfg = {}
133
+ if sample_size is None:
134
+ sample_size = model_config.sample_size
135
+
136
+ if isinstance(sample_size, int):
137
+ sample_size = (sample_size, sample_size)
138
+
139
+ rbln_kwargs["sample_size"] = sample_size
140
+
141
+ if hasattr(model_config, "block_out_channels"):
142
+ vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
143
+ else:
144
+ # vae image processor default value 8 (int)
145
+ vae_scale_factor = 8
146
146
 
147
- if rbln_use_encode:
148
- model_cfg["img_width"] = rbln_img_width
149
- model_cfg["img_height"] = rbln_img_height
147
+ dec_shape = (sample_size[0] // vae_scale_factor, sample_size[1] // vae_scale_factor)
148
+ enc_shape = (sample_size[0], sample_size[1])
150
149
 
150
+ if is_img2img or is_inpaint:
151
151
  vae_enc_input_info = [
152
- ("x", [rbln_batch_size, model_config.in_channels, rbln_img_height, rbln_img_width], "float32")
152
+ (
153
+ "x",
154
+ [rbln_batch_size, model_config.in_channels, enc_shape[0], enc_shape[1]],
155
+ "float32",
156
+ )
153
157
  ]
154
158
  vae_dec_input_info = [
155
159
  (
156
160
  "z",
157
- [
158
- rbln_batch_size,
159
- model_config.latent_channels,
160
- rbln_img_height // rbln_vae_scale_factor,
161
- rbln_img_width // rbln_vae_scale_factor,
162
- ],
161
+ [rbln_batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
163
162
  "float32",
164
163
  )
165
164
  ]
@@ -173,33 +172,22 @@ class RBLNAutoencoderKL(RBLNModel):
173
172
  compile_cfgs=compile_cfgs,
174
173
  rbln_kwargs=rbln_kwargs,
175
174
  )
176
- rbln_config.model_cfg.update(model_cfg)
177
175
  return rbln_config
178
176
 
179
- if rbln_unet_sample_size is None:
180
- rbln_unet_sample_size = 64
181
-
182
- model_cfg["unet_sample_size"] = rbln_unet_sample_size
183
177
  vae_config = RBLNCompileConfig(
184
178
  input_info=[
185
179
  (
186
180
  "z",
187
- [
188
- rbln_batch_size,
189
- model_config.latent_channels,
190
- rbln_unet_sample_size,
191
- rbln_unet_sample_size,
192
- ],
181
+ [rbln_batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
193
182
  "float32",
194
183
  )
195
- ],
184
+ ]
196
185
  )
197
186
  rbln_config = RBLNConfig(
198
187
  rbln_cls=cls.__name__,
199
188
  compile_cfgs=[vae_config],
200
189
  rbln_kwargs=rbln_kwargs,
201
190
  )
202
- rbln_config.model_cfg.update(model_cfg)
203
191
  return rbln_config
204
192
 
205
193
  @classmethod
@@ -222,36 +210,3 @@ class RBLNAutoencoderKL(RBLNModel):
222
210
 
223
211
  def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
224
212
  return self.decoder.decode(z)
225
-
226
-
227
- class _VAEDecoder(torch.nn.Module):
228
- def __init__(self, vae: "AutoencoderKL"):
229
- super().__init__()
230
- self.vae = vae
231
-
232
- def forward(self, z):
233
- vae_out = self.vae.decode(z, return_dict=False)
234
- return vae_out
235
-
236
-
237
- class _VAEEncoder(torch.nn.Module):
238
- def __init__(self, vae: "AutoencoderKL"):
239
- super().__init__()
240
- self.vae = vae
241
-
242
- def encode(self, x: torch.FloatTensor, return_dict: bool = True):
243
- if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
244
- return self.tiled_encode(x, return_dict=return_dict)
245
-
246
- if self.use_slicing and x.shape[0] > 1:
247
- encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
248
- h = torch.cat(encoded_slices)
249
- else:
250
- h = self.encoder(x)
251
-
252
- moments = self.quant_conv(h)
253
- return moments
254
-
255
- def forward(self, x):
256
- vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
257
- return vae_out
@@ -0,0 +1,84 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+
25
+ import logging
26
+ from typing import TYPE_CHECKING
27
+
28
+ import torch # noqa: I001
29
+ from diffusers import AutoencoderKL
30
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
31
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
32
+
33
+ from ....utils.runtime_utils import RBLNPytorchRuntime
34
+
35
+
36
+ if TYPE_CHECKING:
37
+ import torch
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
43
+ def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
44
+ moments = self.forward(x.contiguous())
45
+ posterior = DiagonalGaussianDistribution(moments)
46
+ return AutoencoderKLOutput(latent_dist=posterior)
47
+
48
+
49
+ class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
50
+ def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
51
+ return (self.forward(z),)
52
+
53
+
54
+ class _VAEDecoder(torch.nn.Module):
55
+ def __init__(self, vae: "AutoencoderKL"):
56
+ super().__init__()
57
+ self.vae = vae
58
+
59
+ def forward(self, z):
60
+ vae_out = self.vae.decode(z, return_dict=False)
61
+ return vae_out
62
+
63
+
64
+ class _VAEEncoder(torch.nn.Module):
65
+ def __init__(self, vae: "AutoencoderKL"):
66
+ super().__init__()
67
+ self.vae = vae
68
+
69
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True):
70
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
71
+ return self.tiled_encode(x, return_dict=return_dict)
72
+
73
+ if self.use_slicing and x.shape[0] > 1:
74
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
75
+ h = torch.cat(encoded_slices)
76
+ else:
77
+ h = self.encoder(x)
78
+ if self.quant_conv is not None:
79
+ h = self.quant_conv(h)
80
+ return h
81
+
82
+ def forward(self, x):
83
+ vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
84
+ return vae_out