optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -13,18 +13,12 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import math
16
- from typing import List, Optional, Tuple
16
+ from typing import List, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
19
  from torch import nn
20
20
  from transformers import PretrainedConfig, PreTrainedModel
21
21
 
22
- from ....ops import (
23
- register_rbln_custom_paged_attention,
24
- register_rbln_custom_paged_causal_attention,
25
- register_rbln_custom_paged_flash_attention,
26
- register_rbln_custom_paged_flash_causal_attention,
27
- )
28
22
  from ....utils import logging
29
23
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
30
24
 
@@ -38,30 +32,39 @@ MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
38
32
  MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
39
33
 
40
34
 
41
- def validate_attention_method(
42
- rbln_attn_impl: str, rbln_kvcache_partition_len: int, rbln_kvcache_block_size: int, rbln_max_seq_len: int
43
- ) -> Tuple[str, int]:
44
- if rbln_kvcache_partition_len is not None:
45
- if rbln_attn_impl == "eager":
46
- raise ValueError(
47
- f"`rbln_kvcache_partition_len` is set to {rbln_kvcache_partition_len}, but KV cache partitioning"
48
- " is not supported with 'eager' attention. Please set `rbln_kvcache_partition_len` to None, "
49
- "or switch `rbln_attn_impl` to 'flash_attn' to use KV cache partitioning."
50
- )
51
- elif rbln_attn_impl is None:
52
- rbln_attn_impl = "flash_attn"
35
+ def set_default_values(
36
+ attn_impl: Optional[str] = None,
37
+ kvcache_partition_len: Optional[int] = None,
38
+ kvcache_block_size: Optional[int] = None,
39
+ max_seq_len: Optional[int] = None,
40
+ ) -> Tuple[str, int, int]:
41
+ if attn_impl is None:
42
+ attn_impl = "eager"
43
+
44
+ if kvcache_partition_len is not None:
45
+ if attn_impl == "eager":
46
+ attn_impl = "flash_attn"
53
47
  logger.warning(
54
- "A non-null `rbln_kvcache_partition_len` was provided, but `rbln_attn_impl` was not explicitly set. "
55
- "Since KV cache partitioning is only supported with flash attention, "
56
- "`rbln_attn_impl` has been automatically switched to 'flash_attn'."
48
+ "A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
49
+ "set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
50
+ "`attn_impl` has been automatically switched to 'flash_attn'."
57
51
  )
58
52
 
59
- rbln_attn_impl = "eager" if rbln_attn_impl is None else rbln_attn_impl
60
- if rbln_attn_impl not in ["eager", "flash_attn"]:
61
- raise ValueError(f"Unknown `rbln_attn_impl` : {rbln_attn_impl}. (Available : 'eager', 'flash_attn`)")
53
+ if kvcache_partition_len is None and attn_impl == "flash_attn":
54
+ kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
55
+
56
+ if kvcache_block_size is None:
57
+ if attn_impl == "eager":
58
+ kvcache_block_size = max_seq_len
59
+ else:
60
+ kvcache_block_size = kvcache_partition_len
61
+
62
+ return attn_impl, kvcache_partition_len, kvcache_block_size
63
+
62
64
 
63
- if rbln_kvcache_partition_len is None and rbln_attn_impl == "flash_attn":
64
- rbln_kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
65
+ def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
66
+ if attn_impl not in ["eager", "flash_attn"]:
67
+ raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
65
68
 
66
69
  ## Checking Constraints...
67
70
  # Constraint of eager attention:
@@ -71,47 +74,45 @@ def validate_attention_method(
71
74
  # 1. `max_seq_len` should be multiple of `partition_len`.
72
75
  # 2. 4k <= `partition_len` <= 32k.
73
76
  # 3. `max_seq_len` should be larger then 8k.
74
- if rbln_attn_impl == "eager" and rbln_max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
77
+ if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
75
78
  raise ValueError(
76
- f"`rbln_max_seq_len` is set to {rbln_max_seq_len}, "
79
+ f"`max_seq_len` is set to {max_seq_len}, "
77
80
  f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
78
- f"Please reduce the `rbln_max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
79
- " or consider switching `rbln_attn_impl` to 'flash_attn' for larger sequence lengths."
81
+ f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
82
+ " or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
80
83
  )
81
84
 
82
- if rbln_attn_impl == "flash_attn":
83
- if rbln_max_seq_len // rbln_kvcache_partition_len < 2 or rbln_max_seq_len % rbln_kvcache_partition_len != 0:
85
+ if attn_impl == "flash_attn":
86
+ if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
84
87
  raise ValueError(
85
- f"`rbln_max_seq_len` ({rbln_max_seq_len}) must be a multiple of `rbln_kvcache_partition_len` ({rbln_kvcache_partition_len}) "
88
+ f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
86
89
  f"when using 'flash_attn'. Please adjust either value to meet this requirement."
87
90
  )
88
- elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= rbln_kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
91
+ elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
89
92
  raise ValueError(
90
- f"`rbln_kvcache_partition_len` ({rbln_kvcache_partition_len}) is out of the supported range for 'flash_attn' "
91
- f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `rbln_kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
93
+ f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
94
+ f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
92
95
  f"Please provide a valid value within this range."
93
96
  )
94
- elif rbln_max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
97
+ elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
95
98
  raise ValueError(
96
- f"`rbln_max_seq_len` ({rbln_max_seq_len}) is too small for 'flash_attn'. The minimum "
97
- f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `rbln_max_seq_len` to meet "
98
- "this requirement, or consider switching `rbln_attn_impl` to 'eager' for shorter lengths."
99
+ f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
100
+ f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
101
+ "this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
99
102
  )
100
103
 
101
- if rbln_kvcache_block_size is not None:
102
- if rbln_attn_impl == "flash_attn" and rbln_kvcache_partition_len != rbln_kvcache_block_size:
104
+ if kvcache_block_size is not None:
105
+ if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
103
106
  raise ValueError(
104
- f" When using 'flash attention', the `rbln_kvcache_block_size` ({rbln_kvcache_block_size}) "
105
- f"must always be set equal to the `rbln_kvcache_partition_len` {rbln_kvcache_partition_len}."
107
+ f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
108
+ f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
106
109
  )
107
- elif rbln_attn_impl == "eager" and rbln_kvcache_block_size != rbln_max_seq_len:
110
+ elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
108
111
  raise ValueError(
109
- f" When using 'eager attention', the `rbln_kvcache_block_size` ({rbln_kvcache_block_size}) "
110
- f"must always be set equal to the `rbln_max_seq_len` {rbln_max_seq_len}."
112
+ f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
113
+ f"must always be set equal to the `max_seq_len` {max_seq_len}."
111
114
  )
112
115
 
113
- return rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size
114
-
115
116
 
116
117
  class DecoderOnlyWrapper(nn.Module):
117
118
  """A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
@@ -162,16 +163,8 @@ class DecoderOnlyWrapper(nn.Module):
162
163
  self.use_attention_mask = use_attention_mask
163
164
  if self.attn_impl == "flash_attn":
164
165
  self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
165
- if self.use_attention_mask:
166
- register_rbln_custom_paged_flash_attention()
167
- else:
168
- register_rbln_custom_paged_flash_causal_attention()
169
166
  elif self.attn_impl == "eager":
170
167
  self.kvcache_partition_len = None
171
- if self.use_attention_mask:
172
- register_rbln_custom_paged_attention()
173
- else:
174
- register_rbln_custom_paged_causal_attention()
175
168
  else:
176
169
  raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
177
170
 
@@ -191,6 +184,7 @@ class DecoderOnlyWrapper(nn.Module):
191
184
 
192
185
  def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
193
186
  new_layers = []
187
+
194
188
  for layer in causal_lm.model.layers:
195
189
  if self.attn_impl == "eager":
196
190
  new_self_attn = DecoderOnlyAttention(
@@ -208,6 +202,7 @@ class DecoderOnlyWrapper(nn.Module):
208
202
 
209
203
  new_layer = DecoderOnlyLayer(layer, new_self_attn)
210
204
  new_layers.append(new_layer)
205
+
211
206
  new_model = DecoderOnlyModel(
212
207
  causal_lm.model,
213
208
  new_layers,
@@ -227,6 +222,53 @@ class DecoderOnlyWrapper(nn.Module):
227
222
  self._phase = phase
228
223
  self.causal_lm.phase = phase
229
224
 
225
+ def forward_common(
226
+ self,
227
+ input_ids_or_inputs_embeds: torch.Tensor,
228
+ cache_position: torch.Tensor,
229
+ attention_mask: torch.Tensor,
230
+ query_position: torch.Tensor,
231
+ block_tables: torch.Tensor,
232
+ rotary_emb: Union[nn.Module, torch.Tensor],
233
+ *past_key_values: List[torch.Tensor],
234
+ ):
235
+ if input_ids_or_inputs_embeds.ndim == 2:
236
+ input_ids = input_ids_or_inputs_embeds
237
+ inputs_embeds = None
238
+ elif input_ids_or_inputs_embeds.ndim == 3:
239
+ input_ids = None
240
+ inputs_embeds = input_ids_or_inputs_embeds
241
+ else:
242
+ raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
243
+
244
+ if len(past_key_values) != 2 * self.num_hidden_layers:
245
+ raise ValueError(
246
+ f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
247
+ )
248
+
249
+ # [key, value] * n_layer -> ( (key, value) ) * n_layer
250
+ # cache shape : batch, n_heads, 1, max_seq_len, head_dim
251
+ _past_key_values = []
252
+ for i in range(self.config.num_hidden_layers):
253
+ key_states = past_key_values[i * 2]
254
+ value_states = past_key_values[i * 2 + 1]
255
+ past_key_value = [key_states, value_states]
256
+ _past_key_values.append(past_key_value)
257
+ past_key_values = _past_key_values
258
+
259
+ logit = self.causal_lm(
260
+ input_ids=input_ids,
261
+ inputs_embeds=inputs_embeds,
262
+ attention_mask=attention_mask,
263
+ cache_position=cache_position,
264
+ query_position=query_position,
265
+ past_key_values=past_key_values,
266
+ rotary_emb=rotary_emb,
267
+ block_tables=block_tables,
268
+ )
269
+
270
+ return logit
271
+
230
272
  def forward(self, *args):
231
273
  if self.phase == "decode":
232
274
  if self.use_attention_mask:
@@ -269,43 +311,16 @@ class DecoderOnlyWrapper(nn.Module):
269
311
  else:
270
312
  raise ValueError(f"Unknown phase: {self.phase}")
271
313
 
272
- if input_ids_or_inputs_embeds.ndim == 2:
273
- input_ids = input_ids_or_inputs_embeds
274
- inputs_embeds = None
275
- elif input_ids_or_inputs_embeds.ndim == 3:
276
- input_ids = None
277
- inputs_embeds = input_ids_or_inputs_embeds
278
- else:
279
- raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
280
-
281
- if len(past_key_values) != 2 * self.num_hidden_layers:
282
- raise ValueError(
283
- f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
284
- )
285
-
286
- # [key, value] * n_layer -> ( (key, value) ) * n_layer
287
- # cache shape : batch, n_heads, 1, max_seq_len, head_dim
288
- _past_key_values = []
289
- for i in range(self.config.num_hidden_layers):
290
- key_states = past_key_values[i * 2]
291
- value_states = past_key_values[i * 2 + 1]
292
- past_key_value = [key_states, value_states]
293
- _past_key_values.append(past_key_value)
294
- past_key_values = _past_key_values
295
-
296
- logit = self.causal_lm(
297
- input_ids=input_ids,
298
- inputs_embeds=inputs_embeds,
299
- attention_mask=attention_mask,
300
- cache_position=cache_position,
301
- query_position=query_position,
302
- past_key_values=past_key_values,
303
- rotary_emb=self.rotary_emb,
304
- block_tables=block_tables,
314
+ return self.forward_common(
315
+ input_ids_or_inputs_embeds,
316
+ cache_position,
317
+ attention_mask,
318
+ query_position,
319
+ block_tables,
320
+ self.rotary_emb,
321
+ *past_key_values,
305
322
  )
306
323
 
307
- return logit
308
-
309
324
 
310
325
  class DecoderOnlyForCausalLM(nn.Module):
311
326
  """A specialized wrapper for Causal Language Models optimized for RBLN compilation.
@@ -329,12 +344,13 @@ class DecoderOnlyForCausalLM(nn.Module):
329
344
  _phase: Current processing phase ("prefill" or "decode")
330
345
  """
331
346
 
332
- def __init__(self, causal_lm: PreTrainedModel, model):
347
+ def __init__(self, causal_lm: PreTrainedModel, model: nn.Module):
333
348
  super().__init__()
334
349
  self.config = causal_lm.config
335
350
  self._original_mod = causal_lm
336
351
  self.model = model
337
352
  self._phase = "prefill"
353
+ self.lm_head = self._original_mod.lm_head
338
354
 
339
355
  @property
340
356
  def phase(self):
@@ -370,7 +386,7 @@ class DecoderOnlyForCausalLM(nn.Module):
370
386
  if self.phase == "prefill":
371
387
  hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
372
388
 
373
- logits = self._original_mod.lm_head(hidden_states)
389
+ logits = self.lm_head(hidden_states)
374
390
  return logits
375
391
 
376
392
 
@@ -462,8 +478,12 @@ class DecoderOnlyModel(nn.Module):
462
478
 
463
479
  # get cos,sin vector if needed
464
480
  if rotary_emb is not None:
465
- cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
466
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
481
+ if isinstance(rotary_emb, torch.Tensor):
482
+ cos = rotary_emb[0]
483
+ sin = rotary_emb[1]
484
+ else:
485
+ cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
486
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
467
487
  else:
468
488
  batch_size = inputs_embeds.shape[0]
469
489
  if cache_position.shape[0] > 1:
@@ -756,55 +776,55 @@ class AttentionOp(nn.Module):
756
776
  if self.phase == "decode":
757
777
  if self.use_attention_mask:
758
778
  attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
759
- query_state,
760
- key_state,
761
- value_state,
762
- attn_mask,
763
- past_key_state.unsqueeze(2),
764
- past_value_state.unsqueeze(2),
765
- seq_position,
766
- scale,
767
- block_tables,
768
- block_size,
779
+ q=query_state,
780
+ k=key_state,
781
+ v=value_state,
782
+ mask=attn_mask,
783
+ kcache=past_key_state.unsqueeze(2),
784
+ vcache=past_value_state.unsqueeze(2),
785
+ seq=seq_position,
786
+ scale=scale,
787
+ block_table=block_tables,
788
+ block_size=block_size,
769
789
  )
770
790
  else:
771
791
  attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
772
- query_state,
773
- key_state,
774
- value_state,
775
- past_key_state.unsqueeze(2),
776
- past_value_state.unsqueeze(2),
777
- seq_position,
778
- scale,
779
- block_tables,
780
- block_size,
792
+ q=query_state,
793
+ k=key_state,
794
+ v=value_state,
795
+ kcache=past_key_state.unsqueeze(2),
796
+ vcache=past_value_state.unsqueeze(2),
797
+ seq=seq_position,
798
+ scale=scale,
799
+ block_table=block_tables,
800
+ block_size=block_size,
781
801
  )
782
802
 
783
803
  else:
784
804
  if self.use_attention_mask:
785
805
  attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
786
- query_state,
787
- key_state,
788
- value_state,
789
- attn_mask,
790
- past_key_state.unsqueeze(2),
791
- past_value_state.unsqueeze(2),
792
- seq_position,
793
- scale,
794
- block_tables,
795
- block_size,
806
+ q=query_state,
807
+ k=key_state,
808
+ v=value_state,
809
+ mask=attn_mask,
810
+ kcache=past_key_state.unsqueeze(2),
811
+ vcache=past_value_state.unsqueeze(2),
812
+ seq=seq_position,
813
+ scale=scale,
814
+ block_table=block_tables,
815
+ block_size=block_size,
796
816
  )
797
817
  else:
798
818
  attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
799
- query_state,
800
- key_state,
801
- value_state,
802
- past_key_state.unsqueeze(2),
803
- past_value_state.unsqueeze(2),
804
- seq_position,
805
- scale,
806
- block_tables,
807
- block_size,
819
+ q=query_state,
820
+ k=key_state,
821
+ v=value_state,
822
+ kcache=past_key_state.unsqueeze(2),
823
+ vcache=past_value_state.unsqueeze(2),
824
+ seq=seq_position,
825
+ scale=scale,
826
+ block_table=block_tables,
827
+ block_size=block_size,
808
828
  )
809
829
 
810
830
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
@@ -840,7 +860,6 @@ def rotate_half(x):
840
860
 
841
861
  def apply_rotary_pos_emb(q, k, cos, sin):
842
862
  """Applies Rotary Position Embedding to the query and key tensors."""
843
-
844
863
  q_embed = (q * cos) + (rotate_half(q) * sin)
845
864
  k_embed = (k * cos) + (rotate_half(k) * sin)
846
865
  return q_embed, k_embed
@@ -1015,58 +1034,58 @@ class FlashAttentionOp(AttentionOp):
1015
1034
  if self.phase == "decode":
1016
1035
  if self.use_attention_mask:
1017
1036
  attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
1018
- query_state,
1019
- key_state,
1020
- value_state,
1021
- attn_mask,
1022
- past_key_state.unsqueeze(2),
1023
- past_value_state.unsqueeze(2),
1024
- seq_position,
1025
- scale,
1026
- block_tables,
1027
- kvcache_block_size,
1028
- self.kvcache_partition_size,
1037
+ q=query_state,
1038
+ k=key_state,
1039
+ v=value_state,
1040
+ mask=attn_mask,
1041
+ kcache=past_key_state.unsqueeze(2),
1042
+ vcache=past_value_state.unsqueeze(2),
1043
+ seq=seq_position,
1044
+ scale=scale,
1045
+ block_table=block_tables,
1046
+ block_size=kvcache_block_size,
1047
+ partition=self.kvcache_partition_size,
1029
1048
  )
1030
1049
  else:
1031
1050
  attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
1032
- query_state,
1033
- key_state,
1034
- value_state,
1035
- past_key_state.unsqueeze(2),
1036
- past_value_state.unsqueeze(2),
1037
- seq_position,
1038
- scale,
1039
- block_tables,
1040
- kvcache_block_size,
1041
- self.kvcache_partition_size,
1051
+ q=query_state,
1052
+ k=key_state,
1053
+ v=value_state,
1054
+ kcache=past_key_state.unsqueeze(2),
1055
+ vcache=past_value_state.unsqueeze(2),
1056
+ seq=seq_position,
1057
+ scale=scale,
1058
+ block_table=block_tables,
1059
+ block_size=kvcache_block_size,
1060
+ partition=self.kvcache_partition_size,
1042
1061
  )
1043
1062
  else:
1044
1063
  if self.use_attention_mask:
1045
1064
  attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
1046
- query_state,
1047
- key_state,
1048
- value_state,
1049
- attn_mask,
1050
- past_key_state.unsqueeze(2),
1051
- past_value_state.unsqueeze(2),
1052
- seq_position,
1053
- scale,
1054
- block_tables,
1055
- kvcache_block_size,
1056
- self.kvcache_partition_size,
1065
+ q=query_state,
1066
+ k=key_state,
1067
+ v=value_state,
1068
+ mask=attn_mask,
1069
+ kcache=past_key_state.unsqueeze(2),
1070
+ vcache=past_value_state.unsqueeze(2),
1071
+ seq=seq_position,
1072
+ scale=scale,
1073
+ block_table=block_tables,
1074
+ block_size=kvcache_block_size,
1075
+ partition=self.kvcache_partition_size,
1057
1076
  )
1058
1077
  else:
1059
1078
  attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
1060
- query_state,
1061
- key_state,
1062
- value_state,
1063
- past_key_state.unsqueeze(2),
1064
- past_value_state.unsqueeze(2),
1065
- seq_position,
1066
- scale,
1067
- block_tables,
1068
- kvcache_block_size,
1069
- self.kvcache_partition_size,
1079
+ q=query_state,
1080
+ k=key_state,
1081
+ v=value_state,
1082
+ kcache=past_key_state.unsqueeze(2),
1083
+ vcache=past_value_state.unsqueeze(2),
1084
+ seq=seq_position,
1085
+ scale=scale,
1086
+ block_table=block_tables,
1087
+ block_size=kvcache_block_size,
1088
+ partition=self.kvcache_partition_size,
1070
1089
  )
1071
1090
 
1072
1091
  # reshape for removing repeat_kv