optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. optimum/rbln/__init__.py +156 -36
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +85 -75
  31. optimum/rbln/transformers/__init__.py +79 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +96 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  73. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  74. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  75. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  76. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  77. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  78. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  79. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  80. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  81. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  82. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  83. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  84. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  85. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  86. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  87. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  88. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  89. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  90. optimum/rbln/utils/submodule.py +26 -43
  91. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
  92. optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
  93. optimum/rbln/modeling_config.py +0 -310
  94. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  95. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -12,28 +12,24 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
15
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
- from transformers import (
19
- CLIPTextConfig,
20
- CLIPTextModel,
21
- CLIPVisionConfig,
22
- CLIPVisionModel,
23
- )
24
- from transformers.modeling_outputs import BaseModelOutputWithPooling
18
+ from transformers import CLIPTextConfig, CLIPTextModel, CLIPVisionConfig, CLIPVisionModel
25
19
  from transformers.models.clip.modeling_clip import CLIPTextModelOutput, CLIPVisionModelOutput
26
20
 
27
- from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
21
+ from ....configuration_utils import RBLNCompileConfig
28
22
  from ....modeling import RBLNModel
29
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
30
23
  from ....utils.logging import get_logger
24
+ from .configuration_clip import RBLNCLIPTextModelConfig, RBLNCLIPVisionModelConfig
31
25
 
32
26
 
33
27
  logger = get_logger(__name__)
34
28
 
35
29
  if TYPE_CHECKING:
36
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, CLIPTextModel
30
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, CLIPTextModel, PreTrainedModel
31
+
32
+ from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
37
33
 
38
34
 
39
35
  class _TextEncoder(torch.nn.Module):
@@ -48,53 +44,55 @@ class _TextEncoder(torch.nn.Module):
48
44
 
49
45
  class RBLNCLIPTextModel(RBLNModel):
50
46
  @classmethod
51
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
47
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPTextModelConfig) -> torch.nn.Module:
52
48
  return _TextEncoder(model).eval()
53
49
 
54
50
  @classmethod
55
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
51
+ def update_rbln_config_using_pipe(
52
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_config: str
53
+ ) -> "RBLNDiffusionMixinConfig":
56
54
  return rbln_config
57
55
 
58
56
  @classmethod
59
- def _get_rbln_config(
57
+ def _update_rbln_config(
60
58
  cls,
61
59
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
62
- model_config: "CLIPTextConfig",
63
- rbln_kwargs: Dict[str, Any] = {},
64
- rbln_batch_size: Optional[int] = None,
65
- ) -> RBLNConfig:
66
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
67
- if rbln_batch_size is None:
68
- rbln_batch_size = 1
69
-
70
- model_config.return_dict = False
71
-
60
+ model: Optional["PreTrainedModel"] = None,
61
+ model_config: "CLIPTextConfig" = None,
62
+ rbln_config: Optional[RBLNCLIPTextModelConfig] = None,
63
+ ) -> RBLNCLIPTextModelConfig:
72
64
  input_info = [
73
65
  (
74
66
  "input_ids",
75
67
  [
76
- rbln_batch_size,
68
+ rbln_config.batch_size,
77
69
  model_config.max_position_embeddings,
78
70
  ],
79
71
  "int64",
80
72
  ),
81
73
  ]
82
74
 
83
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
84
- rbln_config = RBLNConfig(
85
- rbln_cls=cls.__name__,
86
- compile_cfgs=[rbln_compile_config],
87
- rbln_kwargs=rbln_kwargs,
88
- )
75
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
89
76
  return rbln_config
90
77
 
91
- def forward(self, input_ids: "torch.Tensor", **kwargs):
92
- text_output = super().forward(input_ids)
93
- return CLIPTextModelOutput(
94
- text_embeds=text_output[0],
95
- last_hidden_state=text_output[1],
96
- hidden_states=text_output[2:],
97
- )
78
+ def forward(self, input_ids: torch.LongTensor, return_dict: bool = None, **kwargs) -> torch.FloatTensor:
79
+ # To ignore using attention_mask, we override forward method.
80
+ output = super().forward(input_ids, return_dict=return_dict)
81
+ return output
82
+
83
+ def _prepare_output(self, output, return_dict):
84
+ """
85
+ Prepare model output based on return_dict flag.
86
+ This method can be overridden by subclasses to provide task-specific output handling.
87
+ """
88
+ if not return_dict:
89
+ return (output,) if not isinstance(output, (tuple, list)) else output
90
+ else:
91
+ return CLIPTextModelOutput(
92
+ text_embeds=output[0],
93
+ last_hidden_state=output[1],
94
+ hidden_states=output[2:],
95
+ )
98
96
 
99
97
 
100
98
  class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
@@ -113,30 +111,30 @@ class _VisionEncoder(torch.nn.Module):
113
111
 
114
112
  class RBLNCLIPVisionModel(RBLNModel):
115
113
  @classmethod
116
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
114
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
117
115
  return _VisionEncoder(model).eval()
118
116
 
119
117
  @classmethod
120
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
118
+ def update_rbln_config_using_pipe(
119
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
120
+ ) -> "RBLNDiffusionMixinConfig":
121
121
  return rbln_config
122
122
 
123
123
  @classmethod
124
- def _get_rbln_config(
124
+ def _update_rbln_config(
125
125
  cls,
126
126
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
127
- model_config: "CLIPVisionConfig",
128
- rbln_kwargs: Dict[str, Any] = {},
129
- ) -> RBLNConfig:
130
- rbln_batch_size = rbln_kwargs.get("batch_size", 1)
131
- rbln_image_size = rbln_kwargs.get("image_size", None)
132
-
133
- if rbln_image_size is None:
134
- rbln_image_size = getattr(model_config, "image_size", None)
127
+ model: Optional["PreTrainedModel"] = None,
128
+ model_config: "CLIPVisionConfig" = None,
129
+ rbln_config: Optional[RBLNCLIPVisionModelConfig] = None,
130
+ ) -> RBLNCLIPVisionModelConfig:
131
+ if rbln_config.image_size is None:
132
+ rbln_config.image_size = getattr(model_config, "image_size", None)
135
133
 
136
- if isinstance(rbln_image_size, int):
137
- rbln_image_size = (rbln_image_size, rbln_image_size)
134
+ if isinstance(rbln_config.image_size, int):
135
+ rbln_config.image_size = (rbln_config.image_size, rbln_config.image_size)
138
136
 
139
- if rbln_image_size is None:
137
+ if rbln_config.image_size is None:
140
138
  raise ValueError("`rbln_image_size` should be specified!")
141
139
 
142
140
  rbln_compile_config = RBLNCompileConfig(
@@ -144,45 +142,44 @@ class RBLNCLIPVisionModel(RBLNModel):
144
142
  (
145
143
  "pixel_values",
146
144
  [
147
- rbln_batch_size,
145
+ rbln_config.batch_size,
148
146
  3,
149
- rbln_image_size[0],
150
- rbln_image_size[1],
147
+ rbln_config.image_height,
148
+ rbln_config.image_width,
151
149
  ],
152
150
  "float32",
153
151
  )
154
152
  ]
155
153
  )
156
154
 
157
- rbln_config = RBLNConfig(
158
- rbln_cls=cls.__name__,
159
- compile_cfgs=[rbln_compile_config],
160
- rbln_kwargs=rbln_kwargs,
161
- )
162
-
163
- rbln_config.model_cfg.update(
164
- {
165
- "batch_size": rbln_batch_size,
166
- "image_size": rbln_image_size,
167
- }
168
- )
169
-
155
+ rbln_config.set_compile_cfgs([rbln_compile_config])
170
156
  return rbln_config
171
157
 
172
158
  def forward(
173
159
  self,
174
160
  pixel_values: Optional[torch.FloatTensor] = None,
161
+ return_dict: bool = None,
175
162
  **kwargs,
176
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
163
+ ) -> Union[Tuple, CLIPVisionModelOutput]:
177
164
  if len(kwargs) > 0 and any(kwargs.values()):
178
165
  logger.warning(f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__}.")
179
166
 
180
- output = super().forward(pixel_values)
181
- return BaseModelOutputWithPooling(
182
- last_hidden_state=output[0],
183
- pooler_output=output[1],
184
- hidden_states=output[2:],
185
- )
167
+ output = super().forward(pixel_values, return_dict=return_dict)
168
+ return output
169
+
170
+ def _prepare_output(self, output, return_dict):
171
+ """
172
+ Prepare model output based on return_dict flag.
173
+ This method can be overridden by subclasses to provide task-specific output handling.
174
+ """
175
+ if not return_dict:
176
+ return (output,) if not isinstance(output, (tuple, list)) else output
177
+ else:
178
+ return CLIPVisionModelOutput(
179
+ image_embeds=output[0],
180
+ last_hidden_state=output[1],
181
+ hidden_states=output[2:],
182
+ )
186
183
 
187
184
 
188
185
  class RBLNCLIPVisionModelWithProjection(RBLNCLIPVisionModel):
@@ -22,4 +22,5 @@ from ....ops import (
22
22
  paged_flash_causal_attn_decode,
23
23
  paged_flash_causal_attn_prefill,
24
24
  )
25
+ from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
25
26
  from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
@@ -0,0 +1,90 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ import rebel
18
+
19
+ from ....configuration_utils import RBLNModelConfig
20
+ from ....utils.logging import get_logger
21
+ from ...utils.rbln_quantization import QuantizationManager
22
+
23
+
24
+ logger = get_logger()
25
+
26
+
27
+ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
28
+ def __init__(
29
+ self,
30
+ batch_size: Optional[int] = None,
31
+ max_seq_len: Optional[int] = None,
32
+ use_inputs_embeds: Optional[bool] = None,
33
+ use_attention_mask: Optional[bool] = None,
34
+ attn_impl: Optional[str] = None,
35
+ kvcache_partition_len: Optional[int] = None,
36
+ kvcache_block_size: Optional[int] = None,
37
+ quantization: Optional[Dict[str, Any]] = None,
38
+ prefill_chunk_size: Optional[int] = None,
39
+ kvcache_num_blocks: Optional[int] = None,
40
+ **kwargs,
41
+ ):
42
+ """
43
+ Args:
44
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
45
+ max_seq_len (Optional[int]): The maximum sequence length supported by the model.
46
+ use_inputs_embeds (Optional[bool]): Whether to use input embeddings directly. Defaults to False.
47
+ use_attention_mask (Optional[bool]): Whether to use attention masks. This is automatically set to True
48
+ for RBLN-CA02 devices.
49
+ attn_impl (Optional[str]): The attention implementation to use.
50
+ kvcache_partition_len (Optional[int]): The length of each KV cache partition.
51
+ kvcache_block_size (Optional[int]): The block size for KV cache.
52
+ quantization (Optional[Dict[str, Any]]): Configuration for model quantization.
53
+ prefill_chunk_size (Optional[int]): The chunk size for prefilling the KV cache. Defaults to 128,
54
+ and must be a positive integer divisible by 64.
55
+ kvcache_num_blocks (Optional[int]): The number of blocks in the KV cache.
56
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
57
+
58
+ Raises:
59
+ ValueError: If batch_size is not a positive integer or if prefill_chunk_size is not
60
+ a positive integer divisible by 64.
61
+ """
62
+ super().__init__(**kwargs)
63
+ self.batch_size = batch_size or 1
64
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
65
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
66
+
67
+ self.max_seq_len = max_seq_len
68
+ self.use_inputs_embeds = use_inputs_embeds or False
69
+
70
+ self.use_attention_mask = use_attention_mask
71
+ npu = self.npu or rebel.get_npu_name()
72
+ if npu == "RBLN-CA02":
73
+ if self.use_attention_mask is False:
74
+ logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
75
+ self.use_attention_mask = True
76
+ else:
77
+ self.use_attention_mask = self.use_attention_mask or False
78
+
79
+ self.attn_impl = attn_impl
80
+ self.kvcache_partition_len = kvcache_partition_len
81
+ self.kvcache_block_size = kvcache_block_size
82
+ self.quantization = quantization or {}
83
+ if self.quantization:
84
+ QuantizationManager.validate_quantization_config(self.quantization)
85
+
86
+ self.prefill_chunk_size = prefill_chunk_size or 128
87
+ if self.prefill_chunk_size % 64 != 0 or self.prefill_chunk_size <= 0:
88
+ raise ValueError("`prefill_chunk_size` must be a positive integer divisible by 64.")
89
+
90
+ self.kvcache_num_blocks = kvcache_num_blocks
@@ -32,30 +32,39 @@ MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
32
32
  MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
33
33
 
34
34
 
35
- def validate_attention_method(
36
- rbln_attn_impl: str, rbln_kvcache_partition_len: int, rbln_kvcache_block_size: int, rbln_max_seq_len: int
37
- ) -> Tuple[str, int]:
38
- if rbln_kvcache_partition_len is not None:
39
- if rbln_attn_impl == "eager":
40
- raise ValueError(
41
- f"`rbln_kvcache_partition_len` is set to {rbln_kvcache_partition_len}, but KV cache partitioning"
42
- " is not supported with 'eager' attention. Please set `rbln_kvcache_partition_len` to None, "
43
- "or switch `rbln_attn_impl` to 'flash_attn' to use KV cache partitioning."
44
- )
45
- elif rbln_attn_impl is None:
46
- rbln_attn_impl = "flash_attn"
35
+ def set_default_values(
36
+ attn_impl: Optional[str] = None,
37
+ kvcache_partition_len: Optional[int] = None,
38
+ kvcache_block_size: Optional[int] = None,
39
+ max_seq_len: Optional[int] = None,
40
+ ) -> Tuple[str, int, int]:
41
+ if attn_impl is None:
42
+ attn_impl = "eager"
43
+
44
+ if kvcache_partition_len is not None:
45
+ if attn_impl == "eager":
46
+ attn_impl = "flash_attn"
47
47
  logger.warning(
48
- "A non-null `rbln_kvcache_partition_len` was provided, but `rbln_attn_impl` was not explicitly set. "
49
- "Since KV cache partitioning is only supported with flash attention, "
50
- "`rbln_attn_impl` has been automatically switched to 'flash_attn'."
48
+ "A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
49
+ "set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
50
+ "`attn_impl` has been automatically switched to 'flash_attn'."
51
51
  )
52
52
 
53
- rbln_attn_impl = "eager" if rbln_attn_impl is None else rbln_attn_impl
54
- if rbln_attn_impl not in ["eager", "flash_attn"]:
55
- raise ValueError(f"Unknown `rbln_attn_impl` : {rbln_attn_impl}. (Available : 'eager', 'flash_attn`)")
53
+ if kvcache_partition_len is None and attn_impl == "flash_attn":
54
+ kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
55
+
56
+ if kvcache_block_size is None:
57
+ if attn_impl == "eager":
58
+ kvcache_block_size = max_seq_len
59
+ else:
60
+ kvcache_block_size = kvcache_partition_len
56
61
 
57
- if rbln_kvcache_partition_len is None and rbln_attn_impl == "flash_attn":
58
- rbln_kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
62
+ return attn_impl, kvcache_partition_len, kvcache_block_size
63
+
64
+
65
+ def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
66
+ if attn_impl not in ["eager", "flash_attn"]:
67
+ raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
59
68
 
60
69
  ## Checking Constraints...
61
70
  # Constraint of eager attention:
@@ -65,47 +74,45 @@ def validate_attention_method(
65
74
  # 1. `max_seq_len` should be multiple of `partition_len`.
66
75
  # 2. 4k <= `partition_len` <= 32k.
67
76
  # 3. `max_seq_len` should be larger then 8k.
68
- if rbln_attn_impl == "eager" and rbln_max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
77
+ if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
69
78
  raise ValueError(
70
- f"`rbln_max_seq_len` is set to {rbln_max_seq_len}, "
79
+ f"`max_seq_len` is set to {max_seq_len}, "
71
80
  f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
72
- f"Please reduce the `rbln_max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
73
- " or consider switching `rbln_attn_impl` to 'flash_attn' for larger sequence lengths."
81
+ f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
82
+ " or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
74
83
  )
75
84
 
76
- if rbln_attn_impl == "flash_attn":
77
- if rbln_max_seq_len // rbln_kvcache_partition_len < 2 or rbln_max_seq_len % rbln_kvcache_partition_len != 0:
85
+ if attn_impl == "flash_attn":
86
+ if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
78
87
  raise ValueError(
79
- f"`rbln_max_seq_len` ({rbln_max_seq_len}) must be a multiple of `rbln_kvcache_partition_len` ({rbln_kvcache_partition_len}) "
88
+ f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
80
89
  f"when using 'flash_attn'. Please adjust either value to meet this requirement."
81
90
  )
82
- elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= rbln_kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
91
+ elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
83
92
  raise ValueError(
84
- f"`rbln_kvcache_partition_len` ({rbln_kvcache_partition_len}) is out of the supported range for 'flash_attn' "
85
- f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `rbln_kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
93
+ f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
94
+ f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
86
95
  f"Please provide a valid value within this range."
87
96
  )
88
- elif rbln_max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
97
+ elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
89
98
  raise ValueError(
90
- f"`rbln_max_seq_len` ({rbln_max_seq_len}) is too small for 'flash_attn'. The minimum "
91
- f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `rbln_max_seq_len` to meet "
92
- "this requirement, or consider switching `rbln_attn_impl` to 'eager' for shorter lengths."
99
+ f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
100
+ f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
101
+ "this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
93
102
  )
94
103
 
95
- if rbln_kvcache_block_size is not None:
96
- if rbln_attn_impl == "flash_attn" and rbln_kvcache_partition_len != rbln_kvcache_block_size:
104
+ if kvcache_block_size is not None:
105
+ if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
97
106
  raise ValueError(
98
- f" When using 'flash attention', the `rbln_kvcache_block_size` ({rbln_kvcache_block_size}) "
99
- f"must always be set equal to the `rbln_kvcache_partition_len` {rbln_kvcache_partition_len}."
107
+ f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
108
+ f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
100
109
  )
101
- elif rbln_attn_impl == "eager" and rbln_kvcache_block_size != rbln_max_seq_len:
110
+ elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
102
111
  raise ValueError(
103
- f" When using 'eager attention', the `rbln_kvcache_block_size` ({rbln_kvcache_block_size}) "
104
- f"must always be set equal to the `rbln_max_seq_len` {rbln_max_seq_len}."
112
+ f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
113
+ f"must always be set equal to the `max_seq_len` {max_seq_len}."
105
114
  )
106
115
 
107
- return rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size
108
-
109
116
 
110
117
  class DecoderOnlyWrapper(nn.Module):
111
118
  """A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.