optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -12,28 +12,24 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
15
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
- from transformers import (
19
- CLIPTextConfig,
20
- CLIPTextModel,
21
- CLIPVisionConfig,
22
- CLIPVisionModel,
23
- )
24
- from transformers.modeling_outputs import BaseModelOutputWithPooling
18
+ from transformers import CLIPTextConfig, CLIPTextModel, CLIPVisionConfig, CLIPVisionModel
25
19
  from transformers.models.clip.modeling_clip import CLIPTextModelOutput, CLIPVisionModelOutput
26
20
 
27
- from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
21
+ from ....configuration_utils import RBLNCompileConfig
28
22
  from ....modeling import RBLNModel
29
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
30
23
  from ....utils.logging import get_logger
24
+ from .configuration_clip import RBLNCLIPTextModelConfig, RBLNCLIPVisionModelConfig
31
25
 
32
26
 
33
27
  logger = get_logger(__name__)
34
28
 
35
29
  if TYPE_CHECKING:
36
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, CLIPTextModel
30
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, CLIPTextModel, PreTrainedModel
31
+
32
+ from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
37
33
 
38
34
 
39
35
  class _TextEncoder(torch.nn.Module):
@@ -48,53 +44,55 @@ class _TextEncoder(torch.nn.Module):
48
44
 
49
45
  class RBLNCLIPTextModel(RBLNModel):
50
46
  @classmethod
51
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
47
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPTextModelConfig) -> torch.nn.Module:
52
48
  return _TextEncoder(model).eval()
53
49
 
54
50
  @classmethod
55
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
51
+ def update_rbln_config_using_pipe(
52
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_config: str
53
+ ) -> "RBLNDiffusionMixinConfig":
56
54
  return rbln_config
57
55
 
58
56
  @classmethod
59
- def _get_rbln_config(
57
+ def _update_rbln_config(
60
58
  cls,
61
59
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
62
- model_config: "CLIPTextConfig",
63
- rbln_kwargs: Dict[str, Any] = {},
64
- rbln_batch_size: Optional[int] = None,
65
- ) -> RBLNConfig:
66
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
67
- if rbln_batch_size is None:
68
- rbln_batch_size = 1
69
-
70
- model_config.return_dict = False
71
-
60
+ model: Optional["PreTrainedModel"] = None,
61
+ model_config: "CLIPTextConfig" = None,
62
+ rbln_config: Optional[RBLNCLIPTextModelConfig] = None,
63
+ ) -> RBLNCLIPTextModelConfig:
72
64
  input_info = [
73
65
  (
74
66
  "input_ids",
75
67
  [
76
- rbln_batch_size,
68
+ rbln_config.batch_size,
77
69
  model_config.max_position_embeddings,
78
70
  ],
79
71
  "int64",
80
72
  ),
81
73
  ]
82
74
 
83
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
84
- rbln_config = RBLNConfig(
85
- rbln_cls=cls.__name__,
86
- compile_cfgs=[rbln_compile_config],
87
- rbln_kwargs=rbln_kwargs,
88
- )
75
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
89
76
  return rbln_config
90
77
 
91
- def forward(self, input_ids: "torch.Tensor", **kwargs):
92
- text_output = super().forward(input_ids)
93
- return CLIPTextModelOutput(
94
- text_embeds=text_output[0],
95
- last_hidden_state=text_output[1],
96
- hidden_states=text_output[2:],
97
- )
78
+ def forward(self, input_ids: torch.LongTensor, return_dict: bool = None, **kwargs) -> torch.FloatTensor:
79
+ # To ignore using attention_mask, we override forward method.
80
+ output = super().forward(input_ids, return_dict=return_dict)
81
+ return output
82
+
83
+ def _prepare_output(self, output, return_dict):
84
+ """
85
+ Prepare model output based on return_dict flag.
86
+ This method can be overridden by subclasses to provide task-specific output handling.
87
+ """
88
+ if not return_dict:
89
+ return (output,) if not isinstance(output, (tuple, list)) else output
90
+ else:
91
+ return CLIPTextModelOutput(
92
+ text_embeds=output[0],
93
+ last_hidden_state=output[1],
94
+ hidden_states=output[2:],
95
+ )
98
96
 
99
97
 
100
98
  class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
@@ -113,30 +111,30 @@ class _VisionEncoder(torch.nn.Module):
113
111
 
114
112
  class RBLNCLIPVisionModel(RBLNModel):
115
113
  @classmethod
116
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
114
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
117
115
  return _VisionEncoder(model).eval()
118
116
 
119
117
  @classmethod
120
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
118
+ def update_rbln_config_using_pipe(
119
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
120
+ ) -> "RBLNDiffusionMixinConfig":
121
121
  return rbln_config
122
122
 
123
123
  @classmethod
124
- def _get_rbln_config(
124
+ def _update_rbln_config(
125
125
  cls,
126
126
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
127
- model_config: "CLIPVisionConfig",
128
- rbln_kwargs: Dict[str, Any] = {},
129
- ) -> RBLNConfig:
130
- rbln_batch_size = rbln_kwargs.get("batch_size", 1)
131
- rbln_image_size = rbln_kwargs.get("image_size", None)
132
-
133
- if rbln_image_size is None:
134
- rbln_image_size = getattr(model_config, "image_size", None)
127
+ model: Optional["PreTrainedModel"] = None,
128
+ model_config: "CLIPVisionConfig" = None,
129
+ rbln_config: Optional[RBLNCLIPVisionModelConfig] = None,
130
+ ) -> RBLNCLIPVisionModelConfig:
131
+ if rbln_config.image_size is None:
132
+ rbln_config.image_size = getattr(model_config, "image_size", None)
135
133
 
136
- if isinstance(rbln_image_size, int):
137
- rbln_image_size = (rbln_image_size, rbln_image_size)
134
+ if isinstance(rbln_config.image_size, int):
135
+ rbln_config.image_size = (rbln_config.image_size, rbln_config.image_size)
138
136
 
139
- if rbln_image_size is None:
137
+ if rbln_config.image_size is None:
140
138
  raise ValueError("`rbln_image_size` should be specified!")
141
139
 
142
140
  rbln_compile_config = RBLNCompileConfig(
@@ -144,45 +142,44 @@ class RBLNCLIPVisionModel(RBLNModel):
144
142
  (
145
143
  "pixel_values",
146
144
  [
147
- rbln_batch_size,
145
+ rbln_config.batch_size,
148
146
  3,
149
- rbln_image_size[0],
150
- rbln_image_size[1],
147
+ rbln_config.image_height,
148
+ rbln_config.image_width,
151
149
  ],
152
150
  "float32",
153
151
  )
154
152
  ]
155
153
  )
156
154
 
157
- rbln_config = RBLNConfig(
158
- rbln_cls=cls.__name__,
159
- compile_cfgs=[rbln_compile_config],
160
- rbln_kwargs=rbln_kwargs,
161
- )
162
-
163
- rbln_config.model_cfg.update(
164
- {
165
- "batch_size": rbln_batch_size,
166
- "image_size": rbln_image_size,
167
- }
168
- )
169
-
155
+ rbln_config.set_compile_cfgs([rbln_compile_config])
170
156
  return rbln_config
171
157
 
172
158
  def forward(
173
159
  self,
174
160
  pixel_values: Optional[torch.FloatTensor] = None,
161
+ return_dict: bool = None,
175
162
  **kwargs,
176
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
163
+ ) -> Union[Tuple, CLIPVisionModelOutput]:
177
164
  if len(kwargs) > 0 and any(kwargs.values()):
178
165
  logger.warning(f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__}.")
179
166
 
180
- output = super().forward(pixel_values)
181
- return BaseModelOutputWithPooling(
182
- last_hidden_state=output[0],
183
- pooler_output=output[1],
184
- hidden_states=output[2:],
185
- )
167
+ output = super().forward(pixel_values, return_dict=return_dict)
168
+ return output
169
+
170
+ def _prepare_output(self, output, return_dict):
171
+ """
172
+ Prepare model output based on return_dict flag.
173
+ This method can be overridden by subclasses to provide task-specific output handling.
174
+ """
175
+ if not return_dict:
176
+ return (output,) if not isinstance(output, (tuple, list)) else output
177
+ else:
178
+ return CLIPVisionModelOutput(
179
+ image_embeds=output[0],
180
+ last_hidden_state=output[1],
181
+ hidden_states=output[2:],
182
+ )
186
183
 
187
184
 
188
185
  class RBLNCLIPVisionModelWithProjection(RBLNCLIPVisionModel):
@@ -12,4 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from ....ops import (
16
+ paged_attn_decode,
17
+ paged_attn_prefill,
18
+ paged_causal_attn_decode,
19
+ paged_causal_attn_prefill,
20
+ paged_flash_attn_decode,
21
+ paged_flash_attn_prefill,
22
+ paged_flash_causal_attn_decode,
23
+ paged_flash_causal_attn_prefill,
24
+ )
25
+ from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
26
  from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
@@ -0,0 +1,90 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ import rebel
18
+
19
+ from ....configuration_utils import RBLNModelConfig
20
+ from ....utils.logging import get_logger
21
+ from ...utils.rbln_quantization import QuantizationManager
22
+
23
+
24
+ logger = get_logger()
25
+
26
+
27
+ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
28
+ def __init__(
29
+ self,
30
+ batch_size: Optional[int] = None,
31
+ max_seq_len: Optional[int] = None,
32
+ use_inputs_embeds: Optional[bool] = None,
33
+ use_attention_mask: Optional[bool] = None,
34
+ attn_impl: Optional[str] = None,
35
+ kvcache_partition_len: Optional[int] = None,
36
+ kvcache_block_size: Optional[int] = None,
37
+ quantization: Optional[Dict[str, Any]] = None,
38
+ prefill_chunk_size: Optional[int] = None,
39
+ kvcache_num_blocks: Optional[int] = None,
40
+ **kwargs,
41
+ ):
42
+ """
43
+ Args:
44
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
45
+ max_seq_len (Optional[int]): The maximum sequence length supported by the model.
46
+ use_inputs_embeds (Optional[bool]): Whether to use input embeddings directly. Defaults to False.
47
+ use_attention_mask (Optional[bool]): Whether to use attention masks. This is automatically set to True
48
+ for RBLN-CA02 devices.
49
+ attn_impl (Optional[str]): The attention implementation to use.
50
+ kvcache_partition_len (Optional[int]): The length of each KV cache partition.
51
+ kvcache_block_size (Optional[int]): The block size for KV cache.
52
+ quantization (Optional[Dict[str, Any]]): Configuration for model quantization.
53
+ prefill_chunk_size (Optional[int]): The chunk size for prefilling the KV cache. Defaults to 128,
54
+ and must be a positive integer divisible by 64.
55
+ kvcache_num_blocks (Optional[int]): The number of blocks in the KV cache.
56
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
57
+
58
+ Raises:
59
+ ValueError: If batch_size is not a positive integer or if prefill_chunk_size is not
60
+ a positive integer divisible by 64.
61
+ """
62
+ super().__init__(**kwargs)
63
+ self.batch_size = batch_size or 1
64
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
65
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
66
+
67
+ self.max_seq_len = max_seq_len
68
+ self.use_inputs_embeds = use_inputs_embeds or False
69
+
70
+ self.use_attention_mask = use_attention_mask
71
+ npu = self.npu or rebel.get_npu_name()
72
+ if npu == "RBLN-CA02":
73
+ if self.use_attention_mask is False:
74
+ logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
75
+ self.use_attention_mask = True
76
+ else:
77
+ self.use_attention_mask = self.use_attention_mask or False
78
+
79
+ self.attn_impl = attn_impl
80
+ self.kvcache_partition_len = kvcache_partition_len
81
+ self.kvcache_block_size = kvcache_block_size
82
+ self.quantization = quantization or {}
83
+ if self.quantization:
84
+ QuantizationManager.validate_quantization_config(self.quantization)
85
+
86
+ self.prefill_chunk_size = prefill_chunk_size or 128
87
+ if self.prefill_chunk_size % 64 != 0 or self.prefill_chunk_size <= 0:
88
+ raise ValueError("`prefill_chunk_size` must be a positive integer divisible by 64.")
89
+
90
+ self.kvcache_num_blocks = kvcache_num_blocks