optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. optimum/rbln/__init__.py +156 -36
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +85 -75
  31. optimum/rbln/transformers/__init__.py +79 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +96 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  73. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  74. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  75. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  76. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  77. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  78. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  79. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  80. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  81. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  82. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  83. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  84. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  85. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  86. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  87. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  88. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  89. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  90. optimum/rbln/utils/submodule.py +26 -43
  91. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
  92. optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
  93. optimum/rbln/modeling_config.py +0 -310
  94. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  95. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,143 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....transformers import RBLNCLIPTextModelWithProjectionConfig, RBLNT5EncoderModelConfig
19
+ from ....utils.logging import get_logger
20
+ from ..models import RBLNAutoencoderKLConfig, RBLNSD3Transformer2DModelConfig
21
+
22
+
23
+ logger = get_logger(__name__)
24
+
25
+
26
+ class _RBLNStableDiffusion3PipelineBaseConfig(RBLNModelConfig):
27
+ submodules = ["transformer", "text_encoder", "text_encoder_2", "text_encoder_3", "vae"]
28
+ _vae_uses_encoder = False
29
+
30
+ def __init__(
31
+ self,
32
+ transformer: Optional[RBLNSD3Transformer2DModelConfig] = None,
33
+ text_encoder: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
34
+ text_encoder_2: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
35
+ text_encoder_3: Optional[RBLNT5EncoderModelConfig] = None,
36
+ vae: Optional[RBLNAutoencoderKLConfig] = None,
37
+ *,
38
+ max_seq_len: Optional[int] = None,
39
+ sample_size: Optional[Tuple[int, int]] = None,
40
+ image_size: Optional[Tuple[int, int]] = None,
41
+ batch_size: Optional[int] = None,
42
+ img_height: Optional[int] = None,
43
+ img_width: Optional[int] = None,
44
+ guidance_scale: Optional[float] = None,
45
+ **kwargs,
46
+ ):
47
+ """
48
+ Args:
49
+ transformer (Optional[RBLNSD3Transformer2DModelConfig]): Configuration for the transformer model component.
50
+ Initialized as RBLNSD3Transformer2DModelConfig if not provided.
51
+ text_encoder (Optional[RBLNCLIPTextModelWithProjectionConfig]): Configuration for the primary text encoder.
52
+ Initialized as RBLNCLIPTextModelWithProjectionConfig if not provided.
53
+ text_encoder_2 (Optional[RBLNCLIPTextModelWithProjectionConfig]): Configuration for the secondary text encoder.
54
+ Initialized as RBLNCLIPTextModelWithProjectionConfig if not provided.
55
+ text_encoder_3 (Optional[RBLNT5EncoderModelConfig]): Configuration for the tertiary text encoder.
56
+ Initialized as RBLNT5EncoderModelConfig if not provided.
57
+ vae (Optional[RBLNAutoencoderKLConfig]): Configuration for the VAE model component.
58
+ Initialized as RBLNAutoencoderKLConfig if not provided.
59
+ max_seq_len (Optional[int]): Maximum sequence length for text inputs. Defaults to 256.
60
+ sample_size (Optional[Tuple[int, int]]): Spatial dimensions for the transformer model.
61
+ image_size (Optional[Tuple[int, int]]): Dimensions for the generated images.
62
+ Cannot be used together with img_height/img_width.
63
+ batch_size (Optional[int]): Batch size for inference, applied to all submodules.
64
+ img_height (Optional[int]): Height of the generated images.
65
+ img_width (Optional[int]): Width of the generated images.
66
+ guidance_scale (Optional[float]): Scale for classifier-free guidance. Deprecated parameter.
67
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
68
+
69
+ Raises:
70
+ ValueError: If both image_size and img_height/img_width are provided.
71
+
72
+ Note:
73
+ When guidance_scale > 1.0, the transformer batch size is automatically doubled to
74
+ accommodate classifier-free guidance.
75
+ """
76
+ super().__init__(**kwargs)
77
+ if image_size is not None and (img_height is not None or img_width is not None):
78
+ raise ValueError("image_size and img_height/img_width cannot both be provided")
79
+
80
+ if img_height is not None and img_width is not None:
81
+ image_size = (img_height, img_width)
82
+
83
+ max_seq_len = max_seq_len or 256
84
+
85
+ self.text_encoder = self.init_submodule_config(
86
+ RBLNCLIPTextModelWithProjectionConfig, text_encoder, batch_size=batch_size
87
+ )
88
+ self.text_encoder_2 = self.init_submodule_config(
89
+ RBLNCLIPTextModelWithProjectionConfig, text_encoder_2, batch_size=batch_size
90
+ )
91
+ self.text_encoder_3 = self.init_submodule_config(
92
+ RBLNT5EncoderModelConfig,
93
+ text_encoder_3,
94
+ batch_size=batch_size,
95
+ max_seq_len=max_seq_len,
96
+ )
97
+ self.transformer = self.init_submodule_config(
98
+ RBLNSD3Transformer2DModelConfig,
99
+ transformer,
100
+ batch_size=batch_size,
101
+ sample_size=sample_size,
102
+ )
103
+ self.vae = self.init_submodule_config(
104
+ RBLNAutoencoderKLConfig,
105
+ vae,
106
+ batch_size=batch_size,
107
+ uses_encoder=self.__class__._vae_uses_encoder,
108
+ sample_size=image_size,
109
+ )
110
+
111
+ if guidance_scale is not None:
112
+ logger.warning("Specifying `guidance_scale` is deprecated. It will be removed in a future version.")
113
+ do_classifier_free_guidance = guidance_scale > 1.0
114
+ if do_classifier_free_guidance:
115
+ self.transformer.batch_size = self.text_encoder.batch_size * 2
116
+
117
+ @property
118
+ def max_seq_len(self):
119
+ return self.text_encoder_3.max_seq_len
120
+
121
+ @property
122
+ def batch_size(self):
123
+ return self.vae.batch_size
124
+
125
+ @property
126
+ def sample_size(self):
127
+ return self.transformer.sample_size
128
+
129
+ @property
130
+ def image_size(self):
131
+ return self.vae.sample_size
132
+
133
+
134
+ class RBLNStableDiffusion3PipelineConfig(_RBLNStableDiffusion3PipelineBaseConfig):
135
+ _vae_uses_encoder = False
136
+
137
+
138
+ class RBLNStableDiffusion3Img2ImgPipelineConfig(_RBLNStableDiffusion3PipelineBaseConfig):
139
+ _vae_uses_encoder = True
140
+
141
+
142
+ class RBLNStableDiffusion3InpaintPipelineConfig(_RBLNStableDiffusion3PipelineBaseConfig):
143
+ _vae_uses_encoder = True
@@ -0,0 +1,124 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....transformers import RBLNCLIPTextModelConfig, RBLNCLIPTextModelWithProjectionConfig
19
+ from ....utils.logging import get_logger
20
+ from ..models import RBLNAutoencoderKLConfig, RBLNUNet2DConditionModelConfig
21
+
22
+
23
+ logger = get_logger(__name__)
24
+
25
+
26
+ class _RBLNStableDiffusionXLPipelineBaseConfig(RBLNModelConfig):
27
+ submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
28
+ _vae_uses_encoder = False
29
+
30
+ def __init__(
31
+ self,
32
+ text_encoder: Optional[RBLNCLIPTextModelConfig] = None,
33
+ text_encoder_2: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
34
+ unet: Optional[RBLNUNet2DConditionModelConfig] = None,
35
+ vae: Optional[RBLNAutoencoderKLConfig] = None,
36
+ *,
37
+ batch_size: Optional[int] = None,
38
+ img_height: Optional[int] = None,
39
+ img_width: Optional[int] = None,
40
+ sample_size: Optional[Tuple[int, int]] = None,
41
+ image_size: Optional[Tuple[int, int]] = None,
42
+ guidance_scale: Optional[float] = None,
43
+ **kwargs,
44
+ ):
45
+ """
46
+ Args:
47
+ text_encoder (Optional[RBLNCLIPTextModelConfig]): Configuration for the primary text encoder component.
48
+ Initialized as RBLNCLIPTextModelConfig if not provided.
49
+ text_encoder_2 (Optional[RBLNCLIPTextModelWithProjectionConfig]): Configuration for the secondary text encoder component.
50
+ Initialized as RBLNCLIPTextModelWithProjectionConfig if not provided.
51
+ unet (Optional[RBLNUNet2DConditionModelConfig]): Configuration for the UNet model component.
52
+ Initialized as RBLNUNet2DConditionModelConfig if not provided.
53
+ vae (Optional[RBLNAutoencoderKLConfig]): Configuration for the VAE model component.
54
+ Initialized as RBLNAutoencoderKLConfig if not provided.
55
+ batch_size (Optional[int]): Batch size for inference, applied to all submodules.
56
+ img_height (Optional[int]): Height of the generated images.
57
+ img_width (Optional[int]): Width of the generated images.
58
+ sample_size (Optional[Tuple[int, int]]): Spatial dimensions for the UNet model.
59
+ image_size (Optional[Tuple[int, int]]): Alternative way to specify image dimensions.
60
+ Cannot be used together with img_height/img_width.
61
+ guidance_scale (Optional[float]): Scale for classifier-free guidance. Deprecated parameter.
62
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
63
+
64
+ Raises:
65
+ ValueError: If both image_size and img_height/img_width are provided.
66
+
67
+ Note:
68
+ When guidance_scale > 1.0, the UNet batch size is automatically doubled to
69
+ accommodate classifier-free guidance.
70
+ """
71
+ super().__init__(**kwargs)
72
+ if image_size is not None and (img_height is not None or img_width is not None):
73
+ raise ValueError("image_size and img_height/img_width cannot both be provided")
74
+
75
+ if img_height is not None and img_width is not None:
76
+ image_size = (img_height, img_width)
77
+
78
+ self.text_encoder = self.init_submodule_config(RBLNCLIPTextModelConfig, text_encoder, batch_size=batch_size)
79
+ self.text_encoder_2 = self.init_submodule_config(
80
+ RBLNCLIPTextModelWithProjectionConfig, text_encoder_2, batch_size=batch_size
81
+ )
82
+ self.unet = self.init_submodule_config(
83
+ RBLNUNet2DConditionModelConfig,
84
+ unet,
85
+ batch_size=batch_size,
86
+ sample_size=sample_size,
87
+ )
88
+ self.vae = self.init_submodule_config(
89
+ RBLNAutoencoderKLConfig,
90
+ vae,
91
+ batch_size=batch_size,
92
+ uses_encoder=self.__class__._vae_uses_encoder,
93
+ sample_size=image_size, # image size is equal to sample size in vae
94
+ )
95
+
96
+ if guidance_scale is not None:
97
+ logger.warning("Specifying `guidance_scale` is deprecated. It will be removed in a future version.")
98
+ do_classifier_free_guidance = guidance_scale > 1.0
99
+ if do_classifier_free_guidance:
100
+ self.unet.batch_size = self.text_encoder.batch_size * 2
101
+
102
+ @property
103
+ def batch_size(self):
104
+ return self.vae.batch_size
105
+
106
+ @property
107
+ def sample_size(self):
108
+ return self.unet.sample_size
109
+
110
+ @property
111
+ def image_size(self):
112
+ return self.vae.sample_size
113
+
114
+
115
+ class RBLNStableDiffusionXLPipelineConfig(_RBLNStableDiffusionXLPipelineBaseConfig):
116
+ _vae_uses_encoder = False
117
+
118
+
119
+ class RBLNStableDiffusionXLImg2ImgPipelineConfig(_RBLNStableDiffusionXLPipelineBaseConfig):
120
+ _vae_uses_encoder = True
121
+
122
+
123
+ class RBLNStableDiffusionXLInpaintPipelineConfig(_RBLNStableDiffusionXLPipelineBaseConfig):
124
+ _vae_uses_encoder = True
@@ -15,12 +15,14 @@
15
15
  import copy
16
16
  import importlib
17
17
  from os import PathLike
18
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
18
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
19
19
 
20
20
  import torch
21
21
 
22
+ from ..configuration_utils import ContextRblnConfig, RBLNModelConfig
22
23
  from ..modeling import RBLNModel
23
- from ..modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
24
+
25
+ # from ..transformers import RBLNCLIPTextModelConfig
24
26
  from ..utils.decorator_utils import remove_compile_time_kwargs
25
27
  from ..utils.logging import get_logger
26
28
 
@@ -31,6 +33,10 @@ if TYPE_CHECKING:
31
33
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
32
34
 
33
35
 
36
+ class RBLNDiffusionMixinConfig(RBLNModelConfig):
37
+ pass
38
+
39
+
34
40
  class RBLNDiffusionMixin:
35
41
  """
36
42
  RBLNDiffusionMixin provides essential functionalities for compiling Stable Diffusion pipeline components to run on RBLN NPUs.
@@ -69,6 +75,7 @@ class RBLNDiffusionMixin:
69
75
  _connected_classes = {}
70
76
  _submodules = []
71
77
  _prefix = {}
78
+ _rbln_config_class = None
72
79
 
73
80
  @classmethod
74
81
  def is_img2img_pipeline(cls):
@@ -78,35 +85,6 @@ class RBLNDiffusionMixin:
78
85
  def is_inpaint_pipeline(cls):
79
86
  return "Inpaint" in cls.__name__
80
87
 
81
- @classmethod
82
- def get_submodule_rbln_config(
83
- cls, model: torch.nn.Module, submodule_name: str, rbln_config: Dict[str, Any]
84
- ) -> Dict[str, Any]:
85
- submodule = getattr(model, submodule_name)
86
- submodule_class_name = submodule.__class__.__name__
87
- if isinstance(submodule, torch.nn.Module):
88
- if submodule_class_name == "MultiControlNetModel":
89
- submodule_class_name = "ControlNetModel"
90
-
91
- submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"RBLN{submodule_class_name}")
92
-
93
- submodule_config = rbln_config.get(submodule_name, {})
94
- submodule_config = copy.deepcopy(submodule_config)
95
-
96
- pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
97
-
98
- submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
99
- submodule_config.update(
100
- {
101
- "img2img_pipeline": cls.is_img2img_pipeline(),
102
- "inpaint_pipeline": cls.is_inpaint_pipeline(),
103
- }
104
- )
105
- submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
106
- else:
107
- raise ValueError(f"submodule {submodule_name} isn't supported")
108
- return submodule_config
109
-
110
88
  @staticmethod
111
89
  def _maybe_apply_and_fuse_lora(
112
90
  model: torch.nn.Module,
@@ -146,7 +124,22 @@ class RBLNDiffusionMixin:
146
124
  return model
147
125
 
148
126
  @classmethod
149
- @use_rbln_config
127
+ def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
128
+ """
129
+ Lazily loads and caches the corresponding RBLN model config class.
130
+ """
131
+ if cls._rbln_config_class is None:
132
+ rbln_config_class_name = cls.__name__ + "Config"
133
+ library = importlib.import_module("optimum.rbln")
134
+ cls._rbln_config_class = getattr(library, rbln_config_class_name, None)
135
+ if cls._rbln_config_class is None:
136
+ raise ValueError(
137
+ f"RBLN config class {rbln_config_class_name} not found. This is an internal error. "
138
+ "Please report it to the developers."
139
+ )
140
+ return cls._rbln_config_class
141
+
142
+ @classmethod
150
143
  def from_pretrained(
151
144
  cls,
152
145
  model_id: str,
@@ -159,6 +152,8 @@ class RBLNDiffusionMixin:
159
152
  lora_scales: Optional[Union[float, List[float]]] = None,
160
153
  **kwargs,
161
154
  ) -> RBLNModel:
155
+ rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
156
+
162
157
  if export:
163
158
  # keep submodules if user passed any of them.
164
159
  passed_submodules = {
@@ -168,22 +163,12 @@ class RBLNDiffusionMixin:
168
163
  else:
169
164
  # raise error if any of submodules are torch module.
170
165
  model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
171
- rbln_config = cls._flatten_rbln_config(rbln_config)
172
166
  for submodule_name in cls._submodules:
173
167
  if isinstance(kwargs.get(submodule_name), torch.nn.Module):
174
168
  raise AssertionError(
175
169
  f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
176
170
  )
177
171
 
178
- submodule_config = rbln_config.get(submodule_name, {})
179
-
180
- for key, value in rbln_config.items():
181
- if key in RUNTIME_KEYWORDS and key not in submodule_config:
182
- submodule_config[key] = value
183
-
184
- if not any(kwd in submodule_config for kwd in RUNTIME_KEYWORDS):
185
- continue
186
-
187
172
  module_name, class_name = model_index_config[submodule_name]
188
173
  if module_name != "optimum.rbln":
189
174
  raise ValueError(
@@ -192,19 +177,19 @@ class RBLNDiffusionMixin:
192
177
  "Expected 'optimum.rbln'. Please check the model_index.json configuration."
193
178
  )
194
179
 
195
- submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), class_name)
196
-
180
+ submodule_cls: Type[RBLNModel] = getattr(importlib.import_module("optimum.rbln"), class_name)
181
+ submodule_config = getattr(rbln_config, submodule_name)
197
182
  submodule = submodule_cls.from_pretrained(
198
183
  model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
199
184
  )
200
185
  kwargs[submodule_name] = submodule
201
186
 
202
187
  with ContextRblnConfig(
203
- device=rbln_config.get("device"),
204
- device_map=rbln_config.get("device_map"),
205
- create_runtimes=rbln_config.get("create_runtimes"),
206
- optimize_host_mem=rbln_config.get("optimize_host_memory"),
207
- activate_profiler=rbln_config.get("activate_profiler"),
188
+ device=rbln_config.device,
189
+ device_map=rbln_config.device_map,
190
+ create_runtimes=rbln_config.create_runtimes,
191
+ optimize_host_mem=rbln_config.optimize_host_memory,
192
+ activate_profiler=rbln_config.activate_profiler,
208
193
  ):
209
194
  model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
210
195
 
@@ -224,78 +209,27 @@ class RBLNDiffusionMixin:
224
209
  compiled_submodules = cls._compile_submodules(model, passed_submodules, model_save_dir, rbln_config)
225
210
  return cls._construct_pipe(model, compiled_submodules, model_save_dir, rbln_config)
226
211
 
227
- @classmethod
228
- def _prepare_rbln_config(
229
- cls,
230
- rbln_config,
231
- ) -> Dict[str, Any]:
232
- prepared_config = {}
233
- for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
234
- connected_pipe_config = rbln_config.pop(connected_pipe_name, {})
235
- prefix = cls._prefix.get(connected_pipe_name, "")
236
- guidance_scale = rbln_config.pop(f"{prefix}guidance_scale", None)
237
- if "guidance_scale" not in connected_pipe_config and guidance_scale is not None:
238
- connected_pipe_config["guidance_scale"] = guidance_scale
239
- for submodule_name in connected_pipe_cls._submodules:
240
- submodule_config = rbln_config.pop(prefix + submodule_name, {})
241
- if submodule_name not in connected_pipe_config:
242
- connected_pipe_config[submodule_name] = {}
243
- connected_pipe_config[submodule_name].update(
244
- {k: v for k, v in submodule_config.items() if k not in connected_pipe_config[submodule_name]}
245
- )
246
- prepared_config[connected_pipe_name] = connected_pipe_config
247
- prepared_config.update(rbln_config)
248
- return prepared_config
249
-
250
- @classmethod
251
- def _flatten_rbln_config(
252
- cls,
253
- rbln_config,
254
- ) -> Dict[str, Any]:
255
- prepared_config = cls._prepare_rbln_config(rbln_config)
256
- flattened_config = {}
257
- pipe_global_config = {k: v for k, v in prepared_config.items() if k not in cls._connected_classes.keys()}
258
- for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
259
- connected_pipe_config = prepared_config.pop(connected_pipe_name)
260
- prefix = cls._prefix.get(connected_pipe_name, "")
261
- connected_pipe_global_config = {
262
- k: v for k, v in connected_pipe_config.items() if k not in connected_pipe_cls._submodules
263
- }
264
- for submodule_name in connected_pipe_cls._submodules:
265
- flattened_config[prefix + submodule_name] = connected_pipe_config[submodule_name]
266
- flattened_config[prefix + submodule_name].update(
267
- {
268
- k: v
269
- for k, v in connected_pipe_global_config.items()
270
- if k not in flattened_config[prefix + submodule_name]
271
- }
272
- )
273
- flattened_config.update(pipe_global_config)
274
- return flattened_config
275
-
276
212
  @classmethod
277
213
  def _compile_pipelines(
278
214
  cls,
279
215
  model: torch.nn.Module,
280
216
  passed_submodules: Dict[str, RBLNModel],
281
217
  model_save_dir: Optional[PathLike],
282
- rbln_config: Dict[str, Any],
218
+ rbln_config: "RBLNDiffusionMixinConfig",
283
219
  ) -> Dict[str, RBLNModel]:
284
220
  compiled_submodules = {}
285
-
286
- rbln_config = cls._prepare_rbln_config(rbln_config)
287
- pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._connected_classes.keys()}
288
221
  for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
289
222
  connected_pipe_submodules = {}
290
223
  prefix = cls._prefix.get(connected_pipe_name, "")
291
224
  for submodule_name in connected_pipe_cls._submodules:
292
225
  connected_pipe_submodules[submodule_name] = passed_submodules.get(prefix + submodule_name, None)
293
226
  connected_pipe = getattr(model, connected_pipe_name)
294
- connected_pipe_config = {}
295
- connected_pipe_config.update(pipe_global_config)
296
- connected_pipe_config.update(rbln_config[connected_pipe_name])
297
227
  connected_pipe_compiled_submodules = connected_pipe_cls._compile_submodules(
298
- connected_pipe, connected_pipe_submodules, model_save_dir, connected_pipe_config, prefix
228
+ connected_pipe,
229
+ connected_pipe_submodules,
230
+ model_save_dir,
231
+ getattr(rbln_config, connected_pipe_name),
232
+ prefix,
299
233
  )
300
234
  for submodule_name, compiled_submodule in connected_pipe_compiled_submodules.items():
301
235
  compiled_submodules[prefix + submodule_name] = compiled_submodule
@@ -307,14 +241,19 @@ class RBLNDiffusionMixin:
307
241
  model: torch.nn.Module,
308
242
  passed_submodules: Dict[str, RBLNModel],
309
243
  model_save_dir: Optional[PathLike],
310
- rbln_config: Dict[str, Any],
244
+ rbln_config: RBLNDiffusionMixinConfig,
311
245
  prefix: Optional[str] = "",
312
246
  ) -> Dict[str, RBLNModel]:
313
247
  compiled_submodules = {}
314
248
 
315
249
  for submodule_name in cls._submodules:
316
250
  submodule = passed_submodules.get(submodule_name) or getattr(model, submodule_name, None)
317
- submodule_rbln_config = cls.get_submodule_rbln_config(model, submodule_name, rbln_config)
251
+
252
+ if getattr(rbln_config, submodule_name, None) is None:
253
+ raise ValueError(f"RBLN config for submodule {submodule_name} is not provided.")
254
+
255
+ submodule_rbln_cls: Type[RBLNModel] = getattr(rbln_config, submodule_name).rbln_model_cls
256
+ rbln_config = submodule_rbln_cls.update_rbln_config_using_pipe(model, rbln_config, submodule_name)
318
257
 
319
258
  if submodule is None:
320
259
  raise ValueError(f"submodule ({submodule_name}) cannot be accessed since it is not provided.")
@@ -325,7 +264,7 @@ class RBLNDiffusionMixin:
325
264
  submodule = cls._compile_multicontrolnet(
326
265
  controlnets=submodule,
327
266
  model_save_dir=model_save_dir,
328
- controlnet_rbln_config=submodule_rbln_config,
267
+ controlnet_rbln_config=getattr(rbln_config, submodule_name),
329
268
  prefix=prefix,
330
269
  )
331
270
  elif isinstance(submodule, torch.nn.Module):
@@ -337,7 +276,7 @@ class RBLNDiffusionMixin:
337
276
  model=submodule,
338
277
  subfolder=subfolder,
339
278
  model_save_dir=model_save_dir,
340
- rbln_config=submodule_rbln_config,
279
+ rbln_config=getattr(rbln_config, submodule_name),
341
280
  )
342
281
  else:
343
282
  raise ValueError(f"Unknown class of submodule({submodule_name}) : {submodule.__class__.__name__} ")
@@ -350,22 +289,24 @@ class RBLNDiffusionMixin:
350
289
  cls,
351
290
  controlnets: "MultiControlNetModel",
352
291
  model_save_dir: Optional[PathLike],
353
- controlnet_rbln_config: Dict[str, Any],
292
+ controlnet_rbln_config: RBLNModelConfig,
354
293
  prefix: Optional[str] = "",
355
294
  ):
356
295
  # Compile multiple ControlNet models for a MultiControlNet setup
357
296
  from .models.controlnet import RBLNControlNetModel
358
297
  from .pipelines.controlnet import RBLNMultiControlNetModel
359
298
 
360
- compiled_controlnets = [
361
- RBLNControlNetModel.from_model(
362
- model=controlnet,
363
- subfolder=f"{prefix}controlnet" if i == 0 else f"{prefix}controlnet_{i}",
364
- model_save_dir=model_save_dir,
365
- rbln_config=controlnet_rbln_config,
299
+ compiled_controlnets = []
300
+ for i, controlnet in enumerate(controlnets.nets):
301
+ _controlnet_rbln_config = copy.deepcopy(controlnet_rbln_config)
302
+ compiled_controlnets.append(
303
+ RBLNControlNetModel.from_model(
304
+ model=controlnet,
305
+ subfolder=f"{prefix}controlnet" if i == 0 else f"{prefix}controlnet_{i}",
306
+ model_save_dir=model_save_dir,
307
+ rbln_config=_controlnet_rbln_config,
308
+ )
366
309
  )
367
- for i, controlnet in enumerate(controlnets.nets)
368
- ]
369
310
  return RBLNMultiControlNetModel(compiled_controlnets)
370
311
 
371
312
  @classmethod
@@ -412,7 +353,7 @@ class RBLNDiffusionMixin:
412
353
  # overwrite to replace incorrect config
413
354
  model.save_config(model_save_dir)
414
355
 
415
- if rbln_config.get("optimize_host_memory") is False:
356
+ if rbln_config.optimize_host_memory is False:
416
357
  # Keep compiled_model objs to further analysis. -> TODO: remove soon...
417
358
  model.compiled_models = []
418
359
  for name in cls._submodules:
@@ -441,9 +382,9 @@ class RBLNDiffusionMixin:
441
382
  kwargs["height"] = compiled_image_size[0]
442
383
  kwargs["width"] = compiled_image_size[1]
443
384
 
444
- compiled_num_frames = self.unet.rbln_config.model_cfg.get("num_frames", None)
385
+ compiled_num_frames = self.unet.rbln_config.num_frames
445
386
  if compiled_num_frames is not None:
446
- kwargs["num_frames"] = self.unet.rbln_config.model_cfg.get("num_frames")
387
+ kwargs["num_frames"] = compiled_num_frames
447
388
  return kwargs
448
389
  ```
449
390
  """