optimum-rbln 0.7.5rc0__py3-none-any.whl → 0.7.5rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.5rc0'
21
- __version_tuple__ = version_tuple = (0, 7, 5, 'rc0')
20
+ __version__ = version = '0.7.5rc1'
21
+ __version_tuple__ = version_tuple = (0, 7, 5, 'rc1')
@@ -58,8 +58,7 @@ class _PriorTransformer(torch.nn.Module):
58
58
  class RBLNPriorTransformer(RBLNModel):
59
59
  hf_library_name = "diffusers"
60
60
  auto_model_class = PriorTransformer
61
- output_class = PriorTransformerOutput
62
- output_key = "predicted_image_embedding"
61
+ _output_class = PriorTransformerOutput
63
62
 
64
63
  def __post_init__(self, **kwargs):
65
64
  super().__post_init__(**kwargs)
@@ -61,8 +61,7 @@ class SD3Transformer2DModelWrapper(torch.nn.Module):
61
61
  class RBLNSD3Transformer2DModel(RBLNModel):
62
62
  hf_library_name = "diffusers"
63
63
  auto_model_class = SD3Transformer2DModel
64
- output_class = Transformer2DModelOutput
65
- output_key = "sample"
64
+ _output_class = Transformer2DModelOutput
66
65
 
67
66
  def __post_init__(self, **kwargs):
68
67
  super().__post_init__(**kwargs)
@@ -143,8 +143,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
143
143
  hf_library_name = "diffusers"
144
144
  auto_model_class = UNet2DConditionModel
145
145
  _rbln_config_class = RBLNUNet2DConditionModelConfig
146
- output_class = UNet2DConditionOutput
147
- output_key = "sample"
146
+ _output_class = UNet2DConditionOutput
148
147
 
149
148
  def __post_init__(self, **kwargs):
150
149
  super().__post_init__(**kwargs)
optimum/rbln/modeling.py CHANGED
@@ -14,7 +14,7 @@
14
14
 
15
15
  from pathlib import Path
16
16
  from tempfile import TemporaryDirectory
17
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
17
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union, get_args, get_origin, get_type_hints
18
18
 
19
19
  import rebel
20
20
  import torch
@@ -49,8 +49,7 @@ class RBLNModel(RBLNBaseModel):
49
49
  ```
50
50
  """
51
51
 
52
- output_class = None
53
- output_key = "last_hidden_state"
52
+ _output_class = None
54
53
 
55
54
  @classmethod
56
55
  def update_kwargs(cls, kwargs):
@@ -245,16 +244,61 @@ class RBLNModel(RBLNBaseModel):
245
244
  # Format output according to task requirements
246
245
  return self._prepare_output(output, return_dict)
247
246
 
247
+ @classmethod
248
+ def get_hf_output_class(cls):
249
+ """
250
+ Dynamically gets the output class from the corresponding HuggingFace model class.
251
+
252
+ Returns:
253
+ type: The appropriate output class from transformers or diffusers
254
+ """
255
+ if cls._output_class:
256
+ return cls._output_class
257
+
258
+ hf_class = cls.get_hf_class()
259
+ if hf_class is None:
260
+ raise ValueError(f"No HuggingFace model class found for {cls.__name__}")
261
+
262
+ hints = get_type_hints(hf_class.forward) if hasattr(hf_class, "forward") else {}
263
+ ret = hints.get("return")
264
+
265
+ if ret is not None:
266
+ candidates = get_args(ret) if get_origin(ret) is Union else (ret,)
267
+
268
+ for t in candidates:
269
+ if t is type(None): # Skip NoneType in Union
270
+ continue
271
+ mod = getattr(t, "__module__", "")
272
+ if "transformers" in mod or "diffusers" in mod:
273
+ cls._output_class = t
274
+ return t
275
+
276
+ # Fallback to BaseModelOutput
277
+ cls._output_class = BaseModelOutput
278
+ return BaseModelOutput
279
+
248
280
  def _prepare_output(self, output, return_dict):
249
281
  """
250
282
  Prepare model output based on return_dict flag.
251
283
  This method can be overridden by subclasses to provide task-specific output handling.
252
284
  """
285
+ tuple_output = (output,) if not isinstance(output, (tuple, list)) else output
253
286
  if not return_dict:
254
- return (output,) if not isinstance(output, (tuple, list)) else output
287
+ return tuple_output
255
288
  else:
256
- if self.output_class is None:
257
- return BaseModelOutput(last_hidden_state=output)
258
-
259
- # Create output with the appropriate class and key
260
- return self.output_class(**{self.output_key: output})
289
+ output_class = self.get_hf_output_class()
290
+ if hasattr(output_class, "loss"):
291
+ tuple_output = (None,) + tuple_output
292
+
293
+ # Truncate if we have too many outputs, otherwise use as is
294
+ if hasattr(output_class, "__annotations__"):
295
+ num_fields = len(output_class.__annotations__)
296
+ if len(tuple_output) > num_fields:
297
+ tuple_output = tuple_output[:num_fields]
298
+ logger.warning(
299
+ f"Truncating output to {num_fields} fields for {output_class.__name__}. "
300
+ f"Expected {num_fields} fields, but got {len(tuple_output)} fields."
301
+ "This is unexpected. Please report this issue to the developers."
302
+ )
303
+
304
+ return output_class(*tuple_output)
@@ -178,9 +178,27 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
178
178
  return str(model_path)
179
179
 
180
180
  @classmethod
181
- def _load_compiled_models(cls, model_path: str):
181
+ def _load_compiled_models(cls, model_path: str, expected_compiled_model_names: List[str]):
182
182
  compiled_models = Path(model_path).glob("*.rbln")
183
- rbln_compiled_models = {cm.stem: rebel.RBLNCompiledModel(cm) for cm in compiled_models}
183
+ expected_compiled_models = [
184
+ Path(model_path) / f"{compiled_model_name}.rbln" for compiled_model_name in expected_compiled_model_names
185
+ ]
186
+ unexpected_compiled_models = [cm for cm in compiled_models if cm not in expected_compiled_models]
187
+ if unexpected_compiled_models:
188
+ # TODO(jongho): fix after May release. raise error if unexpected compiled models are found
189
+ logger.warning(
190
+ f"Unexpected compiled models found: {[cm.name for cm in unexpected_compiled_models]}. "
191
+ f"Please check the model path: {model_path}"
192
+ )
193
+
194
+ rbln_compiled_models = {}
195
+ for compiled_model in expected_compiled_models:
196
+ if not compiled_model.exists():
197
+ raise FileNotFoundError(
198
+ f"Expected RBLN compiled model '{compiled_model.name}' not found at '{model_path}'. "
199
+ "Please ensure all models specified in `rbln_config` are present."
200
+ )
201
+ rbln_compiled_models[compiled_model.stem] = rebel.RBLNCompiledModel(compiled_model)
184
202
  return rbln_compiled_models
185
203
 
186
204
  @classmethod
@@ -271,7 +289,8 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
271
289
  )
272
290
  config = PretrainedConfig(**config)
273
291
 
274
- rbln_compiled_models = cls._load_compiled_models(model_path_subfolder)
292
+ compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
293
+ rbln_compiled_models = cls._load_compiled_models(model_path_subfolder, compiled_model_names)
275
294
 
276
295
  if subfolder != "":
277
296
  model_save_dir = Path(model_path_subfolder).absolute().parent
@@ -36,11 +36,7 @@ from transformers import (
36
36
  )
37
37
  from transformers.modeling_outputs import (
38
38
  BaseModelOutput,
39
- DepthEstimatorOutput,
40
- ImageClassifierOutput,
41
- MaskedLMOutput,
42
39
  QuestionAnsweringModelOutput,
43
- SequenceClassifierOutput,
44
40
  )
45
41
 
46
42
  from ..configuration_utils import RBLNCompileConfig
@@ -63,8 +59,6 @@ class _RBLNTransformerEncoder(RBLNModel):
63
59
  auto_model_class = AutoModel
64
60
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
65
61
  rbln_dtype = "int64"
66
- output_class = BaseModelOutput
67
- output_key = "last_hidden_state"
68
62
 
69
63
  @classmethod
70
64
  def _update_rbln_config(
@@ -149,7 +143,6 @@ class _RBLNImageModel(RBLNModel):
149
143
  auto_model_class = AutoModel
150
144
  main_input_name = "pixel_values"
151
145
  output_class = BaseModelOutput
152
- output_key = "last_hidden_state"
153
146
 
154
147
  @classmethod
155
148
  def _update_rbln_config(
@@ -223,15 +216,11 @@ class RBLNModelForQuestionAnswering(_RBLNTransformerEncoder):
223
216
  class RBLNModelForSequenceClassification(_RBLNTransformerEncoder):
224
217
  auto_model_class = AutoModelForSequenceClassification
225
218
  rbln_model_input_names = ["input_ids", "attention_mask"]
226
- output_class = SequenceClassifierOutput
227
- output_key = "logits"
228
219
 
229
220
 
230
221
  class RBLNModelForMaskedLM(_RBLNTransformerEncoder):
231
222
  auto_model_class = AutoModelForMaskedLM
232
223
  rbln_model_input_names = ["input_ids", "attention_mask"]
233
- output_class = MaskedLMOutput
234
- output_key = "logits"
235
224
 
236
225
 
237
226
  class RBLNModelForTextEncoding(_RBLNTransformerEncoder):
@@ -243,20 +232,14 @@ class RBLNTransformerEncoderForFeatureExtraction(_RBLNTransformerEncoder):
243
232
  # TODO: RBLNModel is also for feature extraction.
244
233
  auto_model_class = AutoModel
245
234
  rbln_model_input_names = ["input_ids", "attention_mask"]
246
- output_class = BaseModelOutput
247
- output_key = "last_hidden_state"
248
235
 
249
236
 
250
237
  class RBLNModelForImageClassification(_RBLNImageModel):
251
238
  auto_model_class = AutoModelForImageClassification
252
- output_class = ImageClassifierOutput
253
- output_key = "logits"
254
239
 
255
240
 
256
241
  class RBLNModelForDepthEstimation(_RBLNImageModel):
257
242
  auto_model_class = AutoModelForDepthEstimation
258
- output_class = DepthEstimatorOutput
259
- output_key = "predicted_depth"
260
243
 
261
244
 
262
245
  class RBLNModelForAudioClassification(RBLNModel):
@@ -273,8 +256,6 @@ class RBLNModelForAudioClassification(RBLNModel):
273
256
  """
274
257
 
275
258
  auto_model_class = AutoModelForAudioClassification
276
- output_class = SequenceClassifierOutput
277
- output_key = "logits"
278
259
 
279
260
  @classmethod
280
261
  def _update_rbln_config(
@@ -15,7 +15,6 @@
15
15
 
16
16
  import torch
17
17
  from transformers import AutoModelForMaskedLM, Wav2Vec2ForCTC
18
- from transformers.modeling_outputs import CausalLMOutput
19
18
 
20
19
  from ...modeling_generic import RBLNModelForMaskedLM
21
20
  from .configuration_wav2vec import RBLNWav2Vec2ForCTCConfig
@@ -46,8 +45,6 @@ class RBLNWav2Vec2ForCTC(RBLNModelForMaskedLM):
46
45
  main_input_name = "input_values"
47
46
  auto_model_class = AutoModelForMaskedLM
48
47
  rbln_dtype = "float32"
49
- output_class = CausalLMOutput
50
- output_key = "logits"
51
48
 
52
49
  @classmethod
53
50
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNWav2Vec2ForCTCConfig) -> torch.nn.Module:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.7.5rc0
3
+ Version: 0.7.5rc1
4
4
  Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -1,8 +1,8 @@
1
1
  optimum/rbln/__init__.py,sha256=oAnsJSMrPYwBGEttUt3CMXTIESVNe15ftTWRTShwhZI,14386
2
- optimum/rbln/__version__.py,sha256=34rdkaF19JfWW5k_S6Q9ZJaeOOAsCoPR3_vN57l-b28,521
2
+ optimum/rbln/__version__.py,sha256=6UGbTME6xZZ1ojJoRpul_clf4TsgGIZHt3214_8maxM,521
3
3
  optimum/rbln/configuration_utils.py,sha256=gvAjRFEGw5rnSoH0IoyuLrE4fkxtk3DN1pikqrN_Rpk,31277
4
- optimum/rbln/modeling.py,sha256=4Xwi3ovWDHOOqxUDH_ZgsgTuea8Kyg25D9s81zVYpr0,9669
5
- optimum/rbln/modeling_base.py,sha256=iQKw2IORu1cN6sOK0xeBVrhatt-ZPeinT_v6l2FnGRw,24173
4
+ optimum/rbln/modeling.py,sha256=CWYpOLQOu1RUQrHvoX3FoidiP2XltDzC9gWIzznUTFo,11455
5
+ optimum/rbln/modeling_base.py,sha256=HQgscr5jpUEtuXU1ACJHSLIntX-kq6Ef0SQ_W2-rp5A,25341
6
6
  optimum/rbln/diffusers/__init__.py,sha256=XL6oKPHbPCV6IVCw3fu0-M9mD2KO_x6unx5kJdAtpVY,6180
7
7
  optimum/rbln/diffusers/modeling_diffusers.py,sha256=bPyP5RMbOFLb2DfEAuLVp7hTuQWJvWid7El72wGmFrY,19535
8
8
  optimum/rbln/diffusers/configurations/__init__.py,sha256=Sk_sQVTuTl01RVgYViWknQSLmulxKaISS0w-oPdNoBQ,1164
@@ -26,10 +26,10 @@ optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py,sha256=qjReFNmuQEjn
26
26
  optimum/rbln/diffusers/models/autoencoders/vae.py,sha256=_fyFco2697uT1zo_P_fGML-_zqZw2sUQp3tRRjA5pg4,4172
27
27
  optimum/rbln/diffusers/models/autoencoders/vq_model.py,sha256=DC8Nee8_BabGhagJgpCUDhA-oaTpZMg-lCVzXJ6dNEw,6134
28
28
  optimum/rbln/diffusers/models/transformers/__init__.py,sha256=V8rSR7WzHs-i8Cwb_MNxhY2NFbwPgxu24vGtkwl-6tk,706
29
- optimum/rbln/diffusers/models/transformers/prior_transformer.py,sha256=d7CYmm88lozepqXjmrFr4qsQ-lRE_10wQRwnenMSflU,4989
30
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py,sha256=ywWiRAYQ0wfKMMZBrJ9C34GBwIu92F5QXSG8qi7Cr6s,6579
29
+ optimum/rbln/diffusers/models/transformers/prior_transformer.py,sha256=XaIICLeMdGyqm9B3f2A3vqh1haJpqfT3GJ3ZM0DKcaY,4945
30
+ optimum/rbln/diffusers/models/transformers/transformer_sd3.py,sha256=H1dsDOnAK4Dp0ixCVIt_4_4KJ5ZcTygfG7sFFdpOvrI,6554
31
31
  optimum/rbln/diffusers/models/unets/__init__.py,sha256=MaICuK9CWjgzejXy8y2NDrphuEq1rkzanF8u45k6O5I,655
32
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py,sha256=dyrez3vS0_YSxTVwcjlSijDJhe6XchdsLsCFv74HiFQ,15555
32
+ optimum/rbln/diffusers/models/unets/unet_2d_condition.py,sha256=3dzqJQPiklkgoyxRHysOE7q9hrhaT4K0_SNiCflFvLg,15530
33
33
  optimum/rbln/diffusers/pipelines/__init__.py,sha256=5KLZ5LrpMzBya2e_3_PvEoPwG24U8JMexfw_ygZREKc,3140
34
34
  optimum/rbln/diffusers/pipelines/controlnet/__init__.py,sha256=n1Ef22TSeax-kENi_d8K6wGGHSNEo9QkUeygELHgcao,983
35
35
  optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py,sha256=Cv__E0Boc6TSOIv8TdXVE821zIiPG4MVI_lnaGSqquk,4102
@@ -65,7 +65,7 @@ optimum/rbln/transformers/__init__.py,sha256=LW6abfb0W0jHziE8dIEDBeyb4Cj-aq8dUld
65
65
  optimum/rbln/transformers/configuration_alias.py,sha256=qFVfg6ohsR7a6b-CBgxjBUPDrk9MyiJwtO8AQah_RTU,1505
66
66
  optimum/rbln/transformers/configuration_generic.py,sha256=XIiZ1-5p1CMHhG7Sr2qR4SLYKcYw9aph7eGlga3Opx0,5056
67
67
  optimum/rbln/transformers/modeling_alias.py,sha256=yx7FnZQWAnrWzivaO5hI7T6i-fyLzt2tMIXG2oDNbPo,1657
68
- optimum/rbln/transformers/modeling_generic.py,sha256=nT_lytAILkYtwBVJKxXg0dxmh0UpjGYO6zOdLoMs1uU,12891
68
+ optimum/rbln/transformers/modeling_generic.py,sha256=L5ndJJzKhXa4de1YAA8uxNzMKWOHsAHPoJrANxWYWjE,12265
69
69
  optimum/rbln/transformers/modeling_rope_utils.py,sha256=3zwkhYUyTZhxCJUSmwCc88iiY1TppRWEY9ShwUqNB2k,14293
70
70
  optimum/rbln/transformers/models/__init__.py,sha256=qNh_d7bBKxhxBbUImXJ66n0Vo0NW1m7tMIU5M2ZxGmw,8510
71
71
  optimum/rbln/transformers/models/auto/__init__.py,sha256=34Xghf1ogG4u-jhBMlj134nHdgnR3JEHSeZTPuy3MpY,1071
@@ -158,7 +158,7 @@ optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_t
158
158
  optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py,sha256=XJDjQGbWXUq4ZimNojlcbm3mTDpxUMCl6tkFSzfYFl4,13769
159
159
  optimum/rbln/transformers/models/wav2vec2/__init__.py,sha256=dzXqyf_uiI45hPJGbnF1v780Izi2TigsbAo3hxFmhy0,709
160
160
  optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py,sha256=hXsN_pc_gb_xcQdXXnvpp-o0dk5lNepXnt9O5HB-3g4,771
161
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=Sw44ZJVQpXfBiq34oyKxSY2SslCeF3QT_yBysWiTyHY,2060
161
+ optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=Lmm39NhvJIQtCkBa5BynkJYNqWOF7GaWsV5qYX-4L94,1943
162
162
  optimum/rbln/transformers/models/whisper/__init__.py,sha256=ErquiUlYycSYPsDcq9IwwmbZXoYLn1MVZ8VikWY5gQo,792
163
163
  optimum/rbln/transformers/models/whisper/configuration_whisper.py,sha256=-Su7pbkg3gkYTf-ECRJyxkpD3JtUJX4y5Mfml8tJJBI,2612
164
164
  optimum/rbln/transformers/models/whisper/generation_whisper.py,sha256=GIHTca3b1VtW81kp7BzKQ7f77c2t9OsEsbZetripgDo,4582
@@ -178,7 +178,7 @@ optimum/rbln/utils/model_utils.py,sha256=V2kFpUe2aqVzLwbpztD8JOVFQqRHncvIWwJbgnU
178
178
  optimum/rbln/utils/runtime_utils.py,sha256=LoKNK3AQNV_BSScstIZWjICkJf265MnUgy360BOocVI,5454
179
179
  optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
180
180
  optimum/rbln/utils/submodule.py,sha256=ZfI7e3YzbjbbBW4Yjfucj8NygEsukfIkaJi3PtwHrhc,5105
181
- optimum_rbln-0.7.5rc0.dist-info/METADATA,sha256=aXeccsNinGR5xXxBOKIMhxfeHyF-wQE5DxWtOrt2WyI,5300
182
- optimum_rbln-0.7.5rc0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
183
- optimum_rbln-0.7.5rc0.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
184
- optimum_rbln-0.7.5rc0.dist-info/RECORD,,
181
+ optimum_rbln-0.7.5rc1.dist-info/METADATA,sha256=RUPCGW8cEzu6extEsTB9xYDgOb8hAqgEKG0tG3K5feA,5300
182
+ optimum_rbln-0.7.5rc1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
183
+ optimum_rbln-0.7.5rc1.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
184
+ optimum_rbln-0.7.5rc1.dist-info/RECORD,,