optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +37 -2
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
- optimum/rbln/diffusers/models/controlnet.py +56 -40
- optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
- optimum/rbln/modeling_alias.py +3 -3
- optimum/rbln/modeling_base.py +471 -231
- optimum/rbln/modeling_config.py +152 -77
- optimum/rbln/modeling_seq2seq.py +166 -77
- optimum/rbln/transformers/__init__.py +35 -1
- optimum/rbln/transformers/models/__init__.py +20 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
- optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
- optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
- optimum/rbln/utils/import_utils.py +36 -1
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +33 -0
- optimum/rbln/utils/timer_utils.py +19 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
- optimum_rbln-0.1.11.dist-info/RECORD +93 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.9.dist-info/RECORD +0 -78
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING
|
|
25
25
|
|
26
26
|
from transformers.utils import _LazyModule
|
27
27
|
|
28
|
+
from .__version__ import __version__
|
28
29
|
from .utils import check_version_compats
|
29
30
|
|
30
31
|
|
@@ -54,13 +55,30 @@ _import_structure = {
|
|
54
55
|
],
|
55
56
|
"transformers": [
|
56
57
|
"BatchTextIteratorStreamer",
|
58
|
+
"RBLNAutoModel",
|
59
|
+
"RBLNAutoModelForAudioClassification",
|
60
|
+
"RBLNAutoModelForCausalLM",
|
61
|
+
"RBLNAutoModelForCTC",
|
62
|
+
"RBLNAutoModelForDepthEstimation",
|
63
|
+
"RBLNAutoModelForImageClassification",
|
64
|
+
"RBLNAutoModelForMaskedLM",
|
65
|
+
"RBLNAutoModelForQuestionAnswering",
|
66
|
+
"RBLNAutoModelForSeq2SeqLM",
|
67
|
+
"RBLNAutoModelForSequenceClassification",
|
68
|
+
"RBLNAutoModelForSpeechSeq2Seq",
|
69
|
+
"RBLNAutoModelForVision2Seq",
|
70
|
+
"RBLNBartModel",
|
71
|
+
"RBLNBertModel",
|
57
72
|
"RBLNCLIPTextModel",
|
58
73
|
"RBLNCLIPTextModelWithProjection",
|
74
|
+
"RBLNCLIPVisionModel",
|
59
75
|
"RBLNDPTForDepthEstimation",
|
60
76
|
"RBLNGemmaForCausalLM",
|
61
77
|
"RBLNGPT2LMHeadModel",
|
62
78
|
"RBLNWav2Vec2ForCTC",
|
63
79
|
"RBLNLlamaForCausalLM",
|
80
|
+
"RBLNPhiForCausalLM",
|
81
|
+
"RBLNLlavaNextForConditionalGeneration",
|
64
82
|
"RBLNMidmLMHeadModel",
|
65
83
|
"RBLNMistralForCausalLM",
|
66
84
|
"RBLNWhisperForConditionalGeneration",
|
@@ -80,7 +98,7 @@ _import_structure = {
|
|
80
98
|
"RBLNStableDiffusionXLControlNetPipeline",
|
81
99
|
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
82
100
|
],
|
83
|
-
"modeling_config": ["
|
101
|
+
"modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
|
84
102
|
}
|
85
103
|
|
86
104
|
if TYPE_CHECKING:
|
@@ -117,18 +135,35 @@ if TYPE_CHECKING:
|
|
117
135
|
RBLNModelForQuestionAnswering,
|
118
136
|
RBLNModelForSequenceClassification,
|
119
137
|
)
|
120
|
-
from .modeling_config import
|
138
|
+
from .modeling_config import RBLNCompileConfig, RBLNConfig
|
121
139
|
from .modeling_seq2seq import RBLNModelForSeq2SeqLM
|
122
140
|
from .transformers import (
|
123
141
|
BatchTextIteratorStreamer,
|
142
|
+
RBLNAutoModel,
|
143
|
+
RBLNAutoModelForAudioClassification,
|
144
|
+
RBLNAutoModelForCausalLM,
|
145
|
+
RBLNAutoModelForCTC,
|
146
|
+
RBLNAutoModelForDepthEstimation,
|
147
|
+
RBLNAutoModelForImageClassification,
|
148
|
+
RBLNAutoModelForMaskedLM,
|
149
|
+
RBLNAutoModelForQuestionAnswering,
|
150
|
+
RBLNAutoModelForSeq2SeqLM,
|
151
|
+
RBLNAutoModelForSequenceClassification,
|
152
|
+
RBLNAutoModelForSpeechSeq2Seq,
|
153
|
+
RBLNAutoModelForVision2Seq,
|
154
|
+
RBLNBartModel,
|
155
|
+
RBLNBertModel,
|
124
156
|
RBLNCLIPTextModel,
|
125
157
|
RBLNCLIPTextModelWithProjection,
|
158
|
+
RBLNCLIPVisionModel,
|
126
159
|
RBLNDPTForDepthEstimation,
|
127
160
|
RBLNGemmaForCausalLM,
|
128
161
|
RBLNGPT2LMHeadModel,
|
129
162
|
RBLNLlamaForCausalLM,
|
163
|
+
RBLNLlavaNextForConditionalGeneration,
|
130
164
|
RBLNMidmLMHeadModel,
|
131
165
|
RBLNMistralForCausalLM,
|
166
|
+
RBLNPhiForCausalLM,
|
132
167
|
RBLNWav2Vec2ForCTC,
|
133
168
|
RBLNWhisperForConditionalGeneration,
|
134
169
|
RBLNXLMRobertaModel,
|
optimum/rbln/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = '0.1.
|
1
|
+
__version__ = '0.1.11'
|
@@ -23,7 +23,7 @@
|
|
23
23
|
|
24
24
|
import logging
|
25
25
|
from pathlib import Path
|
26
|
-
from typing import TYPE_CHECKING, Dict, List,
|
26
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
27
27
|
|
28
28
|
import rebel
|
29
29
|
import torch # noqa: I001
|
@@ -34,7 +34,7 @@ from optimum.exporters import TasksManager
|
|
34
34
|
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
35
35
|
|
36
36
|
from ...modeling_base import RBLNModel
|
37
|
-
from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME,
|
37
|
+
from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
38
38
|
from ...utils.runtime_utils import RBLNPytorchRuntime
|
39
39
|
|
40
40
|
|
@@ -63,10 +63,9 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
63
63
|
auto_model_class = AutoModel # feature extraction
|
64
64
|
|
65
65
|
def __post_init__(self, **kwargs):
|
66
|
-
|
67
|
-
|
68
|
-
self.rbln_use_encode = self.rbln_config.meta["rbln_use_encode"]
|
66
|
+
super().__post_init__(**kwargs)
|
69
67
|
|
68
|
+
self.rbln_use_encode = self.rbln_config.model_cfg["use_encode"]
|
70
69
|
if self.rbln_use_encode:
|
71
70
|
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
|
72
71
|
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
|
@@ -81,20 +80,20 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
81
80
|
encoder_model.eval()
|
82
81
|
decoder_model.eval()
|
83
82
|
|
84
|
-
enc_compiled_model = cls.compile(encoder_model,
|
85
|
-
dec_compiled_model = cls.compile(decoder_model,
|
83
|
+
enc_compiled_model = cls.compile(encoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
|
84
|
+
dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[1])
|
86
85
|
|
87
|
-
return enc_compiled_model, dec_compiled_model
|
86
|
+
return {"encoder": enc_compiled_model, "decoder": dec_compiled_model}
|
88
87
|
|
89
88
|
def compile_text2img():
|
90
89
|
decoder_model = _VAEDecoder(model)
|
91
90
|
decoder_model.eval()
|
92
91
|
|
93
|
-
dec_compiled_model = cls.compile(decoder_model,
|
92
|
+
dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
|
94
93
|
|
95
94
|
return dec_compiled_model
|
96
95
|
|
97
|
-
if rbln_config.
|
96
|
+
if rbln_config.model_cfg.get("use_encode", False):
|
98
97
|
return compile_img2img()
|
99
98
|
else:
|
100
99
|
return compile_text2img()
|
@@ -133,23 +132,23 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
133
132
|
cls,
|
134
133
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
135
134
|
model_config: "PretrainedConfig",
|
136
|
-
|
137
|
-
rbln_img_width: Optional[int] = None,
|
138
|
-
rbln_img_height: Optional[int] = None,
|
139
|
-
rbln_batch_size: Optional[int] = None,
|
140
|
-
rbln_use_encode: Optional[bool] = None,
|
141
|
-
rbln_vae_scale_factor: Optional[int] = None,
|
135
|
+
rbln_kwargs: Dict[str, Any] = {},
|
142
136
|
) -> RBLNConfig:
|
143
|
-
|
137
|
+
rbln_unet_sample_size = rbln_kwargs.get("unet_sample_size", None)
|
138
|
+
rbln_img_width = rbln_kwargs.get("img_width", None)
|
139
|
+
rbln_img_height = rbln_kwargs.get("img_height", None)
|
140
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
141
|
+
rbln_use_encode = rbln_kwargs.get("use_encode", None)
|
142
|
+
rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
|
143
|
+
|
144
144
|
if rbln_batch_size is None:
|
145
145
|
rbln_batch_size = 1
|
146
146
|
|
147
|
-
|
148
|
-
meta["rbln_batch_size"] = rbln_batch_size
|
147
|
+
model_cfg = {}
|
149
148
|
|
150
149
|
if rbln_use_encode:
|
151
|
-
|
152
|
-
|
150
|
+
model_cfg["img_width"] = rbln_img_width
|
151
|
+
model_cfg["img_height"] = rbln_img_height
|
153
152
|
|
154
153
|
vae_enc_input_info = [
|
155
154
|
("x", [rbln_batch_size, model_config.in_channels, rbln_img_height, rbln_img_width], "float32")
|
@@ -167,20 +166,23 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
167
166
|
)
|
168
167
|
]
|
169
168
|
|
170
|
-
|
171
|
-
|
169
|
+
enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info)
|
170
|
+
dec_rbln_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info)
|
172
171
|
|
173
|
-
|
174
|
-
|
175
|
-
|
172
|
+
compile_cfgs = [enc_rbln_compile_config, dec_rbln_compile_config]
|
173
|
+
rbln_config = RBLNConfig(
|
174
|
+
rbln_cls=cls.__name__,
|
175
|
+
compile_cfgs=compile_cfgs,
|
176
|
+
rbln_kwargs=rbln_kwargs,
|
176
177
|
)
|
178
|
+
rbln_config.model_cfg.update(model_cfg)
|
177
179
|
return rbln_config
|
178
180
|
|
179
181
|
if rbln_unet_sample_size is None:
|
180
182
|
rbln_unet_sample_size = 64
|
181
183
|
|
182
|
-
|
183
|
-
vae_config =
|
184
|
+
model_cfg["unet_sample_size"] = rbln_unet_sample_size
|
185
|
+
vae_config = RBLNCompileConfig(
|
184
186
|
input_info=[
|
185
187
|
(
|
186
188
|
"z",
|
@@ -194,7 +196,12 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
194
196
|
)
|
195
197
|
],
|
196
198
|
)
|
197
|
-
rbln_config = RBLNConfig
|
199
|
+
rbln_config = RBLNConfig(
|
200
|
+
rbln_cls=cls.__name__,
|
201
|
+
compile_cfgs=[vae_config],
|
202
|
+
rbln_kwargs=rbln_kwargs,
|
203
|
+
)
|
204
|
+
rbln_config.model_cfg.update(model_cfg)
|
198
205
|
return rbln_config
|
199
206
|
|
200
207
|
@classmethod
|
@@ -23,7 +23,7 @@
|
|
23
23
|
|
24
24
|
import logging
|
25
25
|
from pathlib import Path
|
26
|
-
from typing import TYPE_CHECKING, Dict, Optional, Union
|
26
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
27
27
|
|
28
28
|
import torch
|
29
29
|
from diffusers import ControlNetModel
|
@@ -31,7 +31,7 @@ from optimum.exporters import TasksManager
|
|
31
31
|
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
32
32
|
|
33
33
|
from ...modeling_base import RBLNModel
|
34
|
-
from ...modeling_config import
|
34
|
+
from ...modeling_config import RBLNCompileConfig, RBLNConfig
|
35
35
|
|
36
36
|
|
37
37
|
if TYPE_CHECKING:
|
@@ -109,21 +109,21 @@ class RBLNControlNetModel(RBLNModel):
|
|
109
109
|
auto_model_class = AutoModel # feature extraction
|
110
110
|
|
111
111
|
def __post_init__(self, **kwargs):
|
112
|
-
|
112
|
+
super().__post_init__(**kwargs)
|
113
113
|
self.use_encoder_hidden_states = any(
|
114
|
-
item[0] == "encoder_hidden_states" for item in self.rbln_config[
|
114
|
+
item[0] == "encoder_hidden_states" for item in self.rbln_config.compile_cfgs[0].input_info
|
115
115
|
)
|
116
116
|
|
117
117
|
@classmethod
|
118
118
|
def from_pretrained(cls, *args, **kwargs):
|
119
|
+
if "subfolder" in kwargs:
|
120
|
+
del kwargs["subfolder"]
|
121
|
+
|
119
122
|
def get_model_from_task(
|
120
123
|
task: str,
|
121
124
|
model_name_or_path: Union[str, Path],
|
122
125
|
**kwargs,
|
123
126
|
):
|
124
|
-
if "subfolder" in kwargs:
|
125
|
-
del kwargs["subfolder"]
|
126
|
-
|
127
127
|
return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
|
128
128
|
|
129
129
|
tasktmp = TasksManager.get_model_from_task
|
@@ -155,14 +155,14 @@ class RBLNControlNetModel(RBLNModel):
|
|
155
155
|
cls,
|
156
156
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
157
157
|
model_config: "PretrainedConfig",
|
158
|
-
|
159
|
-
rbln_text_model_hidden_size: Optional[int] = None,
|
160
|
-
rbln_batch_size: Optional[int] = None,
|
161
|
-
rbln_img_width: Optional[int] = None,
|
162
|
-
rbln_img_height: Optional[int] = None,
|
163
|
-
rbln_vae_scale_factor: Optional[int] = None,
|
158
|
+
rbln_kwargs: Dict[str, Any] = {},
|
164
159
|
) -> RBLNConfig:
|
165
|
-
|
160
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
161
|
+
rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
|
162
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
163
|
+
rbln_img_width = rbln_kwargs.get("img_width", None)
|
164
|
+
rbln_img_height = rbln_kwargs.get("img_height", None)
|
165
|
+
rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
|
166
166
|
|
167
167
|
if rbln_batch_size is None:
|
168
168
|
rbln_batch_size = 1
|
@@ -170,28 +170,29 @@ class RBLNControlNetModel(RBLNModel):
|
|
170
170
|
if rbln_max_seq_len is None:
|
171
171
|
rbln_max_seq_len = 77
|
172
172
|
|
173
|
+
if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
|
174
|
+
raise ValueError("rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided")
|
175
|
+
|
173
176
|
input_width = rbln_img_width // rbln_vae_scale_factor
|
174
177
|
input_height = rbln_img_height // rbln_vae_scale_factor
|
175
178
|
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
batch_size=rbln_batch_size,
|
191
|
-
)
|
179
|
+
input_info = [
|
180
|
+
(
|
181
|
+
"sample",
|
182
|
+
[
|
183
|
+
rbln_batch_size,
|
184
|
+
model_config.in_channels,
|
185
|
+
input_height,
|
186
|
+
input_width,
|
187
|
+
],
|
188
|
+
"float32",
|
189
|
+
),
|
190
|
+
("timestep", [], "float32"),
|
191
|
+
]
|
192
|
+
|
192
193
|
use_encoder_hidden_states = any(element != "DownBlock2D" for element in model_config.down_block_types)
|
193
194
|
if use_encoder_hidden_states:
|
194
|
-
|
195
|
+
input_info.append(
|
195
196
|
(
|
196
197
|
"encoder_hidden_states",
|
197
198
|
[
|
@@ -202,19 +203,34 @@ class RBLNControlNetModel(RBLNModel):
|
|
202
203
|
"float32",
|
203
204
|
)
|
204
205
|
)
|
205
|
-
|
206
|
-
|
207
|
-
)
|
208
|
-
|
206
|
+
|
207
|
+
input_info.append(("controlnet_cond", [rbln_batch_size, 3, rbln_img_height, rbln_img_width], "float32"))
|
208
|
+
input_info.append(("conditioning_scale", [], "float32"))
|
209
|
+
|
209
210
|
if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
|
210
211
|
if rbln_text_model_hidden_size is None:
|
211
212
|
rbln_text_model_hidden_size = 768
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
213
|
+
input_info.append(("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32"))
|
214
|
+
input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
|
215
|
+
|
216
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
217
|
+
|
218
|
+
rbln_config = RBLNConfig(
|
219
|
+
rbln_cls=cls.__name__,
|
220
|
+
compile_cfgs=[rbln_compile_config],
|
221
|
+
rbln_kwargs=rbln_kwargs,
|
222
|
+
)
|
223
|
+
|
224
|
+
rbln_config.model_cfg.update(
|
225
|
+
{
|
226
|
+
"max_seq_len": rbln_max_seq_len,
|
227
|
+
"batch_size": rbln_batch_size,
|
228
|
+
"img_width": rbln_img_width,
|
229
|
+
"img_height": rbln_img_height,
|
230
|
+
"vae_scale_factor": rbln_vae_scale_factor,
|
231
|
+
}
|
232
|
+
)
|
216
233
|
|
217
|
-
rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
218
234
|
return rbln_config
|
219
235
|
|
220
236
|
def forward(
|
@@ -32,7 +32,7 @@ from optimum.exporters import TasksManager
|
|
32
32
|
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
33
33
|
|
34
34
|
from ...modeling_base import RBLNModel
|
35
|
-
from ...modeling_config import
|
35
|
+
from ...modeling_config import RBLNCompileConfig, RBLNConfig
|
36
36
|
|
37
37
|
|
38
38
|
if TYPE_CHECKING:
|
@@ -130,8 +130,8 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
130
130
|
auto_model_class = AutoModel # feature extraction
|
131
131
|
|
132
132
|
def __post_init__(self, **kwargs):
|
133
|
-
|
134
|
-
self.in_features = self.rbln_config.
|
133
|
+
super().__post_init__(**kwargs)
|
134
|
+
self.in_features = self.rbln_config.model_cfg.get("in_features", None)
|
135
135
|
if self.in_features is not None:
|
136
136
|
|
137
137
|
@dataclass
|
@@ -183,31 +183,31 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
183
183
|
cls,
|
184
184
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
185
185
|
model_config: "PretrainedConfig",
|
186
|
-
|
187
|
-
rbln_text_model_hidden_size: Optional[int] = None,
|
188
|
-
rbln_batch_size: Optional[int] = None,
|
189
|
-
rbln_in_features: Optional[int] = None,
|
190
|
-
rbln_use_encode: Optional[bool] = None,
|
191
|
-
rbln_img_width: Optional[int] = None,
|
192
|
-
rbln_img_height: Optional[int] = None,
|
193
|
-
rbln_vae_scale_factor: Optional[int] = None,
|
194
|
-
rbln_is_controlnet: Optional[bool] = None,
|
186
|
+
rbln_kwargs: Dict[str, Any] = {},
|
195
187
|
) -> RBLNConfig:
|
196
|
-
|
197
|
-
|
198
|
-
|
188
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
189
|
+
rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
|
190
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
191
|
+
rbln_in_features = rbln_kwargs.get("in_features", None)
|
192
|
+
rbln_use_encode = rbln_kwargs.get("use_encode", None)
|
193
|
+
rbln_img_width = rbln_kwargs.get("img_width", None)
|
194
|
+
rbln_img_height = rbln_kwargs.get("img_height", None)
|
195
|
+
rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
|
196
|
+
rbln_is_controlnet = rbln_kwargs.get("is_controlnet", None)
|
199
197
|
|
200
198
|
if rbln_max_seq_len is None:
|
201
199
|
rbln_max_seq_len = 77
|
202
|
-
|
203
|
-
|
200
|
+
if rbln_batch_size is None:
|
201
|
+
rbln_batch_size = 1
|
204
202
|
|
205
203
|
if rbln_use_encode:
|
206
|
-
|
204
|
+
if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
|
205
|
+
raise ValueError(
|
206
|
+
"rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided when rbln_use_encode is True"
|
207
|
+
)
|
207
208
|
input_width = rbln_img_width // rbln_vae_scale_factor
|
208
209
|
input_height = rbln_img_height // rbln_vae_scale_factor
|
209
210
|
else:
|
210
|
-
# FIXME :: model_config.sample_size can be tuple or list
|
211
211
|
input_width, input_height = model_config.sample_size, model_config.sample_size
|
212
212
|
|
213
213
|
input_info = [
|
@@ -232,6 +232,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
232
232
|
"float32",
|
233
233
|
),
|
234
234
|
]
|
235
|
+
|
235
236
|
if rbln_is_controlnet:
|
236
237
|
if len(model_config.block_out_channels) > 0:
|
237
238
|
input_info.extend(
|
@@ -304,24 +305,35 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
304
305
|
)
|
305
306
|
)
|
306
307
|
|
307
|
-
|
308
|
-
input_info=input_info,
|
309
|
-
batch_size=rbln_batch_size,
|
310
|
-
)
|
308
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
311
309
|
|
312
310
|
if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
|
313
|
-
# In case of sdxl
|
314
311
|
if rbln_text_model_hidden_size is None:
|
315
312
|
rbln_text_model_hidden_size = 768
|
316
313
|
if rbln_in_features is None:
|
317
314
|
rbln_in_features = model_config.projection_class_embeddings_input_dim
|
318
|
-
|
319
|
-
rbln_runtime_config.input_info.append(
|
315
|
+
rbln_compile_config.input_info.append(
|
320
316
|
("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
|
321
317
|
)
|
322
|
-
|
318
|
+
rbln_compile_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
|
319
|
+
|
320
|
+
rbln_config = RBLNConfig(
|
321
|
+
rbln_cls=cls.__name__,
|
322
|
+
compile_cfgs=[rbln_compile_config],
|
323
|
+
rbln_kwargs=rbln_kwargs,
|
324
|
+
)
|
325
|
+
|
326
|
+
rbln_config.model_cfg.update(
|
327
|
+
{
|
328
|
+
"max_seq_len": rbln_max_seq_len,
|
329
|
+
"batch_size": rbln_batch_size,
|
330
|
+
"use_encode": rbln_use_encode,
|
331
|
+
}
|
332
|
+
)
|
333
|
+
|
334
|
+
if rbln_in_features is not None:
|
335
|
+
rbln_config.model_cfg["in_features"] = rbln_in_features
|
323
336
|
|
324
|
-
rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
325
337
|
return rbln_config
|
326
338
|
|
327
339
|
def forward(
|
@@ -37,6 +37,7 @@ from transformers import CLIPTextModel
|
|
37
37
|
|
38
38
|
from ....modeling_base import RBLNBaseModel
|
39
39
|
from ....transformers import RBLNCLIPTextModel
|
40
|
+
from ....utils.runtime_utils import ContextRblnConfig
|
40
41
|
from ...models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
|
41
42
|
from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
|
42
43
|
|
@@ -69,8 +70,13 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
69
70
|
text_encoder = kwargs.pop("text_encoder", None)
|
70
71
|
controlnet = kwargs.pop("controlnet", None)
|
71
72
|
model_save_dir = kwargs.pop("model_save_dir", None)
|
73
|
+
rbln_config = kwargs.pop("rbln_config", None)
|
74
|
+
rbln_kwargs, _ = RBLNBaseModel.resolve_rbln_config(rbln_config, kwargs)
|
72
75
|
|
73
|
-
|
76
|
+
device = rbln_kwargs.get("device", None)
|
77
|
+
device_map = rbln_kwargs.get("device_map", None)
|
78
|
+
create_runtimes = rbln_kwargs.get("create_runtimes", None)
|
79
|
+
optimize_host_memory = rbln_kwargs.get("optimize_host_memory", None)
|
74
80
|
|
75
81
|
kwargs_dict = {
|
76
82
|
"pretrained_model_name_or_path": model_id,
|
@@ -98,13 +104,19 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
98
104
|
}
|
99
105
|
)
|
100
106
|
|
101
|
-
|
107
|
+
with ContextRblnConfig(
|
108
|
+
device=device,
|
109
|
+
device_map=device_map,
|
110
|
+
create_runtimes=create_runtimes,
|
111
|
+
optimze_host_mem=optimize_host_memory,
|
112
|
+
):
|
113
|
+
model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
|
102
114
|
|
103
115
|
if export is None or export is False:
|
104
116
|
return model
|
105
117
|
|
106
118
|
do_classifier_free_guidance = (
|
107
|
-
|
119
|
+
rbln_kwargs.pop("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
108
120
|
)
|
109
121
|
|
110
122
|
# compile model, create runtime
|
@@ -117,8 +129,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
117
129
|
rbln_unet_sample_size=model.unet.config.sample_size,
|
118
130
|
rbln_use_encode=False,
|
119
131
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
120
|
-
**
|
121
|
-
**rbln_constructor_kwargs,
|
132
|
+
rbln_config={**rbln_kwargs},
|
122
133
|
)
|
123
134
|
|
124
135
|
if not isinstance(text_encoder, RBLNCLIPTextModel):
|
@@ -127,11 +138,10 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
127
138
|
subfolder="text_encoder",
|
128
139
|
export=True,
|
129
140
|
model_save_dir=model_save_dir,
|
130
|
-
**
|
131
|
-
**rbln_constructor_kwargs,
|
141
|
+
rbln_config={**rbln_kwargs},
|
132
142
|
)
|
133
143
|
|
134
|
-
batch_size =
|
144
|
+
batch_size = rbln_kwargs.pop("batch_size", 1)
|
135
145
|
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
136
146
|
|
137
147
|
if not isinstance(unet, RBLNUNet2DConditionModel):
|
@@ -145,8 +155,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
145
155
|
rbln_use_encode=False,
|
146
156
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
147
157
|
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
148
|
-
**
|
149
|
-
**rbln_constructor_kwargs,
|
158
|
+
rbln_config={**rbln_kwargs},
|
150
159
|
)
|
151
160
|
|
152
161
|
if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
|
@@ -162,8 +171,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
162
171
|
model_save_dir=model_save_dir,
|
163
172
|
rbln_batch_size=unet_batch_size,
|
164
173
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
165
|
-
**
|
166
|
-
**rbln_constructor_kwargs,
|
174
|
+
rbln_config={**rbln_kwargs},
|
167
175
|
)
|
168
176
|
)
|
169
177
|
controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
|
@@ -176,8 +184,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
176
184
|
model_save_dir=model_save_dir,
|
177
185
|
rbln_batch_size=unet_batch_size,
|
178
186
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
179
|
-
**
|
180
|
-
**rbln_constructor_kwargs,
|
187
|
+
rbln_config={**rbln_kwargs},
|
181
188
|
)
|
182
189
|
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
183
190
|
|
@@ -209,7 +216,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
209
216
|
model.save_config(model_save_dir)
|
210
217
|
|
211
218
|
# use for CI to access each compiled model
|
212
|
-
if
|
219
|
+
if optimize_host_memory is False:
|
213
220
|
model.compiled_models = [vae.compiled_models[0], text_encoder.compiled_models[0], unet.compiled_models[0]]
|
214
221
|
if isinstance(controlnet, RBLNMultiControlNetModel):
|
215
222
|
for c_model in controlnet.nets:
|