optimum-rbln 0.7.2rc1__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +8 -0
- optimum/rbln/__version__.py +9 -4
- optimum/rbln/diffusers/__init__.py +8 -0
- optimum/rbln/diffusers/modeling_diffusers.py +103 -109
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -3
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +15 -8
- optimum/rbln/diffusers/pipelines/__init__.py +8 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +7 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +25 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +107 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +25 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +3 -0
- optimum/rbln/modeling.py +4 -1
- optimum/rbln/modeling_base.py +16 -3
- optimum/rbln/ops/__init__.py +6 -2
- optimum/rbln/ops/attn.py +94 -85
- optimum/rbln/ops/flash_attn.py +46 -25
- optimum/rbln/ops/kv_cache_update.py +4 -4
- optimum/rbln/transformers/modeling_generic.py +3 -3
- optimum/rbln/transformers/models/bart/bart_architecture.py +10 -6
- optimum/rbln/transformers/models/bart/modeling_bart.py +6 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +264 -133
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +276 -29
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +11 -4
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +11 -4
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +5 -3
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -3
- optimum/rbln/transformers/models/phi/phi_architecture.py +9 -7
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +50 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +60 -36
- optimum/rbln/transformers/models/t5/modeling_t5.py +3 -1
- optimum/rbln/transformers/models/t5/t5_architecture.py +65 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +26 -36
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -14
- optimum/rbln/utils/import_utils.py +7 -0
- {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/RECORD +40 -38
- {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -78,9 +78,13 @@ _import_structure = {
|
|
78
78
|
"RBLNAutoencoderKL",
|
79
79
|
"RBLNControlNetModel",
|
80
80
|
"RBLNPriorTransformer",
|
81
|
+
"RBLNKandinskyV22CombinedPipeline",
|
82
|
+
"RBLNKandinskyV22Img2ImgCombinedPipeline",
|
81
83
|
"RBLNKandinskyV22InpaintCombinedPipeline",
|
82
84
|
"RBLNKandinskyV22InpaintPipeline",
|
85
|
+
"RBLNKandinskyV22Img2ImgPipeline",
|
83
86
|
"RBLNKandinskyV22PriorPipeline",
|
87
|
+
"RBLNKandinskyV22Pipeline",
|
84
88
|
"RBLNStableDiffusionPipeline",
|
85
89
|
"RBLNStableDiffusionXLPipeline",
|
86
90
|
"RBLNUNet2DConditionModel",
|
@@ -107,8 +111,12 @@ if TYPE_CHECKING:
|
|
107
111
|
RBLNAutoencoderKL,
|
108
112
|
RBLNControlNetModel,
|
109
113
|
RBLNDiffusionMixin,
|
114
|
+
RBLNKandinskyV22CombinedPipeline,
|
115
|
+
RBLNKandinskyV22Img2ImgCombinedPipeline,
|
116
|
+
RBLNKandinskyV22Img2ImgPipeline,
|
110
117
|
RBLNKandinskyV22InpaintCombinedPipeline,
|
111
118
|
RBLNKandinskyV22InpaintPipeline,
|
119
|
+
RBLNKandinskyV22Pipeline,
|
112
120
|
RBLNKandinskyV22PriorPipeline,
|
113
121
|
RBLNMultiControlNetModel,
|
114
122
|
RBLNPriorTransformer,
|
optimum/rbln/__version__.py
CHANGED
@@ -1,8 +1,13 @@
|
|
1
|
-
# file generated by
|
1
|
+
# file generated by setuptools-scm
|
2
2
|
# don't change, don't track in version control
|
3
|
+
|
4
|
+
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
5
|
+
|
3
6
|
TYPE_CHECKING = False
|
4
7
|
if TYPE_CHECKING:
|
5
|
-
from typing import Tuple
|
8
|
+
from typing import Tuple
|
9
|
+
from typing import Union
|
10
|
+
|
6
11
|
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
7
12
|
else:
|
8
13
|
VERSION_TUPLE = object
|
@@ -12,5 +17,5 @@ __version__: str
|
|
12
17
|
__version_tuple__: VERSION_TUPLE
|
13
18
|
version_tuple: VERSION_TUPLE
|
14
19
|
|
15
|
-
__version__ = version = '0.7.
|
16
|
-
__version_tuple__ = version_tuple = (0, 7,
|
20
|
+
__version__ = version = '0.7.3'
|
21
|
+
__version_tuple__ = version_tuple = (0, 7, 3)
|
@@ -24,9 +24,13 @@ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES["optimum.rbln"])
|
|
24
24
|
|
25
25
|
_import_structure = {
|
26
26
|
"pipelines": [
|
27
|
+
"RBLNKandinskyV22CombinedPipeline",
|
28
|
+
"RBLNKandinskyV22Img2ImgCombinedPipeline",
|
27
29
|
"RBLNKandinskyV22InpaintCombinedPipeline",
|
28
30
|
"RBLNKandinskyV22InpaintPipeline",
|
31
|
+
"RBLNKandinskyV22Img2ImgPipeline",
|
29
32
|
"RBLNKandinskyV22PriorPipeline",
|
33
|
+
"RBLNKandinskyV22Pipeline",
|
30
34
|
"RBLNStableDiffusionPipeline",
|
31
35
|
"RBLNStableDiffusionXLPipeline",
|
32
36
|
"RBLNStableDiffusionImg2ImgPipeline",
|
@@ -66,8 +70,12 @@ if TYPE_CHECKING:
|
|
66
70
|
RBLNVQModel,
|
67
71
|
)
|
68
72
|
from .pipelines import (
|
73
|
+
RBLNKandinskyV22CombinedPipeline,
|
74
|
+
RBLNKandinskyV22Img2ImgCombinedPipeline,
|
75
|
+
RBLNKandinskyV22Img2ImgPipeline,
|
69
76
|
RBLNKandinskyV22InpaintCombinedPipeline,
|
70
77
|
RBLNKandinskyV22InpaintPipeline,
|
78
|
+
RBLNKandinskyV22Pipeline,
|
71
79
|
RBLNKandinskyV22PriorPipeline,
|
72
80
|
RBLNMultiControlNetModel,
|
73
81
|
RBLNStableDiffusion3Img2ImgPipeline,
|
@@ -23,7 +23,6 @@ from ..modeling import RBLNModel
|
|
23
23
|
from ..modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
|
24
24
|
from ..utils.decorator_utils import remove_compile_time_kwargs
|
25
25
|
from ..utils.logging import get_logger
|
26
|
-
from . import pipelines
|
27
26
|
|
28
27
|
|
29
28
|
logger = get_logger(__name__)
|
@@ -67,17 +66,16 @@ class RBLNDiffusionMixin:
|
|
67
66
|
as keys in rbln_config
|
68
67
|
"""
|
69
68
|
|
69
|
+
_connected_classes = {}
|
70
70
|
_submodules = []
|
71
71
|
_prefix = {}
|
72
72
|
|
73
73
|
@classmethod
|
74
|
-
|
75
|
-
def img2img_pipeline(cls):
|
74
|
+
def is_img2img_pipeline(cls):
|
76
75
|
return "Img2Img" in cls.__name__
|
77
76
|
|
78
77
|
@classmethod
|
79
|
-
|
80
|
-
def inpaint_pipeline(cls):
|
78
|
+
def is_inpaint_pipeline(cls):
|
81
79
|
return "Inpaint" in cls.__name__
|
82
80
|
|
83
81
|
@classmethod
|
@@ -100,34 +98,11 @@ class RBLNDiffusionMixin:
|
|
100
98
|
submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
|
101
99
|
submodule_config.update(
|
102
100
|
{
|
103
|
-
"img2img_pipeline": cls.
|
104
|
-
"inpaint_pipeline": cls.
|
101
|
+
"img2img_pipeline": cls.is_img2img_pipeline(),
|
102
|
+
"inpaint_pipeline": cls.is_inpaint_pipeline(),
|
105
103
|
}
|
106
104
|
)
|
107
105
|
submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
|
108
|
-
elif hasattr(pipelines, submodule_class_name):
|
109
|
-
submodule_config = rbln_config.get(submodule_name, {})
|
110
|
-
submodule_config = copy.deepcopy(submodule_config)
|
111
|
-
|
112
|
-
submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"{submodule_class_name}")
|
113
|
-
prefix = cls._prefix.get(submodule_name, "")
|
114
|
-
connected_submodules = cls._connected_classes.get(submodule_name)._submodules
|
115
|
-
for connected_submodule_name in connected_submodules:
|
116
|
-
connected_submodule_config = rbln_config.pop(prefix + connected_submodule_name, {})
|
117
|
-
if connected_submodule_name in submodule_config:
|
118
|
-
submodule_config[connected_submodule_name].update(connected_submodule_config)
|
119
|
-
else:
|
120
|
-
submodule_config[connected_submodule_name] = connected_submodule_config
|
121
|
-
|
122
|
-
submodules = copy.deepcopy(cls._submodules)
|
123
|
-
submodules += [prefix + connected_submodule_name for connected_submodule_name in connected_submodules]
|
124
|
-
|
125
|
-
pipe_global_config = {k: v for k, v in rbln_config.items() if k not in submodules}
|
126
|
-
for connected_submodule_name in connected_submodules:
|
127
|
-
submodule_config[connected_submodule_name].update(
|
128
|
-
{k: v for k, v in pipe_global_config.items() if k not in submodule_config}
|
129
|
-
)
|
130
|
-
rbln_config[submodule_name] = submodule_config
|
131
106
|
else:
|
132
107
|
raise ValueError(f"submodule {submodule_name} isn't supported")
|
133
108
|
return submodule_config
|
@@ -193,25 +168,8 @@ class RBLNDiffusionMixin:
|
|
193
168
|
else:
|
194
169
|
# raise error if any of submodules are torch module.
|
195
170
|
model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
|
196
|
-
|
197
|
-
|
198
|
-
for submodule in cls._submodules:
|
199
|
-
submodule_config = rbln_config.pop(submodule, {})
|
200
|
-
prefix = cls._prefix.get(submodule, "")
|
201
|
-
connected_submodules = cls._connected_classes.get(submodule)._submodules
|
202
|
-
for connected_submodule_name in connected_submodules:
|
203
|
-
connected_submodule_config = submodule_config.pop(connected_submodule_name, {})
|
204
|
-
if connected_submodule_config:
|
205
|
-
rbln_config[prefix + connected_submodule_name] = connected_submodule_config
|
206
|
-
submodules.append(prefix + connected_submodule_name)
|
207
|
-
pipe_global_config = {k: v for k, v in rbln_config.items() if k not in submodules}
|
208
|
-
for submodule in submodules:
|
209
|
-
if submodule in rbln_config:
|
210
|
-
rbln_config[submodule].update(pipe_global_config)
|
211
|
-
else:
|
212
|
-
submodules = cls._submodules
|
213
|
-
|
214
|
-
for submodule_name in submodules:
|
171
|
+
rbln_config = cls._flatten_rbln_config(rbln_config)
|
172
|
+
for submodule_name in cls._submodules:
|
215
173
|
if isinstance(kwargs.get(submodule_name), torch.nn.Module):
|
216
174
|
raise AssertionError(
|
217
175
|
f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
|
@@ -260,9 +218,89 @@ class RBLNDiffusionMixin:
|
|
260
218
|
lora_scales=lora_scales,
|
261
219
|
)
|
262
220
|
|
263
|
-
|
221
|
+
if cls._load_connected_pipes:
|
222
|
+
compiled_submodules = cls._compile_pipelines(model, passed_submodules, model_save_dir, rbln_config)
|
223
|
+
else:
|
224
|
+
compiled_submodules = cls._compile_submodules(model, passed_submodules, model_save_dir, rbln_config)
|
264
225
|
return cls._construct_pipe(model, compiled_submodules, model_save_dir, rbln_config)
|
265
226
|
|
227
|
+
@classmethod
|
228
|
+
def _prepare_rbln_config(
|
229
|
+
cls,
|
230
|
+
rbln_config,
|
231
|
+
) -> Dict[str, Any]:
|
232
|
+
prepared_config = {}
|
233
|
+
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
234
|
+
connected_pipe_config = rbln_config.pop(connected_pipe_name, {})
|
235
|
+
prefix = cls._prefix.get(connected_pipe_name, "")
|
236
|
+
guidance_scale = rbln_config.pop(f"{prefix}guidance_scale", None)
|
237
|
+
if "guidance_scale" not in connected_pipe_config and guidance_scale is not None:
|
238
|
+
connected_pipe_config["guidance_scale"] = guidance_scale
|
239
|
+
for submodule_name in connected_pipe_cls._submodules:
|
240
|
+
submodule_config = rbln_config.pop(prefix + submodule_name, {})
|
241
|
+
if submodule_name not in connected_pipe_config:
|
242
|
+
connected_pipe_config[submodule_name] = {}
|
243
|
+
connected_pipe_config[submodule_name].update(
|
244
|
+
{k: v for k, v in submodule_config.items() if k not in connected_pipe_config[submodule_name]}
|
245
|
+
)
|
246
|
+
prepared_config[connected_pipe_name] = connected_pipe_config
|
247
|
+
prepared_config.update(rbln_config)
|
248
|
+
return prepared_config
|
249
|
+
|
250
|
+
@classmethod
|
251
|
+
def _flatten_rbln_config(
|
252
|
+
cls,
|
253
|
+
rbln_config,
|
254
|
+
) -> Dict[str, Any]:
|
255
|
+
prepared_config = cls._prepare_rbln_config(rbln_config)
|
256
|
+
flattened_config = {}
|
257
|
+
pipe_global_config = {k: v for k, v in prepared_config.items() if k not in cls._connected_classes.keys()}
|
258
|
+
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
259
|
+
connected_pipe_config = prepared_config.pop(connected_pipe_name)
|
260
|
+
prefix = cls._prefix.get(connected_pipe_name, "")
|
261
|
+
connected_pipe_global_config = {
|
262
|
+
k: v for k, v in connected_pipe_config.items() if k not in connected_pipe_cls._submodules
|
263
|
+
}
|
264
|
+
for submodule_name in connected_pipe_cls._submodules:
|
265
|
+
flattened_config[prefix + submodule_name] = connected_pipe_config[submodule_name]
|
266
|
+
flattened_config[prefix + submodule_name].update(
|
267
|
+
{
|
268
|
+
k: v
|
269
|
+
for k, v in connected_pipe_global_config.items()
|
270
|
+
if k not in flattened_config[prefix + submodule_name]
|
271
|
+
}
|
272
|
+
)
|
273
|
+
flattened_config.update(pipe_global_config)
|
274
|
+
return flattened_config
|
275
|
+
|
276
|
+
@classmethod
|
277
|
+
def _compile_pipelines(
|
278
|
+
cls,
|
279
|
+
model: torch.nn.Module,
|
280
|
+
passed_submodules: Dict[str, RBLNModel],
|
281
|
+
model_save_dir: Optional[PathLike],
|
282
|
+
rbln_config: Dict[str, Any],
|
283
|
+
) -> Dict[str, RBLNModel]:
|
284
|
+
compiled_submodules = {}
|
285
|
+
|
286
|
+
rbln_config = cls._prepare_rbln_config(rbln_config)
|
287
|
+
pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._connected_classes.keys()}
|
288
|
+
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
289
|
+
connected_pipe_submodules = {}
|
290
|
+
prefix = cls._prefix.get(connected_pipe_name, "")
|
291
|
+
for submodule_name in connected_pipe_cls._submodules:
|
292
|
+
connected_pipe_submodules[submodule_name] = passed_submodules.get(prefix + submodule_name, None)
|
293
|
+
connected_pipe = getattr(model, connected_pipe_name)
|
294
|
+
connected_pipe_config = {}
|
295
|
+
connected_pipe_config.update(pipe_global_config)
|
296
|
+
connected_pipe_config.update(rbln_config[connected_pipe_name])
|
297
|
+
connected_pipe_compiled_submodules = connected_pipe_cls._compile_submodules(
|
298
|
+
connected_pipe, connected_pipe_submodules, model_save_dir, connected_pipe_config, prefix
|
299
|
+
)
|
300
|
+
for submodule_name, compiled_submodule in connected_pipe_compiled_submodules.items():
|
301
|
+
compiled_submodules[prefix + submodule_name] = compiled_submodule
|
302
|
+
return compiled_submodules
|
303
|
+
|
266
304
|
@classmethod
|
267
305
|
def _compile_submodules(
|
268
306
|
cls,
|
@@ -301,41 +339,6 @@ class RBLNDiffusionMixin:
|
|
301
339
|
model_save_dir=model_save_dir,
|
302
340
|
rbln_config=submodule_rbln_config,
|
303
341
|
)
|
304
|
-
elif hasattr(pipelines, submodule.__class__.__name__):
|
305
|
-
connected_pipe = submodule
|
306
|
-
connected_pipe_model_save_dir = model_save_dir
|
307
|
-
connected_pipe_rbln_config = submodule_rbln_config
|
308
|
-
connected_pipe_cls: RBLNDiffusionMixin = getattr(
|
309
|
-
importlib.import_module("optimum.rbln"), connected_pipe.__class__.__name__
|
310
|
-
)
|
311
|
-
submodule_dict = {}
|
312
|
-
for name in connected_pipe.config.keys():
|
313
|
-
if hasattr(connected_pipe, name):
|
314
|
-
submodule_dict[name] = getattr(connected_pipe, name)
|
315
|
-
connected_pipe = connected_pipe_cls(**submodule_dict)
|
316
|
-
connected_pipe_submodules = {}
|
317
|
-
prefix = cls._prefix.get(submodule_name, "")
|
318
|
-
for name in connected_pipe_cls._submodules:
|
319
|
-
if prefix + name in passed_submodules:
|
320
|
-
connected_pipe_submodules[name] = passed_submodules.get(prefix + name)
|
321
|
-
|
322
|
-
connected_pipe_compiled_submodules = connected_pipe_cls._compile_submodules(
|
323
|
-
model=connected_pipe,
|
324
|
-
passed_submodules=connected_pipe_submodules,
|
325
|
-
model_save_dir=model_save_dir,
|
326
|
-
rbln_config=connected_pipe_rbln_config,
|
327
|
-
prefix=prefix,
|
328
|
-
)
|
329
|
-
connected_pipe = connected_pipe_cls._construct_pipe(
|
330
|
-
connected_pipe,
|
331
|
-
connected_pipe_compiled_submodules,
|
332
|
-
connected_pipe_model_save_dir,
|
333
|
-
connected_pipe_rbln_config,
|
334
|
-
)
|
335
|
-
|
336
|
-
for name in connected_pipe_cls._submodules:
|
337
|
-
compiled_submodules[prefix + name] = getattr(connected_pipe, name)
|
338
|
-
submodule = connected_pipe
|
339
342
|
else:
|
340
343
|
raise ValueError(f"Unknown class of submodule({submodule_name}) : {submodule.__class__.__name__} ")
|
341
344
|
|
@@ -368,23 +371,16 @@ class RBLNDiffusionMixin:
|
|
368
371
|
@classmethod
|
369
372
|
def _construct_pipe(cls, model, submodules, model_save_dir, rbln_config):
|
370
373
|
# Construct finalize pipe setup with compiled submodules and configurations
|
371
|
-
submodule_names = []
|
372
|
-
for submodule_name in cls._submodules:
|
373
|
-
submodule = getattr(model, submodule_name)
|
374
|
-
if hasattr(pipelines, submodule.__class__.__name__):
|
375
|
-
prefix = cls._prefix.get(submodule_name, "")
|
376
|
-
connected_pipe_submodules = submodules[submodule_name].__class__._submodules
|
377
|
-
connected_pipe_submodules = [prefix + name for name in connected_pipe_submodules]
|
378
|
-
submodule_names += connected_pipe_submodules
|
379
|
-
setattr(model, submodule_name, submodules[submodule_name])
|
380
|
-
else:
|
381
|
-
submodule_names.append(submodule_name)
|
382
|
-
|
383
374
|
if model_save_dir is not None:
|
384
375
|
# To skip saving original pytorch modules
|
385
|
-
for submodule_name in
|
376
|
+
for submodule_name in cls._submodules:
|
386
377
|
delattr(model, submodule_name)
|
387
378
|
|
379
|
+
if cls._load_connected_pipes:
|
380
|
+
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
381
|
+
for submodule_name in connected_pipe_cls._submodules:
|
382
|
+
delattr(getattr(model, connected_pipe_name), submodule_name)
|
383
|
+
|
388
384
|
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
389
385
|
# So config must be saved again, later.
|
390
386
|
model.save_pretrained(model_save_dir)
|
@@ -392,10 +388,15 @@ class RBLNDiffusionMixin:
|
|
392
388
|
# Causing warning messeages.
|
393
389
|
|
394
390
|
update_dict = {}
|
395
|
-
for submodule_name in
|
391
|
+
for submodule_name in cls._submodules:
|
396
392
|
# replace submodule
|
397
393
|
setattr(model, submodule_name, submodules[submodule_name])
|
398
394
|
update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
|
395
|
+
if cls._load_connected_pipes:
|
396
|
+
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
397
|
+
prefix = cls._prefix.get(connected_pipe_name, "")
|
398
|
+
for submodule_name in connected_pipe_cls._submodules:
|
399
|
+
setattr(getattr(model, connected_pipe_name), submodule_name, submodules[prefix + submodule_name])
|
399
400
|
|
400
401
|
# Update config to be able to load from model directory.
|
401
402
|
#
|
@@ -414,16 +415,9 @@ class RBLNDiffusionMixin:
|
|
414
415
|
if rbln_config.get("optimize_host_memory") is False:
|
415
416
|
# Keep compiled_model objs to further analysis. -> TODO: remove soon...
|
416
417
|
model.compiled_models = []
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
for submodule_name in connected_pipe.__class__._submodules:
|
421
|
-
submodule = getattr(connected_pipe, submodule_name)
|
422
|
-
model.compiled_models.extend(submodule.compiled_models)
|
423
|
-
else:
|
424
|
-
for name in cls._submodules:
|
425
|
-
submodule = getattr(model, name)
|
426
|
-
model.compiled_models.extend(submodule.compiled_models)
|
418
|
+
for name in cls._submodules:
|
419
|
+
submodule = getattr(model, name)
|
420
|
+
model.compiled_models.extend(submodule.compiled_models)
|
427
421
|
|
428
422
|
return model
|
429
423
|
|
@@ -90,9 +90,17 @@ class RBLNVQModel(RBLNModel):
|
|
90
90
|
model_config: "PretrainedConfig",
|
91
91
|
rbln_kwargs: Dict[str, Any] = {},
|
92
92
|
) -> RBLNConfig:
|
93
|
-
batch_size = rbln_kwargs.get("batch_size")
|
94
|
-
|
95
|
-
|
93
|
+
batch_size = rbln_kwargs.get("batch_size")
|
94
|
+
if batch_size is None:
|
95
|
+
batch_size = 1
|
96
|
+
|
97
|
+
height = rbln_kwargs.get("img_height")
|
98
|
+
if height is None:
|
99
|
+
height = 512
|
100
|
+
|
101
|
+
width = rbln_kwargs.get("img_width")
|
102
|
+
if width is None:
|
103
|
+
width = 512
|
96
104
|
|
97
105
|
if hasattr(model_config, "block_out_channels"):
|
98
106
|
scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
|
@@ -176,15 +176,22 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
176
176
|
raise ValueError("Both image height and image width must be given or not given")
|
177
177
|
elif image_size[0] is None and image_size[1] is None:
|
178
178
|
if rbln_config["img2img_pipeline"]:
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
179
|
+
if hasattr(pipe, "vae"):
|
180
|
+
# In case of img2img, sample size of unet is determined by vae encoder.
|
181
|
+
vae_sample_size = pipe.vae.config.sample_size
|
182
|
+
if isinstance(vae_sample_size, int):
|
183
|
+
sample_size = vae_sample_size // scale_factor
|
184
|
+
else:
|
185
|
+
sample_size = (
|
186
|
+
vae_sample_size[0] // scale_factor,
|
187
|
+
vae_sample_size[1] // scale_factor,
|
188
|
+
)
|
189
|
+
elif hasattr(pipe, "movq"):
|
190
|
+
logger.warning(
|
191
|
+
"RBLN config 'img_height' and 'img_width' should have been provided for this pipeline. "
|
192
|
+
"Both variable will be set 512 by default."
|
187
193
|
)
|
194
|
+
sample_size = (512 // scale_factor, 512 // scale_factor)
|
188
195
|
else:
|
189
196
|
sample_size = pipe.unet.config.sample_size
|
190
197
|
else:
|
@@ -26,9 +26,13 @@ _import_structure = {
|
|
26
26
|
"RBLNStableDiffusionXLControlNetPipeline",
|
27
27
|
],
|
28
28
|
"kandinsky2_2": [
|
29
|
+
"RBLNKandinskyV22CombinedPipeline",
|
30
|
+
"RBLNKandinskyV22Img2ImgCombinedPipeline",
|
29
31
|
"RBLNKandinskyV22InpaintCombinedPipeline",
|
30
32
|
"RBLNKandinskyV22InpaintPipeline",
|
33
|
+
"RBLNKandinskyV22Img2ImgPipeline",
|
31
34
|
"RBLNKandinskyV22PriorPipeline",
|
35
|
+
"RBLNKandinskyV22Pipeline",
|
32
36
|
],
|
33
37
|
"stable_diffusion": [
|
34
38
|
"RBLNStableDiffusionImg2ImgPipeline",
|
@@ -55,8 +59,12 @@ if TYPE_CHECKING:
|
|
55
59
|
RBLNStableDiffusionXLControlNetPipeline,
|
56
60
|
)
|
57
61
|
from .kandinsky2_2 import (
|
62
|
+
RBLNKandinskyV22CombinedPipeline,
|
63
|
+
RBLNKandinskyV22Img2ImgCombinedPipeline,
|
64
|
+
RBLNKandinskyV22Img2ImgPipeline,
|
58
65
|
RBLNKandinskyV22InpaintCombinedPipeline,
|
59
66
|
RBLNKandinskyV22InpaintPipeline,
|
67
|
+
RBLNKandinskyV22Pipeline,
|
60
68
|
RBLNKandinskyV22PriorPipeline,
|
61
69
|
)
|
62
70
|
from .stable_diffusion import (
|
@@ -12,6 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from .
|
15
|
+
from .pipeline_kandinsky2_2 import RBLNKandinskyV22Pipeline
|
16
|
+
from .pipeline_kandinsky2_2_combined import (
|
17
|
+
RBLNKandinskyV22CombinedPipeline,
|
18
|
+
RBLNKandinskyV22Img2ImgCombinedPipeline,
|
19
|
+
RBLNKandinskyV22InpaintCombinedPipeline,
|
20
|
+
)
|
21
|
+
from .pipeline_kandinsky2_2_img2img import RBLNKandinskyV22Img2ImgPipeline
|
16
22
|
from .pipeline_kandinsky2_2_inpaint import RBLNKandinskyV22InpaintPipeline
|
17
23
|
from .pipeline_kandinsky2_2_prior import RBLNKandinskyV22PriorPipeline
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from diffusers import KandinskyV22Pipeline
|
16
|
+
|
17
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
18
|
+
|
19
|
+
|
20
|
+
class RBLNKandinskyV22Pipeline(RBLNDiffusionMixin, KandinskyV22Pipeline):
|
21
|
+
original_class = KandinskyV22Pipeline
|
22
|
+
_submodules = ["unet", "movq"]
|
23
|
+
|
24
|
+
def get_compiled_image_size(self):
|
25
|
+
return self.movq.image_size
|
@@ -14,6 +14,8 @@
|
|
14
14
|
|
15
15
|
from diffusers import (
|
16
16
|
DDPMScheduler,
|
17
|
+
KandinskyV22CombinedPipeline,
|
18
|
+
KandinskyV22Img2ImgCombinedPipeline,
|
17
19
|
KandinskyV22InpaintCombinedPipeline,
|
18
20
|
PriorTransformer,
|
19
21
|
UnCLIPScheduler,
|
@@ -28,14 +30,118 @@ from transformers import (
|
|
28
30
|
)
|
29
31
|
|
30
32
|
from ...modeling_diffusers import RBLNDiffusionMixin
|
33
|
+
from .pipeline_kandinsky2_2 import RBLNKandinskyV22Pipeline
|
34
|
+
from .pipeline_kandinsky2_2_img2img import RBLNKandinskyV22Img2ImgPipeline
|
31
35
|
from .pipeline_kandinsky2_2_inpaint import RBLNKandinskyV22InpaintPipeline
|
32
36
|
from .pipeline_kandinsky2_2_prior import RBLNKandinskyV22PriorPipeline
|
33
37
|
|
34
38
|
|
39
|
+
class RBLNKandinskyV22CombinedPipeline(RBLNDiffusionMixin, KandinskyV22CombinedPipeline):
|
40
|
+
original_class = KandinskyV22CombinedPipeline
|
41
|
+
_connected_classes = {"prior_pipe": RBLNKandinskyV22PriorPipeline, "decoder_pipe": RBLNKandinskyV22Pipeline}
|
42
|
+
_submodules = ["prior_image_encoder", "prior_text_encoder", "prior_prior", "unet", "movq"]
|
43
|
+
_prefix = {"prior_pipe": "prior_"}
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
unet: "UNet2DConditionModel",
|
48
|
+
scheduler: "DDPMScheduler",
|
49
|
+
movq: "VQModel",
|
50
|
+
prior_prior: "PriorTransformer",
|
51
|
+
prior_image_encoder: "CLIPVisionModelWithProjection",
|
52
|
+
prior_text_encoder: "CLIPTextModelWithProjection",
|
53
|
+
prior_tokenizer: "CLIPTokenizer",
|
54
|
+
prior_scheduler: "UnCLIPScheduler",
|
55
|
+
prior_image_processor: "CLIPImageProcessor",
|
56
|
+
):
|
57
|
+
RBLNDiffusionMixin.__init__(self)
|
58
|
+
super(KandinskyV22CombinedPipeline, self).__init__()
|
59
|
+
|
60
|
+
self.register_modules(
|
61
|
+
unet=unet,
|
62
|
+
scheduler=scheduler,
|
63
|
+
movq=movq,
|
64
|
+
prior_prior=prior_prior,
|
65
|
+
prior_image_encoder=prior_image_encoder,
|
66
|
+
prior_text_encoder=prior_text_encoder,
|
67
|
+
prior_tokenizer=prior_tokenizer,
|
68
|
+
prior_scheduler=prior_scheduler,
|
69
|
+
prior_image_processor=prior_image_processor,
|
70
|
+
)
|
71
|
+
|
72
|
+
self.prior_pipe = RBLNKandinskyV22PriorPipeline(
|
73
|
+
prior=prior_prior,
|
74
|
+
image_encoder=prior_image_encoder,
|
75
|
+
text_encoder=prior_text_encoder,
|
76
|
+
tokenizer=prior_tokenizer,
|
77
|
+
scheduler=prior_scheduler,
|
78
|
+
image_processor=prior_image_processor,
|
79
|
+
)
|
80
|
+
self.decoder_pipe = RBLNKandinskyV22Pipeline(
|
81
|
+
unet=unet,
|
82
|
+
scheduler=scheduler,
|
83
|
+
movq=movq,
|
84
|
+
)
|
85
|
+
|
86
|
+
def get_compiled_image_size(self):
|
87
|
+
return self.movq.image_size
|
88
|
+
|
89
|
+
|
90
|
+
class RBLNKandinskyV22Img2ImgCombinedPipeline(RBLNDiffusionMixin, KandinskyV22Img2ImgCombinedPipeline):
|
91
|
+
original_class = KandinskyV22Img2ImgCombinedPipeline
|
92
|
+
_connected_classes = {"prior_pipe": RBLNKandinskyV22PriorPipeline, "decoder_pipe": RBLNKandinskyV22Img2ImgPipeline}
|
93
|
+
_submodules = ["prior_image_encoder", "prior_text_encoder", "prior_prior", "unet", "movq"]
|
94
|
+
_prefix = {"prior_pipe": "prior_"}
|
95
|
+
|
96
|
+
def __init__(
|
97
|
+
self,
|
98
|
+
unet: "UNet2DConditionModel",
|
99
|
+
scheduler: "DDPMScheduler",
|
100
|
+
movq: "VQModel",
|
101
|
+
prior_prior: "PriorTransformer",
|
102
|
+
prior_image_encoder: "CLIPVisionModelWithProjection",
|
103
|
+
prior_text_encoder: "CLIPTextModelWithProjection",
|
104
|
+
prior_tokenizer: "CLIPTokenizer",
|
105
|
+
prior_scheduler: "UnCLIPScheduler",
|
106
|
+
prior_image_processor: "CLIPImageProcessor",
|
107
|
+
):
|
108
|
+
RBLNDiffusionMixin.__init__(self)
|
109
|
+
super(KandinskyV22Img2ImgCombinedPipeline, self).__init__()
|
110
|
+
|
111
|
+
self.register_modules(
|
112
|
+
unet=unet,
|
113
|
+
scheduler=scheduler,
|
114
|
+
movq=movq,
|
115
|
+
prior_prior=prior_prior,
|
116
|
+
prior_image_encoder=prior_image_encoder,
|
117
|
+
prior_text_encoder=prior_text_encoder,
|
118
|
+
prior_tokenizer=prior_tokenizer,
|
119
|
+
prior_scheduler=prior_scheduler,
|
120
|
+
prior_image_processor=prior_image_processor,
|
121
|
+
)
|
122
|
+
|
123
|
+
self.prior_pipe = RBLNKandinskyV22PriorPipeline(
|
124
|
+
prior=prior_prior,
|
125
|
+
image_encoder=prior_image_encoder,
|
126
|
+
text_encoder=prior_text_encoder,
|
127
|
+
tokenizer=prior_tokenizer,
|
128
|
+
scheduler=prior_scheduler,
|
129
|
+
image_processor=prior_image_processor,
|
130
|
+
)
|
131
|
+
self.decoder_pipe = RBLNKandinskyV22Img2ImgPipeline(
|
132
|
+
unet=unet,
|
133
|
+
scheduler=scheduler,
|
134
|
+
movq=movq,
|
135
|
+
)
|
136
|
+
|
137
|
+
def get_compiled_image_size(self):
|
138
|
+
return self.movq.image_size
|
139
|
+
|
140
|
+
|
35
141
|
class RBLNKandinskyV22InpaintCombinedPipeline(RBLNDiffusionMixin, KandinskyV22InpaintCombinedPipeline):
|
36
142
|
original_class = KandinskyV22InpaintCombinedPipeline
|
37
143
|
_connected_classes = {"prior_pipe": RBLNKandinskyV22PriorPipeline, "decoder_pipe": RBLNKandinskyV22InpaintPipeline}
|
38
|
-
_submodules = ["
|
144
|
+
_submodules = ["prior_image_encoder", "prior_text_encoder", "prior_prior", "unet", "movq"]
|
39
145
|
_prefix = {"prior_pipe": "prior_"}
|
40
146
|
|
41
147
|
def __init__(
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from diffusers import KandinskyV22Img2ImgPipeline
|
16
|
+
|
17
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
18
|
+
|
19
|
+
|
20
|
+
class RBLNKandinskyV22Img2ImgPipeline(RBLNDiffusionMixin, KandinskyV22Img2ImgPipeline):
|
21
|
+
original_class = KandinskyV22Img2ImgPipeline
|
22
|
+
_submodules = ["unet", "movq"]
|
23
|
+
|
24
|
+
def get_compiled_image_size(self):
|
25
|
+
return self.movq.image_size
|
@@ -20,3 +20,6 @@ from ...modeling_diffusers import RBLNDiffusionMixin
|
|
20
20
|
class RBLNKandinskyV22InpaintPipeline(RBLNDiffusionMixin, KandinskyV22InpaintPipeline):
|
21
21
|
original_class = KandinskyV22InpaintPipeline
|
22
22
|
_submodules = ["unet", "movq"]
|
23
|
+
|
24
|
+
def get_compiled_image_size(self):
|
25
|
+
return self.movq.image_size
|