optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -12,6 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Optional, Union
16
+
17
+ import torch
18
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions, SequenceClassifierOutput
19
+
15
20
  from ...modeling_generic import RBLNModelForSequenceClassification, RBLNTransformerEncoderForFeatureExtraction
16
21
 
17
22
 
@@ -20,6 +25,30 @@ class RBLNXLMRobertaModel(RBLNTransformerEncoderForFeatureExtraction):
20
25
  XLM-RoBERTa base model optimized for RBLN NPU.
21
26
  """
22
27
 
28
+ def forward(
29
+ self,
30
+ input_ids: Optional[torch.Tensor] = None,
31
+ attention_mask: Optional[torch.Tensor] = None,
32
+ token_type_ids: Optional[torch.Tensor] = None,
33
+ **kwargs,
34
+ ) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, tuple]:
35
+ """
36
+ Forward pass for the RBLN-optimized XLM-RoBERTa base model.
37
+
38
+ Args:
39
+ input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
40
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
41
+ token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate different portions of the inputs.
42
+
43
+ Returns:
44
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPoolingAndCrossAttentions object.
45
+ """
46
+
47
+ if token_type_ids is not None:
48
+ kwargs.setdefault("token_type_ids", token_type_ids)
49
+
50
+ return super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
51
+
23
52
 
24
53
  class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification):
25
54
  """
@@ -27,3 +56,27 @@ class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification
27
56
  """
28
57
 
29
58
  rbln_model_input_names = ["input_ids", "attention_mask"]
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: Optional[torch.LongTensor] = None,
63
+ attention_mask: Optional[torch.FloatTensor] = None,
64
+ token_type_ids: Optional[torch.LongTensor] = None,
65
+ **kwargs,
66
+ ) -> Union[SequenceClassifierOutput, tuple]:
67
+ """
68
+ Forward pass for the RBLN-optimized XLM-RoBERTa model for sequence classification.
69
+
70
+ Args:
71
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
72
+ attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
73
+ token_type_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
74
+
75
+ Returns:
76
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SequenceClassifierOutput object.
77
+ """
78
+
79
+ if token_type_ids is not None:
80
+ kwargs.setdefault("token_type_ids", token_type_ids)
81
+
82
+ return super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
@@ -123,6 +123,15 @@ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
123
123
  if self.RBLN_QUANT_BITS_ENV in os.environ:
124
124
  os.environ.pop(self.RBLN_QUANT_BITS_ENV)
125
125
 
126
+ @property
127
+ def nbits_per_param(self) -> int:
128
+ if self.weights in ["int4", "fp4"]:
129
+ return 4
130
+ elif self.weights in ["int8", "fp8"]:
131
+ return 8
132
+ else:
133
+ raise ValueError(f"Invalid weights: {self.weights}")
134
+
126
135
 
127
136
  class QuantizedLayerFactory:
128
137
  def __init__(self, quantization_config: RBLNQuantizationConfig):
@@ -212,11 +221,12 @@ def load_weight_files(
212
221
  cache_dir: Optional[str] = None,
213
222
  force_download: bool = False,
214
223
  local_files_only: bool = False,
224
+ exception_keywords: Optional[List[str]] = None,
215
225
  ) -> list[str]:
216
226
  """
217
227
  Discover and download safetensors files for the given model id.
218
228
  """
219
-
229
+ exception_keywords = exception_keywords or []
220
230
  if os.path.isdir(model_id):
221
231
  safetensor_files = glob.glob(f"{model_id}/*.safetensors")
222
232
  else:
@@ -228,17 +238,24 @@ def load_weight_files(
228
238
 
229
239
  for file in repo_files:
230
240
  if file.endswith(".safetensors"):
231
- # Download the safetensors file
232
- downloaded_file = hf_hub_download(
233
- repo_id=model_id,
234
- filename=file,
235
- revision=revision,
236
- token=use_auth_token,
237
- cache_dir=cache_dir,
238
- force_download=force_download,
239
- local_files_only=local_files_only,
240
- )
241
- safetensor_files.append(downloaded_file)
241
+ exculde = False
242
+ for except_key in exception_keywords:
243
+ if except_key in file:
244
+ exculde = True
245
+ break
246
+
247
+ if not exculde:
248
+ # Download the safetensors file
249
+ downloaded_file = hf_hub_download(
250
+ repo_id=model_id,
251
+ filename=file,
252
+ revision=revision,
253
+ token=use_auth_token,
254
+ cache_dir=cache_dir,
255
+ force_download=force_download,
256
+ local_files_only=local_files_only,
257
+ )
258
+ safetensor_files.append(downloaded_file)
242
259
  except Exception as e:
243
260
  logger.error(f"Failed to download safetensors files from Hugging Face Hub: {e}")
244
261
  raise e
@@ -0,0 +1,213 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ # Copyright 2025 Rebellions Inc. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # **********************************************************************************
17
+ # * NOTE: This file has been modified from its original version in *
18
+ # * the Hugging Face transformers library. *
19
+ # * Original source: *
20
+ # * https://github.com/huggingface/transformers/blob/v4.57.1/src/transformers/utils/deprecation.py
21
+ # **********************************************************************************
22
+
23
+ import inspect
24
+ from enum import Enum
25
+ from functools import wraps
26
+ from typing import Callable, Optional
27
+
28
+ import packaging.version
29
+
30
+ from ..__version__ import __version__
31
+ from .logging import get_logger
32
+
33
+
34
+ logger = get_logger(__name__)
35
+
36
+
37
+ def warn_deprecated_npu(npu: Optional[str] = None):
38
+ import rebel
39
+
40
+ npu = npu or rebel.get_npu_name()
41
+ if npu == "RBLN-CA02":
42
+ logger.warning_once(
43
+ "Support for the RBLN-CA02 device is provided only up to optimum-rbln v0.8.0 and has reached end of life.",
44
+ )
45
+
46
+
47
+ class Action(Enum):
48
+ NONE = "none"
49
+ NOTIFY = "notify"
50
+ RAISE = "raise"
51
+
52
+
53
+ # Scenario Table for Deprecation Strategy Example
54
+ # Assume that current version is v0.9.6 and the deprecated version is v0.10.0
55
+ # |--------------------|----------------|----------------|---------------------------------------------|--------------------------------------------------------------------------------------|----------------------------------------------------------------------|
56
+ # | Type | v0.9.6 (as_is) | v0.9.6 (to_be) | v0.9.6 Patch | v0.9.7 Action | v0.10.0+ Action |
57
+ # |--------------------|----------------|----------------|---------------------------------------------|--------------------------------------------------------------------------------------|----------------------------------------------------------------------|
58
+ # | Modify (Key Name) | a: bool | a': bool | Add a', Keep a | 1. Only 'a' provided: replace a -> a' & future warning | In v0.10.0, raise error once, then remove decorator. |
59
+ # | | | | | 2. Both 'a' & 'a'' provided: ignore 'a' value & future warning | |
60
+ # |--------------------|----------------|----------------|---------------------------------------------|--------------------------------------------------------------------------------------|----------------------------------------------------------------------|
61
+ # | Modify (Value Type)| b: bool | b: str | b: Union[str, bool] | 'bool' value provided for 'b': replace with corresponding 'str' & future warning | In v0.10.0, raise error once, then remove decorator. |
62
+ # | | | | | | |
63
+ # |--------------------|----------------|----------------|---------------------------------------------|--------------------------------------------------------------------------------------|----------------------------------------------------------------------|
64
+ # | Deletion | c | - | Delete c or Keep c (flexible) | ignore c & future warning | In v0.10.0, raise error once, then remove decorator. |
65
+ # | | | | | | |
66
+ # |--------------------|----------------|----------------|---------------------------------------------|--------------------------------------------------------------------------------------|----------------------------------------------------------------------|
67
+ # | Addition | - | d | Add d, set default_value for d | No action needed as default value is set | Keep default value |
68
+ # |--------------------|----------------|----------------|---------------------------------------------|--------------------------------------------------------------------------------------|----------------------------------------------------------------------|
69
+
70
+
71
+ def deprecate_kwarg(
72
+ old_name: str,
73
+ version: str,
74
+ new_name: Optional[str] = None,
75
+ deprecated_type: Optional[type] = None,
76
+ value_replacer: Optional[Callable] = None,
77
+ raise_if_greater_or_equal_version: bool = True,
78
+ raise_if_both_names: bool = False,
79
+ additional_message: Optional[str] = None,
80
+ ):
81
+ """
82
+ Function or method decorator to notify users about deprecated keyword arguments, replacing them with a new name if specified,
83
+ or handling deprecated value types.
84
+
85
+ This decorator allows you to:
86
+ - Notify users when a keyword argument name is deprecated (Scenario 'a', 'c').
87
+ - Notify users when a specific value type for an argument is deprecated (Scenario 'b').
88
+ - Automatically replace deprecated keyword arguments with new ones.
89
+ - Automatically replace deprecated values with new ones using a replacer function.
90
+ - Raise an error if deprecated arguments are used, depending on the specified conditions.
91
+
92
+ By default, the decorator notifies the user about the deprecated argument while the `optimum.rbln.__version__` < specified `version`
93
+ in the decorator. To keep notifications with any version `warn_if_greater_or_equal_version=True` can be set.
94
+
95
+ Parameters:
96
+ old_name (`str`):
97
+ Name of the deprecated keyword argument, or the argument with a deprecated value type.
98
+ version (`str`):
99
+ The version in which the keyword argument or value type was (or will be) deprecated.
100
+ new_name (`Optional[str]`, *optional*):
101
+ The new name for the deprecated keyword argument. If specified, the deprecated keyword argument will be replaced with this new name (Scenario 'a').
102
+ deprecated_type (`type`, *optional*):
103
+ The deprecated type for the keyword argument specified by `old_name` (Scenario 'b').
104
+ If this is set, `new_name` should typically be `None`.
105
+ value_replacer (`Callable`, *optional*):
106
+ A function that takes the old (deprecated type) value and returns a new value (Scenario 'b').
107
+ Used in conjunction with `deprecated_type`. If provided, the value will be automatically converted.
108
+ raise_if_greater_or_equal_version (`bool`, *optional*, defaults to `False`):
109
+ Whether to raise `ValueError` if current `optimum.rbln.` version is greater or equal to the deprecated version.
110
+ raise_if_both_names (`bool`, *optional*, defaults to `False`):
111
+ Whether to raise `ValueError` if both deprecated and new keyword arguments are set (only for Scenario 'a').
112
+ additional_message (`Optional[str]`, *optional*):
113
+ An additional message to append to the default deprecation message.
114
+
115
+ Raises:
116
+ ValueError:
117
+ If raise_if_greater_or_equal_version is True and the current version is greater than or equal to the deprecated version, or if raise_if_both_names is True and both old and new keyword arguments are provided.
118
+
119
+ Returns:
120
+ Callable:
121
+ A wrapped function that handles the deprecated keyword arguments according to the specified parameters.
122
+ """
123
+
124
+ deprecated_version = packaging.version.parse(version)
125
+ current_version = packaging.version.parse(__version__)
126
+ is_greater_or_equal_version = current_version >= deprecated_version
127
+
128
+ if is_greater_or_equal_version:
129
+ version_message = f"and removed starting from version {version}"
130
+ else:
131
+ version_message = f"and will be removed in version {version}"
132
+
133
+ def wrapper(func):
134
+ # Required for better warning message
135
+ sig = inspect.signature(func)
136
+ function_named_args = set(sig.parameters.keys())
137
+ is_instance_method = "self" in function_named_args
138
+ is_class_method = "cls" in function_named_args
139
+
140
+ @wraps(func)
141
+ def wrapped_func(*args, **kwargs):
142
+ # Get class + function name (just for better warning message)
143
+ func_name = func.__name__
144
+ if is_instance_method:
145
+ func_name = f"{args[0].__class__.__name__}.{func_name}"
146
+ elif is_class_method:
147
+ func_name = f"{args[0].__name__}.{func_name}"
148
+
149
+ minimum_action = Action.NONE
150
+ message = None
151
+
152
+ # Scenario A: Rename (e.g., a -> a')
153
+ if new_name is not None:
154
+ if old_name in kwargs and new_name in kwargs:
155
+ minimum_action = Action.RAISE if raise_if_both_names else Action.NOTIFY
156
+ message = f"Both `{old_name}` and `{new_name}` are set for `{func_name}`. Using `{new_name}={kwargs[new_name]}` and ignoring deprecated `{old_name}={kwargs[old_name]}`."
157
+ kwargs.pop(old_name)
158
+
159
+ elif old_name in kwargs and new_name not in kwargs:
160
+ minimum_action = Action.NOTIFY
161
+ message = (
162
+ f"`{old_name}` is deprecated {version_message} for `{func_name}`. Use `{new_name}` instead."
163
+ )
164
+ kwargs[new_name] = kwargs.pop(old_name)
165
+
166
+ # Scenario B: Value Type Change (e.g., b: bool -> str)
167
+ if deprecated_type is not None:
168
+ key_to_check = old_name if new_name is None else new_name # For Senario A + B Mixed
169
+ if key_to_check in kwargs and isinstance(kwargs[key_to_check], deprecated_type):
170
+ minimum_action = Action.NOTIFY
171
+ old_value = kwargs[key_to_check]
172
+ message = f"Using type `{deprecated_type.__name__}` for argument `{key_to_check}` in `{func_name}` is deprecated {version_message}."
173
+
174
+ if value_replacer:
175
+ try:
176
+ new_value = value_replacer(old_value)
177
+ kwargs[key_to_check] = new_value
178
+ message += f" Value `{old_value}` has been automatically replaced with `{new_value}`."
179
+ except Exception as e:
180
+ logger.error(f"Error during deprecated value replacement for {key_to_check}: {e}")
181
+ message += f" Automatic replacement failed: {e}. Passing original value."
182
+ else:
183
+ raise ValueError(
184
+ f"value_replacer should be provided when deprecated_type is set for {key_to_check} in {func_name}"
185
+ )
186
+
187
+ # Scenario C: Deletion (e.g., c)
188
+ if old_name in kwargs and new_name is None and deprecated_type is None:
189
+ minimum_action = Action.NOTIFY
190
+ message = f"`{old_name}` is deprecated {version_message} for `{func_name}`."
191
+ kwargs.pop(old_name)
192
+
193
+ if message is not None and additional_message is not None:
194
+ message = f"{message} {additional_message}"
195
+
196
+ # update minimum_action if argument is ALREADY deprecated (current version >= deprecated version)
197
+ if is_greater_or_equal_version:
198
+ # change to NOTIFY -> RAISE in case we want to raise error for already deprecated arguments
199
+ if raise_if_greater_or_equal_version:
200
+ minimum_action = Action.RAISE
201
+
202
+ # raise error or notify user
203
+ if minimum_action == Action.RAISE:
204
+ raise ValueError(message)
205
+ elif minimum_action == Action.NOTIFY:
206
+ # DeprecationWarning is ignored by default, so we use FutureWarning instead
207
+ logger.warning(message, stacklevel=2)
208
+
209
+ return func(*args, **kwargs)
210
+
211
+ return wrapped_func
212
+
213
+ return wrapper
optimum/rbln/utils/hub.py CHANGED
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import json
15
16
  from pathlib import Path
16
17
  from typing import List, Optional, Union
17
18
 
@@ -67,15 +68,25 @@ def validate_files(
67
68
  location: str,
68
69
  ):
69
70
  """Validate the presence and count of required files."""
70
- if len(files) == 0:
71
- raise FileNotFoundError(f"Could not find any rbln model file in {location}")
72
-
73
71
  if len(config_files) == 0:
74
72
  raise FileNotFoundError(f"Could not find `rbln_config.json` file in {location}")
75
73
 
76
74
  if len(config_files) > 1:
77
75
  raise FileExistsError(f"Multiple rbln_config.json files found in {location}. This is not expected.")
78
76
 
77
+ try:
78
+ with open(config_files[0], "r") as f:
79
+ config_data = json.load(f)
80
+ compile_cfgs = config_data.get("_compile_cfgs", [])
81
+ if len(compile_cfgs) == 0:
82
+ # If compile_cfgs is empty, we don't need .rbln files
83
+ return
84
+ except (json.JSONDecodeError, KeyError, OSError):
85
+ pass
86
+
87
+ if len(files) == 0:
88
+ raise FileNotFoundError(f"Could not find any rbln model file in {location}")
89
+
79
90
 
80
91
  def _get_huggingface_token(token: Union[bool, str]) -> str:
81
92
  if isinstance(token, str):
@@ -136,13 +136,32 @@ def is_rbln_available() -> bool:
136
136
 
137
137
  def check_version_compats() -> None:
138
138
  warnings.filterwarnings(action="always", category=ImportWarning, module="optimum.rbln")
139
- my_version = importlib.metadata.version("optimum-rbln")
139
+ try:
140
+ my_version = importlib.metadata.version("optimum-rbln")
141
+ except importlib.metadata.PackageNotFoundError:
142
+ # Common dev case: running from source (e.g. PYTHONPATH=src) without installing the package.
143
+ # package metadata doesn't exist, so fall back to the in-repo version file.
144
+ try:
145
+ from optimum.rbln.__version__ import __version__ as my_version # type: ignore
146
+ except Exception:
147
+ warnings.warn(
148
+ "Could not determine optimum-rbln version (package metadata missing). "
149
+ "If you are running from source, consider `pip install -e .` to install metadata.",
150
+ ImportWarning,
151
+ stacklevel=2,
152
+ )
153
+ return
154
+
140
155
  target_version = list(filter(lambda v: Version(my_version) >= Version(v), RBLN_VERSION_COMPATS.keys()))[0]
141
156
  for compat in RBLN_VERSION_COMPATS[target_version]:
142
157
  try:
143
158
  dep_version = importlib.metadata.version(compat.package_name)
144
159
  except importlib.metadata.PackageNotFoundError:
145
- warnings.warn(f"optimum-rbln requires {compat.package_name} to be installed.", ImportWarning)
160
+ warnings.warn(
161
+ f"optimum-rbln requires {compat.package_name} to be installed.",
162
+ ImportWarning,
163
+ stacklevel=2,
164
+ )
146
165
  continue
147
166
  # For versions 0.7.2 and above, don't show warning for rebel-compiler if base versions match
148
167
 
@@ -160,6 +179,7 @@ def check_version_compats() -> None:
160
179
  f"For optimal performance and compatibility, please ensure both packages share the same major and minor version numbers. "
161
180
  "Please refer to our SDK release notes at https://docs.rbln.ai/about_atom/release_note.html",
162
181
  ImportWarning,
182
+ stacklevel=2,
163
183
  )
164
184
  else:
165
185
  if not Version(compat.min_version) <= Version(dep_version) < Version(compat.max_version):
@@ -167,4 +187,5 @@ def check_version_compats() -> None:
167
187
  f"optimum-rbln v{my_version} is compatible to {compat.package_name} v{compat.min_version} to v{compat.max_version}. (you are currently using v{dep_version})\n"
168
188
  "Please refer to our SDK release notes at https://docs.rbln.ai/about_atom/release_note.html",
169
189
  ImportWarning,
190
+ stacklevel=2,
170
191
  )
@@ -20,6 +20,42 @@ import rebel
20
20
  import torch
21
21
 
22
22
 
23
+ def is_compiler_supports_buffer_resize() -> bool:
24
+ return hasattr(rebel.RBLNCompiledModel, "exp_multiply_buffer_size")
25
+
26
+
27
+ def get_available_dram(npu: Optional[str] = None) -> int:
28
+ """
29
+ Get the available DRAM size of the specified NPU.
30
+
31
+ Args:
32
+ npu : Optional[str], default=None
33
+ The NPU to get the available DRAM size.
34
+ If None, the function will attempt to retrieve through `ensure_valid_npu()`
35
+
36
+ Returns:
37
+ int
38
+ The available DRAM size in bytes.
39
+ """
40
+ if npu is None:
41
+ if not rebel.npu_is_available(0):
42
+ raise RuntimeError("No NPU is available to get available DRAM size.")
43
+
44
+ npu = rebel.get_npu_name(0)
45
+
46
+ if npu.startswith("RBLN-CR"):
47
+ # TODO(jongho): Assuming 4 chiplets.
48
+ DRAM_NBYTES = 144 * 2**30
49
+ SYS_DRAM_NBYTES = 4 * 2**30
50
+ elif npu.startswith("RBLN-CA"):
51
+ DRAM_NBYTES = 16 * 2**30
52
+ SYS_DRAM_NBYTES = 288 * 2**20
53
+ else:
54
+ raise ValueError(f"Unknown npu name: {npu}")
55
+
56
+ return DRAM_NBYTES - SYS_DRAM_NBYTES
57
+
58
+
23
59
  def normalize_npu(npu: str) -> str:
24
60
  """Normalize the NPU string by removing the form factor."""
25
61
  match = re.match(r"(RBLN-CA|RBLN-CR)(\d+)", npu)
@@ -43,12 +79,6 @@ def tp_and_devices_are_ok(
43
79
  if tensor_parallel_size is None:
44
80
  tensor_parallel_size = 1
45
81
 
46
- if rebel.device_count() < tensor_parallel_size:
47
- return (
48
- f"Tensor parallel size {tensor_parallel_size} is greater than "
49
- f"the number of available devices {rebel.device_count()}."
50
- )
51
-
52
82
  if device is None:
53
83
  device = list(range(tensor_parallel_size))
54
84
  elif isinstance(device, int):
@@ -71,6 +101,12 @@ def tp_and_devices_are_ok(
71
101
  f"Device {device_id} is not a valid NPU device. Please check your NPU status with 'rbln-stat' command."
72
102
  )
73
103
 
104
+ if rebel.device_count() < tensor_parallel_size:
105
+ return (
106
+ f"Tensor parallel size {tensor_parallel_size} is greater than "
107
+ f"the number of available devices {rebel.device_count()}."
108
+ )
109
+
74
110
  if npu is not None:
75
111
  for device_id in device:
76
112
  npu_name = rebel.get_npu_name(device_id)
@@ -37,7 +37,9 @@ class SubModulesMixin:
37
37
 
38
38
  _rbln_submodules: List[Dict[str, Any]] = []
39
39
 
40
- def __init__(self, *, rbln_submodules: List["RBLNModel"] = [], **kwargs) -> None:
40
+ def __init__(self, *, rbln_submodules: Optional[List["RBLNModel"]] = None, **kwargs) -> None:
41
+ if rbln_submodules is None:
42
+ rbln_submodules = []
41
43
  for submodule_meta, submodule in zip(self._rbln_submodules, rbln_submodules):
42
44
  setattr(self, submodule_meta["name"], submodule)
43
45
 
@@ -59,12 +61,25 @@ class SubModulesMixin:
59
61
  ):
60
62
  return rbln_config
61
63
 
64
+ @classmethod
65
+ def _update_submodule_rbln_config(
66
+ cls,
67
+ submodule_name: str,
68
+ submodule_cls: Type["RBLNModel"],
69
+ model: "PreTrainedModel",
70
+ submodule_config: PretrainedConfig,
71
+ submodule_rbln_config: RBLNModelConfig,
72
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
73
+ ):
74
+ return submodule_rbln_config
75
+
62
76
  @classmethod
63
77
  def _export_submodules_from_model(
64
78
  cls, model: "PreTrainedModel", model_save_dir: str, rbln_config: RBLNModelConfig, **kwargs
65
79
  ) -> List["RBLNModel"]:
66
80
  rbln_submodules = []
67
81
  submodule_prefix = getattr(cls, "_rbln_submodule_prefix", None)
82
+ submodule_postfix = getattr(cls, "_rbln_submodule_postfix", None)
68
83
  preprocessors = kwargs.pop("preprocessors", [])
69
84
 
70
85
  for submodule in cls._rbln_submodules:
@@ -72,6 +87,9 @@ class SubModulesMixin:
72
87
  if submodule_prefix is not None:
73
88
  torch_submodule: PreTrainedModel = getattr(model, submodule_prefix)
74
89
  torch_submodule = getattr(torch_submodule, submodule_name)
90
+ elif submodule_postfix is not None:
91
+ torch_submodule: PreTrainedModel = getattr(model, submodule_name)
92
+ torch_submodule = getattr(torch_submodule, submodule_postfix)
75
93
  else:
76
94
  torch_submodule: PreTrainedModel = getattr(model, submodule_name)
77
95
 
@@ -90,6 +108,14 @@ class SubModulesMixin:
90
108
  filtered_kwargs["cls_name"] = submodule_config_cls.__name__
91
109
  submodule_rbln_config = submodule_config_cls(**filtered_kwargs)
92
110
 
111
+ submodule_rbln_config = cls._update_submodule_rbln_config(
112
+ submodule_name=submodule_name,
113
+ submodule_cls=submodule_cls,
114
+ model=model,
115
+ submodule_config=torch_submodule.config,
116
+ submodule_rbln_config=submodule_rbln_config,
117
+ preprocessors=preprocessors,
118
+ )
93
119
  setattr(rbln_config, submodule_name, submodule_rbln_config)
94
120
  submodule_rbln_config = submodule_cls._update_submodule_config(model, submodule_rbln_config, preprocessors)
95
121
 
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.9.3rc0
3
+ Version: 0.9.5a4
4
4
  Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
7
- Project-URL: Repository, https://github.com/rebellions-sw/optimum-rbln
7
+ Project-URL: Repository, https://github.com/rbln-sw/optimum-rbln
8
8
  Author-email: "Rebellions Inc." <support@rebellions.ai>
9
9
  License-Expression: Apache-2.0
10
10
  License-File: LICENSE
@@ -24,12 +24,12 @@ Classifier: Programming Language :: Python :: 3.13
24
24
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
25
25
  Requires-Python: <3.14,>=3.9
26
26
  Requires-Dist: accelerate>=1.0.1
27
- Requires-Dist: diffusers==0.35.1
27
+ Requires-Dist: diffusers==0.36.0
28
28
  Requires-Dist: packaging>=24.1
29
29
  Requires-Dist: torch==2.8.0
30
30
  Requires-Dist: torchaudio<=2.8.0
31
31
  Requires-Dist: torchvision<=0.23.0
32
- Requires-Dist: transformers==4.57.1
32
+ Requires-Dist: transformers==4.57.3
33
33
  Description-Content-Type: text/markdown
34
34
 
35
35
 
@@ -40,7 +40,7 @@ Description-Content-Type: text/markdown
40
40
  <img src="assets/rbln_logo.png" width="60%"/>
41
41
 
42
42
  [![PyPI version](https://badge.fury.io/py/optimum-rbln.svg)](https://badge.fury.io/py/optimum-rbln)
43
- [![License](https://img.shields.io/github/license/rebellions-sw/optimum-rbln)](https://github.com/rebellions-sw/optimum-rbln/blob/main/LICENSE)
43
+ [![License](https://img.shields.io/github/license/rbln-sw/optimum-rbln)](https://github.com/rbln-sw/optimum-rbln/blob/main/LICENSE)
44
44
  [![Documentation](https://img.shields.io/badge/docs-available-brightgreen)](https://docs.rbln.ai/software/optimum/optimum_rbln.html)
45
45
  [![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg)](CODE_OF_CONDUCT.md)
46
46
 
@@ -113,7 +113,7 @@ pip install optimum-rbln --extra-index-url https://download.pytorch.org/whl/cpu
113
113
  The below command installs `optimum-rbln` along with its dependencies.
114
114
 
115
115
  ```bash
116
- git clone https://github.com/rebellions-sw/optimum-rbln.git
116
+ git clone https://github.com/rbln-sw/optimum-rbln.git
117
117
  cd optimum-rbln
118
118
  ./scripts/uv-sync.sh
119
119
  ```