optimum-rbln 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +5 -1
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -61
  4. optimum/rbln/diffusers/models/controlnet.py +36 -56
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -153
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +7 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -190
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -191
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -192
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -110
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -115
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -122
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -125
  16. optimum/rbln/modeling_base.py +12 -5
  17. optimum/rbln/modeling_diffusers.py +400 -0
  18. optimum/rbln/transformers/__init__.py +2 -0
  19. optimum/rbln/transformers/cache_utils.py +5 -9
  20. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  21. optimum/rbln/transformers/models/__init__.py +80 -31
  22. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -22
  23. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  25. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +74 -16
  26. optimum/rbln/transformers/models/exaone/exaone_architecture.py +18 -9
  27. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -29
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  29. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  30. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  31. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  32. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +27 -8
  33. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  34. optimum/rbln/transformers/models/midm/modeling_midm.py +4 -29
  35. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  36. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  37. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  38. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  39. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  40. optimum/rbln/transformers/models/t5/modeling_t5.py +57 -4
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
  42. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  43. optimum/rbln/utils/context.py +58 -0
  44. optimum/rbln/utils/decorator_utils.py +55 -0
  45. optimum/rbln/utils/import_utils.py +7 -0
  46. optimum/rbln/utils/runtime_utils.py +4 -4
  47. optimum/rbln/utils/timer_utils.py +2 -2
  48. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +8 -7
  49. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/RECORD +52 -48
  50. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +0 -0
  51. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  52. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -26,208 +26,23 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
26
26
 
27
27
  import torch
28
28
  import torch.nn.functional as F
29
- from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionXLControlNetPipeline
29
+ from diffusers import StableDiffusionXLControlNetPipeline
30
30
  from diffusers.image_processor import PipelineImageInput
31
- from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
32
31
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
33
32
  from diffusers.utils import deprecate, logging
34
33
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
35
- from transformers import CLIPTextModel
36
34
 
37
- from ....modeling_config import use_rbln_config
38
- from ....transformers import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
39
- from ....utils.runtime_utils import ContextRblnConfig
40
- from ...models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
35
+ from ....modeling_diffusers import RBLNDiffusionMixin
36
+ from ....utils.decorator_utils import remove_compile_time_kwargs
37
+ from ...models import RBLNControlNetModel
41
38
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
42
39
 
43
40
 
44
41
  logger = logging.get_logger(__name__)
45
42
 
46
43
 
47
- class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipeline):
48
- @classmethod
49
- @use_rbln_config
50
- def from_pretrained(cls, model_id, **kwargs):
51
- """
52
- Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet.
53
-
54
- This model inherits from [`StableDiffusionXLControlNetPipeline`]. Check the superclass documentation for the generic methods
55
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
56
-
57
- It implements the methods to convert a pre-trained Stable Diffusion XL Controlnet pipeline into a RBLNStableDiffusionXLControlNet pipeline by:
58
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
59
- - compiling the resulting graph using the RBLN compiler.
60
-
61
- Args:
62
- model_id (`Union[str, Path]`):
63
- Can be either:
64
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
65
- - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
66
- """
67
- export = kwargs.pop("export", None)
68
- vae = kwargs.pop("vae", None)
69
- unet = kwargs.pop("unet", None)
70
- text_encoder = kwargs.pop("text_encoder", None)
71
- text_encoder_2 = kwargs.pop("text_encoder_2", None)
72
- controlnet = kwargs.pop("controlnet", None)
73
- model_save_dir = kwargs.pop("model_save_dir", None)
74
- rbln_config = kwargs.pop("rbln_config", None)
75
- rbln_config = {} if rbln_config is None else rbln_config
76
-
77
- device = rbln_config.get("device", None)
78
- device_map = rbln_config.get("device_map", None)
79
- create_runtimes = rbln_config.get("create_runtimes", None)
80
- optimize_host_memory = rbln_config.get("optimize_host_memory", None)
81
-
82
- kwargs_dict = {
83
- "pretrained_model_name_or_path": model_id,
84
- **kwargs,
85
- }
86
-
87
- kwargs_dict.update(
88
- {
89
- **({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
90
- **({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
91
- **(
92
- {"text_encoder": text_encoder}
93
- if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
94
- else {}
95
- ),
96
- **(
97
- {"controlnet": controlnet}
98
- if controlnet is not None
99
- and (
100
- isinstance(controlnet, ControlNetModel)
101
- or all(isinstance(c, ControlNetModel) for c in controlnet)
102
- )
103
- else {}
104
- ),
105
- }
106
- )
107
-
108
- with ContextRblnConfig(
109
- device=device,
110
- device_map=device_map,
111
- create_runtimes=create_runtimes,
112
- optimze_host_mem=optimize_host_memory,
113
- ):
114
- model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
115
-
116
- if export is None or export is False:
117
- return model
118
-
119
- do_classifier_free_guidance = (
120
- rbln_config.pop("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
121
- )
122
-
123
- if not isinstance(vae, RBLNAutoencoderKL):
124
- vae = RBLNAutoencoderKL.from_pretrained(
125
- model_id=model_id,
126
- subfolder="vae",
127
- export=True,
128
- model_save_dir=model_save_dir,
129
- rbln_unet_sample_size=model.unet.config.sample_size,
130
- rbln_use_encode=False,
131
- rbln_vae_scale_factor=model.vae_scale_factor,
132
- rbln_config={**rbln_config},
133
- )
134
-
135
- if not isinstance(text_encoder, RBLNCLIPTextModel):
136
- text_encoder = RBLNCLIPTextModel.from_pretrained(
137
- model_id=model_id,
138
- subfolder="text_encoder",
139
- export=True,
140
- model_save_dir=model_save_dir,
141
- rbln_config={**rbln_config},
142
- )
143
-
144
- if not isinstance(text_encoder_2, RBLNCLIPTextModel):
145
- text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
146
- model_id=model_id,
147
- subfolder="text_encoder_2",
148
- export=True,
149
- model_save_dir=model_save_dir,
150
- rbln_config={**rbln_config},
151
- )
152
-
153
- batch_size = rbln_config.pop("batch_size", 1)
154
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
155
-
156
- if not isinstance(unet, RBLNUNet2DConditionModel):
157
- unet = RBLNUNet2DConditionModel.from_pretrained(
158
- model_id=model_id,
159
- subfolder="unet",
160
- export=True,
161
- model_save_dir=model_save_dir,
162
- rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
163
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
164
- rbln_batch_size=unet_batch_size,
165
- rbln_use_encode=False,
166
- rbln_vae_scale_factor=model.vae_scale_factor,
167
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
168
- rbln_config={**rbln_config},
169
- )
170
-
171
- if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
172
- if isinstance(controlnet, (list, tuple)):
173
- multicontrolnet = []
174
- for i, cid in enumerate(controlnet):
175
- subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
176
- multicontrolnet.append(
177
- RBLNControlNetModel.from_model(
178
- model=cid,
179
- subfolder=subfolder_name,
180
- model_save_dir=model_save_dir,
181
- rbln_batch_size=unet_batch_size,
182
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
183
- rbln_vae_scale_factor=model.vae_scale_factor,
184
- rbln_config={**rbln_config},
185
- )
186
- )
187
- controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
188
- controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
189
- else:
190
- controlnet = RBLNControlNetModel.from_model(
191
- model=controlnet,
192
- subfolder="controlnet",
193
- model_save_dir=model_save_dir,
194
- rbln_batch_size=unet_batch_size,
195
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
196
- rbln_vae_scale_factor=model.vae_scale_factor,
197
- rbln_config={**rbln_config},
198
- )
199
- controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
200
-
201
- if model_save_dir is not None:
202
- # To skip saving original pytorch modules
203
- del (model.vae, model.text_encoder, model.unet, model.controlnet)
204
-
205
- # Direct calling of `save_pretrained` causes config.unet = (None, None).
206
- # So config must be saved again, later.
207
- model.save_pretrained(model_save_dir)
208
-
209
- # replace modules
210
- model.vae = vae
211
- model.text_encoder = text_encoder
212
- model.unet = unet
213
- model.text_encoder_2 = text_encoder_2
214
- model.controlnet = controlnet
215
-
216
- # update config to be able to load from file
217
- update_dict = {
218
- "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
219
- "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
220
- "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
221
- "text_encoder_2": ("optimum.rbln", "RBLNCLIPTextModel"),
222
- "controlnet": controlnet_dict,
223
- }
224
- model.register_to_config(**update_dict)
225
-
226
- if model_save_dir is not None:
227
- # overwrite to replace incorrect config
228
- model.save_config(model_save_dir)
229
-
230
- return model
44
+ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusionXLControlNetPipeline):
45
+ _submodules = ["text_encoder", "text_encoder_2", "unet", "vae", "controlnet"]
231
46
 
232
47
  def check_inputs(
233
48
  self,
@@ -419,6 +234,7 @@ class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipelin
419
234
  )
420
235
 
421
236
  @torch.no_grad()
237
+ @remove_compile_time_kwargs
422
238
  def __call__(
423
239
  self,
424
240
  prompt: Union[str, List[str]] = None,
@@ -682,6 +498,7 @@ class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipelin
682
498
  text_encoder_lora_scale = (
683
499
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
684
500
  )
501
+
685
502
  (
686
503
  prompt_embeds,
687
504
  negative_prompt_embeds,
@@ -26,209 +26,23 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
26
26
 
27
27
  import torch
28
28
  import torch.nn.functional as F
29
- from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline
29
+ from diffusers import StableDiffusionXLControlNetImg2ImgPipeline
30
30
  from diffusers.image_processor import PipelineImageInput
31
- from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
32
31
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
33
32
  from diffusers.utils import deprecate, logging
34
33
  from diffusers.utils.torch_utils import is_compiled_module
35
- from transformers import CLIPTextModel
36
34
 
37
- from ....modeling_config import use_rbln_config
38
- from ....transformers import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
39
- from ....utils.runtime_utils import ContextRblnConfig
40
- from ...models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
35
+ from ....modeling_diffusers import RBLNDiffusionMixin
36
+ from ....utils.decorator_utils import remove_compile_time_kwargs
37
+ from ...models import RBLNControlNetModel
41
38
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
42
39
 
43
40
 
44
41
  logger = logging.get_logger(__name__)
45
42
 
46
43
 
47
- class RBLNStableDiffusionXLControlNetImg2ImgPipeline(StableDiffusionXLControlNetImg2ImgPipeline):
48
- @classmethod
49
- @use_rbln_config
50
- def from_pretrained(cls, model_id, **kwargs):
51
- """
52
- Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet.
53
-
54
- This model inherits from [`StableDiffusionXLControlNetImg2ImgPipeline`]. Check the superclass documentation for the generic methods
55
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
56
-
57
- It implements the methods to convert a pre-trained Stable Diffusion XL Controlnet pipeline into a RBLNStableDiffusionXLControlNetImg2Img pipeline by:
58
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
59
- - compiling the resulting graph using the RBLN compiler.
60
-
61
- Args:
62
- model_id (`Union[str, Path]`):
63
- Can be either:
64
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
65
- - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
66
- """
67
- export = kwargs.pop("export", None)
68
- vae = kwargs.pop("vae", None)
69
- unet = kwargs.pop("unet", None)
70
- text_encoder = kwargs.pop("text_encoder", None)
71
- text_encoder_2 = kwargs.pop("text_encoder_2", None)
72
- controlnet = kwargs.pop("controlnet", None)
73
- model_save_dir = kwargs.pop("model_save_dir", None)
74
-
75
- rbln_config = kwargs.pop("rbln_config", None)
76
- rbln_config = {} if rbln_config is None else rbln_config
77
-
78
- device = rbln_config.get("device", None)
79
- device_map = rbln_config.get("device_map", None)
80
- create_runtimes = rbln_config.get("create_runtimes", None)
81
- optimize_host_memory = rbln_config.get("optimize_host_memory", None)
82
-
83
- kwargs_dict = {
84
- "pretrained_model_name_or_path": model_id,
85
- **kwargs,
86
- }
87
-
88
- kwargs_dict.update(
89
- {
90
- **({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
91
- **({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
92
- **(
93
- {"text_encoder": text_encoder}
94
- if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
95
- else {}
96
- ),
97
- **(
98
- {"controlnet": controlnet}
99
- if controlnet is not None
100
- and (
101
- isinstance(controlnet, ControlNetModel)
102
- or all(isinstance(c, ControlNetModel) for c in controlnet)
103
- )
104
- else {}
105
- ),
106
- }
107
- )
108
-
109
- with ContextRblnConfig(
110
- device=device,
111
- device_map=device_map,
112
- create_runtimes=create_runtimes,
113
- optimze_host_mem=optimize_host_memory,
114
- ):
115
- model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
116
-
117
- if export is None or export is False:
118
- return model
119
-
120
- do_classifier_free_guidance = (
121
- rbln_config.pop("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
122
- )
123
-
124
- if not isinstance(vae, RBLNAutoencoderKL):
125
- vae = RBLNAutoencoderKL.from_pretrained(
126
- model_id=model_id,
127
- subfolder="vae",
128
- export=True,
129
- model_save_dir=model_save_dir,
130
- rbln_unet_sample_size=model.unet.config.sample_size,
131
- rbln_use_encode=True,
132
- rbln_vae_scale_factor=model.vae_scale_factor,
133
- rbln_config={**rbln_config},
134
- )
135
-
136
- if not isinstance(text_encoder, RBLNCLIPTextModel):
137
- text_encoder = RBLNCLIPTextModel.from_pretrained(
138
- model_id=model_id,
139
- subfolder="text_encoder",
140
- export=True,
141
- model_save_dir=model_save_dir,
142
- rbln_config={**rbln_config},
143
- )
144
-
145
- if not isinstance(text_encoder_2, RBLNCLIPTextModel):
146
- text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
147
- model_id=model_id,
148
- subfolder="text_encoder_2",
149
- export=True,
150
- model_save_dir=model_save_dir,
151
- rbln_config={**rbln_config},
152
- )
153
-
154
- batch_size = rbln_config.pop("batch_size", 1)
155
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
156
-
157
- if not isinstance(unet, RBLNUNet2DConditionModel):
158
- unet = RBLNUNet2DConditionModel.from_pretrained(
159
- model_id=model_id,
160
- subfolder="unet",
161
- export=True,
162
- model_save_dir=model_save_dir,
163
- rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
164
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
165
- rbln_batch_size=unet_batch_size,
166
- rbln_use_encode=True,
167
- rbln_vae_scale_factor=model.vae_scale_factor,
168
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
169
- rbln_config={**rbln_config},
170
- )
171
-
172
- if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
173
- if isinstance(controlnet, (list, tuple)):
174
- multicontrolnet = []
175
- for i, cid in enumerate(controlnet):
176
- subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
177
- multicontrolnet.append(
178
- RBLNControlNetModel.from_model(
179
- model=cid,
180
- subfolder=subfolder_name,
181
- model_save_dir=model_save_dir,
182
- rbln_batch_size=unet_batch_size,
183
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
184
- rbln_vae_scale_factor=model.vae_scale_factor,
185
- rbln_config={**rbln_config},
186
- )
187
- )
188
- controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
189
- controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
190
- else:
191
- controlnet = RBLNControlNetModel.from_model(
192
- model=controlnet,
193
- subfolder="controlnet",
194
- model_save_dir=model_save_dir,
195
- rbln_batch_size=unet_batch_size,
196
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
197
- rbln_vae_scale_factor=model.vae_scale_factor,
198
- rbln_config={**rbln_config},
199
- )
200
- controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
201
-
202
- if model_save_dir is not None:
203
- # To skip saving original pytorch modules
204
- del (model.vae, model.text_encoder, model.unet, model.controlnet)
205
-
206
- # Direct calling of `save_pretrained` causes config.unet = (None, None).
207
- # So config must be saved again, later.
208
- model.save_pretrained(model_save_dir)
209
-
210
- # replace modules
211
- model.vae = vae
212
- model.text_encoder = text_encoder
213
- model.unet = unet
214
- model.text_encoder_2 = text_encoder_2
215
- model.controlnet = controlnet
216
-
217
- # update config to be able to load from file
218
- update_dict = {
219
- "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
220
- "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
221
- "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
222
- "text_encoder_2": ("optimum.rbln", "RBLNCLIPTextModelWithProjection"),
223
- "controlnet": controlnet_dict,
224
- }
225
- model.register_to_config(**update_dict)
226
-
227
- if model_save_dir is not None:
228
- # overwrite to replace incorrect config
229
- model.save_config(model_save_dir)
230
-
231
- return model
44
+ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionXLControlNetImg2ImgPipeline):
45
+ _submodules = ["text_encoder", "text_encoder_2", "unet", "vae", "controlnet"]
232
46
 
233
47
  def check_inputs(
234
48
  self,
@@ -432,6 +246,7 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(StableDiffusionXLControlNet
432
246
  )
433
247
 
434
248
  @torch.no_grad()
249
+ @remove_compile_time_kwargs
435
250
  def __call__(
436
251
  self,
437
252
  prompt: Union[str, List[str]] = None,
@@ -718,6 +533,7 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(StableDiffusionXLControlNet
718
533
  text_encoder_lora_scale = (
719
534
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
720
535
  )
536
+
721
537
  (
722
538
  prompt_embeds,
723
539
  negative_prompt_embeds,
@@ -24,115 +24,8 @@
24
24
 
25
25
  from diffusers import StableDiffusionPipeline
26
26
 
27
- from ....modeling_config import use_rbln_config
28
- from ....transformers import RBLNCLIPTextModel
29
- from ....utils.runtime_utils import ContextRblnConfig
30
- from ...models import RBLNAutoencoderKL, RBLNUNet2DConditionModel
27
+ from ....modeling_diffusers import RBLNDiffusionMixin
31
28
 
32
29
 
33
- class RBLNStableDiffusionPipeline(StableDiffusionPipeline):
34
- @classmethod
35
- @use_rbln_config
36
- def from_pretrained(cls, model_id, **kwargs):
37
- """
38
- Pipeline for text-to-image generation using Stable Diffusion.
39
-
40
- This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods
41
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
42
-
43
- It implements the methods to convert a pre-trained Stable Diffusion pipeline into a RBLNStableDiffusion pipeline by:
44
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
45
- - compiling the resulting graph using the RBLN compiler.
46
-
47
- Args:
48
- model_id (`Union[str, Path]`):
49
- Can be either:
50
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
51
- - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
52
- """
53
- export = kwargs.pop("export", None)
54
- model_save_dir = kwargs.pop("model_save_dir", None)
55
- rbln_config = kwargs.pop("rbln_config", None)
56
- rbln_config = {} if rbln_config is None else rbln_config
57
-
58
- device = rbln_config.get("device", None)
59
- device_map = rbln_config.get("device_map", None)
60
- create_runtimes = rbln_config.get("create_runtimes", None)
61
- optimize_host_memory = rbln_config.get("optimize_host_memory", None)
62
-
63
- with ContextRblnConfig(
64
- device=device,
65
- device_map=device_map,
66
- create_runtimes=create_runtimes,
67
- optimze_host_mem=optimize_host_memory,
68
- ):
69
- model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
70
-
71
- if export is None or export is False:
72
- return model
73
-
74
- do_classifier_free_guidance = (
75
- rbln_config.pop("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
76
- )
77
-
78
- vae = RBLNAutoencoderKL.from_pretrained(
79
- model_id=model_id,
80
- subfolder="vae",
81
- export=True,
82
- model_save_dir=model_save_dir,
83
- rbln_unet_sample_size=model.unet.config.sample_size,
84
- rbln_use_encode=False,
85
- rbln_config={**rbln_config},
86
- )
87
- text_encoder = RBLNCLIPTextModel.from_pretrained(
88
- model_id=model_id,
89
- subfolder="text_encoder",
90
- export=True,
91
- model_save_dir=model_save_dir,
92
- rbln_config={**rbln_config},
93
- )
94
-
95
- batch_size = rbln_config.pop("batch_size", 1)
96
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
97
-
98
- unet = RBLNUNet2DConditionModel.from_pretrained(
99
- model_id=model_id,
100
- subfolder="unet",
101
- export=True,
102
- model_save_dir=model_save_dir,
103
- rbln_max_seq_len=text_encoder.config.max_position_embeddings,
104
- rbln_batch_size=unet_batch_size,
105
- rbln_use_encode=False,
106
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
107
- rbln_config={**rbln_config},
108
- )
109
-
110
- if model_save_dir is not None:
111
- # To skip saving original pytorch modules
112
- del (model.vae, model.text_encoder, model.unet)
113
-
114
- # Direct calling of `save_pretrained` causes config.unet = (None, None).
115
- # So config must be saved again, later.
116
- model.save_pretrained(model_save_dir)
117
-
118
- # replace modules
119
- model.vae = vae
120
- model.text_encoder = text_encoder
121
- model.unet = unet
122
-
123
- # update config to be able to load from file.
124
- update_dict = {
125
- "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
126
- "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
127
- "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
128
- }
129
- model.register_to_config(**update_dict)
130
-
131
- if model_save_dir is not None:
132
- # overwrite to replace incorrect config
133
- model.save_config(model_save_dir)
134
-
135
- if optimize_host_memory is False:
136
- model.compiled_models = [vae.compiled_models[0], text_encoder.compiled_models[0], unet.compiled_models[0]]
137
-
138
- return model
30
+ class RBLNStableDiffusionPipeline(RBLNDiffusionMixin, StableDiffusionPipeline):
31
+ _submodules = ["text_encoder", "unet", "vae"]