optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. optimum/rbln/__init__.py +14 -7
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
  4. optimum/rbln/diffusers/models/controlnet.py +36 -62
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
  16. optimum/rbln/modeling_alias.py +4 -9
  17. optimum/rbln/modeling_base.py +117 -144
  18. optimum/rbln/modeling_config.py +51 -0
  19. optimum/rbln/modeling_diffusers.py +400 -0
  20. optimum/rbln/transformers/__init__.py +10 -0
  21. optimum/rbln/transformers/cache_utils.py +5 -9
  22. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  23. optimum/rbln/transformers/models/__init__.py +80 -28
  24. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
  25. optimum/rbln/transformers/models/bart/__init__.py +1 -1
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
  27. optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
  28. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
  30. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  32. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
  33. optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
  34. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  35. optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  37. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  38. optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
  39. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  40. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  41. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  42. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  47. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  48. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  49. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  50. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
  51. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  52. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  53. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
  54. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  55. optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
  56. optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
  57. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
  58. optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
  59. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
  60. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  61. optimum/rbln/utils/context.py +58 -0
  62. optimum/rbln/utils/decorator_utils.py +55 -0
  63. optimum/rbln/utils/import_utils.py +21 -0
  64. optimum/rbln/utils/logging.py +1 -1
  65. optimum/rbln/utils/runtime_utils.py +4 -4
  66. optimum/rbln/utils/timer_utils.py +26 -2
  67. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
  68. optimum_rbln-0.1.13.dist-info/RECORD +107 -0
  69. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
  70. optimum_rbln-0.1.11.dist-info/RECORD +0 -93
  71. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  72. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -26,209 +26,23 @@ from typing import Any, Callable, Dict, List, Optional, Union
26
26
 
27
27
  import torch
28
28
  import torch.nn.functional as F
29
- from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionControlNetImg2ImgPipeline
29
+ from diffusers import StableDiffusionControlNetImg2ImgPipeline
30
30
  from diffusers.image_processor import PipelineImageInput
31
- from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
32
31
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
33
32
  from diffusers.utils import deprecate, logging
34
33
  from diffusers.utils.torch_utils import is_compiled_module
35
- from transformers import CLIPTextModel
36
34
 
37
- from ....modeling_base import RBLNBaseModel
38
- from ....transformers import RBLNCLIPTextModel
39
- from ....utils.runtime_utils import ContextRblnConfig
40
- from ...models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
35
+ from ....modeling_diffusers import RBLNDiffusionMixin
36
+ from ....utils.decorator_utils import remove_compile_time_kwargs
37
+ from ...models import RBLNControlNetModel
41
38
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
42
39
 
43
40
 
44
41
  logger = logging.get_logger(__name__)
45
42
 
46
43
 
47
- class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2ImgPipeline):
48
- @classmethod
49
- def from_pretrained(cls, model_id, **kwargs):
50
- """
51
- Pipeline for image-to-image generation using Stable Diffusion with ControlNet.
52
-
53
- This model inherits from [`StableDiffusionControlNetImg2ImgPipeline`]. Check the superclass documentation for the generic methods
54
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
55
-
56
- It implements the methods to convert a pre-trained Stable Diffusion Controlnet pipeline into a RBLNStableDiffusionControlNetImg2Img pipeline by:
57
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
58
- - compiling the resulting graph using the RBLN compiler.
59
-
60
- Args:
61
- model_id (`Union[str, Path]`):
62
- Can be either:
63
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
64
- - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
65
- """
66
- export = kwargs.pop("export", None)
67
- vae = kwargs.pop("vae", None)
68
- unet = kwargs.pop("unet", None)
69
- text_encoder = kwargs.pop("text_encoder", None)
70
- controlnet = kwargs.pop("controlnet", None)
71
- model_save_dir = kwargs.pop("model_save_dir", None)
72
- rbln_config = kwargs.pop("rbln_config", None)
73
- rbln_kwargs, _ = RBLNBaseModel.resolve_rbln_config(rbln_config, kwargs)
74
-
75
- device = rbln_kwargs.get("device", None)
76
- device_map = rbln_kwargs.get("device_map", None)
77
- create_runtimes = rbln_kwargs.get("create_runtimes", None)
78
- optimize_host_memory = rbln_kwargs.get("optimize_host_memory", None)
79
-
80
- kwargs_dict = {
81
- "pretrained_model_name_or_path": model_id,
82
- **kwargs,
83
- }
84
-
85
- kwargs_dict.update(
86
- {
87
- **({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
88
- **({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
89
- **(
90
- {"text_encoder": text_encoder}
91
- if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
92
- else {}
93
- ),
94
- **(
95
- {"controlnet": controlnet}
96
- if controlnet is not None
97
- and (
98
- isinstance(controlnet, ControlNetModel)
99
- or all(isinstance(c, ControlNetModel) for c in controlnet)
100
- )
101
- else {}
102
- ),
103
- }
104
- )
105
-
106
- with ContextRblnConfig(
107
- device=device,
108
- device_map=device_map,
109
- create_runtimes=create_runtimes,
110
- optimze_host_mem=optimize_host_memory,
111
- ):
112
- model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
113
-
114
- if export is None or export is False:
115
- return model
116
-
117
- do_classifier_free_guidance = (
118
- rbln_kwargs.pop("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
119
- )
120
-
121
- # compile model, create runtime
122
- if not isinstance(vae, RBLNAutoencoderKL):
123
- vae = RBLNAutoencoderKL.from_pretrained(
124
- model_id=model_id,
125
- subfolder="vae",
126
- export=True,
127
- model_save_dir=model_save_dir,
128
- rbln_unet_sample_size=model.unet.config.sample_size,
129
- rbln_use_encode=True,
130
- rbln_vae_scale_factor=model.vae_scale_factor,
131
- rbln_config={**rbln_kwargs},
132
- )
133
-
134
- if not isinstance(text_encoder, RBLNCLIPTextModel):
135
- text_encoder = RBLNCLIPTextModel.from_pretrained(
136
- model_id=model_id,
137
- subfolder="text_encoder",
138
- export=True,
139
- model_save_dir=model_save_dir,
140
- rbln_config={**rbln_kwargs},
141
- )
142
-
143
- batch_size = rbln_kwargs.pop("batch_size", 1)
144
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
145
-
146
- if not isinstance(unet, RBLNUNet2DConditionModel):
147
- unet = RBLNUNet2DConditionModel.from_pretrained(
148
- model_id=model_id,
149
- subfolder="unet",
150
- export=True,
151
- model_save_dir=model_save_dir,
152
- rbln_max_seq_len=text_encoder.config.max_position_embeddings,
153
- rbln_batch_size=unet_batch_size,
154
- rbln_use_encode=True,
155
- rbln_vae_scale_factor=model.vae_scale_factor,
156
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
157
- rbln_config={**rbln_kwargs},
158
- )
159
-
160
- if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
161
- if isinstance(controlnet, (list, tuple)):
162
- multicontrolnet = []
163
- for i, cid in enumerate(controlnet):
164
- subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
165
- multicontrolnet.append(
166
- RBLNControlNetModel.from_pretrained(
167
- model_id=cid.config._name_or_path,
168
- subfolder=subfolder_name,
169
- export=True,
170
- model_save_dir=model_save_dir,
171
- rbln_batch_size=unet_batch_size,
172
- rbln_vae_scale_factor=model.vae_scale_factor,
173
- rbln_config={**rbln_kwargs},
174
- )
175
- )
176
- controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
177
- controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
178
- else:
179
- controlnet = RBLNControlNetModel.from_pretrained(
180
- model_id=controlnet.config._name_or_path,
181
- subfolder="controlnet",
182
- export=True,
183
- model_save_dir=model_save_dir,
184
- rbln_batch_size=unet_batch_size,
185
- rbln_vae_scale_factor=model.vae_scale_factor,
186
- rbln_config={**rbln_kwargs},
187
- )
188
- controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
189
-
190
- if model_save_dir is not None:
191
- # To skip saving original pytorch modules
192
- del (model.vae, model.text_encoder, model.unet, model.controlnet)
193
-
194
- # Direct calling of `save_pretrained` causes config.unet = (None, None).
195
- # So config must be saved again, later.
196
- model.save_pretrained(model_save_dir)
197
-
198
- # replace modules
199
- model.vae = vae
200
- model.text_encoder = text_encoder
201
- model.unet = unet
202
- model.controlnet = controlnet
203
-
204
- # update config to be able to load from file.
205
- update_dict = {
206
- "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
207
- "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
208
- "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
209
- "controlnet": controlnet_dict,
210
- }
211
- model.register_to_config(**update_dict)
212
-
213
- if model_save_dir is not None:
214
- # overwrite to replace incorrect config
215
- model.save_config(model_save_dir)
216
-
217
- # use for CI to access each compiled model
218
- if optimize_host_memory is False:
219
- model.compiled_models = [
220
- vae.compiled_models[0],
221
- vae.compiled_models[1],
222
- text_encoder.compiled_models[0],
223
- unet.compiled_models[0],
224
- ]
225
- if isinstance(controlnet, RBLNMultiControlNetModel):
226
- for c_model in controlnet.nets:
227
- model.compiled_models.append(c_model.compiled_models[0])
228
- else:
229
- model.compiled_models.append(controlnet.compiled_models[0])
230
-
231
- return model
44
+ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionControlNetImg2ImgPipeline):
45
+ _submodules = ["text_encoder", "unet", "vae", "controlnet"]
232
46
 
233
47
  def check_inputs(
234
48
  self,
@@ -388,6 +202,7 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2
388
202
  )
389
203
 
390
204
  @torch.no_grad()
205
+ @remove_compile_time_kwargs
391
206
  def __call__(
392
207
  self,
393
208
  prompt: Union[str, List[str]] = None,
@@ -596,6 +411,7 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2
596
411
  text_encoder_lora_scale = (
597
412
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
598
413
  )
414
+
599
415
  prompt_embeds, negative_prompt_embeds = self.encode_prompt(
600
416
  prompt,
601
417
  device,
@@ -26,223 +26,23 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
26
26
 
27
27
  import torch
28
28
  import torch.nn.functional as F
29
- from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionXLControlNetPipeline
29
+ from diffusers import StableDiffusionXLControlNetPipeline
30
30
  from diffusers.image_processor import PipelineImageInput
31
- from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
32
31
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
33
32
  from diffusers.utils import deprecate, logging
34
33
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
35
- from transformers import CLIPTextModel
36
34
 
37
- from ....modeling_base import RBLNBaseModel
38
- from ....transformers import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
39
- from ....utils.runtime_utils import ContextRblnConfig
40
- from ...models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
35
+ from ....modeling_diffusers import RBLNDiffusionMixin
36
+ from ....utils.decorator_utils import remove_compile_time_kwargs
37
+ from ...models import RBLNControlNetModel
41
38
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
42
39
 
43
40
 
44
41
  logger = logging.get_logger(__name__)
45
42
 
46
43
 
47
- class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipeline):
48
- @classmethod
49
- def from_pretrained(cls, model_id, **kwargs):
50
- """
51
- Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet.
52
-
53
- This model inherits from [`StableDiffusionXLControlNetPipeline`]. Check the superclass documentation for the generic methods
54
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
55
-
56
- It implements the methods to convert a pre-trained Stable Diffusion XL Controlnet pipeline into a RBLNStableDiffusionXLControlNet pipeline by:
57
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
58
- - compiling the resulting graph using the RBLN compiler.
59
-
60
- Args:
61
- model_id (`Union[str, Path]`):
62
- Can be either:
63
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
64
- - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
65
- """
66
- export = kwargs.pop("export", None)
67
- vae = kwargs.pop("vae", None)
68
- unet = kwargs.pop("unet", None)
69
- text_encoder = kwargs.pop("text_encoder", None)
70
- text_encoder_2 = kwargs.pop("text_encoder_2", None)
71
- controlnet = kwargs.pop("controlnet", None)
72
- model_save_dir = kwargs.pop("model_save_dir", None)
73
- rbln_config = kwargs.pop("rbln_config", None)
74
- rbln_kwargs, _ = RBLNBaseModel.resolve_rbln_config(rbln_config, kwargs)
75
-
76
- device = rbln_kwargs.get("device", None)
77
- device_map = rbln_kwargs.get("device_map", None)
78
- create_runtimes = rbln_kwargs.get("create_runtimes", None)
79
- optimize_host_memory = rbln_kwargs.get("optimize_host_memory", None)
80
-
81
- kwargs_dict = {
82
- "pretrained_model_name_or_path": model_id,
83
- **kwargs,
84
- }
85
-
86
- kwargs_dict.update(
87
- {
88
- **({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
89
- **({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
90
- **(
91
- {"text_encoder": text_encoder}
92
- if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
93
- else {}
94
- ),
95
- **(
96
- {"controlnet": controlnet}
97
- if controlnet is not None
98
- and (
99
- isinstance(controlnet, ControlNetModel)
100
- or all(isinstance(c, ControlNetModel) for c in controlnet)
101
- )
102
- else {}
103
- ),
104
- }
105
- )
106
-
107
- with ContextRblnConfig(
108
- device=device,
109
- device_map=device_map,
110
- create_runtimes=create_runtimes,
111
- optimze_host_mem=optimize_host_memory,
112
- ):
113
- model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
114
-
115
- if export is None or export is False:
116
- return model
117
-
118
- do_classifier_free_guidance = (
119
- rbln_kwargs.pop("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
120
- )
121
-
122
- if not isinstance(vae, RBLNAutoencoderKL):
123
- vae = RBLNAutoencoderKL.from_pretrained(
124
- model_id=model_id,
125
- subfolder="vae",
126
- export=True,
127
- model_save_dir=model_save_dir,
128
- rbln_unet_sample_size=model.unet.config.sample_size,
129
- rbln_use_encode=False,
130
- rbln_vae_scale_factor=model.vae_scale_factor,
131
- rbln_config={**rbln_kwargs},
132
- )
133
-
134
- if not isinstance(text_encoder, RBLNCLIPTextModel):
135
- text_encoder = RBLNCLIPTextModel.from_pretrained(
136
- model_id=model_id,
137
- subfolder="text_encoder",
138
- export=True,
139
- model_save_dir=model_save_dir,
140
- rbln_config={**rbln_kwargs},
141
- )
142
-
143
- if not isinstance(text_encoder_2, RBLNCLIPTextModel):
144
- text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
145
- model_id=model_id,
146
- subfolder="text_encoder_2",
147
- export=True,
148
- model_save_dir=model_save_dir,
149
- rbln_config={**rbln_kwargs},
150
- )
151
-
152
- batch_size = rbln_kwargs.pop("batch_size", 1)
153
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
154
-
155
- if not isinstance(unet, RBLNUNet2DConditionModel):
156
- unet = RBLNUNet2DConditionModel.from_pretrained(
157
- model_id=model_id,
158
- subfolder="unet",
159
- export=True,
160
- model_save_dir=model_save_dir,
161
- rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
162
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
163
- rbln_batch_size=unet_batch_size,
164
- rbln_use_encode=False,
165
- rbln_vae_scale_factor=model.vae_scale_factor,
166
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
167
- rbln_config={**rbln_kwargs},
168
- )
169
-
170
- if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
171
- if isinstance(controlnet, (list, tuple)):
172
- multicontrolnet = []
173
- for i, cid in enumerate(controlnet):
174
- subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
175
- multicontrolnet.append(
176
- RBLNControlNetModel.from_pretrained(
177
- model_id=cid.config._name_or_path,
178
- subfolder=subfolder_name,
179
- export=True,
180
- model_save_dir=model_save_dir,
181
- rbln_batch_size=unet_batch_size,
182
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
183
- rbln_vae_scale_factor=model.vae_scale_factor,
184
- rbln_config={**rbln_kwargs},
185
- )
186
- )
187
- controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
188
- controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
189
- else:
190
- controlnet = RBLNControlNetModel.from_pretrained(
191
- model_id=controlnet.config._name_or_path,
192
- subfolder="controlnet",
193
- export=True,
194
- model_save_dir=model_save_dir,
195
- rbln_batch_size=unet_batch_size,
196
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
197
- rbln_vae_scale_factor=model.vae_scale_factor,
198
- rbln_config={**rbln_kwargs},
199
- )
200
- controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
201
-
202
- if model_save_dir is not None:
203
- # To skip saving original pytorch modules
204
- del (model.vae, model.text_encoder, model.unet, model.controlnet)
205
-
206
- # Direct calling of `save_pretrained` causes config.unet = (None, None).
207
- # So config must be saved again, later.
208
- model.save_pretrained(model_save_dir)
209
-
210
- # replace modules
211
- model.vae = vae
212
- model.text_encoder = text_encoder
213
- model.unet = unet
214
- model.text_encoder_2 = text_encoder_2
215
- model.controlnet = controlnet
216
-
217
- # update config to be able to load from file
218
- update_dict = {
219
- "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
220
- "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
221
- "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
222
- "text_encoder_2": ("optimum.rbln", "RBLNCLIPTextModel"),
223
- "controlnet": controlnet_dict,
224
- }
225
- model.register_to_config(**update_dict)
226
-
227
- if model_save_dir is not None:
228
- # overwrite to replace incorrect config
229
- model.save_config(model_save_dir)
230
-
231
- # use for CI to access each compiled model
232
- if optimize_host_memory is False:
233
- model.compiled_models = [
234
- vae.compiled_models[0],
235
- text_encoder.compiled_models[0],
236
- text_encoder_2.compiled_models[0],
237
- unet.compiled_models[0],
238
- ]
239
- if isinstance(controlnet, RBLNMultiControlNetModel):
240
- for c_model in controlnet.nets:
241
- model.compiled_models.append(c_model.compiled_models[0])
242
- else:
243
- model.compiled_models.append(controlnet.compiled_models[0])
244
-
245
- return model
44
+ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusionXLControlNetPipeline):
45
+ _submodules = ["text_encoder", "text_encoder_2", "unet", "vae", "controlnet"]
246
46
 
247
47
  def check_inputs(
248
48
  self,
@@ -434,6 +234,7 @@ class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipelin
434
234
  )
435
235
 
436
236
  @torch.no_grad()
237
+ @remove_compile_time_kwargs
437
238
  def __call__(
438
239
  self,
439
240
  prompt: Union[str, List[str]] = None,
@@ -697,6 +498,7 @@ class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipelin
697
498
  text_encoder_lora_scale = (
698
499
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
699
500
  )
501
+
700
502
  (
701
503
  prompt_embeds,
702
504
  negative_prompt_embeds,