optimum-rbln 0.8.2rc0__py3-none-any.whl → 0.8.3a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (91) hide show
  1. optimum/rbln/__init__.py +4 -9
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +4 -4
  4. optimum/rbln/diffusers/__init__.py +1 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  22. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  23. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  24. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  27. optimum/rbln/modeling.py +2 -2
  28. optimum/rbln/modeling_base.py +12 -4
  29. optimum/rbln/ops/attn.py +158 -0
  30. optimum/rbln/ops/flash_attn.py +166 -0
  31. optimum/rbln/transformers/__init__.py +2 -0
  32. optimum/rbln/transformers/configuration_generic.py +4 -4
  33. optimum/rbln/transformers/modeling_generic.py +1 -4
  34. optimum/rbln/transformers/modeling_outputs.py +37 -0
  35. optimum/rbln/transformers/models/__init__.py +6 -16
  36. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  37. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  38. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  39. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  40. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  41. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  42. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  43. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  44. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  45. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  46. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  47. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +101 -91
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  49. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  50. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +296 -986
  51. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  52. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  53. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  54. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  55. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +19 -250
  56. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  57. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  58. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  59. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  60. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  61. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  64. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  65. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  66. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  67. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  68. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  69. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  70. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  71. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  72. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  73. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  74. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  75. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  76. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  77. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  78. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  79. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  80. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  81. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  82. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  83. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  84. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  85. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  86. optimum/rbln/transformers/utils/rbln_quantization.py +365 -65
  87. optimum/rbln/utils/runtime_utils.py +3 -3
  88. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a1.dist-info}/METADATA +1 -1
  89. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a1.dist-info}/RECORD +91 -87
  90. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a1.dist-info}/WHEEL +0 -0
  91. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a1.dist-info}/licenses/LICENSE +0 -0
@@ -12,9 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_siglip import (
16
- RBLNSiglipVisionModelConfig,
17
- )
18
- from .modeling_siglip import (
19
- RBLNSiglipVisionModel,
20
- )
15
+ from .configuration_siglip import RBLNSiglipVisionModelConfig
16
+ from .modeling_siglip import RBLNSiglipVisionModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
15
+ from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from transformers import SiglipVisionConfig, SiglipVisionModel
@@ -126,7 +126,7 @@ class RBLNSiglipVisionModel(RBLNModel):
126
126
  output_attentions: bool = None,
127
127
  output_hidden_states: bool = None,
128
128
  interpolate_pos_encoding: bool = False,
129
- **kwargs: Dict[str, Any],
129
+ **kwargs: Any,
130
130
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
131
131
  if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
132
132
  logger.warning(
@@ -32,3 +32,5 @@ class RBLNT5ForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
32
32
  This configuration class stores the configuration parameters specific to
33
33
  RBLN-optimized T5 models for conditional text generation tasks.
34
34
  """
35
+
36
+ support_paged_attention = False
@@ -126,7 +126,14 @@ class T5Decoder(Seq2SeqDecoder):
126
126
  b_size = attention_mask.shape[0]
127
127
  batch_decoder_position_bias = []
128
128
  for i in range(b_size):
129
- batch_position_bias = self._dec_position_bias[:, :, cache_position[i][0]].unsqueeze(2)
129
+ if torch.compiler.is_exporting():
130
+ cache_pos = cache_position[i][0].item()
131
+ torch._check_is_size(cache_pos)
132
+ torch._check(cache_pos >= 0)
133
+ torch._check(cache_pos < self._dec_position_bias.shape[2])
134
+ else:
135
+ cache_pos = cache_position[i][0]
136
+ batch_position_bias = torch.select(self._dec_position_bias, dim=2, index=cache_pos).unsqueeze(2)
130
137
  batch_decoder_position_bias.append(batch_position_bias)
131
138
  position_bias = torch.cat(batch_decoder_position_bias, dim=0)
132
139
 
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  from ....configuration_utils import RBLNModelConfig
4
4
 
@@ -17,7 +17,7 @@ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
17
17
  enc_max_seq_len: Optional[int] = None,
18
18
  dec_max_seq_len: Optional[int] = None,
19
19
  num_parallel_samples: Optional[int] = None,
20
- **kwargs: Dict[str, Any],
20
+ **kwargs: Any,
21
21
  ):
22
22
  """
23
23
  Args:
@@ -23,24 +23,20 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from dataclasses import dataclass
27
26
  from pathlib import Path
28
- from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
27
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
29
28
 
30
29
  import rebel
31
30
  import torch
32
31
  from rebel.compile_context import CompileContext
33
- from transformers import (
34
- PretrainedConfig,
35
- TimeSeriesTransformerForPrediction,
36
- TimeSeriesTransformerModel,
37
- )
38
- from transformers.modeling_outputs import ModelOutput, SampleTSPredictionOutput, Seq2SeqTSModelOutput
32
+ from transformers import PretrainedConfig, TimeSeriesTransformerForPrediction, TimeSeriesTransformerModel
33
+ from transformers.modeling_outputs import SampleTSPredictionOutput, Seq2SeqTSModelOutput
39
34
  from transformers.modeling_utils import no_init_weights
40
35
 
41
36
  from ....configuration_utils import RBLNCompileConfig
42
37
  from ....modeling import RBLNModel
43
38
  from ....utils.runtime_utils import RBLNPytorchRuntime
39
+ from ...modeling_outputs import RBLNSeq2SeqTSDecoderOutput
44
40
  from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
45
41
  from .time_series_transformers_architecture import TimeSeriesTransformersWrapper
46
42
 
@@ -113,12 +109,6 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
113
109
  )
114
110
 
115
111
 
116
- @dataclass
117
- class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
118
- last_hidden_states: torch.FloatTensor = None
119
- params: Tuple[torch.FloatTensor] = None
120
-
121
-
122
112
  class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
123
113
  """
124
114
  The Time Series Transformer Model with a distribution head on top for time-series forecasting. e.g., for datasets like M4, NN5, or other time series forecasting benchmarks.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict
15
+ from typing import Any
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....utils.logging import get_logger
@@ -36,7 +36,9 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
36
36
  use_attention_mask: bool = None,
37
37
  enc_max_seq_len: int = None,
38
38
  dec_max_seq_len: int = None,
39
- **kwargs: Dict[str, Any],
39
+ kvcache_num_blocks: int = None,
40
+ kvcache_block_size: int = None,
41
+ **kwargs: Any,
40
42
  ):
41
43
  """
42
44
  Args:
@@ -45,6 +47,10 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
45
47
  use_attention_mask (bool, optional): Whether to use attention masks during inference. This is automatically
46
48
  enc_max_seq_len (int, optional): Maximum sequence length for the encoder.
47
49
  dec_max_seq_len (int, optional): Maximum sequence length for the decoder.
50
+ kvcache_num_blocks (int, optional): The total number of blocks to allocate for the
51
+ PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
52
+ kvcache_block_size (int, optional): Sets the size (in number of tokens) of each block
53
+ in the PagedAttention KV cache for the SelfAttention. Defaults to dec_max_seq_len.
48
54
  **kwargs: Additional arguments passed to the parent RBLNModelConfig.
49
55
 
50
56
  Raises:
@@ -62,3 +68,5 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
62
68
 
63
69
  self.use_attention_mask = use_attention_mask
64
70
  self.use_attention_mask = self.use_attention_mask or False
71
+ self.kvcache_num_blocks = kvcache_num_blocks
72
+ self.kvcache_block_size = kvcache_block_size
@@ -46,7 +46,7 @@ if TYPE_CHECKING:
46
46
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
47
47
  mandatory_members = ["main_input_name"]
48
48
 
49
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
49
+ def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
50
50
  output = super().forward(*args, **kwargs)
51
51
  return BaseModelOutput(last_hidden_state=output)
52
52
 
@@ -253,6 +253,23 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
253
253
 
254
254
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
255
255
 
256
+ @classmethod
257
+ def _update_paged_attention_config(
258
+ cls, model_config: "PretrainedConfig", rbln_config: RBLNWhisperForConditionalGenerationConfig
259
+ ):
260
+ rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
261
+ rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
262
+
263
+ if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
264
+ raise NotImplementedError(
265
+ f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
266
+ )
267
+
268
+ if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
269
+ raise NotImplementedError(
270
+ f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
271
+ )
272
+
256
273
  @classmethod
257
274
  def _update_rbln_config(
258
275
  cls,
@@ -270,6 +287,8 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
270
287
  if rbln_config.dec_max_seq_len is None:
271
288
  rbln_config.dec_max_seq_len = model_config.max_length
272
289
 
290
+ cls._update_paged_attention_config(model_config, rbln_config)
291
+
273
292
  enc_input_info = [
274
293
  ("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
275
294
  ("block_tables", [1], "int16"),
@@ -12,14 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_xlm_roberta import (
16
- RBLNXLMRobertaForSequenceClassificationConfig,
17
- RBLNXLMRobertaModelConfig,
18
- )
19
- from .modeling_xlm_roberta import (
20
- RBLNXLMRobertaForSequenceClassification,
21
- RBLNXLMRobertaModel,
22
- )
15
+ from .configuration_xlm_roberta import RBLNXLMRobertaForSequenceClassificationConfig, RBLNXLMRobertaModelConfig
16
+ from .modeling_xlm_roberta import RBLNXLMRobertaForSequenceClassification, RBLNXLMRobertaModel
23
17
 
24
18
 
25
19
  __all__ = [