optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.4a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. optimum/rbln/__init__.py +12 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +16 -6
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +12 -8
  38. optimum/rbln/transformers/configuration_generic.py +0 -27
  39. optimum/rbln/transformers/modeling_attention_utils.py +242 -109
  40. optimum/rbln/transformers/modeling_generic.py +2 -61
  41. optimum/rbln/transformers/modeling_outputs.py +1 -0
  42. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  43. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  44. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  45. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  46. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  47. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  48. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  49. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  50. optimum/rbln/transformers/models/colpali/modeling_colpali.py +6 -45
  51. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
  52. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +10 -1
  53. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  54. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
  55. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +207 -64
  56. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  57. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  58. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +140 -46
  59. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  60. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  61. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  62. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
  63. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  64. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  65. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
  66. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  67. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  68. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -25
  69. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  70. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  71. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  72. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  73. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  74. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  75. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  76. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -9
  77. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -7
  78. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
  79. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  80. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  81. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  82. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  83. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  84. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  85. optimum/rbln/transformers/models/siglip/modeling_siglip.py +17 -1
  86. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  87. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  88. optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
  89. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  90. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  91. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  92. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  93. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  94. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  95. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  96. optimum/rbln/transformers/utils/rbln_quantization.py +9 -0
  97. optimum/rbln/utils/deprecation.py +213 -0
  98. optimum/rbln/utils/hub.py +14 -3
  99. optimum/rbln/utils/import_utils.py +7 -1
  100. optimum/rbln/utils/runtime_utils.py +32 -0
  101. optimum/rbln/utils/submodule.py +3 -1
  102. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +2 -2
  103. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +106 -99
  104. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +1 -1
  105. optimum/rbln/utils/depreacate_utils.py +0 -16
  106. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
  107. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
@@ -13,13 +13,21 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import TYPE_CHECKING, Optional, Union
17
+
16
18
  import torch
17
- from transformers import AutoModelForMaskedLM, Wav2Vec2ForCTC
19
+ from transformers import AutoModelForCTC, Wav2Vec2Config, Wav2Vec2ForCTC
20
+ from transformers.modeling_outputs import CausalLMOutput
18
21
 
19
- from ...modeling_generic import RBLNModelForMaskedLM
22
+ from ....configuration_utils import RBLNCompileConfig
23
+ from ....modeling import RBLNModel
20
24
  from .configuration_wav2vec2 import RBLNWav2Vec2ForCTCConfig
21
25
 
22
26
 
27
+ if TYPE_CHECKING:
28
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
29
+
30
+
23
31
  class _Wav2Vec2(torch.nn.Module):
24
32
  def __init__(self, model: "Wav2Vec2ForCTC"):
25
33
  super().__init__()
@@ -30,13 +38,10 @@ class _Wav2Vec2(torch.nn.Module):
30
38
  return self.model.lm_head(output[0])
31
39
 
32
40
 
33
- class RBLNWav2Vec2ForCTC(RBLNModelForMaskedLM):
41
+ class RBLNWav2Vec2ForCTC(RBLNModel):
34
42
  """
35
43
  Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
36
44
 
37
- This model inherits from [`RBLNModelForMaskedLM`]. Check the superclass documentation for the generic methods the
38
- library implements for all its model.
39
-
40
45
  It implements the methods to convert a pre-trained Wav2Vec2 model into a RBLN Wav2Vec2 model by:
41
46
 
42
47
  - transferring the checkpoint weights of the original into an optimized RBLN graph,
@@ -44,9 +49,56 @@ class RBLNWav2Vec2ForCTC(RBLNModelForMaskedLM):
44
49
  """
45
50
 
46
51
  main_input_name = "input_values"
47
- auto_model_class = AutoModelForMaskedLM
52
+ auto_model_class = AutoModelForCTC
48
53
  rbln_dtype = "float32"
49
54
 
50
55
  @classmethod
51
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNWav2Vec2ForCTCConfig) -> torch.nn.Module:
56
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNWav2Vec2ForCTCConfig) -> torch.nn.Module:
52
57
  return _Wav2Vec2(model).eval()
58
+
59
+ @classmethod
60
+ def _update_rbln_config(
61
+ cls,
62
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
63
+ model: Optional["PreTrainedModel"] = None,
64
+ model_config: "Wav2Vec2Config" = None,
65
+ rbln_config: Optional[RBLNWav2Vec2ForCTCConfig] = None,
66
+ ) -> RBLNWav2Vec2ForCTCConfig:
67
+ if rbln_config.max_seq_len is None:
68
+ for tokenizer in preprocessors:
69
+ if hasattr(tokenizer, "model_max_length"):
70
+ rbln_config.max_seq_len = tokenizer.model_max_length
71
+ break
72
+ if rbln_config.max_seq_len is None:
73
+ raise ValueError("`rbln_max_seq_len` should be specified!")
74
+
75
+ rbln_compile_config = RBLNCompileConfig(
76
+ input_info=[
77
+ (
78
+ "input_values",
79
+ [
80
+ rbln_config.batch_size,
81
+ rbln_config.max_seq_len,
82
+ ],
83
+ "float32",
84
+ )
85
+ ]
86
+ )
87
+
88
+ rbln_config.set_compile_cfgs([rbln_compile_config])
89
+ return rbln_config
90
+
91
+ def forward(
92
+ self, input_values: torch.Tensor, return_dict: Optional[bool] = None, **kwargs
93
+ ) -> Union[CausalLMOutput, tuple]:
94
+ """
95
+ Forward pass for the RBLN-optimized Wav2Vec2 model for Connectionist Temporal Classification (CTC).
96
+
97
+ Args:
98
+ input_values (torch.FloatTensor of shape (batch_size, sequence_length)): Float values of input raw speech waveform. Values can be obtained by loading a .flac or .wav audio file into an array of type List[float] or a numpy.ndarray, e.g. via the soundfile library (pip install soundfile). To prepare the array into input_values, the AutoProcessor should be used for padding and conversion into a tensor of type torch.FloatTensor.
99
+ return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
100
+
101
+ Returns:
102
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a CausalLMOutput object.
103
+ """
104
+ return super().forward(input_values=input_values, return_dict=return_dict, **kwargs)
@@ -31,29 +31,63 @@ Generation utilities for Whisper.
31
31
  Modified from `transformers.models.whisper.generation_whisper.py`
32
32
  """
33
33
 
34
+ from typing import Any, Dict, Optional, Union
35
+
34
36
  import torch
35
37
  import transformers
36
38
  from packaging import version
37
39
  from transformers import GenerationMixin
40
+ from transformers.generation.configuration_utils import GenerationConfig
41
+ from transformers.modeling_outputs import ModelOutput
38
42
  from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
39
43
 
40
44
 
41
45
  class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
42
- def generate(self, *args, generation_config=None, **kwargs):
43
- num_beams = kwargs.get(
44
- "num_beams",
45
- generation_config.num_beams
46
- if hasattr(generation_config, "num_beams") and generation_config.num_beams is not None
47
- else 1,
48
- )
49
- if num_beams > 1:
50
- raise ValueError(
51
- f"Beam search is not supported in RBLNWhisperGenerationMixin. "
52
- f"Received num_beams={num_beams}, but only num_beams=1 is allowed. "
53
- f"Please set num_beams=1 for greedy search or adjust your configuration."
54
- )
46
+ def generate(
47
+ self,
48
+ input_features: Optional[torch.Tensor] = None,
49
+ attention_mask: Optional[torch.Tensor] = None,
50
+ generation_config: Optional[GenerationConfig] = None,
51
+ return_segments: Optional[bool] = None,
52
+ return_timestamps: Optional[bool] = None,
53
+ return_token_timestamps: Optional[bool] = None,
54
+ **kwargs,
55
+ ) -> Union[ModelOutput, Dict[str, Any], torch.LongTensor]:
56
+ """
57
+ The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
58
+ Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate) for more details.
59
+
60
+ Args:
61
+ input_features(torch.Tensor, optional): The input features to the model.
62
+ attention_mask(torch.Tensor, optional): Attention mask needs to be passed when doing long-form transcription using a batch size > 1.
63
+ generation_config(GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
64
+ If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
65
+ Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
66
+ return_segments(bool, optional): Whether to return segments.
67
+ return_timestamps(bool, optional): Whether to return the timestamps with the text. For audios longer than 30 seconds, it is necessary to set return_timestamps=True.
68
+ return_token_timestamps(bool, optional): Whether to return token timestamps.
69
+ kwargs(dict[str, Any], optional): Additional arguments passed to the generate function.
70
+
71
+ Returns:
72
+ Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
73
+ """
74
+ if kwargs.get("num_beams", None) is not None:
75
+ if kwargs.get("num_beams") != 1:
76
+ raise ValueError(
77
+ "Beam search is not supported in RBLNWhisperGenerationMixin. "
78
+ "Received num_beams={num_beams}, but only num_beams=1 is allowed. "
79
+ "Please set num_beams=1 for greedy search or adjust your configuration."
80
+ )
55
81
 
56
- return super().generate(*args, **kwargs)
82
+ return super().generate(
83
+ input_features,
84
+ attention_mask=attention_mask,
85
+ generation_config=generation_config,
86
+ return_segments=return_segments,
87
+ return_timestamps=return_timestamps,
88
+ return_token_timestamps=return_token_timestamps,
89
+ **kwargs,
90
+ )
57
91
 
58
92
  def _postprocess_outputs(
59
93
  self,
@@ -203,7 +203,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
203
203
  raise NotImplementedError
204
204
 
205
205
  @classmethod
206
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNWhisperForConditionalGenerationConfig):
206
+ def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNWhisperForConditionalGenerationConfig):
207
207
  return WhisperWrapper(
208
208
  model,
209
209
  use_attention_mask=rbln_config.use_attention_mask,
@@ -213,7 +213,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
213
213
  @classmethod
214
214
  @torch.inference_mode()
215
215
  def get_compiled_model(cls, model, rbln_config: RBLNWhisperForConditionalGenerationConfig):
216
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
216
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
217
217
 
218
218
  enc_compile_config = rbln_config.compile_cfgs[0]
219
219
  dec_compile_config = rbln_config.compile_cfgs[1]
@@ -12,6 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Optional, Union
16
+
17
+ import torch
18
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions, SequenceClassifierOutput
19
+
15
20
  from ...modeling_generic import RBLNModelForSequenceClassification, RBLNTransformerEncoderForFeatureExtraction
16
21
 
17
22
 
@@ -20,6 +25,30 @@ class RBLNXLMRobertaModel(RBLNTransformerEncoderForFeatureExtraction):
20
25
  XLM-RoBERTa base model optimized for RBLN NPU.
21
26
  """
22
27
 
28
+ def forward(
29
+ self,
30
+ input_ids: Optional[torch.Tensor] = None,
31
+ attention_mask: Optional[torch.Tensor] = None,
32
+ token_type_ids: Optional[torch.Tensor] = None,
33
+ **kwargs,
34
+ ) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, tuple]:
35
+ """
36
+ Forward pass for the RBLN-optimized XLM-RoBERTa base model.
37
+
38
+ Args:
39
+ input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
40
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
41
+ token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate different portions of the inputs.
42
+
43
+ Returns:
44
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPoolingAndCrossAttentions object.
45
+ """
46
+
47
+ if token_type_ids is not None:
48
+ kwargs.setdefault("token_type_ids", token_type_ids)
49
+
50
+ return super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
51
+
23
52
 
24
53
  class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification):
25
54
  """
@@ -27,3 +56,27 @@ class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification
27
56
  """
28
57
 
29
58
  rbln_model_input_names = ["input_ids", "attention_mask"]
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: Optional[torch.LongTensor] = None,
63
+ attention_mask: Optional[torch.FloatTensor] = None,
64
+ token_type_ids: Optional[torch.LongTensor] = None,
65
+ **kwargs,
66
+ ) -> Union[SequenceClassifierOutput, tuple]:
67
+ """
68
+ Forward pass for the RBLN-optimized XLM-RoBERTa model for sequence classification.
69
+
70
+ Args:
71
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
72
+ attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
73
+ token_type_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
74
+
75
+ Returns:
76
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SequenceClassifierOutput object.
77
+ """
78
+
79
+ if token_type_ids is not None:
80
+ kwargs.setdefault("token_type_ids", token_type_ids)
81
+
82
+ return super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
@@ -123,6 +123,15 @@ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
123
123
  if self.RBLN_QUANT_BITS_ENV in os.environ:
124
124
  os.environ.pop(self.RBLN_QUANT_BITS_ENV)
125
125
 
126
+ @property
127
+ def nbits_per_param(self) -> int:
128
+ if self.weights in ["int4", "fp4"]:
129
+ return 4
130
+ elif self.weights in ["int8", "fp8"]:
131
+ return 8
132
+ else:
133
+ raise ValueError(f"Invalid weights: {self.weights}")
134
+
126
135
 
127
136
  class QuantizedLayerFactory:
128
137
  def __init__(self, quantization_config: RBLNQuantizationConfig):
@@ -0,0 +1,213 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ # Copyright 2025 Rebellions Inc. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # **********************************************************************************
17
+ # * NOTE: This file has been modified from its original version in *
18
+ # * the Hugging Face transformers library. *
19
+ # * Original source: *
20
+ # * https://github.com/huggingface/transformers/blob/v4.57.1/src/transformers/utils/deprecation.py
21
+ # **********************************************************************************
22
+
23
+ import inspect
24
+ from enum import Enum
25
+ from functools import wraps
26
+ from typing import Callable, Optional
27
+
28
+ import packaging.version
29
+
30
+ from ..__version__ import __version__
31
+ from .logging import get_logger
32
+
33
+
34
+ logger = get_logger(__name__)
35
+
36
+
37
+ def warn_deprecated_npu(npu: Optional[str] = None):
38
+ import rebel
39
+
40
+ npu = npu or rebel.get_npu_name()
41
+ if npu == "RBLN-CA02":
42
+ logger.warning_once(
43
+ "Support for the RBLN-CA02 device is provided only up to optimum-rbln v0.8.0 and has reached end of life.",
44
+ )
45
+
46
+
47
+ class Action(Enum):
48
+ NONE = "none"
49
+ NOTIFY = "notify"
50
+ RAISE = "raise"
51
+
52
+
53
+ # Scenario Table for Deprecation Strategy Example
54
+ # Assume that current version is v0.9.6 and the deprecated version is v0.10.0
55
+ # |--------------------|----------------|----------------|---------------------------------------------|--------------------------------------------------------------------------------------|----------------------------------------------------------------------|
56
+ # | Type | v0.9.6 (as_is) | v0.9.6 (to_be) | v0.9.6 Patch | v0.9.7 Action | v0.10.0+ Action |
57
+ # |--------------------|----------------|----------------|---------------------------------------------|--------------------------------------------------------------------------------------|----------------------------------------------------------------------|
58
+ # | Modify (Key Name) | a: bool | a': bool | Add a', Keep a | 1. Only 'a' provided: replace a -> a' & future warning | In v0.10.0, raise error once, then remove decorator. |
59
+ # | | | | | 2. Both 'a' & 'a'' provided: ignore 'a' value & future warning | |
60
+ # |--------------------|----------------|----------------|---------------------------------------------|--------------------------------------------------------------------------------------|----------------------------------------------------------------------|
61
+ # | Modify (Value Type)| b: bool | b: str | b: Union[str, bool] | 'bool' value provided for 'b': replace with corresponding 'str' & future warning | In v0.10.0, raise error once, then remove decorator. |
62
+ # | | | | | | |
63
+ # |--------------------|----------------|----------------|---------------------------------------------|--------------------------------------------------------------------------------------|----------------------------------------------------------------------|
64
+ # | Deletion | c | - | Delete c or Keep c (flexible) | ignore c & future warning | In v0.10.0, raise error once, then remove decorator. |
65
+ # | | | | | | |
66
+ # |--------------------|----------------|----------------|---------------------------------------------|--------------------------------------------------------------------------------------|----------------------------------------------------------------------|
67
+ # | Addition | - | d | Add d, set default_value for d | No action needed as default value is set | Keep default value |
68
+ # |--------------------|----------------|----------------|---------------------------------------------|--------------------------------------------------------------------------------------|----------------------------------------------------------------------|
69
+
70
+
71
+ def deprecate_kwarg(
72
+ old_name: str,
73
+ version: str,
74
+ new_name: Optional[str] = None,
75
+ deprecated_type: Optional[type] = None,
76
+ value_replacer: Optional[Callable] = None,
77
+ raise_if_greater_or_equal_version: bool = True,
78
+ raise_if_both_names: bool = False,
79
+ additional_message: Optional[str] = None,
80
+ ):
81
+ """
82
+ Function or method decorator to notify users about deprecated keyword arguments, replacing them with a new name if specified,
83
+ or handling deprecated value types.
84
+
85
+ This decorator allows you to:
86
+ - Notify users when a keyword argument name is deprecated (Scenario 'a', 'c').
87
+ - Notify users when a specific value type for an argument is deprecated (Scenario 'b').
88
+ - Automatically replace deprecated keyword arguments with new ones.
89
+ - Automatically replace deprecated values with new ones using a replacer function.
90
+ - Raise an error if deprecated arguments are used, depending on the specified conditions.
91
+
92
+ By default, the decorator notifies the user about the deprecated argument while the `optimum.rbln.__version__` < specified `version`
93
+ in the decorator. To keep notifications with any version `warn_if_greater_or_equal_version=True` can be set.
94
+
95
+ Parameters:
96
+ old_name (`str`):
97
+ Name of the deprecated keyword argument, or the argument with a deprecated value type.
98
+ version (`str`):
99
+ The version in which the keyword argument or value type was (or will be) deprecated.
100
+ new_name (`Optional[str]`, *optional*):
101
+ The new name for the deprecated keyword argument. If specified, the deprecated keyword argument will be replaced with this new name (Scenario 'a').
102
+ deprecated_type (`type`, *optional*):
103
+ The deprecated type for the keyword argument specified by `old_name` (Scenario 'b').
104
+ If this is set, `new_name` should typically be `None`.
105
+ value_replacer (`Callable`, *optional*):
106
+ A function that takes the old (deprecated type) value and returns a new value (Scenario 'b').
107
+ Used in conjunction with `deprecated_type`. If provided, the value will be automatically converted.
108
+ raise_if_greater_or_equal_version (`bool`, *optional*, defaults to `False`):
109
+ Whether to raise `ValueError` if current `optimum.rbln.` version is greater or equal to the deprecated version.
110
+ raise_if_both_names (`bool`, *optional*, defaults to `False`):
111
+ Whether to raise `ValueError` if both deprecated and new keyword arguments are set (only for Scenario 'a').
112
+ additional_message (`Optional[str]`, *optional*):
113
+ An additional message to append to the default deprecation message.
114
+
115
+ Raises:
116
+ ValueError:
117
+ If raise_if_greater_or_equal_version is True and the current version is greater than or equal to the deprecated version, or if raise_if_both_names is True and both old and new keyword arguments are provided.
118
+
119
+ Returns:
120
+ Callable:
121
+ A wrapped function that handles the deprecated keyword arguments according to the specified parameters.
122
+ """
123
+
124
+ deprecated_version = packaging.version.parse(version)
125
+ current_version = packaging.version.parse(__version__)
126
+ is_greater_or_equal_version = current_version >= deprecated_version
127
+
128
+ if is_greater_or_equal_version:
129
+ version_message = f"and removed starting from version {version}"
130
+ else:
131
+ version_message = f"and will be removed in version {version}"
132
+
133
+ def wrapper(func):
134
+ # Required for better warning message
135
+ sig = inspect.signature(func)
136
+ function_named_args = set(sig.parameters.keys())
137
+ is_instance_method = "self" in function_named_args
138
+ is_class_method = "cls" in function_named_args
139
+
140
+ @wraps(func)
141
+ def wrapped_func(*args, **kwargs):
142
+ # Get class + function name (just for better warning message)
143
+ func_name = func.__name__
144
+ if is_instance_method:
145
+ func_name = f"{args[0].__class__.__name__}.{func_name}"
146
+ elif is_class_method:
147
+ func_name = f"{args[0].__name__}.{func_name}"
148
+
149
+ minimum_action = Action.NONE
150
+ message = None
151
+
152
+ # Scenario A: Rename (e.g., a -> a')
153
+ if new_name is not None:
154
+ if old_name in kwargs and new_name in kwargs:
155
+ minimum_action = Action.RAISE if raise_if_both_names else Action.NOTIFY
156
+ message = f"Both `{old_name}` and `{new_name}` are set for `{func_name}`. Using `{new_name}={kwargs[new_name]}` and ignoring deprecated `{old_name}={kwargs[old_name]}`."
157
+ kwargs.pop(old_name)
158
+
159
+ elif old_name in kwargs and new_name not in kwargs:
160
+ minimum_action = Action.NOTIFY
161
+ message = (
162
+ f"`{old_name}` is deprecated {version_message} for `{func_name}`. Use `{new_name}` instead."
163
+ )
164
+ kwargs[new_name] = kwargs.pop(old_name)
165
+
166
+ # Scenario B: Value Type Change (e.g., b: bool -> str)
167
+ if deprecated_type is not None:
168
+ key_to_check = old_name if new_name is None else new_name # For Senario A + B Mixed
169
+ if key_to_check in kwargs and isinstance(kwargs[key_to_check], deprecated_type):
170
+ minimum_action = Action.NOTIFY
171
+ old_value = kwargs[key_to_check]
172
+ message = f"Using type `{deprecated_type.__name__}` for argument `{key_to_check}` in `{func_name}` is deprecated {version_message}."
173
+
174
+ if value_replacer:
175
+ try:
176
+ new_value = value_replacer(old_value)
177
+ kwargs[key_to_check] = new_value
178
+ message += f" Value `{old_value}` has been automatically replaced with `{new_value}`."
179
+ except Exception as e:
180
+ logger.error(f"Error during deprecated value replacement for {key_to_check}: {e}")
181
+ message += f" Automatic replacement failed: {e}. Passing original value."
182
+ else:
183
+ raise ValueError(
184
+ f"value_replacer should be provided when deprecated_type is set for {key_to_check} in {func_name}"
185
+ )
186
+
187
+ # Scenario C: Deletion (e.g., c)
188
+ if old_name in kwargs and new_name is None and deprecated_type is None:
189
+ minimum_action = Action.NOTIFY
190
+ message = f"`{old_name}` is deprecated {version_message} for `{func_name}`."
191
+ kwargs.pop(old_name)
192
+
193
+ if message is not None and additional_message is not None:
194
+ message = f"{message} {additional_message}"
195
+
196
+ # update minimum_action if argument is ALREADY deprecated (current version >= deprecated version)
197
+ if is_greater_or_equal_version:
198
+ # change to NOTIFY -> RAISE in case we want to raise error for already deprecated arguments
199
+ if raise_if_greater_or_equal_version:
200
+ minimum_action = Action.RAISE
201
+
202
+ # raise error or notify user
203
+ if minimum_action == Action.RAISE:
204
+ raise ValueError(message)
205
+ elif minimum_action == Action.NOTIFY:
206
+ # DeprecationWarning is ignored by default, so we use FutureWarning instead
207
+ logger.warning(message, stacklevel=2)
208
+
209
+ return func(*args, **kwargs)
210
+
211
+ return wrapped_func
212
+
213
+ return wrapper
optimum/rbln/utils/hub.py CHANGED
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import json
15
16
  from pathlib import Path
16
17
  from typing import List, Optional, Union
17
18
 
@@ -67,15 +68,25 @@ def validate_files(
67
68
  location: str,
68
69
  ):
69
70
  """Validate the presence and count of required files."""
70
- if len(files) == 0:
71
- raise FileNotFoundError(f"Could not find any rbln model file in {location}")
72
-
73
71
  if len(config_files) == 0:
74
72
  raise FileNotFoundError(f"Could not find `rbln_config.json` file in {location}")
75
73
 
76
74
  if len(config_files) > 1:
77
75
  raise FileExistsError(f"Multiple rbln_config.json files found in {location}. This is not expected.")
78
76
 
77
+ try:
78
+ with open(config_files[0], "r") as f:
79
+ config_data = json.load(f)
80
+ compile_cfgs = config_data.get("_compile_cfgs", [])
81
+ if len(compile_cfgs) == 0:
82
+ # If compile_cfgs is empty, we don't need .rbln files
83
+ return
84
+ except (json.JSONDecodeError, KeyError, OSError):
85
+ pass
86
+
87
+ if len(files) == 0:
88
+ raise FileNotFoundError(f"Could not find any rbln model file in {location}")
89
+
79
90
 
80
91
  def _get_huggingface_token(token: Union[bool, str]) -> str:
81
92
  if isinstance(token, str):
@@ -142,7 +142,11 @@ def check_version_compats() -> None:
142
142
  try:
143
143
  dep_version = importlib.metadata.version(compat.package_name)
144
144
  except importlib.metadata.PackageNotFoundError:
145
- warnings.warn(f"optimum-rbln requires {compat.package_name} to be installed.", ImportWarning)
145
+ warnings.warn(
146
+ f"optimum-rbln requires {compat.package_name} to be installed.",
147
+ ImportWarning,
148
+ stacklevel=2,
149
+ )
146
150
  continue
147
151
  # For versions 0.7.2 and above, don't show warning for rebel-compiler if base versions match
148
152
 
@@ -160,6 +164,7 @@ def check_version_compats() -> None:
160
164
  f"For optimal performance and compatibility, please ensure both packages share the same major and minor version numbers. "
161
165
  "Please refer to our SDK release notes at https://docs.rbln.ai/about_atom/release_note.html",
162
166
  ImportWarning,
167
+ stacklevel=2,
163
168
  )
164
169
  else:
165
170
  if not Version(compat.min_version) <= Version(dep_version) < Version(compat.max_version):
@@ -167,4 +172,5 @@ def check_version_compats() -> None:
167
172
  f"optimum-rbln v{my_version} is compatible to {compat.package_name} v{compat.min_version} to v{compat.max_version}. (you are currently using v{dep_version})\n"
168
173
  "Please refer to our SDK release notes at https://docs.rbln.ai/about_atom/release_note.html",
169
174
  ImportWarning,
175
+ stacklevel=2,
170
176
  )
@@ -20,6 +20,38 @@ import rebel
20
20
  import torch
21
21
 
22
22
 
23
+ def get_available_dram(npu: Optional[str] = None) -> int:
24
+ """
25
+ Get the available DRAM size of the specified NPU.
26
+
27
+ Args:
28
+ npu : Optional[str], default=None
29
+ The NPU to get the available DRAM size.
30
+ If None, the function will attempt to retrieve through `ensure_valid_npu()`
31
+
32
+ Returns:
33
+ int
34
+ The available DRAM size in bytes.
35
+ """
36
+ if npu is None:
37
+ if not rebel.npu_is_available(0):
38
+ raise RuntimeError("No NPU is available to get available DRAM size.")
39
+
40
+ npu = rebel.get_npu_name(0)
41
+
42
+ if npu.startswith("RBLN-CR"):
43
+ # TODO(jongho): Assuming 4 chiplets.
44
+ DRAM_NBYTES = 144 * 2**30
45
+ SYS_DRAM_NBYTES = 4 * 2**30
46
+ elif npu.startswith("RBLN-CA"):
47
+ DRAM_NBYTES = 16 * 2**30
48
+ SYS_DRAM_NBYTES = 288 * 2**20
49
+ else:
50
+ raise ValueError(f"Unknown npu name: {npu}")
51
+
52
+ return DRAM_NBYTES - SYS_DRAM_NBYTES
53
+
54
+
23
55
  def normalize_npu(npu: str) -> str:
24
56
  """Normalize the NPU string by removing the form factor."""
25
57
  match = re.match(r"(RBLN-CA|RBLN-CR)(\d+)", npu)
@@ -37,7 +37,9 @@ class SubModulesMixin:
37
37
 
38
38
  _rbln_submodules: List[Dict[str, Any]] = []
39
39
 
40
- def __init__(self, *, rbln_submodules: List["RBLNModel"] = [], **kwargs) -> None:
40
+ def __init__(self, *, rbln_submodules: Optional[List["RBLNModel"]] = None, **kwargs) -> None:
41
+ if rbln_submodules is None:
42
+ rbln_submodules = []
41
43
  for submodule_meta, submodule in zip(self._rbln_submodules, rbln_submodules):
42
44
  setattr(self, submodule_meta["name"], submodule)
43
45
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.9.3rc0
3
+ Version: 0.9.4a2
4
4
  Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -24,7 +24,7 @@ Classifier: Programming Language :: Python :: 3.13
24
24
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
25
25
  Requires-Python: <3.14,>=3.9
26
26
  Requires-Dist: accelerate>=1.0.1
27
- Requires-Dist: diffusers==0.35.1
27
+ Requires-Dist: diffusers==0.35.2
28
28
  Requires-Dist: packaging>=24.1
29
29
  Requires-Dist: torch==2.8.0
30
30
  Requires-Dist: torchaudio<=2.8.0