optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +21 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  5. optimum/rbln/diffusers/models/controlnet.py +3 -0
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  16. optimum/rbln/modeling_alias.py +14 -0
  17. optimum/rbln/modeling_base.py +282 -100
  18. optimum/rbln/modeling_seq2seq.py +58 -132
  19. optimum/rbln/transformers/__init__.py +8 -0
  20. optimum/rbln/transformers/cache_utils.py +111 -0
  21. optimum/rbln/transformers/generation/utils.py +0 -2
  22. optimum/rbln/transformers/models/__init__.py +3 -0
  23. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  24. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  25. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  26. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  27. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  28. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  29. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  30. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  31. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  32. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  33. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
  34. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
  37. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  38. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  42. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  43. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  44. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  45. optimum/rbln/utils/__init__.py +1 -1
  46. optimum/rbln/utils/import_utils.py +46 -0
  47. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
  48. optimum_rbln-0.1.8.dist-info/RECORD +73 -0
  49. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
  51. optimum_rbln-0.1.4.dist-info/RECORD +0 -63
  52. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -22,18 +22,18 @@
22
22
  # from Rebellions Inc.
23
23
  """RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
24
24
 
25
- from pathlib import Path
26
- from tempfile import TemporaryDirectory
27
25
  from typing import Any, Callable, Dict, List, Optional, Union
28
26
 
29
27
  import torch
30
28
  import torch.nn.functional as F
31
- from diffusers import StableDiffusionControlNetPipeline
29
+ from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionControlNetPipeline
32
30
  from diffusers.image_processor import PipelineImageInput
31
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
33
32
  from diffusers.pipelines.controlnet.pipeline_controlnet import retrieve_timesteps
34
33
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
35
34
  from diffusers.utils import deprecate, logging
36
35
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
36
+ from transformers import CLIPTextModel
37
37
 
38
38
  from ....modeling_base import RBLNBaseModel
39
39
  from ....transformers import RBLNCLIPTextModel
@@ -64,18 +64,40 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
64
64
  - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
65
65
  """
66
66
  export = kwargs.pop("export", None)
67
+ vae = kwargs.pop("vae", None)
68
+ unet = kwargs.pop("unet", None)
67
69
  text_encoder = kwargs.pop("text_encoder", None)
68
- controlnets = kwargs.pop("controlnet", None)
70
+ controlnet = kwargs.pop("controlnet", None)
71
+ model_save_dir = kwargs.pop("model_save_dir", None)
69
72
 
70
73
  rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
71
74
 
72
75
  kwargs_dict = {
73
76
  "pretrained_model_name_or_path": model_id,
74
- "text_encoder": text_encoder,
75
- "controlnet": controlnets,
76
77
  **kwargs,
77
78
  }
78
79
 
80
+ kwargs_dict.update(
81
+ {
82
+ **({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
83
+ **({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
84
+ **(
85
+ {"text_encoder": text_encoder}
86
+ if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
87
+ else {}
88
+ ),
89
+ **(
90
+ {"controlnet": controlnet}
91
+ if controlnet is not None
92
+ and (
93
+ isinstance(controlnet, ControlNetModel)
94
+ or all(isinstance(c, ControlNetModel) for c in controlnet)
95
+ )
96
+ else {}
97
+ ),
98
+ }
99
+ )
100
+
79
101
  model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
80
102
 
81
103
  if export is None or export is False:
@@ -85,64 +107,87 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
85
107
  rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
86
108
  )
87
109
 
88
- save_dir = TemporaryDirectory()
89
- save_dir_path = Path(save_dir.name)
90
-
91
- model.save_pretrained(save_directory=save_dir_path, **kwargs)
92
-
93
110
  # compile model, create runtime
94
- vae = RBLNAutoencoderKL.from_pretrained(
95
- model_id=save_dir_path / "vae",
96
- export=True,
97
- rbln_unet_sample_size=model.unet.config.sample_size,
98
- rbln_use_encode=False,
99
- rbln_vae_scale_factor=model.vae_scale_factor,
100
- **rbln_config_kwargs,
101
- **rbln_constructor_kwargs,
102
- )
103
-
104
- text_encoder = RBLNCLIPTextModel.from_pretrained(
105
- model_id=save_dir_path / "text_encoder",
106
- export=True,
107
- **rbln_config_kwargs,
108
- **rbln_constructor_kwargs,
109
- )
110
-
111
- batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
112
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
113
-
114
- unet = RBLNUNet2DConditionModel.from_pretrained(
115
- model_id=save_dir_path / "unet",
116
- export=True,
117
- rbln_max_seq_len=text_encoder.config.max_position_embeddings,
118
- rbln_batch_size=unet_batch_size,
119
- rbln_use_encode=False,
120
- rbln_vae_scale_factor=model.vae_scale_factor,
121
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
122
- **rbln_config_kwargs,
123
- **rbln_constructor_kwargs,
124
- )
125
-
126
- if isinstance(controlnets, (list, tuple)):
127
- controlnet = RBLNMultiControlNetModel.from_pretrained(
128
- model_id=str(save_dir_path / "controlnet"),
111
+ if not isinstance(vae, RBLNAutoencoderKL):
112
+ vae = RBLNAutoencoderKL.from_pretrained(
113
+ model_id=model_id,
114
+ subfolder="vae",
129
115
  export=True,
130
- rbln_batch_size=unet_batch_size,
116
+ model_save_dir=model_save_dir,
117
+ rbln_unet_sample_size=model.unet.config.sample_size,
118
+ rbln_use_encode=False,
131
119
  rbln_vae_scale_factor=model.vae_scale_factor,
132
120
  **rbln_config_kwargs,
133
121
  **rbln_constructor_kwargs,
134
122
  )
135
- controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
136
- else:
137
- controlnet = RBLNControlNetModel.from_pretrained(
138
- model_id=save_dir_path / "controlnet",
123
+
124
+ if not isinstance(text_encoder, RBLNCLIPTextModel):
125
+ text_encoder = RBLNCLIPTextModel.from_pretrained(
126
+ model_id=model_id,
127
+ subfolder="text_encoder",
139
128
  export=True,
129
+ model_save_dir=model_save_dir,
130
+ **rbln_config_kwargs,
131
+ **rbln_constructor_kwargs,
132
+ )
133
+
134
+ batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
135
+ unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
136
+
137
+ if not isinstance(unet, RBLNUNet2DConditionModel):
138
+ unet = RBLNUNet2DConditionModel.from_pretrained(
139
+ model_id=model_id,
140
+ subfolder="unet",
141
+ export=True,
142
+ model_save_dir=model_save_dir,
143
+ rbln_max_seq_len=text_encoder.config.max_position_embeddings,
140
144
  rbln_batch_size=unet_batch_size,
145
+ rbln_use_encode=False,
141
146
  rbln_vae_scale_factor=model.vae_scale_factor,
147
+ rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
142
148
  **rbln_config_kwargs,
143
149
  **rbln_constructor_kwargs,
144
150
  )
145
- controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
151
+
152
+ if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
153
+ if isinstance(controlnet, (list, tuple)):
154
+ multicontrolnet = []
155
+ for i, cid in enumerate(controlnet):
156
+ subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
157
+ multicontrolnet.append(
158
+ RBLNControlNetModel.from_pretrained(
159
+ model_id=cid.config._name_or_path,
160
+ subfolder=subfolder_name,
161
+ export=True,
162
+ model_save_dir=model_save_dir,
163
+ rbln_batch_size=unet_batch_size,
164
+ rbln_vae_scale_factor=model.vae_scale_factor,
165
+ **rbln_config_kwargs,
166
+ **rbln_constructor_kwargs,
167
+ )
168
+ )
169
+ controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
170
+ controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
171
+ else:
172
+ controlnet = RBLNControlNetModel.from_pretrained(
173
+ model_id=controlnet.config._name_or_path,
174
+ subfolder="controlnet",
175
+ export=True,
176
+ model_save_dir=model_save_dir,
177
+ rbln_batch_size=unet_batch_size,
178
+ rbln_vae_scale_factor=model.vae_scale_factor,
179
+ **rbln_config_kwargs,
180
+ **rbln_constructor_kwargs,
181
+ )
182
+ controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
183
+
184
+ if model_save_dir is not None:
185
+ # To skip saving original pytorch modules
186
+ del (model.vae, model.text_encoder, model.unet, model.controlnet)
187
+
188
+ # Direct calling of `save_pretrained` causes config.unet = (None, None).
189
+ # So config must be saved again, later.
190
+ model.save_pretrained(model_save_dir)
146
191
 
147
192
  # replace modules
148
193
  model.vae = vae
@@ -159,7 +204,18 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
159
204
  }
160
205
  model.register_to_config(**update_dict)
161
206
 
162
- model.models = [vae.model[0], text_encoder.model[0], unet.model[0], controlnet.model[0]]
207
+ if model_save_dir is not None:
208
+ # overwrite to replace incorrect config
209
+ model.save_config(model_save_dir)
210
+
211
+ # use for CI to access each compiled model
212
+ if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
213
+ model.compiled_models = [vae.compiled_models[0], text_encoder.compiled_models[0], unet.compiled_models[0]]
214
+ if isinstance(controlnet, RBLNMultiControlNetModel):
215
+ for c_model in controlnet.nets:
216
+ model.compiled_models.append(c_model.compiled_models[0])
217
+ else:
218
+ model.compiled_models.append(controlnet.compiled_models[0])
163
219
 
164
220
  return model
165
221
 
@@ -22,17 +22,17 @@
22
22
  # from Rebellions Inc.
23
23
  """RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
24
24
 
25
- from pathlib import Path
26
- from tempfile import TemporaryDirectory
27
25
  from typing import Any, Callable, Dict, List, Optional, Union
28
26
 
29
27
  import torch
30
28
  import torch.nn.functional as F
31
- from diffusers import StableDiffusionControlNetImg2ImgPipeline
29
+ from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionControlNetImg2ImgPipeline
32
30
  from diffusers.image_processor import PipelineImageInput
31
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
33
32
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
34
33
  from diffusers.utils import deprecate, logging
35
34
  from diffusers.utils.torch_utils import is_compiled_module
35
+ from transformers import CLIPTextModel
36
36
 
37
37
  from ....modeling_base import RBLNBaseModel
38
38
  from ....transformers import RBLNCLIPTextModel
@@ -63,18 +63,40 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2
63
63
  - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
64
64
  """
65
65
  export = kwargs.pop("export", None)
66
+ vae = kwargs.pop("vae", None)
67
+ unet = kwargs.pop("unet", None)
66
68
  text_encoder = kwargs.pop("text_encoder", None)
67
- controlnets = kwargs.pop("controlnet", None)
69
+ controlnet = kwargs.pop("controlnet", None)
70
+ model_save_dir = kwargs.pop("model_save_dir", None)
68
71
 
69
72
  rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
70
73
 
71
74
  kwargs_dict = {
72
75
  "pretrained_model_name_or_path": model_id,
73
- "text_encoder": text_encoder,
74
- "controlnet": controlnets,
75
76
  **kwargs,
76
77
  }
77
78
 
79
+ kwargs_dict.update(
80
+ {
81
+ **({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
82
+ **({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
83
+ **(
84
+ {"text_encoder": text_encoder}
85
+ if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
86
+ else {}
87
+ ),
88
+ **(
89
+ {"controlnet": controlnet}
90
+ if controlnet is not None
91
+ and (
92
+ isinstance(controlnet, ControlNetModel)
93
+ or all(isinstance(c, ControlNetModel) for c in controlnet)
94
+ )
95
+ else {}
96
+ ),
97
+ }
98
+ )
99
+
78
100
  model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
79
101
 
80
102
  if export is None or export is False:
@@ -84,64 +106,87 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2
84
106
  rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
85
107
  )
86
108
 
87
- save_dir = TemporaryDirectory()
88
- save_dir_path = Path(save_dir.name)
89
-
90
- model.save_pretrained(save_directory=save_dir_path, **kwargs)
91
-
92
109
  # compile model, create runtime
93
- vae = RBLNAutoencoderKL.from_pretrained(
94
- model_id=save_dir_path / "vae",
95
- export=True,
96
- rbln_unet_sample_size=model.unet.config.sample_size,
97
- rbln_use_encode=True,
98
- rbln_vae_scale_factor=model.vae_scale_factor,
99
- **rbln_config_kwargs,
100
- **rbln_constructor_kwargs,
101
- )
102
-
103
- text_encoder = RBLNCLIPTextModel.from_pretrained(
104
- model_id=save_dir_path / "text_encoder",
105
- export=True,
106
- **rbln_config_kwargs,
107
- **rbln_constructor_kwargs,
108
- )
109
-
110
- batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
111
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
112
-
113
- unet = RBLNUNet2DConditionModel.from_pretrained(
114
- model_id=save_dir_path / "unet",
115
- export=True,
116
- rbln_max_seq_len=text_encoder.config.max_position_embeddings,
117
- rbln_batch_size=unet_batch_size,
118
- rbln_use_encode=True,
119
- rbln_vae_scale_factor=model.vae_scale_factor,
120
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
121
- **rbln_config_kwargs,
122
- **rbln_constructor_kwargs,
123
- )
124
-
125
- if isinstance(controlnets, (list, tuple)):
126
- controlnet = RBLNMultiControlNetModel.from_pretrained(
127
- model_id=str(save_dir_path / "controlnet"),
110
+ if not isinstance(vae, RBLNAutoencoderKL):
111
+ vae = RBLNAutoencoderKL.from_pretrained(
112
+ model_id=model_id,
113
+ subfolder="vae",
128
114
  export=True,
129
- rbln_batch_size=unet_batch_size,
115
+ model_save_dir=model_save_dir,
116
+ rbln_unet_sample_size=model.unet.config.sample_size,
117
+ rbln_use_encode=True,
130
118
  rbln_vae_scale_factor=model.vae_scale_factor,
131
119
  **rbln_config_kwargs,
132
120
  **rbln_constructor_kwargs,
133
121
  )
134
- controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
135
- else:
136
- controlnet = RBLNControlNetModel.from_pretrained(
137
- model_id=save_dir_path / "controlnet",
122
+
123
+ if not isinstance(text_encoder, RBLNCLIPTextModel):
124
+ text_encoder = RBLNCLIPTextModel.from_pretrained(
125
+ model_id=model_id,
126
+ subfolder="text_encoder",
127
+ export=True,
128
+ model_save_dir=model_save_dir,
129
+ **rbln_config_kwargs,
130
+ **rbln_constructor_kwargs,
131
+ )
132
+
133
+ batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
134
+ unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
135
+
136
+ if not isinstance(unet, RBLNUNet2DConditionModel):
137
+ unet = RBLNUNet2DConditionModel.from_pretrained(
138
+ model_id=model_id,
139
+ subfolder="unet",
138
140
  export=True,
141
+ model_save_dir=model_save_dir,
142
+ rbln_max_seq_len=text_encoder.config.max_position_embeddings,
139
143
  rbln_batch_size=unet_batch_size,
144
+ rbln_use_encode=True,
140
145
  rbln_vae_scale_factor=model.vae_scale_factor,
146
+ rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
141
147
  **rbln_config_kwargs,
142
148
  **rbln_constructor_kwargs,
143
149
  )
144
- controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
150
+
151
+ if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
152
+ if isinstance(controlnet, (list, tuple)):
153
+ multicontrolnet = []
154
+ for i, cid in enumerate(controlnet):
155
+ subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
156
+ multicontrolnet.append(
157
+ RBLNControlNetModel.from_pretrained(
158
+ model_id=cid.config._name_or_path,
159
+ subfolder=subfolder_name,
160
+ export=True,
161
+ model_save_dir=model_save_dir,
162
+ rbln_batch_size=unet_batch_size,
163
+ rbln_vae_scale_factor=model.vae_scale_factor,
164
+ **rbln_config_kwargs,
165
+ **rbln_constructor_kwargs,
166
+ )
167
+ )
168
+ controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
169
+ controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
170
+ else:
171
+ controlnet = RBLNControlNetModel.from_pretrained(
172
+ model_id=controlnet.config._name_or_path,
173
+ subfolder="controlnet",
174
+ export=True,
175
+ model_save_dir=model_save_dir,
176
+ rbln_batch_size=unet_batch_size,
177
+ rbln_vae_scale_factor=model.vae_scale_factor,
178
+ **rbln_config_kwargs,
179
+ **rbln_constructor_kwargs,
180
+ )
181
+ controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
182
+
183
+ if model_save_dir is not None:
184
+ # To skip saving original pytorch modules
185
+ del (model.vae, model.text_encoder, model.unet, model.controlnet)
186
+
187
+ # Direct calling of `save_pretrained` causes config.unet = (None, None).
188
+ # So config must be saved again, later.
189
+ model.save_pretrained(model_save_dir)
145
190
 
146
191
  # replace modules
147
192
  model.vae = vae
@@ -158,7 +203,23 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2
158
203
  }
159
204
  model.register_to_config(**update_dict)
160
205
 
161
- model.models = [vae.model[0], vae.model[1], text_encoder.model[0], unet.model[0], controlnet.model[0]]
206
+ if model_save_dir is not None:
207
+ # overwrite to replace incorrect config
208
+ model.save_config(model_save_dir)
209
+
210
+ # use for CI to access each compiled model
211
+ if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
212
+ model.compiled_models = [
213
+ vae.compiled_models[0],
214
+ vae.compiled_models[1],
215
+ text_encoder.compiled_models[0],
216
+ unet.compiled_models[0],
217
+ ]
218
+ if isinstance(controlnet, RBLNMultiControlNetModel):
219
+ for c_model in controlnet.nets:
220
+ model.compiled_models.append(c_model.compiled_models[0])
221
+ else:
222
+ model.compiled_models.append(controlnet.compiled_models[0])
162
223
 
163
224
  return model
164
225