optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. optimum/rbln/__init__.py +164 -36
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +107 -78
  31. optimum/rbln/transformers/__init__.py +87 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +108 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +115 -84
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +282 -216
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  73. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  74. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  75. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  76. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  77. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  78. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  79. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  80. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  81. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  82. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  83. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  84. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  85. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  86. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  87. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  88. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  89. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  90. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  91. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  92. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  93. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  94. optimum/rbln/utils/runtime_utils.py +33 -2
  95. optimum/rbln/utils/submodule.py +26 -43
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/METADATA +1 -1
  97. optimum_rbln-0.7.4a6.dist-info/RECORD +166 -0
  98. optimum/rbln/modeling_config.py +0 -310
  99. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  100. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/WHEEL +0 -0
  101. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/licenses/LICENSE +0 -0
@@ -12,28 +12,24 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
15
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
- from transformers import (
19
- CLIPTextConfig,
20
- CLIPTextModel,
21
- CLIPVisionConfig,
22
- CLIPVisionModel,
23
- )
24
- from transformers.modeling_outputs import BaseModelOutputWithPooling
18
+ from transformers import CLIPTextConfig, CLIPTextModel, CLIPVisionConfig, CLIPVisionModel
25
19
  from transformers.models.clip.modeling_clip import CLIPTextModelOutput, CLIPVisionModelOutput
26
20
 
27
- from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
21
+ from ....configuration_utils import RBLNCompileConfig
28
22
  from ....modeling import RBLNModel
29
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
30
23
  from ....utils.logging import get_logger
24
+ from .configuration_clip import RBLNCLIPTextModelConfig, RBLNCLIPVisionModelConfig
31
25
 
32
26
 
33
27
  logger = get_logger(__name__)
34
28
 
35
29
  if TYPE_CHECKING:
36
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, CLIPTextModel
30
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, CLIPTextModel, PreTrainedModel
31
+
32
+ from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
37
33
 
38
34
 
39
35
  class _TextEncoder(torch.nn.Module):
@@ -48,53 +44,55 @@ class _TextEncoder(torch.nn.Module):
48
44
 
49
45
  class RBLNCLIPTextModel(RBLNModel):
50
46
  @classmethod
51
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
47
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPTextModelConfig) -> torch.nn.Module:
52
48
  return _TextEncoder(model).eval()
53
49
 
54
50
  @classmethod
55
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
51
+ def update_rbln_config_using_pipe(
52
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_config: str
53
+ ) -> "RBLNDiffusionMixinConfig":
56
54
  return rbln_config
57
55
 
58
56
  @classmethod
59
- def _get_rbln_config(
57
+ def _update_rbln_config(
60
58
  cls,
61
59
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
62
- model_config: "CLIPTextConfig",
63
- rbln_kwargs: Dict[str, Any] = {},
64
- rbln_batch_size: Optional[int] = None,
65
- ) -> RBLNConfig:
66
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
67
- if rbln_batch_size is None:
68
- rbln_batch_size = 1
69
-
70
- model_config.return_dict = False
71
-
60
+ model: Optional["PreTrainedModel"] = None,
61
+ model_config: "CLIPTextConfig" = None,
62
+ rbln_config: Optional[RBLNCLIPTextModelConfig] = None,
63
+ ) -> RBLNCLIPTextModelConfig:
72
64
  input_info = [
73
65
  (
74
66
  "input_ids",
75
67
  [
76
- rbln_batch_size,
68
+ rbln_config.batch_size,
77
69
  model_config.max_position_embeddings,
78
70
  ],
79
71
  "int64",
80
72
  ),
81
73
  ]
82
74
 
83
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
84
- rbln_config = RBLNConfig(
85
- rbln_cls=cls.__name__,
86
- compile_cfgs=[rbln_compile_config],
87
- rbln_kwargs=rbln_kwargs,
88
- )
75
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
89
76
  return rbln_config
90
77
 
91
- def forward(self, input_ids: "torch.Tensor", **kwargs):
92
- text_output = super().forward(input_ids)
93
- return CLIPTextModelOutput(
94
- text_embeds=text_output[0],
95
- last_hidden_state=text_output[1],
96
- hidden_states=text_output[2:],
97
- )
78
+ def forward(self, input_ids: torch.LongTensor, return_dict: bool = None, **kwargs) -> torch.FloatTensor:
79
+ # To ignore using attention_mask, we override forward method.
80
+ output = super().forward(input_ids, return_dict=return_dict)
81
+ return output
82
+
83
+ def _prepare_output(self, output, return_dict):
84
+ """
85
+ Prepare model output based on return_dict flag.
86
+ This method can be overridden by subclasses to provide task-specific output handling.
87
+ """
88
+ if not return_dict:
89
+ return (output,) if not isinstance(output, (tuple, list)) else output
90
+ else:
91
+ return CLIPTextModelOutput(
92
+ text_embeds=output[0],
93
+ last_hidden_state=output[1],
94
+ hidden_states=output[2:],
95
+ )
98
96
 
99
97
 
100
98
  class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
@@ -113,30 +111,30 @@ class _VisionEncoder(torch.nn.Module):
113
111
 
114
112
  class RBLNCLIPVisionModel(RBLNModel):
115
113
  @classmethod
116
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
114
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
117
115
  return _VisionEncoder(model).eval()
118
116
 
119
117
  @classmethod
120
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
118
+ def update_rbln_config_using_pipe(
119
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
120
+ ) -> "RBLNDiffusionMixinConfig":
121
121
  return rbln_config
122
122
 
123
123
  @classmethod
124
- def _get_rbln_config(
124
+ def _update_rbln_config(
125
125
  cls,
126
126
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
127
- model_config: "CLIPVisionConfig",
128
- rbln_kwargs: Dict[str, Any] = {},
129
- ) -> RBLNConfig:
130
- rbln_batch_size = rbln_kwargs.get("batch_size", 1)
131
- rbln_image_size = rbln_kwargs.get("image_size", None)
132
-
133
- if rbln_image_size is None:
134
- rbln_image_size = getattr(model_config, "image_size", None)
127
+ model: Optional["PreTrainedModel"] = None,
128
+ model_config: "CLIPVisionConfig" = None,
129
+ rbln_config: Optional[RBLNCLIPVisionModelConfig] = None,
130
+ ) -> RBLNCLIPVisionModelConfig:
131
+ if rbln_config.image_size is None:
132
+ rbln_config.image_size = getattr(model_config, "image_size", None)
135
133
 
136
- if isinstance(rbln_image_size, int):
137
- rbln_image_size = (rbln_image_size, rbln_image_size)
134
+ if isinstance(rbln_config.image_size, int):
135
+ rbln_config.image_size = (rbln_config.image_size, rbln_config.image_size)
138
136
 
139
- if rbln_image_size is None:
137
+ if rbln_config.image_size is None:
140
138
  raise ValueError("`rbln_image_size` should be specified!")
141
139
 
142
140
  rbln_compile_config = RBLNCompileConfig(
@@ -144,45 +142,44 @@ class RBLNCLIPVisionModel(RBLNModel):
144
142
  (
145
143
  "pixel_values",
146
144
  [
147
- rbln_batch_size,
145
+ rbln_config.batch_size,
148
146
  3,
149
- rbln_image_size[0],
150
- rbln_image_size[1],
147
+ rbln_config.image_height,
148
+ rbln_config.image_width,
151
149
  ],
152
150
  "float32",
153
151
  )
154
152
  ]
155
153
  )
156
154
 
157
- rbln_config = RBLNConfig(
158
- rbln_cls=cls.__name__,
159
- compile_cfgs=[rbln_compile_config],
160
- rbln_kwargs=rbln_kwargs,
161
- )
162
-
163
- rbln_config.model_cfg.update(
164
- {
165
- "batch_size": rbln_batch_size,
166
- "image_size": rbln_image_size,
167
- }
168
- )
169
-
155
+ rbln_config.set_compile_cfgs([rbln_compile_config])
170
156
  return rbln_config
171
157
 
172
158
  def forward(
173
159
  self,
174
160
  pixel_values: Optional[torch.FloatTensor] = None,
161
+ return_dict: bool = None,
175
162
  **kwargs,
176
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
163
+ ) -> Union[Tuple, CLIPVisionModelOutput]:
177
164
  if len(kwargs) > 0 and any(kwargs.values()):
178
165
  logger.warning(f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__}.")
179
166
 
180
- output = super().forward(pixel_values)
181
- return BaseModelOutputWithPooling(
182
- last_hidden_state=output[0],
183
- pooler_output=output[1],
184
- hidden_states=output[2:],
185
- )
167
+ output = super().forward(pixel_values, return_dict=return_dict)
168
+ return output
169
+
170
+ def _prepare_output(self, output, return_dict):
171
+ """
172
+ Prepare model output based on return_dict flag.
173
+ This method can be overridden by subclasses to provide task-specific output handling.
174
+ """
175
+ if not return_dict:
176
+ return (output,) if not isinstance(output, (tuple, list)) else output
177
+ else:
178
+ return CLIPVisionModelOutput(
179
+ image_embeds=output[0],
180
+ last_hidden_state=output[1],
181
+ hidden_states=output[2:],
182
+ )
186
183
 
187
184
 
188
185
  class RBLNCLIPVisionModelWithProjection(RBLNCLIPVisionModel):
@@ -22,4 +22,5 @@ from ....ops import (
22
22
  paged_flash_causal_attn_decode,
23
23
  paged_flash_causal_attn_prefill,
24
24
  )
25
+ from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
25
26
  from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
@@ -0,0 +1,90 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ import rebel
18
+
19
+ from ....configuration_utils import RBLNModelConfig
20
+ from ....utils.logging import get_logger
21
+ from ...utils.rbln_quantization import QuantizationManager
22
+
23
+
24
+ logger = get_logger()
25
+
26
+
27
+ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
28
+ def __init__(
29
+ self,
30
+ batch_size: Optional[int] = None,
31
+ max_seq_len: Optional[int] = None,
32
+ use_inputs_embeds: Optional[bool] = None,
33
+ use_attention_mask: Optional[bool] = None,
34
+ attn_impl: Optional[str] = None,
35
+ kvcache_partition_len: Optional[int] = None,
36
+ kvcache_block_size: Optional[int] = None,
37
+ quantization: Optional[Dict[str, Any]] = None,
38
+ prefill_chunk_size: Optional[int] = None,
39
+ kvcache_num_blocks: Optional[int] = None,
40
+ **kwargs,
41
+ ):
42
+ """
43
+ Args:
44
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
45
+ max_seq_len (Optional[int]): The maximum sequence length supported by the model.
46
+ use_inputs_embeds (Optional[bool]): Whether to use input embeddings directly. Defaults to False.
47
+ use_attention_mask (Optional[bool]): Whether to use attention masks. This is automatically set to True
48
+ for RBLN-CA02 devices.
49
+ attn_impl (Optional[str]): The attention implementation to use.
50
+ kvcache_partition_len (Optional[int]): The length of each KV cache partition.
51
+ kvcache_block_size (Optional[int]): The block size for KV cache.
52
+ quantization (Optional[Dict[str, Any]]): Configuration for model quantization.
53
+ prefill_chunk_size (Optional[int]): The chunk size for prefilling the KV cache. Defaults to 128,
54
+ and must be a positive integer divisible by 64.
55
+ kvcache_num_blocks (Optional[int]): The number of blocks in the KV cache.
56
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
57
+
58
+ Raises:
59
+ ValueError: If batch_size is not a positive integer or if prefill_chunk_size is not
60
+ a positive integer divisible by 64.
61
+ """
62
+ super().__init__(**kwargs)
63
+ self.batch_size = batch_size or 1
64
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
65
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
66
+
67
+ self.max_seq_len = max_seq_len
68
+ self.use_inputs_embeds = use_inputs_embeds or False
69
+
70
+ self.use_attention_mask = use_attention_mask
71
+ npu = self.npu or rebel.get_npu_name()
72
+ if npu == "RBLN-CA02":
73
+ if self.use_attention_mask is False:
74
+ logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
75
+ self.use_attention_mask = True
76
+ else:
77
+ self.use_attention_mask = self.use_attention_mask or False
78
+
79
+ self.attn_impl = attn_impl
80
+ self.kvcache_partition_len = kvcache_partition_len
81
+ self.kvcache_block_size = kvcache_block_size
82
+ self.quantization = quantization or {}
83
+ if self.quantization:
84
+ QuantizationManager.validate_quantization_config(self.quantization)
85
+
86
+ self.prefill_chunk_size = prefill_chunk_size or 128
87
+ if self.prefill_chunk_size % 64 != 0 or self.prefill_chunk_size <= 0:
88
+ raise ValueError("`prefill_chunk_size` must be a positive integer divisible by 64.")
89
+
90
+ self.kvcache_num_blocks = kvcache_num_blocks
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import math
16
- from typing import List, Optional, Tuple
16
+ from typing import List, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
19
  from torch import nn
@@ -32,30 +32,39 @@ MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
32
32
  MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
33
33
 
34
34
 
35
- def validate_attention_method(
36
- rbln_attn_impl: str, rbln_kvcache_partition_len: int, rbln_kvcache_block_size: int, rbln_max_seq_len: int
37
- ) -> Tuple[str, int]:
38
- if rbln_kvcache_partition_len is not None:
39
- if rbln_attn_impl == "eager":
40
- raise ValueError(
41
- f"`rbln_kvcache_partition_len` is set to {rbln_kvcache_partition_len}, but KV cache partitioning"
42
- " is not supported with 'eager' attention. Please set `rbln_kvcache_partition_len` to None, "
43
- "or switch `rbln_attn_impl` to 'flash_attn' to use KV cache partitioning."
44
- )
45
- elif rbln_attn_impl is None:
46
- rbln_attn_impl = "flash_attn"
35
+ def set_default_values(
36
+ attn_impl: Optional[str] = None,
37
+ kvcache_partition_len: Optional[int] = None,
38
+ kvcache_block_size: Optional[int] = None,
39
+ max_seq_len: Optional[int] = None,
40
+ ) -> Tuple[str, int, int]:
41
+ if attn_impl is None:
42
+ attn_impl = "eager"
43
+
44
+ if kvcache_partition_len is not None:
45
+ if attn_impl == "eager":
46
+ attn_impl = "flash_attn"
47
47
  logger.warning(
48
- "A non-null `rbln_kvcache_partition_len` was provided, but `rbln_attn_impl` was not explicitly set. "
49
- "Since KV cache partitioning is only supported with flash attention, "
50
- "`rbln_attn_impl` has been automatically switched to 'flash_attn'."
48
+ "A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
49
+ "set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
50
+ "`attn_impl` has been automatically switched to 'flash_attn'."
51
51
  )
52
52
 
53
- rbln_attn_impl = "eager" if rbln_attn_impl is None else rbln_attn_impl
54
- if rbln_attn_impl not in ["eager", "flash_attn"]:
55
- raise ValueError(f"Unknown `rbln_attn_impl` : {rbln_attn_impl}. (Available : 'eager', 'flash_attn`)")
53
+ if kvcache_partition_len is None and attn_impl == "flash_attn":
54
+ kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
55
+
56
+ if kvcache_block_size is None:
57
+ if attn_impl == "eager":
58
+ kvcache_block_size = max_seq_len
59
+ else:
60
+ kvcache_block_size = kvcache_partition_len
61
+
62
+ return attn_impl, kvcache_partition_len, kvcache_block_size
56
63
 
57
- if rbln_kvcache_partition_len is None and rbln_attn_impl == "flash_attn":
58
- rbln_kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
64
+
65
+ def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
66
+ if attn_impl not in ["eager", "flash_attn"]:
67
+ raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
59
68
 
60
69
  ## Checking Constraints...
61
70
  # Constraint of eager attention:
@@ -65,47 +74,45 @@ def validate_attention_method(
65
74
  # 1. `max_seq_len` should be multiple of `partition_len`.
66
75
  # 2. 4k <= `partition_len` <= 32k.
67
76
  # 3. `max_seq_len` should be larger then 8k.
68
- if rbln_attn_impl == "eager" and rbln_max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
77
+ if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
69
78
  raise ValueError(
70
- f"`rbln_max_seq_len` is set to {rbln_max_seq_len}, "
79
+ f"`max_seq_len` is set to {max_seq_len}, "
71
80
  f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
72
- f"Please reduce the `rbln_max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
73
- " or consider switching `rbln_attn_impl` to 'flash_attn' for larger sequence lengths."
81
+ f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
82
+ " or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
74
83
  )
75
84
 
76
- if rbln_attn_impl == "flash_attn":
77
- if rbln_max_seq_len // rbln_kvcache_partition_len < 2 or rbln_max_seq_len % rbln_kvcache_partition_len != 0:
85
+ if attn_impl == "flash_attn":
86
+ if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
78
87
  raise ValueError(
79
- f"`rbln_max_seq_len` ({rbln_max_seq_len}) must be a multiple of `rbln_kvcache_partition_len` ({rbln_kvcache_partition_len}) "
88
+ f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
80
89
  f"when using 'flash_attn'. Please adjust either value to meet this requirement."
81
90
  )
82
- elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= rbln_kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
91
+ elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
83
92
  raise ValueError(
84
- f"`rbln_kvcache_partition_len` ({rbln_kvcache_partition_len}) is out of the supported range for 'flash_attn' "
85
- f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `rbln_kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
93
+ f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
94
+ f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
86
95
  f"Please provide a valid value within this range."
87
96
  )
88
- elif rbln_max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
97
+ elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
89
98
  raise ValueError(
90
- f"`rbln_max_seq_len` ({rbln_max_seq_len}) is too small for 'flash_attn'. The minimum "
91
- f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `rbln_max_seq_len` to meet "
92
- "this requirement, or consider switching `rbln_attn_impl` to 'eager' for shorter lengths."
99
+ f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
100
+ f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
101
+ "this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
93
102
  )
94
103
 
95
- if rbln_kvcache_block_size is not None:
96
- if rbln_attn_impl == "flash_attn" and rbln_kvcache_partition_len != rbln_kvcache_block_size:
104
+ if kvcache_block_size is not None:
105
+ if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
97
106
  raise ValueError(
98
- f" When using 'flash attention', the `rbln_kvcache_block_size` ({rbln_kvcache_block_size}) "
99
- f"must always be set equal to the `rbln_kvcache_partition_len` {rbln_kvcache_partition_len}."
107
+ f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
108
+ f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
100
109
  )
101
- elif rbln_attn_impl == "eager" and rbln_kvcache_block_size != rbln_max_seq_len:
110
+ elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
102
111
  raise ValueError(
103
- f" When using 'eager attention', the `rbln_kvcache_block_size` ({rbln_kvcache_block_size}) "
104
- f"must always be set equal to the `rbln_max_seq_len` {rbln_max_seq_len}."
112
+ f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
113
+ f"must always be set equal to the `max_seq_len` {max_seq_len}."
105
114
  )
106
115
 
107
- return rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size
108
-
109
116
 
110
117
  class DecoderOnlyWrapper(nn.Module):
111
118
  """A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
@@ -213,6 +220,53 @@ class DecoderOnlyWrapper(nn.Module):
213
220
  self._phase = phase
214
221
  self.causal_lm.phase = phase
215
222
 
223
+ def forward_common(
224
+ self,
225
+ input_ids_or_inputs_embeds: torch.Tensor,
226
+ cache_position: torch.Tensor,
227
+ attention_mask: torch.Tensor,
228
+ query_position: torch.Tensor,
229
+ block_tables: torch.Tensor,
230
+ rotary_emb: Union[nn.Module, torch.Tensor],
231
+ *past_key_values: List[torch.Tensor],
232
+ ):
233
+ if input_ids_or_inputs_embeds.ndim == 2:
234
+ input_ids = input_ids_or_inputs_embeds
235
+ inputs_embeds = None
236
+ elif input_ids_or_inputs_embeds.ndim == 3:
237
+ input_ids = None
238
+ inputs_embeds = input_ids_or_inputs_embeds
239
+ else:
240
+ raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
241
+
242
+ if len(past_key_values) != 2 * self.num_hidden_layers:
243
+ raise ValueError(
244
+ f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
245
+ )
246
+
247
+ # [key, value] * n_layer -> ( (key, value) ) * n_layer
248
+ # cache shape : batch, n_heads, 1, max_seq_len, head_dim
249
+ _past_key_values = []
250
+ for i in range(self.config.num_hidden_layers):
251
+ key_states = past_key_values[i * 2]
252
+ value_states = past_key_values[i * 2 + 1]
253
+ past_key_value = [key_states, value_states]
254
+ _past_key_values.append(past_key_value)
255
+ past_key_values = _past_key_values
256
+
257
+ logit = self.causal_lm(
258
+ input_ids=input_ids,
259
+ inputs_embeds=inputs_embeds,
260
+ attention_mask=attention_mask,
261
+ cache_position=cache_position,
262
+ query_position=query_position,
263
+ past_key_values=past_key_values,
264
+ rotary_emb=rotary_emb,
265
+ block_tables=block_tables,
266
+ )
267
+
268
+ return logit
269
+
216
270
  def forward(self, *args):
217
271
  if self.phase == "decode":
218
272
  if self.use_attention_mask:
@@ -255,43 +309,16 @@ class DecoderOnlyWrapper(nn.Module):
255
309
  else:
256
310
  raise ValueError(f"Unknown phase: {self.phase}")
257
311
 
258
- if input_ids_or_inputs_embeds.ndim == 2:
259
- input_ids = input_ids_or_inputs_embeds
260
- inputs_embeds = None
261
- elif input_ids_or_inputs_embeds.ndim == 3:
262
- input_ids = None
263
- inputs_embeds = input_ids_or_inputs_embeds
264
- else:
265
- raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
266
-
267
- if len(past_key_values) != 2 * self.num_hidden_layers:
268
- raise ValueError(
269
- f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
270
- )
271
-
272
- # [key, value] * n_layer -> ( (key, value) ) * n_layer
273
- # cache shape : batch, n_heads, 1, max_seq_len, head_dim
274
- _past_key_values = []
275
- for i in range(self.config.num_hidden_layers):
276
- key_states = past_key_values[i * 2]
277
- value_states = past_key_values[i * 2 + 1]
278
- past_key_value = [key_states, value_states]
279
- _past_key_values.append(past_key_value)
280
- past_key_values = _past_key_values
281
-
282
- logit = self.causal_lm(
283
- input_ids=input_ids,
284
- inputs_embeds=inputs_embeds,
285
- attention_mask=attention_mask,
286
- cache_position=cache_position,
287
- query_position=query_position,
288
- past_key_values=past_key_values,
289
- rotary_emb=self.rotary_emb,
290
- block_tables=block_tables,
312
+ return self.forward_common(
313
+ input_ids_or_inputs_embeds,
314
+ cache_position,
315
+ attention_mask,
316
+ query_position,
317
+ block_tables,
318
+ self.rotary_emb,
319
+ *past_key_values,
291
320
  )
292
321
 
293
- return logit
294
-
295
322
 
296
323
  class DecoderOnlyForCausalLM(nn.Module):
297
324
  """A specialized wrapper for Causal Language Models optimized for RBLN compilation.
@@ -315,12 +342,13 @@ class DecoderOnlyForCausalLM(nn.Module):
315
342
  _phase: Current processing phase ("prefill" or "decode")
316
343
  """
317
344
 
318
- def __init__(self, causal_lm: PreTrainedModel, model):
345
+ def __init__(self, causal_lm: PreTrainedModel, model: nn.Module):
319
346
  super().__init__()
320
347
  self.config = causal_lm.config
321
348
  self._original_mod = causal_lm
322
349
  self.model = model
323
350
  self._phase = "prefill"
351
+ self.lm_head = self._original_mod.lm_head
324
352
 
325
353
  @property
326
354
  def phase(self):
@@ -356,7 +384,7 @@ class DecoderOnlyForCausalLM(nn.Module):
356
384
  if self.phase == "prefill":
357
385
  hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
358
386
 
359
- logits = self._original_mod.lm_head(hidden_states)
387
+ logits = self.lm_head(hidden_states)
360
388
  return logits
361
389
 
362
390
 
@@ -448,8 +476,12 @@ class DecoderOnlyModel(nn.Module):
448
476
 
449
477
  # get cos,sin vector if needed
450
478
  if rotary_emb is not None:
451
- cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
452
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
479
+ if isinstance(rotary_emb, torch.Tensor):
480
+ cos = rotary_emb[0]
481
+ sin = rotary_emb[1]
482
+ else:
483
+ cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
484
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
453
485
  else:
454
486
  batch_size = inputs_embeds.shape[0]
455
487
  if cache_position.shape[0] > 1:
@@ -826,7 +858,6 @@ def rotate_half(x):
826
858
 
827
859
  def apply_rotary_pos_emb(q, k, cos, sin):
828
860
  """Applies Rotary Position Embedding to the query and key tensors."""
829
-
830
861
  q_embed = (q * cos) + (rotate_half(q) * sin)
831
862
  k_embed = (k * cos) + (rotate_half(k) * sin)
832
863
  return q_embed, k_embed