optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +21 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
- optimum/rbln/diffusers/models/controlnet.py +3 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
- optimum/rbln/modeling_alias.py +14 -0
- optimum/rbln/modeling_base.py +282 -100
- optimum/rbln/modeling_seq2seq.py +58 -132
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +3 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
- optimum/rbln/transformers/models/dpt/__init__.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
- optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
- optimum/rbln/utils/__init__.py +1 -1
- optimum/rbln/utils/import_utils.py +46 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
- optimum_rbln-0.1.8.dist-info/RECORD +73 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
- optimum_rbln-0.1.4.dist-info/RECORD +0 -63
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -25,6 +25,8 @@ from typing import TYPE_CHECKING
|
|
25
25
|
|
26
26
|
from transformers.utils import _LazyModule
|
27
27
|
|
28
|
+
from .utils import check_version_compats
|
29
|
+
|
28
30
|
|
29
31
|
_import_structure = {
|
30
32
|
"modeling_alias": [
|
@@ -33,6 +35,9 @@ _import_structure = {
|
|
33
35
|
"RBLNResNetForImageClassification",
|
34
36
|
"RBLNT5ForConditionalGeneration",
|
35
37
|
"RBLNBartForConditionalGeneration",
|
38
|
+
"RBLNXLMRobertaForSequenceClassification",
|
39
|
+
"RBLNRobertaForSequenceClassification",
|
40
|
+
"RBLNRobertaForMaskedLM",
|
36
41
|
],
|
37
42
|
"modeling_base": [
|
38
43
|
"RBLNBaseModel",
|
@@ -40,6 +45,8 @@ _import_structure = {
|
|
40
45
|
"RBLNModelForQuestionAnswering",
|
41
46
|
"RBLNModelForAudioClassification",
|
42
47
|
"RBLNModelForImageClassification",
|
48
|
+
"RBLNModelForSequenceClassification",
|
49
|
+
"RBLNModelForMaskedLM",
|
43
50
|
],
|
44
51
|
"modeling_seq2seq": [
|
45
52
|
"RBLNModelForSeq2SeqLM",
|
@@ -48,11 +55,14 @@ _import_structure = {
|
|
48
55
|
"BatchTextIteratorStreamer",
|
49
56
|
"RBLNCLIPTextModel",
|
50
57
|
"RBLNCLIPTextModelWithProjection",
|
58
|
+
"RBLNDPTForDepthEstimation",
|
59
|
+
"RBLNGemmaForCausalLM",
|
51
60
|
"RBLNGPT2LMHeadModel",
|
52
61
|
"RBLNWav2Vec2ForCTC",
|
53
62
|
"RBLNLlamaForCausalLM",
|
54
63
|
"RBLNMidmLMHeadModel",
|
55
64
|
"RBLNWhisperForConditionalGeneration",
|
65
|
+
"RBLNXLMRobertaModel",
|
56
66
|
],
|
57
67
|
"diffusers": [
|
58
68
|
"RBLNStableDiffusionPipeline",
|
@@ -91,14 +101,19 @@ if TYPE_CHECKING:
|
|
91
101
|
RBLNBartForConditionalGeneration,
|
92
102
|
RBLNBertForQuestionAnswering,
|
93
103
|
RBLNResNetForImageClassification,
|
104
|
+
RBLNRobertaForMaskedLM,
|
105
|
+
RBLNRobertaForSequenceClassification,
|
94
106
|
RBLNT5ForConditionalGeneration,
|
107
|
+
RBLNXLMRobertaForSequenceClassification,
|
95
108
|
)
|
96
109
|
from .modeling_base import (
|
97
110
|
RBLNBaseModel,
|
98
111
|
RBLNModel,
|
99
112
|
RBLNModelForAudioClassification,
|
100
113
|
RBLNModelForImageClassification,
|
114
|
+
RBLNModelForMaskedLM,
|
101
115
|
RBLNModelForQuestionAnswering,
|
116
|
+
RBLNModelForSequenceClassification,
|
102
117
|
)
|
103
118
|
from .modeling_config import RBLNConfig, RBLNRuntimeConfig
|
104
119
|
from .modeling_seq2seq import RBLNModelForSeq2SeqLM
|
@@ -106,11 +121,14 @@ if TYPE_CHECKING:
|
|
106
121
|
BatchTextIteratorStreamer,
|
107
122
|
RBLNCLIPTextModel,
|
108
123
|
RBLNCLIPTextModelWithProjection,
|
124
|
+
RBLNDPTForDepthEstimation,
|
125
|
+
RBLNGemmaForCausalLM,
|
109
126
|
RBLNGPT2LMHeadModel,
|
110
127
|
RBLNLlamaForCausalLM,
|
111
128
|
RBLNMidmLMHeadModel,
|
112
129
|
RBLNWav2Vec2ForCTC,
|
113
130
|
RBLNWhisperForConditionalGeneration,
|
131
|
+
RBLNXLMRobertaModel,
|
114
132
|
)
|
115
133
|
else:
|
116
134
|
import sys
|
@@ -121,3 +139,6 @@ else:
|
|
121
139
|
_import_structure,
|
122
140
|
module_spec=__spec__,
|
123
141
|
)
|
142
|
+
|
143
|
+
|
144
|
+
check_version_compats()
|
optimum/rbln/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = '0.1.
|
1
|
+
__version__ = '0.1.8'
|
@@ -23,7 +23,6 @@
|
|
23
23
|
|
24
24
|
import logging
|
25
25
|
from pathlib import Path
|
26
|
-
from tempfile import TemporaryDirectory
|
27
26
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
28
27
|
|
29
28
|
import rebel
|
@@ -37,7 +36,6 @@ from transformers import AutoConfig, AutoModel, PretrainedConfig
|
|
37
36
|
from ...modeling_base import RBLNModel
|
38
37
|
from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
39
38
|
from ...utils.runtime_utils import RBLNPytorchRuntime
|
40
|
-
from ...utils.save_utils import maybe_save_preprocessors
|
41
39
|
|
42
40
|
|
43
41
|
logger = logging.getLogger(__name__)
|
@@ -70,73 +68,13 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
70
68
|
self.rbln_use_encode = self.rbln_config.meta["rbln_use_encode"]
|
71
69
|
|
72
70
|
if self.rbln_use_encode:
|
73
|
-
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.
|
74
|
-
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.
|
71
|
+
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
|
72
|
+
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
|
75
73
|
else:
|
76
|
-
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.
|
74
|
+
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[0], main_input_name="z")
|
77
75
|
|
78
76
|
@classmethod
|
79
|
-
|
80
|
-
def _export(
|
81
|
-
cls,
|
82
|
-
model_id: str,
|
83
|
-
config: "PretrainedConfig",
|
84
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
85
|
-
revision: Optional[str] = None,
|
86
|
-
force_download: bool = False,
|
87
|
-
cache_dir: Optional[str] = None,
|
88
|
-
subfolder: str = "",
|
89
|
-
local_files_only: bool = False,
|
90
|
-
trust_remote_code: bool = False,
|
91
|
-
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
92
|
-
**kwargs,
|
93
|
-
) -> "RBLNAutoencoderKL":
|
94
|
-
task = kwargs.pop("task", None)
|
95
|
-
if task is None:
|
96
|
-
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
97
|
-
|
98
|
-
if model_save_dir is None:
|
99
|
-
save_dir = TemporaryDirectory()
|
100
|
-
save_dir_path = Path(save_dir.name)
|
101
|
-
else:
|
102
|
-
save_dir = model_save_dir
|
103
|
-
if isinstance(save_dir, TemporaryDirectory):
|
104
|
-
save_dir_path = Path(model_save_dir.name)
|
105
|
-
else:
|
106
|
-
save_dir_path = Path(model_save_dir)
|
107
|
-
save_dir_path.mkdir(exist_ok=True)
|
108
|
-
|
109
|
-
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
110
|
-
|
111
|
-
model: AutoencoderKL = TasksManager.get_model_from_task(
|
112
|
-
task=None,
|
113
|
-
model_name_or_path=model_id,
|
114
|
-
subfolder=subfolder,
|
115
|
-
revision=revision,
|
116
|
-
framework="pt",
|
117
|
-
cache_dir=cache_dir,
|
118
|
-
use_auth_token=use_auth_token,
|
119
|
-
local_files_only=local_files_only,
|
120
|
-
force_download=force_download,
|
121
|
-
trust_remote_code=trust_remote_code,
|
122
|
-
**kwargs,
|
123
|
-
)
|
124
|
-
|
125
|
-
if config is None:
|
126
|
-
config = model.config
|
127
|
-
|
128
|
-
if not isinstance(config, PretrainedConfig): # diffusers config
|
129
|
-
config = PretrainedConfig(**config)
|
130
|
-
|
131
|
-
config.save_pretrained(save_dir_path / subfolder)
|
132
|
-
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
133
|
-
|
134
|
-
# Get compilation arguments
|
135
|
-
if rbln_config_kwargs.get("rbln_config", None) is None:
|
136
|
-
rbln_config = cls.get_rbln_config(
|
137
|
-
preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
|
138
|
-
)
|
139
|
-
|
77
|
+
def get_compiled_model(cls, model, rbln_config: RBLNConfig):
|
140
78
|
def compile_img2img():
|
141
79
|
encoder_model = _VAEEncoder(model)
|
142
80
|
decoder_model = _VAEDecoder(model)
|
@@ -146,12 +84,7 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
146
84
|
enc_compiled_model = cls.compile(encoder_model, rbln_runtime_config=rbln_config["encoder"][0])
|
147
85
|
dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["decoder"][0])
|
148
86
|
|
149
|
-
enc_compiled_model
|
150
|
-
save_dir_path / subfolder / f"{rbln_config['encoder'][0].compiled_model_name}.rbln"
|
151
|
-
)
|
152
|
-
dec_compiled_model.save(
|
153
|
-
save_dir_path / subfolder / f"{rbln_config['decoder'][0].compiled_model_name}.rbln"
|
154
|
-
)
|
87
|
+
return enc_compiled_model, dec_compiled_model
|
155
88
|
|
156
89
|
def compile_text2img():
|
157
90
|
decoder_model = _VAEDecoder(model)
|
@@ -159,30 +92,12 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
159
92
|
|
160
93
|
dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["compiled_model"][0])
|
161
94
|
|
162
|
-
dec_compiled_model
|
163
|
-
save_dir_path / subfolder / f"{rbln_config['compiled_model'][0].compiled_model_name}.rbln"
|
164
|
-
)
|
95
|
+
return dec_compiled_model
|
165
96
|
|
166
|
-
if
|
167
|
-
compile_img2img()
|
97
|
+
if rbln_config.meta.get("rbln_use_encode", False):
|
98
|
+
return compile_img2img()
|
168
99
|
else:
|
169
|
-
compile_text2img()
|
170
|
-
|
171
|
-
rbln_config.save(save_dir_path / subfolder)
|
172
|
-
|
173
|
-
return cls._from_pretrained(
|
174
|
-
model_id=save_dir_path,
|
175
|
-
config=config,
|
176
|
-
model_save_dir=save_dir,
|
177
|
-
use_auth_token=use_auth_token,
|
178
|
-
revision=revision,
|
179
|
-
force_download=force_download,
|
180
|
-
cache_dir=cache_dir,
|
181
|
-
subfolder=subfolder,
|
182
|
-
local_files_only=local_files_only,
|
183
|
-
**rbln_constructor_kwargs,
|
184
|
-
**kwargs,
|
185
|
-
)
|
100
|
+
return compile_text2img()
|
186
101
|
|
187
102
|
@classmethod
|
188
103
|
def from_pretrained(cls, *args, **kwargs):
|
@@ -282,15 +197,18 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
282
197
|
rbln_config = RBLNConfig.from_rbln_runtime_configs([vae_config], _rbln_meta=meta)
|
283
198
|
return rbln_config
|
284
199
|
|
285
|
-
|
286
|
-
|
200
|
+
@classmethod
|
201
|
+
def _create_runtimes(
|
202
|
+
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
203
|
+
) -> List[rebel.Runtime]:
|
204
|
+
if len(compiled_models) == 1:
|
287
205
|
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
288
|
-
return [
|
206
|
+
return [compiled_models[0].create_runtime(tensor_type="pt", device=device_val)]
|
289
207
|
|
290
208
|
device_vals = [rbln_device_map["encoder"], rbln_device_map["decoder"]]
|
291
209
|
return [
|
292
210
|
compiled_model.create_runtime(tensor_type="pt", device=device_val)
|
293
|
-
for compiled_model, device_val in zip(
|
211
|
+
for compiled_model, device_val in zip(compiled_models, device_vals)
|
294
212
|
]
|
295
213
|
|
296
214
|
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
@@ -120,6 +120,9 @@ class RBLNControlNetModel(RBLNModel):
|
|
120
120
|
model_name_or_path: Union[str, Path],
|
121
121
|
**kwargs,
|
122
122
|
):
|
123
|
+
if "subfolder" in kwargs:
|
124
|
+
del kwargs["subfolder"]
|
125
|
+
|
123
126
|
return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
|
124
127
|
|
125
128
|
tasktmp = TasksManager.get_model_from_task
|
@@ -244,6 +244,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
244
244
|
for i in range(3)
|
245
245
|
]
|
246
246
|
)
|
247
|
+
if len(model_config.block_out_channels) > 1:
|
247
248
|
input_info.append(
|
248
249
|
(
|
249
250
|
"down_block_additional_residuals_3",
|
@@ -251,7 +252,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
251
252
|
"float32",
|
252
253
|
)
|
253
254
|
)
|
254
|
-
if len(model_config.block_out_channels) > 1:
|
255
255
|
input_info.extend(
|
256
256
|
[
|
257
257
|
(
|
@@ -262,6 +262,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
262
262
|
for i in range(4, 6)
|
263
263
|
]
|
264
264
|
)
|
265
|
+
if len(model_config.block_out_channels) > 2:
|
265
266
|
input_info.append(
|
266
267
|
(
|
267
268
|
f"down_block_additional_residuals_{6}",
|
@@ -269,7 +270,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
269
270
|
"float32",
|
270
271
|
)
|
271
272
|
)
|
272
|
-
if len(model_config.block_out_channels) > 2:
|
273
273
|
input_info.extend(
|
274
274
|
[
|
275
275
|
(
|
@@ -314,7 +314,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
314
314
|
if rbln_text_model_hidden_size is None:
|
315
315
|
rbln_text_model_hidden_size = 768
|
316
316
|
if rbln_in_features is None:
|
317
|
-
rbln_in_features =
|
317
|
+
rbln_in_features = model_config.projection_class_embeddings_input_dim
|
318
318
|
meta["in_features"] = rbln_in_features
|
319
319
|
rbln_runtime_config.input_info.append(
|
320
320
|
("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
|
@@ -24,17 +24,15 @@
|
|
24
24
|
import logging
|
25
25
|
import os
|
26
26
|
from pathlib import Path
|
27
|
-
from tempfile import TemporaryDirectory
|
28
27
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
29
28
|
|
30
|
-
import rebel
|
31
29
|
import torch
|
32
30
|
from diffusers import ControlNetModel
|
33
31
|
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
34
32
|
from optimum.exporters import TasksManager
|
35
|
-
from transformers import AutoConfig, AutoModel
|
33
|
+
from transformers import AutoConfig, AutoModel
|
36
34
|
|
37
|
-
from ....modeling_base import
|
35
|
+
from ....modeling_base import RBLNModel
|
38
36
|
from ....modeling_config import RBLNConfig
|
39
37
|
from ...models.controlnet import RBLNControlNetModel
|
40
38
|
|
@@ -42,38 +40,16 @@ from ...models.controlnet import RBLNControlNetModel
|
|
42
40
|
logger = logging.getLogger(__name__)
|
43
41
|
|
44
42
|
if TYPE_CHECKING:
|
45
|
-
|
46
|
-
PretrainedConfig,
|
47
|
-
PreTrainedModel,
|
48
|
-
)
|
43
|
+
pass
|
49
44
|
|
50
45
|
|
51
|
-
class RBLNMultiControlNetModel(
|
52
|
-
model_type = "rbln_model"
|
53
|
-
auto_model_class = AutoModel
|
54
|
-
|
46
|
+
class RBLNMultiControlNetModel(RBLNModel):
|
55
47
|
def __init__(
|
56
48
|
self,
|
57
|
-
models: List[
|
58
|
-
config: PretrainedConfig = None,
|
59
|
-
preprocessors: Optional[List] = None,
|
60
|
-
rbln_config: Optional[RBLNConfig] = None,
|
49
|
+
models: List[RBLNControlNetModel],
|
61
50
|
**kwargs,
|
62
51
|
):
|
63
|
-
|
64
|
-
models,
|
65
|
-
config,
|
66
|
-
preprocessors,
|
67
|
-
rbln_config,
|
68
|
-
**kwargs,
|
69
|
-
)
|
70
|
-
|
71
|
-
if not isinstance(config, PretrainedConfig):
|
72
|
-
config = PretrainedConfig(**config)
|
73
|
-
|
74
|
-
for i in range(len(models)):
|
75
|
-
self.runtimes[i].config = config
|
76
|
-
self.nets = self.runtimes
|
52
|
+
self.nets = models
|
77
53
|
self.dtype = torch.float32
|
78
54
|
|
79
55
|
@classmethod
|
@@ -83,7 +59,7 @@ class RBLNMultiControlNetModel(RBLNBaseModel):
|
|
83
59
|
model_name_or_path: Union[str, Path],
|
84
60
|
**kwargs,
|
85
61
|
):
|
86
|
-
return MultiControlNetModel.from_pretrained(
|
62
|
+
return MultiControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
|
87
63
|
|
88
64
|
tasktmp = TasksManager.get_model_from_task
|
89
65
|
configtmp = AutoConfig.from_pretrained
|
@@ -101,131 +77,31 @@ class RBLNMultiControlNetModel(RBLNBaseModel):
|
|
101
77
|
def _from_pretrained(
|
102
78
|
cls,
|
103
79
|
model_id: Union[str, Path],
|
104
|
-
config: "PretrainedConfig",
|
105
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
106
|
-
revision: Optional[str] = None,
|
107
|
-
force_download: bool = False,
|
108
|
-
cache_dir: Optional[str] = None,
|
109
|
-
file_name: Optional[str] = None,
|
110
|
-
subfolder: str = "",
|
111
|
-
local_files_only: bool = False,
|
112
80
|
**kwargs,
|
113
|
-
) ->
|
114
|
-
|
115
|
-
if isinstance(model_id, str):
|
116
|
-
model_path = Path(model_id)
|
117
|
-
else:
|
118
|
-
model_path = model_id / "controlnet"
|
81
|
+
) -> RBLNModel:
|
119
82
|
|
120
|
-
rbln_files = []
|
121
|
-
rbln_config_filenames = []
|
122
83
|
idx = 0
|
123
|
-
|
84
|
+
controlnets = []
|
85
|
+
model_path_to_load = model_id
|
124
86
|
|
125
|
-
while
|
126
|
-
|
127
|
-
|
87
|
+
while os.path.isdir(model_path_to_load):
|
88
|
+
controlnet = RBLNControlNetModel.from_pretrained(model_path_to_load, export=False, **kwargs)
|
89
|
+
controlnets.append(controlnet)
|
90
|
+
rbln_config = RBLNConfig.load(model_path_to_load)
|
128
91
|
idx += 1
|
129
|
-
|
130
|
-
|
131
|
-
if len(rbln_files) == 0:
|
132
|
-
raise FileNotFoundError(f"Could not find any rbln model file in {model_path}")
|
133
|
-
|
134
|
-
if len(rbln_config_filenames) == 0:
|
135
|
-
raise FileNotFoundError(f"Could not find `rbln_config.json` file in {model_path}")
|
136
|
-
|
137
|
-
models = []
|
138
|
-
for rconf, rfiles in zip(rbln_config_filenames, rbln_files):
|
139
|
-
rbln_config = RBLNConfig.load(str(rconf))
|
140
|
-
models.append(rebel.RBLNCompiledModel(rfiles))
|
141
|
-
|
142
|
-
preprocessors = []
|
92
|
+
model_path_to_load = model_id + f"_{idx}"
|
143
93
|
|
144
94
|
return cls(
|
145
|
-
|
146
|
-
config,
|
147
|
-
preprocessors,
|
95
|
+
controlnets,
|
148
96
|
rbln_config=rbln_config,
|
149
97
|
**kwargs,
|
150
98
|
)
|
151
99
|
|
152
|
-
def
|
153
|
-
idx
|
154
|
-
real_save_dir_path = save_directory
|
155
|
-
for compiled_model in self.compiled_models:
|
156
|
-
dst_path = Path(real_save_dir_path) / "compiled_model.rbln"
|
157
|
-
if not os.path.exists(real_save_dir_path):
|
158
|
-
os.makedirs(real_save_dir_path)
|
159
|
-
compiled_model.save(dst_path)
|
160
|
-
self.rbln_config.save(real_save_dir_path)
|
161
|
-
idx += 1
|
162
|
-
real_save_dir_path = save_directory + f"_{idx}"
|
163
|
-
|
164
|
-
@classmethod
|
165
|
-
@torch.no_grad()
|
166
|
-
def _export(
|
167
|
-
cls,
|
168
|
-
model_id: str,
|
169
|
-
config: "PretrainedConfig",
|
170
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
171
|
-
revision: Optional[str] = None,
|
172
|
-
force_download: bool = False,
|
173
|
-
cache_dir: Optional[str] = None,
|
174
|
-
subfolder: str = "",
|
175
|
-
local_files_only: bool = False,
|
176
|
-
trust_remote_code: bool = False,
|
177
|
-
**kwargs,
|
178
|
-
) -> "RBLNMultiControlNetModel":
|
179
|
-
|
180
|
-
task = kwargs.pop("task", None)
|
181
|
-
if task is None:
|
182
|
-
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
183
|
-
|
184
|
-
save_dir = TemporaryDirectory()
|
185
|
-
save_dir_path = Path(save_dir.name)
|
186
|
-
|
187
|
-
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
188
|
-
img_width = rbln_config_kwargs.pop("rbln_img_width", None)
|
189
|
-
img_height = rbln_config_kwargs.pop("rbln_img_height", None)
|
190
|
-
vae_scale_factor = rbln_config_kwargs.pop("rbln_vae_scale_factor", None)
|
191
|
-
batch_size = rbln_config_kwargs.pop("rbln_batch_size", None)
|
192
|
-
|
193
|
-
model: MultiControlNetModel = TasksManager.get_model_from_task(
|
194
|
-
task=task,
|
195
|
-
model_name_or_path=model_id,
|
196
|
-
)
|
197
|
-
|
198
|
-
model_path_to_load = model_id
|
199
|
-
real_save_dir_path = save_dir_path / "controlnet"
|
200
|
-
|
201
|
-
for idx in range(len(model.nets)):
|
100
|
+
def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
|
101
|
+
for idx, model in enumerate(self.nets):
|
202
102
|
suffix = "" if idx == 0 else f"_{idx}"
|
203
|
-
|
204
|
-
|
205
|
-
export=True,
|
206
|
-
rbln_batch_size=batch_size,
|
207
|
-
rbln_img_width=img_width,
|
208
|
-
rbln_img_height=img_height,
|
209
|
-
rbln_vae_scale_factor=vae_scale_factor,
|
210
|
-
)
|
211
|
-
controlnet.save_pretrained(real_save_dir_path)
|
212
|
-
real_save_dir_path = save_dir_path / f"controlnet_{idx+1}"
|
213
|
-
|
214
|
-
return cls._from_pretrained(
|
215
|
-
model_id=save_dir_path,
|
216
|
-
config=config,
|
217
|
-
model_save_dir=save_dir,
|
218
|
-
**rbln_constructor_kwargs,
|
219
|
-
**kwargs,
|
220
|
-
)
|
221
|
-
|
222
|
-
def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
|
223
|
-
device_val = rbln_device_map["compiled_model"]
|
224
|
-
|
225
|
-
return [
|
226
|
-
compiled_model.create_runtime(tensor_type="pt", device=device_val)
|
227
|
-
for compiled_model in self.compiled_models
|
228
|
-
]
|
103
|
+
real_save_path = save_directory + suffix
|
104
|
+
model.save_pretrained(real_save_path)
|
229
105
|
|
230
106
|
def forward(
|
231
107
|
self,
|
@@ -243,9 +119,9 @@ class RBLNMultiControlNetModel(RBLNBaseModel):
|
|
243
119
|
return_dict: bool = True,
|
244
120
|
):
|
245
121
|
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
246
|
-
output = controlnet(
|
122
|
+
output = controlnet.model[0](
|
247
123
|
sample=sample.contiguous(),
|
248
|
-
timestep=timestep,
|
124
|
+
timestep=timestep.float(),
|
249
125
|
encoder_hidden_states=encoder_hidden_states,
|
250
126
|
controlnet_cond=image,
|
251
127
|
conditioning_scale=torch.tensor(scale),
|