optimum-rbln 0.2.1a4__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. optimum/rbln/__init__.py +14 -2
  2. optimum/rbln/__version__.py +9 -4
  3. optimum/rbln/diffusers/__init__.py +10 -0
  4. optimum/rbln/diffusers/modeling_diffusers.py +132 -25
  5. optimum/rbln/diffusers/models/__init__.py +7 -1
  6. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +52 -2
  8. optimum/rbln/diffusers/models/autoencoders/vq_model.py +159 -0
  9. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  10. optimum/rbln/diffusers/models/transformers/prior_transformer.py +174 -0
  11. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +57 -14
  12. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  13. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +17 -0
  14. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +83 -0
  15. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +22 -0
  16. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +22 -0
  17. optimum/rbln/modeling_base.py +10 -9
  18. optimum/rbln/transformers/__init__.py +2 -0
  19. optimum/rbln/transformers/models/__init__.py +12 -2
  20. optimum/rbln/transformers/models/clip/__init__.py +6 -1
  21. optimum/rbln/transformers/models/clip/modeling_clip.py +26 -1
  22. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +3 -1
  23. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +1 -1
  24. optimum/rbln/utils/import_utils.py +7 -0
  25. {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/METADATA +1 -1
  26. {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/RECORD +28 -22
  27. {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/WHEEL +0 -0
  28. {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -48,6 +48,7 @@ _import_structure = {
48
48
  "RBLNCLIPTextModel",
49
49
  "RBLNCLIPTextModelWithProjection",
50
50
  "RBLNCLIPVisionModel",
51
+ "RBLNCLIPVisionModelWithProjection",
51
52
  "RBLNDPTForDepthEstimation",
52
53
  "RBLNExaoneForCausalLM",
53
54
  "RBLNGemmaForCausalLM",
@@ -74,11 +75,15 @@ _import_structure = {
74
75
  "RBLNBertForMaskedLM",
75
76
  ],
76
77
  "diffusers": [
78
+ "RBLNAutoencoderKL",
79
+ "RBLNControlNetModel",
80
+ "RBLNPriorTransformer",
81
+ "RBLNKandinskyV22InpaintCombinedPipeline",
82
+ "RBLNKandinskyV22InpaintPipeline",
83
+ "RBLNKandinskyV22PriorPipeline",
77
84
  "RBLNStableDiffusionPipeline",
78
85
  "RBLNStableDiffusionXLPipeline",
79
- "RBLNAutoencoderKL",
80
86
  "RBLNUNet2DConditionModel",
81
- "RBLNControlNetModel",
82
87
  "RBLNStableDiffusionImg2ImgPipeline",
83
88
  "RBLNStableDiffusionInpaintPipeline",
84
89
  "RBLNStableDiffusionControlNetImg2ImgPipeline",
@@ -88,6 +93,7 @@ _import_structure = {
88
93
  "RBLNStableDiffusionControlNetPipeline",
89
94
  "RBLNStableDiffusionXLControlNetPipeline",
90
95
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
96
+ "RBLNVQModel",
91
97
  "RBLNSD3Transformer2DModel",
92
98
  "RBLNStableDiffusion3Img2ImgPipeline",
93
99
  "RBLNStableDiffusion3InpaintPipeline",
@@ -101,7 +107,11 @@ if TYPE_CHECKING:
101
107
  RBLNAutoencoderKL,
102
108
  RBLNControlNetModel,
103
109
  RBLNDiffusionMixin,
110
+ RBLNKandinskyV22InpaintCombinedPipeline,
111
+ RBLNKandinskyV22InpaintPipeline,
112
+ RBLNKandinskyV22PriorPipeline,
104
113
  RBLNMultiControlNetModel,
114
+ RBLNPriorTransformer,
105
115
  RBLNSD3Transformer2DModel,
106
116
  RBLNStableDiffusion3Img2ImgPipeline,
107
117
  RBLNStableDiffusion3InpaintPipeline,
@@ -117,6 +127,7 @@ if TYPE_CHECKING:
117
127
  RBLNStableDiffusionXLInpaintPipeline,
118
128
  RBLNStableDiffusionXLPipeline,
119
129
  RBLNUNet2DConditionModel,
130
+ RBLNVQModel,
120
131
  )
121
132
  from .modeling import (
122
133
  RBLNBaseModel,
@@ -148,6 +159,7 @@ if TYPE_CHECKING:
148
159
  RBLNCLIPTextModel,
149
160
  RBLNCLIPTextModelWithProjection,
150
161
  RBLNCLIPVisionModel,
162
+ RBLNCLIPVisionModelWithProjection,
151
163
  RBLNDistilBertForQuestionAnswering,
152
164
  RBLNDPTForDepthEstimation,
153
165
  RBLNExaoneForCausalLM,
@@ -1,8 +1,13 @@
1
- # file generated by setuptools_scm
1
+ # file generated by setuptools-scm
2
2
  # don't change, don't track in version control
3
+
4
+ __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
5
+
3
6
  TYPE_CHECKING = False
4
7
  if TYPE_CHECKING:
5
- from typing import Tuple, Union
8
+ from typing import Tuple
9
+ from typing import Union
10
+
6
11
  VERSION_TUPLE = Tuple[Union[int, str], ...]
7
12
  else:
8
13
  VERSION_TUPLE = object
@@ -12,5 +17,5 @@ __version__: str
12
17
  __version_tuple__: VERSION_TUPLE
13
18
  version_tuple: VERSION_TUPLE
14
19
 
15
- __version__ = version = '0.2.1a4'
16
- __version_tuple__ = version_tuple = (0, 2, 1)
20
+ __version__ = version = '0.7.2'
21
+ __version_tuple__ = version_tuple = (0, 7, 2)
@@ -24,6 +24,9 @@ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES["optimum.rbln"])
24
24
 
25
25
  _import_structure = {
26
26
  "pipelines": [
27
+ "RBLNKandinskyV22InpaintCombinedPipeline",
28
+ "RBLNKandinskyV22InpaintPipeline",
29
+ "RBLNKandinskyV22PriorPipeline",
27
30
  "RBLNStableDiffusionPipeline",
28
31
  "RBLNStableDiffusionXLPipeline",
29
32
  "RBLNStableDiffusionImg2ImgPipeline",
@@ -44,6 +47,8 @@ _import_structure = {
44
47
  "RBLNUNet2DConditionModel",
45
48
  "RBLNControlNetModel",
46
49
  "RBLNSD3Transformer2DModel",
50
+ "RBLNPriorTransformer",
51
+ "RBLNVQModel",
47
52
  ],
48
53
  "modeling_diffusers": [
49
54
  "RBLNDiffusionMixin",
@@ -55,10 +60,15 @@ if TYPE_CHECKING:
55
60
  from .models import (
56
61
  RBLNAutoencoderKL,
57
62
  RBLNControlNetModel,
63
+ RBLNPriorTransformer,
58
64
  RBLNSD3Transformer2DModel,
59
65
  RBLNUNet2DConditionModel,
66
+ RBLNVQModel,
60
67
  )
61
68
  from .pipelines import (
69
+ RBLNKandinskyV22InpaintCombinedPipeline,
70
+ RBLNKandinskyV22InpaintPipeline,
71
+ RBLNKandinskyV22PriorPipeline,
62
72
  RBLNMultiControlNetModel,
63
73
  RBLNStableDiffusion3Img2ImgPipeline,
64
74
  RBLNStableDiffusion3InpaintPipeline,
@@ -23,6 +23,7 @@ from ..modeling import RBLNModel
23
23
  from ..modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
24
24
  from ..utils.decorator_utils import remove_compile_time_kwargs
25
25
  from ..utils.logging import get_logger
26
+ from . import pipelines
26
27
 
27
28
 
28
29
  logger = get_logger(__name__)
@@ -67,6 +68,7 @@ class RBLNDiffusionMixin:
67
68
  """
68
69
 
69
70
  _submodules = []
71
+ _prefix = {}
70
72
 
71
73
  @classmethod
72
74
  @property
@@ -84,25 +86,58 @@ class RBLNDiffusionMixin:
84
86
  ) -> Dict[str, Any]:
85
87
  submodule = getattr(model, submodule_name)
86
88
  submodule_class_name = submodule.__class__.__name__
89
+ if isinstance(submodule, torch.nn.Module):
90
+ if submodule_class_name == "MultiControlNetModel":
91
+ submodule_class_name = "ControlNetModel"
87
92
 
88
- if submodule_class_name == "MultiControlNetModel":
89
- submodule_class_name = "ControlNetModel"
93
+ submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"RBLN{submodule_class_name}")
90
94
 
91
- submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"RBLN{submodule_class_name}")
95
+ submodule_config = rbln_config.get(submodule_name, {})
96
+ submodule_config = copy.deepcopy(submodule_config)
92
97
 
93
- submodule_config = rbln_config.get(submodule_name, {})
94
- submodule_config = copy.deepcopy(submodule_config)
98
+ pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
95
99
 
96
- pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
100
+ submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
101
+ submodule_config.update(
102
+ {
103
+ "img2img_pipeline": cls.img2img_pipeline,
104
+ "inpaint_pipeline": cls.inpaint_pipeline,
105
+ }
106
+ )
107
+ submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
108
+ elif hasattr(pipelines, submodule_class_name):
109
+ submodule_config = rbln_config.get(submodule_name, {})
110
+ submodule_config = copy.deepcopy(submodule_config)
111
+
112
+ submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"{submodule_class_name}")
113
+ prefix = cls._prefix.get(submodule_name, "")
114
+ connected_submodules = cls._connected_classes.get(submodule_name)._submodules
115
+ pipe_global_config = {k: v for k, v in submodule_config.items() if k not in connected_submodules}
116
+ submodule_config = {k: v for k, v in submodule_config.items() if k in connected_submodules}
117
+ for key in submodule_config.keys():
118
+ submodule_config[key].update(pipe_global_config)
119
+
120
+ for connected_submodule_name in connected_submodules:
121
+ connected_submodule_config = rbln_config.pop(prefix + connected_submodule_name, {})
122
+ if connected_submodule_name in submodule_config:
123
+ submodule_config[connected_submodule_name].update(connected_submodule_config)
124
+ else:
125
+ submodule_config[connected_submodule_name] = connected_submodule_config
97
126
 
98
- submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
99
- submodule_config.update(
100
- {
101
- "img2img_pipeline": cls.img2img_pipeline,
102
- "inpaint_pipeline": cls.inpaint_pipeline,
127
+ pipe_global_config = {
128
+ k: v for k, v in rbln_config.items() if k != submodule_class_name and not isinstance(v, dict)
103
129
  }
104
- )
105
- submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
130
+
131
+ for connected_submodule_name in connected_submodules:
132
+ for k, v in pipe_global_config.items():
133
+ if "guidance_scale" in k:
134
+ if prefix + "guidance_scale" == k:
135
+ submodule_config[connected_submodule_name]["guidance_scale"] = v
136
+ else:
137
+ submodule_config[connected_submodule_name][k] = v
138
+ rbln_config[submodule_name] = submodule_config
139
+ else:
140
+ raise ValueError(f"submodule {submodule_name} isn't supported")
106
141
  return submodule_config
107
142
 
108
143
  @staticmethod
@@ -165,8 +200,26 @@ class RBLNDiffusionMixin:
165
200
 
166
201
  else:
167
202
  # raise error if any of submodules are torch module.
168
- model_index_config = None
169
- for submodule_name in cls._submodules:
203
+ model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
204
+ if cls._load_connected_pipes:
205
+ submodules = []
206
+ for submodule in cls._submodules:
207
+ submodule_config = rbln_config.pop(submodule, {})
208
+ prefix = cls._prefix.get(submodule, "")
209
+ connected_submodules = cls._connected_classes.get(submodule)._submodules
210
+ for connected_submodule_name in connected_submodules:
211
+ connected_submodule_config = submodule_config.pop(connected_submodule_name, {})
212
+ if connected_submodule_config:
213
+ rbln_config[prefix + connected_submodule_name] = connected_submodule_config
214
+ submodules.append(prefix + connected_submodule_name)
215
+ pipe_global_config = {k: v for k, v in rbln_config.items() if k not in submodules}
216
+ for submodule in submodules:
217
+ if submodule in rbln_config:
218
+ rbln_config[submodule].update(pipe_global_config)
219
+ else:
220
+ submodules = cls._submodules
221
+
222
+ for submodule_name in submodules:
170
223
  if isinstance(kwargs.get(submodule_name), torch.nn.Module):
171
224
  raise AssertionError(
172
225
  f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
@@ -181,9 +234,6 @@ class RBLNDiffusionMixin:
181
234
  if not any(kwd in submodule_config for kwd in RUNTIME_KEYWORDS):
182
235
  continue
183
236
 
184
- if model_index_config is None:
185
- model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
186
-
187
237
  module_name, class_name = model_index_config[submodule_name]
188
238
  if module_name != "optimum.rbln":
189
239
  raise ValueError(
@@ -228,6 +278,7 @@ class RBLNDiffusionMixin:
228
278
  passed_submodules: Dict[str, RBLNModel],
229
279
  model_save_dir: Optional[PathLike],
230
280
  rbln_config: Dict[str, Any],
281
+ prefix: Optional[str] = "",
231
282
  ) -> Dict[str, RBLNModel]:
232
283
  compiled_submodules = {}
233
284
 
@@ -245,17 +296,54 @@ class RBLNDiffusionMixin:
245
296
  controlnets=submodule,
246
297
  model_save_dir=model_save_dir,
247
298
  controlnet_rbln_config=submodule_rbln_config,
299
+ prefix=prefix,
248
300
  )
249
301
  elif isinstance(submodule, torch.nn.Module):
250
302
  submodule_cls: RBLNModel = getattr(
251
303
  importlib.import_module("optimum.rbln"), f"RBLN{submodule.__class__.__name__}"
252
304
  )
305
+ subfolder = prefix + submodule_name
253
306
  submodule = submodule_cls.from_model(
254
307
  model=submodule,
255
- subfolder=submodule_name,
308
+ subfolder=subfolder,
256
309
  model_save_dir=model_save_dir,
257
310
  rbln_config=submodule_rbln_config,
258
311
  )
312
+ elif hasattr(pipelines, submodule.__class__.__name__):
313
+ connected_pipe = submodule
314
+ connected_pipe_model_save_dir = model_save_dir
315
+ connected_pipe_rbln_config = submodule_rbln_config
316
+ connected_pipe_cls: RBLNDiffusionMixin = getattr(
317
+ importlib.import_module("optimum.rbln"), connected_pipe.__class__.__name__
318
+ )
319
+ submodule_dict = {}
320
+ for name in connected_pipe.config.keys():
321
+ if hasattr(connected_pipe, name):
322
+ submodule_dict[name] = getattr(connected_pipe, name)
323
+ connected_pipe = connected_pipe_cls(**submodule_dict)
324
+ connected_pipe_submodules = {}
325
+ prefix = cls._prefix.get(submodule_name, "")
326
+ for name in connected_pipe_cls._submodules:
327
+ if prefix + name in passed_submodules:
328
+ connected_pipe_submodules[name] = passed_submodules.get(prefix + name)
329
+
330
+ connected_pipe_compiled_submodules = connected_pipe_cls._compile_submodules(
331
+ model=connected_pipe,
332
+ passed_submodules=connected_pipe_submodules,
333
+ model_save_dir=model_save_dir,
334
+ rbln_config=connected_pipe_rbln_config,
335
+ prefix=prefix,
336
+ )
337
+ connected_pipe = connected_pipe_cls._construct_pipe(
338
+ connected_pipe,
339
+ connected_pipe_compiled_submodules,
340
+ connected_pipe_model_save_dir,
341
+ connected_pipe_rbln_config,
342
+ )
343
+
344
+ for name in connected_pipe_cls._submodules:
345
+ compiled_submodules[prefix + name] = getattr(connected_pipe, name)
346
+ submodule = connected_pipe
259
347
  else:
260
348
  raise ValueError(f"Unknown class of submodule({submodule_name}) : {submodule.__class__.__name__} ")
261
349
 
@@ -268,6 +356,7 @@ class RBLNDiffusionMixin:
268
356
  controlnets: "MultiControlNetModel",
269
357
  model_save_dir: Optional[PathLike],
270
358
  controlnet_rbln_config: Dict[str, Any],
359
+ prefix: Optional[str] = "",
271
360
  ):
272
361
  # Compile multiple ControlNet models for a MultiControlNet setup
273
362
  from .models.controlnet import RBLNControlNetModel
@@ -276,7 +365,7 @@ class RBLNDiffusionMixin:
276
365
  compiled_controlnets = [
277
366
  RBLNControlNetModel.from_model(
278
367
  model=controlnet,
279
- subfolder="controlnet" if i == 0 else f"controlnet_{i}",
368
+ subfolder=f"{prefix}controlnet" if i == 0 else f"{prefix}controlnet_{i}",
280
369
  model_save_dir=model_save_dir,
281
370
  rbln_config=controlnet_rbln_config,
282
371
  )
@@ -287,10 +376,21 @@ class RBLNDiffusionMixin:
287
376
  @classmethod
288
377
  def _construct_pipe(cls, model, submodules, model_save_dir, rbln_config):
289
378
  # Construct finalize pipe setup with compiled submodules and configurations
379
+ submodule_names = []
380
+ for submodule_name in cls._submodules:
381
+ submodule = getattr(model, submodule_name)
382
+ if hasattr(pipelines, submodule.__class__.__name__):
383
+ prefix = cls._prefix.get(submodule_name, "")
384
+ connected_pipe_submodules = submodules[submodule_name].__class__._submodules
385
+ connected_pipe_submodules = [prefix + name for name in connected_pipe_submodules]
386
+ submodule_names += connected_pipe_submodules
387
+ setattr(model, submodule_name, submodules[submodule_name])
388
+ else:
389
+ submodule_names.append(submodule_name)
290
390
 
291
391
  if model_save_dir is not None:
292
392
  # To skip saving original pytorch modules
293
- for submodule_name in cls._submodules:
393
+ for submodule_name in submodule_names:
294
394
  delattr(model, submodule_name)
295
395
 
296
396
  # Direct calling of `save_pretrained` causes config.unet = (None, None).
@@ -300,7 +400,7 @@ class RBLNDiffusionMixin:
300
400
  # Causing warning messeages.
301
401
 
302
402
  update_dict = {}
303
- for submodule_name in cls._submodules:
403
+ for submodule_name in submodule_names:
304
404
  # replace submodule
305
405
  setattr(model, submodule_name, submodules[submodule_name])
306
406
  update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
@@ -322,9 +422,16 @@ class RBLNDiffusionMixin:
322
422
  if rbln_config.get("optimize_host_memory") is False:
323
423
  # Keep compiled_model objs to further analysis. -> TODO: remove soon...
324
424
  model.compiled_models = []
325
- for name in cls._submodules:
326
- submodule = getattr(model, name)
327
- model.compiled_models.extend(submodule.compiled_models)
425
+ if model._load_connected_pipes:
426
+ for name in cls._submodules:
427
+ connected_pipe = getattr(model, name)
428
+ for submodule_name in connected_pipe.__class__._submodules:
429
+ submodule = getattr(connected_pipe, submodule_name)
430
+ model.compiled_models.extend(submodule.compiled_models)
431
+ else:
432
+ for name in cls._submodules:
433
+ submodule = getattr(model, name)
434
+ model.compiled_models.extend(submodule.compiled_models)
328
435
 
329
436
  return model
330
437
 
@@ -20,20 +20,26 @@ from transformers.utils import _LazyModule
20
20
  _import_structure = {
21
21
  "autoencoders": [
22
22
  "RBLNAutoencoderKL",
23
+ "RBLNVQModel",
23
24
  ],
24
25
  "unets": [
25
26
  "RBLNUNet2DConditionModel",
26
27
  ],
27
28
  "controlnet": ["RBLNControlNetModel"],
28
- "transformers": ["RBLNSD3Transformer2DModel"],
29
+ "transformers": [
30
+ "RBLNPriorTransformer",
31
+ "RBLNSD3Transformer2DModel",
32
+ ],
29
33
  }
30
34
 
31
35
  if TYPE_CHECKING:
32
36
  from .autoencoders import (
33
37
  RBLNAutoencoderKL,
38
+ RBLNVQModel,
34
39
  )
35
40
  from .controlnet import RBLNControlNetModel
36
41
  from .transformers import (
42
+ RBLNPriorTransformer,
37
43
  RBLNSD3Transformer2DModel,
38
44
  )
39
45
  from .unets import (
@@ -13,3 +13,4 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .autoencoder_kl import RBLNAutoencoderKL
16
+ from .vq_model import RBLNVQModel
@@ -12,11 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING
15
+ from typing import TYPE_CHECKING, List
16
16
 
17
17
  import torch # noqa: I001
18
- from diffusers import AutoencoderKL
18
+ from diffusers import AutoencoderKL, VQModel
19
19
  from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
20
+ from diffusers.models.autoencoders.vq_model import VQEncoderOutput
20
21
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
21
22
 
22
23
  from ....utils.logging import get_logger
@@ -72,3 +73,52 @@ class _VAEEncoder(torch.nn.Module):
72
73
  def forward(self, x):
73
74
  vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
74
75
  return vae_out
76
+
77
+
78
+ class RBLNRuntimeVQEncoder(RBLNPytorchRuntime):
79
+ def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
80
+ h = self.forward(x.contiguous())
81
+ return VQEncoderOutput(latents=h)
82
+
83
+
84
+ class RBLNRuntimeVQDecoder(RBLNPytorchRuntime):
85
+ def decode(self, h: torch.Tensor, force_not_quantize: bool = False, shape=None, **kwargs) -> List[torch.Tensor]:
86
+ if not (force_not_quantize and not self.lookup_from_codebook):
87
+ raise ValueError(
88
+ "Currently, the `decode` method of the class `RBLNVQModel` is executed successfully only if `force_not_quantize` is True and `config.lookup_from_codebook` is False"
89
+ )
90
+ commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
91
+ dec = self.forward(h.contiguous())
92
+ return dec, commit_loss
93
+
94
+
95
+ class _VQEncoder(torch.nn.Module):
96
+ def __init__(self, vq_model: VQModel):
97
+ super().__init__()
98
+ self.vq_model = vq_model
99
+
100
+ def encode(self, x: torch.Tensor, return_dict: bool = True):
101
+ h = self.vq_model.encoder(x)
102
+ h = self.vq_model.quant_conv(h)
103
+ return h
104
+
105
+ def forward(self, x: torch.Tensor):
106
+ vq_out = self.encode(x)
107
+ return vq_out
108
+
109
+
110
+ class _VQDecoder(torch.nn.Module):
111
+ def __init__(self, vq_model: VQModel):
112
+ super().__init__()
113
+ self.vq_model = vq_model
114
+
115
+ def decode(self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None):
116
+ quant = h
117
+ quant2 = self.vq_model.post_quant_conv(quant)
118
+ quant = quant if self.vq_model.config.norm_type == "spatial" else None
119
+ dec = self.vq_model.decoder(quant2, quant)
120
+ return dec
121
+
122
+ def forward(self, h: torch.Tensor):
123
+ vq_out = self.decode(h)
124
+ return vq_out
@@ -0,0 +1,159 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
16
+
17
+ import rebel
18
+ import torch
19
+ from diffusers import VQModel
20
+ from diffusers.models.autoencoders.vae import DecoderOutput
21
+ from diffusers.models.autoencoders.vq_model import VQEncoderOutput
22
+ from transformers import PretrainedConfig
23
+
24
+ from ....modeling import RBLNModel
25
+ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
26
+ from ....utils.logging import get_logger
27
+ from ...modeling_diffusers import RBLNDiffusionMixin
28
+ from .vae import RBLNRuntimeVQDecoder, RBLNRuntimeVQEncoder, _VQDecoder, _VQEncoder
29
+
30
+
31
+ if TYPE_CHECKING:
32
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
33
+
34
+ logger = get_logger(__name__)
35
+
36
+
37
+ class RBLNVQModel(RBLNModel):
38
+ auto_model_class = VQModel
39
+ config_name = "config.json"
40
+ hf_library_name = "diffusers"
41
+
42
+ def __post_init__(self, **kwargs):
43
+ super().__post_init__(**kwargs)
44
+
45
+ self.encoder = RBLNRuntimeVQEncoder(runtime=self.model[0], main_input_name="x")
46
+ self.decoder = RBLNRuntimeVQDecoder(runtime=self.model[1], main_input_name="z")
47
+ self.decoder.lookup_from_codebook = self.config.lookup_from_codebook
48
+ height = self.rbln_config.model_cfg.get("img_height", 512)
49
+ width = self.rbln_config.model_cfg.get("img_width", 512)
50
+ self.image_size = [height, width]
51
+
52
+ @classmethod
53
+ def get_compiled_model(cls, model, rbln_config: RBLNConfig):
54
+ encoder_model = _VQEncoder(model)
55
+ decoder_model = _VQDecoder(model)
56
+ encoder_model.eval()
57
+ decoder_model.eval()
58
+
59
+ enc_compiled_model = cls.compile(encoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
60
+ dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[1])
61
+
62
+ return {"encoder": enc_compiled_model, "decoder": dec_compiled_model}
63
+
64
+ @classmethod
65
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
66
+ batch_size = rbln_config.get("batch_size")
67
+ if batch_size is None:
68
+ batch_size = 1
69
+ img_height = rbln_config.get("img_height")
70
+ if img_height is None:
71
+ img_height = 512
72
+ img_width = rbln_config.get("img_width")
73
+ if img_width is None:
74
+ img_width = 512
75
+
76
+ rbln_config.update(
77
+ {
78
+ "batch_size": batch_size,
79
+ "img_height": img_height,
80
+ "img_width": img_width,
81
+ }
82
+ )
83
+
84
+ return rbln_config
85
+
86
+ @classmethod
87
+ def _get_rbln_config(
88
+ cls,
89
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
90
+ model_config: "PretrainedConfig",
91
+ rbln_kwargs: Dict[str, Any] = {},
92
+ ) -> RBLNConfig:
93
+ batch_size = rbln_kwargs.get("batch_size") or 1
94
+ height = rbln_kwargs.get("img_height") or 512
95
+ width = rbln_kwargs.get("img_width") or 512
96
+
97
+ if hasattr(model_config, "block_out_channels"):
98
+ scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
99
+ else:
100
+ # image processor default value 8 (int)
101
+ scale_factor = 8
102
+
103
+ enc_shape = (height, width)
104
+ dec_shape = (height // scale_factor, width // scale_factor)
105
+
106
+ enc_input_info = [
107
+ (
108
+ "x",
109
+ [batch_size, model_config.in_channels, enc_shape[0], enc_shape[1]],
110
+ "float32",
111
+ )
112
+ ]
113
+ dec_input_info = [
114
+ (
115
+ "h",
116
+ [batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
117
+ "float32",
118
+ )
119
+ ]
120
+
121
+ enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
122
+ dec_rbln_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
123
+
124
+ compile_cfgs = [enc_rbln_compile_config, dec_rbln_compile_config]
125
+ rbln_config = RBLNConfig(
126
+ rbln_cls=cls.__name__,
127
+ compile_cfgs=compile_cfgs,
128
+ rbln_kwargs=rbln_kwargs,
129
+ )
130
+ return rbln_config
131
+
132
+ @classmethod
133
+ def _create_runtimes(
134
+ cls,
135
+ compiled_models: List[rebel.RBLNCompiledModel],
136
+ rbln_device_map: Dict[str, int],
137
+ activate_profiler: Optional[bool] = None,
138
+ ) -> List[rebel.Runtime]:
139
+ if len(compiled_models) == 1:
140
+ device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
141
+ return [
142
+ compiled_models[0].create_runtime(
143
+ tensor_type="pt", device=device_val, activate_profiler=activate_profiler
144
+ )
145
+ ]
146
+
147
+ device_vals = [rbln_device_map["encoder"], rbln_device_map["decoder"]]
148
+ return [
149
+ compiled_model.create_runtime(tensor_type="pt", device=device_val, activate_profiler=activate_profiler)
150
+ for compiled_model, device_val in zip(compiled_models, device_vals)
151
+ ]
152
+
153
+ def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
154
+ posterior = self.encoder.encode(x)
155
+ return VQEncoderOutput(latents=posterior)
156
+
157
+ def decode(self, h: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
158
+ dec, commit_loss = self.decoder.decode(h, **kwargs)
159
+ return DecoderOutput(sample=dec, commit_loss=commit_loss)
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .prior_transformer import RBLNPriorTransformer
15
16
  from .transformer_sd3 import RBLNSD3Transformer2DModel