optimum-rbln 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. optimum/rbln/__init__.py +6 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +7 -0
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -9
  5. optimum/rbln/diffusers/models/controlnet.py +93 -23
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +78 -61
  7. optimum/rbln/diffusers/pipelines/__init__.py +7 -2
  8. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +4 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +768 -0
  10. optimum/rbln/diffusers/pipelines/{stable_diffusion → controlnet}/pipeline_controlnet_img2img.py +25 -16
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +942 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +955 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -4
  15. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -9
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +19 -3
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +19 -3
  18. optimum/rbln/modeling_base.py +36 -3
  19. optimum/rbln/modeling_seq2seq.py +19 -4
  20. optimum/rbln/transformers/generation/__init__.py +1 -0
  21. optimum/rbln/transformers/generation/streamers.py +17 -0
  22. optimum/rbln/transformers/generation/utils.py +399 -0
  23. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +24 -333
  24. optimum/rbln/transformers/models/llama/modeling_llama.py +63 -45
  25. optimum/rbln/transformers/models/whisper/modeling_whisper.py +13 -3
  26. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/METADATA +1 -1
  27. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/RECORD +29 -25
  28. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/WHEEL +0 -0
  29. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -63,6 +63,9 @@ _import_structure = {
63
63
  "RBLNStableDiffusionControlNetImg2ImgPipeline",
64
64
  "RBLNMultiControlNetModel",
65
65
  "RBLNStableDiffusionXLImg2ImgPipeline",
66
+ "RBLNStableDiffusionControlNetPipeline",
67
+ "RBLNStableDiffusionXLControlNetPipeline",
68
+ "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
66
69
  ],
67
70
  "modeling_config": ["RBLNRuntimeConfig", "RBLNConfig"],
68
71
  }
@@ -73,8 +76,11 @@ if TYPE_CHECKING:
73
76
  RBLNControlNetModel,
74
77
  RBLNMultiControlNetModel,
75
78
  RBLNStableDiffusionControlNetImg2ImgPipeline,
79
+ RBLNStableDiffusionControlNetPipeline,
76
80
  RBLNStableDiffusionImg2ImgPipeline,
77
81
  RBLNStableDiffusionPipeline,
82
+ RBLNStableDiffusionXLControlNetImg2ImgPipeline,
83
+ RBLNStableDiffusionXLControlNetPipeline,
78
84
  RBLNStableDiffusionXLImg2ImgPipeline,
79
85
  RBLNStableDiffusionXLPipeline,
80
86
  RBLNUNet2DConditionModel,
@@ -1 +1 @@
1
- __version__ = '0.1.0'
1
+ __version__ = '0.1.1'
@@ -39,17 +39,24 @@ _import_structure = {
39
39
  "RBLNStableDiffusionControlNetImg2ImgPipeline",
40
40
  "RBLNMultiControlNetModel",
41
41
  "RBLNStableDiffusionXLImg2ImgPipeline",
42
+ "RBLNStableDiffusionControlNetPipeline",
43
+ "RBLNStableDiffusionXLControlNetPipeline",
44
+ "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
42
45
  ],
43
46
  "models": ["RBLNAutoencoderKL", "RBLNUNet2DConditionModel", "RBLNControlNetModel"],
44
47
  }
45
48
 
46
49
  if TYPE_CHECKING:
50
+
47
51
  from .models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
48
52
  from .pipelines import (
49
53
  RBLNMultiControlNetModel,
50
54
  RBLNStableDiffusionControlNetImg2ImgPipeline,
55
+ RBLNStableDiffusionControlNetPipeline,
51
56
  RBLNStableDiffusionImg2ImgPipeline,
52
57
  RBLNStableDiffusionPipeline,
58
+ RBLNStableDiffusionXLControlNetImg2ImgPipeline,
59
+ RBLNStableDiffusionXLControlNetPipeline,
53
60
  RBLNStableDiffusionXLImg2ImgPipeline,
54
61
  RBLNStableDiffusionXLPipeline,
55
62
  )
@@ -88,14 +88,23 @@ class RBLNAutoencoderKL(RBLNModel):
88
88
  subfolder: str = "",
89
89
  local_files_only: bool = False,
90
90
  trust_remote_code: bool = False,
91
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
91
92
  **kwargs,
92
93
  ) -> "RBLNAutoencoderKL":
93
94
  task = kwargs.pop("task", None)
94
95
  if task is None:
95
96
  task = TasksManager.infer_task_from_model(cls.auto_model_class)
96
97
 
97
- save_dir = TemporaryDirectory()
98
- save_dir_path = Path(save_dir.name)
98
+ if model_save_dir is None:
99
+ save_dir = TemporaryDirectory()
100
+ save_dir_path = Path(save_dir.name)
101
+ else:
102
+ save_dir = model_save_dir
103
+ if isinstance(save_dir, TemporaryDirectory):
104
+ save_dir_path = Path(model_save_dir.name)
105
+ else:
106
+ save_dir_path = Path(model_save_dir)
107
+ save_dir_path.mkdir(exist_ok=True)
99
108
 
100
109
  rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
101
110
 
@@ -119,7 +128,7 @@ class RBLNAutoencoderKL(RBLNModel):
119
128
  if not isinstance(config, PretrainedConfig): # diffusers config
120
129
  config = PretrainedConfig(**config)
121
130
 
122
- config.save_pretrained(save_dir_path)
131
+ config.save_pretrained(save_dir_path / subfolder)
123
132
  preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
124
133
 
125
134
  # Get compilation arguments
@@ -137,8 +146,12 @@ class RBLNAutoencoderKL(RBLNModel):
137
146
  enc_compiled_model = cls.compile(encoder_model, rbln_runtime_config=rbln_config["encoder"][0])
138
147
  dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["decoder"][0])
139
148
 
140
- enc_compiled_model.save(save_dir_path / f"{rbln_config['encoder'][0].compiled_model_name}.rbln")
141
- dec_compiled_model.save(save_dir_path / f"{rbln_config['decoder'][0].compiled_model_name}.rbln")
149
+ enc_compiled_model.save(
150
+ save_dir_path / subfolder / f"{rbln_config['encoder'][0].compiled_model_name}.rbln"
151
+ )
152
+ dec_compiled_model.save(
153
+ save_dir_path / subfolder / f"{rbln_config['decoder'][0].compiled_model_name}.rbln"
154
+ )
142
155
 
143
156
  def compile_text2img():
144
157
  decoder_model = _VAEDecoder(model)
@@ -146,19 +159,27 @@ class RBLNAutoencoderKL(RBLNModel):
146
159
 
147
160
  dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["compiled_model"][0])
148
161
 
149
- dec_compiled_model.save(save_dir_path / f"{rbln_config['compiled_model'][0].compiled_model_name}.rbln")
162
+ dec_compiled_model.save(
163
+ save_dir_path / subfolder / f"{rbln_config['compiled_model'][0].compiled_model_name}.rbln"
164
+ )
150
165
 
151
166
  if rbln_config_kwargs.get("rbln_use_encode"):
152
167
  compile_img2img()
153
168
  else:
154
169
  compile_text2img()
155
170
 
156
- rbln_config.save(save_dir_path)
171
+ rbln_config.save(save_dir_path / subfolder)
157
172
 
158
173
  return cls._from_pretrained(
159
174
  model_id=save_dir_path,
160
175
  config=config,
161
176
  model_save_dir=save_dir,
177
+ use_auth_token=use_auth_token,
178
+ revision=revision,
179
+ force_download=force_download,
180
+ cache_dir=cache_dir,
181
+ subfolder=subfolder,
182
+ local_files_only=local_files_only,
162
183
  **rbln_constructor_kwargs,
163
184
  **kwargs,
164
185
  )
@@ -216,7 +237,7 @@ class RBLNAutoencoderKL(RBLNModel):
216
237
  meta["rbln_img_height"] = rbln_img_height
217
238
 
218
239
  vae_enc_input_info = [
219
- ("x", [rbln_batch_size, model_config.in_channels, rbln_img_width, rbln_img_height], "float32")
240
+ ("x", [rbln_batch_size, model_config.in_channels, rbln_img_height, rbln_img_width], "float32")
220
241
  ]
221
242
  vae_dec_input_info = [
222
243
  (
@@ -224,8 +245,8 @@ class RBLNAutoencoderKL(RBLNModel):
224
245
  [
225
246
  rbln_batch_size,
226
247
  model_config.latent_channels,
227
- rbln_img_width // rbln_vae_scale_factor,
228
248
  rbln_img_height // rbln_vae_scale_factor,
249
+ rbln_img_width // rbln_vae_scale_factor,
229
250
  ],
230
251
  "float32",
231
252
  )
@@ -23,9 +23,8 @@
23
23
 
24
24
  import logging
25
25
  from pathlib import Path
26
- from typing import TYPE_CHECKING, Optional, Union
26
+ from typing import TYPE_CHECKING, Dict, Optional, Union
27
27
 
28
- import rebel
29
28
  import torch
30
29
  from diffusers import ControlNetModel
31
30
  from optimum.exporters import TasksManager
@@ -46,6 +45,37 @@ class _ControlNetModel(torch.nn.Module):
46
45
  super().__init__()
47
46
  self.controlnet = controlnet
48
47
 
48
+ def forward(
49
+ self,
50
+ sample: torch.Tensor,
51
+ timestep: torch.Tensor,
52
+ controlnet_cond: torch.Tensor,
53
+ conditioning_scale,
54
+ text_embeds: Optional[torch.Tensor] = None,
55
+ time_ids: Optional[torch.Tensor] = None,
56
+ ):
57
+ if text_embeds is not None and time_ids is not None:
58
+ added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
59
+ else:
60
+ added_cond_kwargs = {}
61
+
62
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
63
+ sample=sample,
64
+ timestep=timestep,
65
+ encoder_hidden_states=None,
66
+ controlnet_cond=controlnet_cond,
67
+ conditioning_scale=conditioning_scale,
68
+ added_cond_kwargs=added_cond_kwargs,
69
+ return_dict=False,
70
+ )
71
+ return down_block_res_samples, mid_block_res_sample
72
+
73
+
74
+ class _ControlNetModel_Cross_Attention(torch.nn.Module):
75
+ def __init__(self, controlnet: "ControlNetModel"):
76
+ super().__init__()
77
+ self.controlnet = controlnet
78
+
49
79
  def forward(
50
80
  self,
51
81
  sample: torch.Tensor,
@@ -53,13 +83,21 @@ class _ControlNetModel(torch.nn.Module):
53
83
  encoder_hidden_states: torch.Tensor,
54
84
  controlnet_cond: torch.Tensor,
55
85
  conditioning_scale,
86
+ text_embeds: Optional[torch.Tensor] = None,
87
+ time_ids: Optional[torch.Tensor] = None,
56
88
  ):
89
+ if text_embeds is not None and time_ids is not None:
90
+ added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
91
+ else:
92
+ added_cond_kwargs = {}
93
+
57
94
  down_block_res_samples, mid_block_res_sample = self.controlnet(
58
95
  sample=sample,
59
96
  timestep=timestep,
60
97
  encoder_hidden_states=encoder_hidden_states,
61
98
  controlnet_cond=controlnet_cond,
62
99
  conditioning_scale=conditioning_scale,
100
+ added_cond_kwargs=added_cond_kwargs,
63
101
  return_dict=False,
64
102
  )
65
103
  return down_block_res_samples, mid_block_res_sample
@@ -71,6 +109,9 @@ class RBLNControlNetModel(RBLNModel):
71
109
 
72
110
  def __post_init__(self, **kwargs):
73
111
  self.dtype = torch.float32
112
+ self.use_encoder_hidden_states = any(
113
+ item[0] == "encoder_hidden_states" for item in self.rbln_config["compiled_model"][0].input_info
114
+ )
74
115
 
75
116
  @classmethod
76
117
  def from_pretrained(cls, *args, **kwargs):
@@ -94,14 +135,16 @@ class RBLNControlNetModel(RBLNModel):
94
135
  return rt
95
136
 
96
137
  @classmethod
97
- def compile(cls, model, rbln_runtime_config: Optional[RBLNRuntimeConfig] = None):
98
- compiled_model = rebel.compile_from_torch(
99
- _ControlNetModel(model),
100
- input_info=rbln_runtime_config.input_info,
101
- batch_size=rbln_runtime_config.batch_size,
102
- fusion=rbln_runtime_config.fusion,
103
- )
104
- return compiled_model
138
+ def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
139
+ use_encoder_hidden_states = False
140
+ for down_block in model.down_blocks:
141
+ if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
142
+ break
143
+
144
+ if use_encoder_hidden_states:
145
+ return _ControlNetModel_Cross_Attention(model).eval()
146
+ else:
147
+ return _ControlNetModel(model).eval()
105
148
 
106
149
  @classmethod
107
150
  def _get_rbln_config(
@@ -109,6 +152,7 @@ class RBLNControlNetModel(RBLNModel):
109
152
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
110
153
  model_config: "PretrainedConfig",
111
154
  rbln_max_seq_len: Optional[int] = None,
155
+ rbln_text_model_hidden_size: Optional[int] = None,
112
156
  rbln_batch_size: Optional[int] = None,
113
157
  rbln_img_width: Optional[int] = None,
114
158
  rbln_img_height: Optional[int] = None,
@@ -132,12 +176,18 @@ class RBLNControlNetModel(RBLNModel):
132
176
  [
133
177
  rbln_batch_size,
134
178
  model_config.in_channels,
135
- input_width,
136
179
  input_height,
180
+ input_width,
137
181
  ],
138
182
  "float32",
139
183
  ),
140
184
  ("timestep", [], "float32"),
185
+ ],
186
+ batch_size=rbln_batch_size,
187
+ )
188
+ use_encoder_hidden_states = any(element != "DownBlock2D" for element in model_config.down_block_types)
189
+ if use_encoder_hidden_states:
190
+ rbln_runtime_config.input_info.append(
141
191
  (
142
192
  "encoder_hidden_states",
143
193
  [
@@ -146,12 +196,20 @@ class RBLNControlNetModel(RBLNModel):
146
196
  model_config.cross_attention_dim,
147
197
  ],
148
198
  "float32",
149
- ),
150
- ("controlnet_cond", [rbln_batch_size, 3, rbln_img_width, rbln_img_height], "float32"),
151
- ("conditioning_scale", [], "float32"),
152
- ],
153
- batch_size=rbln_batch_size,
199
+ )
200
+ )
201
+ rbln_runtime_config.input_info.append(
202
+ ("controlnet_cond", [rbln_batch_size, 3, rbln_img_height, rbln_img_width], "float32")
154
203
  )
204
+ rbln_runtime_config.input_info.append(("conditioning_scale", [], "float32"))
205
+ if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
206
+ if rbln_text_model_hidden_size is None:
207
+ rbln_text_model_hidden_size = 768
208
+ rbln_runtime_config.input_info.append(
209
+ ("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
210
+ )
211
+ rbln_runtime_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
212
+
155
213
  rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
156
214
  return rbln_config
157
215
 
@@ -162,18 +220,30 @@ class RBLNControlNetModel(RBLNModel):
162
220
  encoder_hidden_states: torch.Tensor,
163
221
  controlnet_cond: torch.FloatTensor,
164
222
  conditioning_scale: torch.Tensor = 1.0,
223
+ added_cond_kwargs: Dict[str, torch.Tensor] = {},
165
224
  **kwargs,
166
225
  ):
167
226
  """
168
227
  The [`ControlNetModel`] forward method.
169
228
  """
170
- output = super().forward(
171
- sample.contiguous(),
172
- timestep.float(),
173
- encoder_hidden_states,
174
- controlnet_cond,
175
- torch.tensor(conditioning_scale),
176
- )
229
+ added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
230
+ if self.use_encoder_hidden_states:
231
+ output = super().forward(
232
+ sample.contiguous(),
233
+ timestep.float(),
234
+ encoder_hidden_states,
235
+ controlnet_cond,
236
+ torch.tensor(conditioning_scale),
237
+ **added_cond_kwargs,
238
+ )
239
+ else:
240
+ output = super().forward(
241
+ sample.contiguous(),
242
+ timestep.float(),
243
+ controlnet_cond,
244
+ torch.tensor(conditioning_scale),
245
+ **added_cond_kwargs,
246
+ )
177
247
  down_block_res_samples = output[:-1]
178
248
  mid_block_res_sample = output[-1]
179
249
 
@@ -27,7 +27,7 @@ from pathlib import Path
27
27
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
28
28
 
29
29
  import torch
30
- from diffusers.models.unet_2d_condition import UNet2DConditionModel
30
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
31
31
  from optimum.exporters import TasksManager
32
32
  from transformers import AutoConfig, AutoModel, PretrainedConfig
33
33
 
@@ -90,22 +90,28 @@ class _UNet_SDXL(torch.nn.Module):
90
90
  sample: torch.Tensor,
91
91
  timestep: Union[torch.Tensor, float, int],
92
92
  encoder_hidden_states: torch.Tensor,
93
- text_embeds: Optional[torch.Tensor] = None,
94
- time_ids: Optional[torch.Tensor] = None,
95
93
  *down_and_mid_block_additional_residuals: Optional[Tuple[torch.Tensor]],
96
94
  ) -> torch.Tensor:
97
- if text_embeds is not None and time_ids is not None:
98
- added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
99
- else:
100
- added_cond_kwargs = {}
101
-
102
- if len(down_and_mid_block_additional_residuals) != 0:
95
+ if len(down_and_mid_block_additional_residuals) == 2:
96
+ added_cond_kwargs = {
97
+ "text_embeds": down_and_mid_block_additional_residuals[0],
98
+ "time_ids": down_and_mid_block_additional_residuals[1],
99
+ }
100
+ down_block_additional_residuals = None
101
+ mid_block_additional_residual = None
102
+ elif len(down_and_mid_block_additional_residuals) > 2:
103
+ added_cond_kwargs = {
104
+ "text_embeds": down_and_mid_block_additional_residuals[-2],
105
+ "time_ids": down_and_mid_block_additional_residuals[-1],
106
+ }
103
107
  down_block_additional_residuals, mid_block_additional_residual = (
104
- down_and_mid_block_additional_residuals[:-1],
105
- down_and_mid_block_additional_residuals[-1],
108
+ down_and_mid_block_additional_residuals[:-3],
109
+ down_and_mid_block_additional_residuals[-3],
106
110
  )
107
111
  else:
108
- down_block_additional_residuals, mid_block_additional_residual = None, None
112
+ added_cond_kwargs = {}
113
+ down_block_additional_residuals = None
114
+ mid_block_additional_residual = None
109
115
 
110
116
  unet_out = self.unet(
111
117
  sample=sample,
@@ -197,9 +203,11 @@ class RBLNUNet2DConditionModel(RBLNModel):
197
203
  meta["rbln_use_encode"] = rbln_use_encode
198
204
 
199
205
  if rbln_use_encode:
206
+ # FIXME :: robust img shape getter
200
207
  input_width = rbln_img_width // rbln_vae_scale_factor
201
208
  input_height = rbln_img_height // rbln_vae_scale_factor
202
209
  else:
210
+ # FIXME :: model_config.sample_size can be tuple or list
203
211
  input_width, input_height = model_config.sample_size, model_config.sample_size
204
212
 
205
213
  input_info = [
@@ -208,8 +216,8 @@ class RBLNUNet2DConditionModel(RBLNModel):
208
216
  [
209
217
  rbln_batch_size,
210
218
  model_config.in_channels,
211
- input_width,
212
219
  input_height,
220
+ input_width,
213
221
  ],
214
222
  "float32",
215
223
  ),
@@ -225,64 +233,73 @@ class RBLNUNet2DConditionModel(RBLNModel):
225
233
  ),
226
234
  ]
227
235
  if rbln_is_controlnet:
228
- input_info.extend(
229
- [
230
- (
231
- f"down_block_additional_residuals_{i}",
232
- [rbln_batch_size, model_config.block_out_channels[0], input_width, input_height],
233
- "float32",
234
- )
235
- for i in range(3)
236
- ]
237
- )
238
- input_info.append(
239
- (
240
- f"down_block_additional_residuals_{3}",
241
- [rbln_batch_size, model_config.block_out_channels[0], input_width // 2, input_height // 2],
242
- "float32",
236
+ if len(model_config.block_out_channels) > 0:
237
+ input_info.extend(
238
+ [
239
+ (
240
+ f"down_block_additional_residuals_{i}",
241
+ [rbln_batch_size, model_config.block_out_channels[0], input_height, input_width],
242
+ "float32",
243
+ )
244
+ for i in range(3)
245
+ ]
243
246
  )
244
- )
245
- input_info.extend(
246
- [
247
+ input_info.append(
247
248
  (
248
- f"down_block_additional_residuals_{i}",
249
- [rbln_batch_size, model_config.block_out_channels[1], input_width // 2, input_height // 2],
249
+ "down_block_additional_residuals_3",
250
+ [rbln_batch_size, model_config.block_out_channels[0], input_height // 2, input_width // 2],
250
251
  "float32",
251
252
  )
252
- for i in range(4, 6)
253
- ]
254
- )
255
- input_info.append(
256
- (
257
- f"down_block_additional_residuals_{6}",
258
- [rbln_batch_size, model_config.block_out_channels[1], input_width // 4, input_height // 4],
259
- "float32",
260
253
  )
261
- )
262
- input_info.extend(
263
- [
264
- (
265
- f"down_block_additional_residuals_{i}",
266
- [rbln_batch_size, model_config.block_out_channels[2], input_width // 4, input_height // 4],
267
- "float32",
268
- )
269
- for i in range(7, 9)
270
- ]
271
- )
272
- input_info.extend(
273
- [
254
+ if len(model_config.block_out_channels) > 1:
255
+ input_info.extend(
256
+ [
257
+ (
258
+ f"down_block_additional_residuals_{i}",
259
+ [rbln_batch_size, model_config.block_out_channels[1], input_height // 2, input_width // 2],
260
+ "float32",
261
+ )
262
+ for i in range(4, 6)
263
+ ]
264
+ )
265
+ input_info.append(
274
266
  (
275
- f"down_block_additional_residuals_{i}",
276
- [rbln_batch_size, model_config.block_out_channels[3], input_width // 8, input_height // 8],
267
+ f"down_block_additional_residuals_{6}",
268
+ [rbln_batch_size, model_config.block_out_channels[1], input_height // 4, input_width // 4],
277
269
  "float32",
278
270
  )
279
- for i in range(9, 12)
280
- ]
281
- )
271
+ )
272
+ if len(model_config.block_out_channels) > 2:
273
+ input_info.extend(
274
+ [
275
+ (
276
+ f"down_block_additional_residuals_{i}",
277
+ [rbln_batch_size, model_config.block_out_channels[2], input_height // 4, input_width // 4],
278
+ "float32",
279
+ )
280
+ for i in range(7, 9)
281
+ ]
282
+ )
283
+ if len(model_config.block_out_channels) > 3:
284
+ input_info.extend(
285
+ [
286
+ (
287
+ f"down_block_additional_residuals_{i}",
288
+ [rbln_batch_size, model_config.block_out_channels[3], input_height // 8, input_width // 8],
289
+ "float32",
290
+ )
291
+ for i in range(9, 12)
292
+ ]
293
+ )
282
294
  input_info.append(
283
295
  (
284
296
  "mid_block_additional_residual",
285
- [rbln_batch_size, model_config.block_out_channels[3], input_width // 8, input_height // 8],
297
+ [
298
+ rbln_batch_size,
299
+ model_config.block_out_channels[-1],
300
+ input_height // 2 ** (len(model_config.block_out_channels) - 1),
301
+ input_width // 2 ** (len(model_config.block_out_channels) - 1),
302
+ ],
286
303
  "float32",
287
304
  )
288
305
  )
@@ -344,7 +361,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
344
361
 
345
362
  return (
346
363
  super().forward(
347
- sample,
364
+ sample.contiguous(),
348
365
  timestep.float(),
349
366
  encoder_hidden_states,
350
367
  **added_cond_kwargs,
@@ -21,9 +21,14 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .controlnet import RBLNMultiControlNetModel
25
- from .stable_diffusion import (
24
+ from .controlnet import (
25
+ RBLNMultiControlNetModel,
26
26
  RBLNStableDiffusionControlNetImg2ImgPipeline,
27
+ RBLNStableDiffusionControlNetPipeline,
28
+ RBLNStableDiffusionXLControlNetImg2ImgPipeline,
29
+ RBLNStableDiffusionXLControlNetPipeline,
30
+ )
31
+ from .stable_diffusion import (
27
32
  RBLNStableDiffusionImg2ImgPipeline,
28
33
  RBLNStableDiffusionPipeline,
29
34
  )
@@ -22,3 +22,7 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  from .multicontrolnet import RBLNMultiControlNetModel
25
+ from .pipeline_controlnet import RBLNStableDiffusionControlNetPipeline
26
+ from .pipeline_controlnet_img2img import RBLNStableDiffusionControlNetImg2ImgPipeline
27
+ from .pipeline_controlnet_sd_xl import RBLNStableDiffusionXLControlNetPipeline
28
+ from .pipeline_controlnet_sd_xl_img2img import RBLNStableDiffusionXLControlNetImg2ImgPipeline