optimum-rbln 0.9.4a2__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. optimum/rbln/__init__.py +44 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +230 -67
  4. optimum/rbln/diffusers/models/controlnet.py +2 -2
  5. optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
  6. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
  7. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
  8. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  13. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
  14. optimum/rbln/modeling_base.py +11 -10
  15. optimum/rbln/ops/__init__.py +1 -0
  16. optimum/rbln/ops/attn.py +10 -0
  17. optimum/rbln/ops/flash_attn.py +8 -0
  18. optimum/rbln/ops/moe.py +180 -0
  19. optimum/rbln/ops/sliding_window_attn.py +9 -0
  20. optimum/rbln/transformers/__init__.py +44 -0
  21. optimum/rbln/transformers/modeling_attention_utils.py +124 -222
  22. optimum/rbln/transformers/modeling_outputs.py +25 -0
  23. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  24. optimum/rbln/transformers/models/__init__.py +38 -0
  25. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  27. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
  28. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
  29. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  30. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  31. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  32. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +40 -23
  33. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  34. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  35. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +144 -17
  36. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  37. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -48
  38. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  39. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +120 -128
  40. optimum/rbln/transformers/models/detr/__init__.py +23 -0
  41. optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
  42. optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
  43. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  44. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  45. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  46. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  47. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  48. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  49. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
  50. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  51. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -177
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  53. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  54. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +42 -0
  55. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  56. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +168 -0
  57. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  58. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  59. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  60. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  61. optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
  62. optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
  63. optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
  64. optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
  65. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  66. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  67. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  68. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  69. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  70. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  71. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
  72. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  73. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +13 -1
  74. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  75. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  76. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  77. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  78. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  79. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  80. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  81. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +13 -1
  82. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  83. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  85. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  86. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  87. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  88. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  89. optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
  90. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  91. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  92. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  94. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  95. optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
  96. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  97. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  98. optimum/rbln/utils/deprecation.py +78 -1
  99. optimum/rbln/utils/hub.py +93 -2
  100. optimum/rbln/utils/import_utils.py +16 -1
  101. optimum/rbln/utils/runtime_utils.py +12 -8
  102. optimum/rbln/utils/submodule.py +24 -0
  103. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +6 -6
  104. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +107 -81
  105. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  106. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
  107. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
  108. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
@@ -68,6 +68,8 @@ _import_structure = {
68
68
  "RBLNDecoderOnlyModelForCausalLMConfig",
69
69
  "RBLNDecoderOnlyModelConfig",
70
70
  "RBLNDecoderOnlyModel",
71
+ "RBLNDetrForObjectDetection",
72
+ "RBLNDetrForObjectDetectionConfig",
71
73
  "RBLNDistilBertForQuestionAnswering",
72
74
  "RBLNDistilBertForQuestionAnsweringConfig",
73
75
  "RBLNDPTForDepthEstimation",
@@ -78,6 +80,10 @@ _import_structure = {
78
80
  "RBLNExaoneForCausalLMConfig",
79
81
  "RBLNGemmaModel",
80
82
  "RBLNGemmaModelConfig",
83
+ "RBLNGemma2ForCausalLM",
84
+ "RBLNGemma2ForCausalLMConfig",
85
+ "RBLNGemma2Model",
86
+ "RBLNGemma2ModelConfig",
81
87
  "RBLNGemma3ForCausalLM",
82
88
  "RBLNGemma3ForCausalLMConfig",
83
89
  "RBLNGemma3ForConditionalGeneration",
@@ -88,6 +94,8 @@ _import_structure = {
88
94
  "RBLNGPT2LMHeadModelConfig",
89
95
  "RBLNGPT2Model",
90
96
  "RBLNGPT2ModelConfig",
97
+ "RBLNGptOssForCausalLM",
98
+ "RBLNGptOssForCausalLMConfig",
91
99
  "RBLNGroundingDinoDecoder",
92
100
  "RBLNGroundingDinoDecoderConfig",
93
101
  "RBLNGroundingDinoForObjectDetection",
@@ -110,6 +118,10 @@ _import_structure = {
110
118
  "RBLNPegasusForConditionalGenerationConfig",
111
119
  "RBLNPegasusModel",
112
120
  "RBLNPegasusModelConfig",
121
+ "RBLNPaliGemmaForConditionalGeneration",
122
+ "RBLNPaliGemmaForConditionalGenerationConfig",
123
+ "RBLNPaliGemmaModel",
124
+ "RBLNPaliGemmaModelConfig",
113
125
  "RBLNLlavaNextForConditionalGeneration",
114
126
  "RBLNLlavaNextForConditionalGenerationConfig",
115
127
  "RBLNLoRAAdapterConfig",
@@ -120,6 +132,8 @@ _import_structure = {
120
132
  "RBLNMistralForCausalLMConfig",
121
133
  "RBLNMistralModel",
122
134
  "RBLNMistralModelConfig",
135
+ "RBLNMixtralForCausalLM",
136
+ "RBLNMixtralForCausalLMConfig",
123
137
  "RBLNOPTForCausalLM",
124
138
  "RBLNOPTForCausalLMConfig",
125
139
  "RBLNOPTModel",
@@ -134,14 +148,22 @@ _import_structure = {
134
148
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
135
149
  "RBLNQwen2_5_VLForConditionalGeneration",
136
150
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
151
+ "RBLNQwen2_5_VLModel",
152
+ "RBLNQwen2_5_VLModelConfig",
137
153
  "RBLNQwen2VisionTransformerPretrainedModel",
138
154
  "RBLNQwen2VisionTransformerPretrainedModelConfig",
139
155
  "RBLNQwen2VLForConditionalGeneration",
140
156
  "RBLNQwen2VLForConditionalGenerationConfig",
157
+ "RBLNQwen2VLModel",
158
+ "RBLNQwen2VLModelConfig",
141
159
  "RBLNQwen2Model",
142
160
  "RBLNQwen2ModelConfig",
143
161
  "RBLNQwen2ForCausalLM",
144
162
  "RBLNQwen2ForCausalLMConfig",
163
+ "RBLNQwen2MoeForCausalLM",
164
+ "RBLNQwen2MoeForCausalLMConfig",
165
+ "RBLNQwen3MoeForCausalLM",
166
+ "RBLNQwen3MoeForCausalLMConfig",
145
167
  "RBLNQwen3ForCausalLM",
146
168
  "RBLNQwen3ForCausalLMConfig",
147
169
  "RBLNQwen3Model",
@@ -228,12 +250,18 @@ if TYPE_CHECKING:
228
250
  RBLNDecoderOnlyModelForCausalLMConfig,
229
251
  RBLNDepthAnythingForDepthEstimation,
230
252
  RBLNDepthAnythingForDepthEstimationConfig,
253
+ RBLNDetrForObjectDetection,
254
+ RBLNDetrForObjectDetectionConfig,
231
255
  RBLNDistilBertForQuestionAnswering,
232
256
  RBLNDistilBertForQuestionAnsweringConfig,
233
257
  RBLNDPTForDepthEstimation,
234
258
  RBLNDPTForDepthEstimationConfig,
235
259
  RBLNExaoneForCausalLM,
236
260
  RBLNExaoneForCausalLMConfig,
261
+ RBLNGemma2ForCausalLM,
262
+ RBLNGemma2ForCausalLMConfig,
263
+ RBLNGemma2Model,
264
+ RBLNGemma2ModelConfig,
237
265
  RBLNGemma3ForCausalLM,
238
266
  RBLNGemma3ForCausalLMConfig,
239
267
  RBLNGemma3ForConditionalGeneration,
@@ -246,6 +274,8 @@ if TYPE_CHECKING:
246
274
  RBLNGPT2LMHeadModelConfig,
247
275
  RBLNGPT2Model,
248
276
  RBLNGPT2ModelConfig,
277
+ RBLNGptOssForCausalLM,
278
+ RBLNGptOssForCausalLMConfig,
249
279
  RBLNGroundingDinoDecoder,
250
280
  RBLNGroundingDinoDecoderConfig,
251
281
  RBLNGroundingDinoEncoder,
@@ -272,10 +302,16 @@ if TYPE_CHECKING:
272
302
  RBLNMistralForCausalLMConfig,
273
303
  RBLNMistralModel,
274
304
  RBLNMistralModelConfig,
305
+ RBLNMixtralForCausalLM,
306
+ RBLNMixtralForCausalLMConfig,
275
307
  RBLNOPTForCausalLM,
276
308
  RBLNOPTForCausalLMConfig,
277
309
  RBLNOPTModel,
278
310
  RBLNOPTModelConfig,
311
+ RBLNPaliGemmaForConditionalGeneration,
312
+ RBLNPaliGemmaForConditionalGenerationConfig,
313
+ RBLNPaliGemmaModel,
314
+ RBLNPaliGemmaModelConfig,
279
315
  RBLNPegasusForConditionalGeneration,
280
316
  RBLNPegasusForConditionalGenerationConfig,
281
317
  RBLNPegasusModel,
@@ -290,18 +326,26 @@ if TYPE_CHECKING:
290
326
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
291
327
  RBLNQwen2_5_VLForConditionalGeneration,
292
328
  RBLNQwen2_5_VLForConditionalGenerationConfig,
329
+ RBLNQwen2_5_VLModel,
330
+ RBLNQwen2_5_VLModelConfig,
293
331
  RBLNQwen2ForCausalLM,
294
332
  RBLNQwen2ForCausalLMConfig,
295
333
  RBLNQwen2Model,
296
334
  RBLNQwen2ModelConfig,
335
+ RBLNQwen2MoeForCausalLM,
336
+ RBLNQwen2MoeForCausalLMConfig,
297
337
  RBLNQwen2VisionTransformerPretrainedModel,
298
338
  RBLNQwen2VisionTransformerPretrainedModelConfig,
299
339
  RBLNQwen2VLForConditionalGeneration,
300
340
  RBLNQwen2VLForConditionalGenerationConfig,
341
+ RBLNQwen2VLModel,
342
+ RBLNQwen2VLModelConfig,
301
343
  RBLNQwen3ForCausalLM,
302
344
  RBLNQwen3ForCausalLMConfig,
303
345
  RBLNQwen3Model,
304
346
  RBLNQwen3ModelConfig,
347
+ RBLNQwen3MoeForCausalLM,
348
+ RBLNQwen3MoeForCausalLMConfig,
305
349
  RBLNResNetForImageClassification,
306
350
  RBLNResNetForImageClassificationConfig,
307
351
  RBLNRobertaForMaskedLM,
@@ -1,19 +1,16 @@
1
1
  import math
2
- from collections import Counter, defaultdict
3
- from typing import TYPE_CHECKING, Dict, Optional, Tuple
2
+ from collections import defaultdict
3
+ from typing import Optional, Tuple
4
4
 
5
5
  import rebel
6
6
 
7
7
  from ..utils.logging import get_logger
8
- from ..utils.runtime_utils import get_available_dram
8
+ from ..utils.runtime_utils import get_available_dram, is_compiler_supports_buffer_resize
9
9
  from .models.decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
10
10
 
11
11
 
12
12
  logger = get_logger()
13
13
 
14
- if TYPE_CHECKING:
15
- from transformers import PretrainedConfig, PreTrainedModel
16
-
17
14
 
18
15
  DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
19
16
  DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
@@ -123,7 +120,7 @@ def align_2MB(x: int) -> int:
123
120
  return align(x, 2**21)
124
121
 
125
122
 
126
- def get_alloc_memory_by_key(compiled_models: Dict[str, "rebel.RBLNCompiledModel"]) -> Dict[str, int]:
123
+ def get_alloc_memory_by_key(compiled_models: dict[str, rebel.RBLNCompiledModel]) -> dict[str, int]:
127
124
  alloc_memory_by_key = defaultdict(int)
128
125
  # Get the actual memory allocation of each node by key
129
126
  for compiled_model in compiled_models.values():
@@ -147,239 +144,144 @@ def format_byte_size(nbytes: int) -> str:
147
144
 
148
145
  class RBLNDecoderOnlyFlashAttentionMixin:
149
146
  @classmethod
150
- def get_maximum_num_blocks_by_model(
151
- cls,
152
- model: "PreTrainedModel",
153
- model_config: "PretrainedConfig",
154
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
155
- ) -> int:
156
- tensor_parallel_size = rbln_config.tensor_parallel_size or 1
157
- available_dram = get_available_dram(rbln_config.npu) * tensor_parallel_size
158
-
159
- kernel_memory = cls._get_kernel_memory(model, model_config=model_config, rbln_config=rbln_config)
160
- buffer = cls._get_buffer(rbln_config)
161
-
162
- remaining_dram = available_dram - kernel_memory - buffer
163
- if remaining_dram <= 0:
147
+ def set_kvcache_num_blocks_after_compilation(
148
+ cls, compiled_models: dict[str, rebel.RBLNCompiledModel], rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
149
+ ):
150
+ rbln_config.kvcache_num_blocks = cls.estimate_num_kvcache_blocks(
151
+ compiled_models=compiled_models, rbln_config=rbln_config
152
+ )
153
+ if rbln_config.kvcache_num_blocks < rbln_config.num_min_blocks:
164
154
  raise ValueError(
165
- "Insufficient available DRAM after accounting for kernel memory and buffer. "
166
- "Cannot allocate any KV cache blocks."
167
- f" (Available DRAM: {format_byte_size(available_dram)}, "
168
- f"Kernel Memory: {format_byte_size(kernel_memory)}, "
169
- f"Buffer: {format_byte_size(buffer)})"
155
+ "Memory is not enought for full sequence length. "
156
+ "Please consider decreasing `max_seq_len` to reduce the number of blocks."
170
157
  )
171
- estimated_num_blocks = cls._estimate_num_blocks(
172
- remaining_dram, model_config=model_config, rbln_config=rbln_config
173
- )
174
-
175
- return estimated_num_blocks
176
-
177
- @classmethod
178
- def _get_kernel_memory(
179
- cls,
180
- model: "PreTrainedModel",
181
- model_config: "PretrainedConfig",
182
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
183
- ) -> int:
184
- if model.get_output_embeddings() is None:
185
- lm_head_nbytes = 0
186
- else:
187
- lm_head_nbytes = cls._get_lm_head_memory(model_config, rbln_config)
188
-
189
- layer_nbytes = cls._get_layer_memory(model, model_config, rbln_config)
190
- return lm_head_nbytes + layer_nbytes
191
-
192
- @classmethod
193
- def _get_lm_head_memory(
194
- cls, model_config: "PretrainedConfig", rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
195
- ) -> int:
196
- tensor_parallel_size = rbln_config.tensor_parallel_size or 1
197
- vocab_size = model_config.vocab_size
198
- hidden_size = getattr(model_config, "n_embd", None) or model_config.hidden_size
199
- lm_head_params = align(vocab_size, 64) * hidden_size
200
-
201
- nbytes_per_param = 2 # Assuming lm_head is always not quantized
202
- lm_head_memory_in_bytes = (
203
- align_2MB(lm_head_params * nbytes_per_param / tensor_parallel_size) * tensor_parallel_size
158
+ cls.multiply_kv_cache_num_blocks(
159
+ compiled_models=compiled_models, rbln_config=rbln_config, multiplier=rbln_config.kvcache_num_blocks
204
160
  )
205
161
 
206
- return lm_head_memory_in_bytes
207
-
208
162
  @classmethod
209
- def _get_layer_memory(
163
+ def estimate_num_kvcache_blocks(
210
164
  cls,
211
- model: "PreTrainedModel",
212
- model_config: "PretrainedConfig",
165
+ compiled_models: dict[str, rebel.RBLNCompiledModel],
213
166
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
167
+ available_dram: Optional[int] = None,
214
168
  ) -> int:
215
- # This is an *APPROXIMATE* calculation based on the number of parameters
216
- tensor_parallel_size = rbln_config.tensor_parallel_size or 1
217
- num_hidden_layers = getattr(model_config, "n_layer", None) or model_config.num_hidden_layers
169
+ if available_dram is None:
170
+ available_dram = get_available_dram(rbln_config.npu)
218
171
 
219
- n_model_params = sum(p.numel() for p in model.parameters())
220
- embed_token_params = sum(p.numel() for p in model.get_input_embeddings().parameters())
221
-
222
- # Check : `embed_token` is same as `lm_head`
223
- if model.get_output_embeddings() is not None:
224
- params = n_model_params - 2 * embed_token_params
225
- else:
226
- params = n_model_params - embed_token_params
227
-
228
- # Assuming all layers have the same number of parameters
229
- # and all linear layers are quantized if quantization is enabled (This is not always true)
230
- # TODO(jongho): More accurate calculation
231
- nbits_per_param = rbln_config.nbits_per_param
232
- layer_nbytes = (
233
- (align_2MB(params // num_hidden_layers * nbits_per_param // 8 / tensor_parallel_size))
234
- * num_hidden_layers
235
- * tensor_parallel_size
236
- )
237
-
238
- return layer_nbytes
239
-
240
- @classmethod
241
- def _get_buffer(cls, rbln_config) -> int:
242
- # TODO(jongho): Accurate buffer estimation
243
- buffer_per_runtime_per_core = 2**28 # 256MB per runtime
244
- num_runtimes = 1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes)
245
- tensor_parallel_size = rbln_config.tensor_parallel_size or 1
246
-
247
- buffer_per_core = buffer_per_runtime_per_core * num_runtimes
248
- buffer = buffer_per_core * tensor_parallel_size
249
- return buffer
250
-
251
- @classmethod
252
- def get_maximum_num_blocks_by_compiled_model(
253
- cls,
254
- compiled_models: Dict[str, "rebel.RBLNCompiledModel"],
255
- model_config: "PretrainedConfig",
256
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
257
- ) -> int:
258
- tensor_parallel_size = rbln_config.tensor_parallel_size or 1
259
- available_dram = get_available_dram(rbln_config.npu) * tensor_parallel_size
260
-
261
- alloc_memory_by_key = get_alloc_memory_by_key(compiled_models)
262
- alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
263
- alloc_memory_by_key.pop("DramTensor", None) # kv-cache
264
- used_memory = sum(alloc_memory_by_key.values())
265
-
266
- remaining_dram = available_dram - used_memory
267
-
268
- if remaining_dram <= 0:
172
+ if "prefill" not in rbln_config.phases:
269
173
  logger.warning(
270
- "Insufficient available DRAM after accounting for kernel memory and buffer. "
271
- "Model cannot allocate any KV cache blocks."
174
+ "Not estimating number of KV cache blocks since `prefill` phase is not in the `phases` list."
272
175
  )
273
-
274
- estimated_num_blocks = cls._estimate_num_blocks(
275
- remaining_dram, model_config=model_config, rbln_config=rbln_config
276
- )
277
-
278
- return estimated_num_blocks
279
-
280
- @classmethod
281
- def _estimate_num_blocks(
282
- cls, available_dram: int, model_config: "PretrainedConfig", rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
283
- ) -> int:
284
- """
285
- Estimate the maximum number of KV cache blocks that can be allocated.
286
-
287
- if all of the layers are full attention, the dram_per_block can be calculated simply as follows:
288
- num_blocks = available_dram // dram_per_block
289
-
290
- However, if the model contains a mix of full attention and sliding window attention layers,
291
- we need to consider the memory occupied by the sliding window attention layers first,
292
- since their memory usage is constant regardless of the number of blocks.
293
- num_blocks = (available_dram - swa_kv_nbytes) // dram_per_block
294
-
295
- """
296
-
297
- def get_dram_per_block(seq_len: int, num_key_value_heads: int, tensor_parallel_size: int) -> int:
298
- nbytes_per_param = 2 # Assuming kv-cache is always not quantized
299
- dram_per_block = (
300
- seq_len
301
- * align(head_dim, 64)
302
- * math.ceil(num_key_value_heads / tensor_parallel_size)
303
- * nbytes_per_param
304
- * tensor_parallel_size
305
- * 2
306
- ) # *2 for key and value
307
-
308
- return dram_per_block
309
-
310
- num_attention_heads = getattr(model_config, "n_head", None) or model_config.num_attention_heads
311
- head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
312
- num_hidden_layers = getattr(model_config, "n_layer", None) or model_config.num_hidden_layers
313
- num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
314
- tensor_parallel_size = rbln_config.tensor_parallel_size or 1
315
-
316
- # Consider layer types if available
317
- # If layer types are not found, assume all layers are full attention
318
- layer_types = getattr(model_config, "layer_types", None)
319
- if layer_types:
320
- layer_types_dict = Counter(layer_types)
321
- num_full_attention = layer_types_dict.pop("full_attention", 0)
322
- num_sliding_window_attention = layer_types_dict.pop("sliding_attention", 0)
323
- if len(layer_types_dict) > 0:
324
- raise ValueError(f"Unknown layer types found in the config: {layer_types_dict.keys()}")
325
-
326
- else:
327
- num_full_attention = num_hidden_layers
328
- num_sliding_window_attention = 0
329
-
330
- # Reduce available DRAM by sliding window attention kv-cache
331
- # Since memory occupation of swa layer is constant regardless of num_blocks
332
- swa_kv_nbytes = 0
333
- if num_sliding_window_attention > 0:
334
- sliding_window = getattr(model_config, "sliding_window", None)
335
- if sliding_window is None:
336
- logger.warning(
337
- "`sliding_window` is not found in the config while `sliding_attention` layers are present. "
338
- "Assuming maximum sliding window size for estimation."
339
- )
340
- sliding_window = rbln_config.kvcache_block_size
341
-
342
- swa_kv_nbytes = num_sliding_window_attention * get_dram_per_block(
343
- seq_len=sliding_window,
344
- num_key_value_heads=num_key_value_heads,
345
- tensor_parallel_size=tensor_parallel_size,
176
+ return 1
177
+
178
+ num_node = rbln_config.tensor_parallel_size or 1
179
+ alloc_per_node_without_dram = [0] * num_node
180
+
181
+ for compiled_model in compiled_models.values():
182
+ for key, alloc_per_node in compiled_model.get_alloc_per_node_by_key().items():
183
+ if key == "DramTensor":
184
+ continue
185
+
186
+ if len(alloc_per_node) != num_node:
187
+ alloc_per_node += [0] * (num_node - len(alloc_per_node))
188
+
189
+ alloc_per_node_without_dram = [a + b for a, b in zip(alloc_per_node_without_dram, alloc_per_node)]
190
+
191
+ remaining_dram_at_node: list[int] = [
192
+ available_dram - without_dramtensor for without_dramtensor in alloc_per_node_without_dram
193
+ ]
194
+
195
+ # kvcache_tensor_sizes[key][node_id][chiplet_id] = alloc_size
196
+ kvcache_tensor_sizes: dict[str, list[list[int]]] = compiled_models["prefill"].exp_get_dram_tensor_sizes()
197
+ kvcache_meta_can_resize: dict[str, bool] = {
198
+ kvcache_meta.name: kvcache_meta.can_resize for kvcache_meta in rbln_config.kvcache_metas
199
+ }
200
+
201
+ def get_updated_kvcache_tensor_sizes(
202
+ kvcache_tensor_sizes: dict[str, list[list[int]]], multiplier: int
203
+ ) -> dict[str, list[list[int]]]:
204
+ # Get the updated KV cache tensor sizes by multiplying the multiplier
205
+ # with considering attention type (full or sliding), and memory alignment.
206
+ ret: dict[str, list[list[int]]] = {}
207
+ for key, sizes_at_node in kvcache_tensor_sizes.items():
208
+ m = multiplier if kvcache_meta_can_resize[key] else 1
209
+ ret[key] = [
210
+ [align_2MB(size_at_chiplet * m) for size_at_chiplet in sizes_at_node_at_chiplet]
211
+ for sizes_at_node_at_chiplet in sizes_at_node
212
+ ]
213
+ return ret
214
+
215
+ def check_memory_fits(multiplier: int) -> tuple[bool, list[int]]:
216
+ # Check if the given multiplier fits in memory
217
+ # Returns (fits: bool, kvcache_tensor_sizes_at_node: list[int])
218
+ updated_kvcache_tensor_sizes = get_updated_kvcache_tensor_sizes(kvcache_tensor_sizes, multiplier)
219
+
220
+ kvcache_tensor_sizes_at_node: list[int] = [0] * num_node
221
+ for tensor_sizes_at_node in updated_kvcache_tensor_sizes.values():
222
+ tensor_sizes_at_node: list[list[int]]
223
+ for node_id, sizes_at_chiplet in enumerate(tensor_sizes_at_node):
224
+ sizes_at_chiplet: list[int]
225
+ kvcache_tensor_sizes_at_node[node_id] += sum(sizes_at_chiplet)
226
+
227
+ fits = all(
228
+ remaining_dram_at_node[node_id] >= kvcache_tensor_sizes_at_node[node_id] for node_id in range(num_node)
346
229
  )
347
-
348
- available_dram -= swa_kv_nbytes
349
-
350
- dram_per_block = num_full_attention * get_dram_per_block(
351
- seq_len=rbln_config.kvcache_block_size,
352
- num_key_value_heads=num_key_value_heads,
353
- tensor_parallel_size=tensor_parallel_size,
230
+ return fits, kvcache_tensor_sizes_at_node
231
+
232
+ # Fast path: try maximum blocks first (most common case)
233
+ fits, _ = check_memory_fits(rbln_config.num_full_blocks)
234
+ if fits:
235
+ # Best case: maximum blocks fit in memory
236
+ return rbln_config.num_full_blocks
237
+
238
+ # Slow path: binary search for optimal multiplier
239
+ logger.debug(
240
+ f"[KVCache] Not enough memory for {rbln_config.num_full_blocks} blocks. "
241
+ f"Searching for optimal multiplier..."
354
242
  )
355
243
 
356
- if dram_per_block == 0:
357
- raise ValueError("DRAM per block is calculated as zero, cannot estimate maximum number of blocks.")
244
+ left, right = 1, rbln_config.num_full_blocks - 1
245
+ multiplier = 1 # Default to minimum if no valid multiplier found
246
+
247
+ while left <= right:
248
+ mid = (left + right) // 2
249
+ fits, kvcache_tensor_sizes_at_node = check_memory_fits(mid)
250
+
251
+ if fits:
252
+ # Memory is sufficient, try larger multiplier
253
+ multiplier = mid
254
+ left = mid + 1
255
+ else:
256
+ # Memory is insufficient, try smaller multiplier
257
+ logger.debug(
258
+ f"[KVCache] Not enough memory for {mid} blocks. Remaining DRAM: "
259
+ f"{[format_byte_size(remaining_dram) for remaining_dram in remaining_dram_at_node]}, "
260
+ f"KV cache tensor sizes: {[format_byte_size(size) for size in kvcache_tensor_sizes_at_node]}"
261
+ )
262
+ right = mid - 1
358
263
 
359
- max_n_blocks = available_dram // dram_per_block
360
- return max_n_blocks
264
+ return multiplier
361
265
 
362
266
  @classmethod
363
- def maybe_suggest_kvcache_num_blocks(
267
+ def multiply_kv_cache_num_blocks(
364
268
  cls,
365
- compiled_models: Dict[str, "rebel.RBLNCompiledModel"],
366
- model_config: "PretrainedConfig",
269
+ compiled_models: dict[str, rebel.RBLNCompiledModel],
367
270
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
368
- ) -> None:
369
- max_num_blocks = cls.get_maximum_num_blocks_by_compiled_model(
370
- compiled_models=compiled_models,
371
- model_config=model_config,
372
- rbln_config=rbln_config,
373
- )
271
+ multiplier: int,
272
+ ):
273
+ if not is_compiler_supports_buffer_resize():
274
+ raise RuntimeError(
275
+ "The installed version of rebel-compiler does not support automatic kv cache size determination. "
276
+ "Please upgrade rebel-compiler to a version that supports this feature, "
277
+ "or explicitly set 'kvcache_num_blocks' in rbln_config to manually specify the cache size."
278
+ )
374
279
 
375
- # Since our estimation logic is not always accurate,
376
- # users can set `kvcache_num_blocks` to `max_num_blocks`.
377
- # If the memory is not enough, the model will fail to compile.
378
- if rbln_config.kvcache_num_blocks < max_num_blocks:
379
- logger.warning(
380
- f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
381
- "Our analysis indicates that additional memory is available for more blocks. "
382
- f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
383
- "Please be advised that our memory estimation algorithm has limitations, "
384
- "and increasing this value may not guarantee successful model compilation."
280
+ for compiled_model in compiled_models.values():
281
+ compiled_model.exp_multiply_buffer_size(
282
+ {
283
+ kvcache_meta.name: multiplier
284
+ for kvcache_meta in rbln_config.kvcache_metas
285
+ if kvcache_meta.can_resize
286
+ }
385
287
  )
@@ -18,6 +18,8 @@ from typing import Optional, Tuple
18
18
  import torch
19
19
  from transformers.modeling_outputs import ModelOutput
20
20
 
21
+ from ..configuration_utils import RBLNModelConfig
22
+
21
23
 
22
24
  @dataclass
23
25
  class RBLNDecoderOnlyOutput(ModelOutput):
@@ -36,3 +38,26 @@ class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
36
38
  class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
37
39
  last_hidden_states: torch.FloatTensor = None
38
40
  params: Tuple[torch.FloatTensor] = None
41
+
42
+
43
+ def _validate_output_hidden_states(output_hidden_states: Optional[bool], rbln_config: RBLNModelConfig):
44
+ output_hidden_states = (
45
+ output_hidden_states if output_hidden_states is not None else rbln_config.output_hidden_states
46
+ )
47
+ if output_hidden_states != rbln_config.output_hidden_states:
48
+ raise ValueError(
49
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {rbln_config.output_hidden_states} "
50
+ f"Please compile again with the correct argument."
51
+ )
52
+
53
+ return output_hidden_states
54
+
55
+
56
+ def _validate_output_attentions(output_attentions: Optional[bool], rbln_config: RBLNModelConfig):
57
+ output_attentions = output_attentions if output_attentions is not None else rbln_config.output_attentions
58
+ if output_attentions != rbln_config.output_attentions:
59
+ raise ValueError(
60
+ f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {rbln_config.output_attentions} "
61
+ f"Please compile again with the correct argument."
62
+ )
63
+ return output_attentions