optimum-rbln 0.2.1a4__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -2
- optimum/rbln/__version__.py +9 -4
- optimum/rbln/diffusers/__init__.py +10 -0
- optimum/rbln/diffusers/modeling_diffusers.py +132 -25
- optimum/rbln/diffusers/models/__init__.py +7 -1
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +52 -2
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +159 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +174 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +57 -14
- optimum/rbln/diffusers/pipelines/__init__.py +10 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +83 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +22 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +22 -0
- optimum/rbln/modeling_base.py +10 -9
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/models/__init__.py +12 -2
- optimum/rbln/transformers/models/clip/__init__.py +6 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +26 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +3 -1
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +1 -1
- optimum/rbln/utils/import_utils.py +7 -0
- {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/METADATA +1 -1
- {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/RECORD +28 -22
- {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -48,6 +48,7 @@ _import_structure = {
|
|
48
48
|
"RBLNCLIPTextModel",
|
49
49
|
"RBLNCLIPTextModelWithProjection",
|
50
50
|
"RBLNCLIPVisionModel",
|
51
|
+
"RBLNCLIPVisionModelWithProjection",
|
51
52
|
"RBLNDPTForDepthEstimation",
|
52
53
|
"RBLNExaoneForCausalLM",
|
53
54
|
"RBLNGemmaForCausalLM",
|
@@ -74,11 +75,15 @@ _import_structure = {
|
|
74
75
|
"RBLNBertForMaskedLM",
|
75
76
|
],
|
76
77
|
"diffusers": [
|
78
|
+
"RBLNAutoencoderKL",
|
79
|
+
"RBLNControlNetModel",
|
80
|
+
"RBLNPriorTransformer",
|
81
|
+
"RBLNKandinskyV22InpaintCombinedPipeline",
|
82
|
+
"RBLNKandinskyV22InpaintPipeline",
|
83
|
+
"RBLNKandinskyV22PriorPipeline",
|
77
84
|
"RBLNStableDiffusionPipeline",
|
78
85
|
"RBLNStableDiffusionXLPipeline",
|
79
|
-
"RBLNAutoencoderKL",
|
80
86
|
"RBLNUNet2DConditionModel",
|
81
|
-
"RBLNControlNetModel",
|
82
87
|
"RBLNStableDiffusionImg2ImgPipeline",
|
83
88
|
"RBLNStableDiffusionInpaintPipeline",
|
84
89
|
"RBLNStableDiffusionControlNetImg2ImgPipeline",
|
@@ -88,6 +93,7 @@ _import_structure = {
|
|
88
93
|
"RBLNStableDiffusionControlNetPipeline",
|
89
94
|
"RBLNStableDiffusionXLControlNetPipeline",
|
90
95
|
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
96
|
+
"RBLNVQModel",
|
91
97
|
"RBLNSD3Transformer2DModel",
|
92
98
|
"RBLNStableDiffusion3Img2ImgPipeline",
|
93
99
|
"RBLNStableDiffusion3InpaintPipeline",
|
@@ -101,7 +107,11 @@ if TYPE_CHECKING:
|
|
101
107
|
RBLNAutoencoderKL,
|
102
108
|
RBLNControlNetModel,
|
103
109
|
RBLNDiffusionMixin,
|
110
|
+
RBLNKandinskyV22InpaintCombinedPipeline,
|
111
|
+
RBLNKandinskyV22InpaintPipeline,
|
112
|
+
RBLNKandinskyV22PriorPipeline,
|
104
113
|
RBLNMultiControlNetModel,
|
114
|
+
RBLNPriorTransformer,
|
105
115
|
RBLNSD3Transformer2DModel,
|
106
116
|
RBLNStableDiffusion3Img2ImgPipeline,
|
107
117
|
RBLNStableDiffusion3InpaintPipeline,
|
@@ -117,6 +127,7 @@ if TYPE_CHECKING:
|
|
117
127
|
RBLNStableDiffusionXLInpaintPipeline,
|
118
128
|
RBLNStableDiffusionXLPipeline,
|
119
129
|
RBLNUNet2DConditionModel,
|
130
|
+
RBLNVQModel,
|
120
131
|
)
|
121
132
|
from .modeling import (
|
122
133
|
RBLNBaseModel,
|
@@ -148,6 +159,7 @@ if TYPE_CHECKING:
|
|
148
159
|
RBLNCLIPTextModel,
|
149
160
|
RBLNCLIPTextModelWithProjection,
|
150
161
|
RBLNCLIPVisionModel,
|
162
|
+
RBLNCLIPVisionModelWithProjection,
|
151
163
|
RBLNDistilBertForQuestionAnswering,
|
152
164
|
RBLNDPTForDepthEstimation,
|
153
165
|
RBLNExaoneForCausalLM,
|
optimum/rbln/__version__.py
CHANGED
@@ -1,8 +1,13 @@
|
|
1
|
-
# file generated by
|
1
|
+
# file generated by setuptools-scm
|
2
2
|
# don't change, don't track in version control
|
3
|
+
|
4
|
+
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
5
|
+
|
3
6
|
TYPE_CHECKING = False
|
4
7
|
if TYPE_CHECKING:
|
5
|
-
from typing import Tuple
|
8
|
+
from typing import Tuple
|
9
|
+
from typing import Union
|
10
|
+
|
6
11
|
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
7
12
|
else:
|
8
13
|
VERSION_TUPLE = object
|
@@ -12,5 +17,5 @@ __version__: str
|
|
12
17
|
__version_tuple__: VERSION_TUPLE
|
13
18
|
version_tuple: VERSION_TUPLE
|
14
19
|
|
15
|
-
__version__ = version = '0.2
|
16
|
-
__version_tuple__ = version_tuple = (0,
|
20
|
+
__version__ = version = '0.7.2'
|
21
|
+
__version_tuple__ = version_tuple = (0, 7, 2)
|
@@ -24,6 +24,9 @@ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES["optimum.rbln"])
|
|
24
24
|
|
25
25
|
_import_structure = {
|
26
26
|
"pipelines": [
|
27
|
+
"RBLNKandinskyV22InpaintCombinedPipeline",
|
28
|
+
"RBLNKandinskyV22InpaintPipeline",
|
29
|
+
"RBLNKandinskyV22PriorPipeline",
|
27
30
|
"RBLNStableDiffusionPipeline",
|
28
31
|
"RBLNStableDiffusionXLPipeline",
|
29
32
|
"RBLNStableDiffusionImg2ImgPipeline",
|
@@ -44,6 +47,8 @@ _import_structure = {
|
|
44
47
|
"RBLNUNet2DConditionModel",
|
45
48
|
"RBLNControlNetModel",
|
46
49
|
"RBLNSD3Transformer2DModel",
|
50
|
+
"RBLNPriorTransformer",
|
51
|
+
"RBLNVQModel",
|
47
52
|
],
|
48
53
|
"modeling_diffusers": [
|
49
54
|
"RBLNDiffusionMixin",
|
@@ -55,10 +60,15 @@ if TYPE_CHECKING:
|
|
55
60
|
from .models import (
|
56
61
|
RBLNAutoencoderKL,
|
57
62
|
RBLNControlNetModel,
|
63
|
+
RBLNPriorTransformer,
|
58
64
|
RBLNSD3Transformer2DModel,
|
59
65
|
RBLNUNet2DConditionModel,
|
66
|
+
RBLNVQModel,
|
60
67
|
)
|
61
68
|
from .pipelines import (
|
69
|
+
RBLNKandinskyV22InpaintCombinedPipeline,
|
70
|
+
RBLNKandinskyV22InpaintPipeline,
|
71
|
+
RBLNKandinskyV22PriorPipeline,
|
62
72
|
RBLNMultiControlNetModel,
|
63
73
|
RBLNStableDiffusion3Img2ImgPipeline,
|
64
74
|
RBLNStableDiffusion3InpaintPipeline,
|
@@ -23,6 +23,7 @@ from ..modeling import RBLNModel
|
|
23
23
|
from ..modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
|
24
24
|
from ..utils.decorator_utils import remove_compile_time_kwargs
|
25
25
|
from ..utils.logging import get_logger
|
26
|
+
from . import pipelines
|
26
27
|
|
27
28
|
|
28
29
|
logger = get_logger(__name__)
|
@@ -67,6 +68,7 @@ class RBLNDiffusionMixin:
|
|
67
68
|
"""
|
68
69
|
|
69
70
|
_submodules = []
|
71
|
+
_prefix = {}
|
70
72
|
|
71
73
|
@classmethod
|
72
74
|
@property
|
@@ -84,25 +86,58 @@ class RBLNDiffusionMixin:
|
|
84
86
|
) -> Dict[str, Any]:
|
85
87
|
submodule = getattr(model, submodule_name)
|
86
88
|
submodule_class_name = submodule.__class__.__name__
|
89
|
+
if isinstance(submodule, torch.nn.Module):
|
90
|
+
if submodule_class_name == "MultiControlNetModel":
|
91
|
+
submodule_class_name = "ControlNetModel"
|
87
92
|
|
88
|
-
|
89
|
-
submodule_class_name = "ControlNetModel"
|
93
|
+
submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"RBLN{submodule_class_name}")
|
90
94
|
|
91
|
-
|
95
|
+
submodule_config = rbln_config.get(submodule_name, {})
|
96
|
+
submodule_config = copy.deepcopy(submodule_config)
|
92
97
|
|
93
|
-
|
94
|
-
submodule_config = copy.deepcopy(submodule_config)
|
98
|
+
pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
|
95
99
|
|
96
|
-
|
100
|
+
submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
|
101
|
+
submodule_config.update(
|
102
|
+
{
|
103
|
+
"img2img_pipeline": cls.img2img_pipeline,
|
104
|
+
"inpaint_pipeline": cls.inpaint_pipeline,
|
105
|
+
}
|
106
|
+
)
|
107
|
+
submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
|
108
|
+
elif hasattr(pipelines, submodule_class_name):
|
109
|
+
submodule_config = rbln_config.get(submodule_name, {})
|
110
|
+
submodule_config = copy.deepcopy(submodule_config)
|
111
|
+
|
112
|
+
submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"{submodule_class_name}")
|
113
|
+
prefix = cls._prefix.get(submodule_name, "")
|
114
|
+
connected_submodules = cls._connected_classes.get(submodule_name)._submodules
|
115
|
+
pipe_global_config = {k: v for k, v in submodule_config.items() if k not in connected_submodules}
|
116
|
+
submodule_config = {k: v for k, v in submodule_config.items() if k in connected_submodules}
|
117
|
+
for key in submodule_config.keys():
|
118
|
+
submodule_config[key].update(pipe_global_config)
|
119
|
+
|
120
|
+
for connected_submodule_name in connected_submodules:
|
121
|
+
connected_submodule_config = rbln_config.pop(prefix + connected_submodule_name, {})
|
122
|
+
if connected_submodule_name in submodule_config:
|
123
|
+
submodule_config[connected_submodule_name].update(connected_submodule_config)
|
124
|
+
else:
|
125
|
+
submodule_config[connected_submodule_name] = connected_submodule_config
|
97
126
|
|
98
|
-
|
99
|
-
|
100
|
-
{
|
101
|
-
"img2img_pipeline": cls.img2img_pipeline,
|
102
|
-
"inpaint_pipeline": cls.inpaint_pipeline,
|
127
|
+
pipe_global_config = {
|
128
|
+
k: v for k, v in rbln_config.items() if k != submodule_class_name and not isinstance(v, dict)
|
103
129
|
}
|
104
|
-
|
105
|
-
|
130
|
+
|
131
|
+
for connected_submodule_name in connected_submodules:
|
132
|
+
for k, v in pipe_global_config.items():
|
133
|
+
if "guidance_scale" in k:
|
134
|
+
if prefix + "guidance_scale" == k:
|
135
|
+
submodule_config[connected_submodule_name]["guidance_scale"] = v
|
136
|
+
else:
|
137
|
+
submodule_config[connected_submodule_name][k] = v
|
138
|
+
rbln_config[submodule_name] = submodule_config
|
139
|
+
else:
|
140
|
+
raise ValueError(f"submodule {submodule_name} isn't supported")
|
106
141
|
return submodule_config
|
107
142
|
|
108
143
|
@staticmethod
|
@@ -165,8 +200,26 @@ class RBLNDiffusionMixin:
|
|
165
200
|
|
166
201
|
else:
|
167
202
|
# raise error if any of submodules are torch module.
|
168
|
-
model_index_config =
|
169
|
-
|
203
|
+
model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
|
204
|
+
if cls._load_connected_pipes:
|
205
|
+
submodules = []
|
206
|
+
for submodule in cls._submodules:
|
207
|
+
submodule_config = rbln_config.pop(submodule, {})
|
208
|
+
prefix = cls._prefix.get(submodule, "")
|
209
|
+
connected_submodules = cls._connected_classes.get(submodule)._submodules
|
210
|
+
for connected_submodule_name in connected_submodules:
|
211
|
+
connected_submodule_config = submodule_config.pop(connected_submodule_name, {})
|
212
|
+
if connected_submodule_config:
|
213
|
+
rbln_config[prefix + connected_submodule_name] = connected_submodule_config
|
214
|
+
submodules.append(prefix + connected_submodule_name)
|
215
|
+
pipe_global_config = {k: v for k, v in rbln_config.items() if k not in submodules}
|
216
|
+
for submodule in submodules:
|
217
|
+
if submodule in rbln_config:
|
218
|
+
rbln_config[submodule].update(pipe_global_config)
|
219
|
+
else:
|
220
|
+
submodules = cls._submodules
|
221
|
+
|
222
|
+
for submodule_name in submodules:
|
170
223
|
if isinstance(kwargs.get(submodule_name), torch.nn.Module):
|
171
224
|
raise AssertionError(
|
172
225
|
f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
|
@@ -181,9 +234,6 @@ class RBLNDiffusionMixin:
|
|
181
234
|
if not any(kwd in submodule_config for kwd in RUNTIME_KEYWORDS):
|
182
235
|
continue
|
183
236
|
|
184
|
-
if model_index_config is None:
|
185
|
-
model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
|
186
|
-
|
187
237
|
module_name, class_name = model_index_config[submodule_name]
|
188
238
|
if module_name != "optimum.rbln":
|
189
239
|
raise ValueError(
|
@@ -228,6 +278,7 @@ class RBLNDiffusionMixin:
|
|
228
278
|
passed_submodules: Dict[str, RBLNModel],
|
229
279
|
model_save_dir: Optional[PathLike],
|
230
280
|
rbln_config: Dict[str, Any],
|
281
|
+
prefix: Optional[str] = "",
|
231
282
|
) -> Dict[str, RBLNModel]:
|
232
283
|
compiled_submodules = {}
|
233
284
|
|
@@ -245,17 +296,54 @@ class RBLNDiffusionMixin:
|
|
245
296
|
controlnets=submodule,
|
246
297
|
model_save_dir=model_save_dir,
|
247
298
|
controlnet_rbln_config=submodule_rbln_config,
|
299
|
+
prefix=prefix,
|
248
300
|
)
|
249
301
|
elif isinstance(submodule, torch.nn.Module):
|
250
302
|
submodule_cls: RBLNModel = getattr(
|
251
303
|
importlib.import_module("optimum.rbln"), f"RBLN{submodule.__class__.__name__}"
|
252
304
|
)
|
305
|
+
subfolder = prefix + submodule_name
|
253
306
|
submodule = submodule_cls.from_model(
|
254
307
|
model=submodule,
|
255
|
-
subfolder=
|
308
|
+
subfolder=subfolder,
|
256
309
|
model_save_dir=model_save_dir,
|
257
310
|
rbln_config=submodule_rbln_config,
|
258
311
|
)
|
312
|
+
elif hasattr(pipelines, submodule.__class__.__name__):
|
313
|
+
connected_pipe = submodule
|
314
|
+
connected_pipe_model_save_dir = model_save_dir
|
315
|
+
connected_pipe_rbln_config = submodule_rbln_config
|
316
|
+
connected_pipe_cls: RBLNDiffusionMixin = getattr(
|
317
|
+
importlib.import_module("optimum.rbln"), connected_pipe.__class__.__name__
|
318
|
+
)
|
319
|
+
submodule_dict = {}
|
320
|
+
for name in connected_pipe.config.keys():
|
321
|
+
if hasattr(connected_pipe, name):
|
322
|
+
submodule_dict[name] = getattr(connected_pipe, name)
|
323
|
+
connected_pipe = connected_pipe_cls(**submodule_dict)
|
324
|
+
connected_pipe_submodules = {}
|
325
|
+
prefix = cls._prefix.get(submodule_name, "")
|
326
|
+
for name in connected_pipe_cls._submodules:
|
327
|
+
if prefix + name in passed_submodules:
|
328
|
+
connected_pipe_submodules[name] = passed_submodules.get(prefix + name)
|
329
|
+
|
330
|
+
connected_pipe_compiled_submodules = connected_pipe_cls._compile_submodules(
|
331
|
+
model=connected_pipe,
|
332
|
+
passed_submodules=connected_pipe_submodules,
|
333
|
+
model_save_dir=model_save_dir,
|
334
|
+
rbln_config=connected_pipe_rbln_config,
|
335
|
+
prefix=prefix,
|
336
|
+
)
|
337
|
+
connected_pipe = connected_pipe_cls._construct_pipe(
|
338
|
+
connected_pipe,
|
339
|
+
connected_pipe_compiled_submodules,
|
340
|
+
connected_pipe_model_save_dir,
|
341
|
+
connected_pipe_rbln_config,
|
342
|
+
)
|
343
|
+
|
344
|
+
for name in connected_pipe_cls._submodules:
|
345
|
+
compiled_submodules[prefix + name] = getattr(connected_pipe, name)
|
346
|
+
submodule = connected_pipe
|
259
347
|
else:
|
260
348
|
raise ValueError(f"Unknown class of submodule({submodule_name}) : {submodule.__class__.__name__} ")
|
261
349
|
|
@@ -268,6 +356,7 @@ class RBLNDiffusionMixin:
|
|
268
356
|
controlnets: "MultiControlNetModel",
|
269
357
|
model_save_dir: Optional[PathLike],
|
270
358
|
controlnet_rbln_config: Dict[str, Any],
|
359
|
+
prefix: Optional[str] = "",
|
271
360
|
):
|
272
361
|
# Compile multiple ControlNet models for a MultiControlNet setup
|
273
362
|
from .models.controlnet import RBLNControlNetModel
|
@@ -276,7 +365,7 @@ class RBLNDiffusionMixin:
|
|
276
365
|
compiled_controlnets = [
|
277
366
|
RBLNControlNetModel.from_model(
|
278
367
|
model=controlnet,
|
279
|
-
subfolder="controlnet" if i == 0 else f"controlnet_{i}",
|
368
|
+
subfolder=f"{prefix}controlnet" if i == 0 else f"{prefix}controlnet_{i}",
|
280
369
|
model_save_dir=model_save_dir,
|
281
370
|
rbln_config=controlnet_rbln_config,
|
282
371
|
)
|
@@ -287,10 +376,21 @@ class RBLNDiffusionMixin:
|
|
287
376
|
@classmethod
|
288
377
|
def _construct_pipe(cls, model, submodules, model_save_dir, rbln_config):
|
289
378
|
# Construct finalize pipe setup with compiled submodules and configurations
|
379
|
+
submodule_names = []
|
380
|
+
for submodule_name in cls._submodules:
|
381
|
+
submodule = getattr(model, submodule_name)
|
382
|
+
if hasattr(pipelines, submodule.__class__.__name__):
|
383
|
+
prefix = cls._prefix.get(submodule_name, "")
|
384
|
+
connected_pipe_submodules = submodules[submodule_name].__class__._submodules
|
385
|
+
connected_pipe_submodules = [prefix + name for name in connected_pipe_submodules]
|
386
|
+
submodule_names += connected_pipe_submodules
|
387
|
+
setattr(model, submodule_name, submodules[submodule_name])
|
388
|
+
else:
|
389
|
+
submodule_names.append(submodule_name)
|
290
390
|
|
291
391
|
if model_save_dir is not None:
|
292
392
|
# To skip saving original pytorch modules
|
293
|
-
for submodule_name in
|
393
|
+
for submodule_name in submodule_names:
|
294
394
|
delattr(model, submodule_name)
|
295
395
|
|
296
396
|
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
@@ -300,7 +400,7 @@ class RBLNDiffusionMixin:
|
|
300
400
|
# Causing warning messeages.
|
301
401
|
|
302
402
|
update_dict = {}
|
303
|
-
for submodule_name in
|
403
|
+
for submodule_name in submodule_names:
|
304
404
|
# replace submodule
|
305
405
|
setattr(model, submodule_name, submodules[submodule_name])
|
306
406
|
update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
|
@@ -322,9 +422,16 @@ class RBLNDiffusionMixin:
|
|
322
422
|
if rbln_config.get("optimize_host_memory") is False:
|
323
423
|
# Keep compiled_model objs to further analysis. -> TODO: remove soon...
|
324
424
|
model.compiled_models = []
|
325
|
-
|
326
|
-
|
327
|
-
|
425
|
+
if model._load_connected_pipes:
|
426
|
+
for name in cls._submodules:
|
427
|
+
connected_pipe = getattr(model, name)
|
428
|
+
for submodule_name in connected_pipe.__class__._submodules:
|
429
|
+
submodule = getattr(connected_pipe, submodule_name)
|
430
|
+
model.compiled_models.extend(submodule.compiled_models)
|
431
|
+
else:
|
432
|
+
for name in cls._submodules:
|
433
|
+
submodule = getattr(model, name)
|
434
|
+
model.compiled_models.extend(submodule.compiled_models)
|
328
435
|
|
329
436
|
return model
|
330
437
|
|
@@ -20,20 +20,26 @@ from transformers.utils import _LazyModule
|
|
20
20
|
_import_structure = {
|
21
21
|
"autoencoders": [
|
22
22
|
"RBLNAutoencoderKL",
|
23
|
+
"RBLNVQModel",
|
23
24
|
],
|
24
25
|
"unets": [
|
25
26
|
"RBLNUNet2DConditionModel",
|
26
27
|
],
|
27
28
|
"controlnet": ["RBLNControlNetModel"],
|
28
|
-
"transformers": [
|
29
|
+
"transformers": [
|
30
|
+
"RBLNPriorTransformer",
|
31
|
+
"RBLNSD3Transformer2DModel",
|
32
|
+
],
|
29
33
|
}
|
30
34
|
|
31
35
|
if TYPE_CHECKING:
|
32
36
|
from .autoencoders import (
|
33
37
|
RBLNAutoencoderKL,
|
38
|
+
RBLNVQModel,
|
34
39
|
)
|
35
40
|
from .controlnet import RBLNControlNetModel
|
36
41
|
from .transformers import (
|
42
|
+
RBLNPriorTransformer,
|
37
43
|
RBLNSD3Transformer2DModel,
|
38
44
|
)
|
39
45
|
from .unets import (
|
@@ -12,11 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import TYPE_CHECKING
|
15
|
+
from typing import TYPE_CHECKING, List
|
16
16
|
|
17
17
|
import torch # noqa: I001
|
18
|
-
from diffusers import AutoencoderKL
|
18
|
+
from diffusers import AutoencoderKL, VQModel
|
19
19
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
20
|
+
from diffusers.models.autoencoders.vq_model import VQEncoderOutput
|
20
21
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
21
22
|
|
22
23
|
from ....utils.logging import get_logger
|
@@ -72,3 +73,52 @@ class _VAEEncoder(torch.nn.Module):
|
|
72
73
|
def forward(self, x):
|
73
74
|
vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
|
74
75
|
return vae_out
|
76
|
+
|
77
|
+
|
78
|
+
class RBLNRuntimeVQEncoder(RBLNPytorchRuntime):
|
79
|
+
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
80
|
+
h = self.forward(x.contiguous())
|
81
|
+
return VQEncoderOutput(latents=h)
|
82
|
+
|
83
|
+
|
84
|
+
class RBLNRuntimeVQDecoder(RBLNPytorchRuntime):
|
85
|
+
def decode(self, h: torch.Tensor, force_not_quantize: bool = False, shape=None, **kwargs) -> List[torch.Tensor]:
|
86
|
+
if not (force_not_quantize and not self.lookup_from_codebook):
|
87
|
+
raise ValueError(
|
88
|
+
"Currently, the `decode` method of the class `RBLNVQModel` is executed successfully only if `force_not_quantize` is True and `config.lookup_from_codebook` is False"
|
89
|
+
)
|
90
|
+
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
|
91
|
+
dec = self.forward(h.contiguous())
|
92
|
+
return dec, commit_loss
|
93
|
+
|
94
|
+
|
95
|
+
class _VQEncoder(torch.nn.Module):
|
96
|
+
def __init__(self, vq_model: VQModel):
|
97
|
+
super().__init__()
|
98
|
+
self.vq_model = vq_model
|
99
|
+
|
100
|
+
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
101
|
+
h = self.vq_model.encoder(x)
|
102
|
+
h = self.vq_model.quant_conv(h)
|
103
|
+
return h
|
104
|
+
|
105
|
+
def forward(self, x: torch.Tensor):
|
106
|
+
vq_out = self.encode(x)
|
107
|
+
return vq_out
|
108
|
+
|
109
|
+
|
110
|
+
class _VQDecoder(torch.nn.Module):
|
111
|
+
def __init__(self, vq_model: VQModel):
|
112
|
+
super().__init__()
|
113
|
+
self.vq_model = vq_model
|
114
|
+
|
115
|
+
def decode(self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None):
|
116
|
+
quant = h
|
117
|
+
quant2 = self.vq_model.post_quant_conv(quant)
|
118
|
+
quant = quant if self.vq_model.config.norm_type == "spatial" else None
|
119
|
+
dec = self.vq_model.decoder(quant2, quant)
|
120
|
+
return dec
|
121
|
+
|
122
|
+
def forward(self, h: torch.Tensor):
|
123
|
+
vq_out = self.decode(h)
|
124
|
+
return vq_out
|
@@ -0,0 +1,159 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
16
|
+
|
17
|
+
import rebel
|
18
|
+
import torch
|
19
|
+
from diffusers import VQModel
|
20
|
+
from diffusers.models.autoencoders.vae import DecoderOutput
|
21
|
+
from diffusers.models.autoencoders.vq_model import VQEncoderOutput
|
22
|
+
from transformers import PretrainedConfig
|
23
|
+
|
24
|
+
from ....modeling import RBLNModel
|
25
|
+
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
26
|
+
from ....utils.logging import get_logger
|
27
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
28
|
+
from .vae import RBLNRuntimeVQDecoder, RBLNRuntimeVQEncoder, _VQDecoder, _VQEncoder
|
29
|
+
|
30
|
+
|
31
|
+
if TYPE_CHECKING:
|
32
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
33
|
+
|
34
|
+
logger = get_logger(__name__)
|
35
|
+
|
36
|
+
|
37
|
+
class RBLNVQModel(RBLNModel):
|
38
|
+
auto_model_class = VQModel
|
39
|
+
config_name = "config.json"
|
40
|
+
hf_library_name = "diffusers"
|
41
|
+
|
42
|
+
def __post_init__(self, **kwargs):
|
43
|
+
super().__post_init__(**kwargs)
|
44
|
+
|
45
|
+
self.encoder = RBLNRuntimeVQEncoder(runtime=self.model[0], main_input_name="x")
|
46
|
+
self.decoder = RBLNRuntimeVQDecoder(runtime=self.model[1], main_input_name="z")
|
47
|
+
self.decoder.lookup_from_codebook = self.config.lookup_from_codebook
|
48
|
+
height = self.rbln_config.model_cfg.get("img_height", 512)
|
49
|
+
width = self.rbln_config.model_cfg.get("img_width", 512)
|
50
|
+
self.image_size = [height, width]
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def get_compiled_model(cls, model, rbln_config: RBLNConfig):
|
54
|
+
encoder_model = _VQEncoder(model)
|
55
|
+
decoder_model = _VQDecoder(model)
|
56
|
+
encoder_model.eval()
|
57
|
+
decoder_model.eval()
|
58
|
+
|
59
|
+
enc_compiled_model = cls.compile(encoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
|
60
|
+
dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[1])
|
61
|
+
|
62
|
+
return {"encoder": enc_compiled_model, "decoder": dec_compiled_model}
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
66
|
+
batch_size = rbln_config.get("batch_size")
|
67
|
+
if batch_size is None:
|
68
|
+
batch_size = 1
|
69
|
+
img_height = rbln_config.get("img_height")
|
70
|
+
if img_height is None:
|
71
|
+
img_height = 512
|
72
|
+
img_width = rbln_config.get("img_width")
|
73
|
+
if img_width is None:
|
74
|
+
img_width = 512
|
75
|
+
|
76
|
+
rbln_config.update(
|
77
|
+
{
|
78
|
+
"batch_size": batch_size,
|
79
|
+
"img_height": img_height,
|
80
|
+
"img_width": img_width,
|
81
|
+
}
|
82
|
+
)
|
83
|
+
|
84
|
+
return rbln_config
|
85
|
+
|
86
|
+
@classmethod
|
87
|
+
def _get_rbln_config(
|
88
|
+
cls,
|
89
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
90
|
+
model_config: "PretrainedConfig",
|
91
|
+
rbln_kwargs: Dict[str, Any] = {},
|
92
|
+
) -> RBLNConfig:
|
93
|
+
batch_size = rbln_kwargs.get("batch_size") or 1
|
94
|
+
height = rbln_kwargs.get("img_height") or 512
|
95
|
+
width = rbln_kwargs.get("img_width") or 512
|
96
|
+
|
97
|
+
if hasattr(model_config, "block_out_channels"):
|
98
|
+
scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
|
99
|
+
else:
|
100
|
+
# image processor default value 8 (int)
|
101
|
+
scale_factor = 8
|
102
|
+
|
103
|
+
enc_shape = (height, width)
|
104
|
+
dec_shape = (height // scale_factor, width // scale_factor)
|
105
|
+
|
106
|
+
enc_input_info = [
|
107
|
+
(
|
108
|
+
"x",
|
109
|
+
[batch_size, model_config.in_channels, enc_shape[0], enc_shape[1]],
|
110
|
+
"float32",
|
111
|
+
)
|
112
|
+
]
|
113
|
+
dec_input_info = [
|
114
|
+
(
|
115
|
+
"h",
|
116
|
+
[batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
|
117
|
+
"float32",
|
118
|
+
)
|
119
|
+
]
|
120
|
+
|
121
|
+
enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
122
|
+
dec_rbln_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
123
|
+
|
124
|
+
compile_cfgs = [enc_rbln_compile_config, dec_rbln_compile_config]
|
125
|
+
rbln_config = RBLNConfig(
|
126
|
+
rbln_cls=cls.__name__,
|
127
|
+
compile_cfgs=compile_cfgs,
|
128
|
+
rbln_kwargs=rbln_kwargs,
|
129
|
+
)
|
130
|
+
return rbln_config
|
131
|
+
|
132
|
+
@classmethod
|
133
|
+
def _create_runtimes(
|
134
|
+
cls,
|
135
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
136
|
+
rbln_device_map: Dict[str, int],
|
137
|
+
activate_profiler: Optional[bool] = None,
|
138
|
+
) -> List[rebel.Runtime]:
|
139
|
+
if len(compiled_models) == 1:
|
140
|
+
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
141
|
+
return [
|
142
|
+
compiled_models[0].create_runtime(
|
143
|
+
tensor_type="pt", device=device_val, activate_profiler=activate_profiler
|
144
|
+
)
|
145
|
+
]
|
146
|
+
|
147
|
+
device_vals = [rbln_device_map["encoder"], rbln_device_map["decoder"]]
|
148
|
+
return [
|
149
|
+
compiled_model.create_runtime(tensor_type="pt", device=device_val, activate_profiler=activate_profiler)
|
150
|
+
for compiled_model, device_val in zip(compiled_models, device_vals)
|
151
|
+
]
|
152
|
+
|
153
|
+
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
154
|
+
posterior = self.encoder.encode(x)
|
155
|
+
return VQEncoderOutput(latents=posterior)
|
156
|
+
|
157
|
+
def decode(self, h: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
158
|
+
dec, commit_loss = self.decoder.decode(h, **kwargs)
|
159
|
+
return DecoderOutput(sample=dec, commit_loss=commit_loss)
|