optimum-rbln 0.8.2a7__py3-none-any.whl → 0.8.3a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (90) hide show
  1. optimum/rbln/__init__.py +8 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/configuration_utils.py +4 -4
  4. optimum/rbln/diffusers/__init__.py +1 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  22. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  23. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  24. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  27. optimum/rbln/modeling.py +2 -2
  28. optimum/rbln/modeling_base.py +12 -4
  29. optimum/rbln/ops/attn.py +158 -0
  30. optimum/rbln/ops/flash_attn.py +166 -0
  31. optimum/rbln/transformers/__init__.py +6 -0
  32. optimum/rbln/transformers/configuration_generic.py +4 -4
  33. optimum/rbln/transformers/modeling_generic.py +1 -4
  34. optimum/rbln/transformers/modeling_outputs.py +37 -0
  35. optimum/rbln/transformers/models/__init__.py +10 -16
  36. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  37. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  38. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  39. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  40. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  41. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  42. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  43. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  44. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  45. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  46. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  47. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -93
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  49. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  50. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +297 -987
  51. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  52. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  53. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +14 -3
  54. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  55. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +58 -257
  56. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  57. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  58. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  59. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  60. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  61. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  64. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  65. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  66. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  67. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  68. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  69. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  70. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  71. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  72. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  73. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  74. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  75. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  76. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  77. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  78. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  79. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  80. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  81. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  82. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  83. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  84. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  85. optimum/rbln/transformers/utils/rbln_quantization.py +249 -46
  86. optimum/rbln/utils/runtime_utils.py +3 -3
  87. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/METADATA +1 -1
  88. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/RECORD +90 -86
  89. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/WHEEL +0 -0
  90. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/licenses/LICENSE +0 -0
@@ -18,29 +18,19 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
18
18
 
19
19
  import numpy as np
20
20
  import torch
21
- from transformers import (
22
- AutoModelForVision2Seq,
23
- LlavaNextForConditionalGeneration,
24
- PretrainedConfig,
25
- PreTrainedModel,
26
- )
21
+ from transformers import AutoModelForVision2Seq, LlavaNextForConditionalGeneration, PretrainedConfig, PreTrainedModel
27
22
  from transformers.modeling_outputs import BaseModelOutputWithPooling
28
23
 
29
24
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
30
25
  from ....modeling import RBLNModel
31
26
  from ....utils.logging import get_logger
32
- from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyForCausalLMOutput
27
+ from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
33
28
 
34
29
 
35
30
  logger = get_logger(__name__)
36
31
 
37
32
  if TYPE_CHECKING:
38
- from transformers import (
39
- AutoFeatureExtractor,
40
- AutoProcessor,
41
- AutoTokenizer,
42
- PretrainedConfig,
43
- )
33
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
44
34
 
45
35
 
46
36
  class LoopVisionTower:
@@ -258,7 +248,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
258
248
 
259
249
  def _update_model_kwargs_for_generation(
260
250
  self,
261
- outputs: RBLNDecoderOnlyForCausalLMOutput,
251
+ outputs: RBLNDecoderOnlyOutput,
262
252
  model_kwargs: Dict[str, Any],
263
253
  **kwargs,
264
254
  ) -> Dict[str, Any]:
@@ -359,7 +349,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
359
349
  generate_idx: Optional[torch.Tensor] = None,
360
350
  batch_idx: Optional[int] = None,
361
351
  **kwargs,
362
- ) -> Union[Tuple, RBLNDecoderOnlyForCausalLMOutput]:
352
+ ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
363
353
  vision_feature_layer = (
364
354
  vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
365
355
  )
@@ -418,7 +408,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
418
408
  cache_position=cache_position,
419
409
  )
420
410
  logits = output.logits
421
- return RBLNDecoderOnlyForCausalLMOutput(logits=logits, generate_idx=generate_idx)
411
+ return RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
422
412
 
423
413
  # Almost copied from : https://github.com/huggingface/transformers/blob/6b550462139655d488d4c663086a63e98713c6b9/src/transformers/models/llava_next/modeling_llava_next.py
424
414
  def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
@@ -70,24 +70,10 @@ class RBLNOPTForCausalLM(RBLNDecoderOnlyModelForCausalLM):
70
70
 
71
71
  @classmethod
72
72
  def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
73
- wrapper_cfg = {
74
- "max_seq_len": rbln_config.max_seq_len,
75
- "attn_impl": rbln_config.attn_impl,
76
- "kvcache_partition_len": rbln_config.kvcache_partition_len,
77
- "kvcache_block_size": rbln_config.kvcache_block_size,
78
- "use_rotary_emb": cls._use_rotary_emb,
79
- "use_attention_mask": rbln_config.use_attention_mask,
80
- "use_position_ids": rbln_config.use_position_ids,
81
- "use_inputs_embeds": rbln_config.use_inputs_embeds,
82
- "cache_impl": rbln_config.cache_impl,
83
- "sliding_window": rbln_config.sliding_window,
84
- "sliding_window_layers": rbln_config.sliding_window_layers,
85
- }
86
-
87
73
  for i in range(len(model.model.decoder.layers)):
88
74
  model.model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.model.decoder.layers[i])
89
75
 
90
- return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
76
+ return cls._decoder_wrapper_cls(model, rbln_config=rbln_config, use_rotary_emb=cls._use_rotary_emb).eval()
91
77
 
92
78
 
93
79
  class RBLNOPTModel(RBLNDecoderOnlyModel):
@@ -110,21 +96,7 @@ class RBLNOPTModel(RBLNDecoderOnlyModel):
110
96
 
111
97
  @classmethod
112
98
  def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
113
- wrapper_cfg = {
114
- "max_seq_len": rbln_config.max_seq_len,
115
- "attn_impl": rbln_config.attn_impl,
116
- "kvcache_partition_len": rbln_config.kvcache_partition_len,
117
- "kvcache_block_size": rbln_config.kvcache_block_size,
118
- "use_rotary_emb": cls._use_rotary_emb,
119
- "use_attention_mask": rbln_config.use_attention_mask,
120
- "use_position_ids": rbln_config.use_position_ids,
121
- "use_inputs_embeds": rbln_config.use_inputs_embeds,
122
- "cache_impl": rbln_config.cache_impl,
123
- "sliding_window": rbln_config.sliding_window,
124
- "sliding_window_layers": rbln_config.sliding_window_layers,
125
- }
126
-
127
99
  for i in range(len(model.decoder.layers)):
128
100
  model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.decoder.layers[i])
129
101
 
130
- return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
102
+ return cls._decoder_wrapper_cls(model, rbln_config=rbln_config, use_rotary_emb=cls._use_rotary_emb).eval()
@@ -24,6 +24,8 @@ class RBLNPegasusModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
24
24
  RBLN-optimized PEGASUS models for feature extraction tasks.
25
25
  """
26
26
 
27
+ rbln_model_input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
28
+
27
29
 
28
30
  class RBLNPegasusForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
29
31
  """
@@ -32,3 +34,5 @@ class RBLNPegasusForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
32
34
  This configuration class stores the configuration parameters specific to
33
35
  RBLN-optimized PEGASUS models for conditional text generation tasks.
34
36
  """
37
+
38
+ support_paged_attention = True
@@ -39,6 +39,8 @@ class RBLNPegasusModel(RBLNTransformerEncoderForFeatureExtraction):
39
39
  on RBLN devices, optimized for feature extraction use cases.
40
40
  """
41
41
 
42
+ rbln_model_input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
43
+
42
44
 
43
45
  class RBLNPegasusForConditionalGeneration(RBLNModelForSeq2SeqLM):
44
46
  """
@@ -16,9 +16,7 @@ from typing import Tuple
16
16
 
17
17
  import torch
18
18
  from torch import nn
19
- from transformers.modeling_attn_mask_utils import (
20
- _prepare_4d_attention_mask,
21
- )
19
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
22
20
  from transformers.utils import logging
23
21
 
24
22
  from ..seq2seq.seq2seq_architecture import (
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -23,7 +23,7 @@ class RBLNPixtralVisionModelConfig(RBLNModelConfig):
23
23
  max_image_size: Tuple = None,
24
24
  batch_size: Optional[int] = None,
25
25
  output_hidden_states: Optional[bool] = None,
26
- **kwargs: Dict[str, Any],
26
+ **kwargs: Any,
27
27
  ):
28
28
  """
29
29
  Args:
@@ -21,10 +21,7 @@ import torch.nn as nn
21
21
  from transformers import PixtralVisionConfig, PixtralVisionModel
22
22
  from transformers.modeling_outputs import BaseModelOutput
23
23
  from transformers.modeling_utils import no_init_weights
24
- from transformers.models.pixtral.modeling_pixtral import (
25
- PixtralRMSNorm,
26
- PixtralRotaryEmbedding,
27
- )
24
+ from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm, PixtralRotaryEmbedding
28
25
 
29
26
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
30
27
  from ....modeling import RBLNModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, List, Optional, Union
15
+ from typing import Any, List, Optional, Union
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
@@ -33,7 +33,7 @@ class RBLNQwen2_5_VLForConditionalGenerationConfig(RBLNDecoderOnlyModelForCausal
33
33
  self,
34
34
  visual: Optional[RBLNModelConfig] = None,
35
35
  use_inputs_embeds: bool = True,
36
- **kwargs: Dict[str, Any],
36
+ **kwargs: Any,
37
37
  ):
38
38
  super().__init__(use_inputs_embeds=use_inputs_embeds, **kwargs)
39
39
  if not self.use_inputs_embeds:
@@ -53,7 +53,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModelConfig(RBLNModelConfig):
53
53
  mechanisms for processing images and videos.
54
54
  """
55
55
 
56
- def __init__(self, max_seq_lens: Union[int, List[int]] = None, **kwargs: Dict[str, Any]):
56
+ def __init__(self, max_seq_lens: Union[int, List[int]] = None, **kwargs: Any):
57
57
  """
58
58
  Args:
59
59
  max_seq_lens (Optional[Union[int, List[int]]]): Maximum sequence lengths for Vision
@@ -17,12 +17,7 @@ from pathlib import Path
17
17
  from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
18
18
 
19
19
  import torch
20
- from transformers import (
21
- AutoModelForVision2Seq,
22
- PretrainedConfig,
23
- PreTrainedModel,
24
- Qwen2_5_VLForConditionalGeneration,
25
- )
20
+ from transformers import AutoModelForVision2Seq, PretrainedConfig, PreTrainedModel, Qwen2_5_VLForConditionalGeneration
26
21
  from transformers.modeling_utils import no_init_weights
27
22
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
28
23
  Qwen2_5_VisionPatchEmbed,
@@ -34,7 +29,8 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
34
29
  from ....configuration_utils import RBLNCompileConfig
35
30
  from ....modeling import RBLNModel
36
31
  from ....utils.logging import get_logger
37
- from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyForCausalLMOutput, RBLNDecoderOnlyModelForCausalLM
32
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
33
+ from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
38
34
  from .configuration_qwen2_5_vl import (
39
35
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
40
36
  RBLNQwen2_5_VLForConditionalGenerationConfig,
@@ -45,12 +41,7 @@ from .qwen2_5_vl_architecture import Qwen2_5_VisionTransformerWrapper, Qwen2_5_V
45
41
  logger = get_logger(__name__)
46
42
 
47
43
  if TYPE_CHECKING:
48
- from transformers import (
49
- AutoFeatureExtractor,
50
- AutoProcessor,
51
- AutoTokenizer,
52
- PretrainedConfig,
53
- )
44
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
54
45
 
55
46
 
56
47
  class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
@@ -595,7 +586,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
595
586
  generate_idx: Optional[torch.Tensor] = None,
596
587
  return_dict: Optional[bool] = None,
597
588
  **kwargs,
598
- ) -> RBLNDecoderOnlyForCausalLMOutput:
589
+ ) -> RBLNDecoderOnlyOutput:
599
590
  # Prefill
600
591
  if cache_position is None:
601
592
  inputs_embeds, position_embed, rope_deltas = self._preprocess_prefill(
@@ -637,7 +628,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
637
628
  if not return_dict:
638
629
  return logits, generate_idx
639
630
  else:
640
- return RBLNDecoderOnlyForCausalLMOutput(
631
+ return RBLNDecoderOnlyOutput(
641
632
  logits=logits,
642
633
  generate_idx=generate_idx,
643
634
  )
@@ -4,10 +4,7 @@ from typing import Tuple
4
4
  import torch
5
5
  import torch.nn as nn
6
6
 
7
- from ..decoderonly.decoderonly_architecture import (
8
- DecoderOnlyWrapper,
9
- apply_rotary_pos_emb,
10
- )
7
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper, apply_rotary_pos_emb
11
8
 
12
9
 
13
10
  class Qwen2_5_VisionTransformerWrapper(nn.Module):
@@ -159,15 +156,15 @@ class Qwen2_5_VLVisionWindowAttention(nn.Module):
159
156
  class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
160
157
  def prepare_forward_args(self, *args):
161
158
  args = list(args)
162
- input_ids = None if self.use_inputs_embeds else args.pop(0)
163
- inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
159
+ input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
160
+ inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
164
161
  cache_position = args.pop(0)
165
162
  global_block_tables = args.pop(0)
166
163
  local_block_tables = None
167
164
  position_embeds = args.pop(0)
168
165
  query_position = args.pop(0) if self.phase == "prefill" else None
169
166
  position_ids = None
170
- attention_mask = args.pop(0) if self.use_attention_mask else None
167
+ attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
171
168
  past_key_values = args
172
169
 
173
170
  if len(past_key_values) != 2 * self.num_hidden_layers:
@@ -28,12 +28,60 @@ from .qwen3_architecture import Qwen3Wrapper
28
28
  logger = logging.get_logger(__name__)
29
29
 
30
30
  if TYPE_CHECKING:
31
- from transformers import (
32
- PretrainedConfig,
33
- )
31
+ from transformers import PretrainedConfig
34
32
 
35
33
 
36
34
  class RBLNQwen3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
35
+ """
36
+ The Qwen3 Model transformer with a language modeling head (linear layer) on top.
37
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
38
+ A class to convert and run pre-trained transformers based Qwen3ForCausalLM model on RBLN devices.
39
+ It implements the methods to convert a pre-trained transformers Qwen3ForCausalLM model into a RBLN transformer model by:
40
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
41
+ - compiling the resulting graph using the RBLN compiler.
42
+ **Configuration:**
43
+ This model uses [`RBLNQwen3ForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
44
+ the `rbln_config` parameter should be an instance of [`RBLNQwen3ForCausalLMConfig`] or a dictionary conforming to its structure.
45
+ See the [`RBLNQwen3ForCausalLMConfig`] class for all available configuration options.
46
+ Examples:
47
+ ```python
48
+ from optimum.rbln import RBLNQwen3ForCausalLM
49
+ # Simple usage using rbln_* arguments
50
+ # `max_seq_len` is automatically inferred from the model config
51
+ model = RBLNQwen3ForCausalLM.from_pretrained(
52
+ "Qwen/Qwen3-4B",
53
+ export=True,
54
+ rbln_batch_size=1,
55
+ rbln_tensor_parallel_size=4,
56
+ )
57
+ # Using a config dictionary
58
+ rbln_config = {
59
+ "batch_size": 1,
60
+ "max_seq_len": 40_960,
61
+ "tensor_parallel_size": 4,
62
+ "kvcache_partition_len": 8192,
63
+ }
64
+ model = RBLNQwen3ForCausalLM.from_pretrained(
65
+ "Qwen/Qwen3-4B",
66
+ export=True,
67
+ rbln_config=rbln_config
68
+ )
69
+ # Using a RBLNQwen3ForCausalLMConfig instance (recommended for type checking)
70
+ from optimum.rbln import RBLNQwen3ForCausalLMConfig
71
+ config = RBLNQwen3ForCausalLMConfig(
72
+ batch_size=1,
73
+ max_seq_len=40_960,
74
+ tensor_parallel_size=4,
75
+ kvcache_partition_len=8192,
76
+ )
77
+ model = RBLNQwen3ForCausalLM.from_pretrained(
78
+ "Qwen/Qwen3-4B",
79
+ export=True,
80
+ rbln_config=config
81
+ )
82
+ ```
83
+ """
84
+
37
85
  _decoder_wrapper_cls = Qwen3Wrapper
38
86
 
39
87
  @classmethod
@@ -55,5 +103,31 @@ class RBLNQwen3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
55
103
 
56
104
 
57
105
  class RBLNQwen3Model(RBLNDecoderOnlyModel):
106
+ """
107
+ The bare Qwen3 Model outputting raw hidden-states without any specific head on top.
108
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
109
+ A class to convert and run pre-trained transformers based Qwen3Model on RBLN devices.
110
+ It implements the methods to convert a pre-trained transformers Qwen3Model into a RBLN transformer model by:
111
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
112
+ - compiling the resulting graph using the RBLN compiler.
113
+ **Configuration:**
114
+ This model uses [`RBLNQwen3ModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
115
+ the `rbln_config` parameter should be an instance of [`RBLNQwen3ModelConfig`] or a dictionary conforming to its structure.
116
+ See the [`RBLNQwen3ModelConfig`] class for all available configuration options.
117
+ Examples:
118
+ ```python
119
+ from optimum.rbln import RBLNQwen3Model
120
+ # Simple usage using rbln_* arguments
121
+ # `max_seq_len` is automatically inferred from the model config
122
+ model = RBLNQwen3Model.from_pretrained(
123
+ "Qwen/Qwen3-Embedding-4B",
124
+ export=True,
125
+ rbln_batch_size=1,
126
+ rbln_max_seq_len=40_960,
127
+ rbln_tensor_parallel_size=4,
128
+ rbln_kvcache_partition_len=8192,
129
+ )
130
+ """
131
+
58
132
  _decoder_wrapper_cls = Qwen3Wrapper
59
133
  _use_rotary_emb = True
@@ -13,10 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- from ..decoderonly.decoderonly_architecture import (
17
- DecoderOnlyAttention,
18
- DecoderOnlyWrapper,
19
- )
16
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyWrapper
20
17
 
21
18
 
22
19
  class Qwen3Wrapper(DecoderOnlyWrapper):
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....utils.logging import get_logger
@@ -22,6 +22,8 @@ logger = get_logger()
22
22
 
23
23
 
24
24
  class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
25
+ support_paged_attention = None
26
+
25
27
  def __init__(
26
28
  self,
27
29
  batch_size: Optional[int] = None,
@@ -29,7 +31,9 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
29
31
  dec_max_seq_len: Optional[int] = None,
30
32
  use_attention_mask: Optional[bool] = None,
31
33
  pad_token_id: Optional[int] = None,
32
- **kwargs: Dict[str, Any],
34
+ kvcache_num_blocks: Optional[int] = None,
35
+ kvcache_block_size: Optional[int] = None,
36
+ **kwargs: Any,
33
37
  ):
34
38
  """
35
39
  Args:
@@ -38,6 +42,10 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
38
42
  dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
39
43
  use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
40
44
  pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
45
+ kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
46
+ PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
47
+ kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
48
+ in the PagedAttention KV cache for the SelfAttention. Defaults to dec_max_seq_len.
41
49
  **kwargs: Additional arguments passed to the parent RBLNModelConfig.
42
50
 
43
51
  Raises:
@@ -54,3 +62,12 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
54
62
  self.use_attention_mask = use_attention_mask
55
63
 
56
64
  self.pad_token_id = pad_token_id
65
+
66
+ if self.support_paged_attention:
67
+ self.kvcache_num_blocks = kvcache_num_blocks
68
+ self.kvcache_block_size = kvcache_block_size
69
+ else:
70
+ if kvcache_num_blocks is not None or kvcache_block_size is not None:
71
+ raise ValueError(
72
+ "You cannot set kvcache_num_blocks or kvcache_block_size as paged attention is not supported for the model."
73
+ )
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
38
38
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
39
39
  mandatory_members = ["main_input_name"]
40
40
 
41
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
41
+ def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
42
42
  output = super().forward(*args, **kwargs)
43
43
  return BaseModelOutput(last_hidden_state=output)
44
44
 
@@ -181,6 +181,21 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
181
181
 
182
182
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
183
183
 
184
+ @classmethod
185
+ def _update_paged_attention_config(cls, model_config: PretrainedConfig, rbln_config: RBLNModelForSeq2SeqLMConfig):
186
+ rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
187
+ rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
188
+
189
+ if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
190
+ raise NotImplementedError(
191
+ f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
192
+ )
193
+
194
+ if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
195
+ raise NotImplementedError(
196
+ f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
197
+ )
198
+
184
199
  @classmethod
185
200
  def _update_rbln_config(
186
201
  cls,
@@ -238,6 +253,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
238
253
  if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
239
254
  raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
240
255
 
256
+ if rbln_config.support_paged_attention:
257
+ cls._update_paged_attention_config(model_config, rbln_config)
258
+
241
259
  # model input info
242
260
  enc_input_info = [
243
261
  ("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
@@ -310,6 +328,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
310
328
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
311
329
 
312
330
  rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
331
+
313
332
  return rbln_config
314
333
 
315
334
  @classmethod
@@ -12,9 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_siglip import (
16
- RBLNSiglipVisionModelConfig,
17
- )
18
- from .modeling_siglip import (
19
- RBLNSiglipVisionModel,
20
- )
15
+ from .configuration_siglip import RBLNSiglipVisionModelConfig
16
+ from .modeling_siglip import RBLNSiglipVisionModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
15
+ from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from transformers import SiglipVisionConfig, SiglipVisionModel
@@ -126,7 +126,7 @@ class RBLNSiglipVisionModel(RBLNModel):
126
126
  output_attentions: bool = None,
127
127
  output_hidden_states: bool = None,
128
128
  interpolate_pos_encoding: bool = False,
129
- **kwargs: Dict[str, Any],
129
+ **kwargs: Any,
130
130
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
131
131
  if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
132
132
  logger.warning(
@@ -32,3 +32,5 @@ class RBLNT5ForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
32
32
  This configuration class stores the configuration parameters specific to
33
33
  RBLN-optimized T5 models for conditional text generation tasks.
34
34
  """
35
+
36
+ support_paged_attention = False
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  from ....configuration_utils import RBLNModelConfig
4
4
 
@@ -17,7 +17,7 @@ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
17
17
  enc_max_seq_len: Optional[int] = None,
18
18
  dec_max_seq_len: Optional[int] = None,
19
19
  num_parallel_samples: Optional[int] = None,
20
- **kwargs: Dict[str, Any],
20
+ **kwargs: Any,
21
21
  ):
22
22
  """
23
23
  Args:
@@ -23,24 +23,20 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from dataclasses import dataclass
27
26
  from pathlib import Path
28
- from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
27
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
29
28
 
30
29
  import rebel
31
30
  import torch
32
31
  from rebel.compile_context import CompileContext
33
- from transformers import (
34
- PretrainedConfig,
35
- TimeSeriesTransformerForPrediction,
36
- TimeSeriesTransformerModel,
37
- )
38
- from transformers.modeling_outputs import ModelOutput, SampleTSPredictionOutput, Seq2SeqTSModelOutput
32
+ from transformers import PretrainedConfig, TimeSeriesTransformerForPrediction, TimeSeriesTransformerModel
33
+ from transformers.modeling_outputs import SampleTSPredictionOutput, Seq2SeqTSModelOutput
39
34
  from transformers.modeling_utils import no_init_weights
40
35
 
41
36
  from ....configuration_utils import RBLNCompileConfig
42
37
  from ....modeling import RBLNModel
43
38
  from ....utils.runtime_utils import RBLNPytorchRuntime
39
+ from ...modeling_outputs import RBLNSeq2SeqTSDecoderOutput
44
40
  from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
45
41
  from .time_series_transformers_architecture import TimeSeriesTransformersWrapper
46
42
 
@@ -113,12 +109,6 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
113
109
  )
114
110
 
115
111
 
116
- @dataclass
117
- class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
118
- last_hidden_states: torch.FloatTensor = None
119
- params: Tuple[torch.FloatTensor] = None
120
-
121
-
122
112
  class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
123
113
  """
124
114
  The Time Series Transformer Model with a distribution head on top for time-series forecasting. e.g., for datasets like M4, NN5, or other time series forecasting benchmarks.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict
15
+ from typing import Any
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....utils.logging import get_logger
@@ -36,7 +36,9 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
36
36
  use_attention_mask: bool = None,
37
37
  enc_max_seq_len: int = None,
38
38
  dec_max_seq_len: int = None,
39
- **kwargs: Dict[str, Any],
39
+ kvcache_num_blocks: int = None,
40
+ kvcache_block_size: int = None,
41
+ **kwargs: Any,
40
42
  ):
41
43
  """
42
44
  Args:
@@ -45,6 +47,10 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
45
47
  use_attention_mask (bool, optional): Whether to use attention masks during inference. This is automatically
46
48
  enc_max_seq_len (int, optional): Maximum sequence length for the encoder.
47
49
  dec_max_seq_len (int, optional): Maximum sequence length for the decoder.
50
+ kvcache_num_blocks (int, optional): The total number of blocks to allocate for the
51
+ PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
52
+ kvcache_block_size (int, optional): Sets the size (in number of tokens) of each block
53
+ in the PagedAttention KV cache for the SelfAttention. Defaults to dec_max_seq_len.
48
54
  **kwargs: Additional arguments passed to the parent RBLNModelConfig.
49
55
 
50
56
  Raises:
@@ -62,3 +68,5 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
62
68
 
63
69
  self.use_attention_mask = use_attention_mask
64
70
  self.use_attention_mask = self.use_attention_mask or False
71
+ self.kvcache_num_blocks = kvcache_num_blocks
72
+ self.kvcache_block_size = kvcache_block_size