optimum-rbln 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +17 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/models/controlnet.py +7 -3
- optimum/rbln/diffusers/models/unet_2d_condition.py +5 -5
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +23 -146
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/modeling_alias.py +19 -1
- optimum/rbln/modeling_base.py +162 -18
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +3 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +516 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +464 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +123 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +67 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +10 -257
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
- optimum/rbln/transformers/models/llama/modeling_llama.py +12 -440
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +10 -325
- optimum/rbln/transformers/models/mistral/__init__.py +24 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +131 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +109 -0
- optimum/rbln/utils/import_utils.py +1 -4
- optimum/rbln/utils/runtime_utils.py +2 -1
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/METADATA +11 -5
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/RECORD +48 -35
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -32,9 +32,13 @@ _import_structure = {
|
|
32
32
|
"modeling_alias": [
|
33
33
|
"RBLNASTForAudioClassification",
|
34
34
|
"RBLNBertForQuestionAnswering",
|
35
|
+
"RBLNDistilBertForQuestionAnswering",
|
35
36
|
"RBLNResNetForImageClassification",
|
36
37
|
"RBLNT5ForConditionalGeneration",
|
37
38
|
"RBLNBartForConditionalGeneration",
|
39
|
+
"RBLNXLMRobertaForSequenceClassification",
|
40
|
+
"RBLNRobertaForSequenceClassification",
|
41
|
+
"RBLNRobertaForMaskedLM",
|
38
42
|
],
|
39
43
|
"modeling_base": [
|
40
44
|
"RBLNBaseModel",
|
@@ -42,6 +46,8 @@ _import_structure = {
|
|
42
46
|
"RBLNModelForQuestionAnswering",
|
43
47
|
"RBLNModelForAudioClassification",
|
44
48
|
"RBLNModelForImageClassification",
|
49
|
+
"RBLNModelForSequenceClassification",
|
50
|
+
"RBLNModelForMaskedLM",
|
45
51
|
],
|
46
52
|
"modeling_seq2seq": [
|
47
53
|
"RBLNModelForSeq2SeqLM",
|
@@ -51,11 +57,14 @@ _import_structure = {
|
|
51
57
|
"RBLNCLIPTextModel",
|
52
58
|
"RBLNCLIPTextModelWithProjection",
|
53
59
|
"RBLNDPTForDepthEstimation",
|
60
|
+
"RBLNGemmaForCausalLM",
|
54
61
|
"RBLNGPT2LMHeadModel",
|
55
62
|
"RBLNWav2Vec2ForCTC",
|
56
63
|
"RBLNLlamaForCausalLM",
|
57
64
|
"RBLNMidmLMHeadModel",
|
65
|
+
"RBLNMistralForCausalLM",
|
58
66
|
"RBLNWhisperForConditionalGeneration",
|
67
|
+
"RBLNXLMRobertaModel",
|
59
68
|
],
|
60
69
|
"diffusers": [
|
61
70
|
"RBLNStableDiffusionPipeline",
|
@@ -94,14 +103,19 @@ if TYPE_CHECKING:
|
|
94
103
|
RBLNBartForConditionalGeneration,
|
95
104
|
RBLNBertForQuestionAnswering,
|
96
105
|
RBLNResNetForImageClassification,
|
106
|
+
RBLNRobertaForMaskedLM,
|
107
|
+
RBLNRobertaForSequenceClassification,
|
97
108
|
RBLNT5ForConditionalGeneration,
|
109
|
+
RBLNXLMRobertaForSequenceClassification,
|
98
110
|
)
|
99
111
|
from .modeling_base import (
|
100
112
|
RBLNBaseModel,
|
101
113
|
RBLNModel,
|
102
114
|
RBLNModelForAudioClassification,
|
103
115
|
RBLNModelForImageClassification,
|
116
|
+
RBLNModelForMaskedLM,
|
104
117
|
RBLNModelForQuestionAnswering,
|
118
|
+
RBLNModelForSequenceClassification,
|
105
119
|
)
|
106
120
|
from .modeling_config import RBLNConfig, RBLNRuntimeConfig
|
107
121
|
from .modeling_seq2seq import RBLNModelForSeq2SeqLM
|
@@ -110,11 +124,14 @@ if TYPE_CHECKING:
|
|
110
124
|
RBLNCLIPTextModel,
|
111
125
|
RBLNCLIPTextModelWithProjection,
|
112
126
|
RBLNDPTForDepthEstimation,
|
127
|
+
RBLNGemmaForCausalLM,
|
113
128
|
RBLNGPT2LMHeadModel,
|
114
129
|
RBLNLlamaForCausalLM,
|
115
130
|
RBLNMidmLMHeadModel,
|
131
|
+
RBLNMistralForCausalLM,
|
116
132
|
RBLNWav2Vec2ForCTC,
|
117
133
|
RBLNWhisperForConditionalGeneration,
|
134
|
+
RBLNXLMRobertaModel,
|
118
135
|
)
|
119
136
|
else:
|
120
137
|
import sys
|
optimum/rbln/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = '0.1.
|
1
|
+
__version__ = '0.1.9'
|
@@ -26,7 +26,7 @@ from pathlib import Path
|
|
26
26
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
27
27
|
|
28
28
|
import rebel
|
29
|
-
import torch
|
29
|
+
import torch # noqa: I001
|
30
30
|
from diffusers import AutoencoderKL
|
31
31
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
32
32
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
@@ -38,12 +38,12 @@ from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRunt
|
|
38
38
|
from ...utils.runtime_utils import RBLNPytorchRuntime
|
39
39
|
|
40
40
|
|
41
|
-
logger = logging.getLogger(__name__)
|
42
|
-
|
43
41
|
if TYPE_CHECKING:
|
44
42
|
import torch
|
45
43
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
46
44
|
|
45
|
+
logger = logging.getLogger(__name__)
|
46
|
+
|
47
47
|
|
48
48
|
class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
49
49
|
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
@@ -34,12 +34,13 @@ from ...modeling_base import RBLNModel
|
|
34
34
|
from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
|
35
35
|
|
36
36
|
|
37
|
-
logger = logging.getLogger(__name__)
|
38
|
-
|
39
37
|
if TYPE_CHECKING:
|
40
38
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
41
39
|
|
42
40
|
|
41
|
+
logger = logging.getLogger(__name__)
|
42
|
+
|
43
|
+
|
43
44
|
class _ControlNetModel(torch.nn.Module):
|
44
45
|
def __init__(self, controlnet: "ControlNetModel"):
|
45
46
|
super().__init__()
|
@@ -120,6 +121,9 @@ class RBLNControlNetModel(RBLNModel):
|
|
120
121
|
model_name_or_path: Union[str, Path],
|
121
122
|
**kwargs,
|
122
123
|
):
|
124
|
+
if "subfolder" in kwargs:
|
125
|
+
del kwargs["subfolder"]
|
126
|
+
|
123
127
|
return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
|
124
128
|
|
125
129
|
tasktmp = TasksManager.get_model_from_task
|
@@ -135,7 +139,7 @@ class RBLNControlNetModel(RBLNModel):
|
|
135
139
|
return rt
|
136
140
|
|
137
141
|
@classmethod
|
138
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
|
142
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
139
143
|
use_encoder_hidden_states = False
|
140
144
|
for down_block in model.down_blocks:
|
141
145
|
if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
|
@@ -35,11 +35,11 @@ from ...modeling_base import RBLNModel
|
|
35
35
|
from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
|
36
36
|
|
37
37
|
|
38
|
-
logger = logging.getLogger(__name__)
|
39
|
-
|
40
38
|
if TYPE_CHECKING:
|
41
39
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
42
40
|
|
41
|
+
logger = logging.getLogger(__name__)
|
42
|
+
|
43
43
|
|
44
44
|
class _UNet_SD(torch.nn.Module):
|
45
45
|
def __init__(self, unet: "UNet2DConditionModel"):
|
@@ -172,7 +172,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
172
172
|
return rt
|
173
173
|
|
174
174
|
@classmethod
|
175
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
|
175
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
176
176
|
if model.config.addition_embed_type == "text_time":
|
177
177
|
return _UNet_SDXL(model).eval()
|
178
178
|
else:
|
@@ -244,6 +244,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
244
244
|
for i in range(3)
|
245
245
|
]
|
246
246
|
)
|
247
|
+
if len(model_config.block_out_channels) > 1:
|
247
248
|
input_info.append(
|
248
249
|
(
|
249
250
|
"down_block_additional_residuals_3",
|
@@ -251,7 +252,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
251
252
|
"float32",
|
252
253
|
)
|
253
254
|
)
|
254
|
-
if len(model_config.block_out_channels) > 1:
|
255
255
|
input_info.extend(
|
256
256
|
[
|
257
257
|
(
|
@@ -262,6 +262,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
262
262
|
for i in range(4, 6)
|
263
263
|
]
|
264
264
|
)
|
265
|
+
if len(model_config.block_out_channels) > 2:
|
265
266
|
input_info.append(
|
266
267
|
(
|
267
268
|
f"down_block_additional_residuals_{6}",
|
@@ -269,7 +270,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
269
270
|
"float32",
|
270
271
|
)
|
271
272
|
)
|
272
|
-
if len(model_config.block_out_channels) > 2:
|
273
273
|
input_info.extend(
|
274
274
|
[
|
275
275
|
(
|
@@ -24,56 +24,32 @@
|
|
24
24
|
import logging
|
25
25
|
import os
|
26
26
|
from pathlib import Path
|
27
|
-
from tempfile import TemporaryDirectory
|
28
27
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
29
28
|
|
30
|
-
import rebel
|
31
29
|
import torch
|
32
30
|
from diffusers import ControlNetModel
|
33
31
|
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
34
32
|
from optimum.exporters import TasksManager
|
35
|
-
from transformers import AutoConfig, AutoModel
|
33
|
+
from transformers import AutoConfig, AutoModel
|
36
34
|
|
37
|
-
from ....modeling_base import
|
38
|
-
from ....modeling_config import
|
35
|
+
from ....modeling_base import RBLNModel
|
36
|
+
from ....modeling_config import RBLNConfig
|
39
37
|
from ...models.controlnet import RBLNControlNetModel
|
40
38
|
|
41
39
|
|
42
|
-
logger = logging.getLogger(__name__)
|
43
|
-
|
44
40
|
if TYPE_CHECKING:
|
45
|
-
|
46
|
-
PretrainedConfig,
|
47
|
-
PreTrainedModel,
|
48
|
-
)
|
41
|
+
pass
|
49
42
|
|
43
|
+
logger = logging.getLogger(__name__)
|
50
44
|
|
51
|
-
class RBLNMultiControlNetModel(RBLNBaseModel):
|
52
|
-
model_type = "rbln_model"
|
53
|
-
auto_model_class = AutoModel
|
54
45
|
|
46
|
+
class RBLNMultiControlNetModel(RBLNModel):
|
55
47
|
def __init__(
|
56
48
|
self,
|
57
|
-
models: List[
|
58
|
-
config: PretrainedConfig = None,
|
59
|
-
preprocessors: Optional[List] = None,
|
60
|
-
rbln_config: Optional[RBLNConfig] = None,
|
49
|
+
models: List[RBLNControlNetModel],
|
61
50
|
**kwargs,
|
62
51
|
):
|
63
|
-
|
64
|
-
models,
|
65
|
-
config,
|
66
|
-
preprocessors,
|
67
|
-
rbln_config,
|
68
|
-
**kwargs,
|
69
|
-
)
|
70
|
-
|
71
|
-
if not isinstance(config, PretrainedConfig):
|
72
|
-
config = PretrainedConfig(**config)
|
73
|
-
|
74
|
-
for i in range(len(models)):
|
75
|
-
self.runtimes[i].config = config
|
76
|
-
self.nets = self.runtimes
|
52
|
+
self.nets = models
|
77
53
|
self.dtype = torch.float32
|
78
54
|
|
79
55
|
@classmethod
|
@@ -83,7 +59,7 @@ class RBLNMultiControlNetModel(RBLNBaseModel):
|
|
83
59
|
model_name_or_path: Union[str, Path],
|
84
60
|
**kwargs,
|
85
61
|
):
|
86
|
-
return MultiControlNetModel.from_pretrained(
|
62
|
+
return MultiControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
|
87
63
|
|
88
64
|
tasktmp = TasksManager.get_model_from_task
|
89
65
|
configtmp = AutoConfig.from_pretrained
|
@@ -101,129 +77,30 @@ class RBLNMultiControlNetModel(RBLNBaseModel):
|
|
101
77
|
def _from_pretrained(
|
102
78
|
cls,
|
103
79
|
model_id: Union[str, Path],
|
104
|
-
config: "PretrainedConfig",
|
105
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
106
|
-
revision: Optional[str] = None,
|
107
|
-
force_download: bool = False,
|
108
|
-
cache_dir: Optional[str] = None,
|
109
|
-
file_name: Optional[str] = None,
|
110
|
-
subfolder: str = "",
|
111
|
-
local_files_only: bool = False,
|
112
80
|
**kwargs,
|
113
|
-
) ->
|
114
|
-
if isinstance(model_id, str):
|
115
|
-
model_path = Path(model_id)
|
116
|
-
else:
|
117
|
-
model_path = model_id / "controlnet"
|
118
|
-
|
119
|
-
rbln_files = []
|
120
|
-
rbln_config_filenames = []
|
81
|
+
) -> RBLNModel:
|
121
82
|
idx = 0
|
122
|
-
|
83
|
+
controlnets = []
|
84
|
+
model_path_to_load = model_id
|
123
85
|
|
124
|
-
while
|
125
|
-
|
126
|
-
|
86
|
+
while os.path.isdir(model_path_to_load):
|
87
|
+
controlnet = RBLNControlNetModel.from_pretrained(model_path_to_load, export=False, **kwargs)
|
88
|
+
controlnets.append(controlnet)
|
89
|
+
rbln_config = RBLNConfig.load(model_path_to_load)
|
127
90
|
idx += 1
|
128
|
-
|
129
|
-
|
130
|
-
if len(rbln_files) == 0:
|
131
|
-
raise FileNotFoundError(f"Could not find any rbln model file in {model_path}")
|
132
|
-
|
133
|
-
if len(rbln_config_filenames) == 0:
|
134
|
-
raise FileNotFoundError(f"Could not find `rbln_config.json` file in {model_path}")
|
135
|
-
|
136
|
-
models = []
|
137
|
-
for rconf, rfiles in zip(rbln_config_filenames, rbln_files):
|
138
|
-
rbln_config = RBLNConfig.load(str(rconf))
|
139
|
-
models.append(rebel.RBLNCompiledModel(rfiles))
|
140
|
-
|
141
|
-
preprocessors = []
|
91
|
+
model_path_to_load = model_id + f"_{idx}"
|
142
92
|
|
143
93
|
return cls(
|
144
|
-
|
145
|
-
config,
|
146
|
-
preprocessors,
|
94
|
+
controlnets,
|
147
95
|
rbln_config=rbln_config,
|
148
96
|
**kwargs,
|
149
97
|
)
|
150
98
|
|
151
|
-
def
|
152
|
-
|
153
|
-
idx = 0
|
154
|
-
real_save_dir_path = save_directory
|
155
|
-
for compiled_model in self.compiled_models:
|
156
|
-
dst_path = Path(real_save_dir_path) / "compiled_model.rbln"
|
157
|
-
if not os.path.exists(real_save_dir_path):
|
158
|
-
os.makedirs(real_save_dir_path)
|
159
|
-
compiled_model.save(dst_path)
|
160
|
-
self.rbln_config.save(real_save_dir_path)
|
161
|
-
idx += 1
|
162
|
-
real_save_dir_path = save_directory + f"_{idx}"
|
163
|
-
|
164
|
-
@classmethod
|
165
|
-
@torch.no_grad()
|
166
|
-
def _export(
|
167
|
-
cls,
|
168
|
-
model_id: str,
|
169
|
-
config: "PretrainedConfig",
|
170
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
171
|
-
revision: Optional[str] = None,
|
172
|
-
force_download: bool = False,
|
173
|
-
cache_dir: Optional[str] = None,
|
174
|
-
subfolder: str = "",
|
175
|
-
local_files_only: bool = False,
|
176
|
-
trust_remote_code: bool = False,
|
177
|
-
**kwargs,
|
178
|
-
) -> "RBLNMultiControlNetModel":
|
179
|
-
task = kwargs.pop("task", None)
|
180
|
-
if task is None:
|
181
|
-
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
182
|
-
|
183
|
-
save_dir = TemporaryDirectory()
|
184
|
-
save_dir_path = Path(save_dir.name)
|
185
|
-
|
186
|
-
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
187
|
-
img_width = rbln_config_kwargs.pop("rbln_img_width", None)
|
188
|
-
img_height = rbln_config_kwargs.pop("rbln_img_height", None)
|
189
|
-
vae_scale_factor = rbln_config_kwargs.pop("rbln_vae_scale_factor", None)
|
190
|
-
batch_size = rbln_config_kwargs.pop("rbln_batch_size", None)
|
191
|
-
|
192
|
-
model: MultiControlNetModel = TasksManager.get_model_from_task(
|
193
|
-
task=task,
|
194
|
-
model_name_or_path=model_id,
|
195
|
-
)
|
196
|
-
|
197
|
-
model_path_to_load = model_id
|
198
|
-
real_save_dir_path = save_dir_path / "controlnet"
|
199
|
-
|
200
|
-
for idx in range(len(model.nets)):
|
99
|
+
def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
|
100
|
+
for idx, model in enumerate(self.nets):
|
201
101
|
suffix = "" if idx == 0 else f"_{idx}"
|
202
|
-
|
203
|
-
|
204
|
-
export=True,
|
205
|
-
rbln_batch_size=batch_size,
|
206
|
-
rbln_img_width=img_width,
|
207
|
-
rbln_img_height=img_height,
|
208
|
-
rbln_vae_scale_factor=vae_scale_factor,
|
209
|
-
)
|
210
|
-
controlnet.save_pretrained(real_save_dir_path)
|
211
|
-
real_save_dir_path = save_dir_path / f"controlnet_{idx+1}"
|
212
|
-
|
213
|
-
return cls._from_pretrained(
|
214
|
-
model_id=save_dir_path,
|
215
|
-
config=config,
|
216
|
-
model_save_dir=save_dir,
|
217
|
-
**rbln_constructor_kwargs,
|
218
|
-
**kwargs,
|
219
|
-
)
|
220
|
-
|
221
|
-
@classmethod
|
222
|
-
def _create_runtimes(
|
223
|
-
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
224
|
-
) -> List[rebel.Runtime]:
|
225
|
-
device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
226
|
-
return [compiled_model.create_runtime(tensor_type="pt", device=device) for compiled_model in compiled_models]
|
102
|
+
real_save_path = save_directory + suffix
|
103
|
+
model.save_pretrained(real_save_path)
|
227
104
|
|
228
105
|
def forward(
|
229
106
|
self,
|
@@ -241,7 +118,7 @@ class RBLNMultiControlNetModel(RBLNBaseModel):
|
|
241
118
|
return_dict: bool = True,
|
242
119
|
):
|
243
120
|
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
244
|
-
output = controlnet(
|
121
|
+
output = controlnet.model[0](
|
245
122
|
sample=sample.contiguous(),
|
246
123
|
timestep=timestep.float(),
|
247
124
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -22,18 +22,18 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
"""RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
|
24
24
|
|
25
|
-
from pathlib import Path
|
26
|
-
from tempfile import TemporaryDirectory
|
27
25
|
from typing import Any, Callable, Dict, List, Optional, Union
|
28
26
|
|
29
27
|
import torch
|
30
28
|
import torch.nn.functional as F
|
31
|
-
from diffusers import StableDiffusionControlNetPipeline
|
29
|
+
from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionControlNetPipeline
|
32
30
|
from diffusers.image_processor import PipelineImageInput
|
31
|
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
33
32
|
from diffusers.pipelines.controlnet.pipeline_controlnet import retrieve_timesteps
|
34
33
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
35
34
|
from diffusers.utils import deprecate, logging
|
36
35
|
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
|
36
|
+
from transformers import CLIPTextModel
|
37
37
|
|
38
38
|
from ....modeling_base import RBLNBaseModel
|
39
39
|
from ....transformers import RBLNCLIPTextModel
|
@@ -64,18 +64,40 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
64
64
|
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
65
65
|
"""
|
66
66
|
export = kwargs.pop("export", None)
|
67
|
+
vae = kwargs.pop("vae", None)
|
68
|
+
unet = kwargs.pop("unet", None)
|
67
69
|
text_encoder = kwargs.pop("text_encoder", None)
|
68
|
-
|
70
|
+
controlnet = kwargs.pop("controlnet", None)
|
71
|
+
model_save_dir = kwargs.pop("model_save_dir", None)
|
69
72
|
|
70
73
|
rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
|
71
74
|
|
72
75
|
kwargs_dict = {
|
73
76
|
"pretrained_model_name_or_path": model_id,
|
74
|
-
"text_encoder": text_encoder,
|
75
|
-
"controlnet": controlnets,
|
76
77
|
**kwargs,
|
77
78
|
}
|
78
79
|
|
80
|
+
kwargs_dict.update(
|
81
|
+
{
|
82
|
+
**({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
|
83
|
+
**({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
|
84
|
+
**(
|
85
|
+
{"text_encoder": text_encoder}
|
86
|
+
if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
|
87
|
+
else {}
|
88
|
+
),
|
89
|
+
**(
|
90
|
+
{"controlnet": controlnet}
|
91
|
+
if controlnet is not None
|
92
|
+
and (
|
93
|
+
isinstance(controlnet, ControlNetModel)
|
94
|
+
or all(isinstance(c, ControlNetModel) for c in controlnet)
|
95
|
+
)
|
96
|
+
else {}
|
97
|
+
),
|
98
|
+
}
|
99
|
+
)
|
100
|
+
|
79
101
|
model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
|
80
102
|
|
81
103
|
if export is None or export is False:
|
@@ -85,64 +107,87 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
85
107
|
rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
86
108
|
)
|
87
109
|
|
88
|
-
save_dir = TemporaryDirectory()
|
89
|
-
save_dir_path = Path(save_dir.name)
|
90
|
-
|
91
|
-
model.save_pretrained(save_directory=save_dir_path, **kwargs)
|
92
|
-
|
93
110
|
# compile model, create runtime
|
94
|
-
vae
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
rbln_use_encode=False,
|
99
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
100
|
-
**rbln_config_kwargs,
|
101
|
-
**rbln_constructor_kwargs,
|
102
|
-
)
|
103
|
-
|
104
|
-
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
105
|
-
model_id=save_dir_path / "text_encoder",
|
106
|
-
export=True,
|
107
|
-
**rbln_config_kwargs,
|
108
|
-
**rbln_constructor_kwargs,
|
109
|
-
)
|
110
|
-
|
111
|
-
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
112
|
-
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
113
|
-
|
114
|
-
unet = RBLNUNet2DConditionModel.from_pretrained(
|
115
|
-
model_id=save_dir_path / "unet",
|
116
|
-
export=True,
|
117
|
-
rbln_max_seq_len=text_encoder.config.max_position_embeddings,
|
118
|
-
rbln_batch_size=unet_batch_size,
|
119
|
-
rbln_use_encode=False,
|
120
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
121
|
-
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
122
|
-
**rbln_config_kwargs,
|
123
|
-
**rbln_constructor_kwargs,
|
124
|
-
)
|
125
|
-
|
126
|
-
if isinstance(controlnets, (list, tuple)):
|
127
|
-
controlnet = RBLNMultiControlNetModel.from_pretrained(
|
128
|
-
model_id=str(save_dir_path / "controlnet"),
|
111
|
+
if not isinstance(vae, RBLNAutoencoderKL):
|
112
|
+
vae = RBLNAutoencoderKL.from_pretrained(
|
113
|
+
model_id=model_id,
|
114
|
+
subfolder="vae",
|
129
115
|
export=True,
|
130
|
-
|
116
|
+
model_save_dir=model_save_dir,
|
117
|
+
rbln_unet_sample_size=model.unet.config.sample_size,
|
118
|
+
rbln_use_encode=False,
|
131
119
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
132
120
|
**rbln_config_kwargs,
|
133
121
|
**rbln_constructor_kwargs,
|
134
122
|
)
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
model_id=
|
123
|
+
|
124
|
+
if not isinstance(text_encoder, RBLNCLIPTextModel):
|
125
|
+
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
126
|
+
model_id=model_id,
|
127
|
+
subfolder="text_encoder",
|
139
128
|
export=True,
|
129
|
+
model_save_dir=model_save_dir,
|
130
|
+
**rbln_config_kwargs,
|
131
|
+
**rbln_constructor_kwargs,
|
132
|
+
)
|
133
|
+
|
134
|
+
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
135
|
+
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
136
|
+
|
137
|
+
if not isinstance(unet, RBLNUNet2DConditionModel):
|
138
|
+
unet = RBLNUNet2DConditionModel.from_pretrained(
|
139
|
+
model_id=model_id,
|
140
|
+
subfolder="unet",
|
141
|
+
export=True,
|
142
|
+
model_save_dir=model_save_dir,
|
143
|
+
rbln_max_seq_len=text_encoder.config.max_position_embeddings,
|
140
144
|
rbln_batch_size=unet_batch_size,
|
145
|
+
rbln_use_encode=False,
|
141
146
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
147
|
+
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
142
148
|
**rbln_config_kwargs,
|
143
149
|
**rbln_constructor_kwargs,
|
144
150
|
)
|
145
|
-
|
151
|
+
|
152
|
+
if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
|
153
|
+
if isinstance(controlnet, (list, tuple)):
|
154
|
+
multicontrolnet = []
|
155
|
+
for i, cid in enumerate(controlnet):
|
156
|
+
subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
|
157
|
+
multicontrolnet.append(
|
158
|
+
RBLNControlNetModel.from_pretrained(
|
159
|
+
model_id=cid.config._name_or_path,
|
160
|
+
subfolder=subfolder_name,
|
161
|
+
export=True,
|
162
|
+
model_save_dir=model_save_dir,
|
163
|
+
rbln_batch_size=unet_batch_size,
|
164
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
165
|
+
**rbln_config_kwargs,
|
166
|
+
**rbln_constructor_kwargs,
|
167
|
+
)
|
168
|
+
)
|
169
|
+
controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
|
170
|
+
controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
|
171
|
+
else:
|
172
|
+
controlnet = RBLNControlNetModel.from_pretrained(
|
173
|
+
model_id=controlnet.config._name_or_path,
|
174
|
+
subfolder="controlnet",
|
175
|
+
export=True,
|
176
|
+
model_save_dir=model_save_dir,
|
177
|
+
rbln_batch_size=unet_batch_size,
|
178
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
179
|
+
**rbln_config_kwargs,
|
180
|
+
**rbln_constructor_kwargs,
|
181
|
+
)
|
182
|
+
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
183
|
+
|
184
|
+
if model_save_dir is not None:
|
185
|
+
# To skip saving original pytorch modules
|
186
|
+
del (model.vae, model.text_encoder, model.unet, model.controlnet)
|
187
|
+
|
188
|
+
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
189
|
+
# So config must be saved again, later.
|
190
|
+
model.save_pretrained(model_save_dir)
|
146
191
|
|
147
192
|
# replace modules
|
148
193
|
model.vae = vae
|
@@ -159,15 +204,18 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
159
204
|
}
|
160
205
|
model.register_to_config(**update_dict)
|
161
206
|
|
162
|
-
|
207
|
+
if model_save_dir is not None:
|
208
|
+
# overwrite to replace incorrect config
|
209
|
+
model.save_config(model_save_dir)
|
163
210
|
|
211
|
+
# use for CI to access each compiled model
|
164
212
|
if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
|
165
|
-
model.compiled_models = [
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
213
|
+
model.compiled_models = [vae.compiled_models[0], text_encoder.compiled_models[0], unet.compiled_models[0]]
|
214
|
+
if isinstance(controlnet, RBLNMultiControlNetModel):
|
215
|
+
for c_model in controlnet.nets:
|
216
|
+
model.compiled_models.append(c_model.compiled_models[0])
|
217
|
+
else:
|
218
|
+
model.compiled_models.append(controlnet.compiled_models[0])
|
171
219
|
|
172
220
|
return model
|
173
221
|
|