optimum-rbln 0.1.9__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +47 -9
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +36 -31
- optimum/rbln/diffusers/models/controlnet.py +53 -43
- optimum/rbln/diffusers/models/unet_2d_condition.py +40 -31
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +28 -23
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +28 -23
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +28 -37
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +30 -39
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +24 -15
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +26 -17
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -17
- optimum/rbln/modeling_alias.py +6 -11
- optimum/rbln/modeling_base.py +467 -261
- optimum/rbln/modeling_config.py +199 -73
- optimum/rbln/transformers/__init__.py +43 -1
- optimum/rbln/transformers/models/__init__.py +23 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +95 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +203 -58
- optimum/rbln/transformers/models/bart/modeling_bart.py +125 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +101 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +127 -26
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +409 -150
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -8
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +662 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +6 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +198 -168
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +122 -47
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -12
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +172 -111
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +18 -16
- optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
- optimum/rbln/utils/import_utils.py +50 -1
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +33 -0
- optimum/rbln/utils/timer_utils.py +43 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +9 -7
- optimum_rbln-0.1.12.dist-info/RECORD +103 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.12.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.9.dist-info/RECORD +0 -78
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING
|
|
25
25
|
|
26
26
|
from transformers.utils import _LazyModule
|
27
27
|
|
28
|
+
from .__version__ import __version__
|
28
29
|
from .utils import check_version_compats
|
29
30
|
|
30
31
|
|
@@ -34,11 +35,10 @@ _import_structure = {
|
|
34
35
|
"RBLNBertForQuestionAnswering",
|
35
36
|
"RBLNDistilBertForQuestionAnswering",
|
36
37
|
"RBLNResNetForImageClassification",
|
37
|
-
"RBLNT5ForConditionalGeneration",
|
38
|
-
"RBLNBartForConditionalGeneration",
|
39
38
|
"RBLNXLMRobertaForSequenceClassification",
|
40
39
|
"RBLNRobertaForSequenceClassification",
|
41
40
|
"RBLNRobertaForMaskedLM",
|
41
|
+
"RBLNViTForImageClassification"
|
42
42
|
],
|
43
43
|
"modeling_base": [
|
44
44
|
"RBLNBaseModel",
|
@@ -49,18 +49,36 @@ _import_structure = {
|
|
49
49
|
"RBLNModelForSequenceClassification",
|
50
50
|
"RBLNModelForMaskedLM",
|
51
51
|
],
|
52
|
-
"modeling_seq2seq": [
|
53
|
-
"RBLNModelForSeq2SeqLM",
|
54
|
-
],
|
55
52
|
"transformers": [
|
56
53
|
"BatchTextIteratorStreamer",
|
54
|
+
"RBLNAutoModel",
|
55
|
+
"RBLNAutoModelForAudioClassification",
|
56
|
+
"RBLNAutoModelForCausalLM",
|
57
|
+
"RBLNAutoModelForCTC",
|
58
|
+
"RBLNAutoModelForDepthEstimation",
|
59
|
+
"RBLNAutoModelForImageClassification",
|
60
|
+
"RBLNAutoModelForMaskedLM",
|
61
|
+
"RBLNAutoModelForQuestionAnswering",
|
62
|
+
"RBLNAutoModelForSeq2SeqLM",
|
63
|
+
"RBLNAutoModelForSequenceClassification",
|
64
|
+
"RBLNAutoModelForSpeechSeq2Seq",
|
65
|
+
"RBLNAutoModelForVision2Seq",
|
66
|
+
"RBLNBartForConditionalGeneration",
|
67
|
+
"RBLNBartModel",
|
68
|
+
"RBLNBertModel",
|
57
69
|
"RBLNCLIPTextModel",
|
58
70
|
"RBLNCLIPTextModelWithProjection",
|
71
|
+
"RBLNCLIPVisionModel",
|
59
72
|
"RBLNDPTForDepthEstimation",
|
73
|
+
"RBLNExaoneForCausalLM",
|
60
74
|
"RBLNGemmaForCausalLM",
|
61
75
|
"RBLNGPT2LMHeadModel",
|
76
|
+
"RBLNQwen2ForCausalLM",
|
62
77
|
"RBLNWav2Vec2ForCTC",
|
63
78
|
"RBLNLlamaForCausalLM",
|
79
|
+
"RBLNT5ForConditionalGeneration",
|
80
|
+
"RBLNPhiForCausalLM",
|
81
|
+
"RBLNLlavaNextForConditionalGeneration",
|
64
82
|
"RBLNMidmLMHeadModel",
|
65
83
|
"RBLNMistralForCausalLM",
|
66
84
|
"RBLNWhisperForConditionalGeneration",
|
@@ -80,7 +98,7 @@ _import_structure = {
|
|
80
98
|
"RBLNStableDiffusionXLControlNetPipeline",
|
81
99
|
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
82
100
|
],
|
83
|
-
"modeling_config": ["
|
101
|
+
"modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
|
84
102
|
}
|
85
103
|
|
86
104
|
if TYPE_CHECKING:
|
@@ -100,12 +118,12 @@ if TYPE_CHECKING:
|
|
100
118
|
)
|
101
119
|
from .modeling_alias import (
|
102
120
|
RBLNASTForAudioClassification,
|
103
|
-
RBLNBartForConditionalGeneration,
|
104
121
|
RBLNBertForQuestionAnswering,
|
105
122
|
RBLNResNetForImageClassification,
|
106
123
|
RBLNRobertaForMaskedLM,
|
107
124
|
RBLNRobertaForSequenceClassification,
|
108
125
|
RBLNT5ForConditionalGeneration,
|
126
|
+
RBLNViTForImageClassification,
|
109
127
|
RBLNXLMRobertaForSequenceClassification,
|
110
128
|
)
|
111
129
|
from .modeling_base import (
|
@@ -117,18 +135,38 @@ if TYPE_CHECKING:
|
|
117
135
|
RBLNModelForQuestionAnswering,
|
118
136
|
RBLNModelForSequenceClassification,
|
119
137
|
)
|
120
|
-
from .modeling_config import
|
121
|
-
from .modeling_seq2seq import RBLNModelForSeq2SeqLM
|
138
|
+
from .modeling_config import RBLNCompileConfig, RBLNConfig
|
122
139
|
from .transformers import (
|
123
140
|
BatchTextIteratorStreamer,
|
141
|
+
RBLNAutoModel,
|
142
|
+
RBLNAutoModelForAudioClassification,
|
143
|
+
RBLNAutoModelForCausalLM,
|
144
|
+
RBLNAutoModelForCTC,
|
145
|
+
RBLNAutoModelForDepthEstimation,
|
146
|
+
RBLNAutoModelForImageClassification,
|
147
|
+
RBLNAutoModelForMaskedLM,
|
148
|
+
RBLNAutoModelForQuestionAnswering,
|
149
|
+
RBLNAutoModelForSeq2SeqLM,
|
150
|
+
RBLNAutoModelForSequenceClassification,
|
151
|
+
RBLNAutoModelForSpeechSeq2Seq,
|
152
|
+
RBLNAutoModelForVision2Seq,
|
153
|
+
RBLNBartForConditionalGeneration,
|
154
|
+
RBLNBartModel,
|
155
|
+
RBLNBertModel,
|
124
156
|
RBLNCLIPTextModel,
|
125
157
|
RBLNCLIPTextModelWithProjection,
|
158
|
+
RBLNCLIPVisionModel,
|
126
159
|
RBLNDPTForDepthEstimation,
|
160
|
+
RBLNExaoneForCausalLM,
|
127
161
|
RBLNGemmaForCausalLM,
|
128
162
|
RBLNGPT2LMHeadModel,
|
129
163
|
RBLNLlamaForCausalLM,
|
164
|
+
RBLNLlavaNextForConditionalGeneration,
|
130
165
|
RBLNMidmLMHeadModel,
|
131
166
|
RBLNMistralForCausalLM,
|
167
|
+
RBLNPhiForCausalLM,
|
168
|
+
RBLNQwen2ForCausalLM,
|
169
|
+
RBLNT5ForConditionalGeneration,
|
132
170
|
RBLNWav2Vec2ForCTC,
|
133
171
|
RBLNWhisperForConditionalGeneration,
|
134
172
|
RBLNXLMRobertaModel,
|
optimum/rbln/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = '0.1.
|
1
|
+
__version__ = '0.1.12'
|
@@ -23,7 +23,7 @@
|
|
23
23
|
|
24
24
|
import logging
|
25
25
|
from pathlib import Path
|
26
|
-
from typing import TYPE_CHECKING, Dict, List,
|
26
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
27
27
|
|
28
28
|
import rebel
|
29
29
|
import torch # noqa: I001
|
@@ -34,7 +34,7 @@ from optimum.exporters import TasksManager
|
|
34
34
|
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
35
35
|
|
36
36
|
from ...modeling_base import RBLNModel
|
37
|
-
from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME,
|
37
|
+
from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
38
38
|
from ...utils.runtime_utils import RBLNPytorchRuntime
|
39
39
|
|
40
40
|
|
@@ -58,15 +58,12 @@ class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
|
|
58
58
|
|
59
59
|
|
60
60
|
class RBLNAutoencoderKL(RBLNModel):
|
61
|
-
model_type = "rbln_model"
|
62
61
|
config_name = "config.json"
|
63
|
-
auto_model_class = AutoModel # feature extraction
|
64
62
|
|
65
63
|
def __post_init__(self, **kwargs):
|
66
|
-
|
67
|
-
|
68
|
-
self.rbln_use_encode = self.rbln_config.meta["rbln_use_encode"]
|
64
|
+
super().__post_init__(**kwargs)
|
69
65
|
|
66
|
+
self.rbln_use_encode = self.rbln_config.model_cfg["use_encode"]
|
70
67
|
if self.rbln_use_encode:
|
71
68
|
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
|
72
69
|
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
|
@@ -81,20 +78,20 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
81
78
|
encoder_model.eval()
|
82
79
|
decoder_model.eval()
|
83
80
|
|
84
|
-
enc_compiled_model = cls.compile(encoder_model,
|
85
|
-
dec_compiled_model = cls.compile(decoder_model,
|
81
|
+
enc_compiled_model = cls.compile(encoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
|
82
|
+
dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[1])
|
86
83
|
|
87
|
-
return enc_compiled_model, dec_compiled_model
|
84
|
+
return {"encoder": enc_compiled_model, "decoder": dec_compiled_model}
|
88
85
|
|
89
86
|
def compile_text2img():
|
90
87
|
decoder_model = _VAEDecoder(model)
|
91
88
|
decoder_model.eval()
|
92
89
|
|
93
|
-
dec_compiled_model = cls.compile(decoder_model,
|
90
|
+
dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
|
94
91
|
|
95
92
|
return dec_compiled_model
|
96
93
|
|
97
|
-
if rbln_config.
|
94
|
+
if rbln_config.model_cfg.get("use_encode", False):
|
98
95
|
return compile_img2img()
|
99
96
|
else:
|
100
97
|
return compile_text2img()
|
@@ -133,23 +130,23 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
133
130
|
cls,
|
134
131
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
135
132
|
model_config: "PretrainedConfig",
|
136
|
-
|
137
|
-
rbln_img_width: Optional[int] = None,
|
138
|
-
rbln_img_height: Optional[int] = None,
|
139
|
-
rbln_batch_size: Optional[int] = None,
|
140
|
-
rbln_use_encode: Optional[bool] = None,
|
141
|
-
rbln_vae_scale_factor: Optional[int] = None,
|
133
|
+
rbln_kwargs: Dict[str, Any] = {},
|
142
134
|
) -> RBLNConfig:
|
143
|
-
|
135
|
+
rbln_unet_sample_size = rbln_kwargs.get("unet_sample_size", None)
|
136
|
+
rbln_img_width = rbln_kwargs.get("img_width", None)
|
137
|
+
rbln_img_height = rbln_kwargs.get("img_height", None)
|
138
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
139
|
+
rbln_use_encode = rbln_kwargs.get("use_encode", None)
|
140
|
+
rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
|
141
|
+
|
144
142
|
if rbln_batch_size is None:
|
145
143
|
rbln_batch_size = 1
|
146
144
|
|
147
|
-
|
148
|
-
meta["rbln_batch_size"] = rbln_batch_size
|
145
|
+
model_cfg = {}
|
149
146
|
|
150
147
|
if rbln_use_encode:
|
151
|
-
|
152
|
-
|
148
|
+
model_cfg["img_width"] = rbln_img_width
|
149
|
+
model_cfg["img_height"] = rbln_img_height
|
153
150
|
|
154
151
|
vae_enc_input_info = [
|
155
152
|
("x", [rbln_batch_size, model_config.in_channels, rbln_img_height, rbln_img_width], "float32")
|
@@ -167,20 +164,23 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
167
164
|
)
|
168
165
|
]
|
169
166
|
|
170
|
-
|
171
|
-
|
167
|
+
enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info)
|
168
|
+
dec_rbln_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info)
|
172
169
|
|
173
|
-
|
174
|
-
|
175
|
-
|
170
|
+
compile_cfgs = [enc_rbln_compile_config, dec_rbln_compile_config]
|
171
|
+
rbln_config = RBLNConfig(
|
172
|
+
rbln_cls=cls.__name__,
|
173
|
+
compile_cfgs=compile_cfgs,
|
174
|
+
rbln_kwargs=rbln_kwargs,
|
176
175
|
)
|
176
|
+
rbln_config.model_cfg.update(model_cfg)
|
177
177
|
return rbln_config
|
178
178
|
|
179
179
|
if rbln_unet_sample_size is None:
|
180
180
|
rbln_unet_sample_size = 64
|
181
181
|
|
182
|
-
|
183
|
-
vae_config =
|
182
|
+
model_cfg["unet_sample_size"] = rbln_unet_sample_size
|
183
|
+
vae_config = RBLNCompileConfig(
|
184
184
|
input_info=[
|
185
185
|
(
|
186
186
|
"z",
|
@@ -194,7 +194,12 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
194
194
|
)
|
195
195
|
],
|
196
196
|
)
|
197
|
-
rbln_config = RBLNConfig
|
197
|
+
rbln_config = RBLNConfig(
|
198
|
+
rbln_cls=cls.__name__,
|
199
|
+
compile_cfgs=[vae_config],
|
200
|
+
rbln_kwargs=rbln_kwargs,
|
201
|
+
)
|
202
|
+
rbln_config.model_cfg.update(model_cfg)
|
198
203
|
return rbln_config
|
199
204
|
|
200
205
|
@classmethod
|
@@ -23,7 +23,7 @@
|
|
23
23
|
|
24
24
|
import logging
|
25
25
|
from pathlib import Path
|
26
|
-
from typing import TYPE_CHECKING, Dict, Optional, Union
|
26
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
27
27
|
|
28
28
|
import torch
|
29
29
|
from diffusers import ControlNetModel
|
@@ -31,7 +31,7 @@ from optimum.exporters import TasksManager
|
|
31
31
|
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
32
32
|
|
33
33
|
from ...modeling_base import RBLNModel
|
34
|
-
from ...modeling_config import
|
34
|
+
from ...modeling_config import RBLNCompileConfig, RBLNConfig
|
35
35
|
|
36
36
|
|
37
37
|
if TYPE_CHECKING:
|
@@ -105,13 +105,10 @@ class _ControlNetModel_Cross_Attention(torch.nn.Module):
|
|
105
105
|
|
106
106
|
|
107
107
|
class RBLNControlNetModel(RBLNModel):
|
108
|
-
model_type = "rbln_model"
|
109
|
-
auto_model_class = AutoModel # feature extraction
|
110
|
-
|
111
108
|
def __post_init__(self, **kwargs):
|
112
|
-
|
109
|
+
super().__post_init__(**kwargs)
|
113
110
|
self.use_encoder_hidden_states = any(
|
114
|
-
item[0] == "encoder_hidden_states" for item in self.rbln_config[
|
111
|
+
item[0] == "encoder_hidden_states" for item in self.rbln_config.compile_cfgs[0].input_info
|
115
112
|
)
|
116
113
|
|
117
114
|
@classmethod
|
@@ -121,9 +118,6 @@ class RBLNControlNetModel(RBLNModel):
|
|
121
118
|
model_name_or_path: Union[str, Path],
|
122
119
|
**kwargs,
|
123
120
|
):
|
124
|
-
if "subfolder" in kwargs:
|
125
|
-
del kwargs["subfolder"]
|
126
|
-
|
127
121
|
return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
|
128
122
|
|
129
123
|
tasktmp = TasksManager.get_model_from_task
|
@@ -155,14 +149,14 @@ class RBLNControlNetModel(RBLNModel):
|
|
155
149
|
cls,
|
156
150
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
157
151
|
model_config: "PretrainedConfig",
|
158
|
-
|
159
|
-
rbln_text_model_hidden_size: Optional[int] = None,
|
160
|
-
rbln_batch_size: Optional[int] = None,
|
161
|
-
rbln_img_width: Optional[int] = None,
|
162
|
-
rbln_img_height: Optional[int] = None,
|
163
|
-
rbln_vae_scale_factor: Optional[int] = None,
|
152
|
+
rbln_kwargs: Dict[str, Any] = {},
|
164
153
|
) -> RBLNConfig:
|
165
|
-
|
154
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
155
|
+
rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
|
156
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
157
|
+
rbln_img_width = rbln_kwargs.get("img_width", None)
|
158
|
+
rbln_img_height = rbln_kwargs.get("img_height", None)
|
159
|
+
rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
|
166
160
|
|
167
161
|
if rbln_batch_size is None:
|
168
162
|
rbln_batch_size = 1
|
@@ -170,28 +164,29 @@ class RBLNControlNetModel(RBLNModel):
|
|
170
164
|
if rbln_max_seq_len is None:
|
171
165
|
rbln_max_seq_len = 77
|
172
166
|
|
167
|
+
if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
|
168
|
+
raise ValueError("rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided")
|
169
|
+
|
173
170
|
input_width = rbln_img_width // rbln_vae_scale_factor
|
174
171
|
input_height = rbln_img_height // rbln_vae_scale_factor
|
175
172
|
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
batch_size=rbln_batch_size,
|
191
|
-
)
|
173
|
+
input_info = [
|
174
|
+
(
|
175
|
+
"sample",
|
176
|
+
[
|
177
|
+
rbln_batch_size,
|
178
|
+
model_config.in_channels,
|
179
|
+
input_height,
|
180
|
+
input_width,
|
181
|
+
],
|
182
|
+
"float32",
|
183
|
+
),
|
184
|
+
("timestep", [], "float32"),
|
185
|
+
]
|
186
|
+
|
192
187
|
use_encoder_hidden_states = any(element != "DownBlock2D" for element in model_config.down_block_types)
|
193
188
|
if use_encoder_hidden_states:
|
194
|
-
|
189
|
+
input_info.append(
|
195
190
|
(
|
196
191
|
"encoder_hidden_states",
|
197
192
|
[
|
@@ -202,19 +197,34 @@ class RBLNControlNetModel(RBLNModel):
|
|
202
197
|
"float32",
|
203
198
|
)
|
204
199
|
)
|
205
|
-
|
206
|
-
|
207
|
-
)
|
208
|
-
|
200
|
+
|
201
|
+
input_info.append(("controlnet_cond", [rbln_batch_size, 3, rbln_img_height, rbln_img_width], "float32"))
|
202
|
+
input_info.append(("conditioning_scale", [], "float32"))
|
203
|
+
|
209
204
|
if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
|
210
205
|
if rbln_text_model_hidden_size is None:
|
211
206
|
rbln_text_model_hidden_size = 768
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
207
|
+
input_info.append(("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32"))
|
208
|
+
input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
|
209
|
+
|
210
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
211
|
+
|
212
|
+
rbln_config = RBLNConfig(
|
213
|
+
rbln_cls=cls.__name__,
|
214
|
+
compile_cfgs=[rbln_compile_config],
|
215
|
+
rbln_kwargs=rbln_kwargs,
|
216
|
+
)
|
217
|
+
|
218
|
+
rbln_config.model_cfg.update(
|
219
|
+
{
|
220
|
+
"max_seq_len": rbln_max_seq_len,
|
221
|
+
"batch_size": rbln_batch_size,
|
222
|
+
"img_width": rbln_img_width,
|
223
|
+
"img_height": rbln_img_height,
|
224
|
+
"vae_scale_factor": rbln_vae_scale_factor,
|
225
|
+
}
|
226
|
+
)
|
216
227
|
|
217
|
-
rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
218
228
|
return rbln_config
|
219
229
|
|
220
230
|
def forward(
|
@@ -32,7 +32,7 @@ from optimum.exporters import TasksManager
|
|
32
32
|
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
33
33
|
|
34
34
|
from ...modeling_base import RBLNModel
|
35
|
-
from ...modeling_config import
|
35
|
+
from ...modeling_config import RBLNCompileConfig, RBLNConfig
|
36
36
|
|
37
37
|
|
38
38
|
if TYPE_CHECKING:
|
@@ -126,12 +126,9 @@ class _UNet_SDXL(torch.nn.Module):
|
|
126
126
|
|
127
127
|
|
128
128
|
class RBLNUNet2DConditionModel(RBLNModel):
|
129
|
-
model_type = "rbln_model"
|
130
|
-
auto_model_class = AutoModel # feature extraction
|
131
|
-
|
132
129
|
def __post_init__(self, **kwargs):
|
133
|
-
|
134
|
-
self.in_features = self.rbln_config.
|
130
|
+
super().__post_init__(**kwargs)
|
131
|
+
self.in_features = self.rbln_config.model_cfg.get("in_features", None)
|
135
132
|
if self.in_features is not None:
|
136
133
|
|
137
134
|
@dataclass
|
@@ -183,31 +180,31 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
183
180
|
cls,
|
184
181
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
185
182
|
model_config: "PretrainedConfig",
|
186
|
-
|
187
|
-
rbln_text_model_hidden_size: Optional[int] = None,
|
188
|
-
rbln_batch_size: Optional[int] = None,
|
189
|
-
rbln_in_features: Optional[int] = None,
|
190
|
-
rbln_use_encode: Optional[bool] = None,
|
191
|
-
rbln_img_width: Optional[int] = None,
|
192
|
-
rbln_img_height: Optional[int] = None,
|
193
|
-
rbln_vae_scale_factor: Optional[int] = None,
|
194
|
-
rbln_is_controlnet: Optional[bool] = None,
|
183
|
+
rbln_kwargs: Dict[str, Any] = {},
|
195
184
|
) -> RBLNConfig:
|
196
|
-
|
197
|
-
|
198
|
-
|
185
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
186
|
+
rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
|
187
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
188
|
+
rbln_in_features = rbln_kwargs.get("in_features", None)
|
189
|
+
rbln_use_encode = rbln_kwargs.get("use_encode", None)
|
190
|
+
rbln_img_width = rbln_kwargs.get("img_width", None)
|
191
|
+
rbln_img_height = rbln_kwargs.get("img_height", None)
|
192
|
+
rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
|
193
|
+
rbln_is_controlnet = rbln_kwargs.get("is_controlnet", None)
|
199
194
|
|
200
195
|
if rbln_max_seq_len is None:
|
201
196
|
rbln_max_seq_len = 77
|
202
|
-
|
203
|
-
|
197
|
+
if rbln_batch_size is None:
|
198
|
+
rbln_batch_size = 1
|
204
199
|
|
205
200
|
if rbln_use_encode:
|
206
|
-
|
201
|
+
if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
|
202
|
+
raise ValueError(
|
203
|
+
"rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided when rbln_use_encode is True"
|
204
|
+
)
|
207
205
|
input_width = rbln_img_width // rbln_vae_scale_factor
|
208
206
|
input_height = rbln_img_height // rbln_vae_scale_factor
|
209
207
|
else:
|
210
|
-
# FIXME :: model_config.sample_size can be tuple or list
|
211
208
|
input_width, input_height = model_config.sample_size, model_config.sample_size
|
212
209
|
|
213
210
|
input_info = [
|
@@ -232,6 +229,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
232
229
|
"float32",
|
233
230
|
),
|
234
231
|
]
|
232
|
+
|
235
233
|
if rbln_is_controlnet:
|
236
234
|
if len(model_config.block_out_channels) > 0:
|
237
235
|
input_info.extend(
|
@@ -304,24 +302,35 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
304
302
|
)
|
305
303
|
)
|
306
304
|
|
307
|
-
|
308
|
-
input_info=input_info,
|
309
|
-
batch_size=rbln_batch_size,
|
310
|
-
)
|
305
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
311
306
|
|
312
307
|
if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
|
313
|
-
# In case of sdxl
|
314
308
|
if rbln_text_model_hidden_size is None:
|
315
309
|
rbln_text_model_hidden_size = 768
|
316
310
|
if rbln_in_features is None:
|
317
311
|
rbln_in_features = model_config.projection_class_embeddings_input_dim
|
318
|
-
|
319
|
-
rbln_runtime_config.input_info.append(
|
312
|
+
rbln_compile_config.input_info.append(
|
320
313
|
("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
|
321
314
|
)
|
322
|
-
|
315
|
+
rbln_compile_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
|
316
|
+
|
317
|
+
rbln_config = RBLNConfig(
|
318
|
+
rbln_cls=cls.__name__,
|
319
|
+
compile_cfgs=[rbln_compile_config],
|
320
|
+
rbln_kwargs=rbln_kwargs,
|
321
|
+
)
|
322
|
+
|
323
|
+
rbln_config.model_cfg.update(
|
324
|
+
{
|
325
|
+
"max_seq_len": rbln_max_seq_len,
|
326
|
+
"batch_size": rbln_batch_size,
|
327
|
+
"use_encode": rbln_use_encode,
|
328
|
+
}
|
329
|
+
)
|
330
|
+
|
331
|
+
if rbln_in_features is not None:
|
332
|
+
rbln_config.model_cfg["in_features"] = rbln_in_features
|
323
333
|
|
324
|
-
rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
325
334
|
return rbln_config
|
326
335
|
|
327
336
|
def forward(
|
@@ -102,6 +102,10 @@ class RBLNMultiControlNetModel(RBLNModel):
|
|
102
102
|
real_save_path = save_directory + suffix
|
103
103
|
model.save_pretrained(real_save_path)
|
104
104
|
|
105
|
+
@classmethod
|
106
|
+
def _get_rbln_config(cls, **rbln_config_kwargs):
|
107
|
+
pass
|
108
|
+
|
105
109
|
def forward(
|
106
110
|
self,
|
107
111
|
sample: torch.FloatTensor,
|