optimum-rbln 0.8.1rc1__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (119) hide show
  1. optimum/rbln/__init__.py +58 -9
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +24 -5
  4. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +5 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  18. optimum/rbln/diffusers/modeling_diffusers.py +4 -5
  19. optimum/rbln/diffusers/models/__init__.py +3 -13
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -0
  21. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1 -0
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +1 -0
  23. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +1 -1
  24. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  25. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +12 -4
  26. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +4 -28
  27. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  28. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  29. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  30. optimum/rbln/modeling.py +4 -5
  31. optimum/rbln/modeling_base.py +18 -14
  32. optimum/rbln/ops/kv_cache_update.py +5 -0
  33. optimum/rbln/ops/linear.py +7 -0
  34. optimum/rbln/transformers/__init__.py +60 -0
  35. optimum/rbln/transformers/configuration_generic.py +4 -4
  36. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  37. optimum/rbln/transformers/modeling_generic.py +1 -4
  38. optimum/rbln/transformers/models/__init__.py +45 -30
  39. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  40. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  41. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  42. optimum/rbln/transformers/models/clip/configuration_clip.py +14 -3
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +123 -28
  44. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  45. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  46. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  47. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  48. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
  49. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -454
  50. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +579 -362
  51. optimum/rbln/transformers/models/exaone/exaone_architecture.py +17 -42
  52. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  53. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  54. optimum/rbln/transformers/models/gemma/gemma_architecture.py +3 -44
  55. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  56. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +21 -9
  57. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +9 -63
  58. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +200 -292
  59. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  60. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  61. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +19 -24
  62. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  63. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  64. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  65. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  66. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  67. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  68. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  69. optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
  70. optimum/rbln/transformers/models/llava/modeling_llava.py +419 -0
  71. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +20 -3
  72. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  73. optimum/rbln/transformers/models/midm/midm_architecture.py +14 -22
  74. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  75. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  76. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  77. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  78. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  79. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  80. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  81. optimum/rbln/transformers/models/opt/opt_architecture.py +16 -25
  82. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  83. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  84. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  85. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  86. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  87. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  88. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  89. optimum/rbln/transformers/models/phi/phi_architecture.py +16 -22
  90. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  91. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  92. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +315 -0
  93. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  94. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  95. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  96. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  97. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  98. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -15
  99. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +1 -4
  100. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  101. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  102. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  103. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  104. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -12
  105. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  106. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  107. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  108. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  109. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -5
  110. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -12
  111. optimum/rbln/transformers/models/whisper/modeling_whisper.py +8 -2
  112. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  113. optimum/rbln/utils/depreacate_utils.py +16 -0
  114. optimum/rbln/utils/hub.py +8 -47
  115. optimum/rbln/utils/runtime_utils.py +31 -5
  116. {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2.dist-info}/METADATA +1 -1
  117. {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2.dist-info}/RECORD +119 -102
  118. {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2.dist-info}/WHEEL +0 -0
  119. {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2.dist-info}/licenses/LICENSE +0 -0
@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from transformers import CLIPTextConfig, CLIPTextModel, CLIPVisionConfig, CLIPVisionModel
19
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
19
20
  from transformers.models.clip.modeling_clip import CLIPTextModelOutput, CLIPVisionModelOutput
20
21
 
21
22
  from ....configuration_utils import RBLNCompileConfig
@@ -111,12 +112,27 @@ class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
111
112
 
112
113
 
113
114
  class _VisionEncoder(torch.nn.Module):
114
- def __init__(self, enc: CLIPVisionModel):
115
+ def __init__(
116
+ self,
117
+ enc: CLIPVisionModel,
118
+ interpolate_pos_encoding: bool,
119
+ output_hidden_states: bool,
120
+ output_attentions: bool,
121
+ ):
115
122
  super().__init__()
116
123
  self.enc = enc
124
+ self.interpolate_pos_encoding = interpolate_pos_encoding
125
+ self.output_hidden_states = output_hidden_states
126
+ self.output_attentions = output_attentions
117
127
 
118
128
  def forward(self, inp):
119
- enc_out = self.enc(inp, output_hidden_states=True, return_dict=False)
129
+ enc_out = self.enc(
130
+ inp,
131
+ output_hidden_states=self.output_hidden_states,
132
+ interpolate_pos_encoding=self.interpolate_pos_encoding,
133
+ output_attentions=self.output_attentions,
134
+ return_dict=False,
135
+ )
120
136
  return enc_out
121
137
 
122
138
 
@@ -130,7 +146,12 @@ class RBLNCLIPVisionModel(RBLNModel):
130
146
 
131
147
  @classmethod
132
148
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
133
- return _VisionEncoder(model).eval()
149
+ wrapper_cfg = {
150
+ "interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
151
+ "output_hidden_states": rbln_config.output_hidden_states,
152
+ "output_attentions": rbln_config.output_attentions,
153
+ }
154
+ return _VisionEncoder(model, **wrapper_cfg).eval()
134
155
 
135
156
  @classmethod
136
157
  def update_rbln_config_using_pipe(
@@ -155,6 +176,12 @@ class RBLNCLIPVisionModel(RBLNModel):
155
176
  if rbln_config.image_size is None:
156
177
  raise ValueError("`rbln_image_size` should be specified!")
157
178
 
179
+ if rbln_config.output_attentions is None:
180
+ rbln_config.output_attentions = getattr(model_config, "output_attentions", False)
181
+
182
+ if rbln_config.output_hidden_states is None:
183
+ rbln_config.output_hidden_states = getattr(model_config, "output_hidden_states", False)
184
+
158
185
  rbln_compile_config = RBLNCompileConfig(
159
186
  input_info=[
160
187
  (
@@ -176,27 +203,76 @@ class RBLNCLIPVisionModel(RBLNModel):
176
203
  def forward(
177
204
  self,
178
205
  pixel_values: Optional[torch.FloatTensor] = None,
179
- return_dict: bool = None,
206
+ return_dict: bool = True,
207
+ output_attentions: bool = None,
208
+ output_hidden_states: bool = None,
209
+ interpolate_pos_encoding: bool = False,
180
210
  **kwargs,
181
- ) -> Union[Tuple, CLIPVisionModelOutput]:
211
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
182
212
  if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
183
213
  logger.warning(
184
214
  f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
185
215
  )
216
+
217
+ output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
218
+ output_hidden_states = (
219
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
220
+ )
221
+
222
+ if output_attentions != self.rbln_config.output_attentions:
223
+ raise ValueError(
224
+ f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
225
+ f"Please compile again with the correct argument."
226
+ )
227
+
228
+ if output_hidden_states != self.rbln_config.output_hidden_states:
229
+ raise ValueError(
230
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
231
+ f"Please compile again with the correct argument."
232
+ )
233
+
234
+ if interpolate_pos_encoding != self.rbln_config.interpolate_pos_encoding:
235
+ raise ValueError(
236
+ f"Variable interpolate_pos_encoding {interpolate_pos_encoding} is not equal to rbln_config.interpolate_pos_encoding {self.rbln_config.interpolate_pos_encoding} "
237
+ f"Please compile again with the correct argument."
238
+ )
239
+
186
240
  output = super().forward(pixel_values, return_dict=return_dict)
187
241
  return output
188
242
 
189
243
  def _prepare_output(self, output, return_dict):
190
244
  # Prepare model output based on return_dict flag.
191
245
  # This method can be overridden by subclasses to provide task-specific output handling.
246
+ last_hidden_state = output.pop(0)
247
+ pooler_output = output.pop(0)
248
+ vision_config = self.config.vision_config if hasattr(self.config, "vision_config") else self.config
249
+
250
+ if self.rbln_config.output_hidden_states:
251
+ hidden_states = ()
252
+ num_hidden_layers = vision_config.num_hidden_layers
253
+ for _ in range(num_hidden_layers + 1):
254
+ hidden_states += (output.pop(0),)
255
+ else:
256
+ hidden_states = None
257
+
258
+ if self.rbln_config.output_attentions:
259
+ attentions = ()
260
+ num_hidden_layers = vision_config.num_hidden_layers
261
+ for _ in range(num_hidden_layers):
262
+ attentions += (output.pop(0),)
263
+ else:
264
+ attentions = None
192
265
 
193
266
  if not return_dict:
194
- return (output,) if not isinstance(output, (tuple, list)) else output
267
+ return tuple(
268
+ item for item in (last_hidden_state, pooler_output, hidden_states, attentions) if item is not None
269
+ )
195
270
  else:
196
- return CLIPVisionModelOutput(
197
- image_embeds=output[0],
198
- last_hidden_state=output[1],
199
- hidden_states=output[2:],
271
+ return BaseModelOutputWithPooling(
272
+ last_hidden_state=last_hidden_state,
273
+ pooler_output=pooler_output,
274
+ hidden_states=hidden_states,
275
+ attentions=attentions,
200
276
  )
201
277
 
202
278
 
@@ -208,21 +284,40 @@ class RBLNCLIPVisionModelWithProjection(RBLNCLIPVisionModel):
208
284
  multimodal embedding alignment tasks.
209
285
  """
210
286
 
211
- def forward(
212
- self,
213
- pixel_values: Optional[torch.FloatTensor] = None,
214
- **kwargs,
215
- ) -> Union[Tuple, CLIPVisionModelOutput]:
216
- if len(kwargs) > 0 and any(kwargs.values()):
217
- logger.warning(f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__}.")
218
-
219
- output = super().forward(pixel_values)
220
- image_embeds = output[0]
221
- last_hidden_state = output[1]
222
- hidden_states = output[2:]
223
-
224
- return CLIPVisionModelOutput(
225
- image_embeds=image_embeds,
226
- last_hidden_state=last_hidden_state,
227
- hidden_states=hidden_states,
228
- )
287
+ def _prepare_output(self, output, return_dict):
288
+ # Prepare model output based on return_dict flag.
289
+ # This method can be overridden by subclasses to provide task-specific output handling.
290
+
291
+ image_embeds = output.pop(0) if isinstance(output, (tuple, list)) else output
292
+ last_hidden_state = output.pop(0)
293
+
294
+ vision_config = self.config.vision_config if hasattr(self.config, "vision_config") else self.config
295
+
296
+ if self.rbln_config.output_hidden_states:
297
+ hidden_states = ()
298
+ num_hidden_layers = vision_config.num_hidden_layers
299
+ for _ in range(num_hidden_layers + 1):
300
+ hidden_states += (output.pop(0),)
301
+ else:
302
+ hidden_states = None
303
+
304
+ if self.rbln_config.output_attentions:
305
+ attentions = ()
306
+ num_hidden_layers = vision_config.num_hidden_layers
307
+ for _ in range(num_hidden_layers):
308
+ attentions += (output.pop(0),)
309
+ else:
310
+ attentions = None
311
+
312
+ if not return_dict:
313
+ return tuple(
314
+ item for item in (image_embeds, last_hidden_state, hidden_states, attentions) if item is not None
315
+ )
316
+
317
+ else:
318
+ return CLIPVisionModelOutput(
319
+ image_embeds=image_embeds,
320
+ last_hidden_state=last_hidden_state,
321
+ hidden_states=hidden_states,
322
+ attentions=attentions,
323
+ )
@@ -4,10 +4,7 @@ import torch
4
4
  from torch import nn
5
5
  from transformers import GemmaForCausalLM, GemmaModel
6
6
 
7
- from ..decoderonly.decoderonly_architecture import (
8
- RotaryEmbedding,
9
- apply_rotary_pos_emb,
10
- )
7
+ from ..decoderonly.decoderonly_architecture import RotaryEmbedding, apply_rotary_pos_emb
11
8
 
12
9
 
13
10
  def slice_and_unsqueeze_cos_sin(cos, sin, position_ids):
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import List, Optional, Union
14
+ from typing import Any, List, Optional, Union
15
15
 
16
16
  from ....configuration_utils import RBLNModelConfig
17
17
 
@@ -50,7 +50,7 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
50
50
  max_seq_lens: Union[int, List[int]] = None,
51
51
  output_hidden_states: Optional[bool] = None,
52
52
  vision_tower: Optional[RBLNModelConfig] = None,
53
- **kwargs,
53
+ **kwargs: Any,
54
54
  ):
55
55
  """
56
56
  Args:
@@ -17,10 +17,7 @@ from pathlib import Path
17
17
  from typing import TYPE_CHECKING, Any, Optional, Union
18
18
 
19
19
  import torch
20
- from transformers import (
21
- PretrainedConfig,
22
- PreTrainedModel,
23
- )
20
+ from transformers import PretrainedConfig, PreTrainedModel
24
21
  from transformers.modeling_outputs import BaseModelOutputWithPooling
25
22
  from transformers.modeling_utils import no_init_weights
26
23
  from transformers.models.colpali.modeling_colpali import ColPaliForRetrievalOutput
@@ -32,12 +29,7 @@ from .colpali_architecture import RBLNColPaliForRetrievalWrapper
32
29
 
33
30
 
34
31
  if TYPE_CHECKING:
35
- from transformers import (
36
- AutoFeatureExtractor,
37
- AutoProcessor,
38
- AutoTokenizer,
39
- PretrainedConfig,
40
- )
32
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
41
33
 
42
34
 
43
35
  class LoopVisionTower:
@@ -22,5 +22,5 @@ from ....ops import (
22
22
  paged_flash_causal_attn_decode,
23
23
  paged_flash_causal_attn_prefill,
24
24
  )
25
- from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
26
- from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
25
+ from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
26
+ from .modeling_decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
@@ -12,9 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, List, Literal, Optional, Union
16
-
17
- import rebel
15
+ from typing import Any, Dict, List, Literal, Optional, Union, get_args
18
16
 
19
17
  from ....configuration_utils import RBLNModelConfig
20
18
  from ....utils.logging import get_logger
@@ -24,11 +22,12 @@ from ...utils.rbln_quantization import RBLNQuantizationConfig
24
22
  logger = get_logger()
25
23
 
26
24
  CacheImplType = Literal["static", "sliding_window", "hybrid"]
25
+ PhaseType = Literal["prefill", "image_prefill", "decode"]
27
26
 
28
27
 
29
- class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
28
+ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
30
29
  """
31
- Configuration class for RBLN decoder-only models for Causal Language Modeling.
30
+ Configuration class for RBLN decoder-only models.
32
31
 
33
32
  This class extends RBLNModelConfig with parameters specific to decoder-only transformer
34
33
  architectures optimized for RBLN devices. It controls aspects like attention implementation,
@@ -48,7 +47,6 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
48
47
  quantization: Optional[Union[Dict[str, Any], RBLNQuantizationConfig]] = None,
49
48
  prefill_chunk_size: Optional[int] = None,
50
49
  kvcache_num_blocks: Optional[int] = None,
51
- decoder_batch_sizes: Optional[List[int]] = None,
52
50
  cache_impl: Optional[CacheImplType] = None,
53
51
  sliding_window: Optional[int] = None,
54
52
  sliding_window_layers: Optional[List[int]] = None,
@@ -76,20 +74,12 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
76
74
  kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
77
75
  in the PagedAttention KV cache. See the "KV Cache Block Size (`kvcache_block_size`)"
78
76
  section below for details.
79
- quantization (Optional[Dict[str, Any]]): Configuration dictionary for applying model
80
- quantization. Specifies format, etc.
81
77
  prefill_chunk_size (Optional[int]): The chunk size used during the prefill phase for
82
78
  processing input sequences. Defaults to 128. Must be a positive integer
83
79
  divisible by 64. Affects prefill performance and memory usage.
84
80
  kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
85
81
  PagedAttention KV cache. See the "KV Cache Number of Blocks (`kvcache_num_blocks`)"
86
82
  section below for details.
87
- decoder_batch_sizes (Optional[List[int]]): A list of batch sizes for which separate decoder models will be compiled.
88
- This allows the model to handle varying batch sizes efficiently during generation. If not specified,
89
- defaults to a list containing only the model's main batch size. When specifying multiple batch sizes:
90
- 1) All values must be less than or equal to the main batch size.
91
- 2) The list will be sorted in descending order (larger batch sizes first).
92
- 3) If using multiple decoders, at least one batch size should match the main batch size.
93
83
  cache_impl (Optional[CacheImplType]): Specifies the KV cache implementation strategy. Defaults to "static".
94
84
  - "static": Uses a fixed-size global KV cache for all layers, suitable for standard attention patterns.
95
85
  - "sliding_window": Implements a sliding window KV cache, where each layer maintains a local cache of recent tokens.
@@ -166,58 +156,237 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
166
156
  self.batch_size = batch_size or 1
167
157
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
168
158
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
159
+ if self.batch_size > 1:
160
+ raise NotImplementedError("Batch size > 1 is not supported for RBLNDecoderOnlyModel.")
169
161
 
170
162
  self.max_seq_len = max_seq_len
171
163
  self.use_inputs_embeds = use_inputs_embeds or False
172
164
  self.use_position_ids = use_position_ids or False
173
- self.use_attention_mask = use_attention_mask
174
-
175
- npu = self.npu or rebel.get_npu_name()
176
- if npu == "RBLN-CA02":
177
- if self.use_attention_mask is False:
178
- logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
179
- self.use_attention_mask = True
180
- else:
181
- self.use_attention_mask = self.use_attention_mask or False
165
+ self.use_attention_mask = use_attention_mask or False
182
166
 
183
167
  if self.use_position_ids and not self.use_attention_mask:
184
168
  raise ValueError("Position IDs should be used with attention mask.")
185
169
 
186
- self.attn_impl = attn_impl
187
- self.kvcache_partition_len = kvcache_partition_len
188
- self.kvcache_block_size = kvcache_block_size
189
170
  self.quantization = quantization or {}
190
171
  if self.quantization and isinstance(self.quantization, dict):
191
172
  self.quantization = RBLNQuantizationConfig(**self.quantization)
192
173
 
174
+ self.attn_impl = attn_impl
175
+ self.kvcache_partition_len = kvcache_partition_len
176
+ self.kvcache_block_size = kvcache_block_size
193
177
  self.prefill_chunk_size = prefill_chunk_size or 128
194
178
  if self.prefill_chunk_size % 64 != 0 or self.prefill_chunk_size <= 0:
195
179
  raise ValueError("`prefill_chunk_size` must be a positive integer divisible by 64.")
196
180
 
197
181
  self.kvcache_num_blocks = kvcache_num_blocks
198
- self.decoder_batch_sizes = decoder_batch_sizes
199
- if self.decoder_batch_sizes is None:
200
- self.decoder_batch_sizes = [self.batch_size]
201
-
202
- if self.use_multiple_decoder:
203
- if max(self.decoder_batch_sizes) > self.batch_size:
204
- raise ValueError(
205
- f"Decoder batch size ({max(self.decoder_batch_sizes)}) must be less than or equal to the runtime batch size ({self.batch_size})."
206
- )
207
- if max(self.decoder_batch_sizes) < self.batch_size:
208
- logger.warning(
209
- f"Maximum decoder batch size ({max(self.decoder_batch_sizes)}) is less than the model's batch size ({self.batch_size}). "
210
- "Appending the model's batch size to the decoder batch size."
211
- )
212
- self.decoder_batch_sizes.append(self.batch_size)
213
-
214
- # Larger batch size should be at the beginning of the list.
215
- self.decoder_batch_sizes.sort(reverse=True)
216
-
217
182
  self.cache_impl = cache_impl or "static"
218
183
  self.sliding_window = sliding_window
219
184
  self.sliding_window_layers = sliding_window_layers or []
220
185
 
186
+ @property
187
+ def use_global_attention(self):
188
+ return self.cache_impl in ["static", "hybrid"]
189
+
190
+ @property
191
+ def use_local_attention(self):
192
+ return self.cache_impl in ["sliding_window", "hybrid"]
193
+
194
+
195
+ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNDecoderOnlyModelConfig):
196
+ """
197
+ Configuration class for RBLN decoder-only models for Causal Language Modeling.
198
+
199
+ This class extends RBLNModelConfig with parameters specific to decoder-only transformer
200
+ architectures optimized for RBLN devices. It controls aspects like attention implementation,
201
+ KV cache management, and batching for inference.
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ batch_size: Optional[int] = None,
207
+ max_seq_len: Optional[int] = None,
208
+ use_inputs_embeds: Optional[bool] = None,
209
+ use_attention_mask: Optional[bool] = None,
210
+ use_position_ids: Optional[bool] = None,
211
+ attn_impl: Optional[str] = None,
212
+ kvcache_partition_len: Optional[int] = None,
213
+ kvcache_block_size: Optional[int] = None,
214
+ quantization: Optional[Union[Dict[str, Any], RBLNQuantizationConfig]] = None,
215
+ prefill_chunk_size: Optional[int] = None,
216
+ kvcache_num_blocks: Optional[int] = None,
217
+ decoder_batch_sizes: Optional[List[int]] = None,
218
+ cache_impl: Optional[CacheImplType] = None,
219
+ sliding_window: Optional[int] = None,
220
+ sliding_window_layers: Optional[List[int]] = None,
221
+ phases: Optional[List[PhaseType]] = None,
222
+ **kwargs,
223
+ ):
224
+ """
225
+ Args:
226
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
227
+ max_seq_len (Optional[int]): The maximum sequence length supported by the model.
228
+ If not provided, it attempts to infer from the model's configuration
229
+ (`max_position_embeddings` or `n_positions`). Must be specified if not available
230
+ in the model config.
231
+ use_inputs_embeds (Optional[bool]): Whether to use input embeddings (`inputs_embeds`)
232
+ directly instead of `input_ids`. Defaults to False. Requires the model to be
233
+ compiled with this option enabled.
234
+ use_attention_mask (Optional[bool]): Whether the model requires attention masks during
235
+ inference. This is typically determined based on the target device and model
236
+ architecture. Defaults are often set automatically based on the model and RBLN NPU.
237
+ use_position_ids (Optional[bool]): Whether to use position IDs. Defaults to False.
238
+ attn_impl (Optional[str]): Specifies the attention implementation to use.
239
+ See the "Attention Implementation (`attn_impl`)" section below for details.
240
+ kvcache_partition_len (Optional[int]): Defines the partition length for the KV cache
241
+ when using "flash_attn". See the "KV Cache Partition Length (`kvcache_partition_len`)"
242
+ section below for details.
243
+ kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
244
+ in the PagedAttention KV cache. See the "KV Cache Block Size (`kvcache_block_size`)"
245
+ section below for details.
246
+ quantization (Optional[Dict[str, Any]]): Configuration dictionary for applying model
247
+ quantization. Specifies format, etc.
248
+ prefill_chunk_size (Optional[int]): The chunk size used during the prefill phase for
249
+ processing input sequences. Defaults to 128. Must be a positive integer
250
+ divisible by 64. Affects prefill performance and memory usage.
251
+ kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
252
+ PagedAttention KV cache. See the "KV Cache Number of Blocks (`kvcache_num_blocks`)"
253
+ section below for details.
254
+ decoder_batch_sizes (Optional[List[int]]): A list of batch sizes for which separate decoder models will be compiled.
255
+ This allows the model to handle varying batch sizes efficiently during generation. If not specified,
256
+ defaults to a list containing only the model's main batch size. When specifying multiple batch sizes:
257
+ 1) All values must be less than or equal to the main batch size.
258
+ 2) The list will be sorted in descending order (larger batch sizes first).
259
+ 3) If using multiple decoders, at least one batch size should match the main batch size.
260
+ cache_impl (Optional[CacheImplType]): Specifies the KV cache implementation strategy. Defaults to "static".
261
+ - "static": Uses a fixed-size global KV cache for all layers, suitable for standard attention patterns.
262
+ - "sliding_window": Implements a sliding window KV cache, where each layer maintains a local cache of recent tokens.
263
+ - "hybrid": Combines both static and sliding window approaches, allowing different layers to use different cache strategies.
264
+ The choice affects memory usage and attention patterns. When using "sliding_window" or "hybrid",
265
+ you must specify the `sliding_window` size and optionally `sliding_window_layers` for hybrid mode.
266
+ sliding_window (Optional[int]): The size of the sliding window. Defaults to None.
267
+ sliding_window_layers (Optional[List[int]]): The layers to use for the sliding window used in the hybrid model. Defaults to None.
268
+ phases (Optional[List[PhaseType]]): The phases to compile the model for. Defaults to None.
269
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
270
+
271
+ Raises:
272
+ ValueError: If `batch_size` is not a positive integer.
273
+ ValueError: If `prefill_chunk_size` is not a positive integer divisible by 64.
274
+ ValueError: If `max_seq_len` cannot be determined and is required.
275
+ ValueError: If attention parameter constraints are violated (e.g., `max_seq_len` vs
276
+ `kvcache_partition_len` for flash attention).
277
+
278
+
279
+ Attention Implementation:
280
+ `attn_impl` determines the underlying attention mechanism used by the model.
281
+
282
+ - **`"eager"`** (Default if `kvcache_partition_len` is not set): Uses the standard PyTorch
283
+ attention implementation. Suitable for sequences up to a certain limit (e.g., 32,768 tokens).
284
+ - **`"flash_attn"`**: Utilizes an optimized Flash Attention implementation, beneficial for
285
+ longer sequences and potentially faster execution. Requires `max_seq_len` to be at least
286
+ 8,192. If `kvcache_partition_len` is specified, `attn_impl` automatically defaults
287
+ to `"flash_attn"`. When using `"flash_attn"`, `kvcache_block_size` must equal
288
+ `kvcache_partition_len`.
289
+
290
+ The choice impacts performance and memory usage, especially for long sequences.
291
+ Constraints related to `max_seq_len` and `kvcache_partition_len` apply when using
292
+ `"flash_attn"`.
293
+
294
+
295
+ KV Cache Partition Length:
296
+ `kvcache_partition_len` is relevant **only** when `attn_impl` is `"flash_attn"`.
297
+
298
+ - It defines the length (number of tokens) of each partition within the Key-Value (KV) cache.
299
+ - Must be between 4,096 and 32,768 (inclusive).
300
+ - When using `"flash_attn"`, `max_seq_len` must be a multiple of `kvcache_partition_len`
301
+ and at least twice its value (`max_seq_len >= 2 * kvcache_partition_len`).
302
+ - If `attn_impl` is `"flash_attn"` and `kvcache_partition_len` is `None`, it defaults to
303
+ 16,384.
304
+
305
+
306
+ KV Cache Number of Blocks:
307
+ `kvcache_num_blocks` controls the total number of memory blocks allocated for the PagedAttention KV cache.
308
+ Each block holds `kvcache_block_size` tokens of Key and Value states.
309
+
310
+ - **Automatic Estimation (Default)**: If `kvcache_num_blocks` is `None`, the system estimates
311
+ the maximum number of blocks that can fit into the available RBLN device memory. This
312
+ calculation considers the model size (kernel memory), required buffer memory, the number
313
+ of layers and heads, `kvcache_block_size`, tensor parallelism, and available RBLN NPU DRAM.
314
+ This aims to maximize cache capacity for potentially better performance with long sequences
315
+ or larger batches without manual tuning.
316
+ - **Manual Setting**: You can explicitly set the number of blocks. This provides finer control
317
+ but requires careful consideration of memory limits. Setting it too high may lead to
318
+ compilation errors if it exceeds available memory. The system will issue warnings if your
319
+ setting exceeds the estimated maximum.
320
+ - **Performance Impact**: A larger number of blocks reduces the likelihood of cache eviction,
321
+ which is beneficial for tasks involving many long sequences or large batch sizes, enabling
322
+ higher throughput. However, allocating more blocks consumes more memory.
323
+ - **Minimum Requirement**: The system requires a minimum number of blocks to function,
324
+ calculated based on `max_seq_len`, `kvcache_block_size`, and `batch_size`. The number of
325
+ allocated blocks must be sufficient to hold at least one full sequence length per item
326
+ in the batch concurrently. The system will log warnings or raise errors if constraints
327
+ are violated (e.g., if `kvcache_num_blocks` is less than `batch_size` when using Flash Attention).
328
+
329
+ The optimal value depends on the specific model, task, hardware, and desired trade-off
330
+ between performance and memory usage. The automatic estimation provides a robust starting point.
331
+ """
332
+
333
+ super().__init__(
334
+ max_seq_len=max_seq_len,
335
+ use_inputs_embeds=use_inputs_embeds,
336
+ use_attention_mask=use_attention_mask,
337
+ use_position_ids=use_position_ids,
338
+ attn_impl=attn_impl,
339
+ kvcache_partition_len=kvcache_partition_len,
340
+ kvcache_block_size=kvcache_block_size,
341
+ quantization=quantization,
342
+ prefill_chunk_size=prefill_chunk_size,
343
+ kvcache_num_blocks=kvcache_num_blocks,
344
+ cache_impl=cache_impl,
345
+ sliding_window=sliding_window,
346
+ sliding_window_layers=sliding_window_layers,
347
+ **kwargs,
348
+ )
349
+
350
+ # override batch_size for causal lm
351
+ self.batch_size = batch_size or 1
352
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
353
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
354
+
355
+ if phases is not None:
356
+ self.validate_phases_type(phases)
357
+ self.phases = phases or ["prefill", "decode"]
358
+
359
+ if "decode" in self.phases:
360
+ self.decoder_batch_sizes = decoder_batch_sizes
361
+ if self.decoder_batch_sizes is None:
362
+ self.decoder_batch_sizes = [self.batch_size]
363
+
364
+ if self.use_multiple_decoder:
365
+ if max(self.decoder_batch_sizes) > self.batch_size:
366
+ raise ValueError(
367
+ f"Decoder batch size ({max(self.decoder_batch_sizes)}) must be less than or equal to the runtime batch size ({self.batch_size})."
368
+ )
369
+ if max(self.decoder_batch_sizes) < self.batch_size:
370
+ logger.warning(
371
+ f"Maximum decoder batch size ({max(self.decoder_batch_sizes)}) is less than the model's batch size ({self.batch_size}). "
372
+ "Appending the model's batch size to the decoder batch size."
373
+ )
374
+ self.decoder_batch_sizes.append(self.batch_size)
375
+
376
+ # Larger batch size should be at the beginning of the list.
377
+ self.decoder_batch_sizes.sort(reverse=True)
378
+
379
+ @staticmethod
380
+ def validate_phases_type(phases: List[PhaseType]):
381
+ if not isinstance(phases, list):
382
+ raise ValueError("`phases` must be a list.")
383
+ if not all(phase in get_args(PhaseType) for phase in phases):
384
+ raise ValueError(f"All elements in `phases` must be of type `PhaseType`({get_args(PhaseType)}).")
385
+
221
386
  @property
222
387
  def use_multiple_decoder(self):
223
388
  return isinstance(self.decoder_batch_sizes, list) and len(self.decoder_batch_sizes) > 1
389
+
390
+ @property
391
+ def can_generate(self):
392
+ return "decode" in self.phases