optimum-rbln 0.8.3a1__py3-none-any.whl → 0.8.3a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

optimum/rbln/__init__.py CHANGED
@@ -169,6 +169,9 @@ _import_structure = {
169
169
  "RBLNAutoencoderKLConfig",
170
170
  "RBLNAutoencoderKLCosmos",
171
171
  "RBLNAutoencoderKLCosmosConfig",
172
+ "RBLNAutoPipelineForImage2Image",
173
+ "RBLNAutoPipelineForInpainting",
174
+ "RBLNAutoPipelineForText2Image",
172
175
  "RBLNControlNetModel",
173
176
  "RBLNControlNetModelConfig",
174
177
  "RBLNCosmosTextToWorldPipeline",
@@ -238,6 +241,9 @@ if TYPE_CHECKING:
238
241
  RBLNAutoencoderKLConfig,
239
242
  RBLNAutoencoderKLCosmos,
240
243
  RBLNAutoencoderKLCosmosConfig,
244
+ RBLNAutoPipelineForImage2Image,
245
+ RBLNAutoPipelineForInpainting,
246
+ RBLNAutoPipelineForText2Image,
241
247
  RBLNControlNetModel,
242
248
  RBLNControlNetModelConfig,
243
249
  RBLNCosmosSafetyChecker,
@@ -1,7 +1,14 @@
1
1
  # file generated by setuptools-scm
2
2
  # don't change, don't track in version control
3
3
 
4
- __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
5
12
 
6
13
  TYPE_CHECKING = False
7
14
  if TYPE_CHECKING:
@@ -9,13 +16,19 @@ if TYPE_CHECKING:
9
16
  from typing import Union
10
17
 
11
18
  VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
12
20
  else:
13
21
  VERSION_TUPLE = object
22
+ COMMIT_ID = object
14
23
 
15
24
  version: str
16
25
  __version__: str
17
26
  __version_tuple__: VERSION_TUPLE
18
27
  version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
19
30
 
20
- __version__ = version = '0.8.3a1'
21
- __version_tuple__ = version_tuple = (0, 8, 3, 'a1')
31
+ __version__ = version = '0.8.3a2'
32
+ __version_tuple__ = version_tuple = (0, 8, 3, 'a2')
33
+
34
+ __commit_id__ = commit_id = None
@@ -59,6 +59,9 @@ _import_structure = {
59
59
  "RBLNVQModelConfig",
60
60
  ],
61
61
  "pipelines": [
62
+ "RBLNAutoPipelineForImage2Image",
63
+ "RBLNAutoPipelineForInpainting",
64
+ "RBLNAutoPipelineForText2Image",
62
65
  "RBLNCosmosTextToWorldPipeline",
63
66
  "RBLNCosmosVideoToWorldPipeline",
64
67
  "RBLNCosmosSafetyChecker",
@@ -144,6 +147,9 @@ if TYPE_CHECKING:
144
147
  RBLNVQModel,
145
148
  )
146
149
  from .pipelines import (
150
+ RBLNAutoPipelineForImage2Image,
151
+ RBLNAutoPipelineForInpainting,
152
+ RBLNAutoPipelineForText2Image,
147
153
  RBLNCosmosSafetyChecker,
148
154
  RBLNCosmosTextToWorldPipeline,
149
155
  RBLNCosmosVideoToWorldPipeline,
@@ -18,6 +18,11 @@ from transformers.utils import _LazyModule
18
18
 
19
19
 
20
20
  _import_structure = {
21
+ "auto_pipeline": [
22
+ "RBLNAutoPipelineForImage2Image",
23
+ "RBLNAutoPipelineForInpainting",
24
+ "RBLNAutoPipelineForText2Image",
25
+ ],
21
26
  "controlnet": [
22
27
  "RBLNMultiControlNetModel",
23
28
  "RBLNStableDiffusionControlNetImg2ImgPipeline",
@@ -56,6 +61,11 @@ _import_structure = {
56
61
  ],
57
62
  }
58
63
  if TYPE_CHECKING:
64
+ from .auto_pipeline import (
65
+ RBLNAutoPipelineForImage2Image,
66
+ RBLNAutoPipelineForInpainting,
67
+ RBLNAutoPipelineForText2Image,
68
+ )
59
69
  from .controlnet import (
60
70
  RBLNMultiControlNetModel,
61
71
  RBLNStableDiffusionControlNetImg2ImgPipeline,
@@ -0,0 +1,237 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import importlib
17
+ from typing import Type
18
+
19
+ from diffusers.models.controlnets import ControlNetUnionModel
20
+ from diffusers.pipelines.auto_pipeline import (
21
+ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
22
+ AUTO_INPAINT_PIPELINES_MAPPING,
23
+ AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
24
+ AutoPipelineForImage2Image,
25
+ AutoPipelineForInpainting,
26
+ AutoPipelineForText2Image,
27
+ _get_task_class,
28
+ )
29
+ from huggingface_hub.utils import validate_hf_hub_args
30
+
31
+ from optimum.rbln.modeling_base import RBLNBaseModel
32
+ from optimum.rbln.utils.model_utils import (
33
+ MODEL_MAPPING,
34
+ convert_hf_to_rbln_model_name,
35
+ convert_rbln_to_hf_model_name,
36
+ get_rbln_model_cls,
37
+ )
38
+
39
+
40
+ class RBLNAutoPipelineBase:
41
+ _model_mapping = None
42
+ _model_mapping_names = None
43
+
44
+ @classmethod
45
+ def get_rbln_cls(cls, pretrained_model_name_or_path, export=True, **kwargs):
46
+ if export:
47
+ hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
48
+ rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
49
+ else:
50
+ rbln_class_name = cls.get_rbln_model_cls_name(pretrained_model_name_or_path, **kwargs)
51
+ if convert_rbln_to_hf_model_name(rbln_class_name) not in cls._model_mapping_names:
52
+ raise ValueError(
53
+ f"The architecture '{rbln_class_name}' is not supported by the `{cls.__name__}.from_pretrained()` method. "
54
+ "Please use the `from_pretrained()` method of the appropriate class to load this model, "
55
+ f"or directly use '{rbln_class_name}.from_pretrained()`."
56
+ )
57
+
58
+ try:
59
+ rbln_cls = get_rbln_model_cls(rbln_class_name)
60
+ except AttributeError as e:
61
+ raise AttributeError(
62
+ f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{pretrained_model_name_or_path}'. "
63
+ "Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
64
+ ) from e
65
+
66
+ return rbln_cls
67
+
68
+ @classmethod
69
+ def get_rbln_model_cls_name(cls, pretrained_model_name_or_path, **kwargs):
70
+ """
71
+ Retrieve the path to the compiled model directory for a given RBLN model.
72
+
73
+ Args:
74
+ pretrained_model_name_or_path (str): Identifier of the model.
75
+
76
+ Returns:
77
+ str: Path to the compiled model directory.
78
+ """
79
+ model_index_config = cls.load_config(pretrained_model_name_or_path)
80
+
81
+ if "_class_name" not in model_index_config:
82
+ raise ValueError(
83
+ "The `_class_name` field is missing from model_index_config. This is unexpected and should be reported as an issue. "
84
+ "Please use the `from_pretrained()` method of the appropriate class to load this model."
85
+ )
86
+
87
+ return model_index_config["_class_name"]
88
+
89
+ @classmethod
90
+ def infer_hf_model_class(
91
+ cls,
92
+ pretrained_model_or_path,
93
+ cache_dir=None,
94
+ force_download=False,
95
+ proxies=None,
96
+ token=None,
97
+ local_files_only=False,
98
+ revision=None,
99
+ **kwargs,
100
+ ):
101
+ config = cls.load_config(
102
+ pretrained_model_or_path,
103
+ cache_dir=cache_dir,
104
+ force_download=force_download,
105
+ proxies=proxies,
106
+ token=token,
107
+ local_files_only=local_files_only,
108
+ revision=revision,
109
+ )
110
+ pipeline_key_name = cls.get_pipeline_key_name(config, **kwargs)
111
+
112
+ pipeline_cls = _get_task_class(cls._model_mapping, pipeline_key_name)
113
+
114
+ return pipeline_cls
115
+
116
+ @classmethod
117
+ def get_pipeline_key_name(cls, config, **kwargs):
118
+ orig_class_name = config["_class_name"]
119
+ if "ControlPipeline" in orig_class_name:
120
+ to_replace = "ControlPipeline"
121
+ else:
122
+ to_replace = "Pipeline"
123
+
124
+ if "controlnet" in kwargs:
125
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
126
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
127
+ else:
128
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
129
+ if "enable_pag" in kwargs:
130
+ enable_pag = kwargs.pop("enable_pag")
131
+ if enable_pag:
132
+ orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
133
+
134
+ return orig_class_name
135
+
136
+ @classmethod
137
+ @validate_hf_hub_args
138
+ def from_pretrained(cls, model_id, **kwargs):
139
+ rbln_cls = cls.get_rbln_cls(model_id, **kwargs)
140
+ return rbln_cls.from_pretrained(model_id, **kwargs)
141
+
142
+ @classmethod
143
+ def from_model(cls, model, **kwargs):
144
+ rbln_cls = get_rbln_model_cls(f"RBLN{model.__class__.__name__}")
145
+ return rbln_cls.from_model(model, **kwargs)
146
+
147
+ @staticmethod
148
+ def register(rbln_cls: Type[RBLNBaseModel], exist_ok=False):
149
+ """
150
+ Register a new RBLN model class.
151
+
152
+ Args:
153
+ rbln_cls (Type[RBLNBaseModel]): The RBLN model class to register.
154
+ exist_ok (bool): Whether to allow registering an already registered model.
155
+ """
156
+ if not issubclass(rbln_cls, RBLNBaseModel):
157
+ raise ValueError("`rbln_cls` must be a subclass of RBLNBaseModel.")
158
+
159
+ native_cls = getattr(importlib.import_module("optimum.rbln"), rbln_cls.__name__, None)
160
+ if rbln_cls.__name__ in MODEL_MAPPING or native_cls is not None:
161
+ if not exist_ok:
162
+ raise ValueError(f"Model for {rbln_cls.__name__} already registered.")
163
+
164
+ MODEL_MAPPING[rbln_cls.__name__] = rbln_cls
165
+
166
+
167
+ class RBLNAutoPipelineForText2Image(RBLNAutoPipelineBase, AutoPipelineForText2Image):
168
+ _model_mapping = AUTO_TEXT2IMAGE_PIPELINES_MAPPING
169
+ _model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_TEXT2IMAGE_PIPELINES_MAPPING.items()}
170
+
171
+
172
+ class RBLNAutoPipelineForImage2Image(RBLNAutoPipelineBase, AutoPipelineForImage2Image):
173
+ _model_mapping = AUTO_IMAGE2IMAGE_PIPELINES_MAPPING
174
+ _model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.items()}
175
+
176
+ @classmethod
177
+ def get_pipeline_key_name(cls, config, **kwargs):
178
+ orig_class_name = config["_class_name"]
179
+ # the `orig_class_name` can be:
180
+ # `- *Pipeline` (for regular text-to-image checkpoint)
181
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
182
+ # `- *Img2ImgPipeline` (for refiner checkpoint)
183
+ if "Img2Img" in orig_class_name:
184
+ to_replace = "Img2ImgPipeline"
185
+ elif "ControlPipeline" in orig_class_name:
186
+ to_replace = "ControlPipeline"
187
+ else:
188
+ to_replace = "Pipeline"
189
+
190
+ if "controlnet" in kwargs:
191
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
192
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
193
+ else:
194
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
195
+ if "enable_pag" in kwargs:
196
+ enable_pag = kwargs.pop("enable_pag")
197
+ if enable_pag:
198
+ orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
199
+
200
+ if to_replace == "ControlPipeline":
201
+ orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
202
+
203
+ return orig_class_name
204
+
205
+
206
+ class RBLNAutoPipelineForInpainting(RBLNAutoPipelineBase, AutoPipelineForInpainting):
207
+ _model_mapping = AUTO_INPAINT_PIPELINES_MAPPING
208
+ _model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_INPAINT_PIPELINES_MAPPING.items()}
209
+
210
+ @classmethod
211
+ def get_pipeline_key_name(cls, config, **kwargs):
212
+ orig_class_name = config["_class_name"]
213
+
214
+ # The `orig_class_name`` can be:
215
+ # `- *InpaintPipeline` (for inpaint-specific checkpoint)
216
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
217
+ # - or *Pipeline (for regular text-to-image checkpoint)
218
+ if "Inpaint" in orig_class_name:
219
+ to_replace = "InpaintPipeline"
220
+ elif "ControlPipeline" in orig_class_name:
221
+ to_replace = "ControlPipeline"
222
+ else:
223
+ to_replace = "Pipeline"
224
+
225
+ if "controlnet" in kwargs:
226
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
227
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
228
+ else:
229
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
230
+ if "enable_pag" in kwargs:
231
+ enable_pag = kwargs.pop("enable_pag")
232
+ if enable_pag:
233
+ orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
234
+ if to_replace == "ControlPipeline":
235
+ orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
236
+
237
+ return orig_class_name
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.8.3a1
3
+ Version: 0.8.3a2
4
4
  Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -1,9 +1,9 @@
1
- optimum/rbln/__init__.py,sha256=i3WWddZ0okF5dQN3B_2wfM7NTnQ37lwdD2udzjMRGH8,17140
2
- optimum/rbln/__version__.py,sha256=moPf49BSKprAMyeg6-M2OWNQW2vr0prDpm8YlgXAXOY,519
1
+ optimum/rbln/__init__.py,sha256=YhaBhcyu6BgoJrprUogLGAmiBaHayvg6Tjm6PpfJETw,17382
2
+ optimum/rbln/__version__.py,sha256=LoGi14U0L2os-fSHKgBIGeByegJLodfXKteGMBVsCEc,712
3
3
  optimum/rbln/configuration_utils.py,sha256=xneqnRWSUVROqpzbTrBACex42-L9zwo3eSjfHjFuhv4,33072
4
4
  optimum/rbln/modeling.py,sha256=0CMQnpVvW9evNrTFHM2XFbNpRY1HkbFzYJ5sRyYFq0o,14293
5
5
  optimum/rbln/modeling_base.py,sha256=gHfqIO6lKT8smkUthUuRHnbITpxHpnDeBPT8iTeasCk,24575
6
- optimum/rbln/diffusers/__init__.py,sha256=cvyJaFRU1sP1WeRjWrxMOm-5vr0c4X-TD8eqQ21XIgc,6990
6
+ optimum/rbln/diffusers/__init__.py,sha256=1tgU_xWA42BmInqu9bBz_5R_E9TGhhK3mI06YlaiTLg,7232
7
7
  optimum/rbln/diffusers/modeling_diffusers.py,sha256=TAuMb7PSMjNwK7mh5ItE_CtAEgYeZKI27XkFFmxjHlQ,19902
8
8
  optimum/rbln/diffusers/configurations/__init__.py,sha256=vMRnPY4s-Uju43xP038D2EA18X_mhy2YfsZVpSU-VoA,1322
9
9
  optimum/rbln/diffusers/configurations/models/__init__.py,sha256=7q95gtgDzCeIBogGw8SLQoHT4Wch7vpLJVF2UQovuoo,567
@@ -35,7 +35,8 @@ optimum/rbln/diffusers/models/transformers/transformer_cosmos.py,sha256=UQ_R7RVJ
35
35
  optimum/rbln/diffusers/models/transformers/transformer_sd3.py,sha256=yF7sS0QvawowpV9hR5GeT8DaE8CCp3mj1njHHd9cKTc,6630
36
36
  optimum/rbln/diffusers/models/unets/__init__.py,sha256=MaICuK9CWjgzejXy8y2NDrphuEq1rkzanF8u45k6O5I,655
37
37
  optimum/rbln/diffusers/models/unets/unet_2d_condition.py,sha256=v3WS9EGKROE_QClXrxC7rmRko1BspAvAbeIfh83LK88,15832
38
- optimum/rbln/diffusers/pipelines/__init__.py,sha256=Ft1i48HP3wVi5t7PpIPNhL-bcxpLfwyZ5kuaTECAx1A,3392
38
+ optimum/rbln/diffusers/pipelines/__init__.py,sha256=r8mu21102cKXdkG1II9tpfpUS6wuyren2oK9y_MptZY,3703
39
+ optimum/rbln/diffusers/pipelines/auto_pipeline.py,sha256=oGZWXfj82w695D2NiYUitgoWiwP2Z4PlgA3q6hoOKww,9502
39
40
  optimum/rbln/diffusers/pipelines/controlnet/__init__.py,sha256=n1Ef22TSeax-kENi_d8K6wGGHSNEo9QkUeygELHgcao,983
40
41
  optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py,sha256=3S9dogIHW8Bqg5kIlCudhCQG-4g3FcdOPEWhBOf7CJA,4059
41
42
  optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py,sha256=G96bh4D9Cu-w4F9gZBQF6wNzhJQv9kvI34ZFsuEDjSw,35714
@@ -226,7 +227,7 @@ optimum/rbln/utils/model_utils.py,sha256=4k5879Kh75m3x_vS4-qOGfqsOiAvc2kdNFFfvsF
226
227
  optimum/rbln/utils/runtime_utils.py,sha256=R6uXDbeJP03-FWdd4vthNe2D4aCra5n12E3WB1ifiGM,7933
227
228
  optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
228
229
  optimum/rbln/utils/submodule.py,sha256=w5mgPgncI740gVKMu3S-69DGNdUSI0bTZxegQGcZ98Y,5011
229
- optimum_rbln-0.8.3a1.dist-info/METADATA,sha256=InRecmrzQW8U2sU6cH4nI94K2WQ_Y7WX_u7qZJMxXao,5299
230
- optimum_rbln-0.8.3a1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
231
- optimum_rbln-0.8.3a1.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
232
- optimum_rbln-0.8.3a1.dist-info/RECORD,,
230
+ optimum_rbln-0.8.3a2.dist-info/METADATA,sha256=KAOx0J5beZebrxsAf9AsklRO43eTWaw222WX1iInnpk,5299
231
+ optimum_rbln-0.8.3a2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
232
+ optimum_rbln-0.8.3a2.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
233
+ optimum_rbln-0.8.3a2.dist-info/RECORD,,