optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +21 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  5. optimum/rbln/diffusers/models/controlnet.py +3 -0
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  16. optimum/rbln/modeling_alias.py +14 -0
  17. optimum/rbln/modeling_base.py +282 -100
  18. optimum/rbln/modeling_seq2seq.py +58 -132
  19. optimum/rbln/transformers/__init__.py +8 -0
  20. optimum/rbln/transformers/cache_utils.py +111 -0
  21. optimum/rbln/transformers/generation/utils.py +0 -2
  22. optimum/rbln/transformers/models/__init__.py +3 -0
  23. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  24. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  25. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  26. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  27. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  28. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  29. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  30. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  31. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  32. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  33. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
  34. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
  37. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  38. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  42. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  43. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  44. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  45. optimum/rbln/utils/__init__.py +1 -1
  46. optimum/rbln/utils/import_utils.py +46 -0
  47. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
  48. optimum_rbln-0.1.8.dist-info/RECORD +73 -0
  49. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
  51. optimum_rbln-0.1.4.dist-info/RECORD +0 -63
  52. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -22,17 +22,17 @@
22
22
  # from Rebellions Inc.
23
23
  """RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
24
24
 
25
- from pathlib import Path
26
- from tempfile import TemporaryDirectory
27
25
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
28
26
 
29
27
  import torch
30
28
  import torch.nn.functional as F
31
- from diffusers import StableDiffusionXLControlNetPipeline
29
+ from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionXLControlNetPipeline
32
30
  from diffusers.image_processor import PipelineImageInput
31
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
33
32
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
34
33
  from diffusers.utils import deprecate, logging
35
34
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
35
+ from transformers import CLIPTextModel
36
36
 
37
37
  from ....modeling_base import RBLNBaseModel
38
38
  from ....transformers import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
@@ -63,103 +63,152 @@ class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipelin
63
63
  - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
64
64
  """
65
65
  export = kwargs.pop("export", None)
66
- text_encoder = kwargs.pop("text_encoder", None)
67
- controlnets = kwargs.pop("controlnet", None)
68
66
  vae = kwargs.pop("vae", None)
67
+ unet = kwargs.pop("unet", None)
68
+ text_encoder = kwargs.pop("text_encoder", None)
69
+ text_encoder_2 = kwargs.pop("text_encoder_2", None)
70
+ controlnet = kwargs.pop("controlnet", None)
71
+ model_save_dir = kwargs.pop("model_save_dir", None)
69
72
 
70
73
  rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
74
+
71
75
  kwargs_dict = {
72
76
  "pretrained_model_name_or_path": model_id,
73
- "vae": vae,
74
- "controlnet": controlnets,
75
- "text_encoder": text_encoder,
76
77
  **kwargs,
77
78
  }
78
79
 
80
+ kwargs_dict.update(
81
+ {
82
+ **({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
83
+ **({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
84
+ **(
85
+ {"text_encoder": text_encoder}
86
+ if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
87
+ else {}
88
+ ),
89
+ **(
90
+ {"controlnet": controlnet}
91
+ if controlnet is not None
92
+ and (
93
+ isinstance(controlnet, ControlNetModel)
94
+ or all(isinstance(c, ControlNetModel) for c in controlnet)
95
+ )
96
+ else {}
97
+ ),
98
+ }
99
+ )
100
+
79
101
  model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
80
102
 
81
103
  if export is None or export is False:
82
104
  return model
83
105
 
84
- save_dir = TemporaryDirectory()
85
- save_dir_path = Path(save_dir.name)
86
-
87
- model.save_pretrained(save_directory=save_dir_path, **kwargs)
88
-
89
106
  do_classifier_free_guidance = (
90
107
  rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
91
108
  )
92
109
 
93
- vae = RBLNAutoencoderKL.from_pretrained(
94
- model_id=model_id,
95
- subfolder="vae",
96
- export=True,
97
- rbln_unet_sample_size=model.unet.config.sample_size,
98
- rbln_use_encode=True,
99
- rbln_vae_scale_factor=model.vae_scale_factor,
100
- **rbln_config_kwargs,
101
- **rbln_constructor_kwargs,
102
- )
103
- text_encoder = RBLNCLIPTextModel.from_pretrained(
104
- model_id=model_id,
105
- subfolder="text_encoder",
106
- export=True,
107
- **rbln_config_kwargs,
108
- **rbln_constructor_kwargs,
109
- )
110
- text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
111
- model_id=model_id,
112
- subfolder="text_encoder_2",
113
- export=True,
114
- **rbln_config_kwargs,
115
- **rbln_constructor_kwargs,
116
- )
117
-
118
- batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
119
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
110
+ if not isinstance(vae, RBLNAutoencoderKL):
111
+ vae = RBLNAutoencoderKL.from_pretrained(
112
+ model_id=model_id,
113
+ subfolder="vae",
114
+ export=True,
115
+ model_save_dir=model_save_dir,
116
+ rbln_unet_sample_size=model.unet.config.sample_size,
117
+ rbln_use_encode=False,
118
+ rbln_vae_scale_factor=model.vae_scale_factor,
119
+ **rbln_config_kwargs,
120
+ **rbln_constructor_kwargs,
121
+ )
120
122
 
121
- unet = RBLNUNet2DConditionModel.from_pretrained(
122
- model_id=model_id,
123
- subfolder="unet",
124
- export=True,
125
- rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
126
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
127
- rbln_batch_size=unet_batch_size,
128
- rbln_use_encode=True,
129
- rbln_vae_scale_factor=model.vae_scale_factor,
130
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
131
- **rbln_config_kwargs,
132
- **rbln_constructor_kwargs,
133
- )
123
+ if not isinstance(text_encoder, RBLNCLIPTextModel):
124
+ text_encoder = RBLNCLIPTextModel.from_pretrained(
125
+ model_id=model_id,
126
+ subfolder="text_encoder",
127
+ export=True,
128
+ model_save_dir=model_save_dir,
129
+ **rbln_config_kwargs,
130
+ **rbln_constructor_kwargs,
131
+ )
134
132
 
135
- if isinstance(controlnets, (list, tuple)):
136
- controlnet = RBLNMultiControlNetModel.from_pretrained(
137
- model_id=str(save_dir_path / "controlnet"),
133
+ if not isinstance(text_encoder_2, RBLNCLIPTextModel):
134
+ text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
135
+ model_id=model_id,
136
+ subfolder="text_encoder_2",
138
137
  export=True,
139
- rbln_batch_size=unet_batch_size,
140
- rbln_vae_scale_factor=model.vae_scale_factor,
138
+ model_save_dir=model_save_dir,
141
139
  **rbln_config_kwargs,
142
140
  **rbln_constructor_kwargs,
143
141
  )
144
- controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
145
- else:
146
- controlnet = RBLNControlNetModel.from_pretrained(
147
- model_id=save_dir_path / "controlnet",
142
+
143
+ batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
144
+ unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
145
+
146
+ if not isinstance(unet, RBLNUNet2DConditionModel):
147
+ unet = RBLNUNet2DConditionModel.from_pretrained(
148
+ model_id=model_id,
149
+ subfolder="unet",
148
150
  export=True,
149
- rbln_batch_size=unet_batch_size,
151
+ model_save_dir=model_save_dir,
152
+ rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
150
153
  rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
154
+ rbln_batch_size=unet_batch_size,
155
+ rbln_use_encode=False,
151
156
  rbln_vae_scale_factor=model.vae_scale_factor,
157
+ rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
152
158
  **rbln_config_kwargs,
153
159
  **rbln_constructor_kwargs,
154
160
  )
155
- controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
156
161
 
162
+ if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
163
+ if isinstance(controlnet, (list, tuple)):
164
+ multicontrolnet = []
165
+ for i, cid in enumerate(controlnet):
166
+ subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
167
+ multicontrolnet.append(
168
+ RBLNControlNetModel.from_pretrained(
169
+ model_id=cid.config._name_or_path,
170
+ subfolder=subfolder_name,
171
+ export=True,
172
+ model_save_dir=model_save_dir,
173
+ rbln_batch_size=unet_batch_size,
174
+ rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
175
+ rbln_vae_scale_factor=model.vae_scale_factor,
176
+ **rbln_config_kwargs,
177
+ **rbln_constructor_kwargs,
178
+ )
179
+ )
180
+ controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
181
+ controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
182
+ else:
183
+ controlnet = RBLNControlNetModel.from_pretrained(
184
+ model_id=controlnet.config._name_or_path,
185
+ subfolder="controlnet",
186
+ export=True,
187
+ model_save_dir=model_save_dir,
188
+ rbln_batch_size=unet_batch_size,
189
+ rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
190
+ rbln_vae_scale_factor=model.vae_scale_factor,
191
+ **rbln_config_kwargs,
192
+ **rbln_constructor_kwargs,
193
+ )
194
+ controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
195
+
196
+ if model_save_dir is not None:
197
+ # To skip saving original pytorch modules
198
+ del (model.vae, model.text_encoder, model.unet, model.controlnet)
199
+
200
+ # Direct calling of `save_pretrained` causes config.unet = (None, None).
201
+ # So config must be saved again, later.
202
+ model.save_pretrained(model_save_dir)
203
+
204
+ # replace modules
157
205
  model.vae = vae
158
206
  model.text_encoder = text_encoder
159
207
  model.unet = unet
160
208
  model.text_encoder_2 = text_encoder_2
161
209
  model.controlnet = controlnet
162
210
 
211
+ # update config to be able to load from file
163
212
  update_dict = {
164
213
  "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
165
214
  "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
@@ -169,13 +218,23 @@ class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipelin
169
218
  }
170
219
  model.register_to_config(**update_dict)
171
220
 
172
- model.models = [
173
- vae.model[0],
174
- unet.model[0],
175
- text_encoder.model[0],
176
- text_encoder_2.model[0],
177
- controlnet.model[0],
178
- ]
221
+ if model_save_dir is not None:
222
+ # overwrite to replace incorrect config
223
+ model.save_config(model_save_dir)
224
+
225
+ # use for CI to access each compiled model
226
+ if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
227
+ model.compiled_models = [
228
+ vae.compiled_models[0],
229
+ text_encoder.compiled_models[0],
230
+ text_encoder_2.compiled_models[0],
231
+ unet.compiled_models[0],
232
+ ]
233
+ if isinstance(controlnet, RBLNMultiControlNetModel):
234
+ for c_model in controlnet.nets:
235
+ model.compiled_models.append(c_model.compiled_models[0])
236
+ else:
237
+ model.compiled_models.append(controlnet.compiled_models[0])
179
238
 
180
239
  return model
181
240
 
@@ -22,17 +22,17 @@
22
22
  # from Rebellions Inc.
23
23
  """RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
24
24
 
25
- from pathlib import Path
26
- from tempfile import TemporaryDirectory
27
25
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
28
26
 
29
27
  import torch
30
28
  import torch.nn.functional as F
31
- from diffusers import StableDiffusionXLControlNetImg2ImgPipeline
29
+ from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline
32
30
  from diffusers.image_processor import PipelineImageInput
31
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
33
32
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
34
33
  from diffusers.utils import deprecate, logging
35
34
  from diffusers.utils.torch_utils import is_compiled_module
35
+ from transformers import CLIPTextModel
36
36
 
37
37
  from ....modeling_base import RBLNBaseModel
38
38
  from ....transformers import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
@@ -63,103 +63,152 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(StableDiffusionXLControlNet
63
63
  - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
64
64
  """
65
65
  export = kwargs.pop("export", None)
66
- text_encoder = kwargs.pop("text_encoder", None)
67
- controlnets = kwargs.pop("controlnet", None)
68
66
  vae = kwargs.pop("vae", None)
67
+ unet = kwargs.pop("unet", None)
68
+ text_encoder = kwargs.pop("text_encoder", None)
69
+ text_encoder_2 = kwargs.pop("text_encoder_2", None)
70
+ controlnet = kwargs.pop("controlnet", None)
71
+ model_save_dir = kwargs.pop("model_save_dir", None)
69
72
 
70
73
  rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
74
+
71
75
  kwargs_dict = {
72
76
  "pretrained_model_name_or_path": model_id,
73
- "vae": vae,
74
- "controlnet": controlnets,
75
- "text_encoder": text_encoder,
76
77
  **kwargs,
77
78
  }
78
79
 
80
+ kwargs_dict.update(
81
+ {
82
+ **({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
83
+ **({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
84
+ **(
85
+ {"text_encoder": text_encoder}
86
+ if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
87
+ else {}
88
+ ),
89
+ **(
90
+ {"controlnet": controlnet}
91
+ if controlnet is not None
92
+ and (
93
+ isinstance(controlnet, ControlNetModel)
94
+ or all(isinstance(c, ControlNetModel) for c in controlnet)
95
+ )
96
+ else {}
97
+ ),
98
+ }
99
+ )
100
+
79
101
  model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
80
102
 
81
103
  if export is None or export is False:
82
104
  return model
83
105
 
84
- save_dir = TemporaryDirectory()
85
- save_dir_path = Path(save_dir.name)
86
-
87
- model.save_pretrained(save_directory=save_dir_path, **kwargs)
88
-
89
106
  do_classifier_free_guidance = (
90
107
  rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
91
108
  )
92
109
 
93
- vae = RBLNAutoencoderKL.from_pretrained(
94
- model_id=model_id,
95
- subfolder="vae",
96
- export=True,
97
- rbln_unet_sample_size=model.unet.config.sample_size,
98
- rbln_use_encode=True,
99
- rbln_vae_scale_factor=model.vae_scale_factor,
100
- **rbln_config_kwargs,
101
- **rbln_constructor_kwargs,
102
- )
103
- text_encoder = RBLNCLIPTextModel.from_pretrained(
104
- model_id=model_id,
105
- subfolder="text_encoder",
106
- export=True,
107
- **rbln_config_kwargs,
108
- **rbln_constructor_kwargs,
109
- )
110
- text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
111
- model_id=model_id,
112
- subfolder="text_encoder_2",
113
- export=True,
114
- **rbln_config_kwargs,
115
- **rbln_constructor_kwargs,
116
- )
117
-
118
- batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
119
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
110
+ if not isinstance(vae, RBLNAutoencoderKL):
111
+ vae = RBLNAutoencoderKL.from_pretrained(
112
+ model_id=model_id,
113
+ subfolder="vae",
114
+ export=True,
115
+ model_save_dir=model_save_dir,
116
+ rbln_unet_sample_size=model.unet.config.sample_size,
117
+ rbln_use_encode=True,
118
+ rbln_vae_scale_factor=model.vae_scale_factor,
119
+ **rbln_config_kwargs,
120
+ **rbln_constructor_kwargs,
121
+ )
120
122
 
121
- unet = RBLNUNet2DConditionModel.from_pretrained(
122
- model_id=model_id,
123
- subfolder="unet",
124
- export=True,
125
- rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
126
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
127
- rbln_batch_size=unet_batch_size,
128
- rbln_use_encode=True,
129
- rbln_vae_scale_factor=model.vae_scale_factor,
130
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
131
- **rbln_config_kwargs,
132
- **rbln_constructor_kwargs,
133
- )
123
+ if not isinstance(text_encoder, RBLNCLIPTextModel):
124
+ text_encoder = RBLNCLIPTextModel.from_pretrained(
125
+ model_id=model_id,
126
+ subfolder="text_encoder",
127
+ export=True,
128
+ model_save_dir=model_save_dir,
129
+ **rbln_config_kwargs,
130
+ **rbln_constructor_kwargs,
131
+ )
134
132
 
135
- if isinstance(controlnets, (list, tuple)):
136
- controlnet = RBLNMultiControlNetModel.from_pretrained(
137
- model_id=str(save_dir_path / "controlnet"),
133
+ if not isinstance(text_encoder_2, RBLNCLIPTextModel):
134
+ text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
135
+ model_id=model_id,
136
+ subfolder="text_encoder_2",
138
137
  export=True,
139
- rbln_batch_size=unet_batch_size,
140
- rbln_vae_scale_factor=model.vae_scale_factor,
138
+ model_save_dir=model_save_dir,
141
139
  **rbln_config_kwargs,
142
140
  **rbln_constructor_kwargs,
143
141
  )
144
- controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
145
- else:
146
- controlnet = RBLNControlNetModel.from_pretrained(
147
- model_id=save_dir_path / "controlnet",
142
+
143
+ batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
144
+ unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
145
+
146
+ if not isinstance(unet, RBLNUNet2DConditionModel):
147
+ unet = RBLNUNet2DConditionModel.from_pretrained(
148
+ model_id=model_id,
149
+ subfolder="unet",
148
150
  export=True,
149
- rbln_batch_size=unet_batch_size,
151
+ model_save_dir=model_save_dir,
152
+ rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
150
153
  rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
154
+ rbln_batch_size=unet_batch_size,
155
+ rbln_use_encode=True,
151
156
  rbln_vae_scale_factor=model.vae_scale_factor,
157
+ rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
152
158
  **rbln_config_kwargs,
153
159
  **rbln_constructor_kwargs,
154
160
  )
155
- controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
156
161
 
162
+ if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
163
+ if isinstance(controlnet, (list, tuple)):
164
+ multicontrolnet = []
165
+ for i, cid in enumerate(controlnet):
166
+ subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
167
+ multicontrolnet.append(
168
+ RBLNControlNetModel.from_pretrained(
169
+ model_id=cid.config._name_or_path,
170
+ subfolder=subfolder_name,
171
+ export=True,
172
+ model_save_dir=model_save_dir,
173
+ rbln_batch_size=unet_batch_size,
174
+ rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
175
+ rbln_vae_scale_factor=model.vae_scale_factor,
176
+ **rbln_config_kwargs,
177
+ **rbln_constructor_kwargs,
178
+ )
179
+ )
180
+ controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
181
+ controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
182
+ else:
183
+ controlnet = RBLNControlNetModel.from_pretrained(
184
+ model_id=controlnet.config._name_or_path,
185
+ subfolder="controlnet",
186
+ export=True,
187
+ model_save_dir=model_save_dir,
188
+ rbln_batch_size=unet_batch_size,
189
+ rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
190
+ rbln_vae_scale_factor=model.vae_scale_factor,
191
+ **rbln_config_kwargs,
192
+ **rbln_constructor_kwargs,
193
+ )
194
+ controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
195
+
196
+ if model_save_dir is not None:
197
+ # To skip saving original pytorch modules
198
+ del (model.vae, model.text_encoder, model.unet, model.controlnet)
199
+
200
+ # Direct calling of `save_pretrained` causes config.unet = (None, None).
201
+ # So config must be saved again, later.
202
+ model.save_pretrained(model_save_dir)
203
+
204
+ # replace modules
157
205
  model.vae = vae
158
206
  model.text_encoder = text_encoder
159
207
  model.unet = unet
160
208
  model.text_encoder_2 = text_encoder_2
161
209
  model.controlnet = controlnet
162
210
 
211
+ # update config to be able to load from file
163
212
  update_dict = {
164
213
  "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
165
214
  "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
@@ -169,14 +218,24 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(StableDiffusionXLControlNet
169
218
  }
170
219
  model.register_to_config(**update_dict)
171
220
 
172
- model.models = [
173
- vae.model[0],
174
- vae.model[1],
175
- unet.model[0],
176
- text_encoder.model[0],
177
- text_encoder_2.model[0],
178
- controlnet.model[0],
179
- ]
221
+ if model_save_dir is not None:
222
+ # overwrite to replace incorrect config
223
+ model.save_config(model_save_dir)
224
+
225
+ # use for CI to access each compiled model
226
+ if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
227
+ model.compiled_models = [
228
+ vae.compiled_models[0],
229
+ vae.compiled_models[1],
230
+ text_encoder.compiled_models[0],
231
+ text_encoder_2.compiled_models[0],
232
+ unet.compiled_models[0],
233
+ ]
234
+ if isinstance(controlnet, RBLNMultiControlNetModel):
235
+ for c_model in controlnet.nets:
236
+ model.compiled_models.append(c_model.compiled_models[0])
237
+ else:
238
+ model.compiled_models.append(controlnet.compiled_models[0])
180
239
 
181
240
  return model
182
241
 
@@ -122,4 +122,7 @@ class RBLNStableDiffusionPipeline(StableDiffusionPipeline):
122
122
 
123
123
  model.models = [vae.model[0], text_encoder.model[0], unet.model[0]]
124
124
 
125
+ if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
126
+ model.compiled_models = [vae.compiled_models[0], text_encoder.compiled_models[0], unet.compiled_models[0]]
127
+
125
128
  return model
@@ -126,4 +126,12 @@ class RBLNStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
126
126
  # vae encoder, vae decoder, text_encoder, unet
127
127
  model.models = [vae.model[0], vae.model[1], text_encoder.model[0], unet.model[0]]
128
128
 
129
+ if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
130
+ model.compiled_models = [
131
+ vae.compiled_models[0],
132
+ vae.compiled_models[1],
133
+ text_encoder.compiled_models[0],
134
+ unet.compiled_models[0],
135
+ ]
136
+
129
137
  return model
@@ -122,4 +122,12 @@ class RBLNStableDiffusionXLPipeline(StableDiffusionXLPipeline):
122
122
 
123
123
  model.models = [vae.model[0], unet.model[0], text_encoder.model[0], text_encoder_2.model[0]]
124
124
 
125
+ if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
126
+ model.compiled_models = [
127
+ vae.compiled_models[0],
128
+ unet.compiled_models[0],
129
+ text_encoder.compiled_models[0],
130
+ text_encoder_2.compiled_models[0],
131
+ ]
132
+
125
133
  return model
@@ -124,4 +124,13 @@ class RBLNStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
124
124
 
125
125
  model.models = [vae.model[0], vae.model[1], unet.model[0], text_encoder.model[0], text_encoder_2.model[0]]
126
126
 
127
+ if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
128
+ model.compiled_models = [
129
+ vae.compiled_models[0],
130
+ vae.compiled_models[1],
131
+ unet.compiled_models[0],
132
+ text_encoder.compiled_models[0],
133
+ text_encoder_2.compiled_models[0],
134
+ ]
135
+
127
136
  return model
@@ -24,7 +24,9 @@
24
24
  from .modeling_base import (
25
25
  RBLNModelForAudioClassification,
26
26
  RBLNModelForImageClassification,
27
+ RBLNModelForMaskedLM,
27
28
  RBLNModelForQuestionAnswering,
29
+ RBLNModelForSequenceClassification,
28
30
  )
29
31
  from .modeling_seq2seq import RBLNModelForSeq2SeqLM
30
32
 
@@ -47,3 +49,15 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
47
49
 
48
50
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
49
51
  pass
52
+
53
+
54
+ class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification):
55
+ pass
56
+
57
+
58
+ class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
59
+ pass
60
+
61
+
62
+ class RBLNRobertaForMaskedLM(RBLNModelForMaskedLM):
63
+ pass