optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. optimum/rbln/__init__.py +14 -7
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
  4. optimum/rbln/diffusers/models/controlnet.py +36 -62
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
  16. optimum/rbln/modeling_alias.py +4 -9
  17. optimum/rbln/modeling_base.py +117 -144
  18. optimum/rbln/modeling_config.py +51 -0
  19. optimum/rbln/modeling_diffusers.py +400 -0
  20. optimum/rbln/transformers/__init__.py +10 -0
  21. optimum/rbln/transformers/cache_utils.py +5 -9
  22. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  23. optimum/rbln/transformers/models/__init__.py +80 -28
  24. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
  25. optimum/rbln/transformers/models/bart/__init__.py +1 -1
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
  27. optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
  28. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
  30. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  32. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
  33. optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
  34. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  35. optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  37. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  38. optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
  39. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  40. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  41. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  42. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  47. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  48. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  49. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  50. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
  51. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  52. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  53. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
  54. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  55. optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
  56. optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
  57. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
  58. optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
  59. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
  60. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  61. optimum/rbln/utils/context.py +58 -0
  62. optimum/rbln/utils/decorator_utils.py +55 -0
  63. optimum/rbln/utils/import_utils.py +21 -0
  64. optimum/rbln/utils/logging.py +1 -1
  65. optimum/rbln/utils/runtime_utils.py +4 -4
  66. optimum/rbln/utils/timer_utils.py +26 -2
  67. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
  68. optimum_rbln-0.1.13.dist-info/RECORD +107 -0
  69. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
  70. optimum_rbln-0.1.11.dist-info/RECORD +0 -93
  71. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  72. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -23,16 +23,15 @@
23
23
 
24
24
  import logging
25
25
  from dataclasses import dataclass
26
- from pathlib import Path
27
26
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
28
27
 
29
28
  import torch
30
29
  from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
31
- from optimum.exporters import TasksManager
32
- from transformers import AutoConfig, AutoModel, PretrainedConfig
30
+ from transformers import PretrainedConfig
33
31
 
34
32
  from ...modeling_base import RBLNModel
35
33
  from ...modeling_config import RBLNCompileConfig, RBLNConfig
34
+ from ...utils.context import override_auto_classes
36
35
 
37
36
 
38
37
  if TYPE_CHECKING:
@@ -126,9 +125,6 @@ class _UNet_SDXL(torch.nn.Module):
126
125
 
127
126
 
128
127
  class RBLNUNet2DConditionModel(RBLNModel):
129
- model_type = "rbln_model"
130
- auto_model_class = AutoModel # feature extraction
131
-
132
128
  def __post_init__(self, **kwargs):
133
129
  super().__post_init__(**kwargs)
134
130
  self.in_features = self.rbln_config.model_cfg.get("in_features", None)
@@ -146,29 +142,11 @@ class RBLNUNet2DConditionModel(RBLNModel):
146
142
 
147
143
  @classmethod
148
144
  def from_pretrained(cls, *args, **kwargs):
149
- def get_model_from_task(
150
- task: str,
151
- model_name_or_path: Union[str, Path],
152
- **kwargs,
145
+ with override_auto_classes(
146
+ config_func=UNet2DConditionModel.load_config,
147
+ model_func=UNet2DConditionModel.from_pretrained,
153
148
  ):
154
- return UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
155
-
156
- tasktmp = TasksManager.get_model_from_task
157
- configtmp = AutoConfig.from_pretrained
158
- modeltmp = AutoModel.from_pretrained
159
- TasksManager.get_model_from_task = get_model_from_task
160
- if kwargs.get("export", None):
161
- # This is an ad-hoc to workaround save null values of the config.
162
- # if export, pure optimum(not optimum-rbln) loads config using AutoConfig
163
- # and diffusers model do not support loading by AutoConfig.
164
- AutoConfig.from_pretrained = lambda *args, **kwargs: None
165
- else:
166
- AutoConfig.from_pretrained = UNet2DConditionModel.load_config
167
- AutoModel.from_pretrained = UNet2DConditionModel.from_pretrained
168
- rt = super().from_pretrained(*args, **kwargs)
169
- AutoConfig.from_pretrained = configtmp
170
- AutoModel.from_pretrained = modeltmp
171
- TasksManager.get_model_from_task = tasktmp
149
+ rt = super().from_pretrained(*args, **kwargs)
172
150
  return rt
173
151
 
174
152
  @classmethod
@@ -185,137 +163,68 @@ class RBLNUNet2DConditionModel(RBLNModel):
185
163
  model_config: "PretrainedConfig",
186
164
  rbln_kwargs: Dict[str, Any] = {},
187
165
  ) -> RBLNConfig:
188
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
189
- rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
190
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
191
- rbln_in_features = rbln_kwargs.get("in_features", None)
192
- rbln_use_encode = rbln_kwargs.get("use_encode", None)
193
- rbln_img_width = rbln_kwargs.get("img_width", None)
194
- rbln_img_height = rbln_kwargs.get("img_height", None)
195
- rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
196
- rbln_is_controlnet = rbln_kwargs.get("is_controlnet", None)
197
-
198
- if rbln_max_seq_len is None:
199
- rbln_max_seq_len = 77
200
- if rbln_batch_size is None:
201
- rbln_batch_size = 1
202
-
203
- if rbln_use_encode:
204
- if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
205
- raise ValueError(
206
- "rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided when rbln_use_encode is True"
207
- )
208
- input_width = rbln_img_width // rbln_vae_scale_factor
209
- input_height = rbln_img_height // rbln_vae_scale_factor
210
- else:
211
- input_width, input_height = model_config.sample_size, model_config.sample_size
166
+ batch_size = rbln_kwargs.get("batch_size")
167
+ max_seq_len = rbln_kwargs.get("max_seq_len")
168
+ sample_size = rbln_kwargs.get("sample_size")
169
+ is_controlnet = rbln_kwargs.get("is_controlnet")
170
+ rbln_in_features = None
171
+
172
+ if batch_size is None:
173
+ batch_size = 1
174
+
175
+ if sample_size is None:
176
+ sample_size = model_config.sample_size
177
+
178
+ if isinstance(sample_size, int):
179
+ sample_size = (sample_size, sample_size)
180
+
181
+ if max_seq_len is None:
182
+ raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings )must be specified")
212
183
 
213
184
  input_info = [
214
- (
215
- "sample",
216
- [
217
- rbln_batch_size,
218
- model_config.in_channels,
219
- input_height,
220
- input_width,
221
- ],
222
- "float32",
223
- ),
185
+ ("sample", [batch_size, model_config.in_channels, sample_size[0], sample_size[1]], "float32"),
224
186
  ("timestep", [], "float32"),
225
- (
226
- "encoder_hidden_states",
227
- [
228
- rbln_batch_size,
229
- rbln_max_seq_len,
230
- model_config.cross_attention_dim,
231
- ],
232
- "float32",
233
- ),
187
+ ("encoder_hidden_states", [batch_size, max_seq_len, model_config.cross_attention_dim], "float32"),
234
188
  ]
235
189
 
236
- if rbln_is_controlnet:
237
- if len(model_config.block_out_channels) > 0:
238
- input_info.extend(
239
- [
240
- (
241
- f"down_block_additional_residuals_{i}",
242
- [rbln_batch_size, model_config.block_out_channels[0], input_height, input_width],
243
- "float32",
244
- )
245
- for i in range(3)
246
- ]
247
- )
248
- if len(model_config.block_out_channels) > 1:
249
- input_info.append(
250
- (
251
- "down_block_additional_residuals_3",
252
- [rbln_batch_size, model_config.block_out_channels[0], input_height // 2, input_width // 2],
253
- "float32",
254
- )
255
- )
256
- input_info.extend(
257
- [
258
- (
259
- f"down_block_additional_residuals_{i}",
260
- [rbln_batch_size, model_config.block_out_channels[1], input_height // 2, input_width // 2],
261
- "float32",
262
- )
263
- for i in range(4, 6)
264
- ]
265
- )
266
- if len(model_config.block_out_channels) > 2:
267
- input_info.append(
268
- (
269
- f"down_block_additional_residuals_{6}",
270
- [rbln_batch_size, model_config.block_out_channels[1], input_height // 4, input_width // 4],
271
- "float32",
272
- )
273
- )
274
- input_info.extend(
275
- [
276
- (
277
- f"down_block_additional_residuals_{i}",
278
- [rbln_batch_size, model_config.block_out_channels[2], input_height // 4, input_width // 4],
279
- "float32",
280
- )
281
- for i in range(7, 9)
282
- ]
283
- )
284
- if len(model_config.block_out_channels) > 3:
285
- input_info.extend(
286
- [
287
- (
288
- f"down_block_additional_residuals_{i}",
289
- [rbln_batch_size, model_config.block_out_channels[3], input_height // 8, input_width // 8],
290
- "float32",
291
- )
292
- for i in range(9, 12)
293
- ]
294
- )
295
- input_info.append(
296
- (
297
- "mid_block_additional_residual",
298
- [
299
- rbln_batch_size,
300
- model_config.block_out_channels[-1],
301
- input_height // 2 ** (len(model_config.block_out_channels) - 1),
302
- input_width // 2 ** (len(model_config.block_out_channels) - 1),
303
- ],
304
- "float32",
305
- )
306
- )
190
+ if is_controlnet:
191
+ # down block addtional residuals
192
+ first_shape = [batch_size, model_config.block_out_channels[0], sample_size[0], sample_size[1]]
193
+ height, width = sample_size[0], sample_size[1]
194
+ input_info.append(("down_block_additional_residuals_0", first_shape, "float32"))
195
+ name_idx = 1
196
+ for idx, _ in enumerate(model_config.down_block_types):
197
+ shape = [batch_size, model_config.block_out_channels[idx], height, width]
198
+ for _ in range(model_config.layers_per_block):
199
+ input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
200
+ name_idx += 1
201
+ if idx != len(model_config.down_block_types) - 1:
202
+ height = height // 2
203
+ width = width // 2
204
+ shape = [batch_size, model_config.block_out_channels[idx], height, width]
205
+ input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
206
+ name_idx += 1
207
+
208
+ # mid block addtional residual
209
+ num_cross_attn_blocks = model_config.down_block_types.count("CrossAttnDownBlock2D")
210
+ out_channels = model_config.block_out_channels[-1]
211
+ shape = [
212
+ batch_size,
213
+ out_channels,
214
+ sample_size[0] // 2**num_cross_attn_blocks,
215
+ sample_size[1] // 2**num_cross_attn_blocks,
216
+ ]
217
+ input_info.append(("mid_block_additional_residual", shape, "float32"))
307
218
 
308
219
  rbln_compile_config = RBLNCompileConfig(input_info=input_info)
309
220
 
310
221
  if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
311
- if rbln_text_model_hidden_size is None:
312
- rbln_text_model_hidden_size = 768
313
- if rbln_in_features is None:
314
- rbln_in_features = model_config.projection_class_embeddings_input_dim
222
+ rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
223
+ rbln_in_features = model_config.projection_class_embeddings_input_dim
315
224
  rbln_compile_config.input_info.append(
316
- ("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
225
+ ("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32")
317
226
  )
318
- rbln_compile_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
227
+ rbln_compile_config.input_info.append(("time_ids", [batch_size, 6], "float32"))
319
228
 
320
229
  rbln_config = RBLNConfig(
321
230
  rbln_cls=cls.__name__,
@@ -323,14 +232,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
323
232
  rbln_kwargs=rbln_kwargs,
324
233
  )
325
234
 
326
- rbln_config.model_cfg.update(
327
- {
328
- "max_seq_len": rbln_max_seq_len,
329
- "batch_size": rbln_batch_size,
330
- "use_encode": rbln_use_encode,
331
- }
332
- )
333
-
334
235
  if rbln_in_features is not None:
335
236
  rbln_config.model_cfg["in_features"] = rbln_in_features
336
237
 
@@ -20,16 +20,44 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+ from typing import TYPE_CHECKING
23
24
 
24
- from .controlnet import (
25
- RBLNMultiControlNetModel,
26
- RBLNStableDiffusionControlNetImg2ImgPipeline,
27
- RBLNStableDiffusionControlNetPipeline,
28
- RBLNStableDiffusionXLControlNetImg2ImgPipeline,
29
- RBLNStableDiffusionXLControlNetPipeline,
30
- )
31
- from .stable_diffusion import (
32
- RBLNStableDiffusionImg2ImgPipeline,
33
- RBLNStableDiffusionPipeline,
34
- )
35
- from .stable_diffusion_xl import RBLNStableDiffusionXLImg2ImgPipeline, RBLNStableDiffusionXLPipeline
25
+ from transformers.utils import _LazyModule
26
+
27
+
28
+ _import_structure = {
29
+ "controlnet": [
30
+ "RBLNMultiControlNetModel",
31
+ "RBLNStableDiffusionControlNetImg2ImgPipeline",
32
+ "RBLNStableDiffusionControlNetPipeline",
33
+ "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
34
+ "RBLNStableDiffusionXLControlNetPipeline",
35
+ ],
36
+ "stable_diffusion": [
37
+ "RBLNStableDiffusionImg2ImgPipeline",
38
+ "RBLNStableDiffusionPipeline",
39
+ ],
40
+ "stable_diffusion_xl": ["RBLNStableDiffusionXLImg2ImgPipeline", "RBLNStableDiffusionXLPipeline"],
41
+ }
42
+ if TYPE_CHECKING:
43
+ from .controlnet import (
44
+ RBLNMultiControlNetModel,
45
+ RBLNStableDiffusionControlNetImg2ImgPipeline,
46
+ RBLNStableDiffusionControlNetPipeline,
47
+ RBLNStableDiffusionXLControlNetImg2ImgPipeline,
48
+ RBLNStableDiffusionXLControlNetPipeline,
49
+ )
50
+ from .stable_diffusion import (
51
+ RBLNStableDiffusionImg2ImgPipeline,
52
+ RBLNStableDiffusionPipeline,
53
+ )
54
+ from .stable_diffusion_xl import RBLNStableDiffusionXLImg2ImgPipeline, RBLNStableDiffusionXLPipeline
55
+ else:
56
+ import sys
57
+
58
+ sys.modules[__name__] = _LazyModule(
59
+ __name__,
60
+ globals()["__file__"],
61
+ _import_structure,
62
+ module_spec=__spec__,
63
+ )
@@ -52,6 +52,13 @@ class RBLNMultiControlNetModel(RBLNModel):
52
52
  self.nets = models
53
53
  self.dtype = torch.float32
54
54
 
55
+ @property
56
+ def compiled_models(self):
57
+ cm = []
58
+ for net in self.nets:
59
+ cm.extend(net.compiled_models)
60
+ return cm
61
+
55
62
  @classmethod
56
63
  def from_pretrained(cls, *args, **kwargs):
57
64
  def get_model_from_task(
@@ -102,6 +109,10 @@ class RBLNMultiControlNetModel(RBLNModel):
102
109
  real_save_path = save_directory + suffix
103
110
  model.save_pretrained(real_save_path)
104
111
 
112
+ @classmethod
113
+ def _get_rbln_config(cls, **rbln_config_kwargs):
114
+ pass
115
+
105
116
  def forward(
106
117
  self,
107
118
  sample: torch.FloatTensor,
@@ -26,205 +26,25 @@ from typing import Any, Callable, Dict, List, Optional, Union
26
26
 
27
27
  import torch
28
28
  import torch.nn.functional as F
29
- from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionControlNetPipeline
29
+ from diffusers import StableDiffusionControlNetPipeline
30
30
  from diffusers.image_processor import PipelineImageInput
31
- from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
32
31
  from diffusers.pipelines.controlnet.pipeline_controlnet import retrieve_timesteps
33
32
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
34
33
  from diffusers.utils import deprecate, logging
35
34
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
36
- from transformers import CLIPTextModel
37
35
 
38
- from ....modeling_base import RBLNBaseModel
39
- from ....transformers import RBLNCLIPTextModel
40
- from ....utils.runtime_utils import ContextRblnConfig
41
- from ...models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
36
+ from ....modeling_diffusers import RBLNDiffusionMixin
37
+ from ....utils.decorator_utils import remove_compile_time_kwargs
38
+ from ...models import RBLNControlNetModel
42
39
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
43
40
 
44
41
 
45
42
  logger = logging.get_logger(__name__)
46
43
 
47
44
 
48
- class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
49
- @classmethod
50
- def from_pretrained(cls, model_id, **kwargs):
51
- """
52
- Pipeline for text-to-image generation using Stable Diffusion with ControlNet.
53
-
54
- This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods
55
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
56
-
57
- It implements the methods to convert a pre-trained Stable Diffusion Controlnet pipeline into a RBLNStableDiffusionControlNet pipeline by:
58
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
59
- - compiling the resulting graph using the RBLN compiler.
60
-
61
- Args:
62
- model_id (`Union[str, Path]`):
63
- Can be either:
64
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
65
- - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
66
- """
67
- export = kwargs.pop("export", None)
68
- vae = kwargs.pop("vae", None)
69
- unet = kwargs.pop("unet", None)
70
- text_encoder = kwargs.pop("text_encoder", None)
71
- controlnet = kwargs.pop("controlnet", None)
72
- model_save_dir = kwargs.pop("model_save_dir", None)
73
- rbln_config = kwargs.pop("rbln_config", None)
74
- rbln_kwargs, _ = RBLNBaseModel.resolve_rbln_config(rbln_config, kwargs)
75
-
76
- device = rbln_kwargs.get("device", None)
77
- device_map = rbln_kwargs.get("device_map", None)
78
- create_runtimes = rbln_kwargs.get("create_runtimes", None)
79
- optimize_host_memory = rbln_kwargs.get("optimize_host_memory", None)
80
-
81
- kwargs_dict = {
82
- "pretrained_model_name_or_path": model_id,
83
- **kwargs,
84
- }
85
-
86
- kwargs_dict.update(
87
- {
88
- **({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
89
- **({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
90
- **(
91
- {"text_encoder": text_encoder}
92
- if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
93
- else {}
94
- ),
95
- **(
96
- {"controlnet": controlnet}
97
- if controlnet is not None
98
- and (
99
- isinstance(controlnet, ControlNetModel)
100
- or all(isinstance(c, ControlNetModel) for c in controlnet)
101
- )
102
- else {}
103
- ),
104
- }
105
- )
106
-
107
- with ContextRblnConfig(
108
- device=device,
109
- device_map=device_map,
110
- create_runtimes=create_runtimes,
111
- optimze_host_mem=optimize_host_memory,
112
- ):
113
- model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
114
-
115
- if export is None or export is False:
116
- return model
117
-
118
- do_classifier_free_guidance = (
119
- rbln_kwargs.pop("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
120
- )
121
-
122
- # compile model, create runtime
123
- if not isinstance(vae, RBLNAutoencoderKL):
124
- vae = RBLNAutoencoderKL.from_pretrained(
125
- model_id=model_id,
126
- subfolder="vae",
127
- export=True,
128
- model_save_dir=model_save_dir,
129
- rbln_unet_sample_size=model.unet.config.sample_size,
130
- rbln_use_encode=False,
131
- rbln_vae_scale_factor=model.vae_scale_factor,
132
- rbln_config={**rbln_kwargs},
133
- )
134
-
135
- if not isinstance(text_encoder, RBLNCLIPTextModel):
136
- text_encoder = RBLNCLIPTextModel.from_pretrained(
137
- model_id=model_id,
138
- subfolder="text_encoder",
139
- export=True,
140
- model_save_dir=model_save_dir,
141
- rbln_config={**rbln_kwargs},
142
- )
143
-
144
- batch_size = rbln_kwargs.pop("batch_size", 1)
145
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
146
-
147
- if not isinstance(unet, RBLNUNet2DConditionModel):
148
- unet = RBLNUNet2DConditionModel.from_pretrained(
149
- model_id=model_id,
150
- subfolder="unet",
151
- export=True,
152
- model_save_dir=model_save_dir,
153
- rbln_max_seq_len=text_encoder.config.max_position_embeddings,
154
- rbln_batch_size=unet_batch_size,
155
- rbln_use_encode=False,
156
- rbln_vae_scale_factor=model.vae_scale_factor,
157
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
158
- rbln_config={**rbln_kwargs},
159
- )
160
-
161
- if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
162
- if isinstance(controlnet, (list, tuple)):
163
- multicontrolnet = []
164
- for i, cid in enumerate(controlnet):
165
- subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
166
- multicontrolnet.append(
167
- RBLNControlNetModel.from_pretrained(
168
- model_id=cid.config._name_or_path,
169
- subfolder=subfolder_name,
170
- export=True,
171
- model_save_dir=model_save_dir,
172
- rbln_batch_size=unet_batch_size,
173
- rbln_vae_scale_factor=model.vae_scale_factor,
174
- rbln_config={**rbln_kwargs},
175
- )
176
- )
177
- controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
178
- controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
179
- else:
180
- controlnet = RBLNControlNetModel.from_pretrained(
181
- model_id=controlnet.config._name_or_path,
182
- subfolder="controlnet",
183
- export=True,
184
- model_save_dir=model_save_dir,
185
- rbln_batch_size=unet_batch_size,
186
- rbln_vae_scale_factor=model.vae_scale_factor,
187
- rbln_config={**rbln_kwargs},
188
- )
189
- controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
190
-
191
- if model_save_dir is not None:
192
- # To skip saving original pytorch modules
193
- del (model.vae, model.text_encoder, model.unet, model.controlnet)
194
-
195
- # Direct calling of `save_pretrained` causes config.unet = (None, None).
196
- # So config must be saved again, later.
197
- model.save_pretrained(model_save_dir)
198
-
199
- # replace modules
200
- model.vae = vae
201
- model.text_encoder = text_encoder
202
- model.unet = unet
203
- model.controlnet = controlnet
204
-
205
- # update config to be able to load from file.
206
- update_dict = {
207
- "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
208
- "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
209
- "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
210
- "controlnet": controlnet_dict,
211
- }
212
- model.register_to_config(**update_dict)
213
-
214
- if model_save_dir is not None:
215
- # overwrite to replace incorrect config
216
- model.save_config(model_save_dir)
217
-
218
- # use for CI to access each compiled model
219
- if optimize_host_memory is False:
220
- model.compiled_models = [vae.compiled_models[0], text_encoder.compiled_models[0], unet.compiled_models[0]]
221
- if isinstance(controlnet, RBLNMultiControlNetModel):
222
- for c_model in controlnet.nets:
223
- model.compiled_models.append(c_model.compiled_models[0])
224
- else:
225
- model.compiled_models.append(controlnet.compiled_models[0])
226
-
227
- return model
45
+ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionControlNetPipeline):
46
+ original_class = StableDiffusionControlNetPipeline
47
+ _submodules = ["text_encoder", "unet", "vae", "controlnet"]
228
48
 
229
49
  def check_inputs(
230
50
  self,
@@ -390,6 +210,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
390
210
  )
391
211
 
392
212
  @torch.no_grad()
213
+ @remove_compile_time_kwargs
393
214
  def __call__(
394
215
  self,
395
216
  prompt: Union[str, List[str]] = None,
@@ -599,6 +420,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
599
420
  text_encoder_lora_scale = (
600
421
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
601
422
  )
423
+
602
424
  prompt_embeds, negative_prompt_embeds = self.encode_prompt(
603
425
  prompt,
604
426
  device,