optimum-rbln 0.2.1a4__py3-none-any.whl → 0.7.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -2
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/__init__.py +10 -0
- optimum/rbln/diffusers/modeling_diffusers.py +115 -23
- optimum/rbln/diffusers/models/__init__.py +7 -1
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +52 -2
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +159 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +174 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +57 -14
- optimum/rbln/diffusers/pipelines/__init__.py +10 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +83 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +22 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +22 -0
- optimum/rbln/modeling_base.py +10 -9
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/models/__init__.py +12 -2
- optimum/rbln/transformers/models/clip/__init__.py +6 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +26 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +3 -1
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +1 -1
- {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2rc0.dist-info}/METADATA +1 -1
- {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2rc0.dist-info}/RECORD +27 -21
- {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2rc0.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -48,6 +48,7 @@ _import_structure = {
|
|
48
48
|
"RBLNCLIPTextModel",
|
49
49
|
"RBLNCLIPTextModelWithProjection",
|
50
50
|
"RBLNCLIPVisionModel",
|
51
|
+
"RBLNCLIPVisionModelWithProjection",
|
51
52
|
"RBLNDPTForDepthEstimation",
|
52
53
|
"RBLNExaoneForCausalLM",
|
53
54
|
"RBLNGemmaForCausalLM",
|
@@ -74,11 +75,15 @@ _import_structure = {
|
|
74
75
|
"RBLNBertForMaskedLM",
|
75
76
|
],
|
76
77
|
"diffusers": [
|
78
|
+
"RBLNAutoencoderKL",
|
79
|
+
"RBLNControlNetModel",
|
80
|
+
"RBLNPriorTransformer",
|
81
|
+
"RBLNKandinskyV22InpaintCombinedPipeline",
|
82
|
+
"RBLNKandinskyV22InpaintPipeline",
|
83
|
+
"RBLNKandinskyV22PriorPipeline",
|
77
84
|
"RBLNStableDiffusionPipeline",
|
78
85
|
"RBLNStableDiffusionXLPipeline",
|
79
|
-
"RBLNAutoencoderKL",
|
80
86
|
"RBLNUNet2DConditionModel",
|
81
|
-
"RBLNControlNetModel",
|
82
87
|
"RBLNStableDiffusionImg2ImgPipeline",
|
83
88
|
"RBLNStableDiffusionInpaintPipeline",
|
84
89
|
"RBLNStableDiffusionControlNetImg2ImgPipeline",
|
@@ -88,6 +93,7 @@ _import_structure = {
|
|
88
93
|
"RBLNStableDiffusionControlNetPipeline",
|
89
94
|
"RBLNStableDiffusionXLControlNetPipeline",
|
90
95
|
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
96
|
+
"RBLNVQModel",
|
91
97
|
"RBLNSD3Transformer2DModel",
|
92
98
|
"RBLNStableDiffusion3Img2ImgPipeline",
|
93
99
|
"RBLNStableDiffusion3InpaintPipeline",
|
@@ -101,7 +107,11 @@ if TYPE_CHECKING:
|
|
101
107
|
RBLNAutoencoderKL,
|
102
108
|
RBLNControlNetModel,
|
103
109
|
RBLNDiffusionMixin,
|
110
|
+
RBLNKandinskyV22InpaintCombinedPipeline,
|
111
|
+
RBLNKandinskyV22InpaintPipeline,
|
112
|
+
RBLNKandinskyV22PriorPipeline,
|
104
113
|
RBLNMultiControlNetModel,
|
114
|
+
RBLNPriorTransformer,
|
105
115
|
RBLNSD3Transformer2DModel,
|
106
116
|
RBLNStableDiffusion3Img2ImgPipeline,
|
107
117
|
RBLNStableDiffusion3InpaintPipeline,
|
@@ -117,6 +127,7 @@ if TYPE_CHECKING:
|
|
117
127
|
RBLNStableDiffusionXLInpaintPipeline,
|
118
128
|
RBLNStableDiffusionXLPipeline,
|
119
129
|
RBLNUNet2DConditionModel,
|
130
|
+
RBLNVQModel,
|
120
131
|
)
|
121
132
|
from .modeling import (
|
122
133
|
RBLNBaseModel,
|
@@ -148,6 +159,7 @@ if TYPE_CHECKING:
|
|
148
159
|
RBLNCLIPTextModel,
|
149
160
|
RBLNCLIPTextModelWithProjection,
|
150
161
|
RBLNCLIPVisionModel,
|
162
|
+
RBLNCLIPVisionModelWithProjection,
|
151
163
|
RBLNDistilBertForQuestionAnswering,
|
152
164
|
RBLNDPTForDepthEstimation,
|
153
165
|
RBLNExaoneForCausalLM,
|
optimum/rbln/__version__.py
CHANGED
@@ -24,6 +24,9 @@ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES["optimum.rbln"])
|
|
24
24
|
|
25
25
|
_import_structure = {
|
26
26
|
"pipelines": [
|
27
|
+
"RBLNKandinskyV22InpaintCombinedPipeline",
|
28
|
+
"RBLNKandinskyV22InpaintPipeline",
|
29
|
+
"RBLNKandinskyV22PriorPipeline",
|
27
30
|
"RBLNStableDiffusionPipeline",
|
28
31
|
"RBLNStableDiffusionXLPipeline",
|
29
32
|
"RBLNStableDiffusionImg2ImgPipeline",
|
@@ -44,6 +47,8 @@ _import_structure = {
|
|
44
47
|
"RBLNUNet2DConditionModel",
|
45
48
|
"RBLNControlNetModel",
|
46
49
|
"RBLNSD3Transformer2DModel",
|
50
|
+
"RBLNPriorTransformer",
|
51
|
+
"RBLNVQModel",
|
47
52
|
],
|
48
53
|
"modeling_diffusers": [
|
49
54
|
"RBLNDiffusionMixin",
|
@@ -55,10 +60,15 @@ if TYPE_CHECKING:
|
|
55
60
|
from .models import (
|
56
61
|
RBLNAutoencoderKL,
|
57
62
|
RBLNControlNetModel,
|
63
|
+
RBLNPriorTransformer,
|
58
64
|
RBLNSD3Transformer2DModel,
|
59
65
|
RBLNUNet2DConditionModel,
|
66
|
+
RBLNVQModel,
|
60
67
|
)
|
61
68
|
from .pipelines import (
|
69
|
+
RBLNKandinskyV22InpaintCombinedPipeline,
|
70
|
+
RBLNKandinskyV22InpaintPipeline,
|
71
|
+
RBLNKandinskyV22PriorPipeline,
|
62
72
|
RBLNMultiControlNetModel,
|
63
73
|
RBLNStableDiffusion3Img2ImgPipeline,
|
64
74
|
RBLNStableDiffusion3InpaintPipeline,
|
@@ -23,6 +23,7 @@ from ..modeling import RBLNModel
|
|
23
23
|
from ..modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
|
24
24
|
from ..utils.decorator_utils import remove_compile_time_kwargs
|
25
25
|
from ..utils.logging import get_logger
|
26
|
+
from . import pipelines
|
26
27
|
|
27
28
|
|
28
29
|
logger = get_logger(__name__)
|
@@ -67,6 +68,7 @@ class RBLNDiffusionMixin:
|
|
67
68
|
"""
|
68
69
|
|
69
70
|
_submodules = []
|
71
|
+
_prefix = {}
|
70
72
|
|
71
73
|
@classmethod
|
72
74
|
@property
|
@@ -84,25 +86,50 @@ class RBLNDiffusionMixin:
|
|
84
86
|
) -> Dict[str, Any]:
|
85
87
|
submodule = getattr(model, submodule_name)
|
86
88
|
submodule_class_name = submodule.__class__.__name__
|
89
|
+
if isinstance(submodule, torch.nn.Module):
|
90
|
+
if submodule_class_name == "MultiControlNetModel":
|
91
|
+
submodule_class_name = "ControlNetModel"
|
87
92
|
|
88
|
-
|
89
|
-
submodule_class_name = "ControlNetModel"
|
93
|
+
submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"RBLN{submodule_class_name}")
|
90
94
|
|
91
|
-
|
95
|
+
submodule_config = rbln_config.get(submodule_name, {})
|
96
|
+
submodule_config = copy.deepcopy(submodule_config)
|
92
97
|
|
93
|
-
|
94
|
-
submodule_config = copy.deepcopy(submodule_config)
|
98
|
+
pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
|
95
99
|
|
96
|
-
|
100
|
+
submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
|
101
|
+
submodule_config.update(
|
102
|
+
{
|
103
|
+
"img2img_pipeline": cls.img2img_pipeline,
|
104
|
+
"inpaint_pipeline": cls.inpaint_pipeline,
|
105
|
+
}
|
106
|
+
)
|
107
|
+
submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
|
108
|
+
elif hasattr(pipelines, submodule_class_name):
|
109
|
+
submodule_config = rbln_config.get(submodule_name, {})
|
110
|
+
submodule_config = copy.deepcopy(submodule_config)
|
111
|
+
|
112
|
+
submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"{submodule_class_name}")
|
113
|
+
prefix = cls._prefix.get(submodule_name, "")
|
114
|
+
connected_submodules = cls._connected_classes.get(submodule_name)._submodules
|
115
|
+
for connected_submodule_name in connected_submodules:
|
116
|
+
connected_submodule_config = rbln_config.pop(prefix + connected_submodule_name, {})
|
117
|
+
if connected_submodule_name in submodule_config:
|
118
|
+
submodule_config[connected_submodule_name].update(connected_submodule_config)
|
119
|
+
else:
|
120
|
+
submodule_config[connected_submodule_name] = connected_submodule_config
|
97
121
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
122
|
+
submodules = copy.deepcopy(cls._submodules)
|
123
|
+
submodules += [prefix + connected_submodule_name for connected_submodule_name in connected_submodules]
|
124
|
+
|
125
|
+
pipe_global_config = {k: v for k, v in rbln_config.items() if k not in submodules}
|
126
|
+
for connected_submodule_name in connected_submodules:
|
127
|
+
submodule_config[connected_submodule_name].update(
|
128
|
+
{k: v for k, v in pipe_global_config.items() if k not in submodule_config}
|
129
|
+
)
|
130
|
+
rbln_config[submodule_name] = submodule_config
|
131
|
+
else:
|
132
|
+
raise ValueError(f"submodule {submodule_name} isn't supported")
|
106
133
|
return submodule_config
|
107
134
|
|
108
135
|
@staticmethod
|
@@ -165,8 +192,26 @@ class RBLNDiffusionMixin:
|
|
165
192
|
|
166
193
|
else:
|
167
194
|
# raise error if any of submodules are torch module.
|
168
|
-
model_index_config =
|
169
|
-
|
195
|
+
model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
|
196
|
+
if cls._load_connected_pipes:
|
197
|
+
submodules = []
|
198
|
+
for submodule in cls._submodules:
|
199
|
+
submodule_config = rbln_config.pop(submodule, {})
|
200
|
+
prefix = cls._prefix.get(submodule, "")
|
201
|
+
connected_submodules = cls._connected_classes.get(submodule)._submodules
|
202
|
+
for connected_submodule_name in connected_submodules:
|
203
|
+
connected_submodule_config = submodule_config.pop(connected_submodule_name, {})
|
204
|
+
if connected_submodule_config:
|
205
|
+
rbln_config[prefix + connected_submodule_name] = connected_submodule_config
|
206
|
+
submodules.append(prefix + connected_submodule_name)
|
207
|
+
pipe_global_config = {k: v for k, v in rbln_config.items() if k not in submodules}
|
208
|
+
for submodule in submodules:
|
209
|
+
if submodule in rbln_config:
|
210
|
+
rbln_config[submodule].update(pipe_global_config)
|
211
|
+
else:
|
212
|
+
submodules = cls._submodules
|
213
|
+
|
214
|
+
for submodule_name in submodules:
|
170
215
|
if isinstance(kwargs.get(submodule_name), torch.nn.Module):
|
171
216
|
raise AssertionError(
|
172
217
|
f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
|
@@ -181,9 +226,6 @@ class RBLNDiffusionMixin:
|
|
181
226
|
if not any(kwd in submodule_config for kwd in RUNTIME_KEYWORDS):
|
182
227
|
continue
|
183
228
|
|
184
|
-
if model_index_config is None:
|
185
|
-
model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
|
186
|
-
|
187
229
|
module_name, class_name = model_index_config[submodule_name]
|
188
230
|
if module_name != "optimum.rbln":
|
189
231
|
raise ValueError(
|
@@ -228,6 +270,7 @@ class RBLNDiffusionMixin:
|
|
228
270
|
passed_submodules: Dict[str, RBLNModel],
|
229
271
|
model_save_dir: Optional[PathLike],
|
230
272
|
rbln_config: Dict[str, Any],
|
273
|
+
prefix: Optional[str] = "",
|
231
274
|
) -> Dict[str, RBLNModel]:
|
232
275
|
compiled_submodules = {}
|
233
276
|
|
@@ -245,17 +288,54 @@ class RBLNDiffusionMixin:
|
|
245
288
|
controlnets=submodule,
|
246
289
|
model_save_dir=model_save_dir,
|
247
290
|
controlnet_rbln_config=submodule_rbln_config,
|
291
|
+
prefix=prefix,
|
248
292
|
)
|
249
293
|
elif isinstance(submodule, torch.nn.Module):
|
250
294
|
submodule_cls: RBLNModel = getattr(
|
251
295
|
importlib.import_module("optimum.rbln"), f"RBLN{submodule.__class__.__name__}"
|
252
296
|
)
|
297
|
+
subfolder = prefix + submodule_name
|
253
298
|
submodule = submodule_cls.from_model(
|
254
299
|
model=submodule,
|
255
|
-
subfolder=
|
300
|
+
subfolder=subfolder,
|
256
301
|
model_save_dir=model_save_dir,
|
257
302
|
rbln_config=submodule_rbln_config,
|
258
303
|
)
|
304
|
+
elif hasattr(pipelines, submodule.__class__.__name__):
|
305
|
+
connected_pipe = submodule
|
306
|
+
connected_pipe_model_save_dir = model_save_dir
|
307
|
+
connected_pipe_rbln_config = submodule_rbln_config
|
308
|
+
connected_pipe_cls: RBLNDiffusionMixin = getattr(
|
309
|
+
importlib.import_module("optimum.rbln"), connected_pipe.__class__.__name__
|
310
|
+
)
|
311
|
+
submodule_dict = {}
|
312
|
+
for name in connected_pipe.config.keys():
|
313
|
+
if hasattr(connected_pipe, name):
|
314
|
+
submodule_dict[name] = getattr(connected_pipe, name)
|
315
|
+
connected_pipe = connected_pipe_cls(**submodule_dict)
|
316
|
+
connected_pipe_submodules = {}
|
317
|
+
prefix = cls._prefix.get(submodule_name, "")
|
318
|
+
for name in connected_pipe_cls._submodules:
|
319
|
+
if prefix + name in passed_submodules:
|
320
|
+
connected_pipe_submodules[name] = passed_submodules.get(prefix + name)
|
321
|
+
|
322
|
+
connected_pipe_compiled_submodules = connected_pipe_cls._compile_submodules(
|
323
|
+
model=connected_pipe,
|
324
|
+
passed_submodules=connected_pipe_submodules,
|
325
|
+
model_save_dir=model_save_dir,
|
326
|
+
rbln_config=connected_pipe_rbln_config,
|
327
|
+
prefix=prefix,
|
328
|
+
)
|
329
|
+
connected_pipe = connected_pipe_cls._construct_pipe(
|
330
|
+
connected_pipe,
|
331
|
+
connected_pipe_compiled_submodules,
|
332
|
+
connected_pipe_model_save_dir,
|
333
|
+
connected_pipe_rbln_config,
|
334
|
+
)
|
335
|
+
|
336
|
+
for name in connected_pipe_cls._submodules:
|
337
|
+
compiled_submodules[prefix + name] = getattr(connected_pipe, name)
|
338
|
+
submodule = connected_pipe
|
259
339
|
else:
|
260
340
|
raise ValueError(f"Unknown class of submodule({submodule_name}) : {submodule.__class__.__name__} ")
|
261
341
|
|
@@ -268,6 +348,7 @@ class RBLNDiffusionMixin:
|
|
268
348
|
controlnets: "MultiControlNetModel",
|
269
349
|
model_save_dir: Optional[PathLike],
|
270
350
|
controlnet_rbln_config: Dict[str, Any],
|
351
|
+
prefix: Optional[str] = "",
|
271
352
|
):
|
272
353
|
# Compile multiple ControlNet models for a MultiControlNet setup
|
273
354
|
from .models.controlnet import RBLNControlNetModel
|
@@ -276,7 +357,7 @@ class RBLNDiffusionMixin:
|
|
276
357
|
compiled_controlnets = [
|
277
358
|
RBLNControlNetModel.from_model(
|
278
359
|
model=controlnet,
|
279
|
-
subfolder="controlnet" if i == 0 else f"controlnet_{i}",
|
360
|
+
subfolder=f"{prefix}controlnet" if i == 0 else f"{prefix}controlnet_{i}",
|
280
361
|
model_save_dir=model_save_dir,
|
281
362
|
rbln_config=controlnet_rbln_config,
|
282
363
|
)
|
@@ -287,10 +368,21 @@ class RBLNDiffusionMixin:
|
|
287
368
|
@classmethod
|
288
369
|
def _construct_pipe(cls, model, submodules, model_save_dir, rbln_config):
|
289
370
|
# Construct finalize pipe setup with compiled submodules and configurations
|
371
|
+
submodule_names = []
|
372
|
+
for submodule_name in cls._submodules:
|
373
|
+
submodule = getattr(model, submodule_name)
|
374
|
+
if hasattr(pipelines, submodule.__class__.__name__):
|
375
|
+
prefix = cls._prefix.get(submodule_name, "")
|
376
|
+
connected_pipe_submodules = submodules[submodule_name].__class__._submodules
|
377
|
+
connected_pipe_submodules = [prefix + name for name in connected_pipe_submodules]
|
378
|
+
submodule_names += connected_pipe_submodules
|
379
|
+
setattr(model, submodule_name, submodules[submodule_name])
|
380
|
+
else:
|
381
|
+
submodule_names.append(submodule_name)
|
290
382
|
|
291
383
|
if model_save_dir is not None:
|
292
384
|
# To skip saving original pytorch modules
|
293
|
-
for submodule_name in
|
385
|
+
for submodule_name in submodule_names:
|
294
386
|
delattr(model, submodule_name)
|
295
387
|
|
296
388
|
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
@@ -300,7 +392,7 @@ class RBLNDiffusionMixin:
|
|
300
392
|
# Causing warning messeages.
|
301
393
|
|
302
394
|
update_dict = {}
|
303
|
-
for submodule_name in
|
395
|
+
for submodule_name in submodule_names:
|
304
396
|
# replace submodule
|
305
397
|
setattr(model, submodule_name, submodules[submodule_name])
|
306
398
|
update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
|
@@ -20,20 +20,26 @@ from transformers.utils import _LazyModule
|
|
20
20
|
_import_structure = {
|
21
21
|
"autoencoders": [
|
22
22
|
"RBLNAutoencoderKL",
|
23
|
+
"RBLNVQModel",
|
23
24
|
],
|
24
25
|
"unets": [
|
25
26
|
"RBLNUNet2DConditionModel",
|
26
27
|
],
|
27
28
|
"controlnet": ["RBLNControlNetModel"],
|
28
|
-
"transformers": [
|
29
|
+
"transformers": [
|
30
|
+
"RBLNPriorTransformer",
|
31
|
+
"RBLNSD3Transformer2DModel",
|
32
|
+
],
|
29
33
|
}
|
30
34
|
|
31
35
|
if TYPE_CHECKING:
|
32
36
|
from .autoencoders import (
|
33
37
|
RBLNAutoencoderKL,
|
38
|
+
RBLNVQModel,
|
34
39
|
)
|
35
40
|
from .controlnet import RBLNControlNetModel
|
36
41
|
from .transformers import (
|
42
|
+
RBLNPriorTransformer,
|
37
43
|
RBLNSD3Transformer2DModel,
|
38
44
|
)
|
39
45
|
from .unets import (
|
@@ -12,11 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import TYPE_CHECKING
|
15
|
+
from typing import TYPE_CHECKING, List
|
16
16
|
|
17
17
|
import torch # noqa: I001
|
18
|
-
from diffusers import AutoencoderKL
|
18
|
+
from diffusers import AutoencoderKL, VQModel
|
19
19
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
20
|
+
from diffusers.models.autoencoders.vq_model import VQEncoderOutput
|
20
21
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
21
22
|
|
22
23
|
from ....utils.logging import get_logger
|
@@ -72,3 +73,52 @@ class _VAEEncoder(torch.nn.Module):
|
|
72
73
|
def forward(self, x):
|
73
74
|
vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
|
74
75
|
return vae_out
|
76
|
+
|
77
|
+
|
78
|
+
class RBLNRuntimeVQEncoder(RBLNPytorchRuntime):
|
79
|
+
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
80
|
+
h = self.forward(x.contiguous())
|
81
|
+
return VQEncoderOutput(latents=h)
|
82
|
+
|
83
|
+
|
84
|
+
class RBLNRuntimeVQDecoder(RBLNPytorchRuntime):
|
85
|
+
def decode(self, h: torch.Tensor, force_not_quantize: bool = False, shape=None, **kwargs) -> List[torch.Tensor]:
|
86
|
+
if not (force_not_quantize and not self.lookup_from_codebook):
|
87
|
+
raise ValueError(
|
88
|
+
"Currently, the `decode` method of the class `RBLNVQModel` is executed successfully only if `force_not_quantize` is True and `config.lookup_from_codebook` is False"
|
89
|
+
)
|
90
|
+
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
|
91
|
+
dec = self.forward(h.contiguous())
|
92
|
+
return dec, commit_loss
|
93
|
+
|
94
|
+
|
95
|
+
class _VQEncoder(torch.nn.Module):
|
96
|
+
def __init__(self, vq_model: VQModel):
|
97
|
+
super().__init__()
|
98
|
+
self.vq_model = vq_model
|
99
|
+
|
100
|
+
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
101
|
+
h = self.vq_model.encoder(x)
|
102
|
+
h = self.vq_model.quant_conv(h)
|
103
|
+
return h
|
104
|
+
|
105
|
+
def forward(self, x: torch.Tensor):
|
106
|
+
vq_out = self.encode(x)
|
107
|
+
return vq_out
|
108
|
+
|
109
|
+
|
110
|
+
class _VQDecoder(torch.nn.Module):
|
111
|
+
def __init__(self, vq_model: VQModel):
|
112
|
+
super().__init__()
|
113
|
+
self.vq_model = vq_model
|
114
|
+
|
115
|
+
def decode(self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None):
|
116
|
+
quant = h
|
117
|
+
quant2 = self.vq_model.post_quant_conv(quant)
|
118
|
+
quant = quant if self.vq_model.config.norm_type == "spatial" else None
|
119
|
+
dec = self.vq_model.decoder(quant2, quant)
|
120
|
+
return dec
|
121
|
+
|
122
|
+
def forward(self, h: torch.Tensor):
|
123
|
+
vq_out = self.decode(h)
|
124
|
+
return vq_out
|
@@ -0,0 +1,159 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
16
|
+
|
17
|
+
import rebel
|
18
|
+
import torch
|
19
|
+
from diffusers import VQModel
|
20
|
+
from diffusers.models.autoencoders.vae import DecoderOutput
|
21
|
+
from diffusers.models.autoencoders.vq_model import VQEncoderOutput
|
22
|
+
from transformers import PretrainedConfig
|
23
|
+
|
24
|
+
from ....modeling import RBLNModel
|
25
|
+
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
26
|
+
from ....utils.logging import get_logger
|
27
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
28
|
+
from .vae import RBLNRuntimeVQDecoder, RBLNRuntimeVQEncoder, _VQDecoder, _VQEncoder
|
29
|
+
|
30
|
+
|
31
|
+
if TYPE_CHECKING:
|
32
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
33
|
+
|
34
|
+
logger = get_logger(__name__)
|
35
|
+
|
36
|
+
|
37
|
+
class RBLNVQModel(RBLNModel):
|
38
|
+
auto_model_class = VQModel
|
39
|
+
config_name = "config.json"
|
40
|
+
hf_library_name = "diffusers"
|
41
|
+
|
42
|
+
def __post_init__(self, **kwargs):
|
43
|
+
super().__post_init__(**kwargs)
|
44
|
+
|
45
|
+
self.encoder = RBLNRuntimeVQEncoder(runtime=self.model[0], main_input_name="x")
|
46
|
+
self.decoder = RBLNRuntimeVQDecoder(runtime=self.model[1], main_input_name="z")
|
47
|
+
self.decoder.lookup_from_codebook = self.config.lookup_from_codebook
|
48
|
+
height = self.rbln_config.model_cfg.get("img_height", 512)
|
49
|
+
width = self.rbln_config.model_cfg.get("img_width", 512)
|
50
|
+
self.image_size = [height, width]
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def get_compiled_model(cls, model, rbln_config: RBLNConfig):
|
54
|
+
encoder_model = _VQEncoder(model)
|
55
|
+
decoder_model = _VQDecoder(model)
|
56
|
+
encoder_model.eval()
|
57
|
+
decoder_model.eval()
|
58
|
+
|
59
|
+
enc_compiled_model = cls.compile(encoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
|
60
|
+
dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[1])
|
61
|
+
|
62
|
+
return {"encoder": enc_compiled_model, "decoder": dec_compiled_model}
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
66
|
+
batch_size = rbln_config.get("batch_size")
|
67
|
+
if batch_size is None:
|
68
|
+
batch_size = 1
|
69
|
+
img_height = rbln_config.get("img_height")
|
70
|
+
if img_height is None:
|
71
|
+
img_height = 512
|
72
|
+
img_width = rbln_config.get("img_width")
|
73
|
+
if img_width is None:
|
74
|
+
img_width = 512
|
75
|
+
|
76
|
+
rbln_config.update(
|
77
|
+
{
|
78
|
+
"batch_size": batch_size,
|
79
|
+
"img_height": img_height,
|
80
|
+
"img_width": img_width,
|
81
|
+
}
|
82
|
+
)
|
83
|
+
|
84
|
+
return rbln_config
|
85
|
+
|
86
|
+
@classmethod
|
87
|
+
def _get_rbln_config(
|
88
|
+
cls,
|
89
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
90
|
+
model_config: "PretrainedConfig",
|
91
|
+
rbln_kwargs: Dict[str, Any] = {},
|
92
|
+
) -> RBLNConfig:
|
93
|
+
batch_size = rbln_kwargs.get("batch_size") or 1
|
94
|
+
height = rbln_kwargs.get("img_height") or 512
|
95
|
+
width = rbln_kwargs.get("img_width") or 512
|
96
|
+
|
97
|
+
if hasattr(model_config, "block_out_channels"):
|
98
|
+
scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
|
99
|
+
else:
|
100
|
+
# image processor default value 8 (int)
|
101
|
+
scale_factor = 8
|
102
|
+
|
103
|
+
enc_shape = (height, width)
|
104
|
+
dec_shape = (height // scale_factor, width // scale_factor)
|
105
|
+
|
106
|
+
enc_input_info = [
|
107
|
+
(
|
108
|
+
"x",
|
109
|
+
[batch_size, model_config.in_channels, enc_shape[0], enc_shape[1]],
|
110
|
+
"float32",
|
111
|
+
)
|
112
|
+
]
|
113
|
+
dec_input_info = [
|
114
|
+
(
|
115
|
+
"h",
|
116
|
+
[batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
|
117
|
+
"float32",
|
118
|
+
)
|
119
|
+
]
|
120
|
+
|
121
|
+
enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
122
|
+
dec_rbln_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
123
|
+
|
124
|
+
compile_cfgs = [enc_rbln_compile_config, dec_rbln_compile_config]
|
125
|
+
rbln_config = RBLNConfig(
|
126
|
+
rbln_cls=cls.__name__,
|
127
|
+
compile_cfgs=compile_cfgs,
|
128
|
+
rbln_kwargs=rbln_kwargs,
|
129
|
+
)
|
130
|
+
return rbln_config
|
131
|
+
|
132
|
+
@classmethod
|
133
|
+
def _create_runtimes(
|
134
|
+
cls,
|
135
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
136
|
+
rbln_device_map: Dict[str, int],
|
137
|
+
activate_profiler: Optional[bool] = None,
|
138
|
+
) -> List[rebel.Runtime]:
|
139
|
+
if len(compiled_models) == 1:
|
140
|
+
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
141
|
+
return [
|
142
|
+
compiled_models[0].create_runtime(
|
143
|
+
tensor_type="pt", device=device_val, activate_profiler=activate_profiler
|
144
|
+
)
|
145
|
+
]
|
146
|
+
|
147
|
+
device_vals = [rbln_device_map["encoder"], rbln_device_map["decoder"]]
|
148
|
+
return [
|
149
|
+
compiled_model.create_runtime(tensor_type="pt", device=device_val, activate_profiler=activate_profiler)
|
150
|
+
for compiled_model, device_val in zip(compiled_models, device_vals)
|
151
|
+
]
|
152
|
+
|
153
|
+
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
154
|
+
posterior = self.encoder.encode(x)
|
155
|
+
return VQEncoderOutput(latents=posterior)
|
156
|
+
|
157
|
+
def decode(self, h: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
158
|
+
dec, commit_loss = self.decoder.decode(h, **kwargs)
|
159
|
+
return DecoderOutput(sample=dec, commit_loss=commit_loss)
|