optimum-rbln 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. optimum/rbln/__init__.py +22 -12
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +44 -58
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +54 -14
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +78 -16
  13. optimum/rbln/diffusers/pipelines/__init__.py +22 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +5 -26
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -0
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -0
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -0
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -11
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +14 -6
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +14 -6
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  30. optimum/rbln/modeling.py +572 -0
  31. optimum/rbln/modeling_alias.py +1 -1
  32. optimum/rbln/modeling_base.py +164 -758
  33. optimum/rbln/modeling_diffusers.py +51 -122
  34. optimum/rbln/transformers/__init__.py +0 -2
  35. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  36. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  37. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  38. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  39. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  40. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -3
  41. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +672 -412
  42. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +38 -155
  43. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  44. optimum/rbln/transformers/models/exaone/exaone_architecture.py +61 -45
  45. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  46. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  47. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  48. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  49. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -75
  50. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -242
  51. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -6
  52. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -261
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  54. optimum/rbln/transformers/models/t5/modeling_t5.py +102 -4
  55. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  56. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
  57. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  58. optimum/rbln/transformers/utils/rbln_quantization.py +120 -3
  59. optimum/rbln/utils/decorator_utils.py +10 -6
  60. optimum/rbln/utils/hub.py +131 -0
  61. optimum/rbln/utils/import_utils.py +15 -1
  62. optimum/rbln/utils/model_utils.py +53 -0
  63. optimum/rbln/utils/runtime_utils.py +1 -1
  64. optimum/rbln/utils/submodule.py +114 -0
  65. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  66. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/RECORD +69 -66
  67. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  68. optimum/rbln/transformers/generation/streamers.py +0 -139
  69. optimum/rbln/transformers/generation/utils.py +0 -397
  70. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  71. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  72. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  73. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  74. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  75. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  76. optimum/rbln/utils/context.py +0 -58
  77. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  78. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  79. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -40,7 +40,7 @@ _import_structure = {
40
40
  "RBLNRobertaForMaskedLM",
41
41
  "RBLNViTForImageClassification",
42
42
  ],
43
- "modeling_base": [
43
+ "modeling": [
44
44
  "RBLNBaseModel",
45
45
  "RBLNModel",
46
46
  "RBLNModelForQuestionAnswering",
@@ -50,7 +50,6 @@ _import_structure = {
50
50
  "RBLNModelForMaskedLM",
51
51
  ],
52
52
  "transformers": [
53
- "BatchTextIteratorStreamer",
54
53
  "RBLNAutoModel",
55
54
  "RBLNAutoModelForAudioClassification",
56
55
  "RBLNAutoModelForCausalLM",
@@ -92,12 +91,18 @@ _import_structure = {
92
91
  "RBLNUNet2DConditionModel",
93
92
  "RBLNControlNetModel",
94
93
  "RBLNStableDiffusionImg2ImgPipeline",
94
+ "RBLNStableDiffusionInpaintPipeline",
95
95
  "RBLNStableDiffusionControlNetImg2ImgPipeline",
96
96
  "RBLNMultiControlNetModel",
97
97
  "RBLNStableDiffusionXLImg2ImgPipeline",
98
+ "RBLNStableDiffusionXLInpaintPipeline",
98
99
  "RBLNStableDiffusionControlNetPipeline",
99
100
  "RBLNStableDiffusionXLControlNetPipeline",
100
101
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
102
+ "RBLNSD3Transformer2DModel",
103
+ "RBLNStableDiffusion3Img2ImgPipeline",
104
+ "RBLNStableDiffusion3InpaintPipeline",
105
+ "RBLNStableDiffusion3Pipeline",
101
106
  ],
102
107
  "modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
103
108
  "modeling_diffusers": ["RBLNDiffusionMixin"],
@@ -108,16 +113,31 @@ if TYPE_CHECKING:
108
113
  RBLNAutoencoderKL,
109
114
  RBLNControlNetModel,
110
115
  RBLNMultiControlNetModel,
116
+ RBLNSD3Transformer2DModel,
117
+ RBLNStableDiffusion3Img2ImgPipeline,
118
+ RBLNStableDiffusion3InpaintPipeline,
119
+ RBLNStableDiffusion3Pipeline,
111
120
  RBLNStableDiffusionControlNetImg2ImgPipeline,
112
121
  RBLNStableDiffusionControlNetPipeline,
113
122
  RBLNStableDiffusionImg2ImgPipeline,
123
+ RBLNStableDiffusionInpaintPipeline,
114
124
  RBLNStableDiffusionPipeline,
115
125
  RBLNStableDiffusionXLControlNetImg2ImgPipeline,
116
126
  RBLNStableDiffusionXLControlNetPipeline,
117
127
  RBLNStableDiffusionXLImg2ImgPipeline,
128
+ RBLNStableDiffusionXLInpaintPipeline,
118
129
  RBLNStableDiffusionXLPipeline,
119
130
  RBLNUNet2DConditionModel,
120
131
  )
132
+ from .modeling import (
133
+ RBLNBaseModel,
134
+ RBLNModel,
135
+ RBLNModelForAudioClassification,
136
+ RBLNModelForImageClassification,
137
+ RBLNModelForMaskedLM,
138
+ RBLNModelForQuestionAnswering,
139
+ RBLNModelForSequenceClassification,
140
+ )
121
141
  from .modeling_alias import (
122
142
  RBLNASTForAudioClassification,
123
143
  RBLNBertForQuestionAnswering,
@@ -128,19 +148,9 @@ if TYPE_CHECKING:
128
148
  RBLNViTForImageClassification,
129
149
  RBLNXLMRobertaForSequenceClassification,
130
150
  )
131
- from .modeling_base import (
132
- RBLNBaseModel,
133
- RBLNModel,
134
- RBLNModelForAudioClassification,
135
- RBLNModelForImageClassification,
136
- RBLNModelForMaskedLM,
137
- RBLNModelForQuestionAnswering,
138
- RBLNModelForSequenceClassification,
139
- )
140
151
  from .modeling_config import RBLNCompileConfig, RBLNConfig
141
152
  from .modeling_diffusers import RBLNDiffusionMixin
142
153
  from .transformers import (
143
- BatchTextIteratorStreamer,
144
154
  RBLNAutoModel,
145
155
  RBLNAutoModelForAudioClassification,
146
156
  RBLNAutoModelForCausalLM,
@@ -1 +1,16 @@
1
- __version__ = '0.1.13'
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '0.1.15'
16
+ __version_tuple__ = version_tuple = (0, 1, 15)
@@ -36,27 +36,47 @@ _import_structure = {
36
36
  "RBLNStableDiffusionPipeline",
37
37
  "RBLNStableDiffusionXLPipeline",
38
38
  "RBLNStableDiffusionImg2ImgPipeline",
39
+ "RBLNStableDiffusionInpaintPipeline",
39
40
  "RBLNStableDiffusionControlNetImg2ImgPipeline",
40
41
  "RBLNMultiControlNetModel",
41
42
  "RBLNStableDiffusionXLImg2ImgPipeline",
43
+ "RBLNStableDiffusionXLInpaintPipeline",
42
44
  "RBLNStableDiffusionControlNetPipeline",
43
45
  "RBLNStableDiffusionXLControlNetPipeline",
44
46
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
47
+ "RBLNStableDiffusion3Pipeline",
48
+ "RBLNStableDiffusion3Img2ImgPipeline",
49
+ "RBLNStableDiffusion3InpaintPipeline",
50
+ ],
51
+ "models": [
52
+ "RBLNAutoencoderKL",
53
+ "RBLNUNet2DConditionModel",
54
+ "RBLNControlNetModel",
55
+ "RBLNSD3Transformer2DModel",
45
56
  ],
46
- "models": ["RBLNAutoencoderKL", "RBLNUNet2DConditionModel", "RBLNControlNetModel"],
47
57
  }
48
58
 
49
59
  if TYPE_CHECKING:
50
- from .models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
60
+ from .models import (
61
+ RBLNAutoencoderKL,
62
+ RBLNControlNetModel,
63
+ RBLNSD3Transformer2DModel,
64
+ RBLNUNet2DConditionModel,
65
+ )
51
66
  from .pipelines import (
52
67
  RBLNMultiControlNetModel,
68
+ RBLNStableDiffusion3Img2ImgPipeline,
69
+ RBLNStableDiffusion3InpaintPipeline,
70
+ RBLNStableDiffusion3Pipeline,
53
71
  RBLNStableDiffusionControlNetImg2ImgPipeline,
54
72
  RBLNStableDiffusionControlNetPipeline,
55
73
  RBLNStableDiffusionImg2ImgPipeline,
74
+ RBLNStableDiffusionInpaintPipeline,
56
75
  RBLNStableDiffusionPipeline,
57
76
  RBLNStableDiffusionXLControlNetImg2ImgPipeline,
58
77
  RBLNStableDiffusionXLControlNetPipeline,
59
78
  RBLNStableDiffusionXLImg2ImgPipeline,
79
+ RBLNStableDiffusionXLInpaintPipeline,
60
80
  RBLNStableDiffusionXLPipeline,
61
81
  )
62
82
  else:
@@ -20,7 +20,38 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+ from typing import TYPE_CHECKING
23
24
 
24
- from .autoencoder_kl import RBLNAutoencoderKL
25
- from .controlnet import RBLNControlNetModel
26
- from .unet_2d_condition import RBLNUNet2DConditionModel
25
+ from transformers.utils import _LazyModule
26
+
27
+
28
+ _import_structure = {
29
+ "autoencoders": [
30
+ "RBLNAutoencoderKL",
31
+ ],
32
+ "unets": [
33
+ "RBLNUNet2DConditionModel",
34
+ ],
35
+ "controlnet": ["RBLNControlNetModel"],
36
+ "transformers": ["RBLNSD3Transformer2DModel"],
37
+ }
38
+ if TYPE_CHECKING:
39
+ from .autoencoders import (
40
+ RBLNAutoencoderKL,
41
+ )
42
+ from .controlnet import RBLNControlNetModel
43
+ from .transformers import (
44
+ RBLNSD3Transformer2DModel,
45
+ )
46
+ from .unets import (
47
+ RBLNUNet2DConditionModel,
48
+ )
49
+ else:
50
+ import sys
51
+
52
+ sys.modules[__name__] = _LazyModule(
53
+ __name__,
54
+ globals()["__file__"],
55
+ _import_structure,
56
+ module_spec=__spec__,
57
+ )
@@ -21,5 +21,4 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .streamers import BatchTextIteratorStreamer
25
- from .utils import RBLNGenerationMixin
24
+ from .autoencoder_kl import RBLNAutoencoderKL
@@ -22,19 +22,18 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from typing import TYPE_CHECKING, Any, Dict, List, Union
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
26
26
 
27
27
  import rebel
28
28
  import torch # noqa: I001
29
29
  from diffusers import AutoencoderKL
30
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
31
30
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
32
31
  from transformers import PretrainedConfig
33
32
 
34
- from ...modeling_base import RBLNModel
35
- from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
36
- from ...utils.context import override_auto_classes
37
- from ...utils.runtime_utils import RBLNPytorchRuntime
33
+ from ....modeling import RBLNModel
34
+ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
35
+ from ....modeling_diffusers import RBLNDiffusionMixin
36
+ from .vae import RBLNRuntimeVAEDecoder, RBLNRuntimeVAEEncoder, _VAEDecoder, _VAEEncoder
38
37
 
39
38
 
40
39
  if TYPE_CHECKING:
@@ -44,30 +43,22 @@ if TYPE_CHECKING:
44
43
  logger = logging.getLogger(__name__)
45
44
 
46
45
 
47
- class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
48
- def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
49
- moments = self.forward(x.contiguous())
50
- posterior = DiagonalGaussianDistribution(moments)
51
- return AutoencoderKLOutput(latent_dist=posterior)
52
-
53
-
54
- class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
55
- def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
56
- return (self.forward(z),)
57
-
58
-
59
46
  class RBLNAutoencoderKL(RBLNModel):
47
+ auto_model_class = AutoencoderKL
60
48
  config_name = "config.json"
49
+ hf_library_name = "diffusers"
61
50
 
62
51
  def __post_init__(self, **kwargs):
63
52
  super().__post_init__(**kwargs)
64
53
 
65
- if self.rbln_config.model_cfg.get("img2img_pipeline"):
54
+ if self.rbln_config.model_cfg.get("img2img_pipeline") or self.rbln_config.model_cfg.get("inpaint_pipeline"):
66
55
  self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
67
56
  self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
68
57
  else:
69
58
  self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[0], main_input_name="z")
70
59
 
60
+ self.image_size = self.rbln_config.model_cfg["sample_size"]
61
+
71
62
  @classmethod
72
63
  def get_compiled_model(cls, model, rbln_config: RBLNConfig):
73
64
  def compile_img2img():
@@ -89,16 +80,40 @@ class RBLNAutoencoderKL(RBLNModel):
89
80
 
90
81
  return dec_compiled_model
91
82
 
92
- if rbln_config.model_cfg.get("img2img_pipeline"):
83
+ if rbln_config.model_cfg.get("img2img_pipeline") or rbln_config.model_cfg.get("inpaint_pipeline"):
93
84
  return compile_img2img()
94
85
  else:
95
86
  return compile_text2img()
96
87
 
97
88
  @classmethod
98
- def from_pretrained(cls, *args, **kwargs):
99
- with override_auto_classes(config_func=AutoencoderKL.load_config, model_func=AutoencoderKL.from_pretrained):
100
- rt = super().from_pretrained(*args, **kwargs)
101
- return rt
89
+ def get_vae_sample_size(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
90
+ image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
91
+ if (image_size[0] is None) != (image_size[1] is None):
92
+ raise ValueError("Both image height and image width must be given or not given")
93
+ elif image_size[0] is None and image_size[1] is None:
94
+ if rbln_config["img2img_pipeline"]:
95
+ sample_size = pipe.vae.config.sample_size
96
+ elif rbln_config["inpaint_pipeline"]:
97
+ sample_size = pipe.unet.config.sample_size * pipe.vae_scale_factor
98
+ else:
99
+ # In case of text2img, sample size of vae decoder is determined by unet.
100
+ unet_sample_size = pipe.unet.config.sample_size
101
+ if isinstance(unet_sample_size, int):
102
+ sample_size = unet_sample_size * pipe.vae_scale_factor
103
+ else:
104
+ sample_size = (
105
+ unet_sample_size[0] * pipe.vae_scale_factor,
106
+ unet_sample_size[1] * pipe.vae_scale_factor,
107
+ )
108
+ else:
109
+ sample_size = (image_size[0], image_size[1])
110
+
111
+ return sample_size
112
+
113
+ @classmethod
114
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
115
+ rbln_config.update({"sample_size": cls.get_vae_sample_size(pipe, rbln_config)})
116
+ return rbln_config
102
117
 
103
118
  @classmethod
104
119
  def _get_rbln_config(
@@ -109,6 +124,8 @@ class RBLNAutoencoderKL(RBLNModel):
109
124
  ) -> RBLNConfig:
110
125
  rbln_batch_size = rbln_kwargs.get("batch_size")
111
126
  sample_size = rbln_kwargs.get("sample_size")
127
+ is_img2img = rbln_kwargs.get("img2img_pipeline")
128
+ is_inpaint = rbln_kwargs.get("inpaint_pipeline")
112
129
 
113
130
  if rbln_batch_size is None:
114
131
  rbln_batch_size = 1
@@ -119,6 +136,8 @@ class RBLNAutoencoderKL(RBLNModel):
119
136
  if isinstance(sample_size, int):
120
137
  sample_size = (sample_size, sample_size)
121
138
 
139
+ rbln_kwargs["sample_size"] = sample_size
140
+
122
141
  if hasattr(model_config, "block_out_channels"):
123
142
  vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
124
143
  else:
@@ -128,7 +147,7 @@ class RBLNAutoencoderKL(RBLNModel):
128
147
  dec_shape = (sample_size[0] // vae_scale_factor, sample_size[1] // vae_scale_factor)
129
148
  enc_shape = (sample_size[0], sample_size[1])
130
149
 
131
- if rbln_kwargs["img2img_pipeline"]:
150
+ if is_img2img or is_inpaint:
132
151
  vae_enc_input_info = [
133
152
  (
134
153
  "x",
@@ -191,36 +210,3 @@ class RBLNAutoencoderKL(RBLNModel):
191
210
 
192
211
  def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
193
212
  return self.decoder.decode(z)
194
-
195
-
196
- class _VAEDecoder(torch.nn.Module):
197
- def __init__(self, vae: "AutoencoderKL"):
198
- super().__init__()
199
- self.vae = vae
200
-
201
- def forward(self, z):
202
- vae_out = self.vae.decode(z, return_dict=False)
203
- return vae_out
204
-
205
-
206
- class _VAEEncoder(torch.nn.Module):
207
- def __init__(self, vae: "AutoencoderKL"):
208
- super().__init__()
209
- self.vae = vae
210
-
211
- def encode(self, x: torch.FloatTensor, return_dict: bool = True):
212
- if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
213
- return self.tiled_encode(x, return_dict=return_dict)
214
-
215
- if self.use_slicing and x.shape[0] > 1:
216
- encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
217
- h = torch.cat(encoded_slices)
218
- else:
219
- h = self.encoder(x)
220
-
221
- moments = self.quant_conv(h)
222
- return moments
223
-
224
- def forward(self, x):
225
- vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
226
- return vae_out
@@ -0,0 +1,84 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+
25
+ import logging
26
+ from typing import TYPE_CHECKING
27
+
28
+ import torch # noqa: I001
29
+ from diffusers import AutoencoderKL
30
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
31
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
32
+
33
+ from ....utils.runtime_utils import RBLNPytorchRuntime
34
+
35
+
36
+ if TYPE_CHECKING:
37
+ import torch
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
43
+ def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
44
+ moments = self.forward(x.contiguous())
45
+ posterior = DiagonalGaussianDistribution(moments)
46
+ return AutoencoderKLOutput(latent_dist=posterior)
47
+
48
+
49
+ class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
50
+ def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
51
+ return (self.forward(z),)
52
+
53
+
54
+ class _VAEDecoder(torch.nn.Module):
55
+ def __init__(self, vae: "AutoencoderKL"):
56
+ super().__init__()
57
+ self.vae = vae
58
+
59
+ def forward(self, z):
60
+ vae_out = self.vae.decode(z, return_dict=False)
61
+ return vae_out
62
+
63
+
64
+ class _VAEEncoder(torch.nn.Module):
65
+ def __init__(self, vae: "AutoencoderKL"):
66
+ super().__init__()
67
+ self.vae = vae
68
+
69
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True):
70
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
71
+ return self.tiled_encode(x, return_dict=return_dict)
72
+
73
+ if self.use_slicing and x.shape[0] > 1:
74
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
75
+ h = torch.cat(encoded_slices)
76
+ else:
77
+ h = self.encoder(x)
78
+ if self.quant_conv is not None:
79
+ h = self.quant_conv(h)
80
+ return h
81
+
82
+ def forward(self, x):
83
+ vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
84
+ return vae_out
@@ -21,6 +21,7 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ import importlib
24
25
  import logging
25
26
  from typing import TYPE_CHECKING, Any, Dict, Optional, Union
26
27
 
@@ -28,9 +29,9 @@ import torch
28
29
  from diffusers import ControlNetModel
29
30
  from transformers import PretrainedConfig
30
31
 
31
- from ...modeling_base import RBLNModel
32
+ from ...modeling import RBLNModel
32
33
  from ...modeling_config import RBLNCompileConfig, RBLNConfig
33
- from ...utils.context import override_auto_classes
34
+ from ...modeling_diffusers import RBLNDiffusionMixin
34
35
 
35
36
 
36
37
  if TYPE_CHECKING:
@@ -104,21 +105,15 @@ class _ControlNetModel_Cross_Attention(torch.nn.Module):
104
105
 
105
106
 
106
107
  class RBLNControlNetModel(RBLNModel):
108
+ hf_library_name = "diffusers"
109
+ auto_model_class = ControlNetModel
110
+
107
111
  def __post_init__(self, **kwargs):
108
112
  super().__post_init__(**kwargs)
109
113
  self.use_encoder_hidden_states = any(
110
114
  item[0] == "encoder_hidden_states" for item in self.rbln_config.compile_cfgs[0].input_info
111
115
  )
112
116
 
113
- @classmethod
114
- def from_pretrained(cls, *args, **kwargs):
115
- with override_auto_classes(
116
- config_func=ControlNetModel.load_config,
117
- model_func=ControlNetModel.from_pretrained,
118
- ):
119
- rt = super().from_pretrained(*args, **kwargs)
120
- return rt
121
-
122
117
  @classmethod
123
118
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
124
119
  use_encoder_hidden_states = False
@@ -131,6 +126,38 @@ class RBLNControlNetModel(RBLNModel):
131
126
  else:
132
127
  return _ControlNetModel(model).eval()
133
128
 
129
+ @classmethod
130
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
131
+ rbln_vae_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.vae.__class__.__name__}")
132
+ rbln_unet_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.unet.__class__.__name__}")
133
+ text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
134
+
135
+ batch_size = rbln_config.get("batch_size")
136
+ if not batch_size:
137
+ do_classifier_free_guidance = (
138
+ rbln_config.get("guidance_scale", 5.0) > 1.0 and pipe.unet.config.time_cond_proj_dim is None
139
+ )
140
+ batch_size = 2 if do_classifier_free_guidance else 1
141
+ else:
142
+ if rbln_config.get("guidance_scale"):
143
+ logger.warning(
144
+ "guidance_scale is ignored because batch size is explicitly specified. "
145
+ "To ensure consistent behavior, consider removing the guidance scale or "
146
+ "adjusting the batch size configuration as needed."
147
+ )
148
+
149
+ rbln_config.update(
150
+ {
151
+ "max_seq_len": pipe.text_encoder.config.max_position_embeddings,
152
+ "text_model_hidden_size": text_model_hidden_size,
153
+ "vae_sample_size": rbln_vae_cls.get_vae_sample_size(pipe, rbln_config),
154
+ "unet_sample_size": rbln_unet_cls.get_unet_sample_size(pipe, rbln_config),
155
+ "batch_size": batch_size,
156
+ }
157
+ )
158
+
159
+ return rbln_config
160
+
134
161
  @classmethod
135
162
  def _get_rbln_config(
136
163
  cls,
@@ -207,6 +234,10 @@ class RBLNControlNetModel(RBLNModel):
207
234
 
208
235
  return rbln_config
209
236
 
237
+ @property
238
+ def compiled_batch_size(self):
239
+ return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
240
+
210
241
  def forward(
211
242
  self,
212
243
  sample: torch.FloatTensor,
@@ -217,9 +248,18 @@ class RBLNControlNetModel(RBLNModel):
217
248
  added_cond_kwargs: Dict[str, torch.Tensor] = {},
218
249
  **kwargs,
219
250
  ):
220
- """
221
- The [`ControlNetModel`] forward method.
222
- """
251
+ sample_batch_size = sample.size()[0]
252
+ compiled_batch_size = self.compiled_batch_size
253
+ if sample_batch_size != compiled_batch_size and (
254
+ sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
255
+ ):
256
+ raise ValueError(
257
+ f"Mismatch between ControlNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
258
+ "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
259
+ "Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
260
+ "For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
261
+ )
262
+
223
263
  added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
224
264
  if self.use_encoder_hidden_states:
225
265
  output = super().forward(
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .transformer_sd3 import RBLNSD3Transformer2DModel