optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. optimum/rbln/__init__.py +108 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +156 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +48 -21
  52. optimum/rbln/modeling_base.py +99 -22
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +92 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +91 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  157. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  158. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  159. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  160. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  161. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  162. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  163. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
  164. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  165. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  166. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  167. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  168. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  169. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  170. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  171. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  172. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  173. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  174. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  175. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
  176. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  177. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  178. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  179. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  180. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  181. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  182. optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
  183. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  184. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  185. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  186. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  187. optimum/rbln/utils/deprecation.py +213 -0
  188. optimum/rbln/utils/hub.py +14 -3
  189. optimum/rbln/utils/runtime_utils.py +60 -18
  190. optimum/rbln/utils/submodule.py +31 -9
  191. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  192. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  193. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  194. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  195. optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
  196. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import math
16
- from typing import List, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
19
  from torch import nn
@@ -21,106 +21,16 @@ from transformers import PretrainedConfig, PreTrainedModel
21
21
 
22
22
  from ....utils import logging
23
23
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
24
- from .configuration_decoderonly import CacheImplType
24
+ from ...utils.rbln_quantization import RBLNQuantizationConfig
25
+ from .configuration_lora import RBLNLoRAConfig
26
+ from .lora_architecture import LoRALinear
25
27
 
26
28
 
27
- logger = logging.get_logger(__name__)
28
-
29
- DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
30
- DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
31
- MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
32
- MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
33
- MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
34
- MAX_SLIDING_WINDOW_SIZE = 32_768
35
-
36
-
37
- def set_default_values(
38
- attn_impl: Optional[str] = None,
39
- kvcache_partition_len: Optional[int] = None,
40
- kvcache_block_size: Optional[int] = None,
41
- max_seq_len: Optional[int] = None,
42
- ) -> Tuple[str, int, int]:
43
- if attn_impl is None:
44
- attn_impl = "eager"
45
-
46
- if kvcache_partition_len is not None:
47
- if attn_impl == "eager":
48
- attn_impl = "flash_attn"
49
- logger.warning(
50
- "A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
51
- "set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
52
- "`attn_impl` has been automatically switched to 'flash_attn'."
53
- )
54
-
55
- if kvcache_partition_len is None and attn_impl == "flash_attn":
56
- kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
57
-
58
- if kvcache_block_size is None:
59
- if attn_impl == "eager":
60
- kvcache_block_size = max_seq_len
61
- else:
62
- kvcache_block_size = kvcache_partition_len
63
-
64
- return attn_impl, kvcache_partition_len, kvcache_block_size
65
-
66
-
67
- def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
68
- if attn_impl not in ["eager", "flash_attn"]:
69
- raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
70
-
71
- ## Checking Constraints...
72
- # Constraint of eager attention:
73
- # - `max_seq_len` <= 32k
74
-
75
- # Constraints of flash attention:
76
- # 1. `max_seq_len` should be multiple of `partition_len`.
77
- # 2. 4k <= `partition_len` <= 32k.
78
- # 3. `max_seq_len` should be larger then 8k.
79
- if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
80
- raise ValueError(
81
- f"`max_seq_len` is set to {max_seq_len}, "
82
- f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
83
- f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
84
- " or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
85
- )
86
-
87
- if attn_impl == "flash_attn":
88
- if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
89
- raise ValueError(
90
- f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
91
- f"when using 'flash_attn'. Please adjust either value to meet this requirement."
92
- )
93
- elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
94
- raise ValueError(
95
- f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
96
- f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
97
- f"Please provide a valid value within this range."
98
- )
99
- elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
100
- raise ValueError(
101
- f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
102
- f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
103
- "this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
104
- )
105
-
106
- if kvcache_block_size is not None:
107
- if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
108
- raise ValueError(
109
- f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
110
- f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
111
- )
112
- elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
113
- raise ValueError(
114
- f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
115
- f"must always be set equal to the `max_seq_len` {max_seq_len}."
116
- )
29
+ if TYPE_CHECKING:
30
+ from .configuration_decoderonly import RBLNDecoderOnlyModelConfig
117
31
 
118
32
 
119
- def validate_sliding_window_size(sliding_window: int, prefill_chunk_size: int):
120
- if sliding_window > MAX_SLIDING_WINDOW_SIZE - prefill_chunk_size:
121
- raise ValueError(
122
- f"Sliding window size ({sliding_window}) must be less than 32768 - prefill_chunk_size ({32768 - prefill_chunk_size})"
123
- )
33
+ logger = logging.get_logger(__name__)
124
34
 
125
35
 
126
36
  class DecoderOnlyWrapper(nn.Module):
@@ -137,40 +47,22 @@ class DecoderOnlyWrapper(nn.Module):
137
47
  - Wrapper should not contain neural network graph operations (including memory view handling)
138
48
 
139
49
  Args:
140
- causal_lm (PreTrainedModel): The Huggingface causal language model to wrap
141
- max_seq_len (int): Maximum sequence length for position embeddings and cache sizes
50
+ model (PreTrainedModel): The Huggingface causal language model to wrap
51
+ rbln_config: The RBLN model configuration containing all necessary parameters
142
52
  use_rotary_emb (bool): Whether to use rotary position embeddings
143
- attn_impl (str): The attention implementation to use.
144
- - "eager": Uses the standard attention.
145
- - "flash_attn": Uses flash attention. When set,
146
- the key/value cache is partitioned into chunks of length
147
- `kvcache_partition_len`.
148
- kvcache_partition_len (Optional[int]): Length of KV cache partitions for flash attention.
149
- This is only relevant if `attn_impl` is set to "flash_attn`
150
53
  """
151
54
 
152
55
  _use_learned_pos_emb = False
153
56
 
154
- def __init__(
155
- self,
156
- causal_lm: PreTrainedModel,
157
- max_seq_len: int,
158
- use_rotary_emb: bool,
159
- attn_impl: str,
160
- cache_impl: CacheImplType,
161
- use_inputs_embeds: bool,
162
- use_attention_mask: bool,
163
- use_position_ids: bool,
164
- kvcache_partition_len: Optional[int] = None,
165
- kvcache_block_size: Optional[int] = None,
166
- sliding_window: Optional[int] = None,
167
- sliding_window_layers: Optional[List[int]] = None,
168
- ):
57
+ def __init__(self, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig", use_rotary_emb: bool):
169
58
  super().__init__()
170
- self.config = causal_lm.config
59
+ self.quantization = rbln_config.quantization
60
+ self.config = model.config
61
+ self.is_causal_lm = getattr(model, "lm_head", None) is not None
62
+ self.rbln_config = rbln_config
171
63
 
172
64
  if use_rotary_emb:
173
- rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
65
+ rotary_embs = self.get_rotary_emb(max_seq_len=rbln_config.max_seq_len)
174
66
  if isinstance(rotary_embs, tuple):
175
67
  self.rotary_emb_global, self.rotary_emb_local = rotary_embs
176
68
  else:
@@ -178,43 +70,27 @@ class DecoderOnlyWrapper(nn.Module):
178
70
  else:
179
71
  self.rotary_emb = None
180
72
 
181
- self.attn_impl = attn_impl
182
- self.kvcache_block_size = kvcache_block_size
183
- self.use_attention_mask = use_attention_mask
184
- self.use_position_ids = use_position_ids
185
- self.use_inputs_embeds = use_inputs_embeds
186
- self.sliding_window_layers = sliding_window_layers
187
- self.cache_impl = cache_impl
188
- self.sliding_window = sliding_window
189
-
190
- if self.attn_impl == "flash_attn":
191
- self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
192
- elif self.attn_impl == "eager":
193
- self.kvcache_partition_len = None
194
- else:
195
- raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
196
-
197
- if kvcache_partition_len and kvcache_partition_len > max_seq_len:
73
+ if rbln_config.kvcache_partition_len and rbln_config.kvcache_partition_len > rbln_config.max_seq_len:
198
74
  raise ValueError(
199
- f"kvcache_partition_len({kvcache_partition_len}) should be lower"
200
- f" or equal to max_seq_len({max_seq_len})!"
75
+ f"kvcache_partition_len({rbln_config.kvcache_partition_len}) should be lower"
76
+ f" or equal to max_seq_len({rbln_config.max_seq_len})!"
201
77
  )
202
78
 
203
- self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm, max_seq_len)
79
+ self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
204
80
  self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
205
81
  self._phase = "prefill"
206
82
 
207
83
  def get_rotary_emb(self, max_seq_len):
208
84
  return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
209
85
 
210
- def get_decoder_layers(self, causal_lm: PreTrainedModel):
211
- return causal_lm.model.layers
86
+ def get_decoder_layers(self, model: PreTrainedModel):
87
+ return model.model.layers if self.is_causal_lm else model.layers
212
88
 
213
89
  def get_attn_layer(self, layer: nn.Module):
214
90
  return layer.self_attn
215
91
 
216
- def get_model_layer(self, causal_lm: PreTrainedModel):
217
- return causal_lm.model
92
+ def get_model_layer(self, model: PreTrainedModel):
93
+ return model.model if self.is_causal_lm else model
218
94
 
219
95
  def get_rbln_attn_class(self):
220
96
  return DecoderOnlyAttention
@@ -228,35 +104,28 @@ class DecoderOnlyWrapper(nn.Module):
228
104
  def get_rbln_causal_lm_class(self):
229
105
  return DecoderOnlyForCausalLM
230
106
 
231
- def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
107
+ def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
232
108
  new_layers = []
233
- for layer_idx, layer in enumerate(self.get_decoder_layers(causal_lm)):
234
- is_sliding = layer_idx in self.sliding_window_layers
109
+ for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
110
+ is_sliding = layer_idx in self.rbln_config.sliding_window_layers
235
111
  new_self_attn = self.get_rbln_attn_class()(
236
- self.get_attn_layer(layer),
237
- self.use_attention_mask if not is_sliding else True,
238
- self.use_position_ids,
239
- kvcache_block_size=self.sliding_window
240
- if layer_idx in self.sliding_window_layers
241
- else self.kvcache_block_size,
242
- is_sliding=is_sliding,
243
- attn_impl=self.attn_impl if not is_sliding else "eager",
244
- kvcache_partition_len=self.kvcache_partition_len,
112
+ self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
245
113
  )
246
- new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
114
+ new_layer = self.get_rbln_layer_class()(layer, new_self_attn, lora_config=self.rbln_config.lora_config)
247
115
  new_layers.append(new_layer)
248
116
 
249
117
  new_model = self.get_rbln_model_class()(
250
- self.get_model_layer(causal_lm),
118
+ self.get_model_layer(model),
251
119
  new_layers,
252
- partition_len=self.kvcache_partition_len,
253
- max_seq_len=max_seq_len,
254
- kvcache_block_size=self.kvcache_block_size,
120
+ self.rbln_config,
255
121
  use_learned_pos_emb=self.__class__._use_learned_pos_emb,
256
- sliding_window_layers=self.sliding_window_layers,
257
122
  )
258
- new_causal_lm = self.get_rbln_causal_lm_class()(causal_lm, new_model)
259
- return new_causal_lm
123
+
124
+ if self.is_causal_lm:
125
+ new_model = self.get_rbln_causal_lm_class()(model, new_model)
126
+ return new_model
127
+ else:
128
+ return new_model
260
129
 
261
130
  @property
262
131
  def phase(self) -> str:
@@ -265,18 +134,24 @@ class DecoderOnlyWrapper(nn.Module):
265
134
  @phase.setter
266
135
  def phase(self, phase: str):
267
136
  self._phase = phase
268
- self.causal_lm.phase = phase
137
+ self.model.phase = phase
269
138
 
270
139
  def prepare_forward_args(self, *args):
271
140
  args = list(args)
272
- input_ids = None if self.use_inputs_embeds else args.pop(0)
273
- inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
141
+ input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
142
+ inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
274
143
  cache_position = args.pop(0)
275
- global_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "static"] else None
276
- local_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "sliding_window"] else None
277
- query_position = args.pop(0) if "prefill" in self.phase else None
278
- attention_mask = args.pop(0) if self.use_attention_mask else None
279
- position_ids = args.pop(0) if self.use_position_ids else None
144
+ global_block_tables = args.pop(0) if self.rbln_config.use_global_attention else None
145
+ local_block_tables = args.pop(0) if self.rbln_config.use_local_attention else None
146
+ query_position = (
147
+ args.pop(0)
148
+ # query_position usage: 1. causal_lm prefill or 2. sliding_window cache_position
149
+ if ("prefill" in self.phase and (self.is_causal_lm or self.rbln_config.use_local_attention))
150
+ else None
151
+ )
152
+ attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
153
+ position_ids = args.pop(0) if self.rbln_config.use_position_ids else None
154
+ lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
280
155
  past_key_values = args
281
156
 
282
157
  if len(past_key_values) != 2 * self.num_hidden_layers:
@@ -308,6 +183,7 @@ class DecoderOnlyWrapper(nn.Module):
308
183
  query_position,
309
184
  attention_mask,
310
185
  position_ids,
186
+ lora_int_id,
311
187
  past_key_values,
312
188
  rotary_emb,
313
189
  )
@@ -322,11 +198,12 @@ class DecoderOnlyWrapper(nn.Module):
322
198
  query_position,
323
199
  attention_mask,
324
200
  position_ids,
201
+ lora_int_id,
325
202
  past_key_values,
326
203
  rotary_emb,
327
204
  ) = self.prepare_forward_args(*args)
328
205
 
329
- logit = self.causal_lm(
206
+ logit = self.model(
330
207
  input_ids=input_ids,
331
208
  inputs_embeds=inputs_embeds,
332
209
  attention_mask=attention_mask,
@@ -337,6 +214,7 @@ class DecoderOnlyWrapper(nn.Module):
337
214
  rotary_emb=rotary_emb,
338
215
  global_block_tables=global_block_tables,
339
216
  local_block_tables=local_block_tables,
217
+ lora_int_id=lora_int_id,
340
218
  )
341
219
 
342
220
  return logit
@@ -393,6 +271,7 @@ class DecoderOnlyForCausalLM(nn.Module):
393
271
  rotary_emb: nn.Module = None,
394
272
  global_block_tables: Optional[torch.Tensor] = None,
395
273
  local_block_tables: Optional[torch.Tensor] = None,
274
+ lora_int_id: Optional[torch.Tensor] = None,
396
275
  ):
397
276
  # outputs
398
277
  hidden_states = self.model(
@@ -406,6 +285,7 @@ class DecoderOnlyForCausalLM(nn.Module):
406
285
  rotary_emb=rotary_emb,
407
286
  global_block_tables=global_block_tables,
408
287
  local_block_tables=local_block_tables,
288
+ lora_int_id=lora_int_id,
409
289
  )
410
290
 
411
291
  if "prefill" in self.phase:
@@ -428,6 +308,8 @@ class DecoderOnlyModel(nn.Module):
428
308
  Args:
429
309
  model: Original Huggingface model to adapt
430
310
  layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
311
+ rbln_config: RBLN model configuration
312
+ use_learned_pos_emb: Whether to use learned position embeddings (class-specific override)
431
313
 
432
314
  Attributes:
433
315
  _original_mod: Reference to original Huggingface model
@@ -439,21 +321,19 @@ class DecoderOnlyModel(nn.Module):
439
321
  self,
440
322
  model,
441
323
  layers: List["DecoderOnlyLayer"],
442
- partition_len=None,
443
- max_seq_len=None,
444
- kvcache_block_size=None,
324
+ rbln_config: "RBLNDecoderOnlyModelConfig",
445
325
  use_learned_pos_emb=None,
446
- sliding_window_layers=None,
447
326
  ):
448
327
  super().__init__()
449
328
  self._original_mod = model
450
329
  self.layers = nn.ModuleList(layers)
330
+ self.rbln_config = rbln_config
451
331
  self._phase = "prefill"
452
- self.partition_len = partition_len
453
- self.kvcache_block_size = kvcache_block_size
454
- self.max_seq_len = max_seq_len
332
+ self.partition_len = rbln_config.kvcache_partition_len
333
+ self.kvcache_block_size = rbln_config.kvcache_block_size
334
+ self.max_seq_len = rbln_config.max_seq_len
455
335
  self.use_learned_pos_emb = use_learned_pos_emb
456
- self.sliding_window_layers = sliding_window_layers
336
+ self.sliding_window_layers = rbln_config.sliding_window_layers
457
337
 
458
338
  @property
459
339
  def phase(self):
@@ -517,6 +397,7 @@ class DecoderOnlyModel(nn.Module):
517
397
  rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
518
398
  global_block_tables: Optional[torch.Tensor] = None,
519
399
  local_block_tables: Optional[torch.Tensor] = None,
400
+ lora_int_id: Optional[torch.Tensor] = None,
520
401
  ):
521
402
  # retrieve input_ids and inputs_embeds
522
403
  if (input_ids is None) ^ (inputs_embeds is not None):
@@ -589,6 +470,7 @@ class DecoderOnlyModel(nn.Module):
589
470
  cos=cos,
590
471
  sin=sin,
591
472
  block_tables=local_block_tables if is_sliding else global_block_tables,
473
+ lora_int_id=lora_int_id,
592
474
  )
593
475
 
594
476
  hidden_states = self.get_last_layernorm()(hidden_states)
@@ -620,11 +502,27 @@ class DecoderOnlyLayer(nn.Module):
620
502
  phase: Current operation phase ("prefill" or "decode")
621
503
  """
622
504
 
623
- def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
505
+ def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
624
506
  super().__init__()
625
507
  self._original_mod = layer
626
508
  self.self_attn = self_attn
627
509
  self._phase = "prefill"
510
+ self.lora_config = lora_config
511
+
512
+ # Replace target Linear modules in MLP with LoRALinear if configured
513
+ if self.lora_config:
514
+ mlp = self.get_mlp()
515
+ for proj_name in ["gate_proj", "up_proj", "down_proj"]:
516
+ if hasattr(mlp, proj_name):
517
+ original_linear = getattr(mlp, proj_name)
518
+ if isinstance(original_linear, nn.Linear):
519
+ lora_linear = LoRALinear(
520
+ original_linear=original_linear,
521
+ lora_config=self.lora_config,
522
+ projection_name=proj_name,
523
+ layer_idx=self.self_attn.layer_idx,
524
+ )
525
+ setattr(mlp, proj_name, lora_linear)
628
526
 
629
527
  @property
630
528
  def phase(self):
@@ -641,6 +539,25 @@ class DecoderOnlyLayer(nn.Module):
641
539
  def get_post_attention_layernorm(self) -> nn.LayerNorm:
642
540
  return self._original_mod.post_attention_layernorm
643
541
 
542
+ def get_mlp(self) -> nn.Module:
543
+ return self._original_mod.mlp
544
+
545
+ def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
546
+ mlp = self.get_mlp()
547
+ if self.lora_config and lora_int_id is not None:
548
+ gate = mlp.gate_proj(hidden_states, lora_int_id)
549
+ up = mlp.up_proj(hidden_states, lora_int_id)
550
+ act_fn = getattr(mlp, "act_fn", None) or getattr(mlp, "activation_fn", None)
551
+ if act_fn is None:
552
+ gate = torch.nn.functional.silu(gate)
553
+ else:
554
+ gate = act_fn(gate)
555
+ fused = gate * up
556
+ hidden_states = mlp.down_proj(fused, lora_int_id)
557
+ else:
558
+ hidden_states = mlp(hidden_states)
559
+ return hidden_states
560
+
644
561
  def forward(
645
562
  self,
646
563
  hidden_states: torch.Tensor,
@@ -650,6 +567,7 @@ class DecoderOnlyLayer(nn.Module):
650
567
  cos: Optional[torch.Tensor] = None,
651
568
  sin: Optional[torch.Tensor] = None,
652
569
  block_tables: Optional[torch.Tensor] = None,
570
+ lora_int_id: Optional[torch.Tensor] = None,
653
571
  ):
654
572
  residual = hidden_states
655
573
  hidden_states = self.get_pre_attention_layernorm()(hidden_states)
@@ -662,13 +580,14 @@ class DecoderOnlyLayer(nn.Module):
662
580
  cos=cos,
663
581
  sin=sin,
664
582
  block_tables=block_tables,
583
+ lora_int_id=lora_int_id,
665
584
  )
666
585
  hidden_states = residual + hidden_states
667
586
 
668
587
  # Fully Connected
669
588
  residual = hidden_states
670
589
  hidden_states = self.get_post_attention_layernorm()(hidden_states)
671
- hidden_states = self._original_mod.mlp(hidden_states)
590
+ hidden_states = self.forward_mlp(hidden_states, lora_int_id)
672
591
  hidden_states = residual + hidden_states
673
592
 
674
593
  return hidden_states
@@ -683,32 +602,27 @@ class DecoderOnlyAttention(nn.Module):
683
602
 
684
603
  Args:
685
604
  self_attn: Original attention module from the base model
686
- use_attention_mask: Whether to use attention mask
687
- use_position_ids: Whether to use position ids
688
- kvcache_block_size: Block size for KV cache
605
+ rbln_config: RBLN model configuration containing attention parameters
689
606
  is_sliding: Whether this is sliding window attention
690
- attn_impl: Attention implementation type ("eager" or "flash_attn")
691
607
  """
692
608
 
693
609
  def __init__(
694
610
  self,
695
611
  self_attn,
696
- use_attention_mask,
697
- use_position_ids,
698
- kvcache_block_size,
612
+ rbln_config: "RBLNDecoderOnlyModelConfig",
699
613
  is_sliding=False,
700
- attn_impl="eager",
701
- kvcache_partition_len=None,
702
614
  ):
703
615
  super().__init__()
704
616
  self._original_mod = self_attn
617
+ self.rbln_config = rbln_config
705
618
  self.layer_idx = self_attn.layer_idx
706
619
  self.num_heads = getattr(self._original_mod, "num_heads", None) or getattr(
707
620
  self._original_mod.config, "num_attention_heads"
708
621
  )
709
622
  self.head_dim = self._original_mod.head_dim
710
623
  self._phase = "prefill"
711
- self.scale = torch.tensor(self.get_attn_scale())
624
+ self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale()))
625
+ self.quantization = rbln_config.quantization
712
626
 
713
627
  if hasattr(self._original_mod, "num_key_value_heads"):
714
628
  self.num_key_value_heads = self._original_mod.num_key_value_heads
@@ -717,16 +631,29 @@ class DecoderOnlyAttention(nn.Module):
717
631
  else:
718
632
  self.num_key_value_heads = self.num_heads
719
633
 
720
- self.use_attention_mask = use_attention_mask
721
- self.use_position_ids = use_position_ids
634
+ self.use_attention_mask = rbln_config.use_attention_mask if not is_sliding else True
635
+ self.use_position_ids = rbln_config.use_position_ids
722
636
  self.is_sliding = is_sliding
723
- self.attn_impl = attn_impl
724
- self.kvcache_partition_len = kvcache_partition_len
637
+ self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
638
+ self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
639
+ self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
640
+ self.lora_config = rbln_config.lora_config
725
641
 
726
642
  setattr(self, self.get_attention_name(), self.create_attention_op())
727
- self.kvcache_block_size = kvcache_block_size
728
643
  self.__post_init__()
729
644
 
645
+ def _init_lora_weights(self):
646
+ """Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
647
+ for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
648
+ original_linear = getattr(self._original_mod, proj_name)
649
+ lora_linear = LoRALinear(
650
+ original_linear=original_linear,
651
+ lora_config=self.lora_config,
652
+ projection_name=proj_name,
653
+ layer_idx=self.layer_idx,
654
+ )
655
+ setattr(self, proj_name, lora_linear)
656
+
730
657
  def get_attention_name(self):
731
658
  if self.is_sliding:
732
659
  return "sliding_window_attention"
@@ -764,6 +691,7 @@ class DecoderOnlyAttention(nn.Module):
764
691
  self.kvcache_partition_len,
765
692
  self.use_attention_mask,
766
693
  self.use_position_ids,
694
+ self.quantization,
767
695
  )
768
696
  elif self.attn_impl == "eager":
769
697
  return AttentionOp(
@@ -772,28 +700,46 @@ class DecoderOnlyAttention(nn.Module):
772
700
  self.num_key_value_heads,
773
701
  self.use_attention_mask,
774
702
  self.use_position_ids,
703
+ self.quantization,
775
704
  )
776
705
  else:
777
706
  raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
778
707
 
779
708
  def __post_init__(self):
780
- self.q_proj = self._original_mod.q_proj
781
- self.k_proj = self._original_mod.k_proj
782
- self.v_proj = self._original_mod.v_proj
783
- self.o_proj = self._original_mod.o_proj
784
-
785
- def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
709
+ # Initialize LoRA weights if configured, which will replace linear layers
710
+ if self.lora_config:
711
+ self._init_lora_weights()
712
+ else:
713
+ # Use original linear layers if no LoRA
714
+ self.q_proj = self._original_mod.q_proj
715
+ self.k_proj = self._original_mod.k_proj
716
+ self.v_proj = self._original_mod.v_proj
717
+ self.o_proj = self._original_mod.o_proj
718
+
719
+ def projection(
720
+ self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
721
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
786
722
  """Projects input hidden states into query, key, and value representations.
787
723
 
788
724
  Args:
789
725
  hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
726
+ lora_int_id: Adapter ID tensor for LoRA selection [batch_size]
790
727
 
791
728
  Returns:
792
729
  Tuple of (query_states, key_states, value_states)
793
730
  """
794
- query_states = self.q_proj(hidden_states)
795
- key_states = self.k_proj(hidden_states)
796
- value_states = self.v_proj(hidden_states)
731
+ # Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
732
+ if self.lora_config:
733
+ # LoRALinear handles both base projection and LoRA in one forward pass
734
+ query_states = self.q_proj(hidden_states, lora_int_id)
735
+ key_states = self.k_proj(hidden_states, lora_int_id)
736
+ value_states = self.v_proj(hidden_states, lora_int_id)
737
+ else:
738
+ # Standard linear projection without LoRA
739
+ query_states = self.q_proj(hidden_states)
740
+ key_states = self.k_proj(hidden_states)
741
+ value_states = self.v_proj(hidden_states)
742
+
797
743
  return query_states, key_states, value_states
798
744
 
799
745
  def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
@@ -802,6 +748,16 @@ class DecoderOnlyAttention(nn.Module):
802
748
  def get_attn_scale(self):
803
749
  return 1 / math.sqrt(self.head_dim)
804
750
 
751
+ def maybe_get_kvcache_scale(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
752
+ if hasattr(self, "k_proj") and hasattr(self, "v_proj"):
753
+ k_scale = getattr(self.k_proj, "k_scale", None)
754
+ v_scale = getattr(self.v_proj, "v_scale", None)
755
+ else:
756
+ k_scale = None
757
+ v_scale = None
758
+
759
+ return k_scale, v_scale
760
+
805
761
  def forward(
806
762
  self,
807
763
  hidden_states: torch.Tensor,
@@ -811,10 +767,11 @@ class DecoderOnlyAttention(nn.Module):
811
767
  cos: Optional[torch.Tensor] = None,
812
768
  sin: Optional[torch.Tensor] = None,
813
769
  block_tables: Optional[torch.Tensor] = None,
770
+ lora_int_id: Optional[torch.Tensor] = None,
814
771
  ):
815
772
  batch_size, query_length, _ = hidden_states.size()
816
773
 
817
- query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
774
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states, lora_int_id=lora_int_id)
818
775
 
819
776
  query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
820
777
  key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -831,6 +788,8 @@ class DecoderOnlyAttention(nn.Module):
831
788
  if batch_size > 1 and "prefill" in self.phase:
832
789
  raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
833
790
 
791
+ k_scale, v_scale = self.maybe_get_kvcache_scale()
792
+
834
793
  attn_output = self.get_attention_op()(
835
794
  query_states,
836
795
  key_states,
@@ -842,9 +801,18 @@ class DecoderOnlyAttention(nn.Module):
842
801
  scale=self.scale,
843
802
  block_tables=block_tables,
844
803
  block_size=self.kvcache_block_size,
804
+ k_scale=k_scale,
805
+ v_scale=v_scale,
845
806
  )
846
807
 
847
- attn_outputs = self.o_proj(attn_output)
808
+ # Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
809
+ if self.lora_config:
810
+ # LoRALinear handles both base projection and LoRA in one forward pass
811
+ attn_outputs = self.o_proj(attn_output, lora_int_id)
812
+ else:
813
+ # Standard linear projection without LoRA
814
+ attn_outputs = self.o_proj(attn_output)
815
+
848
816
  return attn_outputs
849
817
 
850
818
 
@@ -858,7 +826,13 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
858
826
 
859
827
  class AttentionOp(nn.Module):
860
828
  def __init__(
861
- self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool, use_position_ids: bool
829
+ self,
830
+ num_heads: int,
831
+ head_dim: int,
832
+ num_key_value_heads: int,
833
+ use_attention_mask: bool,
834
+ use_position_ids: bool,
835
+ quantization: Optional[RBLNQuantizationConfig] = None,
862
836
  ):
863
837
  super().__init__()
864
838
  self.num_heads = num_heads
@@ -867,10 +841,10 @@ class AttentionOp(nn.Module):
867
841
  self.phase = "prefill"
868
842
  self.use_attention_mask = use_attention_mask
869
843
  self.use_position_ids = use_position_ids
844
+ self.quantization = quantization
870
845
 
871
846
  def get_attn_op_name(self):
872
847
  phase = "decode" if self.phase == "decode" else "prefill"
873
-
874
848
  if self.use_attention_mask and not self.use_position_ids:
875
849
  attn_op_name = "paged_attn_"
876
850
  else:
@@ -878,6 +852,9 @@ class AttentionOp(nn.Module):
878
852
 
879
853
  attn_op_name += phase
880
854
 
855
+ if self.quantization and self.quantization.kv_caches == "fp8":
856
+ attn_op_name += "_kv_fp8"
857
+
881
858
  return attn_op_name
882
859
 
883
860
  def forward(
@@ -892,6 +869,8 @@ class AttentionOp(nn.Module):
892
869
  scale: torch.Tensor,
893
870
  block_tables: torch.Tensor,
894
871
  block_size: int,
872
+ k_scale: Optional[torch.Tensor] = None,
873
+ v_scale: Optional[torch.Tensor] = None,
895
874
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
896
875
  """Compute attention with static shapes and explicit cache management.
897
876
 
@@ -904,6 +883,10 @@ class AttentionOp(nn.Module):
904
883
  past_value_state: Previous value cache states
905
884
  seq_position: Current position in sequence
906
885
  scale: Scale applied to attn weights
886
+ block_tables: Block tables for paged attention
887
+ block_size: Block size for paged attention
888
+ k_scale: Scale applied to key
889
+ v_scale: Scale applied to value
907
890
 
908
891
  Returns:
909
892
  Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
@@ -940,13 +923,19 @@ class AttentionOp(nn.Module):
940
923
  "block_size": block_size,
941
924
  }
942
925
 
943
- if self.use_attention_mask != self.use_position_ids:
926
+ if self.use_attention_mask:
944
927
  op_args["mask"] = attn_mask
945
928
 
946
929
  if self.phase == "prefill" or self.phase == "image_prefill":
947
930
  if not self.use_attention_mask or self.use_position_ids:
948
931
  op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
949
932
 
933
+ if self.quantization and self.quantization.kv_caches == "fp8":
934
+ if past_key_state.dtype != torch.float8_e4m3fn:
935
+ raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
936
+ op_args["k_scale"] = k_scale
937
+ op_args["v_scale"] = v_scale
938
+
950
939
  attn_op_name = self.get_attn_op_name()
951
940
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
952
941
  if attn_op is None:
@@ -960,97 +949,6 @@ class AttentionOp(nn.Module):
960
949
  return attn_output
961
950
 
962
951
 
963
- def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
964
- """Slice cos[cache_position], sin[cache_position] vector for the query."""
965
- if cache_position.shape[0] > 1:
966
- cos_all = []
967
- sin_all = []
968
- for i in range(cache_position.shape[0]):
969
- cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
970
- sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
971
- cos = torch.cat(cos_all, dim=0)
972
- sin = torch.cat(sin_all, dim=0)
973
- else:
974
- cos = cos[cache_position].unsqueeze(unsqueeze_dim)
975
- sin = sin[cache_position].unsqueeze(unsqueeze_dim)
976
-
977
- return cos, sin
978
-
979
-
980
- def rotate_half(x):
981
- """Rotates half the hidden dims of the input."""
982
- x1 = x[..., : x.shape[-1] // 2]
983
- x2 = x[..., x.shape[-1] // 2 :]
984
- return torch.cat((-x2, x1), dim=-1)
985
-
986
-
987
- def apply_rotary_pos_emb(q, k, cos, sin):
988
- """Applies Rotary Position Embedding to the query and key tensors."""
989
- q_embed = (q * cos) + (rotate_half(q) * sin)
990
- k_embed = (k * cos) + (rotate_half(k) * sin)
991
- return q_embed, k_embed
992
-
993
-
994
- def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
995
- # Partial rotary embedding
996
- query_rot, query_pass = (
997
- query_states[..., :ndim],
998
- query_states[..., ndim:],
999
- )
1000
- key_rot, key_pass = (
1001
- key_states[..., :ndim],
1002
- key_states[..., ndim:],
1003
- )
1004
-
1005
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
1006
- query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
1007
-
1008
- # [batch_size, seq_length, num_heads, head_dim]
1009
- query_states = torch.cat((query_rot, query_pass), dim=-1)
1010
- key_states = torch.cat((key_rot, key_pass), dim=-1)
1011
- return query_states, key_states
1012
-
1013
-
1014
- class RotaryEmbedding(nn.Module):
1015
- """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
1016
-
1017
- def __init__(
1018
- self,
1019
- config: PretrainedConfig,
1020
- max_seq_len_cached: int,
1021
- ):
1022
- super().__init__()
1023
-
1024
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1025
- rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
1026
- else:
1027
- rope_type = "default"
1028
-
1029
- inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
1030
- cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
1031
- cache_position_expanded = cache_position[:, None]
1032
-
1033
- if rope_type == "dynamic":
1034
- freqs = cache_position_expanded.float() * inv_freq.float()
1035
- else:
1036
- inv_freq_expanded = inv_freq[None, :]
1037
- freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
1038
-
1039
- emb = torch.cat((freqs, freqs), dim=-1)
1040
-
1041
- cos = emb.cos() * attention_scaling
1042
- sin = emb.sin() * attention_scaling
1043
-
1044
- self.register_buffer("_cos_cached", cos, persistent=False)
1045
- self.register_buffer("_sin_cached", sin, persistent=False)
1046
-
1047
- def forward(self, x, seq_len):
1048
- return (
1049
- self._cos_cached[:seq_len].to(dtype=x.dtype),
1050
- self._sin_cached[:seq_len].to(dtype=x.dtype),
1051
- )
1052
-
1053
-
1054
952
  class FlashAttentionOp(AttentionOp):
1055
953
  def __init__(
1056
954
  self,
@@ -1060,6 +958,7 @@ class FlashAttentionOp(AttentionOp):
1060
958
  kvcache_partition_len: int,
1061
959
  use_attention_mask: bool,
1062
960
  use_position_ids: bool,
961
+ quantization: Optional[RBLNQuantizationConfig] = None,
1063
962
  ):
1064
963
  super().__init__(
1065
964
  num_heads=num_heads,
@@ -1067,6 +966,7 @@ class FlashAttentionOp(AttentionOp):
1067
966
  num_key_value_heads=num_key_value_heads,
1068
967
  use_attention_mask=use_attention_mask,
1069
968
  use_position_ids=use_position_ids,
969
+ quantization=quantization,
1070
970
  )
1071
971
  self.kvcache_partition_size = kvcache_partition_len
1072
972
 
@@ -1079,6 +979,9 @@ class FlashAttentionOp(AttentionOp):
1079
979
 
1080
980
  attn_op_name += phase
1081
981
 
982
+ if self.quantization and self.quantization.kv_caches == "fp8":
983
+ attn_op_name += "_kv_fp8"
984
+
1082
985
  return attn_op_name
1083
986
 
1084
987
  def forward(
@@ -1093,6 +996,8 @@ class FlashAttentionOp(AttentionOp):
1093
996
  scale,
1094
997
  block_tables,
1095
998
  block_size,
999
+ k_scale=None,
1000
+ v_scale=None,
1096
1001
  ):
1097
1002
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1098
1003
  key_state = key_state.unsqueeze(2)
@@ -1133,6 +1038,12 @@ class FlashAttentionOp(AttentionOp):
1133
1038
  if not self.use_attention_mask or self.use_position_ids:
1134
1039
  op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1135
1040
 
1041
+ if self.quantization and self.quantization.kv_caches == "fp8":
1042
+ if past_key_state.dtype != torch.float8_e4m3fn:
1043
+ raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
1044
+ op_args["k_scale"] = k_scale
1045
+ op_args["v_scale"] = v_scale
1046
+
1136
1047
  attn_op_name = self.get_attn_op_name()
1137
1048
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
1138
1049
  if attn_op is None:
@@ -1160,14 +1071,19 @@ class SlidingWindowAttentionOp(AttentionOp):
1160
1071
  query_state: torch.Tensor,
1161
1072
  key_state: torch.Tensor,
1162
1073
  value_state: torch.Tensor,
1163
- attn_mask: torch.Tensor,
1074
+ attn_mask: Optional[torch.Tensor],
1164
1075
  past_key_state: torch.Tensor,
1165
1076
  past_value_state: torch.Tensor,
1166
1077
  seq_position: Tuple[torch.Tensor],
1167
1078
  scale: torch.Tensor,
1168
1079
  block_tables: torch.Tensor,
1169
1080
  block_size: int,
1081
+ k_scale: Optional[torch.Tensor] = None,
1082
+ v_scale: Optional[torch.Tensor] = None,
1170
1083
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1084
+ assert self.quantization is None, "Sliding window attention does not support quantization"
1085
+ assert k_scale is None and v_scale is None, "Sliding window attention does not support quantization"
1086
+
1171
1087
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1172
1088
  key_state = key_state.unsqueeze(2)
1173
1089
  value_state = value_state.unsqueeze(2)
@@ -1199,8 +1115,7 @@ class SlidingWindowAttentionOp(AttentionOp):
1199
1115
  }
1200
1116
 
1201
1117
  if self.phase == "prefill" or self.phase == "image_prefill":
1202
- if not self.use_attention_mask or self.use_position_ids:
1203
- op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1118
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1204
1119
 
1205
1120
  attn_op_name = self.get_attn_op_name()
1206
1121
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
@@ -1213,3 +1128,97 @@ class SlidingWindowAttentionOp(AttentionOp):
1213
1128
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
1214
1129
 
1215
1130
  return attn_output
1131
+
1132
+
1133
+ class RotaryEmbedding(nn.Module):
1134
+ """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
1135
+
1136
+ def __init__(
1137
+ self,
1138
+ config: PretrainedConfig,
1139
+ max_seq_len_cached: int,
1140
+ ):
1141
+ super().__init__()
1142
+
1143
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1144
+ rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
1145
+ else:
1146
+ rope_type = "default"
1147
+
1148
+ inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
1149
+ cache_position = torch.arange(0, max_seq_len_cached)
1150
+ cache_position_expanded = cache_position[:, None]
1151
+
1152
+ if rope_type == "dynamic":
1153
+ freqs = cache_position_expanded.float() * inv_freq.float()
1154
+ else:
1155
+ inv_freq_expanded = inv_freq[None, :]
1156
+ freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
1157
+
1158
+ emb = torch.cat((freqs, freqs), dim=-1)
1159
+
1160
+ cos = emb.cos() * attention_scaling
1161
+ sin = emb.sin() * attention_scaling
1162
+
1163
+ self.register_buffer("_cos_cached", cos, persistent=False)
1164
+ self.register_buffer("_sin_cached", sin, persistent=False)
1165
+
1166
+ def forward(self, x, seq_len):
1167
+ return (
1168
+ self._cos_cached[:seq_len].to(dtype=torch.float32),
1169
+ self._sin_cached[:seq_len].to(dtype=torch.float32),
1170
+ )
1171
+
1172
+
1173
+ def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
1174
+ """Slice cos[cache_position], sin[cache_position] vector for the query."""
1175
+ if cache_position.shape[0] > 1:
1176
+ cos_all = []
1177
+ sin_all = []
1178
+ for i in range(cache_position.shape[0]):
1179
+ cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
1180
+ sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
1181
+ cos = torch.cat(cos_all, dim=0)
1182
+ sin = torch.cat(sin_all, dim=0)
1183
+ else:
1184
+ cos = cos[cache_position].unsqueeze(unsqueeze_dim)
1185
+ sin = sin[cache_position].unsqueeze(unsqueeze_dim)
1186
+
1187
+ return cos, sin
1188
+
1189
+
1190
+ def rotate_half(x):
1191
+ """Rotates half the hidden dims of the input."""
1192
+ x1 = x[..., : x.shape[-1] // 2]
1193
+ x2 = x[..., x.shape[-1] // 2 :]
1194
+ return torch.cat((-x2, x1), dim=-1)
1195
+
1196
+
1197
+ def apply_rotary_pos_emb(q, k, cos, sin):
1198
+ """Applies Rotary Position Embedding to the query and key tensors."""
1199
+ dtype = q.dtype
1200
+ q_embed = (q * cos) + (rotate_half(q) * sin)
1201
+ k_embed = (k * cos) + (rotate_half(k) * sin)
1202
+ q_embed = q_embed.to(dtype)
1203
+ k_embed = k_embed.to(dtype)
1204
+ return q_embed, k_embed
1205
+
1206
+
1207
+ def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
1208
+ # Partial rotary embedding
1209
+ query_rot, query_pass = (
1210
+ query_states[..., :ndim],
1211
+ query_states[..., ndim:],
1212
+ )
1213
+ key_rot, key_pass = (
1214
+ key_states[..., :ndim],
1215
+ key_states[..., ndim:],
1216
+ )
1217
+
1218
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
1219
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
1220
+
1221
+ # [batch_size, seq_length, num_heads, head_dim]
1222
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
1223
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
1224
+ return query_states, key_states