optimum-rbln 0.2.1a4__py3-none-any.whl → 0.7.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. optimum/rbln/__init__.py +14 -2
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +10 -0
  4. optimum/rbln/diffusers/modeling_diffusers.py +115 -23
  5. optimum/rbln/diffusers/models/__init__.py +7 -1
  6. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +52 -2
  8. optimum/rbln/diffusers/models/autoencoders/vq_model.py +159 -0
  9. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  10. optimum/rbln/diffusers/models/transformers/prior_transformer.py +174 -0
  11. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +57 -14
  12. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  13. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +17 -0
  14. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +83 -0
  15. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +22 -0
  16. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +22 -0
  17. optimum/rbln/modeling_base.py +10 -9
  18. optimum/rbln/transformers/__init__.py +2 -0
  19. optimum/rbln/transformers/models/__init__.py +12 -2
  20. optimum/rbln/transformers/models/clip/__init__.py +6 -1
  21. optimum/rbln/transformers/models/clip/modeling_clip.py +26 -1
  22. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +3 -1
  23. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +1 -1
  24. {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2rc0.dist-info}/METADATA +1 -1
  25. {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2rc0.dist-info}/RECORD +27 -21
  26. {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2rc0.dist-info}/WHEEL +0 -0
  27. {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2rc0.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -48,6 +48,7 @@ _import_structure = {
48
48
  "RBLNCLIPTextModel",
49
49
  "RBLNCLIPTextModelWithProjection",
50
50
  "RBLNCLIPVisionModel",
51
+ "RBLNCLIPVisionModelWithProjection",
51
52
  "RBLNDPTForDepthEstimation",
52
53
  "RBLNExaoneForCausalLM",
53
54
  "RBLNGemmaForCausalLM",
@@ -74,11 +75,15 @@ _import_structure = {
74
75
  "RBLNBertForMaskedLM",
75
76
  ],
76
77
  "diffusers": [
78
+ "RBLNAutoencoderKL",
79
+ "RBLNControlNetModel",
80
+ "RBLNPriorTransformer",
81
+ "RBLNKandinskyV22InpaintCombinedPipeline",
82
+ "RBLNKandinskyV22InpaintPipeline",
83
+ "RBLNKandinskyV22PriorPipeline",
77
84
  "RBLNStableDiffusionPipeline",
78
85
  "RBLNStableDiffusionXLPipeline",
79
- "RBLNAutoencoderKL",
80
86
  "RBLNUNet2DConditionModel",
81
- "RBLNControlNetModel",
82
87
  "RBLNStableDiffusionImg2ImgPipeline",
83
88
  "RBLNStableDiffusionInpaintPipeline",
84
89
  "RBLNStableDiffusionControlNetImg2ImgPipeline",
@@ -88,6 +93,7 @@ _import_structure = {
88
93
  "RBLNStableDiffusionControlNetPipeline",
89
94
  "RBLNStableDiffusionXLControlNetPipeline",
90
95
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
96
+ "RBLNVQModel",
91
97
  "RBLNSD3Transformer2DModel",
92
98
  "RBLNStableDiffusion3Img2ImgPipeline",
93
99
  "RBLNStableDiffusion3InpaintPipeline",
@@ -101,7 +107,11 @@ if TYPE_CHECKING:
101
107
  RBLNAutoencoderKL,
102
108
  RBLNControlNetModel,
103
109
  RBLNDiffusionMixin,
110
+ RBLNKandinskyV22InpaintCombinedPipeline,
111
+ RBLNKandinskyV22InpaintPipeline,
112
+ RBLNKandinskyV22PriorPipeline,
104
113
  RBLNMultiControlNetModel,
114
+ RBLNPriorTransformer,
105
115
  RBLNSD3Transformer2DModel,
106
116
  RBLNStableDiffusion3Img2ImgPipeline,
107
117
  RBLNStableDiffusion3InpaintPipeline,
@@ -117,6 +127,7 @@ if TYPE_CHECKING:
117
127
  RBLNStableDiffusionXLInpaintPipeline,
118
128
  RBLNStableDiffusionXLPipeline,
119
129
  RBLNUNet2DConditionModel,
130
+ RBLNVQModel,
120
131
  )
121
132
  from .modeling import (
122
133
  RBLNBaseModel,
@@ -148,6 +159,7 @@ if TYPE_CHECKING:
148
159
  RBLNCLIPTextModel,
149
160
  RBLNCLIPTextModelWithProjection,
150
161
  RBLNCLIPVisionModel,
162
+ RBLNCLIPVisionModelWithProjection,
151
163
  RBLNDistilBertForQuestionAnswering,
152
164
  RBLNDPTForDepthEstimation,
153
165
  RBLNExaoneForCausalLM,
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.1a4'
16
- __version_tuple__ = version_tuple = (0, 2, 1)
15
+ __version__ = version = '0.7.2rc0'
16
+ __version_tuple__ = version_tuple = (0, 7, 2)
@@ -24,6 +24,9 @@ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES["optimum.rbln"])
24
24
 
25
25
  _import_structure = {
26
26
  "pipelines": [
27
+ "RBLNKandinskyV22InpaintCombinedPipeline",
28
+ "RBLNKandinskyV22InpaintPipeline",
29
+ "RBLNKandinskyV22PriorPipeline",
27
30
  "RBLNStableDiffusionPipeline",
28
31
  "RBLNStableDiffusionXLPipeline",
29
32
  "RBLNStableDiffusionImg2ImgPipeline",
@@ -44,6 +47,8 @@ _import_structure = {
44
47
  "RBLNUNet2DConditionModel",
45
48
  "RBLNControlNetModel",
46
49
  "RBLNSD3Transformer2DModel",
50
+ "RBLNPriorTransformer",
51
+ "RBLNVQModel",
47
52
  ],
48
53
  "modeling_diffusers": [
49
54
  "RBLNDiffusionMixin",
@@ -55,10 +60,15 @@ if TYPE_CHECKING:
55
60
  from .models import (
56
61
  RBLNAutoencoderKL,
57
62
  RBLNControlNetModel,
63
+ RBLNPriorTransformer,
58
64
  RBLNSD3Transformer2DModel,
59
65
  RBLNUNet2DConditionModel,
66
+ RBLNVQModel,
60
67
  )
61
68
  from .pipelines import (
69
+ RBLNKandinskyV22InpaintCombinedPipeline,
70
+ RBLNKandinskyV22InpaintPipeline,
71
+ RBLNKandinskyV22PriorPipeline,
62
72
  RBLNMultiControlNetModel,
63
73
  RBLNStableDiffusion3Img2ImgPipeline,
64
74
  RBLNStableDiffusion3InpaintPipeline,
@@ -23,6 +23,7 @@ from ..modeling import RBLNModel
23
23
  from ..modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
24
24
  from ..utils.decorator_utils import remove_compile_time_kwargs
25
25
  from ..utils.logging import get_logger
26
+ from . import pipelines
26
27
 
27
28
 
28
29
  logger = get_logger(__name__)
@@ -67,6 +68,7 @@ class RBLNDiffusionMixin:
67
68
  """
68
69
 
69
70
  _submodules = []
71
+ _prefix = {}
70
72
 
71
73
  @classmethod
72
74
  @property
@@ -84,25 +86,50 @@ class RBLNDiffusionMixin:
84
86
  ) -> Dict[str, Any]:
85
87
  submodule = getattr(model, submodule_name)
86
88
  submodule_class_name = submodule.__class__.__name__
89
+ if isinstance(submodule, torch.nn.Module):
90
+ if submodule_class_name == "MultiControlNetModel":
91
+ submodule_class_name = "ControlNetModel"
87
92
 
88
- if submodule_class_name == "MultiControlNetModel":
89
- submodule_class_name = "ControlNetModel"
93
+ submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"RBLN{submodule_class_name}")
90
94
 
91
- submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"RBLN{submodule_class_name}")
95
+ submodule_config = rbln_config.get(submodule_name, {})
96
+ submodule_config = copy.deepcopy(submodule_config)
92
97
 
93
- submodule_config = rbln_config.get(submodule_name, {})
94
- submodule_config = copy.deepcopy(submodule_config)
98
+ pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
95
99
 
96
- pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
100
+ submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
101
+ submodule_config.update(
102
+ {
103
+ "img2img_pipeline": cls.img2img_pipeline,
104
+ "inpaint_pipeline": cls.inpaint_pipeline,
105
+ }
106
+ )
107
+ submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
108
+ elif hasattr(pipelines, submodule_class_name):
109
+ submodule_config = rbln_config.get(submodule_name, {})
110
+ submodule_config = copy.deepcopy(submodule_config)
111
+
112
+ submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"{submodule_class_name}")
113
+ prefix = cls._prefix.get(submodule_name, "")
114
+ connected_submodules = cls._connected_classes.get(submodule_name)._submodules
115
+ for connected_submodule_name in connected_submodules:
116
+ connected_submodule_config = rbln_config.pop(prefix + connected_submodule_name, {})
117
+ if connected_submodule_name in submodule_config:
118
+ submodule_config[connected_submodule_name].update(connected_submodule_config)
119
+ else:
120
+ submodule_config[connected_submodule_name] = connected_submodule_config
97
121
 
98
- submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
99
- submodule_config.update(
100
- {
101
- "img2img_pipeline": cls.img2img_pipeline,
102
- "inpaint_pipeline": cls.inpaint_pipeline,
103
- }
104
- )
105
- submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
122
+ submodules = copy.deepcopy(cls._submodules)
123
+ submodules += [prefix + connected_submodule_name for connected_submodule_name in connected_submodules]
124
+
125
+ pipe_global_config = {k: v for k, v in rbln_config.items() if k not in submodules}
126
+ for connected_submodule_name in connected_submodules:
127
+ submodule_config[connected_submodule_name].update(
128
+ {k: v for k, v in pipe_global_config.items() if k not in submodule_config}
129
+ )
130
+ rbln_config[submodule_name] = submodule_config
131
+ else:
132
+ raise ValueError(f"submodule {submodule_name} isn't supported")
106
133
  return submodule_config
107
134
 
108
135
  @staticmethod
@@ -165,8 +192,26 @@ class RBLNDiffusionMixin:
165
192
 
166
193
  else:
167
194
  # raise error if any of submodules are torch module.
168
- model_index_config = None
169
- for submodule_name in cls._submodules:
195
+ model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
196
+ if cls._load_connected_pipes:
197
+ submodules = []
198
+ for submodule in cls._submodules:
199
+ submodule_config = rbln_config.pop(submodule, {})
200
+ prefix = cls._prefix.get(submodule, "")
201
+ connected_submodules = cls._connected_classes.get(submodule)._submodules
202
+ for connected_submodule_name in connected_submodules:
203
+ connected_submodule_config = submodule_config.pop(connected_submodule_name, {})
204
+ if connected_submodule_config:
205
+ rbln_config[prefix + connected_submodule_name] = connected_submodule_config
206
+ submodules.append(prefix + connected_submodule_name)
207
+ pipe_global_config = {k: v for k, v in rbln_config.items() if k not in submodules}
208
+ for submodule in submodules:
209
+ if submodule in rbln_config:
210
+ rbln_config[submodule].update(pipe_global_config)
211
+ else:
212
+ submodules = cls._submodules
213
+
214
+ for submodule_name in submodules:
170
215
  if isinstance(kwargs.get(submodule_name), torch.nn.Module):
171
216
  raise AssertionError(
172
217
  f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
@@ -181,9 +226,6 @@ class RBLNDiffusionMixin:
181
226
  if not any(kwd in submodule_config for kwd in RUNTIME_KEYWORDS):
182
227
  continue
183
228
 
184
- if model_index_config is None:
185
- model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
186
-
187
229
  module_name, class_name = model_index_config[submodule_name]
188
230
  if module_name != "optimum.rbln":
189
231
  raise ValueError(
@@ -228,6 +270,7 @@ class RBLNDiffusionMixin:
228
270
  passed_submodules: Dict[str, RBLNModel],
229
271
  model_save_dir: Optional[PathLike],
230
272
  rbln_config: Dict[str, Any],
273
+ prefix: Optional[str] = "",
231
274
  ) -> Dict[str, RBLNModel]:
232
275
  compiled_submodules = {}
233
276
 
@@ -245,17 +288,54 @@ class RBLNDiffusionMixin:
245
288
  controlnets=submodule,
246
289
  model_save_dir=model_save_dir,
247
290
  controlnet_rbln_config=submodule_rbln_config,
291
+ prefix=prefix,
248
292
  )
249
293
  elif isinstance(submodule, torch.nn.Module):
250
294
  submodule_cls: RBLNModel = getattr(
251
295
  importlib.import_module("optimum.rbln"), f"RBLN{submodule.__class__.__name__}"
252
296
  )
297
+ subfolder = prefix + submodule_name
253
298
  submodule = submodule_cls.from_model(
254
299
  model=submodule,
255
- subfolder=submodule_name,
300
+ subfolder=subfolder,
256
301
  model_save_dir=model_save_dir,
257
302
  rbln_config=submodule_rbln_config,
258
303
  )
304
+ elif hasattr(pipelines, submodule.__class__.__name__):
305
+ connected_pipe = submodule
306
+ connected_pipe_model_save_dir = model_save_dir
307
+ connected_pipe_rbln_config = submodule_rbln_config
308
+ connected_pipe_cls: RBLNDiffusionMixin = getattr(
309
+ importlib.import_module("optimum.rbln"), connected_pipe.__class__.__name__
310
+ )
311
+ submodule_dict = {}
312
+ for name in connected_pipe.config.keys():
313
+ if hasattr(connected_pipe, name):
314
+ submodule_dict[name] = getattr(connected_pipe, name)
315
+ connected_pipe = connected_pipe_cls(**submodule_dict)
316
+ connected_pipe_submodules = {}
317
+ prefix = cls._prefix.get(submodule_name, "")
318
+ for name in connected_pipe_cls._submodules:
319
+ if prefix + name in passed_submodules:
320
+ connected_pipe_submodules[name] = passed_submodules.get(prefix + name)
321
+
322
+ connected_pipe_compiled_submodules = connected_pipe_cls._compile_submodules(
323
+ model=connected_pipe,
324
+ passed_submodules=connected_pipe_submodules,
325
+ model_save_dir=model_save_dir,
326
+ rbln_config=connected_pipe_rbln_config,
327
+ prefix=prefix,
328
+ )
329
+ connected_pipe = connected_pipe_cls._construct_pipe(
330
+ connected_pipe,
331
+ connected_pipe_compiled_submodules,
332
+ connected_pipe_model_save_dir,
333
+ connected_pipe_rbln_config,
334
+ )
335
+
336
+ for name in connected_pipe_cls._submodules:
337
+ compiled_submodules[prefix + name] = getattr(connected_pipe, name)
338
+ submodule = connected_pipe
259
339
  else:
260
340
  raise ValueError(f"Unknown class of submodule({submodule_name}) : {submodule.__class__.__name__} ")
261
341
 
@@ -268,6 +348,7 @@ class RBLNDiffusionMixin:
268
348
  controlnets: "MultiControlNetModel",
269
349
  model_save_dir: Optional[PathLike],
270
350
  controlnet_rbln_config: Dict[str, Any],
351
+ prefix: Optional[str] = "",
271
352
  ):
272
353
  # Compile multiple ControlNet models for a MultiControlNet setup
273
354
  from .models.controlnet import RBLNControlNetModel
@@ -276,7 +357,7 @@ class RBLNDiffusionMixin:
276
357
  compiled_controlnets = [
277
358
  RBLNControlNetModel.from_model(
278
359
  model=controlnet,
279
- subfolder="controlnet" if i == 0 else f"controlnet_{i}",
360
+ subfolder=f"{prefix}controlnet" if i == 0 else f"{prefix}controlnet_{i}",
280
361
  model_save_dir=model_save_dir,
281
362
  rbln_config=controlnet_rbln_config,
282
363
  )
@@ -287,10 +368,21 @@ class RBLNDiffusionMixin:
287
368
  @classmethod
288
369
  def _construct_pipe(cls, model, submodules, model_save_dir, rbln_config):
289
370
  # Construct finalize pipe setup with compiled submodules and configurations
371
+ submodule_names = []
372
+ for submodule_name in cls._submodules:
373
+ submodule = getattr(model, submodule_name)
374
+ if hasattr(pipelines, submodule.__class__.__name__):
375
+ prefix = cls._prefix.get(submodule_name, "")
376
+ connected_pipe_submodules = submodules[submodule_name].__class__._submodules
377
+ connected_pipe_submodules = [prefix + name for name in connected_pipe_submodules]
378
+ submodule_names += connected_pipe_submodules
379
+ setattr(model, submodule_name, submodules[submodule_name])
380
+ else:
381
+ submodule_names.append(submodule_name)
290
382
 
291
383
  if model_save_dir is not None:
292
384
  # To skip saving original pytorch modules
293
- for submodule_name in cls._submodules:
385
+ for submodule_name in submodule_names:
294
386
  delattr(model, submodule_name)
295
387
 
296
388
  # Direct calling of `save_pretrained` causes config.unet = (None, None).
@@ -300,7 +392,7 @@ class RBLNDiffusionMixin:
300
392
  # Causing warning messeages.
301
393
 
302
394
  update_dict = {}
303
- for submodule_name in cls._submodules:
395
+ for submodule_name in submodule_names:
304
396
  # replace submodule
305
397
  setattr(model, submodule_name, submodules[submodule_name])
306
398
  update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
@@ -20,20 +20,26 @@ from transformers.utils import _LazyModule
20
20
  _import_structure = {
21
21
  "autoencoders": [
22
22
  "RBLNAutoencoderKL",
23
+ "RBLNVQModel",
23
24
  ],
24
25
  "unets": [
25
26
  "RBLNUNet2DConditionModel",
26
27
  ],
27
28
  "controlnet": ["RBLNControlNetModel"],
28
- "transformers": ["RBLNSD3Transformer2DModel"],
29
+ "transformers": [
30
+ "RBLNPriorTransformer",
31
+ "RBLNSD3Transformer2DModel",
32
+ ],
29
33
  }
30
34
 
31
35
  if TYPE_CHECKING:
32
36
  from .autoencoders import (
33
37
  RBLNAutoencoderKL,
38
+ RBLNVQModel,
34
39
  )
35
40
  from .controlnet import RBLNControlNetModel
36
41
  from .transformers import (
42
+ RBLNPriorTransformer,
37
43
  RBLNSD3Transformer2DModel,
38
44
  )
39
45
  from .unets import (
@@ -13,3 +13,4 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .autoencoder_kl import RBLNAutoencoderKL
16
+ from .vq_model import RBLNVQModel
@@ -12,11 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING
15
+ from typing import TYPE_CHECKING, List
16
16
 
17
17
  import torch # noqa: I001
18
- from diffusers import AutoencoderKL
18
+ from diffusers import AutoencoderKL, VQModel
19
19
  from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
20
+ from diffusers.models.autoencoders.vq_model import VQEncoderOutput
20
21
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
21
22
 
22
23
  from ....utils.logging import get_logger
@@ -72,3 +73,52 @@ class _VAEEncoder(torch.nn.Module):
72
73
  def forward(self, x):
73
74
  vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
74
75
  return vae_out
76
+
77
+
78
+ class RBLNRuntimeVQEncoder(RBLNPytorchRuntime):
79
+ def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
80
+ h = self.forward(x.contiguous())
81
+ return VQEncoderOutput(latents=h)
82
+
83
+
84
+ class RBLNRuntimeVQDecoder(RBLNPytorchRuntime):
85
+ def decode(self, h: torch.Tensor, force_not_quantize: bool = False, shape=None, **kwargs) -> List[torch.Tensor]:
86
+ if not (force_not_quantize and not self.lookup_from_codebook):
87
+ raise ValueError(
88
+ "Currently, the `decode` method of the class `RBLNVQModel` is executed successfully only if `force_not_quantize` is True and `config.lookup_from_codebook` is False"
89
+ )
90
+ commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
91
+ dec = self.forward(h.contiguous())
92
+ return dec, commit_loss
93
+
94
+
95
+ class _VQEncoder(torch.nn.Module):
96
+ def __init__(self, vq_model: VQModel):
97
+ super().__init__()
98
+ self.vq_model = vq_model
99
+
100
+ def encode(self, x: torch.Tensor, return_dict: bool = True):
101
+ h = self.vq_model.encoder(x)
102
+ h = self.vq_model.quant_conv(h)
103
+ return h
104
+
105
+ def forward(self, x: torch.Tensor):
106
+ vq_out = self.encode(x)
107
+ return vq_out
108
+
109
+
110
+ class _VQDecoder(torch.nn.Module):
111
+ def __init__(self, vq_model: VQModel):
112
+ super().__init__()
113
+ self.vq_model = vq_model
114
+
115
+ def decode(self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None):
116
+ quant = h
117
+ quant2 = self.vq_model.post_quant_conv(quant)
118
+ quant = quant if self.vq_model.config.norm_type == "spatial" else None
119
+ dec = self.vq_model.decoder(quant2, quant)
120
+ return dec
121
+
122
+ def forward(self, h: torch.Tensor):
123
+ vq_out = self.decode(h)
124
+ return vq_out
@@ -0,0 +1,159 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
16
+
17
+ import rebel
18
+ import torch
19
+ from diffusers import VQModel
20
+ from diffusers.models.autoencoders.vae import DecoderOutput
21
+ from diffusers.models.autoencoders.vq_model import VQEncoderOutput
22
+ from transformers import PretrainedConfig
23
+
24
+ from ....modeling import RBLNModel
25
+ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
26
+ from ....utils.logging import get_logger
27
+ from ...modeling_diffusers import RBLNDiffusionMixin
28
+ from .vae import RBLNRuntimeVQDecoder, RBLNRuntimeVQEncoder, _VQDecoder, _VQEncoder
29
+
30
+
31
+ if TYPE_CHECKING:
32
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
33
+
34
+ logger = get_logger(__name__)
35
+
36
+
37
+ class RBLNVQModel(RBLNModel):
38
+ auto_model_class = VQModel
39
+ config_name = "config.json"
40
+ hf_library_name = "diffusers"
41
+
42
+ def __post_init__(self, **kwargs):
43
+ super().__post_init__(**kwargs)
44
+
45
+ self.encoder = RBLNRuntimeVQEncoder(runtime=self.model[0], main_input_name="x")
46
+ self.decoder = RBLNRuntimeVQDecoder(runtime=self.model[1], main_input_name="z")
47
+ self.decoder.lookup_from_codebook = self.config.lookup_from_codebook
48
+ height = self.rbln_config.model_cfg.get("img_height", 512)
49
+ width = self.rbln_config.model_cfg.get("img_width", 512)
50
+ self.image_size = [height, width]
51
+
52
+ @classmethod
53
+ def get_compiled_model(cls, model, rbln_config: RBLNConfig):
54
+ encoder_model = _VQEncoder(model)
55
+ decoder_model = _VQDecoder(model)
56
+ encoder_model.eval()
57
+ decoder_model.eval()
58
+
59
+ enc_compiled_model = cls.compile(encoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
60
+ dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[1])
61
+
62
+ return {"encoder": enc_compiled_model, "decoder": dec_compiled_model}
63
+
64
+ @classmethod
65
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
66
+ batch_size = rbln_config.get("batch_size")
67
+ if batch_size is None:
68
+ batch_size = 1
69
+ img_height = rbln_config.get("img_height")
70
+ if img_height is None:
71
+ img_height = 512
72
+ img_width = rbln_config.get("img_width")
73
+ if img_width is None:
74
+ img_width = 512
75
+
76
+ rbln_config.update(
77
+ {
78
+ "batch_size": batch_size,
79
+ "img_height": img_height,
80
+ "img_width": img_width,
81
+ }
82
+ )
83
+
84
+ return rbln_config
85
+
86
+ @classmethod
87
+ def _get_rbln_config(
88
+ cls,
89
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
90
+ model_config: "PretrainedConfig",
91
+ rbln_kwargs: Dict[str, Any] = {},
92
+ ) -> RBLNConfig:
93
+ batch_size = rbln_kwargs.get("batch_size") or 1
94
+ height = rbln_kwargs.get("img_height") or 512
95
+ width = rbln_kwargs.get("img_width") or 512
96
+
97
+ if hasattr(model_config, "block_out_channels"):
98
+ scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
99
+ else:
100
+ # image processor default value 8 (int)
101
+ scale_factor = 8
102
+
103
+ enc_shape = (height, width)
104
+ dec_shape = (height // scale_factor, width // scale_factor)
105
+
106
+ enc_input_info = [
107
+ (
108
+ "x",
109
+ [batch_size, model_config.in_channels, enc_shape[0], enc_shape[1]],
110
+ "float32",
111
+ )
112
+ ]
113
+ dec_input_info = [
114
+ (
115
+ "h",
116
+ [batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
117
+ "float32",
118
+ )
119
+ ]
120
+
121
+ enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
122
+ dec_rbln_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
123
+
124
+ compile_cfgs = [enc_rbln_compile_config, dec_rbln_compile_config]
125
+ rbln_config = RBLNConfig(
126
+ rbln_cls=cls.__name__,
127
+ compile_cfgs=compile_cfgs,
128
+ rbln_kwargs=rbln_kwargs,
129
+ )
130
+ return rbln_config
131
+
132
+ @classmethod
133
+ def _create_runtimes(
134
+ cls,
135
+ compiled_models: List[rebel.RBLNCompiledModel],
136
+ rbln_device_map: Dict[str, int],
137
+ activate_profiler: Optional[bool] = None,
138
+ ) -> List[rebel.Runtime]:
139
+ if len(compiled_models) == 1:
140
+ device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
141
+ return [
142
+ compiled_models[0].create_runtime(
143
+ tensor_type="pt", device=device_val, activate_profiler=activate_profiler
144
+ )
145
+ ]
146
+
147
+ device_vals = [rbln_device_map["encoder"], rbln_device_map["decoder"]]
148
+ return [
149
+ compiled_model.create_runtime(tensor_type="pt", device=device_val, activate_profiler=activate_profiler)
150
+ for compiled_model, device_val in zip(compiled_models, device_vals)
151
+ ]
152
+
153
+ def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
154
+ posterior = self.encoder.encode(x)
155
+ return VQEncoderOutput(latents=posterior)
156
+
157
+ def decode(self, h: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
158
+ dec, commit_loss = self.decoder.decode(h, **kwargs)
159
+ return DecoderOutput(sample=dec, commit_loss=commit_loss)
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .prior_transformer import RBLNPriorTransformer
15
16
  from .transformer_sd3 import RBLNSD3Transformer2DModel