optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +41 -38
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +26 -2
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
- optimum/rbln/diffusers/models/__init__.py +36 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
- optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
- optimum/rbln/diffusers/models/controlnet.py +54 -14
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
- optimum/rbln/diffusers/pipelines/__init__.py +23 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/modeling.py +238 -0
- optimum/rbln/modeling_base.py +186 -760
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -2
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
- optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
- optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
- optimum/rbln/utils/decorator_utils.py +51 -11
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +22 -1
- optimum/rbln/utils/logging.py +37 -0
- optimum/rbln/utils/model_utils.py +52 -0
- optimum/rbln/utils/runtime_utils.py +10 -4
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +137 -0
- optimum_rbln-0.2.0.dist-info/METADATA +117 -0
- optimum_rbln-0.2.0.dist-info/RECORD +114 -0
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
- optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum/rbln/utils/context.py +0 -58
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.13.dist-info/METADATA +0 -120
- optimum_rbln-0.1.13.dist-info/RECORD +0 -107
- optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
- optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
optimum/rbln/__init__.py
CHANGED
@@ -30,27 +30,15 @@ from .utils import check_version_compats
|
|
30
30
|
|
31
31
|
|
32
32
|
_import_structure = {
|
33
|
-
"
|
34
|
-
"RBLNASTForAudioClassification",
|
35
|
-
"RBLNBertForQuestionAnswering",
|
36
|
-
"RBLNDistilBertForQuestionAnswering",
|
37
|
-
"RBLNResNetForImageClassification",
|
38
|
-
"RBLNXLMRobertaForSequenceClassification",
|
39
|
-
"RBLNRobertaForSequenceClassification",
|
40
|
-
"RBLNRobertaForMaskedLM",
|
41
|
-
"RBLNViTForImageClassification",
|
42
|
-
],
|
43
|
-
"modeling_base": [
|
33
|
+
"modeling": [
|
44
34
|
"RBLNBaseModel",
|
45
35
|
"RBLNModel",
|
46
|
-
|
47
|
-
|
48
|
-
"
|
49
|
-
"
|
50
|
-
"RBLNModelForMaskedLM",
|
36
|
+
],
|
37
|
+
"modeling_config": [
|
38
|
+
"RBLNCompileConfig",
|
39
|
+
"RBLNConfig",
|
51
40
|
],
|
52
41
|
"transformers": [
|
53
|
-
"BatchTextIteratorStreamer",
|
54
42
|
"RBLNAutoModel",
|
55
43
|
"RBLNAutoModelForAudioClassification",
|
56
44
|
"RBLNAutoModelForCausalLM",
|
@@ -84,6 +72,14 @@ _import_structure = {
|
|
84
72
|
"RBLNMistralForCausalLM",
|
85
73
|
"RBLNWhisperForConditionalGeneration",
|
86
74
|
"RBLNXLMRobertaModel",
|
75
|
+
"RBLNASTForAudioClassification",
|
76
|
+
"RBLNBertForQuestionAnswering",
|
77
|
+
"RBLNDistilBertForQuestionAnswering",
|
78
|
+
"RBLNResNetForImageClassification",
|
79
|
+
"RBLNXLMRobertaForSequenceClassification",
|
80
|
+
"RBLNRobertaForSequenceClassification",
|
81
|
+
"RBLNRobertaForMaskedLM",
|
82
|
+
"RBLNViTForImageClassification",
|
87
83
|
],
|
88
84
|
"diffusers": [
|
89
85
|
"RBLNStableDiffusionPipeline",
|
@@ -92,55 +88,54 @@ _import_structure = {
|
|
92
88
|
"RBLNUNet2DConditionModel",
|
93
89
|
"RBLNControlNetModel",
|
94
90
|
"RBLNStableDiffusionImg2ImgPipeline",
|
91
|
+
"RBLNStableDiffusionInpaintPipeline",
|
95
92
|
"RBLNStableDiffusionControlNetImg2ImgPipeline",
|
96
93
|
"RBLNMultiControlNetModel",
|
97
94
|
"RBLNStableDiffusionXLImg2ImgPipeline",
|
95
|
+
"RBLNStableDiffusionXLInpaintPipeline",
|
98
96
|
"RBLNStableDiffusionControlNetPipeline",
|
99
97
|
"RBLNStableDiffusionXLControlNetPipeline",
|
100
98
|
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
99
|
+
"RBLNSD3Transformer2DModel",
|
100
|
+
"RBLNStableDiffusion3Img2ImgPipeline",
|
101
|
+
"RBLNStableDiffusion3InpaintPipeline",
|
102
|
+
"RBLNStableDiffusion3Pipeline",
|
103
|
+
"RBLNDiffusionMixin",
|
101
104
|
],
|
102
|
-
"modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
|
103
|
-
"modeling_diffusers": ["RBLNDiffusionMixin"],
|
104
105
|
}
|
105
106
|
|
106
107
|
if TYPE_CHECKING:
|
107
108
|
from .diffusers import (
|
108
109
|
RBLNAutoencoderKL,
|
109
110
|
RBLNControlNetModel,
|
111
|
+
RBLNDiffusionMixin,
|
110
112
|
RBLNMultiControlNetModel,
|
113
|
+
RBLNSD3Transformer2DModel,
|
114
|
+
RBLNStableDiffusion3Img2ImgPipeline,
|
115
|
+
RBLNStableDiffusion3InpaintPipeline,
|
116
|
+
RBLNStableDiffusion3Pipeline,
|
111
117
|
RBLNStableDiffusionControlNetImg2ImgPipeline,
|
112
118
|
RBLNStableDiffusionControlNetPipeline,
|
113
119
|
RBLNStableDiffusionImg2ImgPipeline,
|
120
|
+
RBLNStableDiffusionInpaintPipeline,
|
114
121
|
RBLNStableDiffusionPipeline,
|
115
122
|
RBLNStableDiffusionXLControlNetImg2ImgPipeline,
|
116
123
|
RBLNStableDiffusionXLControlNetPipeline,
|
117
124
|
RBLNStableDiffusionXLImg2ImgPipeline,
|
125
|
+
RBLNStableDiffusionXLInpaintPipeline,
|
118
126
|
RBLNStableDiffusionXLPipeline,
|
119
127
|
RBLNUNet2DConditionModel,
|
120
128
|
)
|
121
|
-
from .
|
122
|
-
RBLNASTForAudioClassification,
|
123
|
-
RBLNBertForQuestionAnswering,
|
124
|
-
RBLNResNetForImageClassification,
|
125
|
-
RBLNRobertaForMaskedLM,
|
126
|
-
RBLNRobertaForSequenceClassification,
|
127
|
-
RBLNT5ForConditionalGeneration,
|
128
|
-
RBLNViTForImageClassification,
|
129
|
-
RBLNXLMRobertaForSequenceClassification,
|
130
|
-
)
|
131
|
-
from .modeling_base import (
|
129
|
+
from .modeling import (
|
132
130
|
RBLNBaseModel,
|
133
131
|
RBLNModel,
|
134
|
-
RBLNModelForAudioClassification,
|
135
|
-
RBLNModelForImageClassification,
|
136
|
-
RBLNModelForMaskedLM,
|
137
|
-
RBLNModelForQuestionAnswering,
|
138
|
-
RBLNModelForSequenceClassification,
|
139
132
|
)
|
140
|
-
from .modeling_config import
|
141
|
-
|
133
|
+
from .modeling_config import (
|
134
|
+
RBLNCompileConfig,
|
135
|
+
RBLNConfig,
|
136
|
+
)
|
142
137
|
from .transformers import (
|
143
|
-
|
138
|
+
RBLNASTForAudioClassification,
|
144
139
|
RBLNAutoModel,
|
145
140
|
RBLNAutoModelForAudioClassification,
|
146
141
|
RBLNAutoModelForCausalLM,
|
@@ -155,10 +150,12 @@ if TYPE_CHECKING:
|
|
155
150
|
RBLNAutoModelForVision2Seq,
|
156
151
|
RBLNBartForConditionalGeneration,
|
157
152
|
RBLNBartModel,
|
153
|
+
RBLNBertForQuestionAnswering,
|
158
154
|
RBLNBertModel,
|
159
155
|
RBLNCLIPTextModel,
|
160
156
|
RBLNCLIPTextModelWithProjection,
|
161
157
|
RBLNCLIPVisionModel,
|
158
|
+
RBLNDistilBertForQuestionAnswering,
|
162
159
|
RBLNDPTForDepthEstimation,
|
163
160
|
RBLNExaoneForCausalLM,
|
164
161
|
RBLNGemmaForCausalLM,
|
@@ -169,12 +166,18 @@ if TYPE_CHECKING:
|
|
169
166
|
RBLNMistralForCausalLM,
|
170
167
|
RBLNPhiForCausalLM,
|
171
168
|
RBLNQwen2ForCausalLM,
|
169
|
+
RBLNResNetForImageClassification,
|
170
|
+
RBLNRobertaForMaskedLM,
|
171
|
+
RBLNRobertaForSequenceClassification,
|
172
172
|
RBLNT5EncoderModel,
|
173
173
|
RBLNT5ForConditionalGeneration,
|
174
|
+
RBLNViTForImageClassification,
|
174
175
|
RBLNWav2Vec2ForCTC,
|
175
176
|
RBLNWhisperForConditionalGeneration,
|
177
|
+
RBLNXLMRobertaForSequenceClassification,
|
176
178
|
RBLNXLMRobertaModel,
|
177
179
|
)
|
180
|
+
|
178
181
|
else:
|
179
182
|
import sys
|
180
183
|
|
optimum/rbln/__version__.py
CHANGED
@@ -1 +1,16 @@
|
|
1
|
-
|
1
|
+
# file generated by setuptools_scm
|
2
|
+
# don't change, don't track in version control
|
3
|
+
TYPE_CHECKING = False
|
4
|
+
if TYPE_CHECKING:
|
5
|
+
from typing import Tuple, Union
|
6
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
7
|
+
else:
|
8
|
+
VERSION_TUPLE = object
|
9
|
+
|
10
|
+
version: str
|
11
|
+
__version__: str
|
12
|
+
__version_tuple__: VERSION_TUPLE
|
13
|
+
version_tuple: VERSION_TUPLE
|
14
|
+
|
15
|
+
__version__ = version = '0.2.0'
|
16
|
+
__version_tuple__ = version_tuple = (0, 2, 0)
|
@@ -36,27 +36,51 @@ _import_structure = {
|
|
36
36
|
"RBLNStableDiffusionPipeline",
|
37
37
|
"RBLNStableDiffusionXLPipeline",
|
38
38
|
"RBLNStableDiffusionImg2ImgPipeline",
|
39
|
+
"RBLNStableDiffusionInpaintPipeline",
|
39
40
|
"RBLNStableDiffusionControlNetImg2ImgPipeline",
|
40
41
|
"RBLNMultiControlNetModel",
|
41
42
|
"RBLNStableDiffusionXLImg2ImgPipeline",
|
43
|
+
"RBLNStableDiffusionXLInpaintPipeline",
|
42
44
|
"RBLNStableDiffusionControlNetPipeline",
|
43
45
|
"RBLNStableDiffusionXLControlNetPipeline",
|
44
46
|
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
47
|
+
"RBLNStableDiffusion3Pipeline",
|
48
|
+
"RBLNStableDiffusion3Img2ImgPipeline",
|
49
|
+
"RBLNStableDiffusion3InpaintPipeline",
|
50
|
+
],
|
51
|
+
"models": [
|
52
|
+
"RBLNAutoencoderKL",
|
53
|
+
"RBLNUNet2DConditionModel",
|
54
|
+
"RBLNControlNetModel",
|
55
|
+
"RBLNSD3Transformer2DModel",
|
56
|
+
],
|
57
|
+
"modeling_diffusers": [
|
58
|
+
"RBLNDiffusionMixin",
|
45
59
|
],
|
46
|
-
"models": ["RBLNAutoencoderKL", "RBLNUNet2DConditionModel", "RBLNControlNetModel"],
|
47
60
|
}
|
48
61
|
|
49
62
|
if TYPE_CHECKING:
|
50
|
-
from .
|
63
|
+
from .modeling_diffusers import RBLNDiffusionMixin
|
64
|
+
from .models import (
|
65
|
+
RBLNAutoencoderKL,
|
66
|
+
RBLNControlNetModel,
|
67
|
+
RBLNSD3Transformer2DModel,
|
68
|
+
RBLNUNet2DConditionModel,
|
69
|
+
)
|
51
70
|
from .pipelines import (
|
52
71
|
RBLNMultiControlNetModel,
|
72
|
+
RBLNStableDiffusion3Img2ImgPipeline,
|
73
|
+
RBLNStableDiffusion3InpaintPipeline,
|
74
|
+
RBLNStableDiffusion3Pipeline,
|
53
75
|
RBLNStableDiffusionControlNetImg2ImgPipeline,
|
54
76
|
RBLNStableDiffusionControlNetPipeline,
|
55
77
|
RBLNStableDiffusionImg2ImgPipeline,
|
78
|
+
RBLNStableDiffusionInpaintPipeline,
|
56
79
|
RBLNStableDiffusionPipeline,
|
57
80
|
RBLNStableDiffusionXLControlNetImg2ImgPipeline,
|
58
81
|
RBLNStableDiffusionXLControlNetPipeline,
|
59
82
|
RBLNStableDiffusionXLImg2ImgPipeline,
|
83
|
+
RBLNStableDiffusionXLInpaintPipeline,
|
60
84
|
RBLNStableDiffusionXLPipeline,
|
61
85
|
)
|
62
86
|
else:
|
@@ -20,16 +20,21 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import copy
|
23
25
|
import importlib
|
24
26
|
from os import PathLike
|
25
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional,
|
27
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
26
28
|
|
27
29
|
import torch
|
28
30
|
|
29
|
-
from
|
30
|
-
from
|
31
|
-
from
|
31
|
+
from ..modeling import RBLNModel
|
32
|
+
from ..modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
|
33
|
+
from ..utils.decorator_utils import remove_compile_time_kwargs
|
34
|
+
from ..utils.logging import get_logger
|
35
|
+
|
32
36
|
|
37
|
+
logger = get_logger(__name__)
|
33
38
|
|
34
39
|
if TYPE_CHECKING:
|
35
40
|
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
@@ -74,127 +79,40 @@ class RBLNDiffusionMixin:
|
|
74
79
|
|
75
80
|
@classmethod
|
76
81
|
@property
|
77
|
-
def
|
82
|
+
def img2img_pipeline(cls):
|
78
83
|
return "Img2Img" in cls.__name__
|
79
84
|
|
80
85
|
@classmethod
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
do_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
85
|
-
return batch_size * 2 if do_guidance else batch_size
|
86
|
-
|
87
|
-
@classmethod
|
88
|
-
def _get_vae_sample_size(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
|
89
|
-
image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
|
90
|
-
if (image_size[0] is None) != (image_size[1] is None):
|
91
|
-
raise ValueError("Both image height and image width must be given or not given")
|
92
|
-
elif image_size[0] is None and image_size[1] is None:
|
93
|
-
if cls.use_encode:
|
94
|
-
sample_size = model.vae.config.sample_size
|
95
|
-
else:
|
96
|
-
# In case of text2img, sample size of vae decoder is determined by unet.
|
97
|
-
unet_sample_size = model.unet.config.sample_size
|
98
|
-
if isinstance(unet_sample_size, int):
|
99
|
-
sample_size = unet_sample_size * model.vae_scale_factor
|
100
|
-
else:
|
101
|
-
sample_size = (
|
102
|
-
unet_sample_size[0] * model.vae_scale_factor,
|
103
|
-
unet_sample_size[1] * model.vae_scale_factor,
|
104
|
-
)
|
105
|
-
|
106
|
-
else:
|
107
|
-
sample_size = (image_size[0], image_size[1])
|
108
|
-
return sample_size
|
109
|
-
|
110
|
-
@classmethod
|
111
|
-
def _get_unet_sample_size(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
|
112
|
-
image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
|
113
|
-
if (image_size[0] is None) != (image_size[1] is None):
|
114
|
-
raise ValueError("Both image height and image width must be given or not given")
|
115
|
-
elif image_size[0] is None and image_size[1] is None:
|
116
|
-
if cls.use_encode:
|
117
|
-
# In case of img2img, sample size of unet is determined by vae encoder.
|
118
|
-
vae_sample_size = model.vae.config.sample_size
|
119
|
-
if isinstance(vae_sample_size, int):
|
120
|
-
sample_size = vae_sample_size // model.vae_scale_factor
|
121
|
-
else:
|
122
|
-
sample_size = (
|
123
|
-
vae_sample_size[0] // model.vae_scale_factor,
|
124
|
-
vae_sample_size[1] // model.vae_scale_factor,
|
125
|
-
)
|
126
|
-
else:
|
127
|
-
sample_size = model.unet.config.sample_size
|
128
|
-
else:
|
129
|
-
sample_size = (image_size[0] // model.vae_scale_factor, image_size[1] // model.vae_scale_factor)
|
130
|
-
return sample_size
|
131
|
-
|
132
|
-
@classmethod
|
133
|
-
def _get_default_config(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
134
|
-
# default configurations for each submodules
|
135
|
-
return {"img2img_pipeline": cls.use_encode}
|
86
|
+
@property
|
87
|
+
def inpaint_pipeline(cls):
|
88
|
+
return "Inpaint" in cls.__name__
|
136
89
|
|
137
90
|
@classmethod
|
138
|
-
def
|
139
|
-
cls, model: torch.nn.Module, rbln_config: Dict[str, Any]
|
91
|
+
def get_submodule_rbln_config(
|
92
|
+
cls, model: torch.nn.Module, submodule_name: str, rbln_config: Dict[str, Any]
|
140
93
|
) -> Dict[str, Any]:
|
141
|
-
|
142
|
-
|
94
|
+
submodule = getattr(model, submodule_name)
|
95
|
+
submodule_class_name = submodule.__class__.__name__
|
143
96
|
|
144
|
-
|
145
|
-
|
146
|
-
cls, model: torch.nn.Module, rbln_config: Dict[str, Any]
|
147
|
-
) -> Dict[str, Any]:
|
148
|
-
batch_size = rbln_config.get("batch_size", 1)
|
149
|
-
return {"batch_size": batch_size}
|
97
|
+
if submodule_class_name == "MultiControlNetModel":
|
98
|
+
submodule_class_name = "ControlNetModel"
|
150
99
|
|
151
|
-
|
152
|
-
def get_default_rbln_config_unet(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
153
|
-
# configuration for unet
|
154
|
-
unet_batch_size = cls._get_unet_batch_size(model, rbln_config)
|
155
|
-
text_model_hidden_size = model.text_encoder_2.config.hidden_size if hasattr(model, "text_encoder_2") else None
|
156
|
-
return {
|
157
|
-
**cls._get_default_config(model, rbln_config),
|
158
|
-
"max_seq_len": model.text_encoder.config.max_position_embeddings,
|
159
|
-
"text_model_hidden_size": text_model_hidden_size,
|
160
|
-
"batch_size": unet_batch_size,
|
161
|
-
"sample_size": cls._get_unet_sample_size(model, rbln_config),
|
162
|
-
"is_controlnet": "controlnet" in model.config.keys(),
|
163
|
-
}
|
100
|
+
submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"RBLN{submodule_class_name}")
|
164
101
|
|
165
|
-
|
166
|
-
|
167
|
-
# configuration for vae
|
168
|
-
batch_size = rbln_config.get("batch_size", 1)
|
169
|
-
return {
|
170
|
-
**cls._get_default_config(model, rbln_config),
|
171
|
-
"sample_size": cls._get_vae_sample_size(model, rbln_config),
|
172
|
-
"batch_size": batch_size,
|
173
|
-
}
|
102
|
+
submodule_config = rbln_config.get(submodule_name, {})
|
103
|
+
submodule_config = copy.deepcopy(submodule_config)
|
174
104
|
|
175
|
-
|
176
|
-
def get_default_rbln_config_controlnet(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
177
|
-
# configuration for controlnet
|
178
|
-
unet_batch_size = cls._get_unet_batch_size(model, rbln_config)
|
179
|
-
text_model_hidden_size = model.text_encoder_2.config.hidden_size if hasattr(model, "text_encoder_2") else None
|
180
|
-
return {
|
181
|
-
**cls._get_default_config(model, rbln_config),
|
182
|
-
"max_seq_len": model.text_encoder.config.max_position_embeddings,
|
183
|
-
"vae_sample_size": cls._get_vae_sample_size(model, rbln_config),
|
184
|
-
"unet_sample_size": cls._get_unet_sample_size(model, rbln_config),
|
185
|
-
"batch_size": unet_batch_size,
|
186
|
-
"text_model_hidden_size": text_model_hidden_size,
|
187
|
-
}
|
105
|
+
pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
|
188
106
|
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
107
|
+
submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
|
108
|
+
submodule_config.update(
|
109
|
+
{
|
110
|
+
"img2img_pipeline": cls.img2img_pipeline,
|
111
|
+
"inpaint_pipeline": cls.inpaint_pipeline,
|
112
|
+
}
|
113
|
+
)
|
114
|
+
submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
|
115
|
+
return submodule_config
|
198
116
|
|
199
117
|
@staticmethod
|
200
118
|
def _maybe_apply_and_fuse_lora(
|
@@ -256,17 +174,46 @@ class RBLNDiffusionMixin:
|
|
256
174
|
|
257
175
|
else:
|
258
176
|
# raise error if any of submodules are torch module.
|
259
|
-
|
260
|
-
|
177
|
+
model_index_config = None
|
178
|
+
for submodule_name in cls._submodules:
|
179
|
+
if isinstance(kwargs.get(submodule_name), torch.nn.Module):
|
261
180
|
raise AssertionError(
|
262
|
-
f"{
|
181
|
+
f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
|
263
182
|
)
|
264
183
|
|
184
|
+
submodule_config = rbln_config.get(submodule_name, {})
|
185
|
+
|
186
|
+
for key, value in rbln_config.items():
|
187
|
+
if key in RUNTIME_KEYWORDS and key not in submodule_config:
|
188
|
+
submodule_config[key] = value
|
189
|
+
|
190
|
+
if not any(kwd in submodule_config for kwd in RUNTIME_KEYWORDS):
|
191
|
+
continue
|
192
|
+
|
193
|
+
if model_index_config is None:
|
194
|
+
model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
|
195
|
+
|
196
|
+
module_name, class_name = model_index_config[submodule_name]
|
197
|
+
if module_name != "optimum.rbln":
|
198
|
+
raise ValueError(
|
199
|
+
f"Invalid module_name '{module_name}' found in model_index.json for "
|
200
|
+
f"submodule '{submodule_name}'. "
|
201
|
+
"Expected 'optimum.rbln'. Please check the model_index.json configuration."
|
202
|
+
)
|
203
|
+
|
204
|
+
submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), class_name)
|
205
|
+
|
206
|
+
submodule = submodule_cls.from_pretrained(
|
207
|
+
model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
|
208
|
+
)
|
209
|
+
kwargs[submodule_name] = submodule
|
210
|
+
|
265
211
|
with ContextRblnConfig(
|
266
212
|
device=rbln_config.get("device"),
|
267
213
|
device_map=rbln_config.get("device_map"),
|
268
214
|
create_runtimes=rbln_config.get("create_runtimes"),
|
269
215
|
optimize_host_mem=rbln_config.get("optimize_host_memory"),
|
216
|
+
activate_profiler=rbln_config.get("activate_profiler"),
|
270
217
|
):
|
271
218
|
model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
|
272
219
|
|
@@ -291,16 +238,11 @@ class RBLNDiffusionMixin:
|
|
291
238
|
model_save_dir: Optional[PathLike],
|
292
239
|
rbln_config: Dict[str, Any],
|
293
240
|
) -> Dict[str, RBLNModel]:
|
294
|
-
# Compile submodules based on rbln_config
|
295
241
|
compiled_submodules = {}
|
296
242
|
|
297
|
-
# FIXME : Currently, optimum-rbln for transformer does not use base rbln config.
|
298
|
-
base_rbln_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
|
299
243
|
for submodule_name in cls._submodules:
|
300
244
|
submodule = passed_submodules.get(submodule_name) or getattr(model, submodule_name, None)
|
301
|
-
submodule_rbln_config = cls.
|
302
|
-
submodule_rbln_config.update(base_rbln_config)
|
303
|
-
submodule_rbln_config.update(rbln_config.get(submodule_name, {}))
|
245
|
+
submodule_rbln_config = cls.get_submodule_rbln_config(model, submodule_name, rbln_config)
|
304
246
|
|
305
247
|
if submodule is None:
|
306
248
|
raise ValueError(f"submodule ({submodule_name}) cannot be accessed since it is not provided.")
|
@@ -337,8 +279,8 @@ class RBLNDiffusionMixin:
|
|
337
279
|
controlnet_rbln_config: Dict[str, Any],
|
338
280
|
):
|
339
281
|
# Compile multiple ControlNet models for a MultiControlNet setup
|
340
|
-
from .
|
341
|
-
from .
|
282
|
+
from .models.controlnet import RBLNControlNetModel
|
283
|
+
from .pipelines.controlnet import RBLNMultiControlNetModel
|
342
284
|
|
343
285
|
compiled_controlnets = [
|
344
286
|
RBLNControlNetModel.from_model(
|
@@ -349,7 +291,7 @@ class RBLNDiffusionMixin:
|
|
349
291
|
)
|
350
292
|
for i, controlnet in enumerate(controlnets.nets)
|
351
293
|
]
|
352
|
-
return RBLNMultiControlNetModel(compiled_controlnets
|
294
|
+
return RBLNMultiControlNetModel(compiled_controlnets)
|
353
295
|
|
354
296
|
@classmethod
|
355
297
|
def _construct_pipe(cls, model, submodules, model_save_dir, rbln_config):
|
@@ -395,6 +337,35 @@ class RBLNDiffusionMixin:
|
|
395
337
|
|
396
338
|
return model
|
397
339
|
|
340
|
+
def get_compiled_image_size(self):
|
341
|
+
if hasattr(self, "vae"):
|
342
|
+
compiled_image_size = self.vae.image_size
|
343
|
+
else:
|
344
|
+
compiled_image_size = None
|
345
|
+
return compiled_image_size
|
346
|
+
|
347
|
+
def handle_additional_kwargs(self, **kwargs):
|
348
|
+
"""
|
349
|
+
Function to handle additional compile-time parameters during inference.
|
350
|
+
|
351
|
+
If the additional variable is determined by another module, this method should be overrided.
|
352
|
+
|
353
|
+
Example:
|
354
|
+
```python
|
355
|
+
if hasattr(self, "movq"):
|
356
|
+
compiled_image_size = self.movq.image_size
|
357
|
+
kwargs["height"] = compiled_image_size[0]
|
358
|
+
kwargs["width"] = compiled_image_size[1]
|
359
|
+
|
360
|
+
compiled_num_frames = self.unet.rbln_config.model_cfg.get("num_frames", None)
|
361
|
+
if compiled_num_frames is not None:
|
362
|
+
kwargs["num_frames"] = self.unet.rbln_config.model_cfg.get("num_frames")
|
363
|
+
return kwargs
|
364
|
+
```
|
365
|
+
"""
|
366
|
+
return kwargs
|
367
|
+
|
398
368
|
@remove_compile_time_kwargs
|
399
369
|
def __call__(self, *args, **kwargs):
|
370
|
+
kwargs = self.handle_additional_kwargs(**kwargs)
|
400
371
|
return super().__call__(*args, **kwargs)
|
@@ -21,6 +21,39 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from
|
25
|
-
|
26
|
-
from .
|
24
|
+
from typing import TYPE_CHECKING
|
25
|
+
|
26
|
+
from transformers.utils import _LazyModule
|
27
|
+
|
28
|
+
|
29
|
+
_import_structure = {
|
30
|
+
"autoencoders": [
|
31
|
+
"RBLNAutoencoderKL",
|
32
|
+
],
|
33
|
+
"unets": [
|
34
|
+
"RBLNUNet2DConditionModel",
|
35
|
+
],
|
36
|
+
"controlnet": ["RBLNControlNetModel"],
|
37
|
+
"transformers": ["RBLNSD3Transformer2DModel"],
|
38
|
+
}
|
39
|
+
|
40
|
+
if TYPE_CHECKING:
|
41
|
+
from .autoencoders import (
|
42
|
+
RBLNAutoencoderKL,
|
43
|
+
)
|
44
|
+
from .controlnet import RBLNControlNetModel
|
45
|
+
from .transformers import (
|
46
|
+
RBLNSD3Transformer2DModel,
|
47
|
+
)
|
48
|
+
from .unets import (
|
49
|
+
RBLNUNet2DConditionModel,
|
50
|
+
)
|
51
|
+
else:
|
52
|
+
import sys
|
53
|
+
|
54
|
+
sys.modules[__name__] = _LazyModule(
|
55
|
+
__name__,
|
56
|
+
globals()["__file__"],
|
57
|
+
_import_structure,
|
58
|
+
module_spec=__spec__,
|
59
|
+
)
|