optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. optimum/rbln/__init__.py +41 -38
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +26 -2
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
  5. optimum/rbln/diffusers/models/__init__.py +36 -3
  6. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  7. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
  8. optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
  9. optimum/rbln/diffusers/models/controlnet.py +54 -14
  10. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  11. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  12. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  13. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
  14. optimum/rbln/diffusers/pipelines/__init__.py +23 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
  19. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
  23. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
  31. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  32. optimum/rbln/modeling.py +238 -0
  33. optimum/rbln/modeling_base.py +186 -760
  34. optimum/rbln/modeling_config.py +31 -7
  35. optimum/rbln/ops/__init__.py +26 -0
  36. optimum/rbln/ops/attn.py +221 -0
  37. optimum/rbln/ops/flash_attn.py +70 -0
  38. optimum/rbln/ops/kv_cache_update.py +69 -0
  39. optimum/rbln/transformers/__init__.py +20 -2
  40. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  41. optimum/rbln/transformers/modeling_generic.py +385 -0
  42. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  43. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  44. optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
  45. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  46. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  47. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
  48. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  49. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  50. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
  52. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
  53. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  54. optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
  55. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  56. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  57. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  58. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  59. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  60. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
  61. optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
  62. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
  63. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  64. optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
  65. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  66. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  68. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  69. optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  71. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  72. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  73. optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
  74. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  75. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  76. optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
  77. optimum/rbln/utils/decorator_utils.py +51 -11
  78. optimum/rbln/utils/hub.py +131 -0
  79. optimum/rbln/utils/import_utils.py +22 -1
  80. optimum/rbln/utils/logging.py +37 -0
  81. optimum/rbln/utils/model_utils.py +52 -0
  82. optimum/rbln/utils/runtime_utils.py +10 -4
  83. optimum/rbln/utils/save_utils.py +17 -0
  84. optimum/rbln/utils/submodule.py +137 -0
  85. optimum_rbln-0.2.0.dist-info/METADATA +117 -0
  86. optimum_rbln-0.2.0.dist-info/RECORD +114 -0
  87. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
  88. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
  89. optimum/rbln/transformers/cache_utils.py +0 -107
  90. optimum/rbln/transformers/generation/streamers.py +0 -139
  91. optimum/rbln/transformers/generation/utils.py +0 -397
  92. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  93. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  94. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  95. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  96. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  97. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  98. optimum/rbln/utils/context.py +0 -58
  99. optimum/rbln/utils/timer_utils.py +0 -43
  100. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  101. optimum_rbln-0.1.13.dist-info/RECORD +0 -107
  102. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  103. optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -22,19 +22,18 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from typing import TYPE_CHECKING, Any, Dict, List, Union
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
26
26
 
27
27
  import rebel
28
28
  import torch # noqa: I001
29
29
  from diffusers import AutoencoderKL
30
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
31
30
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
32
31
  from transformers import PretrainedConfig
33
32
 
34
- from ...modeling_base import RBLNModel
35
- from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
36
- from ...utils.context import override_auto_classes
37
- from ...utils.runtime_utils import RBLNPytorchRuntime
33
+ from ....modeling import RBLNModel
34
+ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
35
+ from ...modeling_diffusers import RBLNDiffusionMixin
36
+ from .vae import RBLNRuntimeVAEDecoder, RBLNRuntimeVAEEncoder, _VAEDecoder, _VAEEncoder
38
37
 
39
38
 
40
39
  if TYPE_CHECKING:
@@ -44,30 +43,22 @@ if TYPE_CHECKING:
44
43
  logger = logging.getLogger(__name__)
45
44
 
46
45
 
47
- class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
48
- def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
49
- moments = self.forward(x.contiguous())
50
- posterior = DiagonalGaussianDistribution(moments)
51
- return AutoencoderKLOutput(latent_dist=posterior)
52
-
53
-
54
- class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
55
- def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
56
- return (self.forward(z),)
57
-
58
-
59
46
  class RBLNAutoencoderKL(RBLNModel):
47
+ auto_model_class = AutoencoderKL
60
48
  config_name = "config.json"
49
+ hf_library_name = "diffusers"
61
50
 
62
51
  def __post_init__(self, **kwargs):
63
52
  super().__post_init__(**kwargs)
64
53
 
65
- if self.rbln_config.model_cfg.get("img2img_pipeline"):
54
+ if self.rbln_config.model_cfg.get("img2img_pipeline") or self.rbln_config.model_cfg.get("inpaint_pipeline"):
66
55
  self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
67
56
  self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
68
57
  else:
69
58
  self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[0], main_input_name="z")
70
59
 
60
+ self.image_size = self.rbln_config.model_cfg["sample_size"]
61
+
71
62
  @classmethod
72
63
  def get_compiled_model(cls, model, rbln_config: RBLNConfig):
73
64
  def compile_img2img():
@@ -89,16 +80,53 @@ class RBLNAutoencoderKL(RBLNModel):
89
80
 
90
81
  return dec_compiled_model
91
82
 
92
- if rbln_config.model_cfg.get("img2img_pipeline"):
83
+ if rbln_config.model_cfg.get("img2img_pipeline") or rbln_config.model_cfg.get("inpaint_pipeline"):
93
84
  return compile_img2img()
94
85
  else:
95
86
  return compile_text2img()
96
87
 
97
88
  @classmethod
98
- def from_pretrained(cls, *args, **kwargs):
99
- with override_auto_classes(config_func=AutoencoderKL.load_config, model_func=AutoencoderKL.from_pretrained):
100
- rt = super().from_pretrained(*args, **kwargs)
101
- return rt
89
+ def get_vae_sample_size(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
90
+ image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
91
+ noise_module = getattr(pipe, "unet", None) or getattr(pipe, "transformer", None)
92
+ vae_scale_factor = (
93
+ pipe.vae_scale_factor
94
+ if hasattr(pipe, "vae_scale_factor")
95
+ else 2 ** (len(pipe.vae.config.block_out_channels) - 1)
96
+ )
97
+
98
+ if noise_module is None:
99
+ raise AttributeError(
100
+ "Cannot find noise processing or predicting module attributes. ex. U-Net, Transformer, ..."
101
+ )
102
+
103
+ if (image_size[0] is None) != (image_size[1] is None):
104
+ raise ValueError("Both image height and image width must be given or not given")
105
+
106
+ elif image_size[0] is None and image_size[1] is None:
107
+ if rbln_config["img2img_pipeline"]:
108
+ sample_size = noise_module.config.sample_size
109
+ elif rbln_config["inpaint_pipeline"]:
110
+ sample_size = noise_module.config.sample_size * vae_scale_factor
111
+ else:
112
+ # In case of text2img, sample size of vae decoder is determined by unet.
113
+ noise_module_sample_size = noise_module.config.sample_size
114
+ if isinstance(noise_module_sample_size, int):
115
+ sample_size = noise_module_sample_size * vae_scale_factor
116
+ else:
117
+ sample_size = (
118
+ noise_module_sample_size[0] * vae_scale_factor,
119
+ noise_module_sample_size[1] * vae_scale_factor,
120
+ )
121
+ else:
122
+ sample_size = (image_size[0], image_size[1])
123
+
124
+ return sample_size
125
+
126
+ @classmethod
127
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
128
+ rbln_config.update({"sample_size": cls.get_vae_sample_size(pipe, rbln_config)})
129
+ return rbln_config
102
130
 
103
131
  @classmethod
104
132
  def _get_rbln_config(
@@ -109,6 +137,8 @@ class RBLNAutoencoderKL(RBLNModel):
109
137
  ) -> RBLNConfig:
110
138
  rbln_batch_size = rbln_kwargs.get("batch_size")
111
139
  sample_size = rbln_kwargs.get("sample_size")
140
+ is_img2img = rbln_kwargs.get("img2img_pipeline")
141
+ is_inpaint = rbln_kwargs.get("inpaint_pipeline")
112
142
 
113
143
  if rbln_batch_size is None:
114
144
  rbln_batch_size = 1
@@ -119,6 +149,8 @@ class RBLNAutoencoderKL(RBLNModel):
119
149
  if isinstance(sample_size, int):
120
150
  sample_size = (sample_size, sample_size)
121
151
 
152
+ rbln_kwargs["sample_size"] = sample_size
153
+
122
154
  if hasattr(model_config, "block_out_channels"):
123
155
  vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
124
156
  else:
@@ -128,7 +160,7 @@ class RBLNAutoencoderKL(RBLNModel):
128
160
  dec_shape = (sample_size[0] // vae_scale_factor, sample_size[1] // vae_scale_factor)
129
161
  enc_shape = (sample_size[0], sample_size[1])
130
162
 
131
- if rbln_kwargs["img2img_pipeline"]:
163
+ if is_img2img or is_inpaint:
132
164
  vae_enc_input_info = [
133
165
  (
134
166
  "x",
@@ -173,15 +205,28 @@ class RBLNAutoencoderKL(RBLNModel):
173
205
 
174
206
  @classmethod
175
207
  def _create_runtimes(
176
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
208
+ cls,
209
+ compiled_models: List[rebel.RBLNCompiledModel],
210
+ rbln_device_map: Dict[str, int],
211
+ activate_profiler: Optional[bool] = None,
177
212
  ) -> List[rebel.Runtime]:
178
213
  if len(compiled_models) == 1:
214
+ if DEFAULT_COMPILED_MODEL_NAME not in rbln_device_map:
215
+ cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
216
+
179
217
  device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
180
- return [compiled_models[0].create_runtime(tensor_type="pt", device=device_val)]
218
+ return [
219
+ compiled_models[0].create_runtime(
220
+ tensor_type="pt", device=device_val, activate_profiler=activate_profiler
221
+ )
222
+ ]
223
+
224
+ if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
225
+ cls._raise_missing_compiled_file_error(["encoder", "decoder"])
181
226
 
182
227
  device_vals = [rbln_device_map["encoder"], rbln_device_map["decoder"]]
183
228
  return [
184
- compiled_model.create_runtime(tensor_type="pt", device=device_val)
229
+ compiled_model.create_runtime(tensor_type="pt", device=device_val, activate_profiler=activate_profiler)
185
230
  for compiled_model, device_val in zip(compiled_models, device_vals)
186
231
  ]
187
232
 
@@ -191,36 +236,3 @@ class RBLNAutoencoderKL(RBLNModel):
191
236
 
192
237
  def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
193
238
  return self.decoder.decode(z)
194
-
195
-
196
- class _VAEDecoder(torch.nn.Module):
197
- def __init__(self, vae: "AutoencoderKL"):
198
- super().__init__()
199
- self.vae = vae
200
-
201
- def forward(self, z):
202
- vae_out = self.vae.decode(z, return_dict=False)
203
- return vae_out
204
-
205
-
206
- class _VAEEncoder(torch.nn.Module):
207
- def __init__(self, vae: "AutoencoderKL"):
208
- super().__init__()
209
- self.vae = vae
210
-
211
- def encode(self, x: torch.FloatTensor, return_dict: bool = True):
212
- if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
213
- return self.tiled_encode(x, return_dict=return_dict)
214
-
215
- if self.use_slicing and x.shape[0] > 1:
216
- encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
217
- h = torch.cat(encoded_slices)
218
- else:
219
- h = self.encoder(x)
220
-
221
- moments = self.quant_conv(h)
222
- return moments
223
-
224
- def forward(self, x):
225
- vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
226
- return vae_out
@@ -0,0 +1,83 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import logging
25
+ from typing import TYPE_CHECKING
26
+
27
+ import torch # noqa: I001
28
+ from diffusers import AutoencoderKL
29
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
30
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
31
+
32
+ from ....utils.runtime_utils import RBLNPytorchRuntime
33
+
34
+
35
+ if TYPE_CHECKING:
36
+ import torch
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
42
+ def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
43
+ moments = self.forward(x.contiguous())
44
+ posterior = DiagonalGaussianDistribution(moments)
45
+ return AutoencoderKLOutput(latent_dist=posterior)
46
+
47
+
48
+ class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
49
+ def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
50
+ return (self.forward(z),)
51
+
52
+
53
+ class _VAEDecoder(torch.nn.Module):
54
+ def __init__(self, vae: "AutoencoderKL"):
55
+ super().__init__()
56
+ self.vae = vae
57
+
58
+ def forward(self, z):
59
+ vae_out = self.vae.decode(z, return_dict=False)
60
+ return vae_out
61
+
62
+
63
+ class _VAEEncoder(torch.nn.Module):
64
+ def __init__(self, vae: "AutoencoderKL"):
65
+ super().__init__()
66
+ self.vae = vae
67
+
68
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True):
69
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
70
+ return self.tiled_encode(x, return_dict=return_dict)
71
+
72
+ if self.use_slicing and x.shape[0] > 1:
73
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
74
+ h = torch.cat(encoded_slices)
75
+ else:
76
+ h = self.encoder(x)
77
+ if self.quant_conv is not None:
78
+ h = self.quant_conv(h)
79
+ return h
80
+
81
+ def forward(self, x):
82
+ vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
83
+ return vae_out
@@ -21,6 +21,7 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ import importlib
24
25
  import logging
25
26
  from typing import TYPE_CHECKING, Any, Dict, Optional, Union
26
27
 
@@ -28,9 +29,9 @@ import torch
28
29
  from diffusers import ControlNetModel
29
30
  from transformers import PretrainedConfig
30
31
 
31
- from ...modeling_base import RBLNModel
32
+ from ...modeling import RBLNModel
32
33
  from ...modeling_config import RBLNCompileConfig, RBLNConfig
33
- from ...utils.context import override_auto_classes
34
+ from ..modeling_diffusers import RBLNDiffusionMixin
34
35
 
35
36
 
36
37
  if TYPE_CHECKING:
@@ -104,21 +105,15 @@ class _ControlNetModel_Cross_Attention(torch.nn.Module):
104
105
 
105
106
 
106
107
  class RBLNControlNetModel(RBLNModel):
108
+ hf_library_name = "diffusers"
109
+ auto_model_class = ControlNetModel
110
+
107
111
  def __post_init__(self, **kwargs):
108
112
  super().__post_init__(**kwargs)
109
113
  self.use_encoder_hidden_states = any(
110
114
  item[0] == "encoder_hidden_states" for item in self.rbln_config.compile_cfgs[0].input_info
111
115
  )
112
116
 
113
- @classmethod
114
- def from_pretrained(cls, *args, **kwargs):
115
- with override_auto_classes(
116
- config_func=ControlNetModel.load_config,
117
- model_func=ControlNetModel.from_pretrained,
118
- ):
119
- rt = super().from_pretrained(*args, **kwargs)
120
- return rt
121
-
122
117
  @classmethod
123
118
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
124
119
  use_encoder_hidden_states = False
@@ -131,6 +126,38 @@ class RBLNControlNetModel(RBLNModel):
131
126
  else:
132
127
  return _ControlNetModel(model).eval()
133
128
 
129
+ @classmethod
130
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
131
+ rbln_vae_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.vae.__class__.__name__}")
132
+ rbln_unet_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.unet.__class__.__name__}")
133
+ text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
134
+
135
+ batch_size = rbln_config.get("batch_size")
136
+ if not batch_size:
137
+ do_classifier_free_guidance = (
138
+ rbln_config.get("guidance_scale", 5.0) > 1.0 and pipe.unet.config.time_cond_proj_dim is None
139
+ )
140
+ batch_size = 2 if do_classifier_free_guidance else 1
141
+ else:
142
+ if rbln_config.get("guidance_scale"):
143
+ logger.warning(
144
+ "guidance_scale is ignored because batch size is explicitly specified. "
145
+ "To ensure consistent behavior, consider removing the guidance scale or "
146
+ "adjusting the batch size configuration as needed."
147
+ )
148
+
149
+ rbln_config.update(
150
+ {
151
+ "max_seq_len": pipe.text_encoder.config.max_position_embeddings,
152
+ "text_model_hidden_size": text_model_hidden_size,
153
+ "vae_sample_size": rbln_vae_cls.get_vae_sample_size(pipe, rbln_config),
154
+ "unet_sample_size": rbln_unet_cls.get_unet_sample_size(pipe, rbln_config),
155
+ "batch_size": batch_size,
156
+ }
157
+ )
158
+
159
+ return rbln_config
160
+
134
161
  @classmethod
135
162
  def _get_rbln_config(
136
163
  cls,
@@ -207,6 +234,10 @@ class RBLNControlNetModel(RBLNModel):
207
234
 
208
235
  return rbln_config
209
236
 
237
+ @property
238
+ def compiled_batch_size(self):
239
+ return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
240
+
210
241
  def forward(
211
242
  self,
212
243
  sample: torch.FloatTensor,
@@ -217,9 +248,18 @@ class RBLNControlNetModel(RBLNModel):
217
248
  added_cond_kwargs: Dict[str, torch.Tensor] = {},
218
249
  **kwargs,
219
250
  ):
220
- """
221
- The [`ControlNetModel`] forward method.
222
- """
251
+ sample_batch_size = sample.size()[0]
252
+ compiled_batch_size = self.compiled_batch_size
253
+ if sample_batch_size != compiled_batch_size and (
254
+ sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
255
+ ):
256
+ raise ValueError(
257
+ f"Mismatch between ControlNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
258
+ "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
259
+ "Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
260
+ "For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
261
+ )
262
+
223
263
  added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
224
264
  if self.use_encoder_hidden_states:
225
265
  output = super().forward(
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .transformer_sd3 import RBLNSD3Transformer2DModel
@@ -0,0 +1,203 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import logging
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
26
+
27
+ import torch
28
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
29
+ from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
30
+ from transformers import PretrainedConfig
31
+
32
+ from ....modeling import RBLNModel
33
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
34
+ from ...modeling_diffusers import RBLNDiffusionMixin
35
+
36
+
37
+ if TYPE_CHECKING:
38
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ class SD3Transformer2DModelWrapper(torch.nn.Module):
44
+ def __init__(self, model: "SD3Transformer2DModel") -> None:
45
+ super().__init__()
46
+ self.model = model
47
+
48
+ def forward(
49
+ self,
50
+ hidden_states: torch.FloatTensor,
51
+ encoder_hidden_states: torch.FloatTensor = None,
52
+ pooled_projections: torch.FloatTensor = None,
53
+ timestep: torch.LongTensor = None,
54
+ # need controlnet support?
55
+ block_controlnet_hidden_states: List = None,
56
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
57
+ return_dict: bool = True,
58
+ ):
59
+ return self.model(
60
+ hidden_states=hidden_states,
61
+ encoder_hidden_states=encoder_hidden_states,
62
+ pooled_projections=pooled_projections,
63
+ timestep=timestep,
64
+ return_dict=False,
65
+ )
66
+
67
+
68
+ class RBLNSD3Transformer2DModel(RBLNModel):
69
+ hf_library_name = "diffusers"
70
+
71
+ def __post_init__(self, **kwargs):
72
+ super().__post_init__(**kwargs)
73
+
74
+ @classmethod
75
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
76
+ return SD3Transformer2DModelWrapper(model).eval()
77
+
78
+ @classmethod
79
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
80
+ sample_size = rbln_config.get("sample_size", pipe.default_sample_size)
81
+ img_width = rbln_config.get("img_width")
82
+ img_height = rbln_config.get("img_height")
83
+
84
+ if (img_width is None) ^ (img_height is None):
85
+ raise RuntimeError
86
+
87
+ elif img_width and img_height:
88
+ sample_size = img_height // pipe.vae_scale_factor, img_width // pipe.vae_scale_factor
89
+
90
+ prompt_max_length = rbln_config.get("max_sequence_length", 256)
91
+ prompt_embed_length = pipe.tokenizer_max_length + prompt_max_length
92
+
93
+ batch_size = rbln_config.get("batch_size")
94
+ if not batch_size:
95
+ do_classifier_free_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0
96
+ batch_size = 2 if do_classifier_free_guidance else 1
97
+ else:
98
+ if rbln_config.get("guidance_scale"):
99
+ logger.warning(
100
+ "guidance_scale is ignored because batch size is explicitly specified. "
101
+ "To ensure consistent behavior, consider removing the guidance scale or "
102
+ "adjusting the batch size configuration as needed."
103
+ )
104
+
105
+ rbln_config.update(
106
+ {
107
+ "batch_size": batch_size,
108
+ "prompt_embed_length": prompt_embed_length,
109
+ "sample_size": sample_size,
110
+ }
111
+ )
112
+
113
+ return rbln_config
114
+
115
+ @classmethod
116
+ def _get_rbln_config(
117
+ cls,
118
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
119
+ model_config: "PretrainedConfig",
120
+ rbln_kwargs: Dict[str, Any] = {},
121
+ ) -> RBLNConfig:
122
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
123
+
124
+ sample_size = rbln_kwargs.get("sample_size", model_config.sample_size)
125
+ if isinstance(sample_size, int):
126
+ sample_size = (sample_size, sample_size)
127
+
128
+ rbln_prompt_embed_length = rbln_kwargs.get("prompt_embed_length")
129
+ if rbln_prompt_embed_length is None:
130
+ raise ValueError("rbln_prompt_embed_length should be specified.")
131
+
132
+ input_info = [
133
+ (
134
+ "hidden_states",
135
+ [
136
+ rbln_batch_size,
137
+ model_config.in_channels,
138
+ sample_size[0],
139
+ sample_size[1],
140
+ ],
141
+ "float32",
142
+ ),
143
+ (
144
+ "encoder_hidden_states",
145
+ [
146
+ rbln_batch_size,
147
+ rbln_prompt_embed_length,
148
+ model_config.joint_attention_dim,
149
+ ],
150
+ "float32",
151
+ ),
152
+ (
153
+ "pooled_projections",
154
+ [
155
+ rbln_batch_size,
156
+ model_config.pooled_projection_dim,
157
+ ],
158
+ "float32",
159
+ ),
160
+ ("timestep", [rbln_batch_size], "float32"),
161
+ ]
162
+
163
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
164
+
165
+ rbln_config = RBLNConfig(
166
+ rbln_cls=cls.__name__,
167
+ compile_cfgs=[rbln_compile_config],
168
+ rbln_kwargs=rbln_kwargs,
169
+ )
170
+
171
+ rbln_config.model_cfg.update({"batch_size": rbln_batch_size})
172
+
173
+ return rbln_config
174
+
175
+ @property
176
+ def compiled_batch_size(self):
177
+ return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
178
+
179
+ def forward(
180
+ self,
181
+ hidden_states: torch.FloatTensor,
182
+ encoder_hidden_states: torch.FloatTensor = None,
183
+ pooled_projections: torch.FloatTensor = None,
184
+ timestep: torch.LongTensor = None,
185
+ block_controlnet_hidden_states: List = None,
186
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
187
+ return_dict: bool = True,
188
+ **kwargs,
189
+ ):
190
+ sample_batch_size = hidden_states.size()[0]
191
+ compiled_batch_size = self.compiled_batch_size
192
+ if sample_batch_size != compiled_batch_size and (
193
+ sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
194
+ ):
195
+ raise ValueError(
196
+ f"Mismatch between Transformers' runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
197
+ "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
198
+ "Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
199
+ "For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
200
+ )
201
+
202
+ sample = super().forward(hidden_states, encoder_hidden_states, pooled_projections, timestep)
203
+ return Transformer2DModelOutput(sample=sample)
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .unet_2d_condition import RBLNUNet2DConditionModel