optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -118,30 +118,3 @@ class RBLNModelForImageClassificationConfig(RBLNImageModelConfig):
118
118
 
119
119
  class RBLNModelForDepthEstimationConfig(RBLNImageModelConfig):
120
120
  pass
121
-
122
-
123
- class RBLNModelForAudioClassificationConfig(RBLNModelConfig):
124
- def __init__(
125
- self,
126
- batch_size: Optional[int] = None,
127
- max_length: Optional[int] = None,
128
- num_mel_bins: Optional[int] = None,
129
- **kwargs: Any,
130
- ):
131
- """
132
- Args:
133
- batch_size (Optional[int]): The batch size for inference. Defaults to 1.
134
- max_length (Optional[int]): Maximum length of the audio input in time dimension.
135
- num_mel_bins (Optional[int]): Number of Mel frequency bins for audio processing.
136
- kwargs: Additional arguments passed to the parent RBLNModelConfig.
137
-
138
- Raises:
139
- ValueError: If batch_size is not a positive integer.
140
- """
141
- super().__init__(**kwargs)
142
- self.batch_size = batch_size or 1
143
- if not isinstance(self.batch_size, int) or self.batch_size < 0:
144
- raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
145
-
146
- self.max_length = max_length
147
- self.num_mel_bins = num_mel_bins
@@ -1,19 +1,16 @@
1
1
  import math
2
- from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
2
+ from collections import defaultdict
3
+ from typing import Optional, Tuple
3
4
 
4
- from optimum.rbln.transformers.models.decoderonly.configuration_decoderonly import (
5
- RBLNDecoderOnlyModelForCausalLMConfig,
6
- )
5
+ import rebel
7
6
 
8
7
  from ..utils.logging import get_logger
8
+ from ..utils.runtime_utils import get_available_dram, is_compiler_supports_buffer_resize
9
+ from .models.decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
9
10
 
10
11
 
11
12
  logger = get_logger()
12
13
 
13
- if TYPE_CHECKING:
14
- from rebel import RBLNCompiledModel
15
- from transformers import PretrainedConfig
16
-
17
14
 
18
15
  DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
19
16
  DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
@@ -115,138 +112,170 @@ def validate_sliding_window(rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
115
112
  raise ValueError("`use_attention_mask` must be set to False when `cache_impl` is set to 'sliding_window'.")
116
113
 
117
114
 
115
+ def align(x: int, nbytes: int) -> int:
116
+ return int(math.ceil(x / nbytes) * nbytes)
117
+
118
+
119
+ def align_2MB(x: int) -> int:
120
+ return align(x, 2**21)
121
+
122
+
123
+ def get_alloc_memory_by_key(compiled_models: dict[str, rebel.RBLNCompiledModel]) -> dict[str, int]:
124
+ alloc_memory_by_key = defaultdict(int)
125
+ # Get the actual memory allocation of each node by key
126
+ for compiled_model in compiled_models.values():
127
+ alloc_per_node_by_key = compiled_model.get_alloc_per_node_by_key()
128
+ for key, memory_per_node in alloc_per_node_by_key.items():
129
+ alloc_memory_by_key[key] += sum(memory_per_node)
130
+
131
+ return alloc_memory_by_key
132
+
133
+
134
+ def format_byte_size(nbytes: int) -> str:
135
+ if nbytes < 1024:
136
+ return f"{nbytes} B"
137
+ elif nbytes < 1024**2:
138
+ return f"{nbytes / 1024:.2f} KB"
139
+ elif nbytes < 1024**3:
140
+ return f"{nbytes / 1024**2:.2f} MB"
141
+ else:
142
+ return f"{nbytes / 1024**3:.2f} GB"
143
+
144
+
118
145
  class RBLNDecoderOnlyFlashAttentionMixin:
119
146
  @classmethod
120
- def get_maximum_num_blocks(
147
+ def set_kvcache_num_blocks_after_compilation(
148
+ cls, compiled_models: dict[str, rebel.RBLNCompiledModel], rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
149
+ ):
150
+ rbln_config.kvcache_num_blocks = cls.estimate_num_kvcache_blocks(
151
+ compiled_models=compiled_models, rbln_config=rbln_config
152
+ )
153
+ if rbln_config.kvcache_num_blocks < rbln_config.num_min_blocks:
154
+ raise ValueError(
155
+ "Memory is not enought for full sequence length. "
156
+ "Please consider decreasing `max_seq_len` to reduce the number of blocks."
157
+ )
158
+ cls.multiply_kv_cache_num_blocks(
159
+ compiled_models=compiled_models, rbln_config=rbln_config, multiplier=rbln_config.kvcache_num_blocks
160
+ )
161
+
162
+ @classmethod
163
+ def estimate_num_kvcache_blocks(
121
164
  cls,
122
- config: "PretrainedConfig",
123
- tensor_parallel_size: int,
124
- kvcache_block_size: int,
125
- nbits_per_param: Optional[int] = None,
126
- n_model_params: Optional[int] = None,
127
- kernel_size: Optional[int] = None,
128
- buffer: Optional[int] = None,
129
- num_runtimes: int = 2,
165
+ compiled_models: dict[str, rebel.RBLNCompiledModel],
166
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
167
+ available_dram: Optional[int] = None,
130
168
  ) -> int:
131
- # We are finding max_n_blocks(x) that satisfies the following equation:
132
-
133
- # available_dram - kernel_size - buffer
134
- # - num_layers * 2 * tensor_parallel_size
135
- # * align_2MB(
136
- # x
137
- # * block_size
138
- # * align_64(head_dim)
139
- # * math.ceil(num_key_value_heads / tensor_parallel_size)
140
- # * 2
141
- # ) > 0
142
-
143
- # This inequality can be rewritten as follows:
144
-
145
- # a - c * align_2MB(b * x) > 0
146
- # where
147
- # a = available_dram - kernel_size - buffer
148
- # b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
149
- # c = num_layers * 2 * tensor_parallel_size
150
-
151
- # We can rewrite the inequality as follows:
152
- # k > align_2MB(b*x)
153
- # where
154
- # k = a / c
155
-
156
- # After that, we can derive the following equation:
157
- # x = floor(2**21 / b * floor((k - 1) / 2**21))
158
-
159
- def align(x: int, nbytes: int) -> int:
160
- return int(math.ceil(x / nbytes) * nbytes)
161
-
162
- def align_2MB(x: int) -> int:
163
- return align(x, 2**21)
164
-
165
- num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
166
- num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
167
- head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
168
- vocab_size = config.vocab_size
169
- hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
170
- num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
171
-
172
- # TODO(jongho): Update if target npu is REBEL.
173
- ATOM_DRAM_NBYTES = 16 * 2**30
174
- ATOM_SYS_DRAM_NBYTES = 288 * 2**20
175
- available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
176
-
177
- if kernel_size is None:
178
- if n_model_params is None:
179
- raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
180
- # Get estimated kernel size (approximated)
181
- lm_heads_params = align(vocab_size, 64) * hidden_size
182
- lm_heads_nbytes = (
183
- align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
169
+ if available_dram is None:
170
+ available_dram = get_available_dram(rbln_config.npu)
171
+
172
+ if "prefill" not in rbln_config.phases:
173
+ logger.warning(
174
+ "Not estimating number of KV cache blocks since `prefill` phase is not in the `phases` list."
184
175
  )
185
- params = n_model_params - lm_heads_params
186
- layer_nbytes = (
187
- align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
188
- * num_layers
189
- * tensor_parallel_size
176
+ return 1
177
+
178
+ num_node = rbln_config.tensor_parallel_size or 1
179
+ alloc_per_node_without_dram = [0] * num_node
180
+
181
+ for compiled_model in compiled_models.values():
182
+ for key, alloc_per_node in compiled_model.get_alloc_per_node_by_key().items():
183
+ if key == "DramTensor":
184
+ continue
185
+
186
+ if len(alloc_per_node) != num_node:
187
+ alloc_per_node += [0] * (num_node - len(alloc_per_node))
188
+
189
+ alloc_per_node_without_dram = [a + b for a, b in zip(alloc_per_node_without_dram, alloc_per_node)]
190
+
191
+ remaining_dram_at_node: list[int] = [
192
+ available_dram - without_dramtensor for without_dramtensor in alloc_per_node_without_dram
193
+ ]
194
+
195
+ kvcache_tensor_sizes: dict[str, list[int]] = compiled_models["prefill"].exp_get_dram_tensor_sizes()
196
+ kvcache_meta_can_resize: dict[str, bool] = {
197
+ kvcache_meta.name: kvcache_meta.can_resize for kvcache_meta in rbln_config.kvcache_metas
198
+ }
199
+
200
+ def get_updated_kvcache_tensor_sizes(
201
+ kvcache_tensor_sizes: dict[str, list[int]], multiplier: int
202
+ ) -> dict[str, list[int]]:
203
+ # Get the updated KV cache tensor sizes by multiplying the multiplier
204
+ # with considering attention type (full or sliding), and memory alignment.
205
+ ret = {}
206
+ for key, sizes in kvcache_tensor_sizes.items():
207
+ m = multiplier if kvcache_meta_can_resize[key] else 1
208
+ ret[key] = [align_2MB(size * m) for size in sizes]
209
+ return ret
210
+
211
+ def check_memory_fits(multiplier: int) -> tuple[bool, list[int]]:
212
+ # Check if the given multiplier fits in memory
213
+ # Returns (fits: bool, kvcache_tensor_sizes_at_node: list[int])
214
+ updated_kvcache_tensor_sizes = get_updated_kvcache_tensor_sizes(kvcache_tensor_sizes, multiplier)
215
+
216
+ kvcache_tensor_sizes_at_node: list[int] = [0] * num_node
217
+ for tensor_sizes in updated_kvcache_tensor_sizes.values():
218
+ for node_id, size in enumerate(tensor_sizes):
219
+ kvcache_tensor_sizes_at_node[node_id] += size
220
+
221
+ fits = all(
222
+ remaining_dram_at_node[node_id] >= kvcache_tensor_sizes_at_node[node_id] for node_id in range(num_node)
190
223
  )
191
- kernel_size = layer_nbytes + lm_heads_nbytes
192
- elif n_model_params is not None:
193
- raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
224
+ return fits, kvcache_tensor_sizes_at_node
225
+
226
+ # Fast path: try maximum blocks first (most common case)
227
+ fits, _ = check_memory_fits(rbln_config.num_full_blocks)
228
+ if fits:
229
+ # Best case: maximum blocks fit in memory
230
+ return rbln_config.num_full_blocks
231
+
232
+ # Slow path: binary search for optimal multiplier
233
+ logger.debug(
234
+ f"[KVCache] Not enough memory for {rbln_config.num_full_blocks} blocks. "
235
+ f"Searching for optimal multiplier..."
236
+ )
194
237
 
195
- available_dram -= kernel_size
238
+ left, right = 1, rbln_config.num_full_blocks - 1
239
+ multiplier = 1 # Default to minimum if no valid multiplier found
196
240
 
197
- if buffer is None:
198
- # TODO: Accurate buffer estimation
199
- buffer_per_runtime_per_core = 2**28 # 256MB per runtime
200
- buffer_per_core = buffer_per_runtime_per_core * num_runtimes # 1 for prefill, 1 for decoder
201
- buffer = buffer_per_core * tensor_parallel_size
202
- available_dram -= buffer
241
+ while left <= right:
242
+ mid = (left + right) // 2
243
+ fits, kvcache_tensor_sizes_at_node = check_memory_fits(mid)
203
244
 
204
- b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
205
- c = num_layers * 2 * tensor_parallel_size
206
- k = available_dram / c
207
- max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
245
+ if fits:
246
+ # Memory is sufficient, try larger multiplier
247
+ multiplier = mid
248
+ left = mid + 1
249
+ else:
250
+ # Memory is insufficient, try smaller multiplier
251
+ logger.debug(
252
+ f"[KVCache] Not enough memory for {mid} blocks. Remaining DRAM: "
253
+ f"{[format_byte_size(remaining_dram) for remaining_dram in remaining_dram_at_node]}, "
254
+ f"KV cache tensor sizes: {[format_byte_size(size) for size in kvcache_tensor_sizes_at_node]}"
255
+ )
256
+ right = mid - 1
208
257
 
209
- return max_n_blocks
258
+ return multiplier
210
259
 
211
260
  @classmethod
212
- def maybe_suggest_kvcache_num_blocks(
261
+ def multiply_kv_cache_num_blocks(
213
262
  cls,
214
- compiled_models: Dict[str, "RBLNCompiledModel"],
215
- model_config: "PretrainedConfig",
263
+ compiled_models: dict[str, rebel.RBLNCompiledModel],
216
264
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
217
- ) -> None:
218
- # Get the actual memory allocation of each node by key
219
- alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
220
- alloc_memory_by_key: Dict[str, int] = {
221
- key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
222
- }
223
- for batch_size in rbln_config.decoder_batch_sizes:
224
- for key, memory_per_node in (
225
- compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
226
- ):
227
- alloc_memory_by_key[key] += sum(memory_per_node)
228
- alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
229
- alloc_memory_by_key.pop("DramTensor", None) # kv-cache
230
- kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
231
-
232
- # Get the maximum number of blocks that can be allocated
233
- buffer = sum(alloc_memory_by_key.values())
234
- max_num_blocks = cls.get_maximum_num_blocks(
235
- config=model_config,
236
- tensor_parallel_size=rbln_config.tensor_parallel_size,
237
- kvcache_block_size=rbln_config.kvcache_block_size,
238
- kernel_size=kernel_size,
239
- buffer=buffer,
240
- )
265
+ multiplier: int,
266
+ ):
267
+ if not is_compiler_supports_buffer_resize():
268
+ raise RuntimeError(
269
+ "The installed version of rebel-compiler does not support automatic kv cache size determination. "
270
+ "Please upgrade rebel-compiler to a version that supports this feature, "
271
+ "or explicitly set 'kvcache_num_blocks' in rbln_config to manually specify the cache size."
272
+ )
241
273
 
242
- # Since our estimation logic is not always accurate,
243
- # users can set `kvcache_num_blocks` to `max_num_blocks`.
244
- # If the memory is not enough, the model will fail to compile.
245
- if rbln_config.kvcache_num_blocks < max_num_blocks:
246
- logger.warning(
247
- f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
248
- "Our analysis indicates that additional memory is available for more blocks. "
249
- f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
250
- "Please be advised that our memory estimation algorithm has limitations, "
251
- "and increasing this value may not guarantee successful model compilation."
274
+ for compiled_model in compiled_models.values():
275
+ compiled_model.exp_multiply_buffer_size(
276
+ {
277
+ kvcache_meta.name: multiplier
278
+ for kvcache_meta in rbln_config.kvcache_metas
279
+ if kvcache_meta.can_resize
280
+ }
252
281
  )
@@ -26,7 +26,6 @@ from typing import TYPE_CHECKING, Optional, Union
26
26
  from torch import nn
27
27
  from transformers import (
28
28
  AutoModel,
29
- AutoModelForAudioClassification,
30
29
  AutoModelForDepthEstimation,
31
30
  AutoModelForImageClassification,
32
31
  AutoModelForMaskedLM,
@@ -42,7 +41,6 @@ from ..modeling import RBLNModel
42
41
  from ..utils.logging import get_logger
43
42
  from .configuration_generic import (
44
43
  RBLNImageModelConfig,
45
- RBLNModelForAudioClassificationConfig,
46
44
  RBLNTransformerEncoderConfig,
47
45
  )
48
46
 
@@ -59,7 +57,7 @@ class RBLNTransformerEncoder(RBLNModel):
59
57
  rbln_dtype = "int64"
60
58
 
61
59
  @classmethod
62
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNTransformerEncoderConfig) -> nn.Module:
60
+ def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNTransformerEncoderConfig) -> nn.Module:
63
61
  class TransformerEncoderWrapper(nn.Module):
64
62
  # Parameters to disable for RBLN compilation
65
63
  DISABLED_PARAMS = {"return_dict", "use_cache"}
@@ -268,7 +266,7 @@ class RBLNModelForDepthEstimation(RBLNImageModel):
268
266
  auto_model_class = AutoModelForDepthEstimation
269
267
 
270
268
  @classmethod
271
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNImageModelConfig):
269
+ def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNImageModelConfig):
272
270
  class ImageModelWrapper(nn.Module):
273
271
  def __init__(self, model: "PreTrainedModel", rbln_config: RBLNImageModelConfig):
274
272
  super().__init__()
@@ -280,60 +278,3 @@ class RBLNModelForDepthEstimation(RBLNImageModel):
280
278
  return output.predicted_depth
281
279
 
282
280
  return ImageModelWrapper(model, rbln_config).eval()
283
-
284
-
285
- class RBLNModelForAudioClassification(RBLNModel):
286
- """
287
- This is a generic model class that will be instantiated as one of the model classes of the library (with a audio classification head) when created with the from_pretrained() class method
288
- This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
289
-
290
- A class to convert and run pre-trained transformers based AudioClassification models on RBLN devices.
291
- It implements the methods to convert a pre-trained transformers AudioClassification model into a RBLN transformer model by:
292
-
293
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
294
- - compiling the resulting graph using the RBLN compiler.
295
-
296
- Currently, this model class only supports the 'AST' model from the transformers library. Future updates may include support for additional model types.
297
- """
298
-
299
- auto_model_class = AutoModelForAudioClassification
300
-
301
- @classmethod
302
- def _update_rbln_config(
303
- cls,
304
- preprocessors: "AutoFeatureExtractor" = None,
305
- model: Optional["PreTrainedModel"] = None,
306
- model_config: "PretrainedConfig" = None,
307
- rbln_config: Optional[RBLNModelForAudioClassificationConfig] = None,
308
- ) -> RBLNModelForAudioClassificationConfig:
309
- if rbln_config.num_mel_bins is None:
310
- rbln_config.num_mel_bins = getattr(model_config, "num_mel_bins", None)
311
- if rbln_config.num_mel_bins is None:
312
- for feature_extractor in preprocessors:
313
- if hasattr(feature_extractor, "num_mel_bins"):
314
- rbln_config.num_mel_bins = feature_extractor.num_mel_bins
315
- break
316
-
317
- if rbln_config.num_mel_bins is None:
318
- raise ValueError("`num_mel_bins` should be specified!")
319
-
320
- if rbln_config.max_length is None:
321
- rbln_config.max_length = getattr(model_config, "max_length", None)
322
- for feature_extractor in preprocessors:
323
- if hasattr(feature_extractor, "max_length"):
324
- rbln_config.max_length = feature_extractor.max_length
325
- break
326
-
327
- if rbln_config.max_length is None:
328
- raise ValueError("`max_length` should be specified!")
329
-
330
- input_info = [
331
- (
332
- "input_values",
333
- [rbln_config.batch_size, rbln_config.max_length, rbln_config.num_mel_bins],
334
- "float32",
335
- ),
336
- ]
337
-
338
- rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
339
- return rbln_config
@@ -18,12 +18,15 @@ from typing import Optional, Tuple
18
18
  import torch
19
19
  from transformers.modeling_outputs import ModelOutput
20
20
 
21
+ from ..configuration_utils import RBLNModelConfig
22
+
21
23
 
22
24
  @dataclass
23
25
  class RBLNDecoderOnlyOutput(ModelOutput):
24
26
  logits: torch.FloatTensor = None
25
27
  generate_idx: torch.Tensor = None
26
28
  padded_cache_lengths: int = None
29
+ hidden_states: Tuple[torch.FloatTensor] = None
27
30
 
28
31
 
29
32
  @dataclass
@@ -35,3 +38,26 @@ class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
35
38
  class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
36
39
  last_hidden_states: torch.FloatTensor = None
37
40
  params: Tuple[torch.FloatTensor] = None
41
+
42
+
43
+ def _validate_output_hidden_states(output_hidden_states: Optional[bool], rbln_config: RBLNModelConfig):
44
+ output_hidden_states = (
45
+ output_hidden_states if output_hidden_states is not None else rbln_config.output_hidden_states
46
+ )
47
+ if output_hidden_states != rbln_config.output_hidden_states:
48
+ raise ValueError(
49
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {rbln_config.output_hidden_states} "
50
+ f"Please compile again with the correct argument."
51
+ )
52
+
53
+ return output_hidden_states
54
+
55
+
56
+ def _validate_output_attentions(output_attentions: Optional[bool], rbln_config: RBLNModelConfig):
57
+ output_attentions = output_attentions if output_attentions is not None else rbln_config.output_attentions
58
+ if output_attentions != rbln_config.output_attentions:
59
+ raise ValueError(
60
+ f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {rbln_config.output_attentions} "
61
+ f"Please compile again with the correct argument."
62
+ )
63
+ return output_attentions