optimum-rbln 0.1.8__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +40 -2
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +39 -32
- optimum/rbln/diffusers/models/controlnet.py +60 -43
- optimum/rbln/diffusers/models/unet_2d_condition.py +43 -31
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +2 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
- optimum/rbln/modeling_alias.py +8 -4
- optimum/rbln/modeling_base.py +512 -238
- optimum/rbln/modeling_config.py +152 -77
- optimum/rbln/modeling_seq2seq.py +166 -77
- optimum/rbln/transformers/__init__.py +37 -1
- optimum/rbln/transformers/models/__init__.py +21 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
- optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +128 -26
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +32 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +406 -104
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +10 -3
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -89
- optimum/rbln/transformers/models/llama/modeling_llama.py +9 -3
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +5 -88
- optimum/rbln/transformers/models/mistral/__init__.py +24 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +18 -12
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +25 -16
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +97 -0
- optimum/rbln/utils/import_utils.py +37 -5
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +35 -1
- optimum/rbln/utils/timer_utils.py +19 -0
- {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +15 -7
- optimum_rbln-0.1.11.dist-info/RECORD +93 -0
- {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.8.dist-info/RECORD +0 -73
- {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING
|
|
25
25
|
|
26
26
|
from transformers.utils import _LazyModule
|
27
27
|
|
28
|
+
from .__version__ import __version__
|
28
29
|
from .utils import check_version_compats
|
29
30
|
|
30
31
|
|
@@ -32,6 +33,7 @@ _import_structure = {
|
|
32
33
|
"modeling_alias": [
|
33
34
|
"RBLNASTForAudioClassification",
|
34
35
|
"RBLNBertForQuestionAnswering",
|
36
|
+
"RBLNDistilBertForQuestionAnswering",
|
35
37
|
"RBLNResNetForImageClassification",
|
36
38
|
"RBLNT5ForConditionalGeneration",
|
37
39
|
"RBLNBartForConditionalGeneration",
|
@@ -53,14 +55,32 @@ _import_structure = {
|
|
53
55
|
],
|
54
56
|
"transformers": [
|
55
57
|
"BatchTextIteratorStreamer",
|
58
|
+
"RBLNAutoModel",
|
59
|
+
"RBLNAutoModelForAudioClassification",
|
60
|
+
"RBLNAutoModelForCausalLM",
|
61
|
+
"RBLNAutoModelForCTC",
|
62
|
+
"RBLNAutoModelForDepthEstimation",
|
63
|
+
"RBLNAutoModelForImageClassification",
|
64
|
+
"RBLNAutoModelForMaskedLM",
|
65
|
+
"RBLNAutoModelForQuestionAnswering",
|
66
|
+
"RBLNAutoModelForSeq2SeqLM",
|
67
|
+
"RBLNAutoModelForSequenceClassification",
|
68
|
+
"RBLNAutoModelForSpeechSeq2Seq",
|
69
|
+
"RBLNAutoModelForVision2Seq",
|
70
|
+
"RBLNBartModel",
|
71
|
+
"RBLNBertModel",
|
56
72
|
"RBLNCLIPTextModel",
|
57
73
|
"RBLNCLIPTextModelWithProjection",
|
74
|
+
"RBLNCLIPVisionModel",
|
58
75
|
"RBLNDPTForDepthEstimation",
|
59
76
|
"RBLNGemmaForCausalLM",
|
60
77
|
"RBLNGPT2LMHeadModel",
|
61
78
|
"RBLNWav2Vec2ForCTC",
|
62
79
|
"RBLNLlamaForCausalLM",
|
80
|
+
"RBLNPhiForCausalLM",
|
81
|
+
"RBLNLlavaNextForConditionalGeneration",
|
63
82
|
"RBLNMidmLMHeadModel",
|
83
|
+
"RBLNMistralForCausalLM",
|
64
84
|
"RBLNWhisperForConditionalGeneration",
|
65
85
|
"RBLNXLMRobertaModel",
|
66
86
|
],
|
@@ -78,7 +98,7 @@ _import_structure = {
|
|
78
98
|
"RBLNStableDiffusionXLControlNetPipeline",
|
79
99
|
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
80
100
|
],
|
81
|
-
"modeling_config": ["
|
101
|
+
"modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
|
82
102
|
}
|
83
103
|
|
84
104
|
if TYPE_CHECKING:
|
@@ -115,17 +135,35 @@ if TYPE_CHECKING:
|
|
115
135
|
RBLNModelForQuestionAnswering,
|
116
136
|
RBLNModelForSequenceClassification,
|
117
137
|
)
|
118
|
-
from .modeling_config import
|
138
|
+
from .modeling_config import RBLNCompileConfig, RBLNConfig
|
119
139
|
from .modeling_seq2seq import RBLNModelForSeq2SeqLM
|
120
140
|
from .transformers import (
|
121
141
|
BatchTextIteratorStreamer,
|
142
|
+
RBLNAutoModel,
|
143
|
+
RBLNAutoModelForAudioClassification,
|
144
|
+
RBLNAutoModelForCausalLM,
|
145
|
+
RBLNAutoModelForCTC,
|
146
|
+
RBLNAutoModelForDepthEstimation,
|
147
|
+
RBLNAutoModelForImageClassification,
|
148
|
+
RBLNAutoModelForMaskedLM,
|
149
|
+
RBLNAutoModelForQuestionAnswering,
|
150
|
+
RBLNAutoModelForSeq2SeqLM,
|
151
|
+
RBLNAutoModelForSequenceClassification,
|
152
|
+
RBLNAutoModelForSpeechSeq2Seq,
|
153
|
+
RBLNAutoModelForVision2Seq,
|
154
|
+
RBLNBartModel,
|
155
|
+
RBLNBertModel,
|
122
156
|
RBLNCLIPTextModel,
|
123
157
|
RBLNCLIPTextModelWithProjection,
|
158
|
+
RBLNCLIPVisionModel,
|
124
159
|
RBLNDPTForDepthEstimation,
|
125
160
|
RBLNGemmaForCausalLM,
|
126
161
|
RBLNGPT2LMHeadModel,
|
127
162
|
RBLNLlamaForCausalLM,
|
163
|
+
RBLNLlavaNextForConditionalGeneration,
|
128
164
|
RBLNMidmLMHeadModel,
|
165
|
+
RBLNMistralForCausalLM,
|
166
|
+
RBLNPhiForCausalLM,
|
129
167
|
RBLNWav2Vec2ForCTC,
|
130
168
|
RBLNWhisperForConditionalGeneration,
|
131
169
|
RBLNXLMRobertaModel,
|
optimum/rbln/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = '0.1.
|
1
|
+
__version__ = '0.1.11'
|
@@ -23,10 +23,10 @@
|
|
23
23
|
|
24
24
|
import logging
|
25
25
|
from pathlib import Path
|
26
|
-
from typing import TYPE_CHECKING, Dict, List,
|
26
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
27
27
|
|
28
28
|
import rebel
|
29
|
-
import torch
|
29
|
+
import torch # noqa: I001
|
30
30
|
from diffusers import AutoencoderKL
|
31
31
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
32
32
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
@@ -34,16 +34,16 @@ from optimum.exporters import TasksManager
|
|
34
34
|
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
35
35
|
|
36
36
|
from ...modeling_base import RBLNModel
|
37
|
-
from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME,
|
37
|
+
from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
38
38
|
from ...utils.runtime_utils import RBLNPytorchRuntime
|
39
39
|
|
40
40
|
|
41
|
-
logger = logging.getLogger(__name__)
|
42
|
-
|
43
41
|
if TYPE_CHECKING:
|
44
42
|
import torch
|
45
43
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
46
44
|
|
45
|
+
logger = logging.getLogger(__name__)
|
46
|
+
|
47
47
|
|
48
48
|
class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
49
49
|
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
@@ -63,10 +63,9 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
63
63
|
auto_model_class = AutoModel # feature extraction
|
64
64
|
|
65
65
|
def __post_init__(self, **kwargs):
|
66
|
-
|
67
|
-
|
68
|
-
self.rbln_use_encode = self.rbln_config.meta["rbln_use_encode"]
|
66
|
+
super().__post_init__(**kwargs)
|
69
67
|
|
68
|
+
self.rbln_use_encode = self.rbln_config.model_cfg["use_encode"]
|
70
69
|
if self.rbln_use_encode:
|
71
70
|
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
|
72
71
|
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
|
@@ -81,20 +80,20 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
81
80
|
encoder_model.eval()
|
82
81
|
decoder_model.eval()
|
83
82
|
|
84
|
-
enc_compiled_model = cls.compile(encoder_model,
|
85
|
-
dec_compiled_model = cls.compile(decoder_model,
|
83
|
+
enc_compiled_model = cls.compile(encoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
|
84
|
+
dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[1])
|
86
85
|
|
87
|
-
return enc_compiled_model, dec_compiled_model
|
86
|
+
return {"encoder": enc_compiled_model, "decoder": dec_compiled_model}
|
88
87
|
|
89
88
|
def compile_text2img():
|
90
89
|
decoder_model = _VAEDecoder(model)
|
91
90
|
decoder_model.eval()
|
92
91
|
|
93
|
-
dec_compiled_model = cls.compile(decoder_model,
|
92
|
+
dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
|
94
93
|
|
95
94
|
return dec_compiled_model
|
96
95
|
|
97
|
-
if rbln_config.
|
96
|
+
if rbln_config.model_cfg.get("use_encode", False):
|
98
97
|
return compile_img2img()
|
99
98
|
else:
|
100
99
|
return compile_text2img()
|
@@ -133,23 +132,23 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
133
132
|
cls,
|
134
133
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
135
134
|
model_config: "PretrainedConfig",
|
136
|
-
|
137
|
-
rbln_img_width: Optional[int] = None,
|
138
|
-
rbln_img_height: Optional[int] = None,
|
139
|
-
rbln_batch_size: Optional[int] = None,
|
140
|
-
rbln_use_encode: Optional[bool] = None,
|
141
|
-
rbln_vae_scale_factor: Optional[int] = None,
|
135
|
+
rbln_kwargs: Dict[str, Any] = {},
|
142
136
|
) -> RBLNConfig:
|
143
|
-
|
137
|
+
rbln_unet_sample_size = rbln_kwargs.get("unet_sample_size", None)
|
138
|
+
rbln_img_width = rbln_kwargs.get("img_width", None)
|
139
|
+
rbln_img_height = rbln_kwargs.get("img_height", None)
|
140
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
141
|
+
rbln_use_encode = rbln_kwargs.get("use_encode", None)
|
142
|
+
rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
|
143
|
+
|
144
144
|
if rbln_batch_size is None:
|
145
145
|
rbln_batch_size = 1
|
146
146
|
|
147
|
-
|
148
|
-
meta["rbln_batch_size"] = rbln_batch_size
|
147
|
+
model_cfg = {}
|
149
148
|
|
150
149
|
if rbln_use_encode:
|
151
|
-
|
152
|
-
|
150
|
+
model_cfg["img_width"] = rbln_img_width
|
151
|
+
model_cfg["img_height"] = rbln_img_height
|
153
152
|
|
154
153
|
vae_enc_input_info = [
|
155
154
|
("x", [rbln_batch_size, model_config.in_channels, rbln_img_height, rbln_img_width], "float32")
|
@@ -167,20 +166,23 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
167
166
|
)
|
168
167
|
]
|
169
168
|
|
170
|
-
|
171
|
-
|
169
|
+
enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info)
|
170
|
+
dec_rbln_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info)
|
172
171
|
|
173
|
-
|
174
|
-
|
175
|
-
|
172
|
+
compile_cfgs = [enc_rbln_compile_config, dec_rbln_compile_config]
|
173
|
+
rbln_config = RBLNConfig(
|
174
|
+
rbln_cls=cls.__name__,
|
175
|
+
compile_cfgs=compile_cfgs,
|
176
|
+
rbln_kwargs=rbln_kwargs,
|
176
177
|
)
|
178
|
+
rbln_config.model_cfg.update(model_cfg)
|
177
179
|
return rbln_config
|
178
180
|
|
179
181
|
if rbln_unet_sample_size is None:
|
180
182
|
rbln_unet_sample_size = 64
|
181
183
|
|
182
|
-
|
183
|
-
vae_config =
|
184
|
+
model_cfg["unet_sample_size"] = rbln_unet_sample_size
|
185
|
+
vae_config = RBLNCompileConfig(
|
184
186
|
input_info=[
|
185
187
|
(
|
186
188
|
"z",
|
@@ -194,7 +196,12 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
194
196
|
)
|
195
197
|
],
|
196
198
|
)
|
197
|
-
rbln_config = RBLNConfig
|
199
|
+
rbln_config = RBLNConfig(
|
200
|
+
rbln_cls=cls.__name__,
|
201
|
+
compile_cfgs=[vae_config],
|
202
|
+
rbln_kwargs=rbln_kwargs,
|
203
|
+
)
|
204
|
+
rbln_config.model_cfg.update(model_cfg)
|
198
205
|
return rbln_config
|
199
206
|
|
200
207
|
@classmethod
|
@@ -23,7 +23,7 @@
|
|
23
23
|
|
24
24
|
import logging
|
25
25
|
from pathlib import Path
|
26
|
-
from typing import TYPE_CHECKING, Dict, Optional, Union
|
26
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
27
27
|
|
28
28
|
import torch
|
29
29
|
from diffusers import ControlNetModel
|
@@ -31,15 +31,16 @@ from optimum.exporters import TasksManager
|
|
31
31
|
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
32
32
|
|
33
33
|
from ...modeling_base import RBLNModel
|
34
|
-
from ...modeling_config import
|
34
|
+
from ...modeling_config import RBLNCompileConfig, RBLNConfig
|
35
35
|
|
36
36
|
|
37
|
-
logger = logging.getLogger(__name__)
|
38
|
-
|
39
37
|
if TYPE_CHECKING:
|
40
38
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
41
39
|
|
42
40
|
|
41
|
+
logger = logging.getLogger(__name__)
|
42
|
+
|
43
|
+
|
43
44
|
class _ControlNetModel(torch.nn.Module):
|
44
45
|
def __init__(self, controlnet: "ControlNetModel"):
|
45
46
|
super().__init__()
|
@@ -108,21 +109,21 @@ class RBLNControlNetModel(RBLNModel):
|
|
108
109
|
auto_model_class = AutoModel # feature extraction
|
109
110
|
|
110
111
|
def __post_init__(self, **kwargs):
|
111
|
-
|
112
|
+
super().__post_init__(**kwargs)
|
112
113
|
self.use_encoder_hidden_states = any(
|
113
|
-
item[0] == "encoder_hidden_states" for item in self.rbln_config[
|
114
|
+
item[0] == "encoder_hidden_states" for item in self.rbln_config.compile_cfgs[0].input_info
|
114
115
|
)
|
115
116
|
|
116
117
|
@classmethod
|
117
118
|
def from_pretrained(cls, *args, **kwargs):
|
119
|
+
if "subfolder" in kwargs:
|
120
|
+
del kwargs["subfolder"]
|
121
|
+
|
118
122
|
def get_model_from_task(
|
119
123
|
task: str,
|
120
124
|
model_name_or_path: Union[str, Path],
|
121
125
|
**kwargs,
|
122
126
|
):
|
123
|
-
if "subfolder" in kwargs:
|
124
|
-
del kwargs["subfolder"]
|
125
|
-
|
126
127
|
return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
|
127
128
|
|
128
129
|
tasktmp = TasksManager.get_model_from_task
|
@@ -138,7 +139,7 @@ class RBLNControlNetModel(RBLNModel):
|
|
138
139
|
return rt
|
139
140
|
|
140
141
|
@classmethod
|
141
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
|
142
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
142
143
|
use_encoder_hidden_states = False
|
143
144
|
for down_block in model.down_blocks:
|
144
145
|
if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
|
@@ -154,14 +155,14 @@ class RBLNControlNetModel(RBLNModel):
|
|
154
155
|
cls,
|
155
156
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
156
157
|
model_config: "PretrainedConfig",
|
157
|
-
|
158
|
-
rbln_text_model_hidden_size: Optional[int] = None,
|
159
|
-
rbln_batch_size: Optional[int] = None,
|
160
|
-
rbln_img_width: Optional[int] = None,
|
161
|
-
rbln_img_height: Optional[int] = None,
|
162
|
-
rbln_vae_scale_factor: Optional[int] = None,
|
158
|
+
rbln_kwargs: Dict[str, Any] = {},
|
163
159
|
) -> RBLNConfig:
|
164
|
-
|
160
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
161
|
+
rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
|
162
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
163
|
+
rbln_img_width = rbln_kwargs.get("img_width", None)
|
164
|
+
rbln_img_height = rbln_kwargs.get("img_height", None)
|
165
|
+
rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
|
165
166
|
|
166
167
|
if rbln_batch_size is None:
|
167
168
|
rbln_batch_size = 1
|
@@ -169,28 +170,29 @@ class RBLNControlNetModel(RBLNModel):
|
|
169
170
|
if rbln_max_seq_len is None:
|
170
171
|
rbln_max_seq_len = 77
|
171
172
|
|
173
|
+
if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
|
174
|
+
raise ValueError("rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided")
|
175
|
+
|
172
176
|
input_width = rbln_img_width // rbln_vae_scale_factor
|
173
177
|
input_height = rbln_img_height // rbln_vae_scale_factor
|
174
178
|
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
batch_size=rbln_batch_size,
|
190
|
-
)
|
179
|
+
input_info = [
|
180
|
+
(
|
181
|
+
"sample",
|
182
|
+
[
|
183
|
+
rbln_batch_size,
|
184
|
+
model_config.in_channels,
|
185
|
+
input_height,
|
186
|
+
input_width,
|
187
|
+
],
|
188
|
+
"float32",
|
189
|
+
),
|
190
|
+
("timestep", [], "float32"),
|
191
|
+
]
|
192
|
+
|
191
193
|
use_encoder_hidden_states = any(element != "DownBlock2D" for element in model_config.down_block_types)
|
192
194
|
if use_encoder_hidden_states:
|
193
|
-
|
195
|
+
input_info.append(
|
194
196
|
(
|
195
197
|
"encoder_hidden_states",
|
196
198
|
[
|
@@ -201,19 +203,34 @@ class RBLNControlNetModel(RBLNModel):
|
|
201
203
|
"float32",
|
202
204
|
)
|
203
205
|
)
|
204
|
-
|
205
|
-
|
206
|
-
)
|
207
|
-
|
206
|
+
|
207
|
+
input_info.append(("controlnet_cond", [rbln_batch_size, 3, rbln_img_height, rbln_img_width], "float32"))
|
208
|
+
input_info.append(("conditioning_scale", [], "float32"))
|
209
|
+
|
208
210
|
if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
|
209
211
|
if rbln_text_model_hidden_size is None:
|
210
212
|
rbln_text_model_hidden_size = 768
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
213
|
+
input_info.append(("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32"))
|
214
|
+
input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
|
215
|
+
|
216
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
217
|
+
|
218
|
+
rbln_config = RBLNConfig(
|
219
|
+
rbln_cls=cls.__name__,
|
220
|
+
compile_cfgs=[rbln_compile_config],
|
221
|
+
rbln_kwargs=rbln_kwargs,
|
222
|
+
)
|
223
|
+
|
224
|
+
rbln_config.model_cfg.update(
|
225
|
+
{
|
226
|
+
"max_seq_len": rbln_max_seq_len,
|
227
|
+
"batch_size": rbln_batch_size,
|
228
|
+
"img_width": rbln_img_width,
|
229
|
+
"img_height": rbln_img_height,
|
230
|
+
"vae_scale_factor": rbln_vae_scale_factor,
|
231
|
+
}
|
232
|
+
)
|
215
233
|
|
216
|
-
rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
217
234
|
return rbln_config
|
218
235
|
|
219
236
|
def forward(
|
@@ -32,14 +32,14 @@ from optimum.exporters import TasksManager
|
|
32
32
|
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
33
33
|
|
34
34
|
from ...modeling_base import RBLNModel
|
35
|
-
from ...modeling_config import
|
35
|
+
from ...modeling_config import RBLNCompileConfig, RBLNConfig
|
36
36
|
|
37
37
|
|
38
|
-
logger = logging.getLogger(__name__)
|
39
|
-
|
40
38
|
if TYPE_CHECKING:
|
41
39
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
42
40
|
|
41
|
+
logger = logging.getLogger(__name__)
|
42
|
+
|
43
43
|
|
44
44
|
class _UNet_SD(torch.nn.Module):
|
45
45
|
def __init__(self, unet: "UNet2DConditionModel"):
|
@@ -130,8 +130,8 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
130
130
|
auto_model_class = AutoModel # feature extraction
|
131
131
|
|
132
132
|
def __post_init__(self, **kwargs):
|
133
|
-
|
134
|
-
self.in_features = self.rbln_config.
|
133
|
+
super().__post_init__(**kwargs)
|
134
|
+
self.in_features = self.rbln_config.model_cfg.get("in_features", None)
|
135
135
|
if self.in_features is not None:
|
136
136
|
|
137
137
|
@dataclass
|
@@ -172,7 +172,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
172
172
|
return rt
|
173
173
|
|
174
174
|
@classmethod
|
175
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
|
175
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
176
176
|
if model.config.addition_embed_type == "text_time":
|
177
177
|
return _UNet_SDXL(model).eval()
|
178
178
|
else:
|
@@ -183,31 +183,31 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
183
183
|
cls,
|
184
184
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
185
185
|
model_config: "PretrainedConfig",
|
186
|
-
|
187
|
-
rbln_text_model_hidden_size: Optional[int] = None,
|
188
|
-
rbln_batch_size: Optional[int] = None,
|
189
|
-
rbln_in_features: Optional[int] = None,
|
190
|
-
rbln_use_encode: Optional[bool] = None,
|
191
|
-
rbln_img_width: Optional[int] = None,
|
192
|
-
rbln_img_height: Optional[int] = None,
|
193
|
-
rbln_vae_scale_factor: Optional[int] = None,
|
194
|
-
rbln_is_controlnet: Optional[bool] = None,
|
186
|
+
rbln_kwargs: Dict[str, Any] = {},
|
195
187
|
) -> RBLNConfig:
|
196
|
-
|
197
|
-
|
198
|
-
|
188
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
189
|
+
rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
|
190
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
191
|
+
rbln_in_features = rbln_kwargs.get("in_features", None)
|
192
|
+
rbln_use_encode = rbln_kwargs.get("use_encode", None)
|
193
|
+
rbln_img_width = rbln_kwargs.get("img_width", None)
|
194
|
+
rbln_img_height = rbln_kwargs.get("img_height", None)
|
195
|
+
rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
|
196
|
+
rbln_is_controlnet = rbln_kwargs.get("is_controlnet", None)
|
199
197
|
|
200
198
|
if rbln_max_seq_len is None:
|
201
199
|
rbln_max_seq_len = 77
|
202
|
-
|
203
|
-
|
200
|
+
if rbln_batch_size is None:
|
201
|
+
rbln_batch_size = 1
|
204
202
|
|
205
203
|
if rbln_use_encode:
|
206
|
-
|
204
|
+
if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
|
205
|
+
raise ValueError(
|
206
|
+
"rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided when rbln_use_encode is True"
|
207
|
+
)
|
207
208
|
input_width = rbln_img_width // rbln_vae_scale_factor
|
208
209
|
input_height = rbln_img_height // rbln_vae_scale_factor
|
209
210
|
else:
|
210
|
-
# FIXME :: model_config.sample_size can be tuple or list
|
211
211
|
input_width, input_height = model_config.sample_size, model_config.sample_size
|
212
212
|
|
213
213
|
input_info = [
|
@@ -232,6 +232,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
232
232
|
"float32",
|
233
233
|
),
|
234
234
|
]
|
235
|
+
|
235
236
|
if rbln_is_controlnet:
|
236
237
|
if len(model_config.block_out_channels) > 0:
|
237
238
|
input_info.extend(
|
@@ -304,24 +305,35 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
304
305
|
)
|
305
306
|
)
|
306
307
|
|
307
|
-
|
308
|
-
input_info=input_info,
|
309
|
-
batch_size=rbln_batch_size,
|
310
|
-
)
|
308
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
311
309
|
|
312
310
|
if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
|
313
|
-
# In case of sdxl
|
314
311
|
if rbln_text_model_hidden_size is None:
|
315
312
|
rbln_text_model_hidden_size = 768
|
316
313
|
if rbln_in_features is None:
|
317
314
|
rbln_in_features = model_config.projection_class_embeddings_input_dim
|
318
|
-
|
319
|
-
rbln_runtime_config.input_info.append(
|
315
|
+
rbln_compile_config.input_info.append(
|
320
316
|
("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
|
321
317
|
)
|
322
|
-
|
318
|
+
rbln_compile_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
|
319
|
+
|
320
|
+
rbln_config = RBLNConfig(
|
321
|
+
rbln_cls=cls.__name__,
|
322
|
+
compile_cfgs=[rbln_compile_config],
|
323
|
+
rbln_kwargs=rbln_kwargs,
|
324
|
+
)
|
325
|
+
|
326
|
+
rbln_config.model_cfg.update(
|
327
|
+
{
|
328
|
+
"max_seq_len": rbln_max_seq_len,
|
329
|
+
"batch_size": rbln_batch_size,
|
330
|
+
"use_encode": rbln_use_encode,
|
331
|
+
}
|
332
|
+
)
|
333
|
+
|
334
|
+
if rbln_in_features is not None:
|
335
|
+
rbln_config.model_cfg["in_features"] = rbln_in_features
|
323
336
|
|
324
|
-
rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
325
337
|
return rbln_config
|
326
338
|
|
327
339
|
def forward(
|
@@ -37,11 +37,11 @@ from ....modeling_config import RBLNConfig
|
|
37
37
|
from ...models.controlnet import RBLNControlNetModel
|
38
38
|
|
39
39
|
|
40
|
-
logger = logging.getLogger(__name__)
|
41
|
-
|
42
40
|
if TYPE_CHECKING:
|
43
41
|
pass
|
44
42
|
|
43
|
+
logger = logging.getLogger(__name__)
|
44
|
+
|
45
45
|
|
46
46
|
class RBLNMultiControlNetModel(RBLNModel):
|
47
47
|
def __init__(
|
@@ -79,7 +79,6 @@ class RBLNMultiControlNetModel(RBLNModel):
|
|
79
79
|
model_id: Union[str, Path],
|
80
80
|
**kwargs,
|
81
81
|
) -> RBLNModel:
|
82
|
-
|
83
82
|
idx = 0
|
84
83
|
controlnets = []
|
85
84
|
model_path_to_load = model_id
|