optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +21 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  5. optimum/rbln/diffusers/models/controlnet.py +3 -0
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  16. optimum/rbln/modeling_alias.py +14 -0
  17. optimum/rbln/modeling_base.py +282 -100
  18. optimum/rbln/modeling_seq2seq.py +58 -132
  19. optimum/rbln/transformers/__init__.py +8 -0
  20. optimum/rbln/transformers/cache_utils.py +111 -0
  21. optimum/rbln/transformers/generation/utils.py +0 -2
  22. optimum/rbln/transformers/models/__init__.py +3 -0
  23. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  24. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  25. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  26. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  27. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  28. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  29. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  30. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  31. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  32. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  33. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
  34. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
  37. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  38. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  42. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  43. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  44. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  45. optimum/rbln/utils/__init__.py +1 -1
  46. optimum/rbln/utils/import_utils.py +46 -0
  47. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
  48. optimum_rbln-0.1.8.dist-info/RECORD +73 -0
  49. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
  51. optimum_rbln-0.1.4.dist-info/RECORD +0 -63
  52. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -25,6 +25,8 @@ from typing import TYPE_CHECKING
25
25
 
26
26
  from transformers.utils import _LazyModule
27
27
 
28
+ from .utils import check_version_compats
29
+
28
30
 
29
31
  _import_structure = {
30
32
  "modeling_alias": [
@@ -33,6 +35,9 @@ _import_structure = {
33
35
  "RBLNResNetForImageClassification",
34
36
  "RBLNT5ForConditionalGeneration",
35
37
  "RBLNBartForConditionalGeneration",
38
+ "RBLNXLMRobertaForSequenceClassification",
39
+ "RBLNRobertaForSequenceClassification",
40
+ "RBLNRobertaForMaskedLM",
36
41
  ],
37
42
  "modeling_base": [
38
43
  "RBLNBaseModel",
@@ -40,6 +45,8 @@ _import_structure = {
40
45
  "RBLNModelForQuestionAnswering",
41
46
  "RBLNModelForAudioClassification",
42
47
  "RBLNModelForImageClassification",
48
+ "RBLNModelForSequenceClassification",
49
+ "RBLNModelForMaskedLM",
43
50
  ],
44
51
  "modeling_seq2seq": [
45
52
  "RBLNModelForSeq2SeqLM",
@@ -48,11 +55,14 @@ _import_structure = {
48
55
  "BatchTextIteratorStreamer",
49
56
  "RBLNCLIPTextModel",
50
57
  "RBLNCLIPTextModelWithProjection",
58
+ "RBLNDPTForDepthEstimation",
59
+ "RBLNGemmaForCausalLM",
51
60
  "RBLNGPT2LMHeadModel",
52
61
  "RBLNWav2Vec2ForCTC",
53
62
  "RBLNLlamaForCausalLM",
54
63
  "RBLNMidmLMHeadModel",
55
64
  "RBLNWhisperForConditionalGeneration",
65
+ "RBLNXLMRobertaModel",
56
66
  ],
57
67
  "diffusers": [
58
68
  "RBLNStableDiffusionPipeline",
@@ -91,14 +101,19 @@ if TYPE_CHECKING:
91
101
  RBLNBartForConditionalGeneration,
92
102
  RBLNBertForQuestionAnswering,
93
103
  RBLNResNetForImageClassification,
104
+ RBLNRobertaForMaskedLM,
105
+ RBLNRobertaForSequenceClassification,
94
106
  RBLNT5ForConditionalGeneration,
107
+ RBLNXLMRobertaForSequenceClassification,
95
108
  )
96
109
  from .modeling_base import (
97
110
  RBLNBaseModel,
98
111
  RBLNModel,
99
112
  RBLNModelForAudioClassification,
100
113
  RBLNModelForImageClassification,
114
+ RBLNModelForMaskedLM,
101
115
  RBLNModelForQuestionAnswering,
116
+ RBLNModelForSequenceClassification,
102
117
  )
103
118
  from .modeling_config import RBLNConfig, RBLNRuntimeConfig
104
119
  from .modeling_seq2seq import RBLNModelForSeq2SeqLM
@@ -106,11 +121,14 @@ if TYPE_CHECKING:
106
121
  BatchTextIteratorStreamer,
107
122
  RBLNCLIPTextModel,
108
123
  RBLNCLIPTextModelWithProjection,
124
+ RBLNDPTForDepthEstimation,
125
+ RBLNGemmaForCausalLM,
109
126
  RBLNGPT2LMHeadModel,
110
127
  RBLNLlamaForCausalLM,
111
128
  RBLNMidmLMHeadModel,
112
129
  RBLNWav2Vec2ForCTC,
113
130
  RBLNWhisperForConditionalGeneration,
131
+ RBLNXLMRobertaModel,
114
132
  )
115
133
  else:
116
134
  import sys
@@ -121,3 +139,6 @@ else:
121
139
  _import_structure,
122
140
  module_spec=__spec__,
123
141
  )
142
+
143
+
144
+ check_version_compats()
@@ -1 +1 @@
1
- __version__ = '0.1.4'
1
+ __version__ = '0.1.8'
@@ -47,7 +47,6 @@ _import_structure = {
47
47
  }
48
48
 
49
49
  if TYPE_CHECKING:
50
-
51
50
  from .models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
52
51
  from .pipelines import (
53
52
  RBLNMultiControlNetModel,
@@ -23,7 +23,6 @@
23
23
 
24
24
  import logging
25
25
  from pathlib import Path
26
- from tempfile import TemporaryDirectory
27
26
  from typing import TYPE_CHECKING, Dict, List, Optional, Union
28
27
 
29
28
  import rebel
@@ -37,7 +36,6 @@ from transformers import AutoConfig, AutoModel, PretrainedConfig
37
36
  from ...modeling_base import RBLNModel
38
37
  from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
39
38
  from ...utils.runtime_utils import RBLNPytorchRuntime
40
- from ...utils.save_utils import maybe_save_preprocessors
41
39
 
42
40
 
43
41
  logger = logging.getLogger(__name__)
@@ -70,73 +68,13 @@ class RBLNAutoencoderKL(RBLNModel):
70
68
  self.rbln_use_encode = self.rbln_config.meta["rbln_use_encode"]
71
69
 
72
70
  if self.rbln_use_encode:
73
- self.encoder = RBLNRuntimeVAEEncoder(runtime=self.runtimes[0], main_input_name="x")
74
- self.decoder = RBLNRuntimeVAEDecoder(runtime=self.runtimes[1], main_input_name="z")
71
+ self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
72
+ self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
75
73
  else:
76
- self.decoder = RBLNRuntimeVAEDecoder(runtime=self.runtimes[0], main_input_name="z")
74
+ self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[0], main_input_name="z")
77
75
 
78
76
  @classmethod
79
- @torch.no_grad()
80
- def _export(
81
- cls,
82
- model_id: str,
83
- config: "PretrainedConfig",
84
- use_auth_token: Optional[Union[bool, str]] = None,
85
- revision: Optional[str] = None,
86
- force_download: bool = False,
87
- cache_dir: Optional[str] = None,
88
- subfolder: str = "",
89
- local_files_only: bool = False,
90
- trust_remote_code: bool = False,
91
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
92
- **kwargs,
93
- ) -> "RBLNAutoencoderKL":
94
- task = kwargs.pop("task", None)
95
- if task is None:
96
- task = TasksManager.infer_task_from_model(cls.auto_model_class)
97
-
98
- if model_save_dir is None:
99
- save_dir = TemporaryDirectory()
100
- save_dir_path = Path(save_dir.name)
101
- else:
102
- save_dir = model_save_dir
103
- if isinstance(save_dir, TemporaryDirectory):
104
- save_dir_path = Path(model_save_dir.name)
105
- else:
106
- save_dir_path = Path(model_save_dir)
107
- save_dir_path.mkdir(exist_ok=True)
108
-
109
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
110
-
111
- model: AutoencoderKL = TasksManager.get_model_from_task(
112
- task=None,
113
- model_name_or_path=model_id,
114
- subfolder=subfolder,
115
- revision=revision,
116
- framework="pt",
117
- cache_dir=cache_dir,
118
- use_auth_token=use_auth_token,
119
- local_files_only=local_files_only,
120
- force_download=force_download,
121
- trust_remote_code=trust_remote_code,
122
- **kwargs,
123
- )
124
-
125
- if config is None:
126
- config = model.config
127
-
128
- if not isinstance(config, PretrainedConfig): # diffusers config
129
- config = PretrainedConfig(**config)
130
-
131
- config.save_pretrained(save_dir_path / subfolder)
132
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
133
-
134
- # Get compilation arguments
135
- if rbln_config_kwargs.get("rbln_config", None) is None:
136
- rbln_config = cls.get_rbln_config(
137
- preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
138
- )
139
-
77
+ def get_compiled_model(cls, model, rbln_config: RBLNConfig):
140
78
  def compile_img2img():
141
79
  encoder_model = _VAEEncoder(model)
142
80
  decoder_model = _VAEDecoder(model)
@@ -146,12 +84,7 @@ class RBLNAutoencoderKL(RBLNModel):
146
84
  enc_compiled_model = cls.compile(encoder_model, rbln_runtime_config=rbln_config["encoder"][0])
147
85
  dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["decoder"][0])
148
86
 
149
- enc_compiled_model.save(
150
- save_dir_path / subfolder / f"{rbln_config['encoder'][0].compiled_model_name}.rbln"
151
- )
152
- dec_compiled_model.save(
153
- save_dir_path / subfolder / f"{rbln_config['decoder'][0].compiled_model_name}.rbln"
154
- )
87
+ return enc_compiled_model, dec_compiled_model
155
88
 
156
89
  def compile_text2img():
157
90
  decoder_model = _VAEDecoder(model)
@@ -159,30 +92,12 @@ class RBLNAutoencoderKL(RBLNModel):
159
92
 
160
93
  dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["compiled_model"][0])
161
94
 
162
- dec_compiled_model.save(
163
- save_dir_path / subfolder / f"{rbln_config['compiled_model'][0].compiled_model_name}.rbln"
164
- )
95
+ return dec_compiled_model
165
96
 
166
- if rbln_config_kwargs.get("rbln_use_encode"):
167
- compile_img2img()
97
+ if rbln_config.meta.get("rbln_use_encode", False):
98
+ return compile_img2img()
168
99
  else:
169
- compile_text2img()
170
-
171
- rbln_config.save(save_dir_path / subfolder)
172
-
173
- return cls._from_pretrained(
174
- model_id=save_dir_path,
175
- config=config,
176
- model_save_dir=save_dir,
177
- use_auth_token=use_auth_token,
178
- revision=revision,
179
- force_download=force_download,
180
- cache_dir=cache_dir,
181
- subfolder=subfolder,
182
- local_files_only=local_files_only,
183
- **rbln_constructor_kwargs,
184
- **kwargs,
185
- )
100
+ return compile_text2img()
186
101
 
187
102
  @classmethod
188
103
  def from_pretrained(cls, *args, **kwargs):
@@ -282,15 +197,18 @@ class RBLNAutoencoderKL(RBLNModel):
282
197
  rbln_config = RBLNConfig.from_rbln_runtime_configs([vae_config], _rbln_meta=meta)
283
198
  return rbln_config
284
199
 
285
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
286
- if len(self.compiled_models) == 1:
200
+ @classmethod
201
+ def _create_runtimes(
202
+ cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
203
+ ) -> List[rebel.Runtime]:
204
+ if len(compiled_models) == 1:
287
205
  device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
288
- return [self.compiled_models[0].create_runtime(tensor_type="pt", device=device_val)]
206
+ return [compiled_models[0].create_runtime(tensor_type="pt", device=device_val)]
289
207
 
290
208
  device_vals = [rbln_device_map["encoder"], rbln_device_map["decoder"]]
291
209
  return [
292
210
  compiled_model.create_runtime(tensor_type="pt", device=device_val)
293
- for compiled_model, device_val in zip(self.compiled_models, device_vals)
211
+ for compiled_model, device_val in zip(compiled_models, device_vals)
294
212
  ]
295
213
 
296
214
  def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
@@ -120,6 +120,9 @@ class RBLNControlNetModel(RBLNModel):
120
120
  model_name_or_path: Union[str, Path],
121
121
  **kwargs,
122
122
  ):
123
+ if "subfolder" in kwargs:
124
+ del kwargs["subfolder"]
125
+
123
126
  return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
124
127
 
125
128
  tasktmp = TasksManager.get_model_from_task
@@ -244,6 +244,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
244
244
  for i in range(3)
245
245
  ]
246
246
  )
247
+ if len(model_config.block_out_channels) > 1:
247
248
  input_info.append(
248
249
  (
249
250
  "down_block_additional_residuals_3",
@@ -251,7 +252,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
251
252
  "float32",
252
253
  )
253
254
  )
254
- if len(model_config.block_out_channels) > 1:
255
255
  input_info.extend(
256
256
  [
257
257
  (
@@ -262,6 +262,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
262
262
  for i in range(4, 6)
263
263
  ]
264
264
  )
265
+ if len(model_config.block_out_channels) > 2:
265
266
  input_info.append(
266
267
  (
267
268
  f"down_block_additional_residuals_{6}",
@@ -269,7 +270,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
269
270
  "float32",
270
271
  )
271
272
  )
272
- if len(model_config.block_out_channels) > 2:
273
273
  input_info.extend(
274
274
  [
275
275
  (
@@ -314,7 +314,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
314
314
  if rbln_text_model_hidden_size is None:
315
315
  rbln_text_model_hidden_size = 768
316
316
  if rbln_in_features is None:
317
- rbln_in_features = 2816
317
+ rbln_in_features = model_config.projection_class_embeddings_input_dim
318
318
  meta["in_features"] = rbln_in_features
319
319
  rbln_runtime_config.input_info.append(
320
320
  ("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
@@ -24,17 +24,15 @@
24
24
  import logging
25
25
  import os
26
26
  from pathlib import Path
27
- from tempfile import TemporaryDirectory
28
27
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
29
28
 
30
- import rebel
31
29
  import torch
32
30
  from diffusers import ControlNetModel
33
31
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
34
32
  from optimum.exporters import TasksManager
35
- from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
33
+ from transformers import AutoConfig, AutoModel
36
34
 
37
- from ....modeling_base import RBLNBaseModel
35
+ from ....modeling_base import RBLNModel
38
36
  from ....modeling_config import RBLNConfig
39
37
  from ...models.controlnet import RBLNControlNetModel
40
38
 
@@ -42,38 +40,16 @@ from ...models.controlnet import RBLNControlNetModel
42
40
  logger = logging.getLogger(__name__)
43
41
 
44
42
  if TYPE_CHECKING:
45
- from transformers import (
46
- PretrainedConfig,
47
- PreTrainedModel,
48
- )
43
+ pass
49
44
 
50
45
 
51
- class RBLNMultiControlNetModel(RBLNBaseModel):
52
- model_type = "rbln_model"
53
- auto_model_class = AutoModel
54
-
46
+ class RBLNMultiControlNetModel(RBLNModel):
55
47
  def __init__(
56
48
  self,
57
- models: List[Union[PreTrainedModel, rebel.RBLNCompiledModel]],
58
- config: PretrainedConfig = None,
59
- preprocessors: Optional[List] = None,
60
- rbln_config: Optional[RBLNConfig] = None,
49
+ models: List[RBLNControlNetModel],
61
50
  **kwargs,
62
51
  ):
63
- super().__init__(
64
- models,
65
- config,
66
- preprocessors,
67
- rbln_config,
68
- **kwargs,
69
- )
70
-
71
- if not isinstance(config, PretrainedConfig):
72
- config = PretrainedConfig(**config)
73
-
74
- for i in range(len(models)):
75
- self.runtimes[i].config = config
76
- self.nets = self.runtimes
52
+ self.nets = models
77
53
  self.dtype = torch.float32
78
54
 
79
55
  @classmethod
@@ -83,7 +59,7 @@ class RBLNMultiControlNetModel(RBLNBaseModel):
83
59
  model_name_or_path: Union[str, Path],
84
60
  **kwargs,
85
61
  ):
86
- return MultiControlNetModel.from_pretrained(pretrained_model_path=model_name_or_path, **kwargs)
62
+ return MultiControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
87
63
 
88
64
  tasktmp = TasksManager.get_model_from_task
89
65
  configtmp = AutoConfig.from_pretrained
@@ -101,131 +77,31 @@ class RBLNMultiControlNetModel(RBLNBaseModel):
101
77
  def _from_pretrained(
102
78
  cls,
103
79
  model_id: Union[str, Path],
104
- config: "PretrainedConfig",
105
- use_auth_token: Optional[Union[bool, str]] = None,
106
- revision: Optional[str] = None,
107
- force_download: bool = False,
108
- cache_dir: Optional[str] = None,
109
- file_name: Optional[str] = None,
110
- subfolder: str = "",
111
- local_files_only: bool = False,
112
80
  **kwargs,
113
- ) -> RBLNBaseModel:
114
-
115
- if isinstance(model_id, str):
116
- model_path = Path(model_id)
117
- else:
118
- model_path = model_id / "controlnet"
81
+ ) -> RBLNModel:
119
82
 
120
- rbln_files = []
121
- rbln_config_filenames = []
122
83
  idx = 0
123
- model_load_path = model_path
84
+ controlnets = []
85
+ model_path_to_load = model_id
124
86
 
125
- while model_load_path.is_dir():
126
- rbln_files.append(list(model_load_path.glob("**/*.rbln"))[0])
127
- rbln_config_filenames.append(model_load_path)
87
+ while os.path.isdir(model_path_to_load):
88
+ controlnet = RBLNControlNetModel.from_pretrained(model_path_to_load, export=False, **kwargs)
89
+ controlnets.append(controlnet)
90
+ rbln_config = RBLNConfig.load(model_path_to_load)
128
91
  idx += 1
129
- model_load_path = Path(str(model_path) + f"_{idx}")
130
-
131
- if len(rbln_files) == 0:
132
- raise FileNotFoundError(f"Could not find any rbln model file in {model_path}")
133
-
134
- if len(rbln_config_filenames) == 0:
135
- raise FileNotFoundError(f"Could not find `rbln_config.json` file in {model_path}")
136
-
137
- models = []
138
- for rconf, rfiles in zip(rbln_config_filenames, rbln_files):
139
- rbln_config = RBLNConfig.load(str(rconf))
140
- models.append(rebel.RBLNCompiledModel(rfiles))
141
-
142
- preprocessors = []
92
+ model_path_to_load = model_id + f"_{idx}"
143
93
 
144
94
  return cls(
145
- models,
146
- config,
147
- preprocessors,
95
+ controlnets,
148
96
  rbln_config=rbln_config,
149
97
  **kwargs,
150
98
  )
151
99
 
152
- def _save_pretrained(self, save_directory: Union[str, Path]):
153
- idx = 0
154
- real_save_dir_path = save_directory
155
- for compiled_model in self.compiled_models:
156
- dst_path = Path(real_save_dir_path) / "compiled_model.rbln"
157
- if not os.path.exists(real_save_dir_path):
158
- os.makedirs(real_save_dir_path)
159
- compiled_model.save(dst_path)
160
- self.rbln_config.save(real_save_dir_path)
161
- idx += 1
162
- real_save_dir_path = save_directory + f"_{idx}"
163
-
164
- @classmethod
165
- @torch.no_grad()
166
- def _export(
167
- cls,
168
- model_id: str,
169
- config: "PretrainedConfig",
170
- use_auth_token: Optional[Union[bool, str]] = None,
171
- revision: Optional[str] = None,
172
- force_download: bool = False,
173
- cache_dir: Optional[str] = None,
174
- subfolder: str = "",
175
- local_files_only: bool = False,
176
- trust_remote_code: bool = False,
177
- **kwargs,
178
- ) -> "RBLNMultiControlNetModel":
179
-
180
- task = kwargs.pop("task", None)
181
- if task is None:
182
- task = TasksManager.infer_task_from_model(cls.auto_model_class)
183
-
184
- save_dir = TemporaryDirectory()
185
- save_dir_path = Path(save_dir.name)
186
-
187
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
188
- img_width = rbln_config_kwargs.pop("rbln_img_width", None)
189
- img_height = rbln_config_kwargs.pop("rbln_img_height", None)
190
- vae_scale_factor = rbln_config_kwargs.pop("rbln_vae_scale_factor", None)
191
- batch_size = rbln_config_kwargs.pop("rbln_batch_size", None)
192
-
193
- model: MultiControlNetModel = TasksManager.get_model_from_task(
194
- task=task,
195
- model_name_or_path=model_id,
196
- )
197
-
198
- model_path_to_load = model_id
199
- real_save_dir_path = save_dir_path / "controlnet"
200
-
201
- for idx in range(len(model.nets)):
100
+ def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
101
+ for idx, model in enumerate(self.nets):
202
102
  suffix = "" if idx == 0 else f"_{idx}"
203
- controlnet = RBLNControlNetModel.from_pretrained(
204
- model_path_to_load + suffix,
205
- export=True,
206
- rbln_batch_size=batch_size,
207
- rbln_img_width=img_width,
208
- rbln_img_height=img_height,
209
- rbln_vae_scale_factor=vae_scale_factor,
210
- )
211
- controlnet.save_pretrained(real_save_dir_path)
212
- real_save_dir_path = save_dir_path / f"controlnet_{idx+1}"
213
-
214
- return cls._from_pretrained(
215
- model_id=save_dir_path,
216
- config=config,
217
- model_save_dir=save_dir,
218
- **rbln_constructor_kwargs,
219
- **kwargs,
220
- )
221
-
222
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
223
- device_val = rbln_device_map["compiled_model"]
224
-
225
- return [
226
- compiled_model.create_runtime(tensor_type="pt", device=device_val)
227
- for compiled_model in self.compiled_models
228
- ]
103
+ real_save_path = save_directory + suffix
104
+ model.save_pretrained(real_save_path)
229
105
 
230
106
  def forward(
231
107
  self,
@@ -243,9 +119,9 @@ class RBLNMultiControlNetModel(RBLNBaseModel):
243
119
  return_dict: bool = True,
244
120
  ):
245
121
  for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
246
- output = controlnet(
122
+ output = controlnet.model[0](
247
123
  sample=sample.contiguous(),
248
- timestep=timestep,
124
+ timestep=timestep.float(),
249
125
  encoder_hidden_states=encoder_hidden_states,
250
126
  controlnet_cond=image,
251
127
  conditioning_scale=torch.tensor(scale),