optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -7
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
- optimum/rbln/diffusers/models/controlnet.py +36 -62
- optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
- optimum/rbln/diffusers/pipelines/__init__.py +40 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
- optimum/rbln/modeling_alias.py +4 -9
- optimum/rbln/modeling_base.py +117 -144
- optimum/rbln/modeling_config.py +51 -0
- optimum/rbln/modeling_diffusers.py +400 -0
- optimum/rbln/transformers/__init__.py +10 -0
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -28
- optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
- optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
- optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
- optimum/rbln/utils/context.py +58 -0
- optimum/rbln/utils/decorator_utils.py +55 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/logging.py +1 -1
- optimum/rbln/utils/runtime_utils.py +4 -4
- optimum/rbln/utils/timer_utils.py +26 -2
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
- optimum_rbln-0.1.13.dist-info/RECORD +107 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/RECORD +0 -93
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -35,11 +35,10 @@ _import_structure = {
|
|
35
35
|
"RBLNBertForQuestionAnswering",
|
36
36
|
"RBLNDistilBertForQuestionAnswering",
|
37
37
|
"RBLNResNetForImageClassification",
|
38
|
-
"RBLNT5ForConditionalGeneration",
|
39
|
-
"RBLNBartForConditionalGeneration",
|
40
38
|
"RBLNXLMRobertaForSequenceClassification",
|
41
39
|
"RBLNRobertaForSequenceClassification",
|
42
40
|
"RBLNRobertaForMaskedLM",
|
41
|
+
"RBLNViTForImageClassification",
|
43
42
|
],
|
44
43
|
"modeling_base": [
|
45
44
|
"RBLNBaseModel",
|
@@ -50,9 +49,6 @@ _import_structure = {
|
|
50
49
|
"RBLNModelForSequenceClassification",
|
51
50
|
"RBLNModelForMaskedLM",
|
52
51
|
],
|
53
|
-
"modeling_seq2seq": [
|
54
|
-
"RBLNModelForSeq2SeqLM",
|
55
|
-
],
|
56
52
|
"transformers": [
|
57
53
|
"BatchTextIteratorStreamer",
|
58
54
|
"RBLNAutoModel",
|
@@ -67,16 +63,21 @@ _import_structure = {
|
|
67
63
|
"RBLNAutoModelForSequenceClassification",
|
68
64
|
"RBLNAutoModelForSpeechSeq2Seq",
|
69
65
|
"RBLNAutoModelForVision2Seq",
|
66
|
+
"RBLNBartForConditionalGeneration",
|
70
67
|
"RBLNBartModel",
|
71
68
|
"RBLNBertModel",
|
72
69
|
"RBLNCLIPTextModel",
|
73
70
|
"RBLNCLIPTextModelWithProjection",
|
74
71
|
"RBLNCLIPVisionModel",
|
75
72
|
"RBLNDPTForDepthEstimation",
|
73
|
+
"RBLNExaoneForCausalLM",
|
76
74
|
"RBLNGemmaForCausalLM",
|
77
75
|
"RBLNGPT2LMHeadModel",
|
76
|
+
"RBLNQwen2ForCausalLM",
|
78
77
|
"RBLNWav2Vec2ForCTC",
|
79
78
|
"RBLNLlamaForCausalLM",
|
79
|
+
"RBLNT5EncoderModel",
|
80
|
+
"RBLNT5ForConditionalGeneration",
|
80
81
|
"RBLNPhiForCausalLM",
|
81
82
|
"RBLNLlavaNextForConditionalGeneration",
|
82
83
|
"RBLNMidmLMHeadModel",
|
@@ -99,6 +100,7 @@ _import_structure = {
|
|
99
100
|
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
100
101
|
],
|
101
102
|
"modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
|
103
|
+
"modeling_diffusers": ["RBLNDiffusionMixin"],
|
102
104
|
}
|
103
105
|
|
104
106
|
if TYPE_CHECKING:
|
@@ -118,12 +120,12 @@ if TYPE_CHECKING:
|
|
118
120
|
)
|
119
121
|
from .modeling_alias import (
|
120
122
|
RBLNASTForAudioClassification,
|
121
|
-
RBLNBartForConditionalGeneration,
|
122
123
|
RBLNBertForQuestionAnswering,
|
123
124
|
RBLNResNetForImageClassification,
|
124
125
|
RBLNRobertaForMaskedLM,
|
125
126
|
RBLNRobertaForSequenceClassification,
|
126
127
|
RBLNT5ForConditionalGeneration,
|
128
|
+
RBLNViTForImageClassification,
|
127
129
|
RBLNXLMRobertaForSequenceClassification,
|
128
130
|
)
|
129
131
|
from .modeling_base import (
|
@@ -136,7 +138,7 @@ if TYPE_CHECKING:
|
|
136
138
|
RBLNModelForSequenceClassification,
|
137
139
|
)
|
138
140
|
from .modeling_config import RBLNCompileConfig, RBLNConfig
|
139
|
-
from .
|
141
|
+
from .modeling_diffusers import RBLNDiffusionMixin
|
140
142
|
from .transformers import (
|
141
143
|
BatchTextIteratorStreamer,
|
142
144
|
RBLNAutoModel,
|
@@ -151,12 +153,14 @@ if TYPE_CHECKING:
|
|
151
153
|
RBLNAutoModelForSequenceClassification,
|
152
154
|
RBLNAutoModelForSpeechSeq2Seq,
|
153
155
|
RBLNAutoModelForVision2Seq,
|
156
|
+
RBLNBartForConditionalGeneration,
|
154
157
|
RBLNBartModel,
|
155
158
|
RBLNBertModel,
|
156
159
|
RBLNCLIPTextModel,
|
157
160
|
RBLNCLIPTextModelWithProjection,
|
158
161
|
RBLNCLIPVisionModel,
|
159
162
|
RBLNDPTForDepthEstimation,
|
163
|
+
RBLNExaoneForCausalLM,
|
160
164
|
RBLNGemmaForCausalLM,
|
161
165
|
RBLNGPT2LMHeadModel,
|
162
166
|
RBLNLlamaForCausalLM,
|
@@ -164,6 +168,9 @@ if TYPE_CHECKING:
|
|
164
168
|
RBLNMidmLMHeadModel,
|
165
169
|
RBLNMistralForCausalLM,
|
166
170
|
RBLNPhiForCausalLM,
|
171
|
+
RBLNQwen2ForCausalLM,
|
172
|
+
RBLNT5EncoderModel,
|
173
|
+
RBLNT5ForConditionalGeneration,
|
167
174
|
RBLNWav2Vec2ForCTC,
|
168
175
|
RBLNWhisperForConditionalGeneration,
|
169
176
|
RBLNXLMRobertaModel,
|
optimum/rbln/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = '0.1.
|
1
|
+
__version__ = '0.1.13'
|
@@ -22,7 +22,6 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import logging
|
25
|
-
from pathlib import Path
|
26
25
|
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
27
26
|
|
28
27
|
import rebel
|
@@ -30,11 +29,11 @@ import torch # noqa: I001
|
|
30
29
|
from diffusers import AutoencoderKL
|
31
30
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
32
31
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
33
|
-
from
|
34
|
-
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
32
|
+
from transformers import PretrainedConfig
|
35
33
|
|
36
34
|
from ...modeling_base import RBLNModel
|
37
35
|
from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
36
|
+
from ...utils.context import override_auto_classes
|
38
37
|
from ...utils.runtime_utils import RBLNPytorchRuntime
|
39
38
|
|
40
39
|
|
@@ -58,15 +57,12 @@ class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
|
|
58
57
|
|
59
58
|
|
60
59
|
class RBLNAutoencoderKL(RBLNModel):
|
61
|
-
model_type = "rbln_model"
|
62
60
|
config_name = "config.json"
|
63
|
-
auto_model_class = AutoModel # feature extraction
|
64
61
|
|
65
62
|
def __post_init__(self, **kwargs):
|
66
63
|
super().__post_init__(**kwargs)
|
67
64
|
|
68
|
-
|
69
|
-
if self.rbln_use_encode:
|
65
|
+
if self.rbln_config.model_cfg.get("img2img_pipeline"):
|
70
66
|
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
|
71
67
|
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
|
72
68
|
else:
|
@@ -93,38 +89,15 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
93
89
|
|
94
90
|
return dec_compiled_model
|
95
91
|
|
96
|
-
if rbln_config.model_cfg.get("
|
92
|
+
if rbln_config.model_cfg.get("img2img_pipeline"):
|
97
93
|
return compile_img2img()
|
98
94
|
else:
|
99
95
|
return compile_text2img()
|
100
96
|
|
101
97
|
@classmethod
|
102
98
|
def from_pretrained(cls, *args, **kwargs):
|
103
|
-
|
104
|
-
|
105
|
-
model_name_or_path: Union[str, Path],
|
106
|
-
**kwargs,
|
107
|
-
):
|
108
|
-
return AutoencoderKL.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
|
109
|
-
|
110
|
-
tasktmp = TasksManager.get_model_from_task
|
111
|
-
configtmp = AutoConfig.from_pretrained
|
112
|
-
modeltmp = AutoModel.from_pretrained
|
113
|
-
TasksManager.get_model_from_task = get_model_from_task
|
114
|
-
|
115
|
-
if kwargs.get("export", None):
|
116
|
-
# This is an ad-hoc to workaround save null values of the config.
|
117
|
-
# if export, pure optimum(not optimum-rbln) loads config using AutoConfig
|
118
|
-
# and diffusers model do not support loading by AutoConfig.
|
119
|
-
AutoConfig.from_pretrained = lambda *args, **kwargs: None
|
120
|
-
else:
|
121
|
-
AutoConfig.from_pretrained = AutoencoderKL.load_config
|
122
|
-
|
123
|
-
AutoModel.from_pretrained = AutoencoderKL.from_pretrained
|
124
|
-
rt = super().from_pretrained(*args, **kwargs)
|
125
|
-
AutoConfig.from_pretrained = configtmp
|
126
|
-
AutoModel.from_pretrained = modeltmp
|
127
|
-
TasksManager.get_model_from_task = tasktmp
|
99
|
+
with override_auto_classes(config_func=AutoencoderKL.load_config, model_func=AutoencoderKL.from_pretrained):
|
100
|
+
rt = super().from_pretrained(*args, **kwargs)
|
128
101
|
return rt
|
129
102
|
|
130
103
|
@classmethod
|
@@ -134,34 +107,39 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
134
107
|
model_config: "PretrainedConfig",
|
135
108
|
rbln_kwargs: Dict[str, Any] = {},
|
136
109
|
) -> RBLNConfig:
|
137
|
-
|
138
|
-
|
139
|
-
rbln_img_height = rbln_kwargs.get("img_height", None)
|
140
|
-
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
141
|
-
rbln_use_encode = rbln_kwargs.get("use_encode", None)
|
142
|
-
rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
|
110
|
+
rbln_batch_size = rbln_kwargs.get("batch_size")
|
111
|
+
sample_size = rbln_kwargs.get("sample_size")
|
143
112
|
|
144
113
|
if rbln_batch_size is None:
|
145
114
|
rbln_batch_size = 1
|
146
115
|
|
147
|
-
|
116
|
+
if sample_size is None:
|
117
|
+
sample_size = model_config.sample_size
|
118
|
+
|
119
|
+
if isinstance(sample_size, int):
|
120
|
+
sample_size = (sample_size, sample_size)
|
121
|
+
|
122
|
+
if hasattr(model_config, "block_out_channels"):
|
123
|
+
vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
|
124
|
+
else:
|
125
|
+
# vae image processor default value 8 (int)
|
126
|
+
vae_scale_factor = 8
|
148
127
|
|
149
|
-
|
150
|
-
|
151
|
-
model_cfg["img_height"] = rbln_img_height
|
128
|
+
dec_shape = (sample_size[0] // vae_scale_factor, sample_size[1] // vae_scale_factor)
|
129
|
+
enc_shape = (sample_size[0], sample_size[1])
|
152
130
|
|
131
|
+
if rbln_kwargs["img2img_pipeline"]:
|
153
132
|
vae_enc_input_info = [
|
154
|
-
(
|
133
|
+
(
|
134
|
+
"x",
|
135
|
+
[rbln_batch_size, model_config.in_channels, enc_shape[0], enc_shape[1]],
|
136
|
+
"float32",
|
137
|
+
)
|
155
138
|
]
|
156
139
|
vae_dec_input_info = [
|
157
140
|
(
|
158
141
|
"z",
|
159
|
-
[
|
160
|
-
rbln_batch_size,
|
161
|
-
model_config.latent_channels,
|
162
|
-
rbln_img_height // rbln_vae_scale_factor,
|
163
|
-
rbln_img_width // rbln_vae_scale_factor,
|
164
|
-
],
|
142
|
+
[rbln_batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
|
165
143
|
"float32",
|
166
144
|
)
|
167
145
|
]
|
@@ -175,33 +153,22 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
175
153
|
compile_cfgs=compile_cfgs,
|
176
154
|
rbln_kwargs=rbln_kwargs,
|
177
155
|
)
|
178
|
-
rbln_config.model_cfg.update(model_cfg)
|
179
156
|
return rbln_config
|
180
157
|
|
181
|
-
if rbln_unet_sample_size is None:
|
182
|
-
rbln_unet_sample_size = 64
|
183
|
-
|
184
|
-
model_cfg["unet_sample_size"] = rbln_unet_sample_size
|
185
158
|
vae_config = RBLNCompileConfig(
|
186
159
|
input_info=[
|
187
160
|
(
|
188
161
|
"z",
|
189
|
-
[
|
190
|
-
rbln_batch_size,
|
191
|
-
model_config.latent_channels,
|
192
|
-
rbln_unet_sample_size,
|
193
|
-
rbln_unet_sample_size,
|
194
|
-
],
|
162
|
+
[rbln_batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
|
195
163
|
"float32",
|
196
164
|
)
|
197
|
-
]
|
165
|
+
]
|
198
166
|
)
|
199
167
|
rbln_config = RBLNConfig(
|
200
168
|
rbln_cls=cls.__name__,
|
201
169
|
compile_cfgs=[vae_config],
|
202
170
|
rbln_kwargs=rbln_kwargs,
|
203
171
|
)
|
204
|
-
rbln_config.model_cfg.update(model_cfg)
|
205
172
|
return rbln_config
|
206
173
|
|
207
174
|
@classmethod
|
@@ -22,16 +22,15 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import logging
|
25
|
-
from pathlib import Path
|
26
25
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
27
26
|
|
28
27
|
import torch
|
29
28
|
from diffusers import ControlNetModel
|
30
|
-
from
|
31
|
-
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
29
|
+
from transformers import PretrainedConfig
|
32
30
|
|
33
31
|
from ...modeling_base import RBLNModel
|
34
32
|
from ...modeling_config import RBLNCompileConfig, RBLNConfig
|
33
|
+
from ...utils.context import override_auto_classes
|
35
34
|
|
36
35
|
|
37
36
|
if TYPE_CHECKING:
|
@@ -105,9 +104,6 @@ class _ControlNetModel_Cross_Attention(torch.nn.Module):
|
|
105
104
|
|
106
105
|
|
107
106
|
class RBLNControlNetModel(RBLNModel):
|
108
|
-
model_type = "rbln_model"
|
109
|
-
auto_model_class = AutoModel # feature extraction
|
110
|
-
|
111
107
|
def __post_init__(self, **kwargs):
|
112
108
|
super().__post_init__(**kwargs)
|
113
109
|
self.use_encoder_hidden_states = any(
|
@@ -116,26 +112,11 @@ class RBLNControlNetModel(RBLNModel):
|
|
116
112
|
|
117
113
|
@classmethod
|
118
114
|
def from_pretrained(cls, *args, **kwargs):
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
def get_model_from_task(
|
123
|
-
task: str,
|
124
|
-
model_name_or_path: Union[str, Path],
|
125
|
-
**kwargs,
|
115
|
+
with override_auto_classes(
|
116
|
+
config_func=ControlNetModel.load_config,
|
117
|
+
model_func=ControlNetModel.from_pretrained,
|
126
118
|
):
|
127
|
-
|
128
|
-
|
129
|
-
tasktmp = TasksManager.get_model_from_task
|
130
|
-
configtmp = AutoConfig.from_pretrained
|
131
|
-
modeltmp = AutoModel.from_pretrained
|
132
|
-
TasksManager.get_model_from_task = get_model_from_task
|
133
|
-
AutoConfig.from_pretrained = ControlNetModel.load_config
|
134
|
-
AutoModel.from_pretrained = ControlNetModel.from_pretrained
|
135
|
-
rt = super().from_pretrained(*args, **kwargs)
|
136
|
-
AutoConfig.from_pretrained = configtmp
|
137
|
-
AutoModel.from_pretrained = modeltmp
|
138
|
-
TasksManager.get_model_from_task = tasktmp
|
119
|
+
rt = super().from_pretrained(*args, **kwargs)
|
139
120
|
return rt
|
140
121
|
|
141
122
|
@classmethod
|
@@ -157,33 +138,35 @@ class RBLNControlNetModel(RBLNModel):
|
|
157
138
|
model_config: "PretrainedConfig",
|
158
139
|
rbln_kwargs: Dict[str, Any] = {},
|
159
140
|
) -> RBLNConfig:
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
rbln_img_height = rbln_kwargs.get("img_height", None)
|
165
|
-
rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
|
141
|
+
batch_size = rbln_kwargs.get("batch_size")
|
142
|
+
max_seq_len = rbln_kwargs.get("max_seq_len")
|
143
|
+
unet_sample_size = rbln_kwargs.get("unet_sample_size")
|
144
|
+
vae_sample_size = rbln_kwargs.get("vae_sample_size")
|
166
145
|
|
167
|
-
if
|
168
|
-
|
146
|
+
if batch_size is None:
|
147
|
+
batch_size = 1
|
169
148
|
|
170
|
-
if
|
171
|
-
|
149
|
+
if unet_sample_size is None:
|
150
|
+
raise ValueError(
|
151
|
+
"`rbln_unet_sample_size` (latent height, widht) must be specified (ex. unet's sample_size)"
|
152
|
+
)
|
172
153
|
|
173
|
-
if
|
174
|
-
raise ValueError(
|
154
|
+
if vae_sample_size is None:
|
155
|
+
raise ValueError(
|
156
|
+
"`rbln_vae_sample_size` (input image height, width) must be specified (ex. vae's sample_size)"
|
157
|
+
)
|
175
158
|
|
176
|
-
|
177
|
-
|
159
|
+
if max_seq_len is None:
|
160
|
+
raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings )must be specified")
|
178
161
|
|
179
162
|
input_info = [
|
180
163
|
(
|
181
164
|
"sample",
|
182
165
|
[
|
183
|
-
|
166
|
+
batch_size,
|
184
167
|
model_config.in_channels,
|
185
|
-
|
186
|
-
|
168
|
+
unet_sample_size[0],
|
169
|
+
unet_sample_size[1],
|
187
170
|
],
|
188
171
|
"float32",
|
189
172
|
),
|
@@ -195,23 +178,24 @@ class RBLNControlNetModel(RBLNModel):
|
|
195
178
|
input_info.append(
|
196
179
|
(
|
197
180
|
"encoder_hidden_states",
|
198
|
-
[
|
199
|
-
rbln_batch_size,
|
200
|
-
rbln_max_seq_len,
|
201
|
-
model_config.cross_attention_dim,
|
202
|
-
],
|
181
|
+
[batch_size, max_seq_len, model_config.cross_attention_dim],
|
203
182
|
"float32",
|
204
183
|
)
|
205
184
|
)
|
206
185
|
|
207
|
-
input_info.append(
|
186
|
+
input_info.append(
|
187
|
+
(
|
188
|
+
"controlnet_cond",
|
189
|
+
[batch_size, 3, vae_sample_size[0], vae_sample_size[1]],
|
190
|
+
"float32",
|
191
|
+
)
|
192
|
+
)
|
208
193
|
input_info.append(("conditioning_scale", [], "float32"))
|
209
194
|
|
210
195
|
if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
|
211
|
-
|
212
|
-
|
213
|
-
input_info.append(("
|
214
|
-
input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
|
196
|
+
rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
|
197
|
+
input_info.append(("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32"))
|
198
|
+
input_info.append(("time_ids", [batch_size, 6], "float32"))
|
215
199
|
|
216
200
|
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
217
201
|
|
@@ -221,16 +205,6 @@ class RBLNControlNetModel(RBLNModel):
|
|
221
205
|
rbln_kwargs=rbln_kwargs,
|
222
206
|
)
|
223
207
|
|
224
|
-
rbln_config.model_cfg.update(
|
225
|
-
{
|
226
|
-
"max_seq_len": rbln_max_seq_len,
|
227
|
-
"batch_size": rbln_batch_size,
|
228
|
-
"img_width": rbln_img_width,
|
229
|
-
"img_height": rbln_img_height,
|
230
|
-
"vae_scale_factor": rbln_vae_scale_factor,
|
231
|
-
}
|
232
|
-
)
|
233
|
-
|
234
208
|
return rbln_config
|
235
209
|
|
236
210
|
def forward(
|