optimum-rbln 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. optimum/rbln/__init__.py +17 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +3 -3
  5. optimum/rbln/diffusers/models/controlnet.py +7 -3
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +5 -5
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +23 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/modeling_alias.py +19 -1
  13. optimum/rbln/modeling_base.py +162 -18
  14. optimum/rbln/transformers/__init__.py +8 -0
  15. optimum/rbln/transformers/cache_utils.py +111 -0
  16. optimum/rbln/transformers/generation/utils.py +0 -2
  17. optimum/rbln/transformers/models/__init__.py +3 -0
  18. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  19. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  20. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  21. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +516 -0
  22. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +464 -0
  23. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  24. optimum/rbln/transformers/models/gemma/gemma_architecture.py +123 -0
  25. optimum/rbln/transformers/models/gemma/modeling_gemma.py +67 -0
  26. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
  27. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +10 -257
  28. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
  29. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -440
  30. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  31. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  32. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  33. optimum/rbln/transformers/models/midm/modeling_midm.py +10 -325
  34. optimum/rbln/transformers/models/mistral/__init__.py +24 -0
  35. optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
  36. optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
  37. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  38. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  39. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  40. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +131 -0
  41. optimum/rbln/transformers/utils/__init__.py +0 -0
  42. optimum/rbln/transformers/utils/rbln_quantization.py +109 -0
  43. optimum/rbln/utils/import_utils.py +1 -4
  44. optimum/rbln/utils/runtime_utils.py +2 -1
  45. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/METADATA +11 -5
  46. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/RECORD +48 -35
  47. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
  48. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/WHEEL +0 -0
  49. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -32,9 +32,13 @@ _import_structure = {
32
32
  "modeling_alias": [
33
33
  "RBLNASTForAudioClassification",
34
34
  "RBLNBertForQuestionAnswering",
35
+ "RBLNDistilBertForQuestionAnswering",
35
36
  "RBLNResNetForImageClassification",
36
37
  "RBLNT5ForConditionalGeneration",
37
38
  "RBLNBartForConditionalGeneration",
39
+ "RBLNXLMRobertaForSequenceClassification",
40
+ "RBLNRobertaForSequenceClassification",
41
+ "RBLNRobertaForMaskedLM",
38
42
  ],
39
43
  "modeling_base": [
40
44
  "RBLNBaseModel",
@@ -42,6 +46,8 @@ _import_structure = {
42
46
  "RBLNModelForQuestionAnswering",
43
47
  "RBLNModelForAudioClassification",
44
48
  "RBLNModelForImageClassification",
49
+ "RBLNModelForSequenceClassification",
50
+ "RBLNModelForMaskedLM",
45
51
  ],
46
52
  "modeling_seq2seq": [
47
53
  "RBLNModelForSeq2SeqLM",
@@ -51,11 +57,14 @@ _import_structure = {
51
57
  "RBLNCLIPTextModel",
52
58
  "RBLNCLIPTextModelWithProjection",
53
59
  "RBLNDPTForDepthEstimation",
60
+ "RBLNGemmaForCausalLM",
54
61
  "RBLNGPT2LMHeadModel",
55
62
  "RBLNWav2Vec2ForCTC",
56
63
  "RBLNLlamaForCausalLM",
57
64
  "RBLNMidmLMHeadModel",
65
+ "RBLNMistralForCausalLM",
58
66
  "RBLNWhisperForConditionalGeneration",
67
+ "RBLNXLMRobertaModel",
59
68
  ],
60
69
  "diffusers": [
61
70
  "RBLNStableDiffusionPipeline",
@@ -94,14 +103,19 @@ if TYPE_CHECKING:
94
103
  RBLNBartForConditionalGeneration,
95
104
  RBLNBertForQuestionAnswering,
96
105
  RBLNResNetForImageClassification,
106
+ RBLNRobertaForMaskedLM,
107
+ RBLNRobertaForSequenceClassification,
97
108
  RBLNT5ForConditionalGeneration,
109
+ RBLNXLMRobertaForSequenceClassification,
98
110
  )
99
111
  from .modeling_base import (
100
112
  RBLNBaseModel,
101
113
  RBLNModel,
102
114
  RBLNModelForAudioClassification,
103
115
  RBLNModelForImageClassification,
116
+ RBLNModelForMaskedLM,
104
117
  RBLNModelForQuestionAnswering,
118
+ RBLNModelForSequenceClassification,
105
119
  )
106
120
  from .modeling_config import RBLNConfig, RBLNRuntimeConfig
107
121
  from .modeling_seq2seq import RBLNModelForSeq2SeqLM
@@ -110,11 +124,14 @@ if TYPE_CHECKING:
110
124
  RBLNCLIPTextModel,
111
125
  RBLNCLIPTextModelWithProjection,
112
126
  RBLNDPTForDepthEstimation,
127
+ RBLNGemmaForCausalLM,
113
128
  RBLNGPT2LMHeadModel,
114
129
  RBLNLlamaForCausalLM,
115
130
  RBLNMidmLMHeadModel,
131
+ RBLNMistralForCausalLM,
116
132
  RBLNWav2Vec2ForCTC,
117
133
  RBLNWhisperForConditionalGeneration,
134
+ RBLNXLMRobertaModel,
118
135
  )
119
136
  else:
120
137
  import sys
@@ -1 +1 @@
1
- __version__ = '0.1.7'
1
+ __version__ = '0.1.9'
@@ -47,7 +47,6 @@ _import_structure = {
47
47
  }
48
48
 
49
49
  if TYPE_CHECKING:
50
-
51
50
  from .models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
52
51
  from .pipelines import (
53
52
  RBLNMultiControlNetModel,
@@ -26,7 +26,7 @@ from pathlib import Path
26
26
  from typing import TYPE_CHECKING, Dict, List, Optional, Union
27
27
 
28
28
  import rebel
29
- import torch
29
+ import torch # noqa: I001
30
30
  from diffusers import AutoencoderKL
31
31
  from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
32
32
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
@@ -38,12 +38,12 @@ from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRunt
38
38
  from ...utils.runtime_utils import RBLNPytorchRuntime
39
39
 
40
40
 
41
- logger = logging.getLogger(__name__)
42
-
43
41
  if TYPE_CHECKING:
44
42
  import torch
45
43
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
46
44
 
45
+ logger = logging.getLogger(__name__)
46
+
47
47
 
48
48
  class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
49
49
  def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
@@ -34,12 +34,13 @@ from ...modeling_base import RBLNModel
34
34
  from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
35
35
 
36
36
 
37
- logger = logging.getLogger(__name__)
38
-
39
37
  if TYPE_CHECKING:
40
38
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
41
39
 
42
40
 
41
+ logger = logging.getLogger(__name__)
42
+
43
+
43
44
  class _ControlNetModel(torch.nn.Module):
44
45
  def __init__(self, controlnet: "ControlNetModel"):
45
46
  super().__init__()
@@ -120,6 +121,9 @@ class RBLNControlNetModel(RBLNModel):
120
121
  model_name_or_path: Union[str, Path],
121
122
  **kwargs,
122
123
  ):
124
+ if "subfolder" in kwargs:
125
+ del kwargs["subfolder"]
126
+
123
127
  return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
124
128
 
125
129
  tasktmp = TasksManager.get_model_from_task
@@ -135,7 +139,7 @@ class RBLNControlNetModel(RBLNModel):
135
139
  return rt
136
140
 
137
141
  @classmethod
138
- def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
142
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
139
143
  use_encoder_hidden_states = False
140
144
  for down_block in model.down_blocks:
141
145
  if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
@@ -35,11 +35,11 @@ from ...modeling_base import RBLNModel
35
35
  from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
36
36
 
37
37
 
38
- logger = logging.getLogger(__name__)
39
-
40
38
  if TYPE_CHECKING:
41
39
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
42
40
 
41
+ logger = logging.getLogger(__name__)
42
+
43
43
 
44
44
  class _UNet_SD(torch.nn.Module):
45
45
  def __init__(self, unet: "UNet2DConditionModel"):
@@ -172,7 +172,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
172
172
  return rt
173
173
 
174
174
  @classmethod
175
- def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
175
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
176
176
  if model.config.addition_embed_type == "text_time":
177
177
  return _UNet_SDXL(model).eval()
178
178
  else:
@@ -244,6 +244,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
244
244
  for i in range(3)
245
245
  ]
246
246
  )
247
+ if len(model_config.block_out_channels) > 1:
247
248
  input_info.append(
248
249
  (
249
250
  "down_block_additional_residuals_3",
@@ -251,7 +252,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
251
252
  "float32",
252
253
  )
253
254
  )
254
- if len(model_config.block_out_channels) > 1:
255
255
  input_info.extend(
256
256
  [
257
257
  (
@@ -262,6 +262,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
262
262
  for i in range(4, 6)
263
263
  ]
264
264
  )
265
+ if len(model_config.block_out_channels) > 2:
265
266
  input_info.append(
266
267
  (
267
268
  f"down_block_additional_residuals_{6}",
@@ -269,7 +270,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
269
270
  "float32",
270
271
  )
271
272
  )
272
- if len(model_config.block_out_channels) > 2:
273
273
  input_info.extend(
274
274
  [
275
275
  (
@@ -24,56 +24,32 @@
24
24
  import logging
25
25
  import os
26
26
  from pathlib import Path
27
- from tempfile import TemporaryDirectory
28
27
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
29
28
 
30
- import rebel
31
29
  import torch
32
30
  from diffusers import ControlNetModel
33
31
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
34
32
  from optimum.exporters import TasksManager
35
- from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
33
+ from transformers import AutoConfig, AutoModel
36
34
 
37
- from ....modeling_base import RBLNBaseModel
38
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig
35
+ from ....modeling_base import RBLNModel
36
+ from ....modeling_config import RBLNConfig
39
37
  from ...models.controlnet import RBLNControlNetModel
40
38
 
41
39
 
42
- logger = logging.getLogger(__name__)
43
-
44
40
  if TYPE_CHECKING:
45
- from transformers import (
46
- PretrainedConfig,
47
- PreTrainedModel,
48
- )
41
+ pass
49
42
 
43
+ logger = logging.getLogger(__name__)
50
44
 
51
- class RBLNMultiControlNetModel(RBLNBaseModel):
52
- model_type = "rbln_model"
53
- auto_model_class = AutoModel
54
45
 
46
+ class RBLNMultiControlNetModel(RBLNModel):
55
47
  def __init__(
56
48
  self,
57
- models: List[Union[PreTrainedModel, rebel.RBLNCompiledModel]],
58
- config: PretrainedConfig = None,
59
- preprocessors: Optional[List] = None,
60
- rbln_config: Optional[RBLNConfig] = None,
49
+ models: List[RBLNControlNetModel],
61
50
  **kwargs,
62
51
  ):
63
- super().__init__(
64
- models,
65
- config,
66
- preprocessors,
67
- rbln_config,
68
- **kwargs,
69
- )
70
-
71
- if not isinstance(config, PretrainedConfig):
72
- config = PretrainedConfig(**config)
73
-
74
- for i in range(len(models)):
75
- self.runtimes[i].config = config
76
- self.nets = self.runtimes
52
+ self.nets = models
77
53
  self.dtype = torch.float32
78
54
 
79
55
  @classmethod
@@ -83,7 +59,7 @@ class RBLNMultiControlNetModel(RBLNBaseModel):
83
59
  model_name_or_path: Union[str, Path],
84
60
  **kwargs,
85
61
  ):
86
- return MultiControlNetModel.from_pretrained(pretrained_model_path=model_name_or_path, **kwargs)
62
+ return MultiControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
87
63
 
88
64
  tasktmp = TasksManager.get_model_from_task
89
65
  configtmp = AutoConfig.from_pretrained
@@ -101,129 +77,30 @@ class RBLNMultiControlNetModel(RBLNBaseModel):
101
77
  def _from_pretrained(
102
78
  cls,
103
79
  model_id: Union[str, Path],
104
- config: "PretrainedConfig",
105
- use_auth_token: Optional[Union[bool, str]] = None,
106
- revision: Optional[str] = None,
107
- force_download: bool = False,
108
- cache_dir: Optional[str] = None,
109
- file_name: Optional[str] = None,
110
- subfolder: str = "",
111
- local_files_only: bool = False,
112
80
  **kwargs,
113
- ) -> RBLNBaseModel:
114
- if isinstance(model_id, str):
115
- model_path = Path(model_id)
116
- else:
117
- model_path = model_id / "controlnet"
118
-
119
- rbln_files = []
120
- rbln_config_filenames = []
81
+ ) -> RBLNModel:
121
82
  idx = 0
122
- model_load_path = model_path
83
+ controlnets = []
84
+ model_path_to_load = model_id
123
85
 
124
- while model_load_path.is_dir():
125
- rbln_files.append(list(model_load_path.glob("**/*.rbln"))[0])
126
- rbln_config_filenames.append(model_load_path)
86
+ while os.path.isdir(model_path_to_load):
87
+ controlnet = RBLNControlNetModel.from_pretrained(model_path_to_load, export=False, **kwargs)
88
+ controlnets.append(controlnet)
89
+ rbln_config = RBLNConfig.load(model_path_to_load)
127
90
  idx += 1
128
- model_load_path = Path(str(model_path) + f"_{idx}")
129
-
130
- if len(rbln_files) == 0:
131
- raise FileNotFoundError(f"Could not find any rbln model file in {model_path}")
132
-
133
- if len(rbln_config_filenames) == 0:
134
- raise FileNotFoundError(f"Could not find `rbln_config.json` file in {model_path}")
135
-
136
- models = []
137
- for rconf, rfiles in zip(rbln_config_filenames, rbln_files):
138
- rbln_config = RBLNConfig.load(str(rconf))
139
- models.append(rebel.RBLNCompiledModel(rfiles))
140
-
141
- preprocessors = []
91
+ model_path_to_load = model_id + f"_{idx}"
142
92
 
143
93
  return cls(
144
- models,
145
- config,
146
- preprocessors,
94
+ controlnets,
147
95
  rbln_config=rbln_config,
148
96
  **kwargs,
149
97
  )
150
98
 
151
- def _save_pretrained(self, save_directory: Union[str, Path]):
152
- # TODO(kblee) : 확인 부탁드립니다
153
- idx = 0
154
- real_save_dir_path = save_directory
155
- for compiled_model in self.compiled_models:
156
- dst_path = Path(real_save_dir_path) / "compiled_model.rbln"
157
- if not os.path.exists(real_save_dir_path):
158
- os.makedirs(real_save_dir_path)
159
- compiled_model.save(dst_path)
160
- self.rbln_config.save(real_save_dir_path)
161
- idx += 1
162
- real_save_dir_path = save_directory + f"_{idx}"
163
-
164
- @classmethod
165
- @torch.no_grad()
166
- def _export(
167
- cls,
168
- model_id: str,
169
- config: "PretrainedConfig",
170
- use_auth_token: Optional[Union[bool, str]] = None,
171
- revision: Optional[str] = None,
172
- force_download: bool = False,
173
- cache_dir: Optional[str] = None,
174
- subfolder: str = "",
175
- local_files_only: bool = False,
176
- trust_remote_code: bool = False,
177
- **kwargs,
178
- ) -> "RBLNMultiControlNetModel":
179
- task = kwargs.pop("task", None)
180
- if task is None:
181
- task = TasksManager.infer_task_from_model(cls.auto_model_class)
182
-
183
- save_dir = TemporaryDirectory()
184
- save_dir_path = Path(save_dir.name)
185
-
186
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
187
- img_width = rbln_config_kwargs.pop("rbln_img_width", None)
188
- img_height = rbln_config_kwargs.pop("rbln_img_height", None)
189
- vae_scale_factor = rbln_config_kwargs.pop("rbln_vae_scale_factor", None)
190
- batch_size = rbln_config_kwargs.pop("rbln_batch_size", None)
191
-
192
- model: MultiControlNetModel = TasksManager.get_model_from_task(
193
- task=task,
194
- model_name_or_path=model_id,
195
- )
196
-
197
- model_path_to_load = model_id
198
- real_save_dir_path = save_dir_path / "controlnet"
199
-
200
- for idx in range(len(model.nets)):
99
+ def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
100
+ for idx, model in enumerate(self.nets):
201
101
  suffix = "" if idx == 0 else f"_{idx}"
202
- controlnet = RBLNControlNetModel.from_pretrained(
203
- model_path_to_load + suffix,
204
- export=True,
205
- rbln_batch_size=batch_size,
206
- rbln_img_width=img_width,
207
- rbln_img_height=img_height,
208
- rbln_vae_scale_factor=vae_scale_factor,
209
- )
210
- controlnet.save_pretrained(real_save_dir_path)
211
- real_save_dir_path = save_dir_path / f"controlnet_{idx+1}"
212
-
213
- return cls._from_pretrained(
214
- model_id=save_dir_path,
215
- config=config,
216
- model_save_dir=save_dir,
217
- **rbln_constructor_kwargs,
218
- **kwargs,
219
- )
220
-
221
- @classmethod
222
- def _create_runtimes(
223
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
224
- ) -> List[rebel.Runtime]:
225
- device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
226
- return [compiled_model.create_runtime(tensor_type="pt", device=device) for compiled_model in compiled_models]
102
+ real_save_path = save_directory + suffix
103
+ model.save_pretrained(real_save_path)
227
104
 
228
105
  def forward(
229
106
  self,
@@ -241,7 +118,7 @@ class RBLNMultiControlNetModel(RBLNBaseModel):
241
118
  return_dict: bool = True,
242
119
  ):
243
120
  for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
244
- output = controlnet(
121
+ output = controlnet.model[0](
245
122
  sample=sample.contiguous(),
246
123
  timestep=timestep.float(),
247
124
  encoder_hidden_states=encoder_hidden_states,
@@ -22,18 +22,18 @@
22
22
  # from Rebellions Inc.
23
23
  """RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
24
24
 
25
- from pathlib import Path
26
- from tempfile import TemporaryDirectory
27
25
  from typing import Any, Callable, Dict, List, Optional, Union
28
26
 
29
27
  import torch
30
28
  import torch.nn.functional as F
31
- from diffusers import StableDiffusionControlNetPipeline
29
+ from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionControlNetPipeline
32
30
  from diffusers.image_processor import PipelineImageInput
31
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
33
32
  from diffusers.pipelines.controlnet.pipeline_controlnet import retrieve_timesteps
34
33
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
35
34
  from diffusers.utils import deprecate, logging
36
35
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
36
+ from transformers import CLIPTextModel
37
37
 
38
38
  from ....modeling_base import RBLNBaseModel
39
39
  from ....transformers import RBLNCLIPTextModel
@@ -64,18 +64,40 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
64
64
  - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
65
65
  """
66
66
  export = kwargs.pop("export", None)
67
+ vae = kwargs.pop("vae", None)
68
+ unet = kwargs.pop("unet", None)
67
69
  text_encoder = kwargs.pop("text_encoder", None)
68
- controlnets = kwargs.pop("controlnet", None)
70
+ controlnet = kwargs.pop("controlnet", None)
71
+ model_save_dir = kwargs.pop("model_save_dir", None)
69
72
 
70
73
  rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
71
74
 
72
75
  kwargs_dict = {
73
76
  "pretrained_model_name_or_path": model_id,
74
- "text_encoder": text_encoder,
75
- "controlnet": controlnets,
76
77
  **kwargs,
77
78
  }
78
79
 
80
+ kwargs_dict.update(
81
+ {
82
+ **({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
83
+ **({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
84
+ **(
85
+ {"text_encoder": text_encoder}
86
+ if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
87
+ else {}
88
+ ),
89
+ **(
90
+ {"controlnet": controlnet}
91
+ if controlnet is not None
92
+ and (
93
+ isinstance(controlnet, ControlNetModel)
94
+ or all(isinstance(c, ControlNetModel) for c in controlnet)
95
+ )
96
+ else {}
97
+ ),
98
+ }
99
+ )
100
+
79
101
  model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
80
102
 
81
103
  if export is None or export is False:
@@ -85,64 +107,87 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
85
107
  rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
86
108
  )
87
109
 
88
- save_dir = TemporaryDirectory()
89
- save_dir_path = Path(save_dir.name)
90
-
91
- model.save_pretrained(save_directory=save_dir_path, **kwargs)
92
-
93
110
  # compile model, create runtime
94
- vae = RBLNAutoencoderKL.from_pretrained(
95
- model_id=save_dir_path / "vae",
96
- export=True,
97
- rbln_unet_sample_size=model.unet.config.sample_size,
98
- rbln_use_encode=False,
99
- rbln_vae_scale_factor=model.vae_scale_factor,
100
- **rbln_config_kwargs,
101
- **rbln_constructor_kwargs,
102
- )
103
-
104
- text_encoder = RBLNCLIPTextModel.from_pretrained(
105
- model_id=save_dir_path / "text_encoder",
106
- export=True,
107
- **rbln_config_kwargs,
108
- **rbln_constructor_kwargs,
109
- )
110
-
111
- batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
112
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
113
-
114
- unet = RBLNUNet2DConditionModel.from_pretrained(
115
- model_id=save_dir_path / "unet",
116
- export=True,
117
- rbln_max_seq_len=text_encoder.config.max_position_embeddings,
118
- rbln_batch_size=unet_batch_size,
119
- rbln_use_encode=False,
120
- rbln_vae_scale_factor=model.vae_scale_factor,
121
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
122
- **rbln_config_kwargs,
123
- **rbln_constructor_kwargs,
124
- )
125
-
126
- if isinstance(controlnets, (list, tuple)):
127
- controlnet = RBLNMultiControlNetModel.from_pretrained(
128
- model_id=str(save_dir_path / "controlnet"),
111
+ if not isinstance(vae, RBLNAutoencoderKL):
112
+ vae = RBLNAutoencoderKL.from_pretrained(
113
+ model_id=model_id,
114
+ subfolder="vae",
129
115
  export=True,
130
- rbln_batch_size=unet_batch_size,
116
+ model_save_dir=model_save_dir,
117
+ rbln_unet_sample_size=model.unet.config.sample_size,
118
+ rbln_use_encode=False,
131
119
  rbln_vae_scale_factor=model.vae_scale_factor,
132
120
  **rbln_config_kwargs,
133
121
  **rbln_constructor_kwargs,
134
122
  )
135
- controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
136
- else:
137
- controlnet = RBLNControlNetModel.from_pretrained(
138
- model_id=save_dir_path / "controlnet",
123
+
124
+ if not isinstance(text_encoder, RBLNCLIPTextModel):
125
+ text_encoder = RBLNCLIPTextModel.from_pretrained(
126
+ model_id=model_id,
127
+ subfolder="text_encoder",
139
128
  export=True,
129
+ model_save_dir=model_save_dir,
130
+ **rbln_config_kwargs,
131
+ **rbln_constructor_kwargs,
132
+ )
133
+
134
+ batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
135
+ unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
136
+
137
+ if not isinstance(unet, RBLNUNet2DConditionModel):
138
+ unet = RBLNUNet2DConditionModel.from_pretrained(
139
+ model_id=model_id,
140
+ subfolder="unet",
141
+ export=True,
142
+ model_save_dir=model_save_dir,
143
+ rbln_max_seq_len=text_encoder.config.max_position_embeddings,
140
144
  rbln_batch_size=unet_batch_size,
145
+ rbln_use_encode=False,
141
146
  rbln_vae_scale_factor=model.vae_scale_factor,
147
+ rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
142
148
  **rbln_config_kwargs,
143
149
  **rbln_constructor_kwargs,
144
150
  )
145
- controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
151
+
152
+ if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
153
+ if isinstance(controlnet, (list, tuple)):
154
+ multicontrolnet = []
155
+ for i, cid in enumerate(controlnet):
156
+ subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
157
+ multicontrolnet.append(
158
+ RBLNControlNetModel.from_pretrained(
159
+ model_id=cid.config._name_or_path,
160
+ subfolder=subfolder_name,
161
+ export=True,
162
+ model_save_dir=model_save_dir,
163
+ rbln_batch_size=unet_batch_size,
164
+ rbln_vae_scale_factor=model.vae_scale_factor,
165
+ **rbln_config_kwargs,
166
+ **rbln_constructor_kwargs,
167
+ )
168
+ )
169
+ controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
170
+ controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
171
+ else:
172
+ controlnet = RBLNControlNetModel.from_pretrained(
173
+ model_id=controlnet.config._name_or_path,
174
+ subfolder="controlnet",
175
+ export=True,
176
+ model_save_dir=model_save_dir,
177
+ rbln_batch_size=unet_batch_size,
178
+ rbln_vae_scale_factor=model.vae_scale_factor,
179
+ **rbln_config_kwargs,
180
+ **rbln_constructor_kwargs,
181
+ )
182
+ controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
183
+
184
+ if model_save_dir is not None:
185
+ # To skip saving original pytorch modules
186
+ del (model.vae, model.text_encoder, model.unet, model.controlnet)
187
+
188
+ # Direct calling of `save_pretrained` causes config.unet = (None, None).
189
+ # So config must be saved again, later.
190
+ model.save_pretrained(model_save_dir)
146
191
 
147
192
  # replace modules
148
193
  model.vae = vae
@@ -159,15 +204,18 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
159
204
  }
160
205
  model.register_to_config(**update_dict)
161
206
 
162
- model.models = [vae.model[0], text_encoder.model[0], unet.model[0], controlnet.model[0]]
207
+ if model_save_dir is not None:
208
+ # overwrite to replace incorrect config
209
+ model.save_config(model_save_dir)
163
210
 
211
+ # use for CI to access each compiled model
164
212
  if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
165
- model.compiled_models = [
166
- vae.compiled_models[0],
167
- text_encoder.compiled_models[0],
168
- unet.compiled_models[0],
169
- controlnet.compiled_models[0],
170
- ]
213
+ model.compiled_models = [vae.compiled_models[0], text_encoder.compiled_models[0], unet.compiled_models[0]]
214
+ if isinstance(controlnet, RBLNMultiControlNetModel):
215
+ for c_model in controlnet.nets:
216
+ model.compiled_models.append(c_model.compiled_models[0])
217
+ else:
218
+ model.compiled_models.append(controlnet.compiled_models[0])
171
219
 
172
220
  return model
173
221