optimum-rbln 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +5 -1
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -61
- optimum/rbln/diffusers/models/controlnet.py +36 -56
- optimum/rbln/diffusers/models/unet_2d_condition.py +57 -153
- optimum/rbln/diffusers/pipelines/__init__.py +40 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -115
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -122
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -125
- optimum/rbln/modeling_base.py +12 -5
- optimum/rbln/modeling_diffusers.py +400 -0
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/clip/modeling_clip.py +13 -22
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +74 -16
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +18 -9
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +27 -8
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
- optimum/rbln/transformers/models/midm/modeling_midm.py +4 -29
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +57 -4
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
- optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
- optimum/rbln/utils/context.py +58 -0
- optimum/rbln/utils/decorator_utils.py +55 -0
- optimum/rbln/utils/import_utils.py +7 -0
- optimum/rbln/utils/runtime_utils.py +4 -4
- optimum/rbln/utils/timer_utils.py +2 -2
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +8 -7
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/RECORD +52 -48
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -38,7 +38,7 @@ _import_structure = {
|
|
38
38
|
"RBLNXLMRobertaForSequenceClassification",
|
39
39
|
"RBLNRobertaForSequenceClassification",
|
40
40
|
"RBLNRobertaForMaskedLM",
|
41
|
-
"RBLNViTForImageClassification"
|
41
|
+
"RBLNViTForImageClassification",
|
42
42
|
],
|
43
43
|
"modeling_base": [
|
44
44
|
"RBLNBaseModel",
|
@@ -76,6 +76,7 @@ _import_structure = {
|
|
76
76
|
"RBLNQwen2ForCausalLM",
|
77
77
|
"RBLNWav2Vec2ForCTC",
|
78
78
|
"RBLNLlamaForCausalLM",
|
79
|
+
"RBLNT5EncoderModel",
|
79
80
|
"RBLNT5ForConditionalGeneration",
|
80
81
|
"RBLNPhiForCausalLM",
|
81
82
|
"RBLNLlavaNextForConditionalGeneration",
|
@@ -99,6 +100,7 @@ _import_structure = {
|
|
99
100
|
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
100
101
|
],
|
101
102
|
"modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
|
103
|
+
"modeling_diffusers": ["RBLNDiffusionMixin"],
|
102
104
|
}
|
103
105
|
|
104
106
|
if TYPE_CHECKING:
|
@@ -136,6 +138,7 @@ if TYPE_CHECKING:
|
|
136
138
|
RBLNModelForSequenceClassification,
|
137
139
|
)
|
138
140
|
from .modeling_config import RBLNCompileConfig, RBLNConfig
|
141
|
+
from .modeling_diffusers import RBLNDiffusionMixin
|
139
142
|
from .transformers import (
|
140
143
|
BatchTextIteratorStreamer,
|
141
144
|
RBLNAutoModel,
|
@@ -166,6 +169,7 @@ if TYPE_CHECKING:
|
|
166
169
|
RBLNMistralForCausalLM,
|
167
170
|
RBLNPhiForCausalLM,
|
168
171
|
RBLNQwen2ForCausalLM,
|
172
|
+
RBLNT5EncoderModel,
|
169
173
|
RBLNT5ForConditionalGeneration,
|
170
174
|
RBLNWav2Vec2ForCTC,
|
171
175
|
RBLNWhisperForConditionalGeneration,
|
optimum/rbln/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = '0.1.
|
1
|
+
__version__ = '0.1.13'
|
@@ -22,7 +22,6 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import logging
|
25
|
-
from pathlib import Path
|
26
25
|
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
27
26
|
|
28
27
|
import rebel
|
@@ -30,11 +29,11 @@ import torch # noqa: I001
|
|
30
29
|
from diffusers import AutoencoderKL
|
31
30
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
32
31
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
33
|
-
from
|
34
|
-
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
32
|
+
from transformers import PretrainedConfig
|
35
33
|
|
36
34
|
from ...modeling_base import RBLNModel
|
37
35
|
from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
36
|
+
from ...utils.context import override_auto_classes
|
38
37
|
from ...utils.runtime_utils import RBLNPytorchRuntime
|
39
38
|
|
40
39
|
|
@@ -63,8 +62,7 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
63
62
|
def __post_init__(self, **kwargs):
|
64
63
|
super().__post_init__(**kwargs)
|
65
64
|
|
66
|
-
|
67
|
-
if self.rbln_use_encode:
|
65
|
+
if self.rbln_config.model_cfg.get("img2img_pipeline"):
|
68
66
|
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
|
69
67
|
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
|
70
68
|
else:
|
@@ -91,38 +89,15 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
91
89
|
|
92
90
|
return dec_compiled_model
|
93
91
|
|
94
|
-
if rbln_config.model_cfg.get("
|
92
|
+
if rbln_config.model_cfg.get("img2img_pipeline"):
|
95
93
|
return compile_img2img()
|
96
94
|
else:
|
97
95
|
return compile_text2img()
|
98
96
|
|
99
97
|
@classmethod
|
100
98
|
def from_pretrained(cls, *args, **kwargs):
|
101
|
-
|
102
|
-
|
103
|
-
model_name_or_path: Union[str, Path],
|
104
|
-
**kwargs,
|
105
|
-
):
|
106
|
-
return AutoencoderKL.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
|
107
|
-
|
108
|
-
tasktmp = TasksManager.get_model_from_task
|
109
|
-
configtmp = AutoConfig.from_pretrained
|
110
|
-
modeltmp = AutoModel.from_pretrained
|
111
|
-
TasksManager.get_model_from_task = get_model_from_task
|
112
|
-
|
113
|
-
if kwargs.get("export", None):
|
114
|
-
# This is an ad-hoc to workaround save null values of the config.
|
115
|
-
# if export, pure optimum(not optimum-rbln) loads config using AutoConfig
|
116
|
-
# and diffusers model do not support loading by AutoConfig.
|
117
|
-
AutoConfig.from_pretrained = lambda *args, **kwargs: None
|
118
|
-
else:
|
119
|
-
AutoConfig.from_pretrained = AutoencoderKL.load_config
|
120
|
-
|
121
|
-
AutoModel.from_pretrained = AutoencoderKL.from_pretrained
|
122
|
-
rt = super().from_pretrained(*args, **kwargs)
|
123
|
-
AutoConfig.from_pretrained = configtmp
|
124
|
-
AutoModel.from_pretrained = modeltmp
|
125
|
-
TasksManager.get_model_from_task = tasktmp
|
99
|
+
with override_auto_classes(config_func=AutoencoderKL.load_config, model_func=AutoencoderKL.from_pretrained):
|
100
|
+
rt = super().from_pretrained(*args, **kwargs)
|
126
101
|
return rt
|
127
102
|
|
128
103
|
@classmethod
|
@@ -132,34 +107,39 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
132
107
|
model_config: "PretrainedConfig",
|
133
108
|
rbln_kwargs: Dict[str, Any] = {},
|
134
109
|
) -> RBLNConfig:
|
135
|
-
|
136
|
-
|
137
|
-
rbln_img_height = rbln_kwargs.get("img_height", None)
|
138
|
-
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
139
|
-
rbln_use_encode = rbln_kwargs.get("use_encode", None)
|
140
|
-
rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
|
110
|
+
rbln_batch_size = rbln_kwargs.get("batch_size")
|
111
|
+
sample_size = rbln_kwargs.get("sample_size")
|
141
112
|
|
142
113
|
if rbln_batch_size is None:
|
143
114
|
rbln_batch_size = 1
|
144
115
|
|
145
|
-
|
116
|
+
if sample_size is None:
|
117
|
+
sample_size = model_config.sample_size
|
118
|
+
|
119
|
+
if isinstance(sample_size, int):
|
120
|
+
sample_size = (sample_size, sample_size)
|
121
|
+
|
122
|
+
if hasattr(model_config, "block_out_channels"):
|
123
|
+
vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
|
124
|
+
else:
|
125
|
+
# vae image processor default value 8 (int)
|
126
|
+
vae_scale_factor = 8
|
146
127
|
|
147
|
-
|
148
|
-
|
149
|
-
model_cfg["img_height"] = rbln_img_height
|
128
|
+
dec_shape = (sample_size[0] // vae_scale_factor, sample_size[1] // vae_scale_factor)
|
129
|
+
enc_shape = (sample_size[0], sample_size[1])
|
150
130
|
|
131
|
+
if rbln_kwargs["img2img_pipeline"]:
|
151
132
|
vae_enc_input_info = [
|
152
|
-
(
|
133
|
+
(
|
134
|
+
"x",
|
135
|
+
[rbln_batch_size, model_config.in_channels, enc_shape[0], enc_shape[1]],
|
136
|
+
"float32",
|
137
|
+
)
|
153
138
|
]
|
154
139
|
vae_dec_input_info = [
|
155
140
|
(
|
156
141
|
"z",
|
157
|
-
[
|
158
|
-
rbln_batch_size,
|
159
|
-
model_config.latent_channels,
|
160
|
-
rbln_img_height // rbln_vae_scale_factor,
|
161
|
-
rbln_img_width // rbln_vae_scale_factor,
|
162
|
-
],
|
142
|
+
[rbln_batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
|
163
143
|
"float32",
|
164
144
|
)
|
165
145
|
]
|
@@ -173,33 +153,22 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
173
153
|
compile_cfgs=compile_cfgs,
|
174
154
|
rbln_kwargs=rbln_kwargs,
|
175
155
|
)
|
176
|
-
rbln_config.model_cfg.update(model_cfg)
|
177
156
|
return rbln_config
|
178
157
|
|
179
|
-
if rbln_unet_sample_size is None:
|
180
|
-
rbln_unet_sample_size = 64
|
181
|
-
|
182
|
-
model_cfg["unet_sample_size"] = rbln_unet_sample_size
|
183
158
|
vae_config = RBLNCompileConfig(
|
184
159
|
input_info=[
|
185
160
|
(
|
186
161
|
"z",
|
187
|
-
[
|
188
|
-
rbln_batch_size,
|
189
|
-
model_config.latent_channels,
|
190
|
-
rbln_unet_sample_size,
|
191
|
-
rbln_unet_sample_size,
|
192
|
-
],
|
162
|
+
[rbln_batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
|
193
163
|
"float32",
|
194
164
|
)
|
195
|
-
]
|
165
|
+
]
|
196
166
|
)
|
197
167
|
rbln_config = RBLNConfig(
|
198
168
|
rbln_cls=cls.__name__,
|
199
169
|
compile_cfgs=[vae_config],
|
200
170
|
rbln_kwargs=rbln_kwargs,
|
201
171
|
)
|
202
|
-
rbln_config.model_cfg.update(model_cfg)
|
203
172
|
return rbln_config
|
204
173
|
|
205
174
|
@classmethod
|
@@ -22,16 +22,15 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import logging
|
25
|
-
from pathlib import Path
|
26
25
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
27
26
|
|
28
27
|
import torch
|
29
28
|
from diffusers import ControlNetModel
|
30
|
-
from
|
31
|
-
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
29
|
+
from transformers import PretrainedConfig
|
32
30
|
|
33
31
|
from ...modeling_base import RBLNModel
|
34
32
|
from ...modeling_config import RBLNCompileConfig, RBLNConfig
|
33
|
+
from ...utils.context import override_auto_classes
|
35
34
|
|
36
35
|
|
37
36
|
if TYPE_CHECKING:
|
@@ -113,23 +112,11 @@ class RBLNControlNetModel(RBLNModel):
|
|
113
112
|
|
114
113
|
@classmethod
|
115
114
|
def from_pretrained(cls, *args, **kwargs):
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
**kwargs,
|
115
|
+
with override_auto_classes(
|
116
|
+
config_func=ControlNetModel.load_config,
|
117
|
+
model_func=ControlNetModel.from_pretrained,
|
120
118
|
):
|
121
|
-
|
122
|
-
|
123
|
-
tasktmp = TasksManager.get_model_from_task
|
124
|
-
configtmp = AutoConfig.from_pretrained
|
125
|
-
modeltmp = AutoModel.from_pretrained
|
126
|
-
TasksManager.get_model_from_task = get_model_from_task
|
127
|
-
AutoConfig.from_pretrained = ControlNetModel.load_config
|
128
|
-
AutoModel.from_pretrained = ControlNetModel.from_pretrained
|
129
|
-
rt = super().from_pretrained(*args, **kwargs)
|
130
|
-
AutoConfig.from_pretrained = configtmp
|
131
|
-
AutoModel.from_pretrained = modeltmp
|
132
|
-
TasksManager.get_model_from_task = tasktmp
|
119
|
+
rt = super().from_pretrained(*args, **kwargs)
|
133
120
|
return rt
|
134
121
|
|
135
122
|
@classmethod
|
@@ -151,33 +138,35 @@ class RBLNControlNetModel(RBLNModel):
|
|
151
138
|
model_config: "PretrainedConfig",
|
152
139
|
rbln_kwargs: Dict[str, Any] = {},
|
153
140
|
) -> RBLNConfig:
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
rbln_img_height = rbln_kwargs.get("img_height", None)
|
159
|
-
rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
|
141
|
+
batch_size = rbln_kwargs.get("batch_size")
|
142
|
+
max_seq_len = rbln_kwargs.get("max_seq_len")
|
143
|
+
unet_sample_size = rbln_kwargs.get("unet_sample_size")
|
144
|
+
vae_sample_size = rbln_kwargs.get("vae_sample_size")
|
160
145
|
|
161
|
-
if
|
162
|
-
|
146
|
+
if batch_size is None:
|
147
|
+
batch_size = 1
|
163
148
|
|
164
|
-
if
|
165
|
-
|
149
|
+
if unet_sample_size is None:
|
150
|
+
raise ValueError(
|
151
|
+
"`rbln_unet_sample_size` (latent height, widht) must be specified (ex. unet's sample_size)"
|
152
|
+
)
|
166
153
|
|
167
|
-
if
|
168
|
-
raise ValueError(
|
154
|
+
if vae_sample_size is None:
|
155
|
+
raise ValueError(
|
156
|
+
"`rbln_vae_sample_size` (input image height, width) must be specified (ex. vae's sample_size)"
|
157
|
+
)
|
169
158
|
|
170
|
-
|
171
|
-
|
159
|
+
if max_seq_len is None:
|
160
|
+
raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings )must be specified")
|
172
161
|
|
173
162
|
input_info = [
|
174
163
|
(
|
175
164
|
"sample",
|
176
165
|
[
|
177
|
-
|
166
|
+
batch_size,
|
178
167
|
model_config.in_channels,
|
179
|
-
|
180
|
-
|
168
|
+
unet_sample_size[0],
|
169
|
+
unet_sample_size[1],
|
181
170
|
],
|
182
171
|
"float32",
|
183
172
|
),
|
@@ -189,23 +178,24 @@ class RBLNControlNetModel(RBLNModel):
|
|
189
178
|
input_info.append(
|
190
179
|
(
|
191
180
|
"encoder_hidden_states",
|
192
|
-
[
|
193
|
-
rbln_batch_size,
|
194
|
-
rbln_max_seq_len,
|
195
|
-
model_config.cross_attention_dim,
|
196
|
-
],
|
181
|
+
[batch_size, max_seq_len, model_config.cross_attention_dim],
|
197
182
|
"float32",
|
198
183
|
)
|
199
184
|
)
|
200
185
|
|
201
|
-
input_info.append(
|
186
|
+
input_info.append(
|
187
|
+
(
|
188
|
+
"controlnet_cond",
|
189
|
+
[batch_size, 3, vae_sample_size[0], vae_sample_size[1]],
|
190
|
+
"float32",
|
191
|
+
)
|
192
|
+
)
|
202
193
|
input_info.append(("conditioning_scale", [], "float32"))
|
203
194
|
|
204
195
|
if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
|
205
|
-
|
206
|
-
|
207
|
-
input_info.append(("
|
208
|
-
input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
|
196
|
+
rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
|
197
|
+
input_info.append(("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32"))
|
198
|
+
input_info.append(("time_ids", [batch_size, 6], "float32"))
|
209
199
|
|
210
200
|
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
211
201
|
|
@@ -215,16 +205,6 @@ class RBLNControlNetModel(RBLNModel):
|
|
215
205
|
rbln_kwargs=rbln_kwargs,
|
216
206
|
)
|
217
207
|
|
218
|
-
rbln_config.model_cfg.update(
|
219
|
-
{
|
220
|
-
"max_seq_len": rbln_max_seq_len,
|
221
|
-
"batch_size": rbln_batch_size,
|
222
|
-
"img_width": rbln_img_width,
|
223
|
-
"img_height": rbln_img_height,
|
224
|
-
"vae_scale_factor": rbln_vae_scale_factor,
|
225
|
-
}
|
226
|
-
)
|
227
|
-
|
228
208
|
return rbln_config
|
229
209
|
|
230
210
|
def forward(
|
@@ -23,16 +23,15 @@
|
|
23
23
|
|
24
24
|
import logging
|
25
25
|
from dataclasses import dataclass
|
26
|
-
from pathlib import Path
|
27
26
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
28
27
|
|
29
28
|
import torch
|
30
29
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
31
|
-
from
|
32
|
-
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
30
|
+
from transformers import PretrainedConfig
|
33
31
|
|
34
32
|
from ...modeling_base import RBLNModel
|
35
33
|
from ...modeling_config import RBLNCompileConfig, RBLNConfig
|
34
|
+
from ...utils.context import override_auto_classes
|
36
35
|
|
37
36
|
|
38
37
|
if TYPE_CHECKING:
|
@@ -143,29 +142,11 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
143
142
|
|
144
143
|
@classmethod
|
145
144
|
def from_pretrained(cls, *args, **kwargs):
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
**kwargs,
|
145
|
+
with override_auto_classes(
|
146
|
+
config_func=UNet2DConditionModel.load_config,
|
147
|
+
model_func=UNet2DConditionModel.from_pretrained,
|
150
148
|
):
|
151
|
-
|
152
|
-
|
153
|
-
tasktmp = TasksManager.get_model_from_task
|
154
|
-
configtmp = AutoConfig.from_pretrained
|
155
|
-
modeltmp = AutoModel.from_pretrained
|
156
|
-
TasksManager.get_model_from_task = get_model_from_task
|
157
|
-
if kwargs.get("export", None):
|
158
|
-
# This is an ad-hoc to workaround save null values of the config.
|
159
|
-
# if export, pure optimum(not optimum-rbln) loads config using AutoConfig
|
160
|
-
# and diffusers model do not support loading by AutoConfig.
|
161
|
-
AutoConfig.from_pretrained = lambda *args, **kwargs: None
|
162
|
-
else:
|
163
|
-
AutoConfig.from_pretrained = UNet2DConditionModel.load_config
|
164
|
-
AutoModel.from_pretrained = UNet2DConditionModel.from_pretrained
|
165
|
-
rt = super().from_pretrained(*args, **kwargs)
|
166
|
-
AutoConfig.from_pretrained = configtmp
|
167
|
-
AutoModel.from_pretrained = modeltmp
|
168
|
-
TasksManager.get_model_from_task = tasktmp
|
149
|
+
rt = super().from_pretrained(*args, **kwargs)
|
169
150
|
return rt
|
170
151
|
|
171
152
|
@classmethod
|
@@ -182,137 +163,68 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
182
163
|
model_config: "PretrainedConfig",
|
183
164
|
rbln_kwargs: Dict[str, Any] = {},
|
184
165
|
) -> RBLNConfig:
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
if
|
198
|
-
|
199
|
-
|
200
|
-
if
|
201
|
-
|
202
|
-
raise ValueError(
|
203
|
-
"rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided when rbln_use_encode is True"
|
204
|
-
)
|
205
|
-
input_width = rbln_img_width // rbln_vae_scale_factor
|
206
|
-
input_height = rbln_img_height // rbln_vae_scale_factor
|
207
|
-
else:
|
208
|
-
input_width, input_height = model_config.sample_size, model_config.sample_size
|
166
|
+
batch_size = rbln_kwargs.get("batch_size")
|
167
|
+
max_seq_len = rbln_kwargs.get("max_seq_len")
|
168
|
+
sample_size = rbln_kwargs.get("sample_size")
|
169
|
+
is_controlnet = rbln_kwargs.get("is_controlnet")
|
170
|
+
rbln_in_features = None
|
171
|
+
|
172
|
+
if batch_size is None:
|
173
|
+
batch_size = 1
|
174
|
+
|
175
|
+
if sample_size is None:
|
176
|
+
sample_size = model_config.sample_size
|
177
|
+
|
178
|
+
if isinstance(sample_size, int):
|
179
|
+
sample_size = (sample_size, sample_size)
|
180
|
+
|
181
|
+
if max_seq_len is None:
|
182
|
+
raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings )must be specified")
|
209
183
|
|
210
184
|
input_info = [
|
211
|
-
(
|
212
|
-
"sample",
|
213
|
-
[
|
214
|
-
rbln_batch_size,
|
215
|
-
model_config.in_channels,
|
216
|
-
input_height,
|
217
|
-
input_width,
|
218
|
-
],
|
219
|
-
"float32",
|
220
|
-
),
|
185
|
+
("sample", [batch_size, model_config.in_channels, sample_size[0], sample_size[1]], "float32"),
|
221
186
|
("timestep", [], "float32"),
|
222
|
-
(
|
223
|
-
"encoder_hidden_states",
|
224
|
-
[
|
225
|
-
rbln_batch_size,
|
226
|
-
rbln_max_seq_len,
|
227
|
-
model_config.cross_attention_dim,
|
228
|
-
],
|
229
|
-
"float32",
|
230
|
-
),
|
187
|
+
("encoder_hidden_states", [batch_size, max_seq_len, model_config.cross_attention_dim], "float32"),
|
231
188
|
]
|
232
189
|
|
233
|
-
if
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
)
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
]
|
262
|
-
)
|
263
|
-
if len(model_config.block_out_channels) > 2:
|
264
|
-
input_info.append(
|
265
|
-
(
|
266
|
-
f"down_block_additional_residuals_{6}",
|
267
|
-
[rbln_batch_size, model_config.block_out_channels[1], input_height // 4, input_width // 4],
|
268
|
-
"float32",
|
269
|
-
)
|
270
|
-
)
|
271
|
-
input_info.extend(
|
272
|
-
[
|
273
|
-
(
|
274
|
-
f"down_block_additional_residuals_{i}",
|
275
|
-
[rbln_batch_size, model_config.block_out_channels[2], input_height // 4, input_width // 4],
|
276
|
-
"float32",
|
277
|
-
)
|
278
|
-
for i in range(7, 9)
|
279
|
-
]
|
280
|
-
)
|
281
|
-
if len(model_config.block_out_channels) > 3:
|
282
|
-
input_info.extend(
|
283
|
-
[
|
284
|
-
(
|
285
|
-
f"down_block_additional_residuals_{i}",
|
286
|
-
[rbln_batch_size, model_config.block_out_channels[3], input_height // 8, input_width // 8],
|
287
|
-
"float32",
|
288
|
-
)
|
289
|
-
for i in range(9, 12)
|
290
|
-
]
|
291
|
-
)
|
292
|
-
input_info.append(
|
293
|
-
(
|
294
|
-
"mid_block_additional_residual",
|
295
|
-
[
|
296
|
-
rbln_batch_size,
|
297
|
-
model_config.block_out_channels[-1],
|
298
|
-
input_height // 2 ** (len(model_config.block_out_channels) - 1),
|
299
|
-
input_width // 2 ** (len(model_config.block_out_channels) - 1),
|
300
|
-
],
|
301
|
-
"float32",
|
302
|
-
)
|
303
|
-
)
|
190
|
+
if is_controlnet:
|
191
|
+
# down block addtional residuals
|
192
|
+
first_shape = [batch_size, model_config.block_out_channels[0], sample_size[0], sample_size[1]]
|
193
|
+
height, width = sample_size[0], sample_size[1]
|
194
|
+
input_info.append(("down_block_additional_residuals_0", first_shape, "float32"))
|
195
|
+
name_idx = 1
|
196
|
+
for idx, _ in enumerate(model_config.down_block_types):
|
197
|
+
shape = [batch_size, model_config.block_out_channels[idx], height, width]
|
198
|
+
for _ in range(model_config.layers_per_block):
|
199
|
+
input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
|
200
|
+
name_idx += 1
|
201
|
+
if idx != len(model_config.down_block_types) - 1:
|
202
|
+
height = height // 2
|
203
|
+
width = width // 2
|
204
|
+
shape = [batch_size, model_config.block_out_channels[idx], height, width]
|
205
|
+
input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
|
206
|
+
name_idx += 1
|
207
|
+
|
208
|
+
# mid block addtional residual
|
209
|
+
num_cross_attn_blocks = model_config.down_block_types.count("CrossAttnDownBlock2D")
|
210
|
+
out_channels = model_config.block_out_channels[-1]
|
211
|
+
shape = [
|
212
|
+
batch_size,
|
213
|
+
out_channels,
|
214
|
+
sample_size[0] // 2**num_cross_attn_blocks,
|
215
|
+
sample_size[1] // 2**num_cross_attn_blocks,
|
216
|
+
]
|
217
|
+
input_info.append(("mid_block_additional_residual", shape, "float32"))
|
304
218
|
|
305
219
|
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
306
220
|
|
307
221
|
if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
|
308
|
-
|
309
|
-
|
310
|
-
if rbln_in_features is None:
|
311
|
-
rbln_in_features = model_config.projection_class_embeddings_input_dim
|
222
|
+
rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
|
223
|
+
rbln_in_features = model_config.projection_class_embeddings_input_dim
|
312
224
|
rbln_compile_config.input_info.append(
|
313
|
-
("text_embeds", [
|
225
|
+
("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32")
|
314
226
|
)
|
315
|
-
rbln_compile_config.input_info.append(("time_ids", [
|
227
|
+
rbln_compile_config.input_info.append(("time_ids", [batch_size, 6], "float32"))
|
316
228
|
|
317
229
|
rbln_config = RBLNConfig(
|
318
230
|
rbln_cls=cls.__name__,
|
@@ -320,14 +232,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
320
232
|
rbln_kwargs=rbln_kwargs,
|
321
233
|
)
|
322
234
|
|
323
|
-
rbln_config.model_cfg.update(
|
324
|
-
{
|
325
|
-
"max_seq_len": rbln_max_seq_len,
|
326
|
-
"batch_size": rbln_batch_size,
|
327
|
-
"use_encode": rbln_use_encode,
|
328
|
-
}
|
329
|
-
)
|
330
|
-
|
331
235
|
if rbln_in_features is not None:
|
332
236
|
rbln_config.model_cfg["in_features"] = rbln_in_features
|
333
237
|
|