optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,143 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....transformers import RBLNCLIPTextModelWithProjectionConfig, RBLNT5EncoderModelConfig
19
+ from ..models import RBLNAutoencoderKLConfig, RBLNSD3Transformer2DModelConfig
20
+
21
+
22
+ class _RBLNStableDiffusion3PipelineBaseConfig(RBLNModelConfig):
23
+ submodules = ["transformer", "text_encoder", "text_encoder_2", "text_encoder_3", "vae"]
24
+ _vae_uses_encoder = False
25
+
26
+ def __init__(
27
+ self,
28
+ transformer: Optional[RBLNSD3Transformer2DModelConfig] = None,
29
+ text_encoder: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
30
+ text_encoder_2: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
31
+ text_encoder_3: Optional[RBLNT5EncoderModelConfig] = None,
32
+ vae: Optional[RBLNAutoencoderKLConfig] = None,
33
+ *,
34
+ max_seq_len: Optional[int] = None,
35
+ sample_size: Optional[Tuple[int, int]] = None,
36
+ image_size: Optional[Tuple[int, int]] = None,
37
+ batch_size: Optional[int] = None,
38
+ img_height: Optional[int] = None,
39
+ img_width: Optional[int] = None,
40
+ guidance_scale: Optional[float] = None,
41
+ **kwargs,
42
+ ):
43
+ """
44
+ Args:
45
+ transformer (Optional[RBLNSD3Transformer2DModelConfig]): Configuration for the transformer model component.
46
+ Initialized as RBLNSD3Transformer2DModelConfig if not provided.
47
+ text_encoder (Optional[RBLNCLIPTextModelWithProjectionConfig]): Configuration for the primary text encoder.
48
+ Initialized as RBLNCLIPTextModelWithProjectionConfig if not provided.
49
+ text_encoder_2 (Optional[RBLNCLIPTextModelWithProjectionConfig]): Configuration for the secondary text encoder.
50
+ Initialized as RBLNCLIPTextModelWithProjectionConfig if not provided.
51
+ text_encoder_3 (Optional[RBLNT5EncoderModelConfig]): Configuration for the tertiary text encoder.
52
+ Initialized as RBLNT5EncoderModelConfig if not provided.
53
+ vae (Optional[RBLNAutoencoderKLConfig]): Configuration for the VAE model component.
54
+ Initialized as RBLNAutoencoderKLConfig if not provided.
55
+ max_seq_len (Optional[int]): Maximum sequence length for text inputs. Defaults to 256.
56
+ sample_size (Optional[Tuple[int, int]]): Spatial dimensions for the transformer model.
57
+ image_size (Optional[Tuple[int, int]]): Dimensions for the generated images.
58
+ Cannot be used together with img_height/img_width.
59
+ batch_size (Optional[int]): Batch size for inference, applied to all submodules.
60
+ img_height (Optional[int]): Height of the generated images.
61
+ img_width (Optional[int]): Width of the generated images.
62
+ guidance_scale (Optional[float]): Scale for classifier-free guidance.
63
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
64
+
65
+ Raises:
66
+ ValueError: If both image_size and img_height/img_width are provided.
67
+
68
+ Note:
69
+ When guidance_scale > 1.0, the transformer batch size is automatically doubled to
70
+ accommodate classifier-free guidance.
71
+ """
72
+ super().__init__(**kwargs)
73
+ if image_size is not None and (img_height is not None or img_width is not None):
74
+ raise ValueError("image_size and img_height/img_width cannot both be provided")
75
+
76
+ if img_height is not None and img_width is not None:
77
+ image_size = (img_height, img_width)
78
+
79
+ max_seq_len = max_seq_len or 256
80
+
81
+ self.text_encoder = self.init_submodule_config(
82
+ RBLNCLIPTextModelWithProjectionConfig, text_encoder, batch_size=batch_size
83
+ )
84
+ self.text_encoder_2 = self.init_submodule_config(
85
+ RBLNCLIPTextModelWithProjectionConfig, text_encoder_2, batch_size=batch_size
86
+ )
87
+ self.text_encoder_3 = self.init_submodule_config(
88
+ RBLNT5EncoderModelConfig,
89
+ text_encoder_3,
90
+ batch_size=batch_size,
91
+ max_seq_len=max_seq_len,
92
+ )
93
+ self.transformer = self.init_submodule_config(
94
+ RBLNSD3Transformer2DModelConfig,
95
+ transformer,
96
+ sample_size=sample_size,
97
+ )
98
+ self.vae = self.init_submodule_config(
99
+ RBLNAutoencoderKLConfig,
100
+ vae,
101
+ batch_size=batch_size,
102
+ uses_encoder=self.__class__._vae_uses_encoder,
103
+ sample_size=image_size,
104
+ )
105
+
106
+ # Get default guidance scale from original class to set Transformer batch size
107
+ if guidance_scale is None:
108
+ guidance_scale = self.get_default_values_for_original_cls("__call__", ["guidance_scale"])["guidance_scale"]
109
+
110
+ if not self.transformer.batch_size_is_specified:
111
+ do_classifier_free_guidance = guidance_scale > 1.0
112
+ if do_classifier_free_guidance:
113
+ self.transformer.batch_size = self.text_encoder.batch_size * 2
114
+ else:
115
+ self.transformer.batch_size = self.text_encoder.batch_size
116
+
117
+ @property
118
+ def max_seq_len(self):
119
+ return self.text_encoder_3.max_seq_len
120
+
121
+ @property
122
+ def batch_size(self):
123
+ return self.vae.batch_size
124
+
125
+ @property
126
+ def sample_size(self):
127
+ return self.transformer.sample_size
128
+
129
+ @property
130
+ def image_size(self):
131
+ return self.vae.sample_size
132
+
133
+
134
+ class RBLNStableDiffusion3PipelineConfig(_RBLNStableDiffusion3PipelineBaseConfig):
135
+ _vae_uses_encoder = False
136
+
137
+
138
+ class RBLNStableDiffusion3Img2ImgPipelineConfig(_RBLNStableDiffusion3PipelineBaseConfig):
139
+ _vae_uses_encoder = True
140
+
141
+
142
+ class RBLNStableDiffusion3InpaintPipelineConfig(_RBLNStableDiffusion3PipelineBaseConfig):
143
+ _vae_uses_encoder = True
@@ -0,0 +1,124 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....transformers import RBLNCLIPTextModelConfig, RBLNCLIPTextModelWithProjectionConfig
19
+ from ..models import RBLNAutoencoderKLConfig, RBLNUNet2DConditionModelConfig
20
+
21
+
22
+ class _RBLNStableDiffusionXLPipelineBaseConfig(RBLNModelConfig):
23
+ submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
24
+ _vae_uses_encoder = False
25
+
26
+ def __init__(
27
+ self,
28
+ text_encoder: Optional[RBLNCLIPTextModelConfig] = None,
29
+ text_encoder_2: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
30
+ unet: Optional[RBLNUNet2DConditionModelConfig] = None,
31
+ vae: Optional[RBLNAutoencoderKLConfig] = None,
32
+ *,
33
+ batch_size: Optional[int] = None,
34
+ img_height: Optional[int] = None,
35
+ img_width: Optional[int] = None,
36
+ sample_size: Optional[Tuple[int, int]] = None,
37
+ image_size: Optional[Tuple[int, int]] = None,
38
+ guidance_scale: Optional[float] = None,
39
+ **kwargs,
40
+ ):
41
+ """
42
+ Args:
43
+ text_encoder (Optional[RBLNCLIPTextModelConfig]): Configuration for the primary text encoder component.
44
+ Initialized as RBLNCLIPTextModelConfig if not provided.
45
+ text_encoder_2 (Optional[RBLNCLIPTextModelWithProjectionConfig]): Configuration for the secondary text encoder component.
46
+ Initialized as RBLNCLIPTextModelWithProjectionConfig if not provided.
47
+ unet (Optional[RBLNUNet2DConditionModelConfig]): Configuration for the UNet model component.
48
+ Initialized as RBLNUNet2DConditionModelConfig if not provided.
49
+ vae (Optional[RBLNAutoencoderKLConfig]): Configuration for the VAE model component.
50
+ Initialized as RBLNAutoencoderKLConfig if not provided.
51
+ batch_size (Optional[int]): Batch size for inference, applied to all submodules.
52
+ img_height (Optional[int]): Height of the generated images.
53
+ img_width (Optional[int]): Width of the generated images.
54
+ sample_size (Optional[Tuple[int, int]]): Spatial dimensions for the UNet model.
55
+ image_size (Optional[Tuple[int, int]]): Alternative way to specify image dimensions.
56
+ Cannot be used together with img_height/img_width.
57
+ guidance_scale (Optional[float]): Scale for classifier-free guidance.
58
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
59
+
60
+ Raises:
61
+ ValueError: If both image_size and img_height/img_width are provided.
62
+
63
+ Note:
64
+ When guidance_scale > 1.0, the UNet batch size is automatically doubled to
65
+ accommodate classifier-free guidance.
66
+ """
67
+ super().__init__(**kwargs)
68
+ if image_size is not None and (img_height is not None or img_width is not None):
69
+ raise ValueError("image_size and img_height/img_width cannot both be provided")
70
+
71
+ if img_height is not None and img_width is not None:
72
+ image_size = (img_height, img_width)
73
+
74
+ self.text_encoder = self.init_submodule_config(RBLNCLIPTextModelConfig, text_encoder, batch_size=batch_size)
75
+ self.text_encoder_2 = self.init_submodule_config(
76
+ RBLNCLIPTextModelWithProjectionConfig, text_encoder_2, batch_size=batch_size
77
+ )
78
+ self.unet = self.init_submodule_config(
79
+ RBLNUNet2DConditionModelConfig,
80
+ unet,
81
+ sample_size=sample_size,
82
+ )
83
+ self.vae = self.init_submodule_config(
84
+ RBLNAutoencoderKLConfig,
85
+ vae,
86
+ batch_size=batch_size,
87
+ uses_encoder=self.__class__._vae_uses_encoder,
88
+ sample_size=image_size, # image size is equal to sample size in vae
89
+ )
90
+
91
+ # Get default guidance scale from original class to set UNet batch size
92
+ if guidance_scale is None:
93
+ guidance_scale = self.get_default_values_for_original_cls("__call__", ["guidance_scale"])["guidance_scale"]
94
+
95
+ if not self.unet.batch_size_is_specified:
96
+ do_classifier_free_guidance = guidance_scale > 1.0
97
+ if do_classifier_free_guidance:
98
+ self.unet.batch_size = self.text_encoder.batch_size * 2
99
+ else:
100
+ self.unet.batch_size = self.text_encoder.batch_size
101
+
102
+ @property
103
+ def batch_size(self):
104
+ return self.vae.batch_size
105
+
106
+ @property
107
+ def sample_size(self):
108
+ return self.unet.sample_size
109
+
110
+ @property
111
+ def image_size(self):
112
+ return self.vae.sample_size
113
+
114
+
115
+ class RBLNStableDiffusionXLPipelineConfig(_RBLNStableDiffusionXLPipelineBaseConfig):
116
+ _vae_uses_encoder = False
117
+
118
+
119
+ class RBLNStableDiffusionXLImg2ImgPipelineConfig(_RBLNStableDiffusionXLPipelineBaseConfig):
120
+ _vae_uses_encoder = True
121
+
122
+
123
+ class RBLNStableDiffusionXLInpaintPipelineConfig(_RBLNStableDiffusionXLPipelineBaseConfig):
124
+ _vae_uses_encoder = True
@@ -15,12 +15,12 @@
15
15
  import copy
16
16
  import importlib
17
17
  from os import PathLike
18
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
18
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
19
19
 
20
20
  import torch
21
21
 
22
+ from ..configuration_utils import ContextRblnConfig, RBLNModelConfig
22
23
  from ..modeling import RBLNModel
23
- from ..modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
24
24
  from ..utils.decorator_utils import remove_compile_time_kwargs
25
25
  from ..utils.logging import get_logger
26
26
 
@@ -31,6 +31,10 @@ if TYPE_CHECKING:
31
31
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
32
32
 
33
33
 
34
+ class RBLNDiffusionMixinConfig(RBLNModelConfig):
35
+ pass
36
+
37
+
34
38
  class RBLNDiffusionMixin:
35
39
  """
36
40
  RBLNDiffusionMixin provides essential functionalities for compiling Stable Diffusion pipeline components to run on RBLN NPUs.
@@ -41,17 +45,11 @@ class RBLNDiffusionMixin:
41
45
 
42
46
  1. Create a new pipeline class that inherits from both this mixin and the original StableDiffusionPipeline.
43
47
  2. Define the required _submodules class variable listing the components to be compiled.
44
- 3. If needed, implement get_default_rbln_config for custom configuration of submodules.
45
48
 
46
49
  Example:
47
50
  ```python
48
51
  class RBLNStableDiffusionPipeline(RBLNDiffusionMixin, StableDiffusionPipeline):
49
52
  _submodules = ["text_encoder", "unet", "vae"]
50
-
51
- @classmethod
52
- def get_default_rbln_config(cls, model, submodule_name, rbln_config):
53
- # Configuration for other submodules...
54
- pass
55
53
  ```
56
54
 
57
55
  Class Variables:
@@ -69,43 +67,8 @@ class RBLNDiffusionMixin:
69
67
  _connected_classes = {}
70
68
  _submodules = []
71
69
  _prefix = {}
72
-
73
- @classmethod
74
- def is_img2img_pipeline(cls):
75
- return "Img2Img" in cls.__name__
76
-
77
- @classmethod
78
- def is_inpaint_pipeline(cls):
79
- return "Inpaint" in cls.__name__
80
-
81
- @classmethod
82
- def get_submodule_rbln_config(
83
- cls, model: torch.nn.Module, submodule_name: str, rbln_config: Dict[str, Any]
84
- ) -> Dict[str, Any]:
85
- submodule = getattr(model, submodule_name)
86
- submodule_class_name = submodule.__class__.__name__
87
- if isinstance(submodule, torch.nn.Module):
88
- if submodule_class_name == "MultiControlNetModel":
89
- submodule_class_name = "ControlNetModel"
90
-
91
- submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"RBLN{submodule_class_name}")
92
-
93
- submodule_config = rbln_config.get(submodule_name, {})
94
- submodule_config = copy.deepcopy(submodule_config)
95
-
96
- pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
97
-
98
- submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
99
- submodule_config.update(
100
- {
101
- "img2img_pipeline": cls.is_img2img_pipeline(),
102
- "inpaint_pipeline": cls.is_inpaint_pipeline(),
103
- }
104
- )
105
- submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
106
- else:
107
- raise ValueError(f"submodule {submodule_name} isn't supported")
108
- return submodule_config
70
+ _rbln_config_class = None
71
+ _hf_class = None
109
72
 
110
73
  @staticmethod
111
74
  def _maybe_apply_and_fuse_lora(
@@ -146,7 +109,30 @@ class RBLNDiffusionMixin:
146
109
  return model
147
110
 
148
111
  @classmethod
149
- @use_rbln_config
112
+ def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
113
+ """
114
+ Lazily loads and caches the corresponding RBLN model config class.
115
+ """
116
+ if cls._rbln_config_class is None:
117
+ rbln_config_class_name = cls.__name__ + "Config"
118
+ library = importlib.import_module("optimum.rbln")
119
+ cls._rbln_config_class = getattr(library, rbln_config_class_name, None)
120
+ if cls._rbln_config_class is None:
121
+ raise ValueError(
122
+ f"RBLN config class {rbln_config_class_name} not found. This is an internal error. "
123
+ "Please report it to the developers."
124
+ )
125
+ return cls._rbln_config_class
126
+
127
+ @classmethod
128
+ def get_hf_class(cls):
129
+ if cls._hf_class is None:
130
+ hf_cls_name = cls.__name__[4:]
131
+ library = importlib.import_module("diffusers")
132
+ cls._hf_class = getattr(library, hf_cls_name, None)
133
+ return cls._hf_class
134
+
135
+ @classmethod
150
136
  def from_pretrained(
151
137
  cls,
152
138
  model_id: str,
@@ -158,7 +144,49 @@ class RBLNDiffusionMixin:
158
144
  lora_weights_names: Optional[Union[str, List[str]]] = None,
159
145
  lora_scales: Optional[Union[float, List[float]]] = None,
160
146
  **kwargs,
161
- ) -> RBLNModel:
147
+ ) -> "RBLNDiffusionMixin":
148
+ """
149
+ Load a pretrained diffusion pipeline from a model checkpoint, with optional compilation for RBLN NPUs.
150
+
151
+ This method has two distinct operating modes:
152
+ - When `export=True`: Takes a PyTorch-based diffusion model, compiles it for RBLN NPUs, and loads the compiled model
153
+ - When `export=False`: Loads an already compiled RBLN model from `model_id` without recompilation
154
+
155
+ It supports various diffusion pipelines including Stable Diffusion, Kandinsky, ControlNet, and other diffusers-based models.
156
+
157
+ Args:
158
+ model_id (`str`):
159
+ The model ID or path to the pretrained model to load. Can be either:
160
+ - A model ID from the HuggingFace Hub
161
+ - A local path to a saved model directory
162
+ export (`bool`, *optional*, defaults to `False`):
163
+ If True, takes a PyTorch model from `model_id` and compiles it for RBLN NPU execution.
164
+ If False, loads an already compiled RBLN model from `model_id` without recompilation.
165
+ model_save_dir (`os.PathLike`, *optional*):
166
+ Directory to save the compiled model artifacts. Only used when `export=True`.
167
+ If not provided and `export=True`, a temporary directory is used.
168
+ rbln_config (`Dict[str, Any]`, *optional*, defaults to `{}`):
169
+ Configuration options for RBLN compilation. Can include settings for specific submodules
170
+ such as `text_encoder`, `unet`, and `vae`. Configuration can be tailored to the specific
171
+ pipeline being compiled.
172
+ lora_ids (`str` or `List[str]`, *optional*):
173
+ LoRA adapter ID(s) to load and apply before compilation. LoRA weights are fused
174
+ into the model weights during compilation. Only used when `export=True`.
175
+ lora_weights_names (`str` or `List[str]`, *optional*):
176
+ Names of specific LoRA weight files to load, corresponding to lora_ids. Only used when `export=True`.
177
+ lora_scales (`float` or `List[float]`, *optional*):
178
+ Scaling factor(s) to apply to the LoRA adapter(s). Only used when `export=True`.
179
+ **kwargs:
180
+ Additional arguments to pass to the underlying diffusion pipeline constructor or the
181
+ RBLN compilation process. These may include parameters specific to individual submodules
182
+ or the particular diffusion pipeline being used.
183
+
184
+ Returns:
185
+ `RBLNDiffusionMixin`: A compiled or loaded diffusion pipeline that can be used for inference on RBLN NPU.
186
+ The returned object is an instance of the class that called this method, inheriting from RBLNDiffusionMixin.
187
+ """
188
+ rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
189
+
162
190
  if export:
163
191
  # keep submodules if user passed any of them.
164
192
  passed_submodules = {
@@ -168,22 +196,12 @@ class RBLNDiffusionMixin:
168
196
  else:
169
197
  # raise error if any of submodules are torch module.
170
198
  model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
171
- rbln_config = cls._flatten_rbln_config(rbln_config)
172
199
  for submodule_name in cls._submodules:
173
200
  if isinstance(kwargs.get(submodule_name), torch.nn.Module):
174
201
  raise AssertionError(
175
202
  f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
176
203
  )
177
204
 
178
- submodule_config = rbln_config.get(submodule_name, {})
179
-
180
- for key, value in rbln_config.items():
181
- if key in RUNTIME_KEYWORDS and key not in submodule_config:
182
- submodule_config[key] = value
183
-
184
- if not any(kwd in submodule_config for kwd in RUNTIME_KEYWORDS):
185
- continue
186
-
187
205
  module_name, class_name = model_index_config[submodule_name]
188
206
  if module_name != "optimum.rbln":
189
207
  raise ValueError(
@@ -192,19 +210,19 @@ class RBLNDiffusionMixin:
192
210
  "Expected 'optimum.rbln'. Please check the model_index.json configuration."
193
211
  )
194
212
 
195
- submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), class_name)
196
-
213
+ submodule_cls: Type[RBLNModel] = getattr(importlib.import_module("optimum.rbln"), class_name)
214
+ submodule_config = getattr(rbln_config, submodule_name)
197
215
  submodule = submodule_cls.from_pretrained(
198
216
  model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
199
217
  )
200
218
  kwargs[submodule_name] = submodule
201
219
 
202
220
  with ContextRblnConfig(
203
- device=rbln_config.get("device"),
204
- device_map=rbln_config.get("device_map"),
205
- create_runtimes=rbln_config.get("create_runtimes"),
206
- optimize_host_mem=rbln_config.get("optimize_host_memory"),
207
- activate_profiler=rbln_config.get("activate_profiler"),
221
+ device=rbln_config.device,
222
+ device_map=rbln_config.device_map,
223
+ create_runtimes=rbln_config.create_runtimes,
224
+ optimize_host_mem=rbln_config.optimize_host_memory,
225
+ activate_profiler=rbln_config.activate_profiler,
208
226
  ):
209
227
  model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
210
228
 
@@ -224,78 +242,27 @@ class RBLNDiffusionMixin:
224
242
  compiled_submodules = cls._compile_submodules(model, passed_submodules, model_save_dir, rbln_config)
225
243
  return cls._construct_pipe(model, compiled_submodules, model_save_dir, rbln_config)
226
244
 
227
- @classmethod
228
- def _prepare_rbln_config(
229
- cls,
230
- rbln_config,
231
- ) -> Dict[str, Any]:
232
- prepared_config = {}
233
- for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
234
- connected_pipe_config = rbln_config.pop(connected_pipe_name, {})
235
- prefix = cls._prefix.get(connected_pipe_name, "")
236
- guidance_scale = rbln_config.pop(f"{prefix}guidance_scale", None)
237
- if "guidance_scale" not in connected_pipe_config and guidance_scale is not None:
238
- connected_pipe_config["guidance_scale"] = guidance_scale
239
- for submodule_name in connected_pipe_cls._submodules:
240
- submodule_config = rbln_config.pop(prefix + submodule_name, {})
241
- if submodule_name not in connected_pipe_config:
242
- connected_pipe_config[submodule_name] = {}
243
- connected_pipe_config[submodule_name].update(
244
- {k: v for k, v in submodule_config.items() if k not in connected_pipe_config[submodule_name]}
245
- )
246
- prepared_config[connected_pipe_name] = connected_pipe_config
247
- prepared_config.update(rbln_config)
248
- return prepared_config
249
-
250
- @classmethod
251
- def _flatten_rbln_config(
252
- cls,
253
- rbln_config,
254
- ) -> Dict[str, Any]:
255
- prepared_config = cls._prepare_rbln_config(rbln_config)
256
- flattened_config = {}
257
- pipe_global_config = {k: v for k, v in prepared_config.items() if k not in cls._connected_classes.keys()}
258
- for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
259
- connected_pipe_config = prepared_config.pop(connected_pipe_name)
260
- prefix = cls._prefix.get(connected_pipe_name, "")
261
- connected_pipe_global_config = {
262
- k: v for k, v in connected_pipe_config.items() if k not in connected_pipe_cls._submodules
263
- }
264
- for submodule_name in connected_pipe_cls._submodules:
265
- flattened_config[prefix + submodule_name] = connected_pipe_config[submodule_name]
266
- flattened_config[prefix + submodule_name].update(
267
- {
268
- k: v
269
- for k, v in connected_pipe_global_config.items()
270
- if k not in flattened_config[prefix + submodule_name]
271
- }
272
- )
273
- flattened_config.update(pipe_global_config)
274
- return flattened_config
275
-
276
245
  @classmethod
277
246
  def _compile_pipelines(
278
247
  cls,
279
248
  model: torch.nn.Module,
280
249
  passed_submodules: Dict[str, RBLNModel],
281
250
  model_save_dir: Optional[PathLike],
282
- rbln_config: Dict[str, Any],
251
+ rbln_config: "RBLNDiffusionMixinConfig",
283
252
  ) -> Dict[str, RBLNModel]:
284
253
  compiled_submodules = {}
285
-
286
- rbln_config = cls._prepare_rbln_config(rbln_config)
287
- pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._connected_classes.keys()}
288
254
  for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
289
255
  connected_pipe_submodules = {}
290
256
  prefix = cls._prefix.get(connected_pipe_name, "")
291
257
  for submodule_name in connected_pipe_cls._submodules:
292
258
  connected_pipe_submodules[submodule_name] = passed_submodules.get(prefix + submodule_name, None)
293
259
  connected_pipe = getattr(model, connected_pipe_name)
294
- connected_pipe_config = {}
295
- connected_pipe_config.update(pipe_global_config)
296
- connected_pipe_config.update(rbln_config[connected_pipe_name])
297
260
  connected_pipe_compiled_submodules = connected_pipe_cls._compile_submodules(
298
- connected_pipe, connected_pipe_submodules, model_save_dir, connected_pipe_config, prefix
261
+ connected_pipe,
262
+ connected_pipe_submodules,
263
+ model_save_dir,
264
+ getattr(rbln_config, connected_pipe_name),
265
+ prefix,
299
266
  )
300
267
  for submodule_name, compiled_submodule in connected_pipe_compiled_submodules.items():
301
268
  compiled_submodules[prefix + submodule_name] = compiled_submodule
@@ -307,14 +274,19 @@ class RBLNDiffusionMixin:
307
274
  model: torch.nn.Module,
308
275
  passed_submodules: Dict[str, RBLNModel],
309
276
  model_save_dir: Optional[PathLike],
310
- rbln_config: Dict[str, Any],
277
+ rbln_config: RBLNDiffusionMixinConfig,
311
278
  prefix: Optional[str] = "",
312
279
  ) -> Dict[str, RBLNModel]:
313
280
  compiled_submodules = {}
314
281
 
315
282
  for submodule_name in cls._submodules:
316
283
  submodule = passed_submodules.get(submodule_name) or getattr(model, submodule_name, None)
317
- submodule_rbln_config = cls.get_submodule_rbln_config(model, submodule_name, rbln_config)
284
+
285
+ if getattr(rbln_config, submodule_name, None) is None:
286
+ raise ValueError(f"RBLN config for submodule {submodule_name} is not provided.")
287
+
288
+ submodule_rbln_cls: Type[RBLNModel] = getattr(rbln_config, submodule_name).rbln_model_cls
289
+ rbln_config = submodule_rbln_cls.update_rbln_config_using_pipe(model, rbln_config, submodule_name)
318
290
 
319
291
  if submodule is None:
320
292
  raise ValueError(f"submodule ({submodule_name}) cannot be accessed since it is not provided.")
@@ -325,7 +297,7 @@ class RBLNDiffusionMixin:
325
297
  submodule = cls._compile_multicontrolnet(
326
298
  controlnets=submodule,
327
299
  model_save_dir=model_save_dir,
328
- controlnet_rbln_config=submodule_rbln_config,
300
+ controlnet_rbln_config=getattr(rbln_config, submodule_name),
329
301
  prefix=prefix,
330
302
  )
331
303
  elif isinstance(submodule, torch.nn.Module):
@@ -337,7 +309,7 @@ class RBLNDiffusionMixin:
337
309
  model=submodule,
338
310
  subfolder=subfolder,
339
311
  model_save_dir=model_save_dir,
340
- rbln_config=submodule_rbln_config,
312
+ rbln_config=getattr(rbln_config, submodule_name),
341
313
  )
342
314
  else:
343
315
  raise ValueError(f"Unknown class of submodule({submodule_name}) : {submodule.__class__.__name__} ")
@@ -350,22 +322,24 @@ class RBLNDiffusionMixin:
350
322
  cls,
351
323
  controlnets: "MultiControlNetModel",
352
324
  model_save_dir: Optional[PathLike],
353
- controlnet_rbln_config: Dict[str, Any],
325
+ controlnet_rbln_config: RBLNModelConfig,
354
326
  prefix: Optional[str] = "",
355
327
  ):
356
328
  # Compile multiple ControlNet models for a MultiControlNet setup
357
329
  from .models.controlnet import RBLNControlNetModel
358
330
  from .pipelines.controlnet import RBLNMultiControlNetModel
359
331
 
360
- compiled_controlnets = [
361
- RBLNControlNetModel.from_model(
362
- model=controlnet,
363
- subfolder=f"{prefix}controlnet" if i == 0 else f"{prefix}controlnet_{i}",
364
- model_save_dir=model_save_dir,
365
- rbln_config=controlnet_rbln_config,
332
+ compiled_controlnets = []
333
+ for i, controlnet in enumerate(controlnets.nets):
334
+ _controlnet_rbln_config = copy.deepcopy(controlnet_rbln_config)
335
+ compiled_controlnets.append(
336
+ RBLNControlNetModel.from_model(
337
+ model=controlnet,
338
+ subfolder=f"{prefix}controlnet" if i == 0 else f"{prefix}controlnet_{i}",
339
+ model_save_dir=model_save_dir,
340
+ rbln_config=_controlnet_rbln_config,
341
+ )
366
342
  )
367
- for i, controlnet in enumerate(controlnets.nets)
368
- ]
369
343
  return RBLNMultiControlNetModel(compiled_controlnets)
370
344
 
371
345
  @classmethod
@@ -412,7 +386,7 @@ class RBLNDiffusionMixin:
412
386
  # overwrite to replace incorrect config
413
387
  model.save_config(model_save_dir)
414
388
 
415
- if rbln_config.get("optimize_host_memory") is False:
389
+ if rbln_config.optimize_host_memory is False:
416
390
  # Keep compiled_model objs to further analysis. -> TODO: remove soon...
417
391
  model.compiled_models = []
418
392
  for name in cls._submodules:
@@ -441,9 +415,9 @@ class RBLNDiffusionMixin:
441
415
  kwargs["height"] = compiled_image_size[0]
442
416
  kwargs["width"] = compiled_image_size[1]
443
417
 
444
- compiled_num_frames = self.unet.rbln_config.model_cfg.get("num_frames", None)
418
+ compiled_num_frames = self.unet.rbln_config.num_frames
445
419
  if compiled_num_frames is not None:
446
- kwargs["num_frames"] = self.unet.rbln_config.model_cfg.get("num_frames")
420
+ kwargs["num_frames"] = compiled_num_frames
447
421
  return kwargs
448
422
  ```
449
423
  """