optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. optimum/rbln/__init__.py +108 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +156 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +48 -21
  52. optimum/rbln/modeling_base.py +99 -22
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +92 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +91 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  157. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  158. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  159. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  160. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  161. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  162. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  163. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
  164. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  165. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  166. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  167. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  168. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  169. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  170. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  171. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  172. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  173. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  174. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  175. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
  176. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  177. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  178. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  179. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  180. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  181. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  182. optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
  183. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  184. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  185. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  186. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  187. optimum/rbln/utils/deprecation.py +213 -0
  188. optimum/rbln/utils/hub.py +14 -3
  189. optimum/rbln/utils/runtime_utils.py +60 -18
  190. optimum/rbln/utils/submodule.py +31 -9
  191. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  192. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  193. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  194. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  195. optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
  196. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,385 @@
1
+ import math
2
+ from collections import Counter, defaultdict
3
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
4
+
5
+ import rebel
6
+
7
+ from ..utils.logging import get_logger
8
+ from ..utils.runtime_utils import get_available_dram
9
+ from .models.decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
10
+
11
+
12
+ logger = get_logger()
13
+
14
+ if TYPE_CHECKING:
15
+ from transformers import PretrainedConfig, PreTrainedModel
16
+
17
+
18
+ DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
19
+ DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
20
+ MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
21
+ MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
22
+ MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
23
+ MAX_SLIDING_WINDOW_SIZE = 32_768
24
+
25
+
26
+ def set_default_values(
27
+ attn_impl: Optional[str] = None,
28
+ kvcache_partition_len: Optional[int] = None,
29
+ kvcache_block_size: Optional[int] = None,
30
+ max_seq_len: Optional[int] = None,
31
+ ) -> Tuple[str, int, int]:
32
+ if attn_impl is None:
33
+ attn_impl = "eager"
34
+
35
+ if kvcache_partition_len is not None:
36
+ if attn_impl == "eager":
37
+ attn_impl = "flash_attn"
38
+ logger.warning(
39
+ "A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
40
+ "set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
41
+ "`attn_impl` has been automatically switched to 'flash_attn'."
42
+ )
43
+
44
+ if kvcache_partition_len is None and attn_impl == "flash_attn":
45
+ kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
46
+
47
+ if kvcache_block_size is None:
48
+ if attn_impl == "eager":
49
+ kvcache_block_size = max_seq_len
50
+ else:
51
+ kvcache_block_size = kvcache_partition_len
52
+
53
+ return attn_impl, kvcache_partition_len, kvcache_block_size
54
+
55
+
56
+ def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
57
+ if attn_impl not in ["eager", "flash_attn"]:
58
+ raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
59
+
60
+ ## Checking Constraints...
61
+ # Constraint of eager attention:
62
+ # - `max_seq_len` <= 32k
63
+
64
+ # Constraints of flash attention:
65
+ # 1. `max_seq_len` should be multiple of `partition_len`.
66
+ # 2. 4k <= `partition_len` <= 32k.
67
+ # 3. `max_seq_len` should be larger then 8k.
68
+ if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
69
+ raise ValueError(
70
+ f"`max_seq_len` is set to {max_seq_len}, "
71
+ f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
72
+ f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
73
+ " or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
74
+ )
75
+
76
+ if attn_impl == "flash_attn":
77
+ if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
78
+ raise ValueError(
79
+ f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
80
+ f"when using 'flash_attn'. Please adjust either value to meet this requirement."
81
+ )
82
+ elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
83
+ raise ValueError(
84
+ f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
85
+ f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
86
+ f"Please provide a valid value within this range."
87
+ )
88
+ elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
89
+ raise ValueError(
90
+ f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
91
+ f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
92
+ "this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
93
+ )
94
+
95
+ if kvcache_block_size is not None:
96
+ if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
97
+ raise ValueError(
98
+ f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
99
+ f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
100
+ )
101
+ elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
102
+ raise ValueError(
103
+ f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
104
+ f"must always be set equal to the `max_seq_len` {max_seq_len}."
105
+ )
106
+
107
+
108
+ def validate_sliding_window(rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
109
+ if rbln_config.sliding_window > MAX_SLIDING_WINDOW_SIZE - rbln_config.prefill_chunk_size:
110
+ raise ValueError(
111
+ f"Sliding window size ({rbln_config.sliding_window}) must be less than 32768 - prefill_chunk_size ({32768 - rbln_config.prefill_chunk_size})"
112
+ )
113
+
114
+ if rbln_config.cache_impl == "sliding_window" and rbln_config.use_attention_mask:
115
+ raise ValueError("`use_attention_mask` must be set to False when `cache_impl` is set to 'sliding_window'.")
116
+
117
+
118
+ def align(x: int, nbytes: int) -> int:
119
+ return int(math.ceil(x / nbytes) * nbytes)
120
+
121
+
122
+ def align_2MB(x: int) -> int:
123
+ return align(x, 2**21)
124
+
125
+
126
+ def get_alloc_memory_by_key(compiled_models: Dict[str, "rebel.RBLNCompiledModel"]) -> Dict[str, int]:
127
+ alloc_memory_by_key = defaultdict(int)
128
+ # Get the actual memory allocation of each node by key
129
+ for compiled_model in compiled_models.values():
130
+ alloc_per_node_by_key = compiled_model.get_alloc_per_node_by_key()
131
+ for key, memory_per_node in alloc_per_node_by_key.items():
132
+ alloc_memory_by_key[key] += sum(memory_per_node)
133
+
134
+ return alloc_memory_by_key
135
+
136
+
137
+ def format_byte_size(nbytes: int) -> str:
138
+ if nbytes < 1024:
139
+ return f"{nbytes} B"
140
+ elif nbytes < 1024**2:
141
+ return f"{nbytes / 1024:.2f} KB"
142
+ elif nbytes < 1024**3:
143
+ return f"{nbytes / 1024**2:.2f} MB"
144
+ else:
145
+ return f"{nbytes / 1024**3:.2f} GB"
146
+
147
+
148
+ class RBLNDecoderOnlyFlashAttentionMixin:
149
+ @classmethod
150
+ def get_maximum_num_blocks_by_model(
151
+ cls,
152
+ model: "PreTrainedModel",
153
+ model_config: "PretrainedConfig",
154
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
155
+ ) -> int:
156
+ tensor_parallel_size = rbln_config.tensor_parallel_size or 1
157
+ available_dram = get_available_dram(rbln_config.npu) * tensor_parallel_size
158
+
159
+ kernel_memory = cls._get_kernel_memory(model, model_config=model_config, rbln_config=rbln_config)
160
+ buffer = cls._get_buffer(rbln_config)
161
+
162
+ remaining_dram = available_dram - kernel_memory - buffer
163
+ if remaining_dram <= 0:
164
+ raise ValueError(
165
+ "Insufficient available DRAM after accounting for kernel memory and buffer. "
166
+ "Cannot allocate any KV cache blocks."
167
+ f" (Available DRAM: {format_byte_size(available_dram)}, "
168
+ f"Kernel Memory: {format_byte_size(kernel_memory)}, "
169
+ f"Buffer: {format_byte_size(buffer)})"
170
+ )
171
+ estimated_num_blocks = cls._estimate_num_blocks(
172
+ remaining_dram, model_config=model_config, rbln_config=rbln_config
173
+ )
174
+
175
+ return estimated_num_blocks
176
+
177
+ @classmethod
178
+ def _get_kernel_memory(
179
+ cls,
180
+ model: "PreTrainedModel",
181
+ model_config: "PretrainedConfig",
182
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
183
+ ) -> int:
184
+ if model.get_output_embeddings() is None:
185
+ lm_head_nbytes = 0
186
+ else:
187
+ lm_head_nbytes = cls._get_lm_head_memory(model_config, rbln_config)
188
+
189
+ layer_nbytes = cls._get_layer_memory(model, model_config, rbln_config)
190
+ return lm_head_nbytes + layer_nbytes
191
+
192
+ @classmethod
193
+ def _get_lm_head_memory(
194
+ cls, model_config: "PretrainedConfig", rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
195
+ ) -> int:
196
+ tensor_parallel_size = rbln_config.tensor_parallel_size or 1
197
+ vocab_size = model_config.vocab_size
198
+ hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
199
+ lm_head_params = align(vocab_size, 64) * hidden_size
200
+
201
+ nbytes_per_param = 2 # Assuming lm_head is always not quantized
202
+ lm_head_memory_in_bytes = (
203
+ align_2MB(lm_head_params * nbytes_per_param / tensor_parallel_size) * tensor_parallel_size
204
+ )
205
+
206
+ return lm_head_memory_in_bytes
207
+
208
+ @classmethod
209
+ def _get_layer_memory(
210
+ cls,
211
+ model: "PreTrainedModel",
212
+ model_config: "PretrainedConfig",
213
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
214
+ ) -> int:
215
+ # This is an *APPROXIMATE* calculation based on the number of parameters
216
+ tensor_parallel_size = rbln_config.tensor_parallel_size or 1
217
+ num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
218
+
219
+ n_model_params = sum(p.numel() for p in model.parameters())
220
+ embed_token_params = sum(p.numel() for p in model.get_input_embeddings().parameters())
221
+
222
+ # Check : `embed_token` is same as `lm_head`
223
+ if model.get_output_embeddings() is not None:
224
+ params = n_model_params - 2 * embed_token_params
225
+ else:
226
+ params = n_model_params - embed_token_params
227
+
228
+ # Assuming all layers have the same number of parameters
229
+ # and all linear layers are quantized if quantization is enabled (This is not always true)
230
+ # TODO(jongho): More accurate calculation
231
+ nbits_per_param = rbln_config.nbits_per_param
232
+ layer_nbytes = (
233
+ (align_2MB(params // num_hidden_layers * nbits_per_param // 8 / tensor_parallel_size))
234
+ * num_hidden_layers
235
+ * tensor_parallel_size
236
+ )
237
+
238
+ return layer_nbytes
239
+
240
+ @classmethod
241
+ def _get_buffer(cls, rbln_config) -> int:
242
+ # TODO(jongho): Accurate buffer estimation
243
+ buffer_per_runtime_per_core = 2**28 # 256MB per runtime
244
+ num_runtimes = 1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes)
245
+ tensor_parallel_size = rbln_config.tensor_parallel_size or 1
246
+
247
+ buffer_per_core = buffer_per_runtime_per_core * num_runtimes
248
+ buffer = buffer_per_core * tensor_parallel_size
249
+ return buffer
250
+
251
+ @classmethod
252
+ def get_maximum_num_blocks_by_compiled_model(
253
+ cls,
254
+ compiled_models: Dict[str, "rebel.RBLNCompiledModel"],
255
+ model_config: "PretrainedConfig",
256
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
257
+ ) -> int:
258
+ tensor_parallel_size = rbln_config.tensor_parallel_size or 1
259
+ available_dram = get_available_dram(rbln_config.npu) * tensor_parallel_size
260
+
261
+ alloc_memory_by_key = get_alloc_memory_by_key(compiled_models)
262
+ alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
263
+ alloc_memory_by_key.pop("DramTensor", None) # kv-cache
264
+ used_memory = sum(alloc_memory_by_key.values())
265
+
266
+ remaining_dram = available_dram - used_memory
267
+
268
+ if remaining_dram <= 0:
269
+ logger.warning(
270
+ "Insufficient available DRAM after accounting for kernel memory and buffer. "
271
+ "Model cannot allocate any KV cache blocks."
272
+ )
273
+
274
+ estimated_num_blocks = cls._estimate_num_blocks(
275
+ remaining_dram, model_config=model_config, rbln_config=rbln_config
276
+ )
277
+
278
+ return estimated_num_blocks
279
+
280
+ @classmethod
281
+ def _estimate_num_blocks(
282
+ cls, available_dram: int, model_config: "PretrainedConfig", rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
283
+ ) -> int:
284
+ """
285
+ Estimate the maximum number of KV cache blocks that can be allocated.
286
+
287
+ if all of the layers are full attention, the dram_per_block can be calculated simply as follows:
288
+ num_blocks = available_dram // dram_per_block
289
+
290
+ However, if the model contains a mix of full attention and sliding window attention layers,
291
+ we need to consider the memory occupied by the sliding window attention layers first,
292
+ since their memory usage is constant regardless of the number of blocks.
293
+ num_blocks = (available_dram - swa_kv_nbytes) // dram_per_block
294
+
295
+ """
296
+
297
+ def get_dram_per_block(seq_len: int, num_key_value_heads: int, tensor_parallel_size: int) -> int:
298
+ nbytes_per_param = 2 # Assuming kv-cache is always not quantized
299
+ dram_per_block = (
300
+ seq_len
301
+ * align(head_dim, 64)
302
+ * math.ceil(num_key_value_heads / tensor_parallel_size)
303
+ * nbytes_per_param
304
+ * tensor_parallel_size
305
+ * 2
306
+ ) # *2 for key and value
307
+
308
+ return dram_per_block
309
+
310
+ num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
311
+ head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
312
+ num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
313
+ num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
314
+ tensor_parallel_size = rbln_config.tensor_parallel_size or 1
315
+
316
+ # Consider layer types if available
317
+ # If layer types are not found, assume all layers are full attention
318
+ layer_types = getattr(model_config, "layer_types", None)
319
+ if layer_types:
320
+ layer_types_dict = Counter(layer_types)
321
+ num_full_attention = layer_types_dict.pop("full_attention", 0)
322
+ num_sliding_window_attention = layer_types_dict.pop("sliding_attention", 0)
323
+ if len(layer_types_dict) > 0:
324
+ raise ValueError(f"Unknown layer types found in the config: {layer_types_dict.keys()}")
325
+
326
+ else:
327
+ num_full_attention = num_hidden_layers
328
+ num_sliding_window_attention = 0
329
+
330
+ # Reduce available DRAM by sliding window attention kv-cache
331
+ # Since memory occupation of swa layer is constant regardless of num_blocks
332
+ swa_kv_nbytes = 0
333
+ if num_sliding_window_attention > 0:
334
+ sliding_window = getattr(model_config, "sliding_window", None)
335
+ if sliding_window is None:
336
+ logger.warning(
337
+ "`sliding_window` is not found in the config while `sliding_attention` layers are present. "
338
+ "Assuming maximum sliding window size for estimation."
339
+ )
340
+ sliding_window = rbln_config.kvcache_block_size
341
+
342
+ swa_kv_nbytes = num_sliding_window_attention * get_dram_per_block(
343
+ seq_len=sliding_window,
344
+ num_key_value_heads=num_key_value_heads,
345
+ tensor_parallel_size=tensor_parallel_size,
346
+ )
347
+
348
+ available_dram -= swa_kv_nbytes
349
+
350
+ dram_per_block = num_full_attention * get_dram_per_block(
351
+ seq_len=rbln_config.kvcache_block_size,
352
+ num_key_value_heads=num_key_value_heads,
353
+ tensor_parallel_size=tensor_parallel_size,
354
+ )
355
+
356
+ if dram_per_block == 0:
357
+ raise ValueError("DRAM per block is calculated as zero, cannot estimate maximum number of blocks.")
358
+
359
+ max_n_blocks = available_dram // dram_per_block
360
+ return max_n_blocks
361
+
362
+ @classmethod
363
+ def maybe_suggest_kvcache_num_blocks(
364
+ cls,
365
+ compiled_models: Dict[str, "rebel.RBLNCompiledModel"],
366
+ model_config: "PretrainedConfig",
367
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
368
+ ) -> None:
369
+ max_num_blocks = cls.get_maximum_num_blocks_by_compiled_model(
370
+ compiled_models=compiled_models,
371
+ model_config=model_config,
372
+ rbln_config=rbln_config,
373
+ )
374
+
375
+ # Since our estimation logic is not always accurate,
376
+ # users can set `kvcache_num_blocks` to `max_num_blocks`.
377
+ # If the memory is not enough, the model will fail to compile.
378
+ if rbln_config.kvcache_num_blocks < max_num_blocks:
379
+ logger.warning(
380
+ f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
381
+ "Our analysis indicates that additional memory is available for more blocks. "
382
+ f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
383
+ "Please be advised that our memory estimation algorithm has limitations, "
384
+ "and increasing this value may not guarantee successful model compilation."
385
+ )
@@ -23,9 +23,9 @@ different model architectures.
23
23
  import inspect
24
24
  from typing import TYPE_CHECKING, Optional, Union
25
25
 
26
+ from torch import nn
26
27
  from transformers import (
27
28
  AutoModel,
28
- AutoModelForAudioClassification,
29
29
  AutoModelForDepthEstimation,
30
30
  AutoModelForImageClassification,
31
31
  AutoModelForMaskedLM,
@@ -34,17 +34,13 @@ from transformers import (
34
34
  AutoModelForTextEncoding,
35
35
  PretrainedConfig,
36
36
  )
37
- from transformers.modeling_outputs import (
38
- BaseModelOutput,
39
- QuestionAnsweringModelOutput,
40
- )
37
+ from transformers.modeling_outputs import BaseModelOutput, QuestionAnsweringModelOutput
41
38
 
42
39
  from ..configuration_utils import RBLNCompileConfig
43
40
  from ..modeling import RBLNModel
44
41
  from ..utils.logging import get_logger
45
42
  from .configuration_generic import (
46
43
  RBLNImageModelConfig,
47
- RBLNModelForAudioClassificationConfig,
48
44
  RBLNTransformerEncoderConfig,
49
45
  )
50
46
 
@@ -60,6 +56,28 @@ class RBLNTransformerEncoder(RBLNModel):
60
56
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
61
57
  rbln_dtype = "int64"
62
58
 
59
+ @classmethod
60
+ def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNTransformerEncoderConfig) -> nn.Module:
61
+ class TransformerEncoderWrapper(nn.Module):
62
+ # Parameters to disable for RBLN compilation
63
+ DISABLED_PARAMS = {"return_dict", "use_cache"}
64
+
65
+ def __init__(self, model: "PreTrainedModel", rbln_config: RBLNTransformerEncoderConfig):
66
+ super().__init__()
67
+ self.model = model
68
+ self.rbln_config = rbln_config
69
+ self._forward_signature = inspect.signature(model.forward)
70
+
71
+ def forward(self, *args, **kwargs):
72
+ # Disable parameters that are not compatible with RBLN compilation
73
+ for param_name in self.DISABLED_PARAMS:
74
+ if param_name in self._forward_signature.parameters:
75
+ kwargs[param_name] = False
76
+
77
+ return self.model(*args, **kwargs)
78
+
79
+ return TransformerEncoderWrapper(model, rbln_config).eval()
80
+
63
81
  @classmethod
64
82
  def _update_rbln_config(
65
83
  cls,
@@ -130,10 +148,18 @@ class RBLNTransformerEncoder(RBLNModel):
130
148
  "This is an internal error. Please report it to the developers."
131
149
  )
132
150
 
133
- input_info = [
134
- (model_input_name, [rbln_config.batch_size, rbln_config.max_seq_len], cls.rbln_dtype)
135
- for model_input_name in rbln_config.model_input_names
136
- ]
151
+ if rbln_config.model_input_shapes is None:
152
+ input_info = [
153
+ (model_input_name, [rbln_config.batch_size, rbln_config.max_seq_len], cls.rbln_dtype)
154
+ for model_input_name in rbln_config.model_input_names
155
+ ]
156
+ else:
157
+ input_info = [
158
+ (model_input_name, model_input_shape, cls.rbln_dtype)
159
+ for model_input_name, model_input_shape in zip(
160
+ rbln_config.model_input_names, rbln_config.model_input_shapes
161
+ )
162
+ ]
137
163
 
138
164
  rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
139
165
  return rbln_config
@@ -203,7 +229,6 @@ class RBLNModelForQuestionAnswering(RBLNTransformerEncoder):
203
229
 
204
230
  def _prepare_output(self, output, return_dict):
205
231
  # Prepare QuestionAnswering specific output format.
206
-
207
232
  start_logits, end_logits = output
208
233
 
209
234
  if not return_dict:
@@ -240,58 +265,16 @@ class RBLNModelForImageClassification(RBLNImageModel):
240
265
  class RBLNModelForDepthEstimation(RBLNImageModel):
241
266
  auto_model_class = AutoModelForDepthEstimation
242
267
 
243
-
244
- class RBLNModelForAudioClassification(RBLNModel):
245
- """
246
- This is a generic model class that will be instantiated as one of the model classes of the library (with a audio classification head) when created with the from_pretrained() class method
247
- This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
248
-
249
- A class to convert and run pre-trained transformers based AudioClassification models on RBLN devices.
250
- It implements the methods to convert a pre-trained transformers AudioClassification model into a RBLN transformer model by:
251
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
252
- - compiling the resulting graph using the RBLN compiler.
253
-
254
- Currently, this model class only supports the 'AST' model from the transformers library. Future updates may include support for additional model types.
255
- """
256
-
257
- auto_model_class = AutoModelForAudioClassification
258
-
259
268
  @classmethod
260
- def _update_rbln_config(
261
- cls,
262
- preprocessors: "AutoFeatureExtractor" = None,
263
- model: Optional["PreTrainedModel"] = None,
264
- model_config: "PretrainedConfig" = None,
265
- rbln_config: Optional[RBLNModelForAudioClassificationConfig] = None,
266
- ) -> RBLNModelForAudioClassificationConfig:
267
- if rbln_config.num_mel_bins is None:
268
- rbln_config.num_mel_bins = getattr(model_config, "num_mel_bins", None)
269
- if rbln_config.num_mel_bins is None:
270
- for feature_extractor in preprocessors:
271
- if hasattr(feature_extractor, "num_mel_bins"):
272
- rbln_config.num_mel_bins = feature_extractor.num_mel_bins
273
- break
274
-
275
- if rbln_config.num_mel_bins is None:
276
- raise ValueError("`num_mel_bins` should be specified!")
277
-
278
- if rbln_config.max_length is None:
279
- rbln_config.max_length = getattr(model_config, "max_length", None)
280
- for feature_extractor in preprocessors:
281
- if hasattr(feature_extractor, "max_length"):
282
- rbln_config.max_length = feature_extractor.max_length
283
- break
284
-
285
- if rbln_config.max_length is None:
286
- raise ValueError("`max_length` should be specified!")
287
-
288
- input_info = [
289
- (
290
- "input_values",
291
- [rbln_config.batch_size, rbln_config.max_length, rbln_config.num_mel_bins],
292
- "float32",
293
- ),
294
- ]
295
-
296
- rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
297
- return rbln_config
269
+ def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNImageModelConfig):
270
+ class ImageModelWrapper(nn.Module):
271
+ def __init__(self, model: "PreTrainedModel", rbln_config: RBLNImageModelConfig):
272
+ super().__init__()
273
+ self.model = model
274
+ self.rbln_config = rbln_config
275
+
276
+ def forward(self, *args, **kwargs):
277
+ output = self.model(*args, return_dict=True, **kwargs)
278
+ return output.predicted_depth
279
+
280
+ return ImageModelWrapper(model, rbln_config).eval()
@@ -0,0 +1,37 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ from transformers.modeling_outputs import ModelOutput
20
+
21
+
22
+ @dataclass
23
+ class RBLNDecoderOnlyOutput(ModelOutput):
24
+ logits: torch.FloatTensor = None
25
+ generate_idx: torch.Tensor = None
26
+ padded_cache_lengths: int = None
27
+
28
+
29
+ @dataclass
30
+ class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
31
+ attention_mask: Optional[torch.Tensor] = None
32
+
33
+
34
+ @dataclass
35
+ class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
36
+ last_hidden_states: torch.FloatTensor = None
37
+ params: Tuple[torch.FloatTensor] = None