optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. optimum/rbln/__init__.py +41 -38
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +26 -2
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
  5. optimum/rbln/diffusers/models/__init__.py +36 -3
  6. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  7. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
  8. optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
  9. optimum/rbln/diffusers/models/controlnet.py +54 -14
  10. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  11. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  12. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  13. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
  14. optimum/rbln/diffusers/pipelines/__init__.py +23 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
  19. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
  23. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
  31. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  32. optimum/rbln/modeling.py +238 -0
  33. optimum/rbln/modeling_base.py +186 -760
  34. optimum/rbln/modeling_config.py +31 -7
  35. optimum/rbln/ops/__init__.py +26 -0
  36. optimum/rbln/ops/attn.py +221 -0
  37. optimum/rbln/ops/flash_attn.py +70 -0
  38. optimum/rbln/ops/kv_cache_update.py +69 -0
  39. optimum/rbln/transformers/__init__.py +20 -2
  40. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  41. optimum/rbln/transformers/modeling_generic.py +385 -0
  42. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  43. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  44. optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
  45. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  46. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  47. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
  48. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  49. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  50. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
  52. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
  53. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  54. optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
  55. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  56. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  57. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  58. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  59. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  60. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
  61. optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
  62. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
  63. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  64. optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
  65. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  66. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  68. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  69. optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  71. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  72. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  73. optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
  74. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  75. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  76. optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
  77. optimum/rbln/utils/decorator_utils.py +51 -11
  78. optimum/rbln/utils/hub.py +131 -0
  79. optimum/rbln/utils/import_utils.py +22 -1
  80. optimum/rbln/utils/logging.py +37 -0
  81. optimum/rbln/utils/model_utils.py +52 -0
  82. optimum/rbln/utils/runtime_utils.py +10 -4
  83. optimum/rbln/utils/save_utils.py +17 -0
  84. optimum/rbln/utils/submodule.py +137 -0
  85. optimum_rbln-0.2.0.dist-info/METADATA +117 -0
  86. optimum_rbln-0.2.0.dist-info/RECORD +114 -0
  87. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
  88. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
  89. optimum/rbln/transformers/cache_utils.py +0 -107
  90. optimum/rbln/transformers/generation/streamers.py +0 -139
  91. optimum/rbln/transformers/generation/utils.py +0 -397
  92. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  93. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  94. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  95. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  96. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  97. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  98. optimum/rbln/utils/context.py +0 -58
  99. optimum/rbln/utils/timer_utils.py +0 -43
  100. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  101. optimum_rbln-0.1.13.dist-info/RECORD +0 -107
  102. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  103. optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -22,623 +22,740 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import math
25
- from typing import Dict, Optional, Tuple
25
+ from typing import List, Optional, Tuple
26
26
 
27
27
  import torch
28
28
  from torch import nn
29
- from transformers import PretrainedConfig
30
- from transformers.modeling_outputs import (
31
- BaseModelOutputWithPast,
32
- )
29
+ from transformers import PretrainedConfig, PreTrainedModel
33
30
 
31
+ from ....ops import register_rbln_custom_attention, register_rbln_custom_flash_attention
34
32
  from ....utils import logging
35
- from ...cache_utils import RebelDynamicCache
36
33
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
37
34
 
38
35
 
39
36
  logger = logging.get_logger(__name__)
40
- """
41
- ##############################################################################
42
- # RBLN custom operation (python interface)
43
- # torch.compile custom operation
44
- # torch.library.define - kernel declaration
45
- # torch.library.impl - kernel implementation
46
- # torch.library.impl_abstract - symbolic trace
47
- ##############################################################################
48
- """
49
-
50
- # RBLN custom op(flash attention decode)
51
- torch.library.define(
52
- "rbln_custom_ops::flash_attn_decode",
53
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
54
- )
55
-
56
-
57
- @torch.library.impl("rbln_custom_ops::flash_attn_decode", "cpu")
58
- def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, partition):
59
- """
60
- WORKAROUND:
61
- Partition is declared as an argument to the function, even though it is
62
- not actually used in the CPU implementation, this allows the rbln compiler
63
- to perform flash attention operations with partition as an argument.
64
- """
65
- assert kcache.dim() == k.dim()
66
- assert vcache.dim() == v.dim()
67
- assert k.size(-2) == v.size(-2)
68
- assert partition.dim() == 1
69
- b = 0
70
- if seq.dim() == 1:
71
- s = seq[0]
72
- elif seq.dim() == 0:
73
- s = seq
74
- else:
75
- assert False
76
- e = s + k.size(-2)
77
- updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
78
- updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
79
- attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
80
- attn_weight = attn_weight + mask
81
- attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
82
- attn_output = torch.matmul(attn_weight, updated_v)
83
- return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
84
37
 
38
+ DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
39
+ DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
40
+ MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
41
+ MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
42
+ MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
85
43
 
86
- @torch.library.impl_abstract("rbln_custom_ops::flash_attn_decode")
87
- def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
88
- return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
89
44
 
45
+ def validate_attention_method(
46
+ rbln_attn_impl: str, rbln_kvcache_partition_len: int, rbln_max_seq_len: int
47
+ ) -> Tuple[str, int]:
48
+ if rbln_kvcache_partition_len is not None:
49
+ if rbln_attn_impl == "eager":
50
+ raise ValueError(
51
+ f"`rbln_kvcache_partition_len` is set to {rbln_kvcache_partition_len}, but KV cache partitioning"
52
+ " is not supported with 'eager' attention. Please set `rbln_kvcache_partition_len` to None, "
53
+ "or switch `rbln_attn_impl` to 'flash_attn' to use KV cache partitioning."
54
+ )
55
+ elif rbln_attn_impl is None:
56
+ rbln_attn_impl = "flash_attn"
57
+ logger.warning(
58
+ "A non-null `rbln_kvcache_partition_len` was provided, but `rbln_attn_impl` was not explicitly set. "
59
+ "Since KV cache partitioning is only supported with flash attention, "
60
+ "`rbln_attn_impl` has been automatically switched to 'flash_attn'."
61
+ )
90
62
 
91
- # RBLN custom op(flash attention prefill)
92
- torch.library.define(
93
- "rbln_custom_ops::flash_attn_prefill",
94
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
95
- )
63
+ rbln_attn_impl = "eager" if rbln_attn_impl is None else rbln_attn_impl
64
+ if rbln_attn_impl not in ["eager", "flash_attn"]:
65
+ raise ValueError(f"Unknown `rbln_attn_impl` : {rbln_attn_impl}. (Available : 'eager', 'flash_attn`)")
66
+
67
+ if rbln_kvcache_partition_len is None and rbln_attn_impl == "flash_attn":
68
+ rbln_kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
69
+
70
+ ## Checking Constraints...
71
+ # Constraint of eager attention:
72
+ # - `max_seq_len` <= 32k
73
+
74
+ # Constraints of flash attention:
75
+ # 1. `max_seq_len` should be multiple of `partition_len`.
76
+ # 2. 4k <= `partition_len` <= 32k.
77
+ # 3. `max_seq_len` should be larger then 8k.
78
+ if rbln_attn_impl == "eager" and rbln_max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
79
+ raise ValueError(
80
+ f"`rbln_max_seq_len` is set to {rbln_max_seq_len}, "
81
+ f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
82
+ f"Please reduce the `rbln_max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
83
+ " or consider switching `rbln_attn_impl` to 'flash_attn' for larger sequence lengths."
84
+ )
96
85
 
86
+ if rbln_attn_impl == "flash_attn":
87
+ if rbln_max_seq_len // rbln_kvcache_partition_len < 2 or rbln_max_seq_len % rbln_kvcache_partition_len != 0:
88
+ raise ValueError(
89
+ f"`rbln_max_seq_len` ({rbln_max_seq_len}) must be a multiple of `rbln_kvcache_partition_len` ({rbln_kvcache_partition_len}) "
90
+ f"when using 'flash_attn'. Please adjust either value to meet this requirement."
91
+ )
92
+ elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= rbln_kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
93
+ raise ValueError(
94
+ f"`rbln_kvcache_partition_len` ({rbln_kvcache_partition_len}) is out of the supported range for 'flash_attn' "
95
+ f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `rbln_kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
96
+ f"Please provide a valid value within this range."
97
+ )
98
+ elif rbln_max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
99
+ raise ValueError(
100
+ f"`rbln_max_seq_len` ({rbln_max_seq_len}) is too small for 'flash_attn'. The minimum "
101
+ f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `rbln_max_seq_len` to meet "
102
+ "this requirement, or consider switching `rbln_attn_impl` to 'eager' for shorter lengths."
103
+ )
97
104
 
98
- @torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
99
- def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, partition):
100
- """
101
- WORKAROUND:
102
- Partition is declared as an argument to the function, even though it is
103
- not actually used in the CPU implementation, this allows the rbln compiler
104
- to perform flash attention operations with partition as an argument.
105
+ return rbln_attn_impl, rbln_kvcache_partition_len
106
+
107
+
108
+ class DecoderOnlyWrapper(nn.Module):
109
+ """A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
110
+
111
+ This wrapper is designed to:
112
+ 1. Convert Huggingface decoder models for RBLN compilation with static shapes
113
+ 2. Handle input/model mapping and additional information supply (e.g., positional embeddings)
114
+ 3. Manage different attention implementations (standard and flash attention)
115
+ 4. Support both prefill and decode phases
116
+
117
+ Notes:
118
+ - Wrapper must only receive positional arguments in forward() due to torch.jit.trace dependency
119
+ - Wrapper should not contain neural network graph operations (including memory view handling)
120
+
121
+ Args:
122
+ causal_lm (PreTrainedModel): The Huggingface causal language model to wrap
123
+ max_seq_len (int): Maximum sequence length for position embeddings and cache sizes
124
+ use_rotary_emb (bool): Whether to use rotary position embeddings
125
+ attn_impl (str): The attention implementation to use.
126
+ - "eager": Uses the standard attention.
127
+ - "flash_attn": Uses flash attention. When set,
128
+ the key/value cache is partitioned into chunks of length
129
+ `kvcache_partition_len`.
130
+ kvcache_partition_len (Optional[int]): Length of KV cache partitions for flash attention.
131
+ This is only relevant if `attn_impl` is set to "flash_attn`
105
132
  """
106
- assert kcache.dim() == k.dim()
107
- assert vcache.dim() == v.dim()
108
- assert k.size(-2) == v.size(-2)
109
- assert partition.dim() == 1
110
- if batch.dim() == 1:
111
- b = batch[0]
112
- elif batch.dim() == 0:
113
- b = batch
114
- else:
115
- assert False
116
- if seq.dim() == 1:
117
- s = seq[0]
118
- elif seq.dim() == 0:
119
- s = seq
120
- else:
121
- assert False
122
- e = s + k.size(-2)
123
- updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
124
- updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
125
- attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
126
- attn_weight = attn_weight + mask
127
- attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
128
- attn_output = torch.matmul(attn_weight, updated_v)
129
- return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
130
133
 
134
+ def __init__(
135
+ self,
136
+ causal_lm: PreTrainedModel,
137
+ max_seq_len: int,
138
+ use_rotary_emb: bool,
139
+ attn_impl: str,
140
+ kvcache_partition_len: Optional[int] = None,
141
+ ):
142
+ super().__init__()
143
+ self.config = causal_lm.config
131
144
 
132
- @torch.library.impl_abstract("rbln_custom_ops::flash_attn_prefill")
133
- def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
134
- return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
135
-
145
+ if use_rotary_emb:
146
+ self.rotary_emb = self.get_rotary_emb(max_seq_len=max_seq_len)
147
+ else:
148
+ self.rotary_emb = None
149
+
150
+ self.attn_impl = attn_impl
151
+ if self.attn_impl == "flash_attn":
152
+ self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
153
+ register_rbln_custom_flash_attention()
154
+ elif self.attn_impl == "eager":
155
+ self.kvcache_partition_len = None
156
+ register_rbln_custom_attention()
157
+ else:
158
+ raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
136
159
 
137
- # RBLN custom op(cache update)
138
- torch.library.define("rbln_custom_ops::rbln_cache_update", "(Tensor x, Tensor y, Tensor z, Tensor w) -> Tensor")
160
+ if kvcache_partition_len and kvcache_partition_len > max_seq_len:
161
+ raise ValueError(
162
+ f"kvcache_partition_len({kvcache_partition_len}) should be lower"
163
+ f" or equal to max_seq_len({max_seq_len})!"
164
+ )
139
165
 
166
+ self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm)
140
167
 
141
- @torch.library.impl("rbln_custom_ops::rbln_cache_update", "cpu")
142
- def rbln_cache_update_cpu(cache, value, batch, seq):
143
- updated_cache = cache[batch].slice_scatter(value, dim=-2, start=batch[0], end=batch[0] + seq[0])
144
- return updated_cache
168
+ self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
169
+ self._phase = "prefill"
145
170
 
171
+ def get_rotary_emb(self, max_seq_len):
172
+ return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
146
173
 
147
- @torch.library.impl_abstract("rbln_custom_ops::rbln_cache_update")
148
- def rbln_cache_update_abstract(cache, value, batch, seq):
149
- return torch.empty_like(cache)
174
+ def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel):
175
+ new_layers = []
176
+ for layer in causal_lm.model.layers:
177
+ if self.attn_impl == "eager":
178
+ new_self_attn = DecoderOnlyAttention(layer.self_attn)
179
+ elif self.attn_impl == "flash_attn":
180
+ new_self_attn = DecoderOnlyFlashAttention(
181
+ layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
182
+ )
183
+ else:
184
+ raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
185
+
186
+ new_layer = DecoderOnlyLayer(layer, new_self_attn)
187
+ new_layers.append(new_layer)
188
+ new_model = DecoderOnlyModel(causal_lm.model, new_layers, partition_len=self.kvcache_partition_len)
189
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
190
+ return new_causal_lm
191
+
192
+ @property
193
+ def phase(self) -> str:
194
+ return self._phase
195
+
196
+ @phase.setter
197
+ def phase(self, phase: str):
198
+ self._phase = phase
199
+ self.causal_lm.phase = phase
200
+
201
+ def forward(self, *args):
202
+ if self.phase == "decode":
203
+ (
204
+ input_ids_or_inputs_embeds,
205
+ attention_mask,
206
+ cache_position,
207
+ *past_key_values,
208
+ ) = args
209
+ batch_position = torch.tensor(0, dtype=torch.int16)
210
+ query_position = None
211
+ elif self.phase == "prefill":
212
+ (
213
+ input_ids_or_inputs_embeds,
214
+ attention_mask,
215
+ cache_position,
216
+ batch_position,
217
+ query_position,
218
+ *past_key_values,
219
+ ) = args
220
+ else:
221
+ raise ValueError(f"Unknown phase: {self.phase}")
150
222
 
223
+ if input_ids_or_inputs_embeds.ndim == 2:
224
+ input_ids = input_ids_or_inputs_embeds
225
+ inputs_embeds = None
226
+ elif input_ids_or_inputs_embeds.ndim == 3:
227
+ input_ids = None
228
+ inputs_embeds = input_ids_or_inputs_embeds
229
+ else:
230
+ raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
151
231
 
152
- class DecoderOnlyAttention:
153
- def _attn(self, query_state, key_state, value_state, attn_mask, past_key_value, batch_idx=0, is_prefill=False):
154
- # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
155
- key_state = key_state.unsqueeze(2)
156
- value_state = value_state.unsqueeze(2)
157
- attn_mask = attn_mask.unsqueeze(2)
232
+ if len(past_key_values) != 2 * self.num_hidden_layers:
233
+ raise ValueError(
234
+ f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
235
+ )
158
236
 
159
- query_state = query_state.view(
160
- 1,
161
- self.num_key_value_heads,
162
- self.num_heads // self.num_key_value_heads,
163
- -1,
164
- self.head_dim,
237
+ # [key, value] * n_layer -> ( (key, value) ) * n_layer
238
+ # cache shape : batch, n_heads, 1, max_seq_len, head_dim
239
+ _past_key_values = []
240
+ for i in range(self.config.num_hidden_layers):
241
+ key_states = past_key_values[i * 2]
242
+ value_states = past_key_values[i * 2 + 1]
243
+ past_key_value = [key_states, value_states]
244
+ _past_key_values.append(past_key_value)
245
+ past_key_values = _past_key_values
246
+
247
+ logit, present_key_values = self.causal_lm(
248
+ input_ids=input_ids,
249
+ inputs_embeds=inputs_embeds,
250
+ attention_mask=attention_mask,
251
+ cache_position=cache_position,
252
+ batch_position=batch_position,
253
+ query_position=query_position,
254
+ past_key_values=past_key_values,
255
+ rotary_emb=self.rotary_emb,
165
256
  )
166
257
 
167
- key_state, value_state = past_key_value.update(
168
- key_state, value_state, self.layer_idx, batch_idx, read_first_step=is_prefill
169
- )
258
+ # ((key, value)) * n_layer -> [key, value] * n_layer
259
+ _present_key_values = ()
260
+ for i in range(self.num_hidden_layers):
261
+ key_states = present_key_values[i][0]
262
+ value_states = present_key_values[i][1]
263
+ _present_key_values = _present_key_values + (key_states, value_states)
264
+ present_key_values = _present_key_values
170
265
 
171
- attn_weight = torch.matmul(query_state, key_state.transpose(3, 4)) / math.sqrt(self.head_dim)
172
- attn_weight += attn_mask
173
- attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_state.dtype)
174
- attn_output = torch.matmul(attn_weight, value_state)
266
+ return logit, present_key_values
175
267
 
176
- attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
177
- attn_output = attn_output.transpose(1, 2).contiguous()
178
- attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
179
268
 
180
- return attn_output, key_state, value_state
269
+ class DecoderOnlyForCausalLM(nn.Module):
270
+ """A specialized wrapper for Causal Language Models optimized for RBLN compilation.
181
271
 
182
- def forward(
183
- self,
184
- hidden_states: torch.Tensor,
185
- attention_mask: Optional[torch.Tensor] = None,
186
- position_ids: Optional[torch.LongTensor] = None,
187
- past_key_value: Optional[RebelDynamicCache] = None,
188
- batch_index: Optional[torch.Tensor] = None,
189
- output_attentions: bool = False,
190
- cos: Optional[torch.Tensor] = None,
191
- sin: Optional[torch.Tensor] = None,
192
- **kwargs,
193
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
194
- bsz, q_len, _ = hidden_states.size()
195
- query_states = self.q_proj(hidden_states)
196
- key_states = self.k_proj(hidden_states)
197
- value_states = self.v_proj(hidden_states)
272
+ This class adapts Huggingface's CausalLM (or similar models) for RBLN deployment by:
273
+ 1. Managing model phases (prefill/decode) throughout the computation graph
274
+ 2. Handling output shape alignments for static compilation
275
+ 3. Coordinating between the original model and RBLN-optimized components
198
276
 
199
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
200
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
201
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
202
-
203
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
204
-
205
- # Decoder (bsz > 1)
206
- if bsz > 1:
207
- iterate_results = {"key_states": [], "value_states": [], "attn_output": []}
208
- for b in range(bsz):
209
- attn_output, key_state, value_state = DecoderOnlyAttention._attn(
210
- self,
211
- query_states[b].unsqueeze(0),
212
- key_states[b].unsqueeze(0),
213
- value_states[b].unsqueeze(0),
214
- attention_mask[b].unsqueeze(0),
215
- past_key_value,
216
- batch_idx=b,
217
- is_prefill=False,
218
- )
277
+ The class serves as an intermediate layer between DecoderOnlyWrapper and the core model,
278
+ focusing on maintaining correct model behavior while enabling RBLN-specific optimizations.
219
279
 
220
- iterate_results["key_states"].append(key_state)
221
- iterate_results["value_states"].append(value_state)
222
- iterate_results["attn_output"].append(attn_output)
280
+ Args:
281
+ causal_lm (PreTrainedModel): Original Huggingface causal language model
282
+ model (DecoderOnlyModel): RBLN-optimized model instance
223
283
 
224
- key_states = torch.cat(iterate_results["key_states"], dim=0)
225
- value_states = torch.cat(iterate_results["value_states"], dim=0)
226
- attn_output = torch.cat(iterate_results["attn_output"], dim=0)
227
- # Prefill & Decoder (bsz == 1)
228
- else:
229
- attn_output, key_states, value_states = DecoderOnlyAttention._attn(
230
- self,
231
- query_states,
232
- key_states,
233
- value_states,
234
- attention_mask,
235
- past_key_value,
236
- batch_idx=batch_index,
237
- is_prefill=True,
238
- )
239
-
240
- attn_output = self.o_proj(attn_output)
284
+ Attributes:
285
+ config: Configuration from the original causal language model
286
+ _original_mod: Reference to the original model for components like lm_head
287
+ model: RBLN-optimized decoder model instance
288
+ _phase: Current processing phase ("prefill" or "decode")
289
+ """
241
290
 
242
- if not output_attentions:
243
- attn_weight = None
291
+ def __init__(self, causal_lm: PreTrainedModel, model):
292
+ super().__init__()
293
+ self.config = causal_lm.config
294
+ self._original_mod = causal_lm
295
+ self.model = model
296
+ self._phase = "prefill"
244
297
 
245
- return attn_output, attn_weight, key_states, value_states
298
+ @property
299
+ def phase(self):
300
+ return self._phase
246
301
 
302
+ @phase.setter
303
+ def phase(self, phase: str):
304
+ self._phase = phase
305
+ self.model.phase = phase
247
306
 
248
- class DecoderOnlyFlashAttention:
249
307
  def forward(
250
308
  self,
251
- hidden_states: torch.Tensor,
252
- attention_mask: Optional[torch.Tensor] = None,
253
- position_ids: Optional[torch.LongTensor] = None,
254
- past_key_value: Optional[RebelDynamicCache] = None,
255
- batch_index: Optional[torch.Tensor] = None,
256
- output_attentions: bool = False,
257
- cos: Optional[torch.Tensor] = None,
258
- sin: Optional[torch.Tensor] = None,
259
- cache_pos_for_partitions: Optional[torch.Tensor] = None,
260
- kvcache_partition_size: Optional[torch.Tensor] = None,
261
- **kwargs,
262
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
263
- bsz, q_len, _ = hidden_states.size()
264
- query_states = self.q_proj(hidden_states)
265
- key_states = self.k_proj(hidden_states)
266
- value_states = self.v_proj(hidden_states)
267
-
268
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
269
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
270
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
271
-
272
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
273
-
274
- # Decoder (bsz > 1)
275
- if bsz > 1:
276
- all_key_states = []
277
- all_value_states = []
278
- all_attn_output = []
279
-
280
- for b in range(bsz):
281
- query_state = query_states[b].unsqueeze(0)
282
- attn_mask = attention_mask[b].unsqueeze(0)
283
- key_state = key_states[b].unsqueeze(0)
284
- value_state = value_states[b].unsqueeze(0)
285
-
286
- # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
287
- key_state = key_state.unsqueeze(2)
288
- value_state = value_state.unsqueeze(2)
289
- attn_mask = attn_mask.unsqueeze(2)
290
-
291
- query_state = query_state.view(
292
- 1,
293
- self.num_key_value_heads,
294
- self.num_heads // self.num_key_value_heads,
295
- q_len,
296
- self.head_dim,
297
- )
298
-
299
- # RBLN custom flash attention(decode), dummy batch index
300
- sidx = cache_pos_for_partitions[b][0]
301
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_decode(
302
- query_state,
303
- key_state,
304
- value_state,
305
- attn_mask,
306
- past_key_value.key_cache[self.layer_idx].unsqueeze(2),
307
- past_key_value.value_cache[self.layer_idx].unsqueeze(2),
308
- sidx,
309
- kvcache_partition_size,
310
- )
309
+ input_ids: torch.Tensor = None,
310
+ inputs_embeds: torch.Tensor = None,
311
+ attention_mask: torch.Tensor = None,
312
+ cache_position: torch.Tensor = None,
313
+ batch_position: torch.Tensor = None,
314
+ query_position: torch.Tensor = None,
315
+ past_key_values: Tuple[Tuple[torch.Tensor]] = None,
316
+ rotary_emb: nn.Module = None,
317
+ ):
318
+ # outputs
319
+ hidden_states, present_key_values = self.model(
320
+ input_ids=input_ids,
321
+ inputs_embeds=inputs_embeds,
322
+ attention_mask=attention_mask,
323
+ cache_position=cache_position,
324
+ batch_position=batch_position,
325
+ past_key_values=past_key_values,
326
+ rotary_emb=rotary_emb,
327
+ )
311
328
 
312
- # reshape for removing repeat_kv
313
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
314
- attn_output = attn_output.transpose(1, 2).contiguous()
315
- attn_output = attn_output.reshape(1, q_len, self.num_heads * self.head_dim)
329
+ if self.phase == "prefill":
330
+ hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
316
331
 
317
- all_key_states.append(key_state)
318
- all_value_states.append(value_state)
319
- all_attn_output.append(attn_output)
332
+ logits = self._original_mod.lm_head(hidden_states)
333
+ output = (logits, present_key_values)
334
+ return output
320
335
 
321
- key_states = torch.cat(all_key_states, dim=0)
322
- value_states = torch.cat(all_value_states, dim=0)
323
- attn_output = torch.cat(all_attn_output, dim=0)
324
336
 
325
- else:
326
- # reshape for removing repeat_kv
327
- key_states = key_states.unsqueeze(2)
328
- value_states = value_states.unsqueeze(2)
329
- attention_mask = attention_mask.unsqueeze(2)
330
- query_states = query_states.view(
331
- 1,
332
- self.num_key_value_heads,
333
- self.num_heads // self.num_key_value_heads,
334
- q_len,
335
- self.head_dim,
336
- )
337
+ class DecoderOnlyModel(nn.Module):
338
+ """A modified decoder-only model implementation optimized for RBLN compilation.
337
339
 
338
- assert batch_index.dim() == 0
339
- assert not output_attentions
340
- bidx = batch_index
341
- sidx = cache_pos_for_partitions[0][0]
342
- attn_output, key_states, value_states = torch.ops.rbln_custom_ops.flash_attn_prefill(
343
- query_states,
344
- key_states,
345
- value_states,
346
- attention_mask,
347
- past_key_value.key_cache[self.layer_idx].unsqueeze(2),
348
- past_key_value.value_cache[self.layer_idx].unsqueeze(2),
349
- bidx,
350
- sidx,
351
- kvcache_partition_size,
352
- )
340
+ Args:
341
+ model: Original Huggingface model to adapt
342
+ layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
353
343
 
354
- # reshape for removing repeat_kv
355
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
356
- attn_output = attn_output.transpose(1, 2).contiguous()
357
- attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
344
+ Attributes:
345
+ _original_mod: Reference to original Huggingface model
346
+ layers: ModuleList of RBLN-optimized transformer layers
347
+ _phase: Current processing phase ("prefill" or "decode")
348
+ """
358
349
 
359
- attn_output = self.o_proj(attn_output)
350
+ def __init__(self, model, layers: List["DecoderOnlyLayer"], partition_len=None):
351
+ super().__init__()
352
+ self._original_mod = model
353
+ self.layers = nn.ModuleList(layers)
354
+ self._phase = "prefill"
355
+ self.partition_len = partition_len
356
+
357
+ @property
358
+ def phase(self):
359
+ return self._phase
360
+
361
+ @phase.setter
362
+ def phase(self, phase: str):
363
+ self._phase = phase
364
+ for layer in self.layers:
365
+ layer.phase = phase
366
+
367
+ @property
368
+ def attn_impl(self) -> str:
369
+ return "eager" if self.partition_len is None else "flash_attn"
370
+
371
+ @property
372
+ def hidden_multiplier(self):
373
+ return 1
374
+
375
+ def convert_sequence_positions_for_flash_attn(self, seq_positions, max_seq_len):
376
+ if self.attn_impl != "flash_attn":
377
+ raise NotImplementedError(f"Unknown attn_impl ({self.attn_impl}).")
378
+
379
+ partition_len = self.partition_len
380
+ num_partition = max_seq_len // partition_len
381
+
382
+ cs = seq_positions.repeat(num_partition, 1).transpose(0, 1)
383
+ pidx = torch.arange(num_partition)
384
+ cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
385
+ return cache_pos_for_partitions
386
+
387
+ def get_last_layernorm(self) -> nn.LayerNorm:
388
+ return self._original_mod.norm
389
+
390
+ def get_embedding(self) -> nn.Embedding:
391
+ return self._original_mod.embed_tokens
392
+
393
+ def get_pos_embedding(self) -> nn.Embedding:
394
+ raise NotImplementedError(
395
+ "The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
396
+ )
360
397
 
361
- if not output_attentions:
362
- attn_weight = None
398
+ def forward(
399
+ self,
400
+ input_ids: torch.Tensor = None,
401
+ inputs_embeds: torch.Tensor = None,
402
+ attention_mask: torch.Tensor = None,
403
+ cache_position: torch.Tensor = None,
404
+ batch_position: torch.Tensor = None,
405
+ past_key_values: Tuple[Tuple[torch.Tensor]] = None,
406
+ rotary_emb: nn.Module = None,
407
+ ):
408
+ # retrieve input_ids and inputs_embeds
409
+ if (input_ids is None) ^ (inputs_embeds is not None):
410
+ raise ValueError(
411
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
412
+ )
363
413
 
364
- return attn_output, attn_weight, key_states, value_states
414
+ # embed positions
415
+ if inputs_embeds is None:
416
+ inputs_embeds = self.get_embedding()(input_ids)
365
417
 
418
+ hidden_states = inputs_embeds * self.hidden_multiplier
366
419
 
367
- DECODERONLY_ATTENTION_CLASSES = {
368
- "eager": DecoderOnlyAttention,
369
- "flash_attn_rbln": DecoderOnlyFlashAttention,
370
- # "sdpa": DecoderOnlySdpaAttention,
371
- }
420
+ # get cos,sin vector if needed
421
+ if rotary_emb is not None:
422
+ cos, sin = rotary_emb(hidden_states, attention_mask.shape[-1]) # dtype carrier, max_seq_len
423
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
424
+ else:
425
+ batch_size = inputs_embeds.shape[0]
426
+ if cache_position.shape[0] > 1:
427
+ position_embeds = []
428
+ for b_idx in range(batch_size):
429
+ position_embed = self.get_pos_embedding()(cache_position[b_idx])
430
+ position_embeds.append(position_embed)
431
+
432
+ position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
433
+ else:
434
+ position_embeds = self.get_pos_embedding()(cache_position)
435
+ hidden_states = hidden_states + position_embeds
436
+ cos, sin = None, None
437
+
438
+ # (batch, seq_len) -> (batch,)
439
+ seq_positions = cache_position[:, 0]
440
+ if self.attn_impl == "flash_attn":
441
+ max_seq_len = past_key_values[0][0].shape[-2]
442
+ seq_positions = self.convert_sequence_positions_for_flash_attn(
443
+ seq_positions=seq_positions, max_seq_len=max_seq_len
444
+ )
372
445
 
446
+ present_key_values = past_key_values
447
+ for layer in self.layers:
448
+ hidden_states, present_key_values = layer(
449
+ hidden_states=hidden_states,
450
+ attention_mask=attention_mask,
451
+ seq_positions=seq_positions,
452
+ batch_position=batch_position,
453
+ past_key_values=present_key_values,
454
+ cos=cos,
455
+ sin=sin,
456
+ )
373
457
 
374
- class DecoderOnlyWrapper(torch.nn.Module):
375
- def __init__(self, model, max_seq_len, kvcache_partition_len=None):
376
- super().__init__()
377
- self.config = model.config
378
- self.model = model.model
379
- self.lm_head = model.lm_head
380
- self.max_seq_len = max_seq_len
381
- self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
382
-
383
- if kvcache_partition_len is not None:
384
- # WORKAROUND : for passing partition length as a value to the rbln compiler.
385
- # What is actually used is the shape of this tensor.
386
- self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
387
- self.attn_implementation = "flash_attn_rbln"
388
- logger.info(f"Using rbln-flash-attention. (partition length : {kvcache_partition_len})")
389
- else:
390
- self.kvcache_partition_size = None
391
- self.attn_implementation = "eager"
458
+ hidden_states = self.get_last_layernorm()(hidden_states)
459
+ return hidden_states, present_key_values
392
460
 
393
- def get_forward_dict(self):
394
- forward_dict = {
395
- "wrapper": DecoderOnlyModel.forward,
396
- "model": DecoderOnlyDecoderLayer.forward,
397
- "decoder_layer": DECODERONLY_ATTENTION_CLASSES[self.attn_implementation].forward,
398
- }
399
- return forward_dict
400
461
 
401
- def forward(
402
- self,
403
- input_ids_or_inputs_embeds,
404
- attention_mask,
405
- cache_position,
406
- batch_position,
407
- query_idx,
408
- *past_key_values,
409
- ):
410
- if input_ids_or_inputs_embeds.ndim == 2:
411
- # input_ids
412
- input_ids = input_ids_or_inputs_embeds
413
- inputs_embeds = None
414
- elif input_ids_or_inputs_embeds.ndim == 3:
415
- # inputs_embeds
416
- input_ids = None
417
- inputs_embeds = input_ids_or_inputs_embeds
418
- else:
419
- raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
462
+ class DecoderOnlyLayer(nn.Module):
463
+ """A single transformer layer adapted for RBLN compilation with static shapes.
420
464
 
421
- # Formatting list of past_kv to DynamicCache class.
422
- past_key_values = RebelDynamicCache.from_input_format(
423
- cache_position,
424
- self.config.num_hidden_layers,
425
- *past_key_values,
426
- )
465
+ This layer implements a modified transformer block that includes:
466
+ 1. Self-attention mechanism (either standard or flash attention)
467
+ 2. Feed-forward network (FFN)
468
+ 3. Layer normalization
469
+ 4. Residual connections
427
470
 
428
- batch_size = input_ids_or_inputs_embeds.size()[0]
429
- seq_len = input_ids_or_inputs_embeds.size()[1]
430
-
431
- if self.attn_implementation == "eager":
432
- cache_pos_for_partitions = None
433
- elif self.attn_implementation == "flash_attn_rbln":
434
- p_len = self.kvcache_partition_size.size()[0]
435
- num_partition = self.max_seq_len // p_len
436
- if self.max_seq_len % p_len > 0:
437
- raise ValueError(
438
- f"The partition length({p_len}) must be exactly divisible by the max_seq_len({self.max_seq_len})."
439
- )
440
- cache_pos_for_partitions = torch.zeros((batch_size, num_partition), dtype=torch.int32)
471
+ The layer is specifically designed to:
472
+ - Support compilation to RBLN custom ops
473
+ - Maintain static tensor shapes throughout computations
474
+ - Handle both prefill and decode phases efficiently
475
+ - Manage attention state transitions properly
441
476
 
442
- if batch_size > 1: # decode
443
- for b_idx in range(batch_size):
444
- decoding_step = cache_position[b_idx]
445
- cache_pos = decoding_step
446
- for p_idx in range(num_partition):
447
- input_0 = torch.tensor(cache_pos - p_len * p_idx, dtype=torch.int32)
448
- input_1 = torch.tensor(p_len, dtype=torch.int32)
449
- min = torch.minimum(input_0, input_1)
450
- cache_pos_for_partition = torch.maximum(min, torch.tensor(0, dtype=torch.int32))
451
- cache_pos_for_partitions[b_idx][p_idx] = cache_pos_for_partition
452
- else: # prefill
453
- cache_pos = cache_position[0][0]
454
- for p_idx in range(num_partition):
455
- input_0 = torch.tensor(cache_pos - p_len * p_idx, dtype=torch.int32)
456
- input_1 = torch.tensor(p_len, dtype=torch.int32)
457
- min = torch.minimum(input_0, input_1)
458
- cache_pos_for_partition = torch.maximum(min, torch.tensor(0, dtype=torch.int32))
459
- cache_pos_for_partitions[0][p_idx] = cache_pos_for_partition
460
- else:
461
- raise NotImplementedError(f"Unknown attn_implementation: {self.attn_implementation}")
477
+ Args:
478
+ layer: Original transformer layer module to wrap
479
+ self_attn (DecoderOnlyAttention): Modified attention module optimized for RBLN
462
480
 
463
- forward_dict = self.get_forward_dict()
464
- outputs = forward_dict["wrapper"](
465
- self.model,
466
- input_ids=input_ids,
467
- inputs_embeds=inputs_embeds,
468
- attention_mask=attention_mask,
469
- position_ids=cache_position,
470
- past_key_values=past_key_values,
471
- batch_ids=batch_position,
472
- rotary_pos_emb=self.rotary_emb,
473
- cache_pos_for_partitions=cache_pos_for_partitions,
474
- kvcache_partition_size=self.kvcache_partition_size,
475
- forward_dict=forward_dict,
476
- )
481
+ Attributes:
482
+ _original_mod: Reference to original layer for accessing components
483
+ self_attn: Modified attention mechanism mapped to RBLN ops at compile time
484
+ phase: Current operation phase ("prefill" or "decode")
485
+ """
477
486
 
478
- hidden_states = outputs[0]
479
- if seq_len != 1:
480
- hidden_states = hidden_states[:, query_idx.to(torch.int).unsqueeze(0)]
487
+ def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
488
+ super().__init__()
489
+ self._original_mod = layer
490
+ self.self_attn = self_attn
491
+ self._phase = "prefill"
481
492
 
482
- logits = self.lm_head(hidden_states)
493
+ @property
494
+ def phase(self):
495
+ return self._phase
483
496
 
484
- output = (logits,) + outputs[1:]
497
+ @phase.setter
498
+ def phase(self, phase: str):
499
+ self._phase = phase
500
+ self.self_attn.phase = phase
485
501
 
486
- return output, batch_position + query_idx
502
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
503
+ return self._original_mod.input_layernorm
487
504
 
505
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
506
+ return self._original_mod.post_attention_layernorm
488
507
 
489
- class DecoderOnlyDecoderLayer:
490
508
  def forward(
491
509
  self,
492
510
  hidden_states: torch.Tensor,
493
- layer_idx: int,
494
- attention_mask: Optional[torch.Tensor] = None,
495
- position_ids: Optional[torch.LongTensor] = None,
496
- past_key_value: Optional[RebelDynamicCache] = None,
497
- output_attentions: Optional[bool] = None,
498
- use_cache: Optional[bool] = None,
499
- batch_ids: Optional[torch.Tensor] = None,
511
+ attention_mask: torch.Tensor,
512
+ seq_positions: torch.LongTensor,
513
+ batch_position: torch.Tensor,
514
+ past_key_values: Tuple[Tuple[torch.Tensor]],
500
515
  cos: Optional[torch.Tensor] = None,
501
516
  sin: Optional[torch.Tensor] = None,
502
- cache_pos_for_partitions: Optional[torch.Tensor] = None,
503
- kvcache_partition_size: Optional[torch.Tensor] = None,
504
- forward_dict: Optional[Dict[str, classmethod]] = None,
505
- **kwargs,
506
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
517
+ ):
507
518
  residual = hidden_states
519
+ hidden_states = self.get_pre_attention_layernorm()(hidden_states)
508
520
 
509
- hidden_states = self.input_layernorm(hidden_states)
510
-
511
- hidden_states, self_attn_weight, k, v = forward_dict["decoder_layer"](
512
- self.self_attn,
521
+ hidden_states, present_key_values = self.self_attn(
513
522
  hidden_states=hidden_states,
514
523
  attention_mask=attention_mask,
515
- position_ids=position_ids,
516
- past_key_value=past_key_value,
517
- output_attentions=output_attentions,
518
- batch_index=batch_ids,
519
- use_cache=use_cache,
524
+ seq_positions=seq_positions,
525
+ batch_position=batch_position,
526
+ past_key_values=past_key_values,
520
527
  cos=cos,
521
528
  sin=sin,
522
- cache_pos_for_partitions=cache_pos_for_partitions,
523
- kvcache_partition_size=kvcache_partition_size,
524
- **kwargs,
525
529
  )
526
- past_key_value.assign(k, v, layer_idx)
527
-
528
530
  hidden_states = residual + hidden_states
529
531
 
530
532
  # Fully Connected
531
533
  residual = hidden_states
532
- hidden_states = self.post_attention_layernorm(hidden_states)
533
- hidden_states = self.mlp(hidden_states)
534
+ hidden_states = self.get_post_attention_layernorm()(hidden_states)
535
+ hidden_states = self._original_mod.mlp(hidden_states)
534
536
  hidden_states = residual + hidden_states
535
537
 
536
- outputs = (hidden_states,)
538
+ return hidden_states, present_key_values
539
+
540
+
541
+ class DecoderOnlyAttention(nn.Module):
542
+ """Attention implementation for decoder-only models optimized for RBLN compilation.
543
+
544
+ This class implements a modified version of the standard attention mechanism that:
545
+ 1. Supports static shape requirements for RBLN compilation
546
+ 2. Handles explicit batch and position management
547
+
548
+ Args:
549
+ self_attn: Original attention module from the base model
550
+ """
551
+
552
+ def __init__(self, self_attn):
553
+ super().__init__()
554
+ self._original_mod = self_attn
555
+ self.layer_idx = self_attn.layer_idx
556
+ self.num_heads = self._original_mod.num_heads
557
+ self.head_dim = self._original_mod.head_dim
558
+ self._phase = "prefill"
559
+ self.scale = torch.tensor(self.get_attn_scale())
560
+
561
+ if hasattr(self._original_mod, "num_key_value_heads"):
562
+ self.num_key_value_heads = self._original_mod.num_key_value_heads
563
+ else:
564
+ self.num_key_value_heads = self._original_mod.num_heads
565
+
566
+ self.attention = self.get_attention()
567
+ self.__post_init__()
568
+
569
+ @property
570
+ def phase(self):
571
+ return self._phase
572
+
573
+ @phase.setter
574
+ def phase(self, phase: str):
575
+ self._phase = phase
576
+ self.attention.phase = phase
537
577
 
538
- if output_attentions:
539
- outputs += (self_attn_weight,)
578
+ def get_attention(self):
579
+ return AttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads)
540
580
 
541
- if use_cache:
542
- outputs += (past_key_value,)
581
+ def __post_init__(self):
582
+ self.q_proj = self._original_mod.q_proj
583
+ self.k_proj = self._original_mod.k_proj
584
+ self.v_proj = self._original_mod.v_proj
585
+ self.o_proj = self._original_mod.o_proj
543
586
 
544
- return outputs
587
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
588
+ """Projects input hidden states into query, key, and value representations.
545
589
 
590
+ Args:
591
+ hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
592
+
593
+ Returns:
594
+ Tuple of (query_states, key_states, value_states)
595
+ """
596
+ query_states = self.q_proj(hidden_states)
597
+ key_states = self.k_proj(hidden_states)
598
+ value_states = self.v_proj(hidden_states)
599
+ return query_states, key_states, value_states
600
+
601
+ def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
602
+ return apply_rotary_pos_emb(query_states, key_states, cos, sin)
603
+
604
+ def get_attn_scale(self):
605
+ return 1 / math.sqrt(self.head_dim)
546
606
 
547
- class DecoderOnlyModel:
548
607
  def forward(
549
608
  self,
550
- input_ids: torch.LongTensor = None,
551
- attention_mask: Optional[torch.Tensor] = None,
552
- position_ids: Optional[torch.LongTensor] = None,
553
- past_key_values: Optional[RebelDynamicCache] = None,
554
- batch_ids: Optional[torch.Tensor] = None,
555
- inputs_embeds: Optional[torch.FloatTensor] = None,
556
- use_cache: Optional[bool] = True,
557
- output_attentions: Optional[bool] = False,
558
- output_hidden_states: Optional[bool] = False,
559
- cache_pos_for_partitions: Optional[torch.Tensor] = None,
560
- kvcache_partition_size: Optional[torch.Tensor] = None,
561
- forward_dict: Optional[Dict[str, classmethod]] = None,
562
- rotary_pos_emb=None,
563
- ) -> BaseModelOutputWithPast:
564
- # retrieve input_ids and inputs_embeds
565
- if (input_ids is None) ^ (inputs_embeds is not None):
566
- raise ValueError(
567
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
568
- )
609
+ hidden_states: torch.Tensor,
610
+ attention_mask: torch.Tensor,
611
+ seq_positions: torch.LongTensor,
612
+ batch_position: torch.Tensor,
613
+ past_key_values: Tuple[Tuple[torch.Tensor]],
614
+ cos: Optional[torch.Tensor] = None,
615
+ sin: Optional[torch.Tensor] = None,
616
+ ):
617
+ batch_size, query_length, _ = hidden_states.size()
569
618
 
570
- # embed positions
571
- if inputs_embeds is None:
572
- inputs_embeds = self.embed_tokens(input_ids)
573
-
574
- hidden_states = inputs_embeds
575
- attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
576
-
577
- # get cos,sin vector
578
- cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
579
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
580
-
581
- # decoder layers
582
- all_hidden_states = () if output_hidden_states else None
583
- all_self_attns = () if output_attentions else None
584
-
585
- for layer_idx, decoder_layer in enumerate(self.layers):
586
- if output_hidden_states:
587
- all_hidden_states += (hidden_states,)
588
- layer_outputs = forward_dict["model"](
589
- decoder_layer,
590
- hidden_states,
591
- layer_idx,
592
- attention_mask=attention_mask,
593
- position_ids=position_ids,
594
- past_key_value=past_key_values,
595
- output_attentions=output_attentions,
596
- use_cache=use_cache,
597
- batch_ids=batch_ids,
598
- cos=cos,
599
- sin=sin,
600
- cache_pos_for_partitions=cache_pos_for_partitions,
601
- kvcache_partition_size=kvcache_partition_size,
602
- forward_dict=forward_dict,
619
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
620
+
621
+ query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
622
+ key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
623
+ value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
624
+ 1, 2
625
+ )
626
+ # b, num_head, query, head_dim
627
+
628
+ if cos is not None and sin is not None:
629
+ query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
630
+
631
+ if batch_size > 1 and self.phase == "prefill":
632
+ raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
633
+
634
+ # TODO(jongho): flash attn legacy. (clone)
635
+ _seq_positions = seq_positions.clone().unsqueeze(1)
636
+
637
+ _key_states = []
638
+ _value_states = []
639
+ _attn_outputs = []
640
+ for b in range(batch_size):
641
+ seq_position = _seq_positions[b][0]
642
+ attn_output, key_state, value_state = self.attention(
643
+ query_states[b].unsqueeze(0),
644
+ key_states[b].unsqueeze(0),
645
+ value_states[b].unsqueeze(0),
646
+ attention_mask[b].unsqueeze(0) if self.phase == "decode" else attention_mask,
647
+ past_key_state=past_key_values[self.layer_idx][0],
648
+ past_value_state=past_key_values[self.layer_idx][1],
649
+ batch_position=b if self.phase == "decode" else batch_position,
650
+ seq_position=seq_position,
651
+ scale=self.scale,
603
652
  )
653
+ _key_states.append(key_state)
654
+ _value_states.append(value_state)
655
+ _attn_outputs.append(attn_output)
656
+ key_states = torch.cat(_key_states, dim=0)
657
+ value_states = torch.cat(_value_states, dim=0)
658
+ attn_outputs = torch.cat(_attn_outputs, dim=0)
604
659
 
605
- hidden_states = layer_outputs[0]
660
+ attn_outputs = self.o_proj(attn_outputs)
661
+ past_key_values[self.layer_idx] = key_states, value_states
662
+ return attn_outputs, past_key_values
606
663
 
607
- updated_cache = layer_outputs[2 if output_attentions else 1]
608
664
 
609
- if output_attentions:
610
- all_self_attns += (layer_outputs[1],)
665
+ class AttentionOp(nn.Module):
666
+ def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int):
667
+ super().__init__()
668
+ self.num_heads = num_heads
669
+ self.head_dim = head_dim
670
+ self.num_key_value_heads = num_key_value_heads
671
+ self.phase = "prefill"
611
672
 
612
- hidden_states = self.norm(hidden_states)
673
+ def forward(
674
+ self,
675
+ query_state: torch.Tensor,
676
+ key_state: torch.Tensor,
677
+ value_state: torch.Tensor,
678
+ attn_mask: torch.Tensor,
679
+ batch_position: torch.Tensor,
680
+ past_key_state: torch.Tensor,
681
+ past_value_state: torch.Tensor,
682
+ seq_position: torch.Tensor,
683
+ scale: torch.Tensor,
684
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
685
+ """Compute attention with static shapes and explicit cache management.
686
+
687
+ Args:
688
+ query_state: Query tensor [1, num_heads, 1, head_dim]
689
+ key_state: Key tensor [1, num_heads, seq_len, head_dim]
690
+ value_state: Value tensor [1, num_heads, seq_len, head_dim]
691
+ attn_mask: Attention mask tensor ∈ {0, 1}
692
+ batch_position: Batch index for cache lookup
693
+ past_key_state: Previous key cache states
694
+ past_value_state: Previous value cache states
695
+ seq_position: Current position in sequence
696
+ scale: Scale applied to attn weights
697
+
698
+ Returns:
699
+ Tuple of (attention_output, key_state, value_state)
700
+ """
701
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
702
+ key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
703
+ value_state = value_state.unsqueeze(2)
704
+ attn_mask = attn_mask.unsqueeze(2)
613
705
 
614
- # add hidden states from the last decoder layer
615
- if output_hidden_states:
616
- all_hidden_states += (hidden_states,)
706
+ query_state = query_state.view(
707
+ 1,
708
+ self.num_key_value_heads,
709
+ self.num_heads // self.num_key_value_heads,
710
+ -1, # seq len
711
+ self.head_dim,
712
+ )
617
713
 
618
- # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
619
- next_cache = updated_cache.to_legacy_cache()
714
+ if self.phase == "decode":
715
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.attn_decode(
716
+ query_state,
717
+ key_state,
718
+ value_state,
719
+ attn_mask,
720
+ past_key_state.unsqueeze(2),
721
+ past_value_state.unsqueeze(2),
722
+ seq_position,
723
+ scale,
724
+ )
620
725
 
621
- return BaseModelOutputWithPast(
622
- last_hidden_state=hidden_states,
623
- past_key_values=next_cache,
624
- hidden_states=all_hidden_states,
625
- attentions=all_self_attns,
626
- )
726
+ else:
727
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.attn_prefill(
728
+ query_state,
729
+ key_state,
730
+ value_state,
731
+ attn_mask,
732
+ past_key_state.unsqueeze(2),
733
+ past_value_state.unsqueeze(2),
734
+ batch_position,
735
+ seq_position,
736
+ scale,
737
+ )
627
738
 
739
+ attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
740
+ attn_output = attn_output.transpose(1, 2).contiguous()
741
+ attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
742
+
743
+ return attn_output, key_state.squeeze(2), value_state.squeeze(2)
628
744
 
629
- def slice_and_unsqueeze_cos_sin(cos, sin, position_ids, unsqueeze_dim=1):
630
- """Slice cos[position_ids], sin[position_ids] vector for the query."""
631
- if position_ids.shape[0] > 1:
745
+
746
+ def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
747
+ """Slice cos[cache_position], sin[cache_position] vector for the query."""
748
+ if cache_position.shape[0] > 1:
632
749
  cos_all = []
633
750
  sin_all = []
634
- for i in range(position_ids.shape[0]):
635
- cos_all.append(cos[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
636
- sin_all.append(sin[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
751
+ for i in range(cache_position.shape[0]):
752
+ cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
753
+ sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
637
754
  cos = torch.cat(cos_all, dim=0)
638
755
  sin = torch.cat(sin_all, dim=0)
639
756
  else:
640
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
641
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
757
+ cos = cos[cache_position].unsqueeze(unsqueeze_dim)
758
+ sin = sin[cache_position].unsqueeze(unsqueeze_dim)
642
759
 
643
760
  return cos, sin
644
761
 
@@ -658,6 +775,26 @@ def apply_rotary_pos_emb(q, k, cos, sin):
658
775
  return q_embed, k_embed
659
776
 
660
777
 
778
+ def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
779
+ # Partial rotary embedding
780
+ query_rot, query_pass = (
781
+ query_states[..., :ndim],
782
+ query_states[..., ndim:],
783
+ )
784
+ key_rot, key_pass = (
785
+ key_states[..., :ndim],
786
+ key_states[..., ndim:],
787
+ )
788
+
789
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
790
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
791
+
792
+ # [batch_size, seq_length, num_heads, head_dim]
793
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
794
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
795
+ return query_states, key_states
796
+
797
+
661
798
  class RotaryEmbedding(nn.Module):
662
799
  """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
663
800
 
@@ -674,14 +811,14 @@ class RotaryEmbedding(nn.Module):
674
811
  rope_type = "default"
675
812
 
676
813
  inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
677
- position_ids = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
678
- position_ids_expanded = position_ids[:, None]
814
+ cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
815
+ cache_position_expanded = cache_position[:, None]
679
816
 
680
817
  if rope_type == "dynamic":
681
- freqs = position_ids_expanded.float() * inv_freq.float()
818
+ freqs = cache_position_expanded.float() * inv_freq.float()
682
819
  else:
683
820
  inv_freq_expanded = inv_freq[None, :]
684
- freqs = position_ids_expanded.float() @ inv_freq_expanded.float()
821
+ freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
685
822
 
686
823
  emb = torch.cat((freqs, freqs), dim=-1)
687
824
 
@@ -696,3 +833,127 @@ class RotaryEmbedding(nn.Module):
696
833
  self._cos_cached[:seq_len].to(dtype=x.dtype),
697
834
  self._sin_cached[:seq_len].to(dtype=x.dtype),
698
835
  )
836
+
837
+
838
+ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
839
+ def __init__(self, self_attn, kvcache_partition_len):
840
+ self.kvcache_partition_size = kvcache_partition_len
841
+ super().__init__(self_attn=self_attn)
842
+
843
+ def get_attention(self):
844
+ return FlashAttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads, self.kvcache_partition_size)
845
+
846
+ def forward(
847
+ self,
848
+ hidden_states: torch.Tensor,
849
+ attention_mask: torch.Tensor,
850
+ seq_positions: torch.LongTensor,
851
+ batch_position: torch.Tensor,
852
+ past_key_values: Tuple[Tuple[torch.Tensor]],
853
+ cos: Optional[torch.Tensor] = None,
854
+ sin: Optional[torch.Tensor] = None,
855
+ ):
856
+ batch_size, query_length, _ = hidden_states.size()
857
+
858
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
859
+
860
+ query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
861
+ key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
862
+ value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
863
+ 1, 2
864
+ )
865
+ # b, num_head, query, head_dim
866
+
867
+ if cos is not None and sin is not None:
868
+ query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
869
+
870
+ _key_states = []
871
+ _value_states = []
872
+ _attn_outputs = []
873
+ for b in range(batch_size):
874
+ seq_position = seq_positions[b][0] # FIXME: Remove take-take pattern matching
875
+ attn_output, key_state, value_state = self.attention(
876
+ query_states[b].unsqueeze(0),
877
+ key_states[b].unsqueeze(0),
878
+ value_states[b].unsqueeze(0),
879
+ attention_mask[b].unsqueeze(0) if self.phase == "decode" else attention_mask,
880
+ past_key_state=past_key_values[self.layer_idx][0],
881
+ past_value_state=past_key_values[self.layer_idx][1],
882
+ batch_position=b if self.phase == "decode" else batch_position,
883
+ seq_position=seq_position,
884
+ scale=self.scale,
885
+ )
886
+ _key_states.append(key_state)
887
+ _value_states.append(value_state)
888
+ _attn_outputs.append(attn_output)
889
+ key_states = torch.cat(_key_states, dim=0)
890
+ value_states = torch.cat(_value_states, dim=0)
891
+ attn_outputs = torch.cat(_attn_outputs, dim=0)
892
+
893
+ attn_outputs = self.o_proj(attn_outputs)
894
+ past_key_values[self.layer_idx] = key_states, value_states
895
+ return attn_outputs, past_key_values
896
+
897
+
898
+ class FlashAttentionOp(AttentionOp):
899
+ def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int, kvcache_partition_len: int):
900
+ super().__init__(num_heads=num_heads, head_dim=head_dim, num_key_value_heads=num_key_value_heads)
901
+ self.kvcache_partition_size = kvcache_partition_len
902
+
903
+ def forward(
904
+ self,
905
+ query_state,
906
+ key_state,
907
+ value_state,
908
+ attn_mask,
909
+ batch_position,
910
+ past_key_state,
911
+ past_value_state,
912
+ seq_position,
913
+ scale,
914
+ ):
915
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
916
+ key_state = key_state.unsqueeze(2)
917
+ value_state = value_state.unsqueeze(2)
918
+ attn_mask = attn_mask.unsqueeze(2)
919
+
920
+ query_state = query_state.view(
921
+ 1,
922
+ self.num_key_value_heads,
923
+ self.num_heads // self.num_key_value_heads,
924
+ -1, # seq len
925
+ self.head_dim,
926
+ )
927
+
928
+ if self.phase == "decode":
929
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_decode(
930
+ query_state,
931
+ key_state,
932
+ value_state,
933
+ attn_mask,
934
+ past_key_state.unsqueeze(2),
935
+ past_value_state.unsqueeze(2),
936
+ seq_position,
937
+ scale,
938
+ self.kvcache_partition_size,
939
+ )
940
+ else:
941
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_prefill(
942
+ query_state,
943
+ key_state,
944
+ value_state,
945
+ attn_mask,
946
+ past_key_state.unsqueeze(2),
947
+ past_value_state.unsqueeze(2),
948
+ batch_position,
949
+ seq_position,
950
+ scale,
951
+ self.kvcache_partition_size,
952
+ )
953
+
954
+ # reshape for removing repeat_kv
955
+ attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
956
+ attn_output = attn_output.transpose(1, 2).contiguous()
957
+ attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
958
+
959
+ return attn_output, key_state, value_state