optimum-rbln 0.9.5a4__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +196 -52
  4. optimum/rbln/diffusers/models/controlnet.py +2 -2
  5. optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
  6. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
  7. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
  8. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  13. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
  14. optimum/rbln/modeling_base.py +5 -4
  15. optimum/rbln/transformers/__init__.py +8 -0
  16. optimum/rbln/transformers/modeling_attention_utils.py +15 -9
  17. optimum/rbln/transformers/models/__init__.py +10 -0
  18. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  19. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
  20. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
  21. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +2 -2
  22. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +26 -1
  23. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +2 -1
  24. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +45 -21
  25. optimum/rbln/transformers/models/detr/__init__.py +23 -0
  26. optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
  27. optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
  28. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
  29. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +4 -176
  30. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +4 -3
  31. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +10 -7
  32. optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
  33. optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
  34. optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
  35. optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
  36. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +7 -7
  37. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
  38. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +2 -0
  39. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +2 -0
  40. optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
  41. optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
  42. optimum/rbln/utils/deprecation.py +78 -1
  43. optimum/rbln/utils/hub.py +93 -2
  44. optimum/rbln/utils/runtime_utils.py +2 -2
  45. {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +1 -1
  46. {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +49 -42
  47. {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
  48. {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
  49. {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
@@ -17,7 +17,7 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
17
17
 
18
18
  class RBLNGptOssForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
19
  """
20
- Configuration class for RBLN GPT-OSS models.
20
+ Configuration class for RBLN GptOss models.
21
21
 
22
22
  This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
23
 
@@ -28,14 +28,15 @@ class RBLNGptOssForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
28
28
  # Create a configuration object
29
29
  config = RBLNGptOssForCausalLMConfig(
30
30
  batch_size=1,
31
- tensor_parallel_size=4
31
+ tensor_parallel_size=8,
32
+ kvcache_partition_len=8192,
32
33
  )
33
34
 
34
35
  # Use the configuration with from_pretrained
35
36
  model = RBLNGptOssForCausalLM.from_pretrained(
36
37
  "openai/gpt-oss-20b",
37
38
  export=True,
38
- rbln_config=config
39
+ rbln_config=config,
39
40
  )
40
41
  ```
41
42
  """
@@ -41,8 +41,8 @@ class RBLNGptOssForCausalLM(RBLNDecoderOnlyModelForCausalLM):
41
41
  The GPT-OSS Model transformer with a language modeling head (linear layer) on top.
42
42
  This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
43
43
 
44
- A class to convert and run pre-trained transformers based GPT-OSSForCausalLM model on RBLN devices.
45
- It implements the methods to convert a pre-trained transformers GPT-OSSForCausalLM model into a RBLN transformer model by:
44
+ A class to convert and run pre-trained transformers based GptOssForCausalLM model on RBLN devices.
45
+ It implements the methods to convert a pre-trained transformers GptOssForCausalLM model into a RBLN transformer model by:
46
46
  - transferring the checkpoint weights of the original into an optimized RBLN graph,
47
47
  - compiling the resulting graph using the RBLN compiler.
48
48
 
@@ -62,19 +62,21 @@ class RBLNGptOssForCausalLM(RBLNDecoderOnlyModelForCausalLM):
62
62
  "openai/gpt-oss-20b",
63
63
  export=True,
64
64
  rbln_batch_size=1,
65
- rbln_tensor_parallel_size=4,
65
+ rbln_tensor_parallel_size=8,
66
+ rbln_kvcache_partition_len=8192,
66
67
  )
67
68
 
68
69
 
69
70
  # Using a config dictionary
70
71
  rbln_config = {
71
72
  "batch_size": 1,
72
- "tensor_parallel_size": 4,
73
+ "tensor_parallel_size": 8,
74
+ "kvcache_partition_len": 8192,
73
75
  }
74
76
  model = RBLNGptOssForCausalLM.from_pretrained(
75
77
  "openai/gpt-oss-20b",
76
78
  export=True,
77
- rbln_config=rbln_config
79
+ rbln_config=rbln_config,
78
80
  )
79
81
 
80
82
 
@@ -83,12 +85,13 @@ class RBLNGptOssForCausalLM(RBLNDecoderOnlyModelForCausalLM):
83
85
 
84
86
  config = RBLNGptOssForCausalLMConfig(
85
87
  batch_size=1,
86
- tensor_parallel_size=4
88
+ tensor_parallel_size=8,
89
+ kvcache_partition_len=8192,
87
90
  )
88
91
  model = RBLNGptOssForCausalLM.from_pretrained(
89
92
  "openai/gpt-oss-20b",
90
93
  export=True,
91
- rbln_config=config
94
+ rbln_config=config,
92
95
  )
93
96
  ```
94
97
  """
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_mixtral import RBLNMixtralForCausalLMConfig
16
+ from .modeling_mixtral import RBLNMixtralForCausalLM
@@ -0,0 +1,38 @@
1
+ # Copyright 2026 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNMixtralForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN Mixtral models.
21
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
22
+ Example usage:
23
+ ```python
24
+ from optimum.rbln import RBLNMixtralForCausalLM, RBLNMixtralForCausalLMConfig
25
+ # Create a configuration object
26
+ config = RBLNMixtralForCausalLMConfig(
27
+ batch_size=1,
28
+ max_seq_len=32768,
29
+ tensor_parallel_size=4
30
+ )
31
+ # Use the configuration with from_pretrained
32
+ model = RBLNMixtralForCausalLM.from_pretrained(
33
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
34
+ export=True,
35
+ rbln_config=config
36
+ )
37
+ ```
38
+ """
@@ -0,0 +1,76 @@
1
+ # Copyright 2026 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ..decoderonly.configuration_decoderonly import RBLNLoRAConfig
21
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyLayer, DecoderOnlyWrapper
22
+
23
+
24
+ class MixtralWrapper(DecoderOnlyWrapper):
25
+ def get_rbln_layer_class(self):
26
+ return MixtralLayer
27
+
28
+
29
+ class MixtralLayer(DecoderOnlyLayer):
30
+ _MLP_ATTR = ("block_sparse_moe",)
31
+
32
+ def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config: Optional[RBLNLoRAConfig] = None):
33
+ super().__init__(layer, self_attn, lora_config)
34
+ self.block_sparse_moe = MixtralSparseMoeBlock(self.mlp)
35
+
36
+ def get_mlp(self) -> nn.Module:
37
+ return self.block_sparse_moe
38
+
39
+
40
+ class MixtralSparseMoeBlock(nn.Module):
41
+ def __init__(self, model: nn.Module):
42
+ super().__init__()
43
+ # self.num_experts = model.num_experts
44
+ self.top_k = model.top_k
45
+ self.gate = model.gate
46
+ self.experts = MixtralBlockSparseTop2MLP(model.experts, self.top_k)
47
+
48
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
49
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
50
+ hidden_states = hidden_states.view(-1, hidden_dim)
51
+ # router_logits: (batch * sequence_length, n_experts)
52
+ router_logits = self.gate(hidden_states)
53
+ final_hidden_states = self.experts(hidden_states, router_logits)
54
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
55
+ return final_hidden_states
56
+
57
+
58
+ class MixtralBlockSparseTop2MLP(nn.Module):
59
+ def __init__(self, expert_list, top_k):
60
+ super().__init__()
61
+ self.hidden_dim = expert_list[0].hidden_dim
62
+ self.ffn_dim = expert_list[0].ffn_dim
63
+ self.top_k = top_k
64
+
65
+ self.num_experts = len(expert_list)
66
+ self.w1 = nn.Linear(self.hidden_dim, self.num_experts * self.ffn_dim, bias=False)
67
+ self.w2 = nn.Linear(self.num_experts * self.ffn_dim, self.hidden_dim, bias=False)
68
+ self.w3 = nn.Linear(self.hidden_dim, self.num_experts * self.ffn_dim, bias=False)
69
+ self.w1.weight.data = torch.stack([expert.w1.weight.data for expert in expert_list], dim=0)
70
+ self.w2.weight.data = torch.stack([expert.w2.weight.data for expert in expert_list], dim=0)
71
+ self.w3.weight.data = torch.stack([expert.w3.weight.data for expert in expert_list], dim=0)
72
+
73
+ def forward(self, x, router_logits):
74
+ return torch.ops.rbln_custom_ops.custom_moe_glu(
75
+ x, self.w1.weight, self.w3.weight, self.w2.weight, router_logits, self.top_k, True
76
+ )
@@ -0,0 +1,68 @@
1
+ # Copyright 2026 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from .mixtral_architecture import MixtralWrapper
17
+
18
+
19
+ class RBLNMixtralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
20
+ """
21
+ The Mixtral is a Mixture-of-Experts (MoE) variant of Mixtral, available as a base model and an aligned chat model.
22
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
23
+ A class to convert and run pre-trained transformers based MixtralForCausalLM model on RBLN devices.
24
+ It implements the methods to convert a pre-trained transformers MixtralForCausalLM model into a RBLN transformer model by:
25
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
26
+ - compiling the resulting graph using the RBLN compiler.
27
+ **Configuration:**
28
+ This model uses [`RBLNMixtralForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
29
+ the `rbln_config` parameter should be an instance of [`RBLNMixtralForCausalLMConfig`] or a dictionary conforming to its structure.
30
+ See the [`RBLNMixtralForCausalLMConfig`] class for all available configuration options.
31
+ Examples:
32
+ ```python
33
+ from optimum.rbln import RBLNMixtralForCausalLM
34
+ # Simple usage using rbln_* arguments
35
+ # `max_seq_len` is automatically inferred from the model config
36
+ model = RBLNMixtralForCausalLM.from_pretrained(
37
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
38
+ export=True,
39
+ rbln_batch_size=1,
40
+ rbln_tensor_parallel_size=4,
41
+ )
42
+ # Using a config dictionary
43
+ rbln_config = {
44
+ "batch_size": 1,
45
+ "max_seq_len": 32768,
46
+ "tensor_parallel_size": 4,
47
+ }
48
+ model = RBLNMixtralForCausalLM.from_pretrained(
49
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
50
+ export=True,
51
+ rbln_config=rbln_config
52
+ )
53
+ # Using a RBLNMixtralForCausalLMConfig instance (recommended for type checking)
54
+ from optimum.rbln import RBLNMixtralForCausalLMConfig
55
+ config = RBLNMixtralForCausalLMConfig(
56
+ batch_size=1,
57
+ max_seq_len=32768,
58
+ tensor_parallel_size=4
59
+ )
60
+ model = RBLNMixtralForCausalLM.from_pretrained(
61
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
62
+ export=True,
63
+ rbln_config=config
64
+ )
65
+ ```
66
+ """
67
+
68
+ _decoder_wrapper_cls = MixtralWrapper
@@ -525,13 +525,13 @@ class RBLNPaliGemmaModel(RBLNModel):
525
525
  Forward pass for the RBLN-optimized PaliGemmaModel model.
526
526
 
527
527
  Args:
528
- input_ids (torch.LongTensor of shape (batch_size, sequence_length)) Indices of input sequence tokens in the vocabulary.
529
- pixel_values (torch.Tensor of shape (batch_size, num_channels, image_size, image_size)) The tensors corresponding to the input images.
530
- attention_mask (torch.Tensor of shape (batch_size, sequence_length)) Mask to avoid performing attention on padding token indices.
531
- position_ids (torch.LongTensor of shape (batch_size, sequence_length)) Indices of positions of each input sequence tokens in the position embeddings.
532
- token_type_ids (torch.LongTensor of shape (batch_size, sequence_length)) Segment token indices to indicate first and second portions of the inputs.
533
- output_hidden_states (bool, optional) Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
534
- return_dict (bool, optional) Whether or not to return a ModelOutput instead of a plain tuple.
528
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length)): Indices of input sequence tokens in the vocabulary.
529
+ pixel_values (torch.Tensor of shape (batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images.
530
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length)): Mask to avoid performing attention on padding token indices.
531
+ position_ids (torch.LongTensor of shape (batch_size, sequence_length)): Indices of positions of each input sequence tokens in the position embeddings.
532
+ token_type_ids (torch.LongTensor of shape (batch_size, sequence_length)): Segment token indices to indicate first and second portions of the inputs.
533
+ output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
534
+ return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
535
535
 
536
536
  Returns:
537
537
  PaligemmaModelOutputWithPast or tuple(torch.FloatTensor)
@@ -297,13 +297,17 @@ class RBLNPixtralVisionModel(RBLNModel):
297
297
  Forward pass for the RBLN-optimized Pixtral vision model.
298
298
 
299
299
  Args:
300
- pixel_values (torch.Tensor of shape (batch_size, num_channels, image_size, image_size)) — The tensors corresponding to the input images. Pixel values can be obtained using PixtralImageProcessor. See PixtralImageProcessor.call() for details (PixtralProcessor uses PixtralImageProcessor for processing images).
301
- image_sizes (torch.Tensor of shape (batch_size, 2), optional) The sizes of the images in the batch, being (height, width) for each image.
302
- output_hidden_states (bool, optional) Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
303
- return_dict (bool, optional) Whether or not to return a ModelOutput instead of a plain tuple.
300
+ pixel_values: Input images as a tensor of shape (batch_size, num_channels, image_size, image_size).
301
+ Pixel values can be obtained using PixtralImageProcessor. See PixtralImageProcessor.__call__()
302
+ for details (PixtralProcessor uses PixtralImageProcessor for processing images).
303
+ image_sizes: The sizes of the images in the batch as a tensor of shape (batch_size, 2),
304
+ being (height, width) for each image. Optional.
305
+ output_hidden_states: Whether or not to return the hidden states of all layers. Optional.
306
+ See hidden_states under returned tensors for more detail.
307
+ return_dict: Whether or not to return a ModelOutput instead of a plain tuple. Optional.
304
308
 
305
309
  Returns:
306
- BaseModelOutput or tuple(torch.FloatTensor)
310
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutput object.
307
311
  """
308
312
  output_hidden_states = (
309
313
  output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
@@ -61,6 +61,8 @@ class RBLNQwen2_5_VLModelConfig(RBLNDecoderOnlyModelConfig):
61
61
  Configuration class for RBLNQwen2_5_VLModel.
62
62
  """
63
63
 
64
+ submodules = ["visual"]
65
+
64
66
  def __init__(self, visual: Optional[RBLNModelConfig] = None, **kwargs: Any):
65
67
  super().__init__(**kwargs)
66
68
  self.visual = self.initialize_submodule_config(submodule_config=visual)
@@ -53,6 +53,8 @@ class RBLNQwen2VLModelConfig(RBLNDecoderOnlyModelConfig):
53
53
  Configuration class for RBLNQwen2VLModel.
54
54
  """
55
55
 
56
+ submodules = ["visual"]
57
+
56
58
  def __init__(self, visual: Optional[RBLNModelConfig] = None, **kwargs: Dict[str, Any]):
57
59
  super().__init__(**kwargs)
58
60
  self.visual = self.initialize_submodule_config(submodule_config=visual)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- from typing import Optional
16
+ from typing import Any, Optional, Tuple, Union
17
17
 
18
18
  from ...configuration_generic import RBLNModelForImageClassificationConfig
19
19
 
@@ -26,17 +26,23 @@ class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConf
26
26
  RBLN-optimized ResNet models for image classification tasks.
27
27
  """
28
28
 
29
- def __init__(self, output_hidden_states: Optional[bool] = None, **kwargs):
29
+ def __init__(
30
+ self,
31
+ image_size: Optional[Union[int, Tuple[int, int]]] = None,
32
+ batch_size: Optional[int] = None,
33
+ output_hidden_states: Optional[bool] = None,
34
+ **kwargs: Any,
35
+ ):
30
36
  """
31
37
  Args:
32
38
  image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
33
39
  Can be an integer for square images or a tuple (height, width).
34
40
  batch_size (Optional[int]): The batch size for inference. Defaults to 1.
35
- output_hidden_states (bool, optional) Whether or not to return the hidden states of all layers.
41
+ output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers.
36
42
  kwargs: Additional arguments passed to the parent RBLNModelConfig.
37
43
 
38
44
  Raises:
39
45
  ValueError: If batch_size is not a positive integer.
40
46
  """
41
- super().__init__(**kwargs)
47
+ super().__init__(image_size=image_size, batch_size=batch_size, **kwargs)
42
48
  self.output_hidden_states = output_hidden_states
@@ -51,22 +51,22 @@ class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
51
51
  return_segments: Optional[bool] = None,
52
52
  return_timestamps: Optional[bool] = None,
53
53
  return_token_timestamps: Optional[bool] = None,
54
- **kwargs,
54
+ **kwargs: Optional[Dict[str, Any]],
55
55
  ) -> Union[ModelOutput, Dict[str, Any], torch.LongTensor]:
56
56
  """
57
57
  The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
58
58
  Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate) for more details.
59
59
 
60
60
  Args:
61
- input_features(torch.Tensor, optional): The input features to the model.
62
- attention_mask(torch.Tensor, optional): Attention mask needs to be passed when doing long-form transcription using a batch size > 1.
63
- generation_config(GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
61
+ input_features (torch.Tensor, optional): The input features to the model.
62
+ attention_mask (torch.Tensor, optional): Attention mask needs to be passed when doing long-form transcription using a batch size > 1.
63
+ generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
64
64
  If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
65
65
  Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
66
- return_segments(bool, optional): Whether to return segments.
67
- return_timestamps(bool, optional): Whether to return the timestamps with the text. For audios longer than 30 seconds, it is necessary to set return_timestamps=True.
68
- return_token_timestamps(bool, optional): Whether to return token timestamps.
69
- kwargs(dict[str, Any], optional): Additional arguments passed to the generate function.
66
+ return_segments (bool, optional): Whether to return segments.
67
+ return_timestamps (bool, optional): Whether to return the timestamps with the text. For audios longer than 30 seconds, it is necessary to set return_timestamps=True.
68
+ return_token_timestamps (bool, optional): Whether to return token timestamps.
69
+ kwargs (dict[str, Any], optional): Additional arguments passed to the generate function.
70
70
 
71
71
  Returns:
72
72
  Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
@@ -194,7 +194,7 @@ def deprecate_kwarg(
194
194
  message = f"{message} {additional_message}"
195
195
 
196
196
  # update minimum_action if argument is ALREADY deprecated (current version >= deprecated version)
197
- if is_greater_or_equal_version:
197
+ if is_greater_or_equal_version and message is not None:
198
198
  # change to NOTIFY -> RAISE in case we want to raise error for already deprecated arguments
199
199
  if raise_if_greater_or_equal_version:
200
200
  minimum_action = Action.RAISE
@@ -211,3 +211,80 @@ def deprecate_kwarg(
211
211
  return wrapped_func
212
212
 
213
213
  return wrapper
214
+
215
+
216
+ def deprecate_method(
217
+ version: str,
218
+ new_method: Optional[str] = None,
219
+ raise_if_greater_or_equal_version: bool = True,
220
+ additional_message: Optional[str] = None,
221
+ ):
222
+ """
223
+ Decorator to mark a method as deprecated, optionally pointing to a replacement method.
224
+ This decorator allows you to:
225
+ - Notify users when a method is deprecated.
226
+ - Optionally specify a new method name that should be used instead.
227
+ - Raise an error if the deprecated method is called after the specified version.
228
+ Parameters:
229
+ version (`str`):
230
+ The version in which the method was (or will be) deprecated.
231
+ new_method (`Optional[str]`, *optional*):
232
+ The name of the new method to use instead. If specified, users will be directed to use this method.
233
+ raise_if_greater_or_equal_version (`bool`, *optional*, defaults to `True`):
234
+ Whether to raise `ValueError` if current `optimum.rbln` version is greater than or equal to the deprecated version.
235
+ additional_message (`Optional[str]`, *optional*):
236
+ An additional message to append to the default deprecation message.
237
+ Returns:
238
+ Callable:
239
+ A wrapped function that handles the deprecation warning or error.
240
+ Examples:
241
+ >>> class MyClass:
242
+ ... @deprecate_method(version="0.12.0", new_method="from_pretrained")
243
+ ... def load(self, path):
244
+ ... return self.from_pretrained(path)
245
+ """
246
+
247
+ deprecated_version = packaging.version.parse(version)
248
+ current_version = packaging.version.parse(__version__)
249
+ is_greater_or_equal_version = current_version >= deprecated_version
250
+
251
+ if is_greater_or_equal_version:
252
+ version_message = f"and removed starting from version {version}"
253
+ else:
254
+ version_message = f"and will be removed in version {version}"
255
+
256
+ def wrapper(func):
257
+ sig = inspect.signature(func)
258
+ function_named_args = set(sig.parameters.keys())
259
+ is_instance_method = "self" in function_named_args
260
+ is_class_method = "cls" in function_named_args
261
+
262
+ @wraps(func)
263
+ def wrapped_func(*args, **kwargs):
264
+ # Get class + method name for better warning message
265
+ method_name = func.__name__
266
+ if is_instance_method:
267
+ method_name = f"{args[0].__class__.__name__}.{method_name}"
268
+ elif is_class_method:
269
+ method_name = f"{args[0].__name__}.{method_name}"
270
+
271
+ # Build deprecation message
272
+ if new_method is not None:
273
+ message = f"`{method_name}` is deprecated {version_message}. Use `{new_method}` instead."
274
+ else:
275
+ message = f"`{method_name}` is deprecated {version_message}."
276
+
277
+ if additional_message is not None:
278
+ message = f"{message} {additional_message}"
279
+
280
+ # Determine action based on version
281
+ if is_greater_or_equal_version and raise_if_greater_or_equal_version:
282
+ raise ValueError(message)
283
+ else:
284
+ logger.warning(message, stacklevel=2)
285
+
286
+ return func(*args, **kwargs)
287
+
288
+ return wrapped_func
289
+
290
+ return wrapper
optimum/rbln/utils/hub.py CHANGED
@@ -16,7 +16,8 @@ import json
16
16
  from pathlib import Path
17
17
  from typing import List, Optional, Union
18
18
 
19
- from huggingface_hub import HfApi, get_token, hf_hub_download
19
+ from huggingface_hub import HfApi, get_token, hf_hub_download, try_to_load_from_cache
20
+ from huggingface_hub.errors import LocalEntryNotFoundError
20
21
 
21
22
 
22
23
  def pull_compiled_model_from_hub(
@@ -29,6 +30,97 @@ def pull_compiled_model_from_hub(
29
30
  local_files_only: bool,
30
31
  ) -> Path:
31
32
  """Pull model files from the HuggingFace Hub."""
33
+ config_filename = "rbln_config.json" if subfolder == "" else f"{subfolder}/rbln_config.json"
34
+
35
+ # Try to find config file in cache first.
36
+ config_cache_path = try_to_load_from_cache(
37
+ repo_id=str(model_id),
38
+ filename=config_filename,
39
+ revision=revision,
40
+ cache_dir=cache_dir,
41
+ )
42
+
43
+ # If config is cached and we're not forcing download, try to use cached files
44
+ if config_cache_path and isinstance(config_cache_path, str) and not force_download:
45
+ config_path = Path(config_cache_path)
46
+ if config_path.exists():
47
+ cache_dir_path = config_path.parent
48
+
49
+ # Look for .rbln files in the same directory
50
+ pattern_rbln = "*.rbln"
51
+ rbln_files = list(cache_dir_path.glob(pattern_rbln))
52
+
53
+ # Validate files found in cache
54
+ rbln_config_filenames = [config_path] if config_path.exists() else []
55
+ validate_files(rbln_files, rbln_config_filenames, f"cached repository {model_id}")
56
+
57
+ # If local_files_only is True, return cached directory without API call
58
+ if local_files_only:
59
+ return cache_dir_path
60
+
61
+ # If local_files_only is False, ensure all files are downloaded
62
+ # Download config file (will use cache if available, download if missing)
63
+ rbln_config_cache_path = hf_hub_download(
64
+ repo_id=model_id,
65
+ filename=config_filename,
66
+ token=token,
67
+ revision=revision,
68
+ cache_dir=cache_dir,
69
+ force_download=force_download,
70
+ local_files_only=False,
71
+ )
72
+ cache_dir_path = Path(rbln_config_cache_path).parent
73
+
74
+ # Download all .rbln files found in cache (hf_hub_download will use cache if available)
75
+ for rbln_file in rbln_files:
76
+ filename = rbln_file.name if subfolder == "" else f"{subfolder}/{rbln_file.name}"
77
+ try:
78
+ hf_hub_download(
79
+ repo_id=model_id,
80
+ filename=filename,
81
+ token=token,
82
+ revision=revision,
83
+ cache_dir=cache_dir,
84
+ force_download=force_download,
85
+ local_files_only=False,
86
+ )
87
+ except LocalEntryNotFoundError:
88
+ # File might not exist in repo, skip it
89
+ pass
90
+
91
+ # Note: We skip the API call here since we're using cached files
92
+ # If there are additional files in the repo that aren't cached,
93
+ # they won't be downloaded.
94
+ # If the user needs all files, they should use force_download=True
95
+ return cache_dir_path
96
+
97
+ # If local_files_only is True and config not found in cache, try to download with local_files_only
98
+ if local_files_only:
99
+ try:
100
+ rbln_config_cache_path = hf_hub_download(
101
+ repo_id=model_id,
102
+ filename=config_filename,
103
+ token=token,
104
+ revision=revision,
105
+ cache_dir=cache_dir,
106
+ force_download=force_download,
107
+ local_files_only=True,
108
+ )
109
+ cache_dir_path = Path(rbln_config_cache_path).parent
110
+ rbln_files = list(cache_dir_path.glob("*.rbln"))
111
+ rbln_config_filenames = [Path(rbln_config_cache_path)] if Path(rbln_config_cache_path).exists() else []
112
+ validate_files(rbln_files, rbln_config_filenames, f"cached repository {model_id}")
113
+ return cache_dir_path
114
+ except LocalEntryNotFoundError as err:
115
+ raise FileNotFoundError(
116
+ f"Could not find compiled model files for {model_id} in local cache. "
117
+ f"Set local_files_only=False to download from HuggingFace Hub."
118
+ ) from err
119
+
120
+ # List files from repository. This only happens when:
121
+ # 1. Config is not cached, OR
122
+ # 2. force_download=True, OR
123
+ # 3. local_files_only=False and we need to discover all files in the repo
32
124
  huggingface_token = _get_huggingface_token(token)
33
125
  repo_files = list(
34
126
  map(
@@ -51,7 +143,6 @@ def pull_compiled_model_from_hub(
51
143
  rbln_config_cache_path = hf_hub_download(
52
144
  repo_id=model_id,
53
145
  filename=filename,
54
- subfolder=subfolder,
55
146
  token=token,
56
147
  revision=revision,
57
148
  cache_dir=cache_dir,
@@ -98,7 +98,7 @@ def tp_and_devices_are_ok(
98
98
  return None
99
99
  if rebel.get_npu_name(device_id) is None:
100
100
  return (
101
- f"Device {device_id} is not a valid NPU device. Please check your NPU status with 'rbln-stat' command."
101
+ f"Device {device_id} is not a valid NPU device. Please check your NPU status with 'rbln-smi' command."
102
102
  )
103
103
 
104
104
  if rebel.device_count() < tensor_parallel_size:
@@ -185,7 +185,7 @@ class UnavailableRuntime:
185
185
  "This model was loaded with create_runtimes=False. To use this model for inference:\n"
186
186
  "1. Load the model with runtime creation enabled:\n"
187
187
  " model = RBLNModel.from_pretrained(..., rbln_create_runtimes=True)\n"
188
- "2. Ensure your NPU hardware is properly configured (check with 'rbln-stat' command)\n"
188
+ "2. Ensure your NPU hardware is properly configured (check with 'rbln-smi' command)\n"
189
189
  "3. If you're on a machine without NPU hardware, you need to transfer the model files\n"
190
190
  " to a compatible system with NPU support."
191
191
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.9.5a4
3
+ Version: 0.10.0.post1
4
4
  Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai