optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. optimum/rbln/__init__.py +27 -13
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +85 -65
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
  13. optimum/rbln/diffusers/pipelines/__init__.py +60 -12
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  31. optimum/rbln/modeling.py +572 -0
  32. optimum/rbln/modeling_alias.py +1 -1
  33. optimum/rbln/modeling_base.py +176 -763
  34. optimum/rbln/modeling_diffusers.py +329 -0
  35. optimum/rbln/transformers/__init__.py +2 -2
  36. optimum/rbln/transformers/cache_utils.py +5 -9
  37. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  38. optimum/rbln/transformers/models/__init__.py +80 -31
  39. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  40. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
  44. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
  45. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
  46. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
  47. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  48. optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
  49. optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
  50. optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
  51. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  53. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
  54. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  55. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
  56. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
  57. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
  58. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  59. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  60. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
  61. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  62. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  63. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  64. optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
  65. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  66. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  67. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  68. optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
  69. optimum/rbln/utils/decorator_utils.py +59 -0
  70. optimum/rbln/utils/hub.py +131 -0
  71. optimum/rbln/utils/import_utils.py +21 -0
  72. optimum/rbln/utils/model_utils.py +53 -0
  73. optimum/rbln/utils/runtime_utils.py +5 -5
  74. optimum/rbln/utils/submodule.py +114 -0
  75. optimum/rbln/utils/timer_utils.py +2 -2
  76. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  77. optimum_rbln-0.1.15.dist-info/RECORD +110 -0
  78. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  79. optimum/rbln/transformers/generation/streamers.py +0 -139
  80. optimum/rbln/transformers/generation/utils.py +0 -397
  81. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  82. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  83. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  84. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  85. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  86. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  87. optimum_rbln-0.1.12.dist-info/METADATA +0 -119
  88. optimum_rbln-0.1.12.dist-info/RECORD +0 -103
  89. optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
  90. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -26,203 +26,25 @@ from typing import Any, Callable, Dict, List, Optional, Union
26
26
 
27
27
  import torch
28
28
  import torch.nn.functional as F
29
- from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionControlNetPipeline
29
+ from diffusers import StableDiffusionControlNetPipeline
30
30
  from diffusers.image_processor import PipelineImageInput
31
- from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
32
31
  from diffusers.pipelines.controlnet.pipeline_controlnet import retrieve_timesteps
33
32
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
34
33
  from diffusers.utils import deprecate, logging
35
34
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
36
- from transformers import CLIPTextModel
37
35
 
38
- from ....modeling_config import use_rbln_config
39
- from ....transformers import RBLNCLIPTextModel
40
- from ....utils.runtime_utils import ContextRblnConfig
41
- from ...models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
36
+ from ....modeling_diffusers import RBLNDiffusionMixin
37
+ from ....utils.decorator_utils import remove_compile_time_kwargs
38
+ from ...models import RBLNControlNetModel
42
39
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
43
40
 
44
41
 
45
42
  logger = logging.get_logger(__name__)
46
43
 
47
44
 
48
- class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
49
- @classmethod
50
- @use_rbln_config
51
- def from_pretrained(cls, model_id, **kwargs):
52
- """
53
- Pipeline for text-to-image generation using Stable Diffusion with ControlNet.
54
-
55
- This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods
56
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
57
-
58
- It implements the methods to convert a pre-trained Stable Diffusion Controlnet pipeline into a RBLNStableDiffusionControlNet pipeline by:
59
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
60
- - compiling the resulting graph using the RBLN compiler.
61
-
62
- Args:
63
- model_id (`Union[str, Path]`):
64
- Can be either:
65
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
66
- - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
67
- """
68
- export = kwargs.pop("export", None)
69
- vae = kwargs.pop("vae", None)
70
- unet = kwargs.pop("unet", None)
71
- text_encoder = kwargs.pop("text_encoder", None)
72
- controlnet = kwargs.pop("controlnet", None)
73
- model_save_dir = kwargs.pop("model_save_dir", None)
74
- rbln_config = kwargs.pop("rbln_config", None)
75
- rbln_config = {} if rbln_config is None else rbln_config
76
-
77
- device = rbln_config.get("device", None)
78
- device_map = rbln_config.get("device_map", None)
79
- create_runtimes = rbln_config.get("create_runtimes", None)
80
- optimize_host_memory = rbln_config.get("optimize_host_memory", None)
81
-
82
- kwargs_dict = {
83
- "pretrained_model_name_or_path": model_id,
84
- **kwargs,
85
- }
86
-
87
- kwargs_dict.update(
88
- {
89
- **({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
90
- **({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
91
- **(
92
- {"text_encoder": text_encoder}
93
- if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
94
- else {}
95
- ),
96
- **(
97
- {"controlnet": controlnet}
98
- if controlnet is not None
99
- and (
100
- isinstance(controlnet, ControlNetModel)
101
- or all(isinstance(c, ControlNetModel) for c in controlnet)
102
- )
103
- else {}
104
- ),
105
- }
106
- )
107
-
108
- with ContextRblnConfig(
109
- device=device,
110
- device_map=device_map,
111
- create_runtimes=create_runtimes,
112
- optimze_host_mem=optimize_host_memory,
113
- ):
114
- model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
115
-
116
- if export is None or export is False:
117
- return model
118
-
119
- do_classifier_free_guidance = (
120
- rbln_config.pop("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
121
- )
122
-
123
- # compile model, create runtime
124
- if not isinstance(vae, RBLNAutoencoderKL):
125
- vae = RBLNAutoencoderKL.from_pretrained(
126
- model_id=model_id,
127
- subfolder="vae",
128
- export=True,
129
- model_save_dir=model_save_dir,
130
- rbln_unet_sample_size=model.unet.config.sample_size,
131
- rbln_use_encode=False,
132
- rbln_vae_scale_factor=model.vae_scale_factor,
133
- rbln_config={**rbln_config},
134
- )
135
-
136
- if not isinstance(text_encoder, RBLNCLIPTextModel):
137
- text_encoder = RBLNCLIPTextModel.from_pretrained(
138
- model_id=model_id,
139
- subfolder="text_encoder",
140
- export=True,
141
- model_save_dir=model_save_dir,
142
- rbln_config={**rbln_config},
143
- )
144
-
145
- batch_size = rbln_config.pop("batch_size", 1)
146
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
147
-
148
- if not isinstance(unet, RBLNUNet2DConditionModel):
149
- unet = RBLNUNet2DConditionModel.from_pretrained(
150
- model_id=model_id,
151
- subfolder="unet",
152
- export=True,
153
- model_save_dir=model_save_dir,
154
- rbln_max_seq_len=text_encoder.config.max_position_embeddings,
155
- rbln_batch_size=unet_batch_size,
156
- rbln_use_encode=False,
157
- rbln_vae_scale_factor=model.vae_scale_factor,
158
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
159
- rbln_config={**rbln_config},
160
- )
161
-
162
- if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
163
- if isinstance(controlnet, (list, tuple)):
164
- multicontrolnet = []
165
- for i, cid in enumerate(controlnet):
166
- subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
167
- multicontrolnet.append(
168
- RBLNControlNetModel.from_model(
169
- model=cid,
170
- subfolder=subfolder_name,
171
- model_save_dir=model_save_dir,
172
- rbln_batch_size=unet_batch_size,
173
- rbln_vae_scale_factor=model.vae_scale_factor,
174
- rbln_config={**rbln_config},
175
- )
176
- )
177
- controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
178
- controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
179
- else:
180
- controlnet = RBLNControlNetModel.from_model(
181
- model=controlnet,
182
- subfolder="controlnet",
183
- model_save_dir=model_save_dir,
184
- rbln_batch_size=unet_batch_size,
185
- rbln_vae_scale_factor=model.vae_scale_factor,
186
- rbln_config={**rbln_config},
187
- )
188
- controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
189
-
190
- if model_save_dir is not None:
191
- # To skip saving original pytorch modules
192
- del (model.vae, model.text_encoder, model.unet, model.controlnet)
193
-
194
- # Direct calling of `save_pretrained` causes config.unet = (None, None).
195
- # So config must be saved again, later.
196
- model.save_pretrained(model_save_dir)
197
-
198
- # replace modules
199
- model.vae = vae
200
- model.text_encoder = text_encoder
201
- model.unet = unet
202
- model.controlnet = controlnet
203
-
204
- # update config to be able to load from file.
205
- update_dict = {
206
- "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
207
- "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
208
- "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
209
- "controlnet": controlnet_dict,
210
- }
211
- model.register_to_config(**update_dict)
212
-
213
- if model_save_dir is not None:
214
- # overwrite to replace incorrect config
215
- model.save_config(model_save_dir)
216
-
217
- if optimize_host_memory is False:
218
- model.compiled_models = [vae.compiled_models[0], text_encoder.compiled_models[0], unet.compiled_models[0]]
219
- if isinstance(controlnet, RBLNMultiControlNetModel):
220
- for c_model in controlnet.nets:
221
- model.compiled_models.append(c_model.compiled_models[0])
222
- else:
223
- model.compiled_models.append(controlnet.compiled_models[0])
224
-
225
- return model
45
+ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionControlNetPipeline):
46
+ original_class = StableDiffusionControlNetPipeline
47
+ _submodules = ["text_encoder", "unet", "vae", "controlnet"]
226
48
 
227
49
  def check_inputs(
228
50
  self,
@@ -388,6 +210,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
388
210
  )
389
211
 
390
212
  @torch.no_grad()
213
+ @remove_compile_time_kwargs
391
214
  def __call__(
392
215
  self,
393
216
  prompt: Union[str, List[str]] = None,
@@ -597,6 +420,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
597
420
  text_encoder_lora_scale = (
598
421
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
599
422
  )
423
+
600
424
  prompt_embeds, negative_prompt_embeds = self.encode_prompt(
601
425
  prompt,
602
426
  device,
@@ -26,207 +26,24 @@ from typing import Any, Callable, Dict, List, Optional, Union
26
26
 
27
27
  import torch
28
28
  import torch.nn.functional as F
29
- from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionControlNetImg2ImgPipeline
29
+ from diffusers import StableDiffusionControlNetImg2ImgPipeline
30
30
  from diffusers.image_processor import PipelineImageInput
31
- from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
32
31
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
33
32
  from diffusers.utils import deprecate, logging
34
33
  from diffusers.utils.torch_utils import is_compiled_module
35
- from transformers import CLIPTextModel
36
34
 
37
- from ....modeling_config import use_rbln_config
38
- from ....transformers import RBLNCLIPTextModel
39
- from ....utils.runtime_utils import ContextRblnConfig
40
- from ...models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
35
+ from ....modeling_diffusers import RBLNDiffusionMixin
36
+ from ....utils.decorator_utils import remove_compile_time_kwargs
37
+ from ...models import RBLNControlNetModel
41
38
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
42
39
 
43
40
 
44
41
  logger = logging.get_logger(__name__)
45
42
 
46
43
 
47
- class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2ImgPipeline):
48
- @classmethod
49
- @use_rbln_config
50
- def from_pretrained(cls, model_id, **kwargs):
51
- """
52
- Pipeline for image-to-image generation using Stable Diffusion with ControlNet.
53
-
54
- This model inherits from [`StableDiffusionControlNetImg2ImgPipeline`]. Check the superclass documentation for the generic methods
55
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
56
-
57
- It implements the methods to convert a pre-trained Stable Diffusion Controlnet pipeline into a RBLNStableDiffusionControlNetImg2Img pipeline by:
58
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
59
- - compiling the resulting graph using the RBLN compiler.
60
-
61
- Args:
62
- model_id (`Union[str, Path]`):
63
- Can be either:
64
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
65
- - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
66
- """
67
- export = kwargs.pop("export", None)
68
- vae = kwargs.pop("vae", None)
69
- unet = kwargs.pop("unet", None)
70
- text_encoder = kwargs.pop("text_encoder", None)
71
- controlnet = kwargs.pop("controlnet", None)
72
- model_save_dir = kwargs.pop("model_save_dir", None)
73
- rbln_config = kwargs.pop("rbln_config", None)
74
- rbln_config = {} if rbln_config is None else rbln_config
75
-
76
- device = rbln_config.get("device", None)
77
- device_map = rbln_config.get("device_map", None)
78
- create_runtimes = rbln_config.get("create_runtimes", None)
79
- optimize_host_memory = rbln_config.get("optimize_host_memory", None)
80
-
81
- kwargs_dict = {
82
- "pretrained_model_name_or_path": model_id,
83
- **kwargs,
84
- }
85
-
86
- kwargs_dict.update(
87
- {
88
- **({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
89
- **({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
90
- **(
91
- {"text_encoder": text_encoder}
92
- if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
93
- else {}
94
- ),
95
- **(
96
- {"controlnet": controlnet}
97
- if controlnet is not None
98
- and (
99
- isinstance(controlnet, ControlNetModel)
100
- or all(isinstance(c, ControlNetModel) for c in controlnet)
101
- )
102
- else {}
103
- ),
104
- }
105
- )
106
-
107
- with ContextRblnConfig(
108
- device=device,
109
- device_map=device_map,
110
- create_runtimes=create_runtimes,
111
- optimze_host_mem=optimize_host_memory,
112
- ):
113
- model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
114
-
115
- if export is None or export is False:
116
- return model
117
-
118
- do_classifier_free_guidance = (
119
- rbln_config.pop("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
120
- )
121
-
122
- # compile model, create runtime
123
- if not isinstance(vae, RBLNAutoencoderKL):
124
- vae = RBLNAutoencoderKL.from_pretrained(
125
- model_id=model_id,
126
- subfolder="vae",
127
- export=True,
128
- model_save_dir=model_save_dir,
129
- rbln_unet_sample_size=model.unet.config.sample_size,
130
- rbln_use_encode=True,
131
- rbln_vae_scale_factor=model.vae_scale_factor,
132
- rbln_config={**rbln_config},
133
- )
134
-
135
- if not isinstance(text_encoder, RBLNCLIPTextModel):
136
- text_encoder = RBLNCLIPTextModel.from_pretrained(
137
- model_id=model_id,
138
- subfolder="text_encoder",
139
- export=True,
140
- model_save_dir=model_save_dir,
141
- rbln_config={**rbln_config},
142
- )
143
-
144
- batch_size = rbln_config.pop("batch_size", 1)
145
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
146
-
147
- if not isinstance(unet, RBLNUNet2DConditionModel):
148
- unet = RBLNUNet2DConditionModel.from_pretrained(
149
- model_id=model_id,
150
- subfolder="unet",
151
- export=True,
152
- model_save_dir=model_save_dir,
153
- rbln_max_seq_len=text_encoder.config.max_position_embeddings,
154
- rbln_batch_size=unet_batch_size,
155
- rbln_use_encode=True,
156
- rbln_vae_scale_factor=model.vae_scale_factor,
157
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
158
- rbln_config={**rbln_config},
159
- )
160
-
161
- if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
162
- if isinstance(controlnet, (list, tuple)):
163
- multicontrolnet = []
164
- for i, cid in enumerate(controlnet):
165
- subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
166
- multicontrolnet.append(
167
- RBLNControlNetModel.from_model(
168
- model=cid,
169
- subfolder=subfolder_name,
170
- model_save_dir=model_save_dir,
171
- rbln_batch_size=unet_batch_size,
172
- rbln_vae_scale_factor=model.vae_scale_factor,
173
- rbln_config={**rbln_config},
174
- )
175
- )
176
- controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
177
- controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
178
- else:
179
- controlnet = RBLNControlNetModel.from_model(
180
- model=controlnet,
181
- subfolder="controlnet",
182
- model_save_dir=model_save_dir,
183
- rbln_batch_size=unet_batch_size,
184
- rbln_vae_scale_factor=model.vae_scale_factor,
185
- rbln_config={**rbln_config},
186
- )
187
- controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
188
-
189
- if model_save_dir is not None:
190
- # To skip saving original pytorch modules
191
- del (model.vae, model.text_encoder, model.unet, model.controlnet)
192
-
193
- # Direct calling of `save_pretrained` causes config.unet = (None, None).
194
- # So config must be saved again, later.
195
- model.save_pretrained(model_save_dir)
196
-
197
- # replace modules
198
- model.vae = vae
199
- model.text_encoder = text_encoder
200
- model.unet = unet
201
- model.controlnet = controlnet
202
-
203
- # update config to be able to load from file.
204
- update_dict = {
205
- "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
206
- "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
207
- "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
208
- "controlnet": controlnet_dict,
209
- }
210
- model.register_to_config(**update_dict)
211
-
212
- if model_save_dir is not None:
213
- # overwrite to replace incorrect config
214
- model.save_config(model_save_dir)
215
-
216
- if optimize_host_memory is False:
217
- model.compiled_models = [
218
- vae.compiled_models[0],
219
- vae.compiled_models[1],
220
- text_encoder.compiled_models[0],
221
- unet.compiled_models[0],
222
- ]
223
- if isinstance(controlnet, RBLNMultiControlNetModel):
224
- for c_model in controlnet.nets:
225
- model.compiled_models.append(c_model.compiled_models[0])
226
- else:
227
- model.compiled_models.append(controlnet.compiled_models[0])
228
-
229
- return model
44
+ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionControlNetImg2ImgPipeline):
45
+ original_class = StableDiffusionControlNetImg2ImgPipeline
46
+ _submodules = ["text_encoder", "unet", "vae", "controlnet"]
230
47
 
231
48
  def check_inputs(
232
49
  self,
@@ -386,6 +203,7 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2
386
203
  )
387
204
 
388
205
  @torch.no_grad()
206
+ @remove_compile_time_kwargs
389
207
  def __call__(
390
208
  self,
391
209
  prompt: Union[str, List[str]] = None,
@@ -594,6 +412,7 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2
594
412
  text_encoder_lora_scale = (
595
413
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
596
414
  )
415
+
597
416
  prompt_embeds, negative_prompt_embeds = self.encode_prompt(
598
417
  prompt,
599
418
  device,
@@ -26,208 +26,24 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
26
26
 
27
27
  import torch
28
28
  import torch.nn.functional as F
29
- from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionXLControlNetPipeline
29
+ from diffusers import StableDiffusionXLControlNetPipeline
30
30
  from diffusers.image_processor import PipelineImageInput
31
- from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
32
31
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
33
32
  from diffusers.utils import deprecate, logging
34
33
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
35
- from transformers import CLIPTextModel
36
34
 
37
- from ....modeling_config import use_rbln_config
38
- from ....transformers import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
39
- from ....utils.runtime_utils import ContextRblnConfig
40
- from ...models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
35
+ from ....modeling_diffusers import RBLNDiffusionMixin
36
+ from ....utils.decorator_utils import remove_compile_time_kwargs
37
+ from ...models import RBLNControlNetModel
41
38
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
42
39
 
43
40
 
44
41
  logger = logging.get_logger(__name__)
45
42
 
46
43
 
47
- class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipeline):
48
- @classmethod
49
- @use_rbln_config
50
- def from_pretrained(cls, model_id, **kwargs):
51
- """
52
- Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet.
53
-
54
- This model inherits from [`StableDiffusionXLControlNetPipeline`]. Check the superclass documentation for the generic methods
55
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
56
-
57
- It implements the methods to convert a pre-trained Stable Diffusion XL Controlnet pipeline into a RBLNStableDiffusionXLControlNet pipeline by:
58
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
59
- - compiling the resulting graph using the RBLN compiler.
60
-
61
- Args:
62
- model_id (`Union[str, Path]`):
63
- Can be either:
64
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
65
- - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
66
- """
67
- export = kwargs.pop("export", None)
68
- vae = kwargs.pop("vae", None)
69
- unet = kwargs.pop("unet", None)
70
- text_encoder = kwargs.pop("text_encoder", None)
71
- text_encoder_2 = kwargs.pop("text_encoder_2", None)
72
- controlnet = kwargs.pop("controlnet", None)
73
- model_save_dir = kwargs.pop("model_save_dir", None)
74
- rbln_config = kwargs.pop("rbln_config", None)
75
- rbln_config = {} if rbln_config is None else rbln_config
76
-
77
- device = rbln_config.get("device", None)
78
- device_map = rbln_config.get("device_map", None)
79
- create_runtimes = rbln_config.get("create_runtimes", None)
80
- optimize_host_memory = rbln_config.get("optimize_host_memory", None)
81
-
82
- kwargs_dict = {
83
- "pretrained_model_name_or_path": model_id,
84
- **kwargs,
85
- }
86
-
87
- kwargs_dict.update(
88
- {
89
- **({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
90
- **({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
91
- **(
92
- {"text_encoder": text_encoder}
93
- if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
94
- else {}
95
- ),
96
- **(
97
- {"controlnet": controlnet}
98
- if controlnet is not None
99
- and (
100
- isinstance(controlnet, ControlNetModel)
101
- or all(isinstance(c, ControlNetModel) for c in controlnet)
102
- )
103
- else {}
104
- ),
105
- }
106
- )
107
-
108
- with ContextRblnConfig(
109
- device=device,
110
- device_map=device_map,
111
- create_runtimes=create_runtimes,
112
- optimze_host_mem=optimize_host_memory,
113
- ):
114
- model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
115
-
116
- if export is None or export is False:
117
- return model
118
-
119
- do_classifier_free_guidance = (
120
- rbln_config.pop("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
121
- )
122
-
123
- if not isinstance(vae, RBLNAutoencoderKL):
124
- vae = RBLNAutoencoderKL.from_pretrained(
125
- model_id=model_id,
126
- subfolder="vae",
127
- export=True,
128
- model_save_dir=model_save_dir,
129
- rbln_unet_sample_size=model.unet.config.sample_size,
130
- rbln_use_encode=False,
131
- rbln_vae_scale_factor=model.vae_scale_factor,
132
- rbln_config={**rbln_config},
133
- )
134
-
135
- if not isinstance(text_encoder, RBLNCLIPTextModel):
136
- text_encoder = RBLNCLIPTextModel.from_pretrained(
137
- model_id=model_id,
138
- subfolder="text_encoder",
139
- export=True,
140
- model_save_dir=model_save_dir,
141
- rbln_config={**rbln_config},
142
- )
143
-
144
- if not isinstance(text_encoder_2, RBLNCLIPTextModel):
145
- text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
146
- model_id=model_id,
147
- subfolder="text_encoder_2",
148
- export=True,
149
- model_save_dir=model_save_dir,
150
- rbln_config={**rbln_config},
151
- )
152
-
153
- batch_size = rbln_config.pop("batch_size", 1)
154
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
155
-
156
- if not isinstance(unet, RBLNUNet2DConditionModel):
157
- unet = RBLNUNet2DConditionModel.from_pretrained(
158
- model_id=model_id,
159
- subfolder="unet",
160
- export=True,
161
- model_save_dir=model_save_dir,
162
- rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
163
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
164
- rbln_batch_size=unet_batch_size,
165
- rbln_use_encode=False,
166
- rbln_vae_scale_factor=model.vae_scale_factor,
167
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
168
- rbln_config={**rbln_config},
169
- )
170
-
171
- if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
172
- if isinstance(controlnet, (list, tuple)):
173
- multicontrolnet = []
174
- for i, cid in enumerate(controlnet):
175
- subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
176
- multicontrolnet.append(
177
- RBLNControlNetModel.from_model(
178
- model=cid,
179
- subfolder=subfolder_name,
180
- model_save_dir=model_save_dir,
181
- rbln_batch_size=unet_batch_size,
182
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
183
- rbln_vae_scale_factor=model.vae_scale_factor,
184
- rbln_config={**rbln_config},
185
- )
186
- )
187
- controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
188
- controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
189
- else:
190
- controlnet = RBLNControlNetModel.from_model(
191
- model=controlnet,
192
- subfolder="controlnet",
193
- model_save_dir=model_save_dir,
194
- rbln_batch_size=unet_batch_size,
195
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
196
- rbln_vae_scale_factor=model.vae_scale_factor,
197
- rbln_config={**rbln_config},
198
- )
199
- controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
200
-
201
- if model_save_dir is not None:
202
- # To skip saving original pytorch modules
203
- del (model.vae, model.text_encoder, model.unet, model.controlnet)
204
-
205
- # Direct calling of `save_pretrained` causes config.unet = (None, None).
206
- # So config must be saved again, later.
207
- model.save_pretrained(model_save_dir)
208
-
209
- # replace modules
210
- model.vae = vae
211
- model.text_encoder = text_encoder
212
- model.unet = unet
213
- model.text_encoder_2 = text_encoder_2
214
- model.controlnet = controlnet
215
-
216
- # update config to be able to load from file
217
- update_dict = {
218
- "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
219
- "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
220
- "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
221
- "text_encoder_2": ("optimum.rbln", "RBLNCLIPTextModel"),
222
- "controlnet": controlnet_dict,
223
- }
224
- model.register_to_config(**update_dict)
225
-
226
- if model_save_dir is not None:
227
- # overwrite to replace incorrect config
228
- model.save_config(model_save_dir)
229
-
230
- return model
44
+ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusionXLControlNetPipeline):
45
+ original_class = StableDiffusionXLControlNetPipeline
46
+ _submodules = ["text_encoder", "text_encoder_2", "unet", "vae", "controlnet"]
231
47
 
232
48
  def check_inputs(
233
49
  self,
@@ -419,6 +235,7 @@ class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipelin
419
235
  )
420
236
 
421
237
  @torch.no_grad()
238
+ @remove_compile_time_kwargs
422
239
  def __call__(
423
240
  self,
424
241
  prompt: Union[str, List[str]] = None,
@@ -682,6 +499,7 @@ class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipelin
682
499
  text_encoder_lora_scale = (
683
500
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
684
501
  )
502
+
685
503
  (
686
504
  prompt_embeds,
687
505
  negative_prompt_embeds,