optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. optimum/rbln/__init__.py +36 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +35 -16
  4. optimum/rbln/modeling_base.py +6 -6
  5. optimum/rbln/ops/__init__.py +1 -0
  6. optimum/rbln/ops/attn.py +10 -0
  7. optimum/rbln/ops/flash_attn.py +8 -0
  8. optimum/rbln/ops/moe.py +180 -0
  9. optimum/rbln/ops/sliding_window_attn.py +9 -0
  10. optimum/rbln/transformers/__init__.py +36 -0
  11. optimum/rbln/transformers/modeling_attention_utils.py +118 -222
  12. optimum/rbln/transformers/modeling_outputs.py +25 -0
  13. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  14. optimum/rbln/transformers/models/__init__.py +28 -0
  15. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  16. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  17. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  18. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  19. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
  20. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  21. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  22. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
  23. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
  25. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
  27. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  29. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  30. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  31. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  32. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  33. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  34. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
  35. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  36. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  37. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  38. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  39. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  40. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  41. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  42. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  43. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  44. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  45. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  46. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  47. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  48. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  50. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  51. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  53. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  54. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  55. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  56. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  57. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  58. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  59. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  60. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  61. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  62. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  63. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  64. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  65. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  66. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  68. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  69. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  71. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  72. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  73. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  74. optimum/rbln/utils/import_utils.py +16 -1
  75. optimum/rbln/utils/runtime_utils.py +10 -6
  76. optimum/rbln/utils/submodule.py +24 -0
  77. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  78. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
  79. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  80. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
  81. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  82. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -78,6 +78,10 @@ _import_structure = {
78
78
  "RBLNExaoneForCausalLMConfig",
79
79
  "RBLNGemmaModel",
80
80
  "RBLNGemmaModelConfig",
81
+ "RBLNGemma2ForCausalLM",
82
+ "RBLNGemma2ForCausalLMConfig",
83
+ "RBLNGemma2Model",
84
+ "RBLNGemma2ModelConfig",
81
85
  "RBLNGemma3ForCausalLM",
82
86
  "RBLNGemma3ForCausalLMConfig",
83
87
  "RBLNGemma3ForConditionalGeneration",
@@ -88,6 +92,8 @@ _import_structure = {
88
92
  "RBLNGPT2LMHeadModelConfig",
89
93
  "RBLNGPT2Model",
90
94
  "RBLNGPT2ModelConfig",
95
+ "RBLNGptOssForCausalLM",
96
+ "RBLNGptOssForCausalLMConfig",
91
97
  "RBLNGroundingDinoDecoder",
92
98
  "RBLNGroundingDinoDecoderConfig",
93
99
  "RBLNGroundingDinoForObjectDetection",
@@ -110,6 +116,10 @@ _import_structure = {
110
116
  "RBLNPegasusForConditionalGenerationConfig",
111
117
  "RBLNPegasusModel",
112
118
  "RBLNPegasusModelConfig",
119
+ "RBLNPaliGemmaForConditionalGeneration",
120
+ "RBLNPaliGemmaForConditionalGenerationConfig",
121
+ "RBLNPaliGemmaModel",
122
+ "RBLNPaliGemmaModelConfig",
113
123
  "RBLNLlavaNextForConditionalGeneration",
114
124
  "RBLNLlavaNextForConditionalGenerationConfig",
115
125
  "RBLNLoRAAdapterConfig",
@@ -134,14 +144,22 @@ _import_structure = {
134
144
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
135
145
  "RBLNQwen2_5_VLForConditionalGeneration",
136
146
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
147
+ "RBLNQwen2_5_VLModel",
148
+ "RBLNQwen2_5_VLModelConfig",
137
149
  "RBLNQwen2VisionTransformerPretrainedModel",
138
150
  "RBLNQwen2VisionTransformerPretrainedModelConfig",
139
151
  "RBLNQwen2VLForConditionalGeneration",
140
152
  "RBLNQwen2VLForConditionalGenerationConfig",
153
+ "RBLNQwen2VLModel",
154
+ "RBLNQwen2VLModelConfig",
141
155
  "RBLNQwen2Model",
142
156
  "RBLNQwen2ModelConfig",
143
157
  "RBLNQwen2ForCausalLM",
144
158
  "RBLNQwen2ForCausalLMConfig",
159
+ "RBLNQwen2MoeForCausalLM",
160
+ "RBLNQwen2MoeForCausalLMConfig",
161
+ "RBLNQwen3MoeForCausalLM",
162
+ "RBLNQwen3MoeForCausalLMConfig",
145
163
  "RBLNQwen3ForCausalLM",
146
164
  "RBLNQwen3ForCausalLMConfig",
147
165
  "RBLNQwen3Model",
@@ -234,6 +252,10 @@ if TYPE_CHECKING:
234
252
  RBLNDPTForDepthEstimationConfig,
235
253
  RBLNExaoneForCausalLM,
236
254
  RBLNExaoneForCausalLMConfig,
255
+ RBLNGemma2ForCausalLM,
256
+ RBLNGemma2ForCausalLMConfig,
257
+ RBLNGemma2Model,
258
+ RBLNGemma2ModelConfig,
237
259
  RBLNGemma3ForCausalLM,
238
260
  RBLNGemma3ForCausalLMConfig,
239
261
  RBLNGemma3ForConditionalGeneration,
@@ -246,6 +268,8 @@ if TYPE_CHECKING:
246
268
  RBLNGPT2LMHeadModelConfig,
247
269
  RBLNGPT2Model,
248
270
  RBLNGPT2ModelConfig,
271
+ RBLNGptOssForCausalLM,
272
+ RBLNGptOssForCausalLMConfig,
249
273
  RBLNGroundingDinoDecoder,
250
274
  RBLNGroundingDinoDecoderConfig,
251
275
  RBLNGroundingDinoEncoder,
@@ -276,6 +300,10 @@ if TYPE_CHECKING:
276
300
  RBLNOPTForCausalLMConfig,
277
301
  RBLNOPTModel,
278
302
  RBLNOPTModelConfig,
303
+ RBLNPaliGemmaForConditionalGeneration,
304
+ RBLNPaliGemmaForConditionalGenerationConfig,
305
+ RBLNPaliGemmaModel,
306
+ RBLNPaliGemmaModelConfig,
279
307
  RBLNPegasusForConditionalGeneration,
280
308
  RBLNPegasusForConditionalGenerationConfig,
281
309
  RBLNPegasusModel,
@@ -290,18 +318,26 @@ if TYPE_CHECKING:
290
318
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
291
319
  RBLNQwen2_5_VLForConditionalGeneration,
292
320
  RBLNQwen2_5_VLForConditionalGenerationConfig,
321
+ RBLNQwen2_5_VLModel,
322
+ RBLNQwen2_5_VLModelConfig,
293
323
  RBLNQwen2ForCausalLM,
294
324
  RBLNQwen2ForCausalLMConfig,
295
325
  RBLNQwen2Model,
296
326
  RBLNQwen2ModelConfig,
327
+ RBLNQwen2MoeForCausalLM,
328
+ RBLNQwen2MoeForCausalLMConfig,
297
329
  RBLNQwen2VisionTransformerPretrainedModel,
298
330
  RBLNQwen2VisionTransformerPretrainedModelConfig,
299
331
  RBLNQwen2VLForConditionalGeneration,
300
332
  RBLNQwen2VLForConditionalGenerationConfig,
333
+ RBLNQwen2VLModel,
334
+ RBLNQwen2VLModelConfig,
301
335
  RBLNQwen3ForCausalLM,
302
336
  RBLNQwen3ForCausalLMConfig,
303
337
  RBLNQwen3Model,
304
338
  RBLNQwen3ModelConfig,
339
+ RBLNQwen3MoeForCausalLM,
340
+ RBLNQwen3MoeForCausalLMConfig,
305
341
  RBLNResNetForImageClassification,
306
342
  RBLNResNetForImageClassificationConfig,
307
343
  RBLNRobertaForMaskedLM,
@@ -1,19 +1,16 @@
1
1
  import math
2
- from collections import Counter, defaultdict
3
- from typing import TYPE_CHECKING, Dict, Optional, Tuple
2
+ from collections import defaultdict
3
+ from typing import Optional, Tuple
4
4
 
5
5
  import rebel
6
6
 
7
7
  from ..utils.logging import get_logger
8
- from ..utils.runtime_utils import get_available_dram
8
+ from ..utils.runtime_utils import get_available_dram, is_compiler_supports_buffer_resize
9
9
  from .models.decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
10
10
 
11
11
 
12
12
  logger = get_logger()
13
13
 
14
- if TYPE_CHECKING:
15
- from transformers import PretrainedConfig, PreTrainedModel
16
-
17
14
 
18
15
  DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
19
16
  DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
@@ -123,7 +120,7 @@ def align_2MB(x: int) -> int:
123
120
  return align(x, 2**21)
124
121
 
125
122
 
126
- def get_alloc_memory_by_key(compiled_models: Dict[str, "rebel.RBLNCompiledModel"]) -> Dict[str, int]:
123
+ def get_alloc_memory_by_key(compiled_models: dict[str, rebel.RBLNCompiledModel]) -> dict[str, int]:
127
124
  alloc_memory_by_key = defaultdict(int)
128
125
  # Get the actual memory allocation of each node by key
129
126
  for compiled_model in compiled_models.values():
@@ -147,239 +144,138 @@ def format_byte_size(nbytes: int) -> str:
147
144
 
148
145
  class RBLNDecoderOnlyFlashAttentionMixin:
149
146
  @classmethod
150
- def get_maximum_num_blocks_by_model(
151
- cls,
152
- model: "PreTrainedModel",
153
- model_config: "PretrainedConfig",
154
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
155
- ) -> int:
156
- tensor_parallel_size = rbln_config.tensor_parallel_size or 1
157
- available_dram = get_available_dram(rbln_config.npu) * tensor_parallel_size
158
-
159
- kernel_memory = cls._get_kernel_memory(model, model_config=model_config, rbln_config=rbln_config)
160
- buffer = cls._get_buffer(rbln_config)
161
-
162
- remaining_dram = available_dram - kernel_memory - buffer
163
- if remaining_dram <= 0:
147
+ def set_kvcache_num_blocks_after_compilation(
148
+ cls, compiled_models: dict[str, rebel.RBLNCompiledModel], rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
149
+ ):
150
+ rbln_config.kvcache_num_blocks = cls.estimate_num_kvcache_blocks(
151
+ compiled_models=compiled_models, rbln_config=rbln_config
152
+ )
153
+ if rbln_config.kvcache_num_blocks < rbln_config.num_min_blocks:
164
154
  raise ValueError(
165
- "Insufficient available DRAM after accounting for kernel memory and buffer. "
166
- "Cannot allocate any KV cache blocks."
167
- f" (Available DRAM: {format_byte_size(available_dram)}, "
168
- f"Kernel Memory: {format_byte_size(kernel_memory)}, "
169
- f"Buffer: {format_byte_size(buffer)})"
155
+ "Memory is not enought for full sequence length. "
156
+ "Please consider decreasing `max_seq_len` to reduce the number of blocks."
170
157
  )
171
- estimated_num_blocks = cls._estimate_num_blocks(
172
- remaining_dram, model_config=model_config, rbln_config=rbln_config
173
- )
174
-
175
- return estimated_num_blocks
176
-
177
- @classmethod
178
- def _get_kernel_memory(
179
- cls,
180
- model: "PreTrainedModel",
181
- model_config: "PretrainedConfig",
182
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
183
- ) -> int:
184
- if model.get_output_embeddings() is None:
185
- lm_head_nbytes = 0
186
- else:
187
- lm_head_nbytes = cls._get_lm_head_memory(model_config, rbln_config)
188
-
189
- layer_nbytes = cls._get_layer_memory(model, model_config, rbln_config)
190
- return lm_head_nbytes + layer_nbytes
191
-
192
- @classmethod
193
- def _get_lm_head_memory(
194
- cls, model_config: "PretrainedConfig", rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
195
- ) -> int:
196
- tensor_parallel_size = rbln_config.tensor_parallel_size or 1
197
- vocab_size = model_config.vocab_size
198
- hidden_size = getattr(model_config, "n_embd", None) or model_config.hidden_size
199
- lm_head_params = align(vocab_size, 64) * hidden_size
200
-
201
- nbytes_per_param = 2 # Assuming lm_head is always not quantized
202
- lm_head_memory_in_bytes = (
203
- align_2MB(lm_head_params * nbytes_per_param / tensor_parallel_size) * tensor_parallel_size
158
+ cls.multiply_kv_cache_num_blocks(
159
+ compiled_models=compiled_models, rbln_config=rbln_config, multiplier=rbln_config.kvcache_num_blocks
204
160
  )
205
161
 
206
- return lm_head_memory_in_bytes
207
-
208
162
  @classmethod
209
- def _get_layer_memory(
163
+ def estimate_num_kvcache_blocks(
210
164
  cls,
211
- model: "PreTrainedModel",
212
- model_config: "PretrainedConfig",
165
+ compiled_models: dict[str, rebel.RBLNCompiledModel],
213
166
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
167
+ available_dram: Optional[int] = None,
214
168
  ) -> int:
215
- # This is an *APPROXIMATE* calculation based on the number of parameters
216
- tensor_parallel_size = rbln_config.tensor_parallel_size or 1
217
- num_hidden_layers = getattr(model_config, "n_layer", None) or model_config.num_hidden_layers
169
+ if available_dram is None:
170
+ available_dram = get_available_dram(rbln_config.npu)
218
171
 
219
- n_model_params = sum(p.numel() for p in model.parameters())
220
- embed_token_params = sum(p.numel() for p in model.get_input_embeddings().parameters())
221
-
222
- # Check : `embed_token` is same as `lm_head`
223
- if model.get_output_embeddings() is not None:
224
- params = n_model_params - 2 * embed_token_params
225
- else:
226
- params = n_model_params - embed_token_params
227
-
228
- # Assuming all layers have the same number of parameters
229
- # and all linear layers are quantized if quantization is enabled (This is not always true)
230
- # TODO(jongho): More accurate calculation
231
- nbits_per_param = rbln_config.nbits_per_param
232
- layer_nbytes = (
233
- (align_2MB(params // num_hidden_layers * nbits_per_param // 8 / tensor_parallel_size))
234
- * num_hidden_layers
235
- * tensor_parallel_size
236
- )
237
-
238
- return layer_nbytes
239
-
240
- @classmethod
241
- def _get_buffer(cls, rbln_config) -> int:
242
- # TODO(jongho): Accurate buffer estimation
243
- buffer_per_runtime_per_core = 2**28 # 256MB per runtime
244
- num_runtimes = 1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes)
245
- tensor_parallel_size = rbln_config.tensor_parallel_size or 1
246
-
247
- buffer_per_core = buffer_per_runtime_per_core * num_runtimes
248
- buffer = buffer_per_core * tensor_parallel_size
249
- return buffer
250
-
251
- @classmethod
252
- def get_maximum_num_blocks_by_compiled_model(
253
- cls,
254
- compiled_models: Dict[str, "rebel.RBLNCompiledModel"],
255
- model_config: "PretrainedConfig",
256
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
257
- ) -> int:
258
- tensor_parallel_size = rbln_config.tensor_parallel_size or 1
259
- available_dram = get_available_dram(rbln_config.npu) * tensor_parallel_size
260
-
261
- alloc_memory_by_key = get_alloc_memory_by_key(compiled_models)
262
- alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
263
- alloc_memory_by_key.pop("DramTensor", None) # kv-cache
264
- used_memory = sum(alloc_memory_by_key.values())
265
-
266
- remaining_dram = available_dram - used_memory
267
-
268
- if remaining_dram <= 0:
172
+ if "prefill" not in rbln_config.phases:
269
173
  logger.warning(
270
- "Insufficient available DRAM after accounting for kernel memory and buffer. "
271
- "Model cannot allocate any KV cache blocks."
174
+ "Not estimating number of KV cache blocks since `prefill` phase is not in the `phases` list."
272
175
  )
273
-
274
- estimated_num_blocks = cls._estimate_num_blocks(
275
- remaining_dram, model_config=model_config, rbln_config=rbln_config
276
- )
277
-
278
- return estimated_num_blocks
279
-
280
- @classmethod
281
- def _estimate_num_blocks(
282
- cls, available_dram: int, model_config: "PretrainedConfig", rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
283
- ) -> int:
284
- """
285
- Estimate the maximum number of KV cache blocks that can be allocated.
286
-
287
- if all of the layers are full attention, the dram_per_block can be calculated simply as follows:
288
- num_blocks = available_dram // dram_per_block
289
-
290
- However, if the model contains a mix of full attention and sliding window attention layers,
291
- we need to consider the memory occupied by the sliding window attention layers first,
292
- since their memory usage is constant regardless of the number of blocks.
293
- num_blocks = (available_dram - swa_kv_nbytes) // dram_per_block
294
-
295
- """
296
-
297
- def get_dram_per_block(seq_len: int, num_key_value_heads: int, tensor_parallel_size: int) -> int:
298
- nbytes_per_param = 2 # Assuming kv-cache is always not quantized
299
- dram_per_block = (
300
- seq_len
301
- * align(head_dim, 64)
302
- * math.ceil(num_key_value_heads / tensor_parallel_size)
303
- * nbytes_per_param
304
- * tensor_parallel_size
305
- * 2
306
- ) # *2 for key and value
307
-
308
- return dram_per_block
309
-
310
- num_attention_heads = getattr(model_config, "n_head", None) or model_config.num_attention_heads
311
- head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
312
- num_hidden_layers = getattr(model_config, "n_layer", None) or model_config.num_hidden_layers
313
- num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
314
- tensor_parallel_size = rbln_config.tensor_parallel_size or 1
315
-
316
- # Consider layer types if available
317
- # If layer types are not found, assume all layers are full attention
318
- layer_types = getattr(model_config, "layer_types", None)
319
- if layer_types:
320
- layer_types_dict = Counter(layer_types)
321
- num_full_attention = layer_types_dict.pop("full_attention", 0)
322
- num_sliding_window_attention = layer_types_dict.pop("sliding_attention", 0)
323
- if len(layer_types_dict) > 0:
324
- raise ValueError(f"Unknown layer types found in the config: {layer_types_dict.keys()}")
325
-
326
- else:
327
- num_full_attention = num_hidden_layers
328
- num_sliding_window_attention = 0
329
-
330
- # Reduce available DRAM by sliding window attention kv-cache
331
- # Since memory occupation of swa layer is constant regardless of num_blocks
332
- swa_kv_nbytes = 0
333
- if num_sliding_window_attention > 0:
334
- sliding_window = getattr(model_config, "sliding_window", None)
335
- if sliding_window is None:
336
- logger.warning(
337
- "`sliding_window` is not found in the config while `sliding_attention` layers are present. "
338
- "Assuming maximum sliding window size for estimation."
339
- )
340
- sliding_window = rbln_config.kvcache_block_size
341
-
342
- swa_kv_nbytes = num_sliding_window_attention * get_dram_per_block(
343
- seq_len=sliding_window,
344
- num_key_value_heads=num_key_value_heads,
345
- tensor_parallel_size=tensor_parallel_size,
176
+ return 1
177
+
178
+ num_node = rbln_config.tensor_parallel_size or 1
179
+ alloc_per_node_without_dram = [0] * num_node
180
+
181
+ for compiled_model in compiled_models.values():
182
+ for key, alloc_per_node in compiled_model.get_alloc_per_node_by_key().items():
183
+ if key == "DramTensor":
184
+ continue
185
+
186
+ if len(alloc_per_node) != num_node:
187
+ alloc_per_node += [0] * (num_node - len(alloc_per_node))
188
+
189
+ alloc_per_node_without_dram = [a + b for a, b in zip(alloc_per_node_without_dram, alloc_per_node)]
190
+
191
+ remaining_dram_at_node: list[int] = [
192
+ available_dram - without_dramtensor for without_dramtensor in alloc_per_node_without_dram
193
+ ]
194
+
195
+ kvcache_tensor_sizes: dict[str, list[int]] = compiled_models["prefill"].exp_get_dram_tensor_sizes()
196
+ kvcache_meta_can_resize: dict[str, bool] = {
197
+ kvcache_meta.name: kvcache_meta.can_resize for kvcache_meta in rbln_config.kvcache_metas
198
+ }
199
+
200
+ def get_updated_kvcache_tensor_sizes(
201
+ kvcache_tensor_sizes: dict[str, list[int]], multiplier: int
202
+ ) -> dict[str, list[int]]:
203
+ # Get the updated KV cache tensor sizes by multiplying the multiplier
204
+ # with considering attention type (full or sliding), and memory alignment.
205
+ ret = {}
206
+ for key, sizes in kvcache_tensor_sizes.items():
207
+ m = multiplier if kvcache_meta_can_resize[key] else 1
208
+ ret[key] = [align_2MB(size * m) for size in sizes]
209
+ return ret
210
+
211
+ def check_memory_fits(multiplier: int) -> tuple[bool, list[int]]:
212
+ # Check if the given multiplier fits in memory
213
+ # Returns (fits: bool, kvcache_tensor_sizes_at_node: list[int])
214
+ updated_kvcache_tensor_sizes = get_updated_kvcache_tensor_sizes(kvcache_tensor_sizes, multiplier)
215
+
216
+ kvcache_tensor_sizes_at_node: list[int] = [0] * num_node
217
+ for tensor_sizes in updated_kvcache_tensor_sizes.values():
218
+ for node_id, size in enumerate(tensor_sizes):
219
+ kvcache_tensor_sizes_at_node[node_id] += size
220
+
221
+ fits = all(
222
+ remaining_dram_at_node[node_id] >= kvcache_tensor_sizes_at_node[node_id] for node_id in range(num_node)
346
223
  )
347
-
348
- available_dram -= swa_kv_nbytes
349
-
350
- dram_per_block = num_full_attention * get_dram_per_block(
351
- seq_len=rbln_config.kvcache_block_size,
352
- num_key_value_heads=num_key_value_heads,
353
- tensor_parallel_size=tensor_parallel_size,
224
+ return fits, kvcache_tensor_sizes_at_node
225
+
226
+ # Fast path: try maximum blocks first (most common case)
227
+ fits, _ = check_memory_fits(rbln_config.num_full_blocks)
228
+ if fits:
229
+ # Best case: maximum blocks fit in memory
230
+ return rbln_config.num_full_blocks
231
+
232
+ # Slow path: binary search for optimal multiplier
233
+ logger.debug(
234
+ f"[KVCache] Not enough memory for {rbln_config.num_full_blocks} blocks. "
235
+ f"Searching for optimal multiplier..."
354
236
  )
355
237
 
356
- if dram_per_block == 0:
357
- raise ValueError("DRAM per block is calculated as zero, cannot estimate maximum number of blocks.")
238
+ left, right = 1, rbln_config.num_full_blocks - 1
239
+ multiplier = 1 # Default to minimum if no valid multiplier found
240
+
241
+ while left <= right:
242
+ mid = (left + right) // 2
243
+ fits, kvcache_tensor_sizes_at_node = check_memory_fits(mid)
244
+
245
+ if fits:
246
+ # Memory is sufficient, try larger multiplier
247
+ multiplier = mid
248
+ left = mid + 1
249
+ else:
250
+ # Memory is insufficient, try smaller multiplier
251
+ logger.debug(
252
+ f"[KVCache] Not enough memory for {mid} blocks. Remaining DRAM: "
253
+ f"{[format_byte_size(remaining_dram) for remaining_dram in remaining_dram_at_node]}, "
254
+ f"KV cache tensor sizes: {[format_byte_size(size) for size in kvcache_tensor_sizes_at_node]}"
255
+ )
256
+ right = mid - 1
358
257
 
359
- max_n_blocks = available_dram // dram_per_block
360
- return max_n_blocks
258
+ return multiplier
361
259
 
362
260
  @classmethod
363
- def maybe_suggest_kvcache_num_blocks(
261
+ def multiply_kv_cache_num_blocks(
364
262
  cls,
365
- compiled_models: Dict[str, "rebel.RBLNCompiledModel"],
366
- model_config: "PretrainedConfig",
263
+ compiled_models: dict[str, rebel.RBLNCompiledModel],
367
264
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
368
- ) -> None:
369
- max_num_blocks = cls.get_maximum_num_blocks_by_compiled_model(
370
- compiled_models=compiled_models,
371
- model_config=model_config,
372
- rbln_config=rbln_config,
373
- )
265
+ multiplier: int,
266
+ ):
267
+ if not is_compiler_supports_buffer_resize():
268
+ raise RuntimeError(
269
+ "The installed version of rebel-compiler does not support automatic kv cache size determination. "
270
+ "Please upgrade rebel-compiler to a version that supports this feature, "
271
+ "or explicitly set 'kvcache_num_blocks' in rbln_config to manually specify the cache size."
272
+ )
374
273
 
375
- # Since our estimation logic is not always accurate,
376
- # users can set `kvcache_num_blocks` to `max_num_blocks`.
377
- # If the memory is not enough, the model will fail to compile.
378
- if rbln_config.kvcache_num_blocks < max_num_blocks:
379
- logger.warning(
380
- f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
381
- "Our analysis indicates that additional memory is available for more blocks. "
382
- f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
383
- "Please be advised that our memory estimation algorithm has limitations, "
384
- "and increasing this value may not guarantee successful model compilation."
274
+ for compiled_model in compiled_models.values():
275
+ compiled_model.exp_multiply_buffer_size(
276
+ {
277
+ kvcache_meta.name: multiplier
278
+ for kvcache_meta in rbln_config.kvcache_metas
279
+ if kvcache_meta.can_resize
280
+ }
385
281
  )
@@ -18,6 +18,8 @@ from typing import Optional, Tuple
18
18
  import torch
19
19
  from transformers.modeling_outputs import ModelOutput
20
20
 
21
+ from ..configuration_utils import RBLNModelConfig
22
+
21
23
 
22
24
  @dataclass
23
25
  class RBLNDecoderOnlyOutput(ModelOutput):
@@ -36,3 +38,26 @@ class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
36
38
  class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
37
39
  last_hidden_states: torch.FloatTensor = None
38
40
  params: Tuple[torch.FloatTensor] = None
41
+
42
+
43
+ def _validate_output_hidden_states(output_hidden_states: Optional[bool], rbln_config: RBLNModelConfig):
44
+ output_hidden_states = (
45
+ output_hidden_states if output_hidden_states is not None else rbln_config.output_hidden_states
46
+ )
47
+ if output_hidden_states != rbln_config.output_hidden_states:
48
+ raise ValueError(
49
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {rbln_config.output_hidden_states} "
50
+ f"Please compile again with the correct argument."
51
+ )
52
+
53
+ return output_hidden_states
54
+
55
+
56
+ def _validate_output_attentions(output_attentions: Optional[bool], rbln_config: RBLNModelConfig):
57
+ output_attentions = output_attentions if output_attentions is not None else rbln_config.output_attentions
58
+ if output_attentions != rbln_config.output_attentions:
59
+ raise ValueError(
60
+ f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {rbln_config.output_attentions} "
61
+ f"Please compile again with the correct argument."
62
+ )
63
+ return output_attentions