optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.4a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. optimum/rbln/__init__.py +12 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +16 -6
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +12 -8
  38. optimum/rbln/transformers/configuration_generic.py +0 -27
  39. optimum/rbln/transformers/modeling_attention_utils.py +242 -109
  40. optimum/rbln/transformers/modeling_generic.py +2 -61
  41. optimum/rbln/transformers/modeling_outputs.py +1 -0
  42. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  43. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  44. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  45. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  46. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  47. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  48. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  49. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  50. optimum/rbln/transformers/models/colpali/modeling_colpali.py +6 -45
  51. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
  52. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +10 -1
  53. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  54. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
  55. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +207 -64
  56. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  57. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  58. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +140 -46
  59. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  60. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  61. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  62. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
  63. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  64. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  65. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
  66. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  67. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  68. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -25
  69. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  70. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  71. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  72. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  73. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  74. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  75. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  76. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -9
  77. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -7
  78. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
  79. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  80. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  81. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  82. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  83. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  84. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  85. optimum/rbln/transformers/models/siglip/modeling_siglip.py +17 -1
  86. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  87. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  88. optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
  89. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  90. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  91. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  92. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  93. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  94. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  95. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  96. optimum/rbln/transformers/utils/rbln_quantization.py +9 -0
  97. optimum/rbln/utils/deprecation.py +213 -0
  98. optimum/rbln/utils/hub.py +14 -3
  99. optimum/rbln/utils/import_utils.py +7 -1
  100. optimum/rbln/utils/runtime_utils.py +32 -0
  101. optimum/rbln/utils/submodule.py +3 -1
  102. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +2 -2
  103. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +106 -99
  104. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +1 -1
  105. optimum/rbln/utils/depreacate_utils.py +0 -16
  106. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
  107. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
@@ -89,7 +89,7 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
89
89
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
90
90
 
91
91
  @classmethod
92
- def wrap_model_if_needed(
92
+ def _wrap_model_if_needed(
93
93
  cls, model: "PreTrainedModel", rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig
94
94
  ):
95
95
  return Qwen2VisionTransformerWrapper(model).eval()
@@ -112,8 +112,8 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
112
112
  model_config: "PretrainedConfig" = None,
113
113
  rbln_config: Optional[RBLNQwen2VisionTransformerPretrainedModelConfig] = None,
114
114
  ) -> RBLNQwen2VisionTransformerPretrainedModelConfig:
115
- hidden_size = getattr(model_config, "embed_dim")
116
- num_heads = getattr(model_config, "num_heads")
115
+ hidden_size = model_config.embed_dim
116
+ num_heads = model_config.num_heads
117
117
  head_dim = hidden_size // num_heads
118
118
 
119
119
  input_infos = []
@@ -200,10 +200,10 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
200
200
  try:
201
201
  cu_index = torch.searchsorted(self.max_seq_lens, cu_seq_len).item()
202
202
  max_seq_len = self.max_seq_lens[cu_index]
203
- except Exception:
203
+ except Exception as e:
204
204
  raise ValueError(
205
205
  f"Required seq_len({cu_seq_len}) is larger than available max_seq_lens({self.max_seq_lens.tolist()})."
206
- )
206
+ ) from e
207
207
 
208
208
  # Padding for Full Attention Layers
209
209
  hidden_state_full_padded, cos_full_padded, sin_full_padded, full_attn_masks = (
@@ -282,8 +282,7 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
282
282
  return True
283
283
 
284
284
  @classmethod
285
- def get_pytorch_model(cls, *args, **kwargs):
286
- model = super().get_pytorch_model(*args, **kwargs)
285
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
287
286
  model.model.lm_head = model.lm_head
288
287
  model.lm_head = None
289
288
  del model.lm_head
@@ -20,7 +20,7 @@ class Qwen2VisionTransformerWrapper(nn.Module):
20
20
 
21
21
  def wrap_vision_blocks(self, blocks: torch.nn.ModuleList):
22
22
  wrapped_blocks = []
23
- for i, block in enumerate(blocks):
23
+ for _, block in enumerate(blocks):
24
24
  wrapped_blocks.append(Qwen2VLVisionBlock(block))
25
25
  return nn.ModuleList(wrapped_blocks)
26
26
 
@@ -12,24 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING
16
-
17
- from transformers import PretrainedConfig
18
15
 
19
16
  from ....utils import logging
20
17
  from ...models.decoderonly import (
21
18
  RBLNDecoderOnlyModel,
22
19
  RBLNDecoderOnlyModelForCausalLM,
23
- RBLNDecoderOnlyModelForCausalLMConfig,
24
20
  )
25
21
  from .qwen3_architecture import Qwen3Wrapper
26
22
 
27
23
 
28
24
  logger = logging.get_logger(__name__)
29
25
 
30
- if TYPE_CHECKING:
31
- from transformers import PretrainedConfig
32
-
33
26
 
34
27
  class RBLNQwen3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
35
28
  """
@@ -84,19 +77,6 @@ class RBLNQwen3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
84
77
 
85
78
  _decoder_wrapper_cls = Qwen3Wrapper
86
79
 
87
- @classmethod
88
- def _update_sliding_window_config(
89
- cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
90
- ):
91
- # https://github.com/huggingface/transformers/issues/35896
92
- # There seems to be a bug in transformers(v4.52.4). Therefore, similar to when attn_implementation is eager,
93
- # we set all layers to use sliding window in this version. This should be updated once the bug is fixed.
94
-
95
- rbln_config.cache_impl = "sliding_window"
96
- rbln_config.sliding_window = model_config.sliding_window
97
- rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
98
- return rbln_config
99
-
100
80
  def forward(self, *args, **kwargs):
101
81
  kwargs["return_dict"] = True
102
82
  return super().forward(*args, **kwargs)
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import Optional
17
+
16
18
  from ...configuration_generic import RBLNModelForImageClassificationConfig
17
19
 
18
20
 
@@ -23,3 +25,18 @@ class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConf
23
25
  This configuration class stores the configuration parameters specific to
24
26
  RBLN-optimized ResNet models for image classification tasks.
25
27
  """
28
+
29
+ def __init__(self, output_hidden_states: Optional[bool] = None, **kwargs):
30
+ """
31
+ Args:
32
+ image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
33
+ Can be an integer for square images or a tuple (height, width).
34
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
35
+ output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers.
36
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
37
+
38
+ Raises:
39
+ ValueError: If batch_size is not a positive integer.
40
+ """
41
+ super().__init__(**kwargs)
42
+ self.output_hidden_states = output_hidden_states
@@ -13,7 +13,17 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
20
+
16
21
  from ...modeling_generic import RBLNModelForImageClassification
22
+ from .configuration_resnet import RBLNResNetForImageClassificationConfig
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
17
27
 
18
28
 
19
29
  class RBLNResNetForImageClassification(RBLNModelForImageClassification):
@@ -24,3 +34,66 @@ class RBLNResNetForImageClassification(RBLNModelForImageClassification):
24
34
  on RBLN devices, supporting image classification with convolutional neural networks
25
35
  designed for computer vision tasks.
26
36
  """
37
+
38
+ @classmethod
39
+ def _update_rbln_config(
40
+ cls,
41
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
42
+ model: Optional["PreTrainedModel"] = None,
43
+ model_config: Optional["PretrainedConfig"] = None,
44
+ rbln_config: Optional["RBLNResNetForImageClassificationConfig"] = None,
45
+ ) -> "RBLNResNetForImageClassificationConfig":
46
+ if rbln_config.output_hidden_states is None:
47
+ rbln_config.output_hidden_states = getattr(model_config, "output_hidden_states", False)
48
+
49
+ rbln_config = super()._update_rbln_config(
50
+ preprocessors=preprocessors,
51
+ model=model,
52
+ model_config=model_config,
53
+ rbln_config=rbln_config,
54
+ )
55
+
56
+ return rbln_config
57
+
58
+ @classmethod
59
+ def _wrap_model_if_needed(
60
+ cls, model: torch.nn.Module, rbln_config: "RBLNResNetForImageClassificationConfig"
61
+ ) -> torch.nn.Module:
62
+ class _ResNetForImageClassification(torch.nn.Module):
63
+ def __init__(self, model: torch.nn.Module, output_hidden_states: bool):
64
+ super().__init__()
65
+ self.model = model
66
+ self.output_hidden_states = output_hidden_states
67
+
68
+ def forward(self, *args, **kwargs):
69
+ output = self.model(*args, output_hidden_states=self.output_hidden_states, **kwargs)
70
+ return output
71
+
72
+ return _ResNetForImageClassification(model, rbln_config.output_hidden_states)
73
+
74
+ def forward(
75
+ self, pixel_values: torch.Tensor, output_hidden_states: bool = None, return_dict: bool = None, **kwargs
76
+ ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
77
+ """
78
+ Foward pass for the RBLN-optimized ResNet model for image classification.
79
+
80
+ Args:
81
+ pixel_values (torch.FloatTensor of shape (batch_size, channels, height, width)): The tensors corresponding to the input images.
82
+ output_hidden_states (bool, *optional*, defaults to False): Whether or not to return the hidden states of all layers.
83
+ See hidden_states under returned tensors for more details.
84
+ return_dict (bool, *optional*, defaults to True): Whether to return a dictionary of outputs.
85
+
86
+ Returns:
87
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a ImageClassifierOutputWithNoAttention object.
88
+ """
89
+ output_hidden_states = (
90
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
91
+ )
92
+
93
+ if output_hidden_states != self.rbln_config.output_hidden_states:
94
+ raise ValueError(
95
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
96
+ f"Please compile again with the correct argument."
97
+ )
98
+
99
+ return super().forward(pixel_values=pixel_values, return_dict=return_dict, **kwargs)
@@ -12,6 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Tuple, Union
16
+
17
+ import torch
18
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
19
+
15
20
  from ...modeling_generic import RBLNModelForMaskedLM, RBLNModelForSequenceClassification
16
21
 
17
22
 
@@ -26,6 +31,19 @@ class RBLNRobertaForMaskedLM(RBLNModelForMaskedLM):
26
31
 
27
32
  rbln_model_input_names = ["input_ids", "attention_mask"]
28
33
 
34
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Union[Tuple, MaskedLMOutput]:
35
+ """
36
+ Forward pass for the RBLN-optimized RoBERTa model for masked language modeling tasks.
37
+
38
+ Args:
39
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
40
+ attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
41
+
42
+ Returns:
43
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a MaskedLMOutput object.
44
+ """
45
+ return super().forward(input_ids, attention_mask, **kwargs)
46
+
29
47
 
30
48
  class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
31
49
  """
@@ -37,3 +55,18 @@ class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
37
55
  """
38
56
 
39
57
  rbln_model_input_names = ["input_ids", "attention_mask"]
58
+
59
+ def forward(
60
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
61
+ ) -> Union[Tuple, SequenceClassifierOutput]:
62
+ """
63
+ Forward pass for the RBLN-optimized RoBERTa model for sequence classification tasks.
64
+
65
+ Args:
66
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
67
+ attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
68
+
69
+ Returns:
70
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SequenceClassifierOutput object.
71
+ """
72
+ return super().forward(input_ids, attention_mask, **kwargs)
@@ -15,6 +15,7 @@
15
15
  from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.deprecation import deprecate_kwarg
18
19
  from ....utils.logging import get_logger
19
20
 
20
21
 
@@ -24,13 +25,13 @@ logger = get_logger()
24
25
  class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
25
26
  support_paged_attention = None
26
27
 
28
+ @deprecate_kwarg(old_name="pad_token_id", version="0.10.0")
27
29
  def __init__(
28
30
  self,
29
31
  batch_size: Optional[int] = None,
30
32
  enc_max_seq_len: Optional[int] = None,
31
33
  dec_max_seq_len: Optional[int] = None,
32
34
  use_attention_mask: Optional[bool] = None,
33
- pad_token_id: Optional[int] = None,
34
35
  kvcache_num_blocks: Optional[int] = None,
35
36
  kvcache_block_size: Optional[int] = None,
36
37
  **kwargs: Any,
@@ -41,7 +42,6 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
41
42
  enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
42
43
  dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
43
44
  use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
44
- pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
45
45
  kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
46
46
  PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
47
47
  kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
@@ -61,8 +61,6 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
61
61
 
62
62
  self.use_attention_mask = use_attention_mask
63
63
 
64
- self.pad_token_id = pad_token_id
65
-
66
64
  if self.support_paged_attention:
67
65
  self.kvcache_num_blocks = kvcache_num_blocks
68
66
  self.kvcache_block_size = kvcache_block_size
@@ -20,8 +20,9 @@ import rebel
20
20
  import torch
21
21
  from rebel.compile_context import CompileContext
22
22
  from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
23
+ from transformers.generation.configuration_utils import GenerationConfig
23
24
  from transformers.generation.utils import GenerationMixin
24
- from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
25
+ from transformers.modeling_outputs import BaseModelOutput, ModelOutput, Seq2SeqLMOutput
25
26
 
26
27
  from ....configuration_utils import RBLNCompileConfig
27
28
  from ....modeling import RBLNModel
@@ -33,7 +34,7 @@ from .configuration_seq2seq import RBLNModelForSeq2SeqLMConfig
33
34
  logger = get_logger(__name__)
34
35
 
35
36
  if TYPE_CHECKING:
36
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, GenerationConfig, PretrainedConfig
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
37
38
 
38
39
 
39
40
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
@@ -140,7 +141,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
140
141
  @classmethod
141
142
  @torch.inference_mode()
142
143
  def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
143
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
144
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
144
145
 
145
146
  enc_compile_config = rbln_config.compile_cfgs[0]
146
147
  dec_compile_config = rbln_config.compile_cfgs[1]
@@ -209,8 +210,8 @@ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
209
210
  if not cls.support_causal_attn:
210
211
  rbln_config.use_attention_mask = True
211
212
 
212
- n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
213
- n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
213
+ n_layer = getattr(model_config, "decoder_layers", None) or model_config.num_layers
214
+ n_head = getattr(model_config, "decoder_attention_heads", None) or model_config.num_heads
214
215
  d_kv = (
215
216
  model_config.d_kv
216
217
  if hasattr(model_config, "d_kv")
@@ -221,12 +222,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
221
222
  model_config, "max_position_embeddings", None
222
223
  )
223
224
 
224
- pad_token_id = getattr(model_config, "pad_token_id", None)
225
- pad_token_id = pad_token_id or getattr(model_config, "bos_token_id", None)
226
- pad_token_id = pad_token_id or getattr(model_config, "eos_token_id", None)
227
- pad_token_id = pad_token_id or -1
228
- rbln_config.pad_token_id = pad_token_id
229
-
230
225
  if rbln_config.enc_max_seq_len is None:
231
226
  enc_max_seq_len = max_position_embeddings
232
227
  for tokenizer in preprocessors:
@@ -432,7 +427,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
432
427
  inputs_tensor = torch.nn.functional.pad(
433
428
  inputs_tensor,
434
429
  (0, self.rbln_config.enc_max_seq_len - input_len),
435
- value=self.rbln_config.pad_token_id,
430
+ value=self.config.pad_token_id,
436
431
  )
437
432
  model_kwargs["attention_mask"] = torch.nn.functional.pad(
438
433
  model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
@@ -451,3 +446,32 @@ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
451
446
  model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, block_tables=block_tables)
452
447
 
453
448
  return model_kwargs
449
+
450
+ def generate(
451
+ self,
452
+ input_ids: torch.LongTensor,
453
+ attention_mask: Optional[torch.LongTensor] = None,
454
+ generation_config: Optional[GenerationConfig] = None,
455
+ **kwargs,
456
+ ) -> Union[ModelOutput, torch.LongTensor]:
457
+ """
458
+ The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
459
+ Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) for more details.
460
+
461
+ Args:
462
+ input_ids (torch.LongTensor): The input ids to the model.
463
+ attention_mask (torch.LongTensor, optional): The attention mask to the model.
464
+ generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
465
+ If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
466
+ Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
467
+ kwargs (dict[str, Any], optional): Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
468
+
469
+ Returns:
470
+ Generates sequences of token ids for models with a language modeling head.
471
+ """
472
+ if generation_config is not None:
473
+ kwargs["generation_config"] = generation_config
474
+ if attention_mask is not None:
475
+ kwargs["attention_mask"] = attention_mask
476
+
477
+ return super().generate(input_ids, **kwargs)
@@ -66,7 +66,9 @@ class RBLNSiglipVisionModel(RBLNModel):
66
66
  _tp_support = False
67
67
 
68
68
  @classmethod
69
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig) -> torch.nn.Module:
69
+ def _wrap_model_if_needed(
70
+ cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig
71
+ ) -> torch.nn.Module:
70
72
  wrapper_cfg = {
71
73
  "interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
72
74
  "output_hidden_states": rbln_config.output_hidden_states,
@@ -122,6 +124,20 @@ class RBLNSiglipVisionModel(RBLNModel):
122
124
  interpolate_pos_encoding: bool = False,
123
125
  **kwargs: Any,
124
126
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
127
+ """
128
+ Forward pass for the RBLN-optimized SigLIP vision model.
129
+
130
+ Args:
131
+ pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size), optional): The tensors corresponding to the input images. Pixel values can be obtained using ViTImageProcessor. See ViTImageProcessor.call() for details (processor_class uses ViTImageProcessor for processing images).
132
+ return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
133
+ output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
134
+ output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
135
+ interpolate_pos_encoding (bool, defaults to False): Whether to interpolate the pre-trained position encodings.
136
+
137
+ Returns:
138
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
139
+ """
140
+
125
141
  output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
126
142
  output_hidden_states = (
127
143
  output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
@@ -203,7 +203,7 @@ class _SwinBackbone(torch.nn.Module):
203
203
 
204
204
  class RBLNSwinBackbone(RBLNModel):
205
205
  @classmethod
206
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSwinBackboneConfig) -> torch.nn.Module:
206
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSwinBackboneConfig) -> torch.nn.Module:
207
207
  for layer in model.encoder.layers:
208
208
  for block in layer.blocks:
209
209
  block.get_attn_mask = types.MethodType(get_attn_mask, block)
@@ -278,6 +278,19 @@ class RBLNSwinBackbone(RBLNModel):
278
278
  output_hidden_states: bool = None,
279
279
  **kwargs,
280
280
  ) -> Union[Tuple, BackboneOutput]:
281
+ """
282
+ Forward pass for the RBLN-optimized Swin backbone model.
283
+
284
+ Args:
285
+ pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size), optional): The tensors corresponding to the input images. Pixel values can be obtained using ViTImageProcessor. See ViTImageProcessor.call() for details (processor_class uses ViTImageProcessor for processing images).
286
+ return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
287
+ output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
288
+ output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
289
+
290
+ Returns:
291
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BackboneOutput object.
292
+ """
293
+
281
294
  if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
282
295
  logger.warning(
283
296
  f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
@@ -314,19 +327,19 @@ class RBLNSwinBackbone(RBLNModel):
314
327
  output = self.model[0](padded_pixel_values)
315
328
 
316
329
  feature_maps = ()
317
- for i in range(len(self.config.out_features)):
330
+ for _ in range(len(self.config.out_features)):
318
331
  feature_maps += (output.pop(0),)
319
332
 
320
333
  if self.rbln_config.output_hidden_states:
321
334
  hidden_states = ()
322
- for i in range(len(self.config.stage_names)):
335
+ for _ in range(len(self.config.stage_names)):
323
336
  hidden_states += (output.pop(0),)
324
337
  else:
325
338
  hidden_states = None
326
339
 
327
340
  if self.rbln_config.output_attentions:
328
341
  attentions = ()
329
- for i in range(len(self.config.depths)):
342
+ for _ in range(len(self.config.depths)):
330
343
  attentions += (output.pop(0),)
331
344
  else:
332
345
  attentions = None
@@ -68,7 +68,7 @@ class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
68
68
  output_class = BaseModelOutputWithPastAndCrossAttentions
69
69
 
70
70
  @classmethod
71
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
71
+ def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
72
72
  return T5EncoderWrapper(model)
73
73
 
74
74
  @classmethod
@@ -113,7 +113,7 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
113
113
  support_causal_attn = False
114
114
 
115
115
  @classmethod
116
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5ForConditionalGenerationConfig):
116
+ def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5ForConditionalGenerationConfig):
117
117
  return T5Wrapper(
118
118
  model, enc_max_seq_len=rbln_config.enc_max_seq_len, dec_max_seq_len=rbln_config.dec_max_seq_len
119
119
  )
@@ -39,7 +39,7 @@ class T5Wrapper:
39
39
 
40
40
  class T5EncoderWrapper(Seq2SeqEncoderWrapper):
41
41
  def __post_init__(self, model: nn.Module):
42
- self.n_layer = getattr(self.config, "num_layers")
42
+ self.n_layer = self.config.num_layers
43
43
  self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().block)
44
44
  self.num_heads = self.config.num_heads
45
45
  self.d_kv = self.config.d_kv
@@ -153,7 +153,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
153
153
  return redirect(val)
154
154
 
155
155
  @classmethod
156
- def wrap_model_if_needed(
156
+ def _wrap_model_if_needed(
157
157
  self, model: "PreTrainedModel", rbln_config: RBLNTimeSeriesTransformerForPredictionConfig
158
158
  ):
159
159
  return TimeSeriesTransformersWrapper(model, rbln_config.num_parallel_samples)
@@ -161,7 +161,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
161
161
  @classmethod
162
162
  @torch.inference_mode()
163
163
  def get_compiled_model(cls, model, rbln_config: RBLNTimeSeriesTransformerForPredictionConfig):
164
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
164
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
165
165
 
166
166
  enc_compile_config = rbln_config.compile_cfgs[0]
167
167
  dec_compile_config = rbln_config.compile_cfgs[1]
@@ -184,14 +184,6 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
184
184
  if "key_value_states" in name:
185
185
  context.mark_static_address(tensor)
186
186
 
187
- compiled_decoder = cls.compile(
188
- wrapped_model.decoder,
189
- dec_compile_config,
190
- create_runtimes=rbln_config.create_runtimes,
191
- device=rbln_config.device,
192
- example_inputs=dec_example_inputs,
193
- compile_context=context,
194
- )
195
187
  compiled_encoder = cls.compile(
196
188
  wrapped_model.encoder,
197
189
  enc_compile_config,
@@ -201,6 +193,15 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
201
193
  compile_context=context,
202
194
  )
203
195
 
196
+ compiled_decoder = cls.compile(
197
+ wrapped_model.decoder,
198
+ dec_compile_config,
199
+ create_runtimes=rbln_config.create_runtimes,
200
+ device=rbln_config.device,
201
+ example_inputs=dec_example_inputs,
202
+ compile_context=context,
203
+ )
204
+
204
205
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
205
206
 
206
207
  @classmethod
@@ -353,6 +354,20 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
353
354
  static_real_features: Optional[torch.Tensor] = None,
354
355
  **kwargs,
355
356
  ) -> SampleTSPredictionOutput:
357
+ """
358
+ Generate pass for the RBLN-optimized Time Series Transformer model for time series forecasting.
359
+
360
+ Args:
361
+ past_values (torch.FloatTensor of shape (batch_size, sequence_length) or (batch_size, sequence_length, input_size)): Past values of the time series, that serve as context in order to predict the future.
362
+ past_time_features (torch.FloatTensor of shape (batch_size, sequence_length, num_features)): Required time features, which the model internally will add to past_values.
363
+ future_time_features (torch.FloatTensor of shape (batch_size, prediction_length, num_features)): Required time features for the prediction window, which the model internally will add to future_values.
364
+ past_observed_mask (torch.BoolTensor of shape (batch_size, sequence_length) or (batch_size, sequence_length, input_size), optional): Boolean mask to indicate which past_values were observed and which were missing.
365
+ static_categorical_features (torch.LongTensor of shape (batch_size, number of static categorical features), optional): Optional static categorical features for which the model will learn an embedding, which it will add to the values of the time series.
366
+ static_real_features (torch.FloatTensor of shape (batch_size, number of static real features), optional): Optional static real features which the model will add to the values of the time series.
367
+
368
+ Returns:
369
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SampleTSPredictionOutput object.
370
+ """
356
371
  self.validate_batch_size(**{k: v for k, v in locals().items() if isinstance(v, torch.Tensor)})
357
372
 
358
373
  outputs = self.encoder(
@@ -12,6 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Tuple, Union
16
+
17
+ import torch
18
+ from transformers.modeling_outputs import ImageClassifierOutput
19
+
15
20
  from ...modeling_generic import RBLNModelForImageClassification
16
21
 
17
22
 
@@ -23,3 +28,17 @@ class RBLNViTForImageClassification(RBLNModelForImageClassification):
23
28
  on RBLN devices, supporting image classification with transformer-based architectures
24
29
  that process images as sequences of patches.
25
30
  """
31
+
32
+ def forward(self, pixel_values: torch.Tensor, **kwargs) -> Union[ImageClassifierOutput, Tuple]:
33
+ """
34
+ Forward pass for the RBLN-optimized Vision Transformer model for image classification.
35
+
36
+ Args:
37
+ pixel_values (torch.FloatTensor of shape (batch_size, channels, height, width)):
38
+ The tensors corresponding to the input images.
39
+
40
+ Returns:
41
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns an ImageClassifierOutput object.
42
+
43
+ """
44
+ return super().forward(pixel_values, **kwargs)
@@ -12,10 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ...configuration_generic import RBLNModelForMaskedLMConfig
15
+ from typing import Any, Optional
16
16
 
17
+ from ....configuration_utils import RBLNModelConfig
17
18
 
18
- class RBLNWav2Vec2ForCTCConfig(RBLNModelForMaskedLMConfig):
19
+
20
+ class RBLNWav2Vec2ForCTCConfig(RBLNModelConfig):
19
21
  """
20
22
  Configuration class for RBLNWav2Vec2ForCTC.
21
23
 
@@ -23,4 +25,14 @@ class RBLNWav2Vec2ForCTCConfig(RBLNModelForMaskedLMConfig):
23
25
  RBLN-optimized Wav2Vec2 models for Connectionist Temporal Classification (CTC) tasks.
24
26
  """
25
27
 
26
- rbln_model_input_names = ["input_values"]
28
+ def __init__(
29
+ self,
30
+ max_seq_len: Optional[int] = None,
31
+ batch_size: Optional[int] = None,
32
+ **kwargs: Any,
33
+ ):
34
+ super().__init__(**kwargs)
35
+ self.max_seq_len = max_seq_len
36
+ self.batch_size = batch_size or 1
37
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
38
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")