optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. optimum/rbln/__init__.py +116 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +171 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +33 -18
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +50 -24
  52. optimum/rbln/modeling_base.py +116 -35
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +100 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +93 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  158. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  159. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  160. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  161. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  162. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  163. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  164. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
  165. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  166. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  167. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  168. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  169. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  170. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  171. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  172. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  173. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  174. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  175. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  176. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
  177. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  178. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  179. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  180. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  181. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  182. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  183. optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
  184. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  185. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  186. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  187. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  188. optimum/rbln/utils/deprecation.py +213 -0
  189. optimum/rbln/utils/hub.py +22 -50
  190. optimum/rbln/utils/runtime_utils.py +85 -17
  191. optimum/rbln/utils/submodule.py +31 -9
  192. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  193. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  194. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  195. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  196. optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
  197. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import math
16
- from typing import List, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
19
  from torch import nn
@@ -21,106 +21,16 @@ from transformers import PretrainedConfig, PreTrainedModel
21
21
 
22
22
  from ....utils import logging
23
23
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
24
- from .configuration_decoderonly import CacheImplType
24
+ from ...utils.rbln_quantization import RBLNQuantizationConfig
25
+ from .configuration_lora import RBLNLoRAConfig
26
+ from .lora_architecture import LoRALinear
25
27
 
26
28
 
27
- logger = logging.get_logger(__name__)
28
-
29
- DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
30
- DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
31
- MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
32
- MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
33
- MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
34
- MAX_SLIDING_WINDOW_SIZE = 32_768
35
-
36
-
37
- def set_default_values(
38
- attn_impl: Optional[str] = None,
39
- kvcache_partition_len: Optional[int] = None,
40
- kvcache_block_size: Optional[int] = None,
41
- max_seq_len: Optional[int] = None,
42
- ) -> Tuple[str, int, int]:
43
- if attn_impl is None:
44
- attn_impl = "eager"
45
-
46
- if kvcache_partition_len is not None:
47
- if attn_impl == "eager":
48
- attn_impl = "flash_attn"
49
- logger.warning(
50
- "A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
51
- "set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
52
- "`attn_impl` has been automatically switched to 'flash_attn'."
53
- )
54
-
55
- if kvcache_partition_len is None and attn_impl == "flash_attn":
56
- kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
57
-
58
- if kvcache_block_size is None:
59
- if attn_impl == "eager":
60
- kvcache_block_size = max_seq_len
61
- else:
62
- kvcache_block_size = kvcache_partition_len
63
-
64
- return attn_impl, kvcache_partition_len, kvcache_block_size
65
-
66
-
67
- def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
68
- if attn_impl not in ["eager", "flash_attn"]:
69
- raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
29
+ if TYPE_CHECKING:
30
+ from .configuration_decoderonly import RBLNDecoderOnlyModelConfig
70
31
 
71
- ## Checking Constraints...
72
- # Constraint of eager attention:
73
- # - `max_seq_len` <= 32k
74
32
 
75
- # Constraints of flash attention:
76
- # 1. `max_seq_len` should be multiple of `partition_len`.
77
- # 2. 4k <= `partition_len` <= 32k.
78
- # 3. `max_seq_len` should be larger then 8k.
79
- if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
80
- raise ValueError(
81
- f"`max_seq_len` is set to {max_seq_len}, "
82
- f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
83
- f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
84
- " or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
85
- )
86
-
87
- if attn_impl == "flash_attn":
88
- if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
89
- raise ValueError(
90
- f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
91
- f"when using 'flash_attn'. Please adjust either value to meet this requirement."
92
- )
93
- elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
94
- raise ValueError(
95
- f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
96
- f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
97
- f"Please provide a valid value within this range."
98
- )
99
- elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
100
- raise ValueError(
101
- f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
102
- f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
103
- "this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
104
- )
105
-
106
- if kvcache_block_size is not None:
107
- if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
108
- raise ValueError(
109
- f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
110
- f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
111
- )
112
- elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
113
- raise ValueError(
114
- f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
115
- f"must always be set equal to the `max_seq_len` {max_seq_len}."
116
- )
117
-
118
-
119
- def validate_sliding_window_size(sliding_window: int, prefill_chunk_size: int):
120
- if sliding_window > MAX_SLIDING_WINDOW_SIZE - prefill_chunk_size:
121
- raise ValueError(
122
- f"Sliding window size ({sliding_window}) must be less than 32768 - prefill_chunk_size ({32768 - prefill_chunk_size})"
123
- )
33
+ logger = logging.get_logger(__name__)
124
34
 
125
35
 
126
36
  class DecoderOnlyWrapper(nn.Module):
@@ -137,40 +47,22 @@ class DecoderOnlyWrapper(nn.Module):
137
47
  - Wrapper should not contain neural network graph operations (including memory view handling)
138
48
 
139
49
  Args:
140
- causal_lm (PreTrainedModel): The Huggingface causal language model to wrap
141
- max_seq_len (int): Maximum sequence length for position embeddings and cache sizes
50
+ model (PreTrainedModel): The Huggingface causal language model to wrap
51
+ rbln_config: The RBLN model configuration containing all necessary parameters
142
52
  use_rotary_emb (bool): Whether to use rotary position embeddings
143
- attn_impl (str): The attention implementation to use.
144
- - "eager": Uses the standard attention.
145
- - "flash_attn": Uses flash attention. When set,
146
- the key/value cache is partitioned into chunks of length
147
- `kvcache_partition_len`.
148
- kvcache_partition_len (Optional[int]): Length of KV cache partitions for flash attention.
149
- This is only relevant if `attn_impl` is set to "flash_attn`
150
53
  """
151
54
 
152
55
  _use_learned_pos_emb = False
153
56
 
154
- def __init__(
155
- self,
156
- causal_lm: PreTrainedModel,
157
- max_seq_len: int,
158
- use_rotary_emb: bool,
159
- attn_impl: str,
160
- cache_impl: CacheImplType,
161
- use_inputs_embeds: bool,
162
- use_attention_mask: bool,
163
- use_position_ids: bool,
164
- kvcache_partition_len: Optional[int] = None,
165
- kvcache_block_size: Optional[int] = None,
166
- sliding_window: Optional[int] = None,
167
- sliding_window_layers: Optional[List[int]] = None,
168
- ):
57
+ def __init__(self, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig", use_rotary_emb: bool):
169
58
  super().__init__()
170
- self.config = causal_lm.config
59
+ self.quantization = rbln_config.quantization
60
+ self.config = model.config
61
+ self.is_causal_lm = getattr(model, "lm_head", None) is not None
62
+ self.rbln_config = rbln_config
171
63
 
172
64
  if use_rotary_emb:
173
- rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
65
+ rotary_embs = self.get_rotary_emb(max_seq_len=rbln_config.max_seq_len)
174
66
  if isinstance(rotary_embs, tuple):
175
67
  self.rotary_emb_global, self.rotary_emb_local = rotary_embs
176
68
  else:
@@ -178,43 +70,27 @@ class DecoderOnlyWrapper(nn.Module):
178
70
  else:
179
71
  self.rotary_emb = None
180
72
 
181
- self.attn_impl = attn_impl
182
- self.kvcache_block_size = kvcache_block_size
183
- self.use_attention_mask = use_attention_mask
184
- self.use_position_ids = use_position_ids
185
- self.use_inputs_embeds = use_inputs_embeds
186
- self.sliding_window_layers = sliding_window_layers
187
- self.cache_impl = cache_impl
188
- self.sliding_window = sliding_window
189
-
190
- if self.attn_impl == "flash_attn":
191
- self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
192
- elif self.attn_impl == "eager":
193
- self.kvcache_partition_len = None
194
- else:
195
- raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
196
-
197
- if kvcache_partition_len and kvcache_partition_len > max_seq_len:
73
+ if rbln_config.kvcache_partition_len and rbln_config.kvcache_partition_len > rbln_config.max_seq_len:
198
74
  raise ValueError(
199
- f"kvcache_partition_len({kvcache_partition_len}) should be lower"
200
- f" or equal to max_seq_len({max_seq_len})!"
75
+ f"kvcache_partition_len({rbln_config.kvcache_partition_len}) should be lower"
76
+ f" or equal to max_seq_len({rbln_config.max_seq_len})!"
201
77
  )
202
78
 
203
- self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm, max_seq_len)
79
+ self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
204
80
  self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
205
81
  self._phase = "prefill"
206
82
 
207
83
  def get_rotary_emb(self, max_seq_len):
208
84
  return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
209
85
 
210
- def get_decoder_layers(self, causal_lm: PreTrainedModel):
211
- return causal_lm.model.layers
86
+ def get_decoder_layers(self, model: PreTrainedModel):
87
+ return model.model.layers if self.is_causal_lm else model.layers
212
88
 
213
89
  def get_attn_layer(self, layer: nn.Module):
214
90
  return layer.self_attn
215
91
 
216
- def get_model_layer(self, causal_lm: PreTrainedModel):
217
- return causal_lm.model
92
+ def get_model_layer(self, model: PreTrainedModel):
93
+ return model.model if self.is_causal_lm else model
218
94
 
219
95
  def get_rbln_attn_class(self):
220
96
  return DecoderOnlyAttention
@@ -228,34 +104,28 @@ class DecoderOnlyWrapper(nn.Module):
228
104
  def get_rbln_causal_lm_class(self):
229
105
  return DecoderOnlyForCausalLM
230
106
 
231
- def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
107
+ def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
232
108
  new_layers = []
233
- for layer_idx, layer in enumerate(self.get_decoder_layers(causal_lm)):
109
+ for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
110
+ is_sliding = layer_idx in self.rbln_config.sliding_window_layers
234
111
  new_self_attn = self.get_rbln_attn_class()(
235
- self.get_attn_layer(layer),
236
- self.use_attention_mask,
237
- self.use_position_ids,
238
- kvcache_block_size=self.sliding_window
239
- if layer_idx in self.sliding_window_layers
240
- else self.kvcache_block_size,
241
- is_sliding=layer_idx in self.sliding_window_layers,
242
- attn_impl=self.attn_impl,
243
- kvcache_partition_len=self.kvcache_partition_len,
112
+ self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
244
113
  )
245
- new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
114
+ new_layer = self.get_rbln_layer_class()(layer, new_self_attn, lora_config=self.rbln_config.lora_config)
246
115
  new_layers.append(new_layer)
247
116
 
248
117
  new_model = self.get_rbln_model_class()(
249
- self.get_model_layer(causal_lm),
118
+ self.get_model_layer(model),
250
119
  new_layers,
251
- partition_len=self.kvcache_partition_len,
252
- max_seq_len=max_seq_len,
253
- kvcache_block_size=self.kvcache_block_size,
120
+ self.rbln_config,
254
121
  use_learned_pos_emb=self.__class__._use_learned_pos_emb,
255
- sliding_window_layers=self.sliding_window_layers,
256
122
  )
257
- new_causal_lm = self.get_rbln_causal_lm_class()(causal_lm, new_model)
258
- return new_causal_lm
123
+
124
+ if self.is_causal_lm:
125
+ new_model = self.get_rbln_causal_lm_class()(model, new_model)
126
+ return new_model
127
+ else:
128
+ return new_model
259
129
 
260
130
  @property
261
131
  def phase(self) -> str:
@@ -264,18 +134,24 @@ class DecoderOnlyWrapper(nn.Module):
264
134
  @phase.setter
265
135
  def phase(self, phase: str):
266
136
  self._phase = phase
267
- self.causal_lm.phase = phase
137
+ self.model.phase = phase
268
138
 
269
139
  def prepare_forward_args(self, *args):
270
140
  args = list(args)
271
- input_ids = None if self.use_inputs_embeds else args.pop(0)
272
- inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
141
+ input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
142
+ inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
273
143
  cache_position = args.pop(0)
274
- global_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "static"] else None
275
- local_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "sliding_window"] else None
276
- query_position = args.pop(0) if "prefill" in self.phase else None
277
- attention_mask = args.pop(0) if self.use_attention_mask else None
278
- position_ids = args.pop(0) if self.use_position_ids else None
144
+ global_block_tables = args.pop(0) if self.rbln_config.use_global_attention else None
145
+ local_block_tables = args.pop(0) if self.rbln_config.use_local_attention else None
146
+ query_position = (
147
+ args.pop(0)
148
+ # query_position usage: 1. causal_lm prefill or 2. sliding_window cache_position
149
+ if ("prefill" in self.phase and (self.is_causal_lm or self.rbln_config.use_local_attention))
150
+ else None
151
+ )
152
+ attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
153
+ position_ids = args.pop(0) if self.rbln_config.use_position_ids else None
154
+ lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
279
155
  past_key_values = args
280
156
 
281
157
  if len(past_key_values) != 2 * self.num_hidden_layers:
@@ -307,6 +183,7 @@ class DecoderOnlyWrapper(nn.Module):
307
183
  query_position,
308
184
  attention_mask,
309
185
  position_ids,
186
+ lora_int_id,
310
187
  past_key_values,
311
188
  rotary_emb,
312
189
  )
@@ -321,11 +198,12 @@ class DecoderOnlyWrapper(nn.Module):
321
198
  query_position,
322
199
  attention_mask,
323
200
  position_ids,
201
+ lora_int_id,
324
202
  past_key_values,
325
203
  rotary_emb,
326
204
  ) = self.prepare_forward_args(*args)
327
205
 
328
- logit = self.causal_lm(
206
+ logit = self.model(
329
207
  input_ids=input_ids,
330
208
  inputs_embeds=inputs_embeds,
331
209
  attention_mask=attention_mask,
@@ -336,6 +214,7 @@ class DecoderOnlyWrapper(nn.Module):
336
214
  rotary_emb=rotary_emb,
337
215
  global_block_tables=global_block_tables,
338
216
  local_block_tables=local_block_tables,
217
+ lora_int_id=lora_int_id,
339
218
  )
340
219
 
341
220
  return logit
@@ -392,6 +271,7 @@ class DecoderOnlyForCausalLM(nn.Module):
392
271
  rotary_emb: nn.Module = None,
393
272
  global_block_tables: Optional[torch.Tensor] = None,
394
273
  local_block_tables: Optional[torch.Tensor] = None,
274
+ lora_int_id: Optional[torch.Tensor] = None,
395
275
  ):
396
276
  # outputs
397
277
  hidden_states = self.model(
@@ -405,6 +285,7 @@ class DecoderOnlyForCausalLM(nn.Module):
405
285
  rotary_emb=rotary_emb,
406
286
  global_block_tables=global_block_tables,
407
287
  local_block_tables=local_block_tables,
288
+ lora_int_id=lora_int_id,
408
289
  )
409
290
 
410
291
  if "prefill" in self.phase:
@@ -427,6 +308,8 @@ class DecoderOnlyModel(nn.Module):
427
308
  Args:
428
309
  model: Original Huggingface model to adapt
429
310
  layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
311
+ rbln_config: RBLN model configuration
312
+ use_learned_pos_emb: Whether to use learned position embeddings (class-specific override)
430
313
 
431
314
  Attributes:
432
315
  _original_mod: Reference to original Huggingface model
@@ -438,21 +321,19 @@ class DecoderOnlyModel(nn.Module):
438
321
  self,
439
322
  model,
440
323
  layers: List["DecoderOnlyLayer"],
441
- partition_len=None,
442
- max_seq_len=None,
443
- kvcache_block_size=None,
324
+ rbln_config: "RBLNDecoderOnlyModelConfig",
444
325
  use_learned_pos_emb=None,
445
- sliding_window_layers=None,
446
326
  ):
447
327
  super().__init__()
448
328
  self._original_mod = model
449
329
  self.layers = nn.ModuleList(layers)
330
+ self.rbln_config = rbln_config
450
331
  self._phase = "prefill"
451
- self.partition_len = partition_len
452
- self.kvcache_block_size = kvcache_block_size
453
- self.max_seq_len = max_seq_len
332
+ self.partition_len = rbln_config.kvcache_partition_len
333
+ self.kvcache_block_size = rbln_config.kvcache_block_size
334
+ self.max_seq_len = rbln_config.max_seq_len
454
335
  self.use_learned_pos_emb = use_learned_pos_emb
455
- self.sliding_window_layers = sliding_window_layers
336
+ self.sliding_window_layers = rbln_config.sliding_window_layers
456
337
 
457
338
  @property
458
339
  def phase(self):
@@ -516,6 +397,7 @@ class DecoderOnlyModel(nn.Module):
516
397
  rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
517
398
  global_block_tables: Optional[torch.Tensor] = None,
518
399
  local_block_tables: Optional[torch.Tensor] = None,
400
+ lora_int_id: Optional[torch.Tensor] = None,
519
401
  ):
520
402
  # retrieve input_ids and inputs_embeds
521
403
  if (input_ids is None) ^ (inputs_embeds is not None):
@@ -588,6 +470,7 @@ class DecoderOnlyModel(nn.Module):
588
470
  cos=cos,
589
471
  sin=sin,
590
472
  block_tables=local_block_tables if is_sliding else global_block_tables,
473
+ lora_int_id=lora_int_id,
591
474
  )
592
475
 
593
476
  hidden_states = self.get_last_layernorm()(hidden_states)
@@ -619,11 +502,27 @@ class DecoderOnlyLayer(nn.Module):
619
502
  phase: Current operation phase ("prefill" or "decode")
620
503
  """
621
504
 
622
- def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
505
+ def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
623
506
  super().__init__()
624
507
  self._original_mod = layer
625
508
  self.self_attn = self_attn
626
509
  self._phase = "prefill"
510
+ self.lora_config = lora_config
511
+
512
+ # Replace target Linear modules in MLP with LoRALinear if configured
513
+ if self.lora_config:
514
+ mlp = self.get_mlp()
515
+ for proj_name in ["gate_proj", "up_proj", "down_proj"]:
516
+ if hasattr(mlp, proj_name):
517
+ original_linear = getattr(mlp, proj_name)
518
+ if isinstance(original_linear, nn.Linear):
519
+ lora_linear = LoRALinear(
520
+ original_linear=original_linear,
521
+ lora_config=self.lora_config,
522
+ projection_name=proj_name,
523
+ layer_idx=self.self_attn.layer_idx,
524
+ )
525
+ setattr(mlp, proj_name, lora_linear)
627
526
 
628
527
  @property
629
528
  def phase(self):
@@ -640,6 +539,25 @@ class DecoderOnlyLayer(nn.Module):
640
539
  def get_post_attention_layernorm(self) -> nn.LayerNorm:
641
540
  return self._original_mod.post_attention_layernorm
642
541
 
542
+ def get_mlp(self) -> nn.Module:
543
+ return self._original_mod.mlp
544
+
545
+ def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
546
+ mlp = self.get_mlp()
547
+ if self.lora_config and lora_int_id is not None:
548
+ gate = mlp.gate_proj(hidden_states, lora_int_id)
549
+ up = mlp.up_proj(hidden_states, lora_int_id)
550
+ act_fn = getattr(mlp, "act_fn", None) or getattr(mlp, "activation_fn", None)
551
+ if act_fn is None:
552
+ gate = torch.nn.functional.silu(gate)
553
+ else:
554
+ gate = act_fn(gate)
555
+ fused = gate * up
556
+ hidden_states = mlp.down_proj(fused, lora_int_id)
557
+ else:
558
+ hidden_states = mlp(hidden_states)
559
+ return hidden_states
560
+
643
561
  def forward(
644
562
  self,
645
563
  hidden_states: torch.Tensor,
@@ -649,6 +567,7 @@ class DecoderOnlyLayer(nn.Module):
649
567
  cos: Optional[torch.Tensor] = None,
650
568
  sin: Optional[torch.Tensor] = None,
651
569
  block_tables: Optional[torch.Tensor] = None,
570
+ lora_int_id: Optional[torch.Tensor] = None,
652
571
  ):
653
572
  residual = hidden_states
654
573
  hidden_states = self.get_pre_attention_layernorm()(hidden_states)
@@ -661,13 +580,14 @@ class DecoderOnlyLayer(nn.Module):
661
580
  cos=cos,
662
581
  sin=sin,
663
582
  block_tables=block_tables,
583
+ lora_int_id=lora_int_id,
664
584
  )
665
585
  hidden_states = residual + hidden_states
666
586
 
667
587
  # Fully Connected
668
588
  residual = hidden_states
669
589
  hidden_states = self.get_post_attention_layernorm()(hidden_states)
670
- hidden_states = self._original_mod.mlp(hidden_states)
590
+ hidden_states = self.forward_mlp(hidden_states, lora_int_id)
671
591
  hidden_states = residual + hidden_states
672
592
 
673
593
  return hidden_states
@@ -682,32 +602,27 @@ class DecoderOnlyAttention(nn.Module):
682
602
 
683
603
  Args:
684
604
  self_attn: Original attention module from the base model
685
- use_attention_mask: Whether to use attention mask
686
- use_position_ids: Whether to use position ids
687
- kvcache_block_size: Block size for KV cache
605
+ rbln_config: RBLN model configuration containing attention parameters
688
606
  is_sliding: Whether this is sliding window attention
689
- attn_impl: Attention implementation type ("eager" or "flash_attn")
690
607
  """
691
608
 
692
609
  def __init__(
693
610
  self,
694
611
  self_attn,
695
- use_attention_mask,
696
- use_position_ids,
697
- kvcache_block_size,
612
+ rbln_config: "RBLNDecoderOnlyModelConfig",
698
613
  is_sliding=False,
699
- attn_impl="eager",
700
- kvcache_partition_len=None,
701
614
  ):
702
615
  super().__init__()
703
616
  self._original_mod = self_attn
617
+ self.rbln_config = rbln_config
704
618
  self.layer_idx = self_attn.layer_idx
705
619
  self.num_heads = getattr(self._original_mod, "num_heads", None) or getattr(
706
620
  self._original_mod.config, "num_attention_heads"
707
621
  )
708
622
  self.head_dim = self._original_mod.head_dim
709
623
  self._phase = "prefill"
710
- self.scale = torch.tensor(self.get_attn_scale())
624
+ self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale()))
625
+ self.quantization = rbln_config.quantization
711
626
 
712
627
  if hasattr(self._original_mod, "num_key_value_heads"):
713
628
  self.num_key_value_heads = self._original_mod.num_key_value_heads
@@ -716,20 +631,29 @@ class DecoderOnlyAttention(nn.Module):
716
631
  else:
717
632
  self.num_key_value_heads = self.num_heads
718
633
 
719
- self.use_attention_mask = use_attention_mask
720
- self.use_position_ids = use_position_ids
634
+ self.use_attention_mask = rbln_config.use_attention_mask if not is_sliding else True
635
+ self.use_position_ids = rbln_config.use_position_ids
721
636
  self.is_sliding = is_sliding
722
- self.attn_impl = attn_impl
723
-
724
- if self.is_sliding and self.attn_impl != "eager":
725
- raise NotImplementedError("Sliding window attention is only supported with eager attention.")
726
-
727
- self.kvcache_partition_len = kvcache_partition_len
637
+ self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
638
+ self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
639
+ self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
640
+ self.lora_config = rbln_config.lora_config
728
641
 
729
642
  setattr(self, self.get_attention_name(), self.create_attention_op())
730
- self.kvcache_block_size = kvcache_block_size
731
643
  self.__post_init__()
732
644
 
645
+ def _init_lora_weights(self):
646
+ """Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
647
+ for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
648
+ original_linear = getattr(self._original_mod, proj_name)
649
+ lora_linear = LoRALinear(
650
+ original_linear=original_linear,
651
+ lora_config=self.lora_config,
652
+ projection_name=proj_name,
653
+ layer_idx=self.layer_idx,
654
+ )
655
+ setattr(self, proj_name, lora_linear)
656
+
733
657
  def get_attention_name(self):
734
658
  if self.is_sliding:
735
659
  return "sliding_window_attention"
@@ -767,6 +691,7 @@ class DecoderOnlyAttention(nn.Module):
767
691
  self.kvcache_partition_len,
768
692
  self.use_attention_mask,
769
693
  self.use_position_ids,
694
+ self.quantization,
770
695
  )
771
696
  elif self.attn_impl == "eager":
772
697
  return AttentionOp(
@@ -775,28 +700,46 @@ class DecoderOnlyAttention(nn.Module):
775
700
  self.num_key_value_heads,
776
701
  self.use_attention_mask,
777
702
  self.use_position_ids,
703
+ self.quantization,
778
704
  )
779
705
  else:
780
706
  raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
781
707
 
782
708
  def __post_init__(self):
783
- self.q_proj = self._original_mod.q_proj
784
- self.k_proj = self._original_mod.k_proj
785
- self.v_proj = self._original_mod.v_proj
786
- self.o_proj = self._original_mod.o_proj
787
-
788
- def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
709
+ # Initialize LoRA weights if configured, which will replace linear layers
710
+ if self.lora_config:
711
+ self._init_lora_weights()
712
+ else:
713
+ # Use original linear layers if no LoRA
714
+ self.q_proj = self._original_mod.q_proj
715
+ self.k_proj = self._original_mod.k_proj
716
+ self.v_proj = self._original_mod.v_proj
717
+ self.o_proj = self._original_mod.o_proj
718
+
719
+ def projection(
720
+ self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
721
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
789
722
  """Projects input hidden states into query, key, and value representations.
790
723
 
791
724
  Args:
792
725
  hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
726
+ lora_int_id: Adapter ID tensor for LoRA selection [batch_size]
793
727
 
794
728
  Returns:
795
729
  Tuple of (query_states, key_states, value_states)
796
730
  """
797
- query_states = self.q_proj(hidden_states)
798
- key_states = self.k_proj(hidden_states)
799
- value_states = self.v_proj(hidden_states)
731
+ # Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
732
+ if self.lora_config:
733
+ # LoRALinear handles both base projection and LoRA in one forward pass
734
+ query_states = self.q_proj(hidden_states, lora_int_id)
735
+ key_states = self.k_proj(hidden_states, lora_int_id)
736
+ value_states = self.v_proj(hidden_states, lora_int_id)
737
+ else:
738
+ # Standard linear projection without LoRA
739
+ query_states = self.q_proj(hidden_states)
740
+ key_states = self.k_proj(hidden_states)
741
+ value_states = self.v_proj(hidden_states)
742
+
800
743
  return query_states, key_states, value_states
801
744
 
802
745
  def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
@@ -805,6 +748,16 @@ class DecoderOnlyAttention(nn.Module):
805
748
  def get_attn_scale(self):
806
749
  return 1 / math.sqrt(self.head_dim)
807
750
 
751
+ def maybe_get_kvcache_scale(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
752
+ if hasattr(self, "k_proj") and hasattr(self, "v_proj"):
753
+ k_scale = getattr(self.k_proj, "k_scale", None)
754
+ v_scale = getattr(self.v_proj, "v_scale", None)
755
+ else:
756
+ k_scale = None
757
+ v_scale = None
758
+
759
+ return k_scale, v_scale
760
+
808
761
  def forward(
809
762
  self,
810
763
  hidden_states: torch.Tensor,
@@ -814,10 +767,11 @@ class DecoderOnlyAttention(nn.Module):
814
767
  cos: Optional[torch.Tensor] = None,
815
768
  sin: Optional[torch.Tensor] = None,
816
769
  block_tables: Optional[torch.Tensor] = None,
770
+ lora_int_id: Optional[torch.Tensor] = None,
817
771
  ):
818
772
  batch_size, query_length, _ = hidden_states.size()
819
773
 
820
- query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
774
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states, lora_int_id=lora_int_id)
821
775
 
822
776
  query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
823
777
  key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -834,6 +788,8 @@ class DecoderOnlyAttention(nn.Module):
834
788
  if batch_size > 1 and "prefill" in self.phase:
835
789
  raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
836
790
 
791
+ k_scale, v_scale = self.maybe_get_kvcache_scale()
792
+
837
793
  attn_output = self.get_attention_op()(
838
794
  query_states,
839
795
  key_states,
@@ -845,9 +801,18 @@ class DecoderOnlyAttention(nn.Module):
845
801
  scale=self.scale,
846
802
  block_tables=block_tables,
847
803
  block_size=self.kvcache_block_size,
804
+ k_scale=k_scale,
805
+ v_scale=v_scale,
848
806
  )
849
807
 
850
- attn_outputs = self.o_proj(attn_output)
808
+ # Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
809
+ if self.lora_config:
810
+ # LoRALinear handles both base projection and LoRA in one forward pass
811
+ attn_outputs = self.o_proj(attn_output, lora_int_id)
812
+ else:
813
+ # Standard linear projection without LoRA
814
+ attn_outputs = self.o_proj(attn_output)
815
+
851
816
  return attn_outputs
852
817
 
853
818
 
@@ -861,7 +826,13 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
861
826
 
862
827
  class AttentionOp(nn.Module):
863
828
  def __init__(
864
- self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool, use_position_ids: bool
829
+ self,
830
+ num_heads: int,
831
+ head_dim: int,
832
+ num_key_value_heads: int,
833
+ use_attention_mask: bool,
834
+ use_position_ids: bool,
835
+ quantization: Optional[RBLNQuantizationConfig] = None,
865
836
  ):
866
837
  super().__init__()
867
838
  self.num_heads = num_heads
@@ -870,16 +841,20 @@ class AttentionOp(nn.Module):
870
841
  self.phase = "prefill"
871
842
  self.use_attention_mask = use_attention_mask
872
843
  self.use_position_ids = use_position_ids
844
+ self.quantization = quantization
873
845
 
874
846
  def get_attn_op_name(self):
875
847
  phase = "decode" if self.phase == "decode" else "prefill"
876
- if self.use_attention_mask:
848
+ if self.use_attention_mask and not self.use_position_ids:
877
849
  attn_op_name = "paged_attn_"
878
850
  else:
879
851
  attn_op_name = "paged_causal_attn_"
880
852
 
881
853
  attn_op_name += phase
882
854
 
855
+ if self.quantization and self.quantization.kv_caches == "fp8":
856
+ attn_op_name += "_kv_fp8"
857
+
883
858
  return attn_op_name
884
859
 
885
860
  def forward(
@@ -894,6 +869,8 @@ class AttentionOp(nn.Module):
894
869
  scale: torch.Tensor,
895
870
  block_tables: torch.Tensor,
896
871
  block_size: int,
872
+ k_scale: Optional[torch.Tensor] = None,
873
+ v_scale: Optional[torch.Tensor] = None,
897
874
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
898
875
  """Compute attention with static shapes and explicit cache management.
899
876
 
@@ -906,6 +883,10 @@ class AttentionOp(nn.Module):
906
883
  past_value_state: Previous value cache states
907
884
  seq_position: Current position in sequence
908
885
  scale: Scale applied to attn weights
886
+ block_tables: Block tables for paged attention
887
+ block_size: Block size for paged attention
888
+ k_scale: Scale applied to key
889
+ v_scale: Scale applied to value
909
890
 
910
891
  Returns:
911
892
  Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
@@ -942,13 +923,19 @@ class AttentionOp(nn.Module):
942
923
  "block_size": block_size,
943
924
  }
944
925
 
945
- if self.use_attention_mask != self.use_position_ids:
926
+ if self.use_attention_mask:
946
927
  op_args["mask"] = attn_mask
947
928
 
948
929
  if self.phase == "prefill" or self.phase == "image_prefill":
949
930
  if not self.use_attention_mask or self.use_position_ids:
950
931
  op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
951
932
 
933
+ if self.quantization and self.quantization.kv_caches == "fp8":
934
+ if past_key_state.dtype != torch.float8_e4m3fn:
935
+ raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
936
+ op_args["k_scale"] = k_scale
937
+ op_args["v_scale"] = v_scale
938
+
952
939
  attn_op_name = self.get_attn_op_name()
953
940
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
954
941
  if attn_op is None:
@@ -962,97 +949,6 @@ class AttentionOp(nn.Module):
962
949
  return attn_output
963
950
 
964
951
 
965
- def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
966
- """Slice cos[cache_position], sin[cache_position] vector for the query."""
967
- if cache_position.shape[0] > 1:
968
- cos_all = []
969
- sin_all = []
970
- for i in range(cache_position.shape[0]):
971
- cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
972
- sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
973
- cos = torch.cat(cos_all, dim=0)
974
- sin = torch.cat(sin_all, dim=0)
975
- else:
976
- cos = cos[cache_position].unsqueeze(unsqueeze_dim)
977
- sin = sin[cache_position].unsqueeze(unsqueeze_dim)
978
-
979
- return cos, sin
980
-
981
-
982
- def rotate_half(x):
983
- """Rotates half the hidden dims of the input."""
984
- x1 = x[..., : x.shape[-1] // 2]
985
- x2 = x[..., x.shape[-1] // 2 :]
986
- return torch.cat((-x2, x1), dim=-1)
987
-
988
-
989
- def apply_rotary_pos_emb(q, k, cos, sin):
990
- """Applies Rotary Position Embedding to the query and key tensors."""
991
- q_embed = (q * cos) + (rotate_half(q) * sin)
992
- k_embed = (k * cos) + (rotate_half(k) * sin)
993
- return q_embed, k_embed
994
-
995
-
996
- def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
997
- # Partial rotary embedding
998
- query_rot, query_pass = (
999
- query_states[..., :ndim],
1000
- query_states[..., ndim:],
1001
- )
1002
- key_rot, key_pass = (
1003
- key_states[..., :ndim],
1004
- key_states[..., ndim:],
1005
- )
1006
-
1007
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
1008
- query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
1009
-
1010
- # [batch_size, seq_length, num_heads, head_dim]
1011
- query_states = torch.cat((query_rot, query_pass), dim=-1)
1012
- key_states = torch.cat((key_rot, key_pass), dim=-1)
1013
- return query_states, key_states
1014
-
1015
-
1016
- class RotaryEmbedding(nn.Module):
1017
- """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
1018
-
1019
- def __init__(
1020
- self,
1021
- config: PretrainedConfig,
1022
- max_seq_len_cached: int,
1023
- ):
1024
- super().__init__()
1025
-
1026
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1027
- rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
1028
- else:
1029
- rope_type = "default"
1030
-
1031
- inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
1032
- cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
1033
- cache_position_expanded = cache_position[:, None]
1034
-
1035
- if rope_type == "dynamic":
1036
- freqs = cache_position_expanded.float() * inv_freq.float()
1037
- else:
1038
- inv_freq_expanded = inv_freq[None, :]
1039
- freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
1040
-
1041
- emb = torch.cat((freqs, freqs), dim=-1)
1042
-
1043
- cos = emb.cos() * attention_scaling
1044
- sin = emb.sin() * attention_scaling
1045
-
1046
- self.register_buffer("_cos_cached", cos, persistent=False)
1047
- self.register_buffer("_sin_cached", sin, persistent=False)
1048
-
1049
- def forward(self, x, seq_len):
1050
- return (
1051
- self._cos_cached[:seq_len].to(dtype=x.dtype),
1052
- self._sin_cached[:seq_len].to(dtype=x.dtype),
1053
- )
1054
-
1055
-
1056
952
  class FlashAttentionOp(AttentionOp):
1057
953
  def __init__(
1058
954
  self,
@@ -1062,6 +958,7 @@ class FlashAttentionOp(AttentionOp):
1062
958
  kvcache_partition_len: int,
1063
959
  use_attention_mask: bool,
1064
960
  use_position_ids: bool,
961
+ quantization: Optional[RBLNQuantizationConfig] = None,
1065
962
  ):
1066
963
  super().__init__(
1067
964
  num_heads=num_heads,
@@ -1069,18 +966,22 @@ class FlashAttentionOp(AttentionOp):
1069
966
  num_key_value_heads=num_key_value_heads,
1070
967
  use_attention_mask=use_attention_mask,
1071
968
  use_position_ids=use_position_ids,
969
+ quantization=quantization,
1072
970
  )
1073
971
  self.kvcache_partition_size = kvcache_partition_len
1074
972
 
1075
973
  def get_attn_op_name(self):
1076
974
  phase = "decode" if self.phase == "decode" else "prefill"
1077
- if self.use_attention_mask:
975
+ if self.use_attention_mask and not self.use_position_ids:
1078
976
  attn_op_name = "paged_flash_attn_"
1079
977
  else:
1080
978
  attn_op_name = "paged_flash_causal_attn_"
1081
979
 
1082
980
  attn_op_name += phase
1083
981
 
982
+ if self.quantization and self.quantization.kv_caches == "fp8":
983
+ attn_op_name += "_kv_fp8"
984
+
1084
985
  return attn_op_name
1085
986
 
1086
987
  def forward(
@@ -1095,6 +996,8 @@ class FlashAttentionOp(AttentionOp):
1095
996
  scale,
1096
997
  block_tables,
1097
998
  block_size,
999
+ k_scale=None,
1000
+ v_scale=None,
1098
1001
  ):
1099
1002
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1100
1003
  key_state = key_state.unsqueeze(2)
@@ -1128,13 +1031,19 @@ class FlashAttentionOp(AttentionOp):
1128
1031
  "partition": self.kvcache_partition_size,
1129
1032
  }
1130
1033
 
1131
- if self.use_attention_mask != self.use_position_ids:
1034
+ if self.use_attention_mask:
1132
1035
  op_args["mask"] = attn_mask
1133
1036
 
1134
1037
  if self.phase == "prefill" or self.phase == "image_prefill":
1135
1038
  if not self.use_attention_mask or self.use_position_ids:
1136
1039
  op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1137
1040
 
1041
+ if self.quantization and self.quantization.kv_caches == "fp8":
1042
+ if past_key_state.dtype != torch.float8_e4m3fn:
1043
+ raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
1044
+ op_args["k_scale"] = k_scale
1045
+ op_args["v_scale"] = v_scale
1046
+
1138
1047
  attn_op_name = self.get_attn_op_name()
1139
1048
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
1140
1049
  if attn_op is None:
@@ -1151,8 +1060,8 @@ class FlashAttentionOp(AttentionOp):
1151
1060
  class SlidingWindowAttentionOp(AttentionOp):
1152
1061
  def get_attn_op_name(self):
1153
1062
  phase = "decode" if self.phase == "decode" else "prefill"
1154
- if self.use_attention_mask:
1155
- raise NotImplementedError("Attention mask is not supported for sliding window attention.")
1063
+ if not self.use_attention_mask:
1064
+ raise NotImplementedError("Attention mask is needed for sliding window attention.")
1156
1065
 
1157
1066
  attn_op_name = "paged_sliding_window_attn_" + phase
1158
1067
  return attn_op_name
@@ -1162,14 +1071,19 @@ class SlidingWindowAttentionOp(AttentionOp):
1162
1071
  query_state: torch.Tensor,
1163
1072
  key_state: torch.Tensor,
1164
1073
  value_state: torch.Tensor,
1165
- attn_mask: torch.Tensor,
1074
+ attn_mask: Optional[torch.Tensor],
1166
1075
  past_key_state: torch.Tensor,
1167
1076
  past_value_state: torch.Tensor,
1168
1077
  seq_position: Tuple[torch.Tensor],
1169
1078
  scale: torch.Tensor,
1170
1079
  block_tables: torch.Tensor,
1171
1080
  block_size: int,
1081
+ k_scale: Optional[torch.Tensor] = None,
1082
+ v_scale: Optional[torch.Tensor] = None,
1172
1083
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1084
+ assert self.quantization is None, "Sliding window attention does not support quantization"
1085
+ assert k_scale is None and v_scale is None, "Sliding window attention does not support quantization"
1086
+
1173
1087
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1174
1088
  key_state = key_state.unsqueeze(2)
1175
1089
  value_state = value_state.unsqueeze(2)
@@ -1201,8 +1115,7 @@ class SlidingWindowAttentionOp(AttentionOp):
1201
1115
  }
1202
1116
 
1203
1117
  if self.phase == "prefill" or self.phase == "image_prefill":
1204
- if not self.use_attention_mask or self.use_position_ids:
1205
- op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1118
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1206
1119
 
1207
1120
  attn_op_name = self.get_attn_op_name()
1208
1121
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
@@ -1215,3 +1128,97 @@ class SlidingWindowAttentionOp(AttentionOp):
1215
1128
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
1216
1129
 
1217
1130
  return attn_output
1131
+
1132
+
1133
+ class RotaryEmbedding(nn.Module):
1134
+ """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
1135
+
1136
+ def __init__(
1137
+ self,
1138
+ config: PretrainedConfig,
1139
+ max_seq_len_cached: int,
1140
+ ):
1141
+ super().__init__()
1142
+
1143
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1144
+ rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
1145
+ else:
1146
+ rope_type = "default"
1147
+
1148
+ inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
1149
+ cache_position = torch.arange(0, max_seq_len_cached)
1150
+ cache_position_expanded = cache_position[:, None]
1151
+
1152
+ if rope_type == "dynamic":
1153
+ freqs = cache_position_expanded.float() * inv_freq.float()
1154
+ else:
1155
+ inv_freq_expanded = inv_freq[None, :]
1156
+ freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
1157
+
1158
+ emb = torch.cat((freqs, freqs), dim=-1)
1159
+
1160
+ cos = emb.cos() * attention_scaling
1161
+ sin = emb.sin() * attention_scaling
1162
+
1163
+ self.register_buffer("_cos_cached", cos, persistent=False)
1164
+ self.register_buffer("_sin_cached", sin, persistent=False)
1165
+
1166
+ def forward(self, x, seq_len):
1167
+ return (
1168
+ self._cos_cached[:seq_len].to(dtype=torch.float32),
1169
+ self._sin_cached[:seq_len].to(dtype=torch.float32),
1170
+ )
1171
+
1172
+
1173
+ def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
1174
+ """Slice cos[cache_position], sin[cache_position] vector for the query."""
1175
+ if cache_position.shape[0] > 1:
1176
+ cos_all = []
1177
+ sin_all = []
1178
+ for i in range(cache_position.shape[0]):
1179
+ cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
1180
+ sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
1181
+ cos = torch.cat(cos_all, dim=0)
1182
+ sin = torch.cat(sin_all, dim=0)
1183
+ else:
1184
+ cos = cos[cache_position].unsqueeze(unsqueeze_dim)
1185
+ sin = sin[cache_position].unsqueeze(unsqueeze_dim)
1186
+
1187
+ return cos, sin
1188
+
1189
+
1190
+ def rotate_half(x):
1191
+ """Rotates half the hidden dims of the input."""
1192
+ x1 = x[..., : x.shape[-1] // 2]
1193
+ x2 = x[..., x.shape[-1] // 2 :]
1194
+ return torch.cat((-x2, x1), dim=-1)
1195
+
1196
+
1197
+ def apply_rotary_pos_emb(q, k, cos, sin):
1198
+ """Applies Rotary Position Embedding to the query and key tensors."""
1199
+ dtype = q.dtype
1200
+ q_embed = (q * cos) + (rotate_half(q) * sin)
1201
+ k_embed = (k * cos) + (rotate_half(k) * sin)
1202
+ q_embed = q_embed.to(dtype)
1203
+ k_embed = k_embed.to(dtype)
1204
+ return q_embed, k_embed
1205
+
1206
+
1207
+ def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
1208
+ # Partial rotary embedding
1209
+ query_rot, query_pass = (
1210
+ query_states[..., :ndim],
1211
+ query_states[..., ndim:],
1212
+ )
1213
+ key_rot, key_pass = (
1214
+ key_states[..., :ndim],
1215
+ key_states[..., ndim:],
1216
+ )
1217
+
1218
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
1219
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
1220
+
1221
+ # [batch_size, seq_length, num_heads, head_dim]
1222
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
1223
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
1224
+ return query_states, key_states