optimum-rbln 0.1.15__py3-none-any.whl → 0.2.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. optimum/rbln/__init__.py +26 -33
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +4 -0
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
  5. optimum/rbln/diffusers/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
  8. optimum/rbln/diffusers/models/controlnet.py +1 -1
  9. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
  10. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
  11. optimum/rbln/diffusers/pipelines/__init__.py +1 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
  13. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
  17. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
  21. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
  27. optimum/rbln/modeling.py +13 -347
  28. optimum/rbln/modeling_base.py +24 -4
  29. optimum/rbln/modeling_config.py +31 -7
  30. optimum/rbln/ops/__init__.py +26 -0
  31. optimum/rbln/ops/attn.py +221 -0
  32. optimum/rbln/ops/flash_attn.py +70 -0
  33. optimum/rbln/ops/kv_cache_update.py +69 -0
  34. optimum/rbln/transformers/__init__.py +20 -0
  35. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  36. optimum/rbln/transformers/modeling_generic.py +385 -0
  37. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
  39. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  40. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
  42. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  43. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
  44. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
  45. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
  46. optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
  47. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  48. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
  49. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  51. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
  52. optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
  53. optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
  54. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  55. optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
  56. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  57. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
  58. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  59. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  60. optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
  61. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  62. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  63. optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
  64. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  65. optimum/rbln/transformers/utils/rbln_quantization.py +1 -2
  66. optimum/rbln/utils/decorator_utils.py +51 -15
  67. optimum/rbln/utils/import_utils.py +8 -1
  68. optimum/rbln/utils/logging.py +38 -1
  69. optimum/rbln/utils/model_utils.py +0 -1
  70. optimum/rbln/utils/runtime_utils.py +9 -3
  71. optimum/rbln/utils/save_utils.py +17 -0
  72. optimum/rbln/utils/submodule.py +23 -0
  73. optimum_rbln-0.2.1a0.dist-info/METADATA +121 -0
  74. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/RECORD +76 -72
  75. optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +288 -0
  76. optimum/rbln/transformers/cache_utils.py +0 -107
  77. optimum/rbln/utils/timer_utils.py +0 -43
  78. optimum_rbln-0.1.15.dist-info/METADATA +0 -106
  79. optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
  80. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/WHEEL +0 -0
optimum/rbln/__init__.py CHANGED
@@ -30,24 +30,13 @@ from .utils import check_version_compats
30
30
 
31
31
 
32
32
  _import_structure = {
33
- "modeling_alias": [
34
- "RBLNASTForAudioClassification",
35
- "RBLNBertForQuestionAnswering",
36
- "RBLNDistilBertForQuestionAnswering",
37
- "RBLNResNetForImageClassification",
38
- "RBLNXLMRobertaForSequenceClassification",
39
- "RBLNRobertaForSequenceClassification",
40
- "RBLNRobertaForMaskedLM",
41
- "RBLNViTForImageClassification",
42
- ],
43
33
  "modeling": [
44
34
  "RBLNBaseModel",
45
35
  "RBLNModel",
46
- "RBLNModelForQuestionAnswering",
47
- "RBLNModelForAudioClassification",
48
- "RBLNModelForImageClassification",
49
- "RBLNModelForSequenceClassification",
50
- "RBLNModelForMaskedLM",
36
+ ],
37
+ "modeling_config": [
38
+ "RBLNCompileConfig",
39
+ "RBLNConfig",
51
40
  ],
52
41
  "transformers": [
53
42
  "RBLNAutoModel",
@@ -83,6 +72,14 @@ _import_structure = {
83
72
  "RBLNMistralForCausalLM",
84
73
  "RBLNWhisperForConditionalGeneration",
85
74
  "RBLNXLMRobertaModel",
75
+ "RBLNASTForAudioClassification",
76
+ "RBLNBertForQuestionAnswering",
77
+ "RBLNDistilBertForQuestionAnswering",
78
+ "RBLNResNetForImageClassification",
79
+ "RBLNXLMRobertaForSequenceClassification",
80
+ "RBLNRobertaForSequenceClassification",
81
+ "RBLNRobertaForMaskedLM",
82
+ "RBLNViTForImageClassification",
86
83
  ],
87
84
  "diffusers": [
88
85
  "RBLNStableDiffusionPipeline",
@@ -103,15 +100,15 @@ _import_structure = {
103
100
  "RBLNStableDiffusion3Img2ImgPipeline",
104
101
  "RBLNStableDiffusion3InpaintPipeline",
105
102
  "RBLNStableDiffusion3Pipeline",
103
+ "RBLNDiffusionMixin",
106
104
  ],
107
- "modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
108
- "modeling_diffusers": ["RBLNDiffusionMixin"],
109
105
  }
110
106
 
111
107
  if TYPE_CHECKING:
112
108
  from .diffusers import (
113
109
  RBLNAutoencoderKL,
114
110
  RBLNControlNetModel,
111
+ RBLNDiffusionMixin,
115
112
  RBLNMultiControlNetModel,
116
113
  RBLNSD3Transformer2DModel,
117
114
  RBLNStableDiffusion3Img2ImgPipeline,
@@ -132,25 +129,13 @@ if TYPE_CHECKING:
132
129
  from .modeling import (
133
130
  RBLNBaseModel,
134
131
  RBLNModel,
135
- RBLNModelForAudioClassification,
136
- RBLNModelForImageClassification,
137
- RBLNModelForMaskedLM,
138
- RBLNModelForQuestionAnswering,
139
- RBLNModelForSequenceClassification,
140
132
  )
141
- from .modeling_alias import (
142
- RBLNASTForAudioClassification,
143
- RBLNBertForQuestionAnswering,
144
- RBLNResNetForImageClassification,
145
- RBLNRobertaForMaskedLM,
146
- RBLNRobertaForSequenceClassification,
147
- RBLNT5ForConditionalGeneration,
148
- RBLNViTForImageClassification,
149
- RBLNXLMRobertaForSequenceClassification,
133
+ from .modeling_config import (
134
+ RBLNCompileConfig,
135
+ RBLNConfig,
150
136
  )
151
- from .modeling_config import RBLNCompileConfig, RBLNConfig
152
- from .modeling_diffusers import RBLNDiffusionMixin
153
137
  from .transformers import (
138
+ RBLNASTForAudioClassification,
154
139
  RBLNAutoModel,
155
140
  RBLNAutoModelForAudioClassification,
156
141
  RBLNAutoModelForCausalLM,
@@ -165,10 +150,12 @@ if TYPE_CHECKING:
165
150
  RBLNAutoModelForVision2Seq,
166
151
  RBLNBartForConditionalGeneration,
167
152
  RBLNBartModel,
153
+ RBLNBertForQuestionAnswering,
168
154
  RBLNBertModel,
169
155
  RBLNCLIPTextModel,
170
156
  RBLNCLIPTextModelWithProjection,
171
157
  RBLNCLIPVisionModel,
158
+ RBLNDistilBertForQuestionAnswering,
172
159
  RBLNDPTForDepthEstimation,
173
160
  RBLNExaoneForCausalLM,
174
161
  RBLNGemmaForCausalLM,
@@ -179,12 +166,18 @@ if TYPE_CHECKING:
179
166
  RBLNMistralForCausalLM,
180
167
  RBLNPhiForCausalLM,
181
168
  RBLNQwen2ForCausalLM,
169
+ RBLNResNetForImageClassification,
170
+ RBLNRobertaForMaskedLM,
171
+ RBLNRobertaForSequenceClassification,
182
172
  RBLNT5EncoderModel,
183
173
  RBLNT5ForConditionalGeneration,
174
+ RBLNViTForImageClassification,
184
175
  RBLNWav2Vec2ForCTC,
185
176
  RBLNWhisperForConditionalGeneration,
177
+ RBLNXLMRobertaForSequenceClassification,
186
178
  RBLNXLMRobertaModel,
187
179
  )
180
+
188
181
  else:
189
182
  import sys
190
183
 
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.1.15'
16
- __version_tuple__ = version_tuple = (0, 1, 15)
15
+ __version__ = version = '0.2.1a0'
16
+ __version_tuple__ = version_tuple = (0, 2, 1)
@@ -54,9 +54,13 @@ _import_structure = {
54
54
  "RBLNControlNetModel",
55
55
  "RBLNSD3Transformer2DModel",
56
56
  ],
57
+ "modeling_diffusers": [
58
+ "RBLNDiffusionMixin",
59
+ ],
57
60
  }
58
61
 
59
62
  if TYPE_CHECKING:
63
+ from .modeling_diffusers import RBLNDiffusionMixin
60
64
  from .models import (
61
65
  RBLNAutoencoderKL,
62
66
  RBLNControlNetModel,
@@ -20,6 +20,7 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+
23
24
  import copy
24
25
  import importlib
25
26
  from os import PathLike
@@ -27,10 +28,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
27
28
 
28
29
  import torch
29
30
 
30
- from .modeling import RBLNModel
31
- from .modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
32
- from .utils.decorator_utils import remove_compile_time_kwargs
31
+ from ..modeling import RBLNModel
32
+ from ..modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
33
+ from ..utils.decorator_utils import remove_compile_time_kwargs
34
+ from ..utils.logging import get_logger
35
+
33
36
 
37
+ logger = get_logger(__name__)
34
38
 
35
39
  if TYPE_CHECKING:
36
40
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
@@ -177,30 +181,39 @@ class RBLNDiffusionMixin:
177
181
  f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
178
182
  )
179
183
 
180
- # Load submodule outside if runtime kwargs(e.g. device) is specified.
181
- if submodule_config := rbln_config.get(submodule_name):
182
- if any(kwd in submodule_config for kwd in RUNTIME_KEYWORDS):
183
- if model_index_config is None:
184
- model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
185
-
186
- module_name, class_name = model_index_config[submodule_name]
187
- if module_name != "optimum.rbln":
188
- raise ValueError(
189
- f"Invalid module_name '{module_name}' found in model_index.json for "
190
- f"submodule '{submodule_name}'. "
191
- "Expected 'optimum.rbln'. Please check the model_index.json configuration."
192
- )
193
- submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), class_name)
194
- submodule = submodule_cls.from_pretrained(
195
- model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
196
- )
197
- kwargs[submodule_name] = submodule
184
+ submodule_config = rbln_config.get(submodule_name, {})
185
+
186
+ for key, value in rbln_config.items():
187
+ if key in RUNTIME_KEYWORDS and key not in submodule_config:
188
+ submodule_config[key] = value
189
+
190
+ if not any(kwd in submodule_config for kwd in RUNTIME_KEYWORDS):
191
+ continue
192
+
193
+ if model_index_config is None:
194
+ model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
195
+
196
+ module_name, class_name = model_index_config[submodule_name]
197
+ if module_name != "optimum.rbln":
198
+ raise ValueError(
199
+ f"Invalid module_name '{module_name}' found in model_index.json for "
200
+ f"submodule '{submodule_name}'. "
201
+ "Expected 'optimum.rbln'. Please check the model_index.json configuration."
202
+ )
203
+
204
+ submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), class_name)
205
+
206
+ submodule = submodule_cls.from_pretrained(
207
+ model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
208
+ )
209
+ kwargs[submodule_name] = submodule
198
210
 
199
211
  with ContextRblnConfig(
200
212
  device=rbln_config.get("device"),
201
213
  device_map=rbln_config.get("device_map"),
202
214
  create_runtimes=rbln_config.get("create_runtimes"),
203
215
  optimize_host_mem=rbln_config.get("optimize_host_memory"),
216
+ activate_profiler=rbln_config.get("activate_profiler"),
204
217
  ):
205
218
  model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
206
219
 
@@ -266,8 +279,8 @@ class RBLNDiffusionMixin:
266
279
  controlnet_rbln_config: Dict[str, Any],
267
280
  ):
268
281
  # Compile multiple ControlNet models for a MultiControlNet setup
269
- from .diffusers.models.controlnet import RBLNControlNetModel
270
- from .diffusers.pipelines.controlnet import RBLNMultiControlNetModel
282
+ from .models.controlnet import RBLNControlNetModel
283
+ from .pipelines.controlnet import RBLNMultiControlNetModel
271
284
 
272
285
  compiled_controlnets = [
273
286
  RBLNControlNetModel.from_model(
@@ -278,7 +291,7 @@ class RBLNDiffusionMixin:
278
291
  )
279
292
  for i, controlnet in enumerate(controlnets.nets)
280
293
  ]
281
- return RBLNMultiControlNetModel(compiled_controlnets, config=controlnets.nets[0].config)
294
+ return RBLNMultiControlNetModel(compiled_controlnets)
282
295
 
283
296
  @classmethod
284
297
  def _construct_pipe(cls, model, submodules, model_save_dir, rbln_config):
@@ -324,6 +337,35 @@ class RBLNDiffusionMixin:
324
337
 
325
338
  return model
326
339
 
340
+ def get_compiled_image_size(self):
341
+ if hasattr(self, "vae"):
342
+ compiled_image_size = self.vae.image_size
343
+ else:
344
+ compiled_image_size = None
345
+ return compiled_image_size
346
+
347
+ def handle_additional_kwargs(self, **kwargs):
348
+ """
349
+ Function to handle additional compile-time parameters during inference.
350
+
351
+ If the additional variable is determined by another module, this method should be overrided.
352
+
353
+ Example:
354
+ ```python
355
+ if hasattr(self, "movq"):
356
+ compiled_image_size = self.movq.image_size
357
+ kwargs["height"] = compiled_image_size[0]
358
+ kwargs["width"] = compiled_image_size[1]
359
+
360
+ compiled_num_frames = self.unet.rbln_config.model_cfg.get("num_frames", None)
361
+ if compiled_num_frames is not None:
362
+ kwargs["num_frames"] = self.unet.rbln_config.model_cfg.get("num_frames")
363
+ return kwargs
364
+ ```
365
+ """
366
+ return kwargs
367
+
327
368
  @remove_compile_time_kwargs
328
369
  def __call__(self, *args, **kwargs):
370
+ kwargs = self.handle_additional_kwargs(**kwargs)
329
371
  return super().__call__(*args, **kwargs)
@@ -20,6 +20,7 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+
23
24
  from typing import TYPE_CHECKING
24
25
 
25
26
  from transformers.utils import _LazyModule
@@ -35,6 +36,7 @@ _import_structure = {
35
36
  "controlnet": ["RBLNControlNetModel"],
36
37
  "transformers": ["RBLNSD3Transformer2DModel"],
37
38
  }
39
+
38
40
  if TYPE_CHECKING:
39
41
  from .autoencoders import (
40
42
  RBLNAutoencoderKL,
@@ -22,7 +22,7 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
26
26
 
27
27
  import rebel
28
28
  import torch # noqa: I001
@@ -32,7 +32,7 @@ from transformers import PretrainedConfig
32
32
 
33
33
  from ....modeling import RBLNModel
34
34
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
35
- from ....modeling_diffusers import RBLNDiffusionMixin
35
+ from ...modeling_diffusers import RBLNDiffusionMixin
36
36
  from .vae import RBLNRuntimeVAEDecoder, RBLNRuntimeVAEEncoder, _VAEDecoder, _VAEEncoder
37
37
 
38
38
 
@@ -88,22 +88,35 @@ class RBLNAutoencoderKL(RBLNModel):
88
88
  @classmethod
89
89
  def get_vae_sample_size(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
90
90
  image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
91
+ noise_module = getattr(pipe, "unet", None) or getattr(pipe, "transformer", None)
92
+ vae_scale_factor = (
93
+ pipe.vae_scale_factor
94
+ if hasattr(pipe, "vae_scale_factor")
95
+ else 2 ** (len(pipe.vae.config.block_out_channels) - 1)
96
+ )
97
+
98
+ if noise_module is None:
99
+ raise AttributeError(
100
+ "Cannot find noise processing or predicting module attributes. ex. U-Net, Transformer, ..."
101
+ )
102
+
91
103
  if (image_size[0] is None) != (image_size[1] is None):
92
104
  raise ValueError("Both image height and image width must be given or not given")
105
+
93
106
  elif image_size[0] is None and image_size[1] is None:
94
107
  if rbln_config["img2img_pipeline"]:
95
- sample_size = pipe.vae.config.sample_size
108
+ sample_size = noise_module.config.sample_size
96
109
  elif rbln_config["inpaint_pipeline"]:
97
- sample_size = pipe.unet.config.sample_size * pipe.vae_scale_factor
110
+ sample_size = noise_module.config.sample_size * vae_scale_factor
98
111
  else:
99
112
  # In case of text2img, sample size of vae decoder is determined by unet.
100
- unet_sample_size = pipe.unet.config.sample_size
101
- if isinstance(unet_sample_size, int):
102
- sample_size = unet_sample_size * pipe.vae_scale_factor
113
+ noise_module_sample_size = noise_module.config.sample_size
114
+ if isinstance(noise_module_sample_size, int):
115
+ sample_size = noise_module_sample_size * vae_scale_factor
103
116
  else:
104
117
  sample_size = (
105
- unet_sample_size[0] * pipe.vae_scale_factor,
106
- unet_sample_size[1] * pipe.vae_scale_factor,
118
+ noise_module_sample_size[0] * vae_scale_factor,
119
+ noise_module_sample_size[1] * vae_scale_factor,
107
120
  )
108
121
  else:
109
122
  sample_size = (image_size[0], image_size[1])
@@ -192,15 +205,28 @@ class RBLNAutoencoderKL(RBLNModel):
192
205
 
193
206
  @classmethod
194
207
  def _create_runtimes(
195
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
208
+ cls,
209
+ compiled_models: List[rebel.RBLNCompiledModel],
210
+ rbln_device_map: Dict[str, int],
211
+ activate_profiler: Optional[bool] = None,
196
212
  ) -> List[rebel.Runtime]:
197
213
  if len(compiled_models) == 1:
214
+ if DEFAULT_COMPILED_MODEL_NAME not in rbln_device_map:
215
+ cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
216
+
198
217
  device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
199
- return [compiled_models[0].create_runtime(tensor_type="pt", device=device_val)]
218
+ return [
219
+ compiled_models[0].create_runtime(
220
+ tensor_type="pt", device=device_val, activate_profiler=activate_profiler
221
+ )
222
+ ]
223
+
224
+ if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
225
+ cls._raise_missing_compiled_file_error(["encoder", "decoder"])
200
226
 
201
227
  device_vals = [rbln_device_map["encoder"], rbln_device_map["decoder"]]
202
228
  return [
203
- compiled_model.create_runtime(tensor_type="pt", device=device_val)
229
+ compiled_model.create_runtime(tensor_type="pt", device=device_val, activate_profiler=activate_profiler)
204
230
  for compiled_model, device_val in zip(compiled_models, device_vals)
205
231
  ]
206
232
 
@@ -21,7 +21,6 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
-
25
24
  import logging
26
25
  from typing import TYPE_CHECKING
27
26
 
@@ -31,7 +31,7 @@ from transformers import PretrainedConfig
31
31
 
32
32
  from ...modeling import RBLNModel
33
33
  from ...modeling_config import RBLNCompileConfig, RBLNConfig
34
- from ...modeling_diffusers import RBLNDiffusionMixin
34
+ from ..modeling_diffusers import RBLNDiffusionMixin
35
35
 
36
36
 
37
37
  if TYPE_CHECKING:
@@ -31,7 +31,7 @@ from transformers import PretrainedConfig
31
31
 
32
32
  from ....modeling import RBLNModel
33
33
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
34
- from ....modeling_diffusers import RBLNDiffusionMixin
34
+ from ...modeling_diffusers import RBLNDiffusionMixin
35
35
 
36
36
 
37
37
  if TYPE_CHECKING:
@@ -31,7 +31,7 @@ from transformers import PretrainedConfig
31
31
 
32
32
  from ....modeling import RBLNModel
33
33
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
34
- from ....modeling_diffusers import RBLNDiffusionMixin
34
+ from ...modeling_diffusers import RBLNDiffusionMixin
35
35
 
36
36
 
37
37
  if TYPE_CHECKING:
@@ -265,15 +265,13 @@ class RBLNUNet2DConditionModel(RBLNModel):
265
265
  ]
266
266
  input_info.append(("mid_block_additional_residual", shape, "float32"))
267
267
 
268
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
269
-
270
268
  if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
271
269
  rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
272
270
  rbln_in_features = model_config.projection_class_embeddings_input_dim
273
- rbln_compile_config.input_info.append(
274
- ("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32")
275
- )
276
- rbln_compile_config.input_info.append(("time_ids", [batch_size, 6], "float32"))
271
+ input_info.append(("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32"))
272
+ input_info.append(("time_ids", [batch_size, 6], "float32"))
273
+
274
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
277
275
 
278
276
  rbln_config = RBLNConfig(
279
277
  rbln_cls=cls.__name__,
@@ -20,6 +20,7 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+
23
24
  from typing import TYPE_CHECKING
24
25
 
25
26
  from transformers.utils import _LazyModule
@@ -30,7 +30,6 @@ import torch
30
30
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
31
31
 
32
32
  from ....modeling import RBLNModel
33
- from ....modeling_config import RBLNConfig
34
33
  from ...models.controlnet import RBLNControlNetModel
35
34
 
36
35
 
@@ -47,7 +46,6 @@ class RBLNMultiControlNetModel(RBLNModel):
47
46
  def __init__(
48
47
  self,
49
48
  models: List[RBLNControlNetModel],
50
- **kwargs,
51
49
  ):
52
50
  self.nets = models
53
51
  self.dtype = torch.float32
@@ -67,19 +65,22 @@ class RBLNMultiControlNetModel(RBLNModel):
67
65
  ) -> RBLNModel:
68
66
  idx = 0
69
67
  controlnets = []
70
- model_path_to_load = model_id
68
+ subfolder_name = kwargs.pop("subfolder", None)
69
+ if subfolder_name is not None:
70
+ model_path_to_load = model_id + "/" + subfolder_name
71
+ else:
72
+ model_path_to_load = model_id
73
+
74
+ base_model_path_to_load = model_path_to_load
71
75
 
72
76
  while os.path.isdir(model_path_to_load):
73
77
  controlnet = RBLNControlNetModel.from_pretrained(model_path_to_load, export=False, **kwargs)
74
78
  controlnets.append(controlnet)
75
- rbln_config = RBLNConfig.load(model_path_to_load)
76
79
  idx += 1
77
- model_path_to_load = model_id + f"_{idx}"
80
+ model_path_to_load = base_model_path_to_load + f"_{idx}"
78
81
 
79
82
  return cls(
80
83
  controlnets,
81
- rbln_config=rbln_config,
82
- **kwargs,
83
84
  )
84
85
 
85
86
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
@@ -1,3 +1,17 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  # Copyright 2024 Rebellions Inc.
2
16
 
3
17
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,7 +34,6 @@
20
34
  # are the intellectual property of Rebellions Inc. and may not be
21
35
  # copied, modified, or distributed without prior written permission
22
36
  # from Rebellions Inc.
23
- """RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
24
37
 
25
38
  from typing import Any, Callable, Dict, List, Optional, Union
26
39
 
@@ -33,8 +46,8 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
33
46
  from diffusers.utils import deprecate, logging
34
47
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
35
48
 
36
- from ....modeling_diffusers import RBLNDiffusionMixin
37
49
  from ....utils.decorator_utils import remove_compile_time_kwargs
50
+ from ...modeling_diffusers import RBLNDiffusionMixin
38
51
  from ...models import RBLNControlNetModel
39
52
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
40
53
 
@@ -46,6 +59,7 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
46
59
  original_class = StableDiffusionControlNetPipeline
47
60
  _submodules = ["text_encoder", "unet", "vae", "controlnet"]
48
61
 
62
+ # Almost copied from diffusers.pipelines.controlnet.pipeline_controlnet.py
49
63
  def check_inputs(
50
64
  self,
51
65
  prompt,
@@ -209,6 +223,7 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
209
223
  f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
210
224
  )
211
225
 
226
+ # Almost copied from diffusers.pipelines.controlnet.pipeline_controlnet.py
212
227
  @torch.no_grad()
213
228
  @remove_compile_time_kwargs
214
229
  def __call__(
@@ -1,3 +1,17 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  # Copyright 2024 Rebellions Inc.
2
16
 
3
17
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,7 +34,6 @@
20
34
  # are the intellectual property of Rebellions Inc. and may not be
21
35
  # copied, modified, or distributed without prior written permission
22
36
  # from Rebellions Inc.
23
- """RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
24
37
 
25
38
  from typing import Any, Callable, Dict, List, Optional, Union
26
39
 
@@ -32,8 +45,8 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
32
45
  from diffusers.utils import deprecate, logging
33
46
  from diffusers.utils.torch_utils import is_compiled_module
34
47
 
35
- from ....modeling_diffusers import RBLNDiffusionMixin
36
48
  from ....utils.decorator_utils import remove_compile_time_kwargs
49
+ from ...modeling_diffusers import RBLNDiffusionMixin
37
50
  from ...models import RBLNControlNetModel
38
51
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
39
52
 
@@ -45,6 +58,7 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDif
45
58
  original_class = StableDiffusionControlNetImg2ImgPipeline
46
59
  _submodules = ["text_encoder", "unet", "vae", "controlnet"]
47
60
 
61
+ # Almost copied from diffusers.pipelines.controlnet.pipeline_controlnet_img2img.py
48
62
  def check_inputs(
49
63
  self,
50
64
  prompt,
@@ -202,6 +216,7 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDif
202
216
  f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
203
217
  )
204
218
 
219
+ # Almost copied from diffusers.pipelines.controlnet.pipeline_controlnet_img2img.py
205
220
  @torch.no_grad()
206
221
  @remove_compile_time_kwargs
207
222
  def __call__(
@@ -1,3 +1,17 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  # Copyright 2024 Rebellions Inc.
2
16
 
3
17
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,7 +34,6 @@
20
34
  # are the intellectual property of Rebellions Inc. and may not be
21
35
  # copied, modified, or distributed without prior written permission
22
36
  # from Rebellions Inc.
23
- """RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
24
37
 
25
38
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
26
39
 
@@ -32,8 +45,8 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffus
32
45
  from diffusers.utils import deprecate, logging
33
46
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
34
47
 
35
- from ....modeling_diffusers import RBLNDiffusionMixin
36
48
  from ....utils.decorator_utils import remove_compile_time_kwargs
49
+ from ...modeling_diffusers import RBLNDiffusionMixin
37
50
  from ...models import RBLNControlNetModel
38
51
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
39
52
 
@@ -45,6 +58,7 @@ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusio
45
58
  original_class = StableDiffusionXLControlNetPipeline
46
59
  _submodules = ["text_encoder", "text_encoder_2", "unet", "vae", "controlnet"]
47
60
 
61
+ # Almost copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.py
48
62
  def check_inputs(
49
63
  self,
50
64
  prompt,
@@ -234,6 +248,7 @@ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusio
234
248
  f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
235
249
  )
236
250
 
251
+ # Almost copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.py
237
252
  @torch.no_grad()
238
253
  @remove_compile_time_kwargs
239
254
  def __call__(