optimum-rbln 0.8.1rc0__py3-none-any.whl → 0.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +58 -9
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +24 -5
- optimum/rbln/diffusers/configurations/models/__init__.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +5 -3
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
- optimum/rbln/diffusers/configurations/models/{configuration_cosmos_transformer.py → configuration_transformer_cosmos.py} +7 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +10 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
- optimum/rbln/diffusers/modeling_diffusers.py +4 -5
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +1 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +1 -1
- optimum/rbln/diffusers/pipelines/__init__.py +1 -5
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +12 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +4 -26
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +2 -2
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +2 -2
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +4 -5
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +60 -0
- optimum/rbln/transformers/configuration_generic.py +4 -4
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/modeling_generic.py +1 -4
- optimum/rbln/transformers/models/__init__.py +45 -30
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
- optimum/rbln/transformers/models/clip/configuration_clip.py +14 -3
- optimum/rbln/transformers/models/clip/modeling_clip.py +123 -28
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -454
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +579 -362
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +17 -42
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +3 -44
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +21 -9
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +9 -63
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +200 -292
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +19 -24
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +419 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +20 -3
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
- optimum/rbln/transformers/models/midm/midm_architecture.py +14 -22
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
- optimum/rbln/transformers/models/opt/opt_architecture.py +16 -25
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +16 -22
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +315 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -15
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +1 -4
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -12
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -5
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -12
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +8 -2
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/utils/depreacate_utils.py +16 -0
- optimum/rbln/utils/hub.py +8 -47
- optimum/rbln/utils/runtime_utils.py +31 -5
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/RECORD +120 -103
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -20,108 +20,13 @@ from torch import nn
|
|
|
20
20
|
from transformers import PretrainedConfig, PreTrainedModel
|
|
21
21
|
|
|
22
22
|
from ....utils import logging
|
|
23
|
+
from ...modeling_attention_utils import DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
|
23
24
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
24
25
|
from .configuration_decoderonly import CacheImplType
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
logger = logging.get_logger(__name__)
|
|
28
29
|
|
|
29
|
-
DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
|
|
30
|
-
DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
|
|
31
|
-
MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
|
|
32
|
-
MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
|
|
33
|
-
MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
|
|
34
|
-
MAX_SLIDING_WINDOW_SIZE = 32_768
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def set_default_values(
|
|
38
|
-
attn_impl: Optional[str] = None,
|
|
39
|
-
kvcache_partition_len: Optional[int] = None,
|
|
40
|
-
kvcache_block_size: Optional[int] = None,
|
|
41
|
-
max_seq_len: Optional[int] = None,
|
|
42
|
-
) -> Tuple[str, int, int]:
|
|
43
|
-
if attn_impl is None:
|
|
44
|
-
attn_impl = "eager"
|
|
45
|
-
|
|
46
|
-
if kvcache_partition_len is not None:
|
|
47
|
-
if attn_impl == "eager":
|
|
48
|
-
attn_impl = "flash_attn"
|
|
49
|
-
logger.warning(
|
|
50
|
-
"A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
|
|
51
|
-
"set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
|
|
52
|
-
"`attn_impl` has been automatically switched to 'flash_attn'."
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
if kvcache_partition_len is None and attn_impl == "flash_attn":
|
|
56
|
-
kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
|
57
|
-
|
|
58
|
-
if kvcache_block_size is None:
|
|
59
|
-
if attn_impl == "eager":
|
|
60
|
-
kvcache_block_size = max_seq_len
|
|
61
|
-
else:
|
|
62
|
-
kvcache_block_size = kvcache_partition_len
|
|
63
|
-
|
|
64
|
-
return attn_impl, kvcache_partition_len, kvcache_block_size
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
|
|
68
|
-
if attn_impl not in ["eager", "flash_attn"]:
|
|
69
|
-
raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
|
|
70
|
-
|
|
71
|
-
## Checking Constraints...
|
|
72
|
-
# Constraint of eager attention:
|
|
73
|
-
# - `max_seq_len` <= 32k
|
|
74
|
-
|
|
75
|
-
# Constraints of flash attention:
|
|
76
|
-
# 1. `max_seq_len` should be multiple of `partition_len`.
|
|
77
|
-
# 2. 4k <= `partition_len` <= 32k.
|
|
78
|
-
# 3. `max_seq_len` should be larger then 8k.
|
|
79
|
-
if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
|
|
80
|
-
raise ValueError(
|
|
81
|
-
f"`max_seq_len` is set to {max_seq_len}, "
|
|
82
|
-
f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
|
|
83
|
-
f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
|
|
84
|
-
" or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
if attn_impl == "flash_attn":
|
|
88
|
-
if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
|
|
89
|
-
raise ValueError(
|
|
90
|
-
f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
|
|
91
|
-
f"when using 'flash_attn'. Please adjust either value to meet this requirement."
|
|
92
|
-
)
|
|
93
|
-
elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
|
|
94
|
-
raise ValueError(
|
|
95
|
-
f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
|
|
96
|
-
f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
|
|
97
|
-
f"Please provide a valid value within this range."
|
|
98
|
-
)
|
|
99
|
-
elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
|
|
100
|
-
raise ValueError(
|
|
101
|
-
f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
|
|
102
|
-
f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
|
|
103
|
-
"this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
if kvcache_block_size is not None:
|
|
107
|
-
if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
|
|
108
|
-
raise ValueError(
|
|
109
|
-
f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
|
110
|
-
f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
|
|
111
|
-
)
|
|
112
|
-
elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
|
|
113
|
-
raise ValueError(
|
|
114
|
-
f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
|
115
|
-
f"must always be set equal to the `max_seq_len` {max_seq_len}."
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def validate_sliding_window_size(sliding_window: int, prefill_chunk_size: int):
|
|
120
|
-
if sliding_window > MAX_SLIDING_WINDOW_SIZE - prefill_chunk_size:
|
|
121
|
-
raise ValueError(
|
|
122
|
-
f"Sliding window size ({sliding_window}) must be less than 32768 - prefill_chunk_size ({32768 - prefill_chunk_size})"
|
|
123
|
-
)
|
|
124
|
-
|
|
125
30
|
|
|
126
31
|
class DecoderOnlyWrapper(nn.Module):
|
|
127
32
|
"""A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
|
|
@@ -149,9 +54,11 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
149
54
|
This is only relevant if `attn_impl` is set to "flash_attn`
|
|
150
55
|
"""
|
|
151
56
|
|
|
57
|
+
_use_learned_pos_emb = False
|
|
58
|
+
|
|
152
59
|
def __init__(
|
|
153
60
|
self,
|
|
154
|
-
|
|
61
|
+
model: PreTrainedModel,
|
|
155
62
|
max_seq_len: int,
|
|
156
63
|
use_rotary_emb: bool,
|
|
157
64
|
attn_impl: str,
|
|
@@ -159,14 +66,14 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
159
66
|
use_inputs_embeds: bool,
|
|
160
67
|
use_attention_mask: bool,
|
|
161
68
|
use_position_ids: bool,
|
|
162
|
-
use_learned_pos_emb: Optional[bool] = None,
|
|
163
69
|
kvcache_partition_len: Optional[int] = None,
|
|
164
70
|
kvcache_block_size: Optional[int] = None,
|
|
165
71
|
sliding_window: Optional[int] = None,
|
|
166
72
|
sliding_window_layers: Optional[List[int]] = None,
|
|
167
73
|
):
|
|
168
74
|
super().__init__()
|
|
169
|
-
self.config =
|
|
75
|
+
self.config = model.config
|
|
76
|
+
self.is_causal_lm = getattr(model, "lm_head", None) is not None
|
|
170
77
|
|
|
171
78
|
if use_rotary_emb:
|
|
172
79
|
rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
|
|
@@ -182,9 +89,10 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
182
89
|
self.use_attention_mask = use_attention_mask
|
|
183
90
|
self.use_position_ids = use_position_ids
|
|
184
91
|
self.use_inputs_embeds = use_inputs_embeds
|
|
185
|
-
self.use_learned_pos_emb = use_learned_pos_emb
|
|
186
92
|
self.sliding_window_layers = sliding_window_layers
|
|
187
93
|
self.cache_impl = cache_impl
|
|
94
|
+
self.use_global_attention = cache_impl in ["static", "hybrid"]
|
|
95
|
+
self.use_local_attention = cache_impl in ["hybrid", "sliding_window"]
|
|
188
96
|
self.sliding_window = sliding_window
|
|
189
97
|
|
|
190
98
|
if self.attn_impl == "flash_attn":
|
|
@@ -200,59 +108,67 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
200
108
|
f" or equal to max_seq_len({max_seq_len})!"
|
|
201
109
|
)
|
|
202
110
|
|
|
203
|
-
self.
|
|
111
|
+
self.model = self.convert_to_rbln_class(model, max_seq_len)
|
|
204
112
|
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
|
|
205
113
|
self._phase = "prefill"
|
|
206
114
|
|
|
207
115
|
def get_rotary_emb(self, max_seq_len):
|
|
208
116
|
return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
|
209
117
|
|
|
210
|
-
def
|
|
118
|
+
def get_decoder_layers(self, model: PreTrainedModel):
|
|
119
|
+
return model.model.layers if self.is_causal_lm else model.layers
|
|
120
|
+
|
|
121
|
+
def get_attn_layer(self, layer: nn.Module):
|
|
122
|
+
return layer.self_attn
|
|
123
|
+
|
|
124
|
+
def get_model_layer(self, model: PreTrainedModel):
|
|
125
|
+
return model.model if self.is_causal_lm else model
|
|
126
|
+
|
|
127
|
+
def get_rbln_attn_class(self):
|
|
128
|
+
return DecoderOnlyAttention
|
|
129
|
+
|
|
130
|
+
def get_rbln_layer_class(self):
|
|
131
|
+
return DecoderOnlyLayer
|
|
132
|
+
|
|
133
|
+
def get_rbln_model_class(self):
|
|
134
|
+
return DecoderOnlyModel
|
|
135
|
+
|
|
136
|
+
def get_rbln_causal_lm_class(self):
|
|
137
|
+
return DecoderOnlyForCausalLM
|
|
138
|
+
|
|
139
|
+
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
|
|
211
140
|
new_layers = []
|
|
212
|
-
for layer_idx, layer in enumerate(
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
self.use_attention_mask,
|
|
227
|
-
self.use_position_ids,
|
|
228
|
-
kvcache_block_size=self.kvcache_block_size,
|
|
229
|
-
is_sliding=False,
|
|
230
|
-
)
|
|
231
|
-
elif self.attn_impl == "flash_attn":
|
|
232
|
-
new_self_attn = DecoderOnlyFlashAttention(
|
|
233
|
-
layer.self_attn,
|
|
234
|
-
kvcache_partition_len=self.kvcache_partition_len,
|
|
235
|
-
kvcache_block_size=self.kvcache_block_size,
|
|
236
|
-
use_attention_mask=self.use_attention_mask,
|
|
237
|
-
use_position_ids=self.use_position_ids,
|
|
238
|
-
)
|
|
239
|
-
else:
|
|
240
|
-
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
|
241
|
-
|
|
242
|
-
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
|
141
|
+
for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
|
|
142
|
+
is_sliding = layer_idx in self.sliding_window_layers
|
|
143
|
+
new_self_attn = self.get_rbln_attn_class()(
|
|
144
|
+
self.get_attn_layer(layer),
|
|
145
|
+
self.use_attention_mask if not is_sliding else True,
|
|
146
|
+
self.use_position_ids,
|
|
147
|
+
kvcache_block_size=self.sliding_window
|
|
148
|
+
if layer_idx in self.sliding_window_layers
|
|
149
|
+
else self.kvcache_block_size,
|
|
150
|
+
is_sliding=is_sliding,
|
|
151
|
+
attn_impl=self.attn_impl if not is_sliding else "eager",
|
|
152
|
+
kvcache_partition_len=self.kvcache_partition_len,
|
|
153
|
+
)
|
|
154
|
+
new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
|
|
243
155
|
new_layers.append(new_layer)
|
|
244
156
|
|
|
245
|
-
new_model =
|
|
246
|
-
|
|
157
|
+
new_model = self.get_rbln_model_class()(
|
|
158
|
+
self.get_model_layer(model),
|
|
247
159
|
new_layers,
|
|
248
160
|
partition_len=self.kvcache_partition_len,
|
|
249
161
|
max_seq_len=max_seq_len,
|
|
250
162
|
kvcache_block_size=self.kvcache_block_size,
|
|
251
|
-
use_learned_pos_emb=self.
|
|
163
|
+
use_learned_pos_emb=self.__class__._use_learned_pos_emb,
|
|
252
164
|
sliding_window_layers=self.sliding_window_layers,
|
|
253
165
|
)
|
|
254
|
-
|
|
255
|
-
|
|
166
|
+
|
|
167
|
+
if self.is_causal_lm:
|
|
168
|
+
new_model = self.get_rbln_causal_lm_class()(model, new_model)
|
|
169
|
+
return new_model
|
|
170
|
+
else:
|
|
171
|
+
return new_model
|
|
256
172
|
|
|
257
173
|
@property
|
|
258
174
|
def phase(self) -> str:
|
|
@@ -261,16 +177,21 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
261
177
|
@phase.setter
|
|
262
178
|
def phase(self, phase: str):
|
|
263
179
|
self._phase = phase
|
|
264
|
-
self.
|
|
180
|
+
self.model.phase = phase
|
|
265
181
|
|
|
266
182
|
def prepare_forward_args(self, *args):
|
|
267
183
|
args = list(args)
|
|
268
184
|
input_ids = None if self.use_inputs_embeds else args.pop(0)
|
|
269
185
|
inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
|
|
270
186
|
cache_position = args.pop(0)
|
|
271
|
-
global_block_tables = args.pop(0) if self.
|
|
272
|
-
local_block_tables = args.pop(0) if self.
|
|
273
|
-
query_position =
|
|
187
|
+
global_block_tables = args.pop(0) if self.use_global_attention else None
|
|
188
|
+
local_block_tables = args.pop(0) if self.use_local_attention else None
|
|
189
|
+
query_position = (
|
|
190
|
+
args.pop(0)
|
|
191
|
+
# query_position usage: 1. causal_lm prefill or 2. sliding_window cache_position
|
|
192
|
+
if ("prefill" in self.phase and (self.is_causal_lm or self.use_local_attention))
|
|
193
|
+
else None
|
|
194
|
+
)
|
|
274
195
|
attention_mask = args.pop(0) if self.use_attention_mask else None
|
|
275
196
|
position_ids = args.pop(0) if self.use_position_ids else None
|
|
276
197
|
past_key_values = args
|
|
@@ -322,7 +243,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
322
243
|
rotary_emb,
|
|
323
244
|
) = self.prepare_forward_args(*args)
|
|
324
245
|
|
|
325
|
-
logit = self.
|
|
246
|
+
logit = self.model(
|
|
326
247
|
input_ids=input_ids,
|
|
327
248
|
inputs_embeds=inputs_embeds,
|
|
328
249
|
attention_mask=attention_mask,
|
|
@@ -679,9 +600,23 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
679
600
|
|
|
680
601
|
Args:
|
|
681
602
|
self_attn: Original attention module from the base model
|
|
603
|
+
use_attention_mask: Whether to use attention mask
|
|
604
|
+
use_position_ids: Whether to use position ids
|
|
605
|
+
kvcache_block_size: Block size for KV cache
|
|
606
|
+
is_sliding: Whether this is sliding window attention
|
|
607
|
+
attn_impl: Attention implementation type ("eager" or "flash_attn")
|
|
682
608
|
"""
|
|
683
609
|
|
|
684
|
-
def __init__(
|
|
610
|
+
def __init__(
|
|
611
|
+
self,
|
|
612
|
+
self_attn,
|
|
613
|
+
use_attention_mask,
|
|
614
|
+
use_position_ids,
|
|
615
|
+
kvcache_block_size,
|
|
616
|
+
is_sliding=False,
|
|
617
|
+
attn_impl="eager",
|
|
618
|
+
kvcache_partition_len=None,
|
|
619
|
+
):
|
|
685
620
|
super().__init__()
|
|
686
621
|
self._original_mod = self_attn
|
|
687
622
|
self.layer_idx = self_attn.layer_idx
|
|
@@ -702,10 +637,24 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
702
637
|
self.use_attention_mask = use_attention_mask
|
|
703
638
|
self.use_position_ids = use_position_ids
|
|
704
639
|
self.is_sliding = is_sliding
|
|
705
|
-
self.
|
|
640
|
+
self.attn_impl = attn_impl
|
|
641
|
+
self.kvcache_partition_len = kvcache_partition_len
|
|
642
|
+
|
|
643
|
+
setattr(self, self.get_attention_name(), self.create_attention_op())
|
|
706
644
|
self.kvcache_block_size = kvcache_block_size
|
|
707
645
|
self.__post_init__()
|
|
708
646
|
|
|
647
|
+
def get_attention_name(self):
|
|
648
|
+
if self.is_sliding:
|
|
649
|
+
return "sliding_window_attention"
|
|
650
|
+
elif self.attn_impl == "flash_attn":
|
|
651
|
+
return "flash_attention"
|
|
652
|
+
else:
|
|
653
|
+
return "attention"
|
|
654
|
+
|
|
655
|
+
def get_attention_op(self):
|
|
656
|
+
return getattr(self, self.get_attention_name())
|
|
657
|
+
|
|
709
658
|
@property
|
|
710
659
|
def phase(self):
|
|
711
660
|
return self._phase
|
|
@@ -713,17 +662,36 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
713
662
|
@phase.setter
|
|
714
663
|
def phase(self, phase: str):
|
|
715
664
|
self._phase = phase
|
|
716
|
-
self.
|
|
665
|
+
getattr(self, self.get_attention_name()).phase = phase
|
|
717
666
|
|
|
718
|
-
def
|
|
667
|
+
def create_attention_op(self):
|
|
719
668
|
if self.is_sliding:
|
|
720
669
|
return SlidingWindowAttentionOp(
|
|
721
|
-
self.num_heads,
|
|
670
|
+
self.num_heads,
|
|
671
|
+
self.head_dim,
|
|
672
|
+
self.num_key_value_heads,
|
|
673
|
+
self.use_attention_mask,
|
|
674
|
+
self.use_position_ids,
|
|
722
675
|
)
|
|
723
|
-
|
|
676
|
+
elif self.attn_impl == "flash_attn":
|
|
677
|
+
return FlashAttentionOp(
|
|
678
|
+
self.num_heads,
|
|
679
|
+
self.head_dim,
|
|
680
|
+
self.num_key_value_heads,
|
|
681
|
+
self.kvcache_partition_len,
|
|
682
|
+
self.use_attention_mask,
|
|
683
|
+
self.use_position_ids,
|
|
684
|
+
)
|
|
685
|
+
elif self.attn_impl == "eager":
|
|
724
686
|
return AttentionOp(
|
|
725
|
-
self.num_heads,
|
|
687
|
+
self.num_heads,
|
|
688
|
+
self.head_dim,
|
|
689
|
+
self.num_key_value_heads,
|
|
690
|
+
self.use_attention_mask,
|
|
691
|
+
self.use_position_ids,
|
|
726
692
|
)
|
|
693
|
+
else:
|
|
694
|
+
raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
|
|
727
695
|
|
|
728
696
|
def __post_init__(self):
|
|
729
697
|
self.q_proj = self._original_mod.q_proj
|
|
@@ -780,7 +748,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
780
748
|
if batch_size > 1 and "prefill" in self.phase:
|
|
781
749
|
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
|
782
750
|
|
|
783
|
-
attn_output = self.
|
|
751
|
+
attn_output = self.get_attention_op()(
|
|
784
752
|
query_states,
|
|
785
753
|
key_states,
|
|
786
754
|
value_states,
|
|
@@ -797,6 +765,14 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
797
765
|
return attn_outputs
|
|
798
766
|
|
|
799
767
|
|
|
768
|
+
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
769
|
+
def __init__(self, *args, **kwargs):
|
|
770
|
+
super().__init__(*args, **kwargs)
|
|
771
|
+
logger.warning(
|
|
772
|
+
"DecoderOnlyFlashAttention is deprecated and may not work as expected. Use DecoderOnlyAttention instead."
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
|
|
800
776
|
class AttentionOp(nn.Module):
|
|
801
777
|
def __init__(
|
|
802
778
|
self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool, use_position_ids: bool
|
|
@@ -809,6 +785,18 @@ class AttentionOp(nn.Module):
|
|
|
809
785
|
self.use_attention_mask = use_attention_mask
|
|
810
786
|
self.use_position_ids = use_position_ids
|
|
811
787
|
|
|
788
|
+
def get_attn_op_name(self):
|
|
789
|
+
phase = "decode" if self.phase == "decode" else "prefill"
|
|
790
|
+
|
|
791
|
+
if self.use_attention_mask and not self.use_position_ids:
|
|
792
|
+
attn_op_name = "paged_attn_"
|
|
793
|
+
else:
|
|
794
|
+
attn_op_name = "paged_causal_attn_"
|
|
795
|
+
|
|
796
|
+
attn_op_name += phase
|
|
797
|
+
|
|
798
|
+
return attn_op_name
|
|
799
|
+
|
|
812
800
|
def forward(
|
|
813
801
|
self,
|
|
814
802
|
query_state: torch.Tensor,
|
|
@@ -857,63 +845,31 @@ class AttentionOp(nn.Module):
|
|
|
857
845
|
self.head_dim,
|
|
858
846
|
)
|
|
859
847
|
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
mask=attn_mask if self.use_position_ids else None,
|
|
886
|
-
)
|
|
887
|
-
|
|
888
|
-
else:
|
|
889
|
-
if self.use_attention_mask and not self.use_position_ids:
|
|
890
|
-
attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
|
|
891
|
-
q=query_state,
|
|
892
|
-
k=key_state,
|
|
893
|
-
v=value_state,
|
|
894
|
-
mask=attn_mask,
|
|
895
|
-
kcache=past_key_state.unsqueeze(2),
|
|
896
|
-
vcache=past_value_state.unsqueeze(2),
|
|
897
|
-
seq=seq_position,
|
|
898
|
-
scale=scale,
|
|
899
|
-
block_table=block_tables,
|
|
900
|
-
block_size=block_size,
|
|
901
|
-
)
|
|
902
|
-
else:
|
|
903
|
-
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
|
|
904
|
-
q=query_state,
|
|
905
|
-
k=key_state,
|
|
906
|
-
v=value_state,
|
|
907
|
-
kcache=past_key_state.unsqueeze(2),
|
|
908
|
-
vcache=past_value_state.unsqueeze(2),
|
|
909
|
-
seq=seq_position,
|
|
910
|
-
scale=scale,
|
|
911
|
-
block_table=block_tables,
|
|
912
|
-
block_size=block_size,
|
|
913
|
-
is_bidirectional=True if self.phase == "image_prefill" else False, # FIXME, Hard-coded for Gemma3.
|
|
914
|
-
mask=attn_mask if self.use_position_ids else None,
|
|
915
|
-
)
|
|
916
|
-
|
|
848
|
+
op_args = {
|
|
849
|
+
"q": query_state,
|
|
850
|
+
"k": key_state,
|
|
851
|
+
"v": value_state,
|
|
852
|
+
"kcache": past_key_state.unsqueeze(2),
|
|
853
|
+
"vcache": past_value_state.unsqueeze(2),
|
|
854
|
+
"seq": seq_position,
|
|
855
|
+
"scale": scale,
|
|
856
|
+
"block_table": block_tables,
|
|
857
|
+
"block_size": block_size,
|
|
858
|
+
}
|
|
859
|
+
|
|
860
|
+
if self.use_attention_mask:
|
|
861
|
+
op_args["mask"] = attn_mask
|
|
862
|
+
|
|
863
|
+
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
864
|
+
if not self.use_attention_mask or self.use_position_ids:
|
|
865
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
866
|
+
|
|
867
|
+
attn_op_name = self.get_attn_op_name()
|
|
868
|
+
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
869
|
+
if attn_op is None:
|
|
870
|
+
raise ValueError(f"Attention operator {attn_op_name} not found.")
|
|
871
|
+
|
|
872
|
+
attn_output = attn_op(**op_args)
|
|
917
873
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
|
918
874
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
919
875
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
@@ -921,161 +877,6 @@ class AttentionOp(nn.Module):
|
|
|
921
877
|
return attn_output
|
|
922
878
|
|
|
923
879
|
|
|
924
|
-
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
|
925
|
-
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
|
|
926
|
-
if cache_position.shape[0] > 1:
|
|
927
|
-
cos_all = []
|
|
928
|
-
sin_all = []
|
|
929
|
-
for i in range(cache_position.shape[0]):
|
|
930
|
-
cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
931
|
-
sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
932
|
-
cos = torch.cat(cos_all, dim=0)
|
|
933
|
-
sin = torch.cat(sin_all, dim=0)
|
|
934
|
-
else:
|
|
935
|
-
cos = cos[cache_position].unsqueeze(unsqueeze_dim)
|
|
936
|
-
sin = sin[cache_position].unsqueeze(unsqueeze_dim)
|
|
937
|
-
|
|
938
|
-
return cos, sin
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
def rotate_half(x):
|
|
942
|
-
"""Rotates half the hidden dims of the input."""
|
|
943
|
-
x1 = x[..., : x.shape[-1] // 2]
|
|
944
|
-
x2 = x[..., x.shape[-1] // 2 :]
|
|
945
|
-
return torch.cat((-x2, x1), dim=-1)
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
def apply_rotary_pos_emb(q, k, cos, sin):
|
|
949
|
-
"""Applies Rotary Position Embedding to the query and key tensors."""
|
|
950
|
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
951
|
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
952
|
-
return q_embed, k_embed
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
956
|
-
# Partial rotary embedding
|
|
957
|
-
query_rot, query_pass = (
|
|
958
|
-
query_states[..., :ndim],
|
|
959
|
-
query_states[..., ndim:],
|
|
960
|
-
)
|
|
961
|
-
key_rot, key_pass = (
|
|
962
|
-
key_states[..., :ndim],
|
|
963
|
-
key_states[..., ndim:],
|
|
964
|
-
)
|
|
965
|
-
|
|
966
|
-
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
|
967
|
-
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
|
968
|
-
|
|
969
|
-
# [batch_size, seq_length, num_heads, head_dim]
|
|
970
|
-
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
971
|
-
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
972
|
-
return query_states, key_states
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
class RotaryEmbedding(nn.Module):
|
|
976
|
-
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
977
|
-
|
|
978
|
-
def __init__(
|
|
979
|
-
self,
|
|
980
|
-
config: PretrainedConfig,
|
|
981
|
-
max_seq_len_cached: int,
|
|
982
|
-
):
|
|
983
|
-
super().__init__()
|
|
984
|
-
|
|
985
|
-
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
986
|
-
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
987
|
-
else:
|
|
988
|
-
rope_type = "default"
|
|
989
|
-
|
|
990
|
-
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
|
991
|
-
cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
|
|
992
|
-
cache_position_expanded = cache_position[:, None]
|
|
993
|
-
|
|
994
|
-
if rope_type == "dynamic":
|
|
995
|
-
freqs = cache_position_expanded.float() * inv_freq.float()
|
|
996
|
-
else:
|
|
997
|
-
inv_freq_expanded = inv_freq[None, :]
|
|
998
|
-
freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
|
|
999
|
-
|
|
1000
|
-
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1001
|
-
|
|
1002
|
-
cos = emb.cos() * attention_scaling
|
|
1003
|
-
sin = emb.sin() * attention_scaling
|
|
1004
|
-
|
|
1005
|
-
self.register_buffer("_cos_cached", cos, persistent=False)
|
|
1006
|
-
self.register_buffer("_sin_cached", sin, persistent=False)
|
|
1007
|
-
|
|
1008
|
-
def forward(self, x, seq_len):
|
|
1009
|
-
return (
|
|
1010
|
-
self._cos_cached[:seq_len].to(dtype=x.dtype),
|
|
1011
|
-
self._sin_cached[:seq_len].to(dtype=x.dtype),
|
|
1012
|
-
)
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
1016
|
-
def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask, use_position_ids):
|
|
1017
|
-
self.kvcache_partition_size = kvcache_partition_len
|
|
1018
|
-
super().__init__(
|
|
1019
|
-
self_attn=self_attn,
|
|
1020
|
-
use_attention_mask=use_attention_mask,
|
|
1021
|
-
use_position_ids=use_position_ids,
|
|
1022
|
-
kvcache_block_size=kvcache_block_size,
|
|
1023
|
-
)
|
|
1024
|
-
|
|
1025
|
-
def get_attention(self):
|
|
1026
|
-
return FlashAttentionOp(
|
|
1027
|
-
self.num_heads,
|
|
1028
|
-
self.head_dim,
|
|
1029
|
-
self.num_key_value_heads,
|
|
1030
|
-
self.kvcache_partition_size,
|
|
1031
|
-
self.use_attention_mask,
|
|
1032
|
-
self.use_position_ids,
|
|
1033
|
-
)
|
|
1034
|
-
|
|
1035
|
-
def forward(
|
|
1036
|
-
self,
|
|
1037
|
-
hidden_states: torch.Tensor,
|
|
1038
|
-
attention_mask: torch.Tensor,
|
|
1039
|
-
seq_positions: torch.LongTensor,
|
|
1040
|
-
past_key_values: Tuple[Tuple[torch.Tensor]],
|
|
1041
|
-
cos: Optional[torch.Tensor] = None,
|
|
1042
|
-
sin: Optional[torch.Tensor] = None,
|
|
1043
|
-
block_tables: Optional[torch.Tensor] = None,
|
|
1044
|
-
):
|
|
1045
|
-
batch_size, query_length, _ = hidden_states.size()
|
|
1046
|
-
|
|
1047
|
-
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
|
1048
|
-
|
|
1049
|
-
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
1050
|
-
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
1051
|
-
value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
|
|
1052
|
-
1, 2
|
|
1053
|
-
)
|
|
1054
|
-
|
|
1055
|
-
if hasattr(self, "q_norm") and hasattr(self, "k_norm"):
|
|
1056
|
-
query_states = self.q_norm(query_states)
|
|
1057
|
-
key_states = self.k_norm(key_states)
|
|
1058
|
-
|
|
1059
|
-
if cos is not None and sin is not None:
|
|
1060
|
-
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
|
1061
|
-
|
|
1062
|
-
attn_output = self.attention(
|
|
1063
|
-
query_states,
|
|
1064
|
-
key_states,
|
|
1065
|
-
value_states,
|
|
1066
|
-
attention_mask,
|
|
1067
|
-
past_key_state=past_key_values[self.layer_idx][0],
|
|
1068
|
-
past_value_state=past_key_values[self.layer_idx][1],
|
|
1069
|
-
seq_position=seq_positions,
|
|
1070
|
-
scale=self.scale,
|
|
1071
|
-
block_tables=block_tables,
|
|
1072
|
-
kvcache_block_size=self.kvcache_block_size,
|
|
1073
|
-
)
|
|
1074
|
-
|
|
1075
|
-
attn_outputs = self.o_proj(attn_output)
|
|
1076
|
-
return attn_outputs
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
880
|
class FlashAttentionOp(AttentionOp):
|
|
1080
881
|
def __init__(
|
|
1081
882
|
self,
|
|
@@ -1095,6 +896,17 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1095
896
|
)
|
|
1096
897
|
self.kvcache_partition_size = kvcache_partition_len
|
|
1097
898
|
|
|
899
|
+
def get_attn_op_name(self):
|
|
900
|
+
phase = "decode" if self.phase == "decode" else "prefill"
|
|
901
|
+
if self.use_attention_mask and not self.use_position_ids:
|
|
902
|
+
attn_op_name = "paged_flash_attn_"
|
|
903
|
+
else:
|
|
904
|
+
attn_op_name = "paged_flash_causal_attn_"
|
|
905
|
+
|
|
906
|
+
attn_op_name += phase
|
|
907
|
+
|
|
908
|
+
return attn_op_name
|
|
909
|
+
|
|
1098
910
|
def forward(
|
|
1099
911
|
self,
|
|
1100
912
|
query_state,
|
|
@@ -1106,7 +918,7 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1106
918
|
seq_position,
|
|
1107
919
|
scale,
|
|
1108
920
|
block_tables,
|
|
1109
|
-
|
|
921
|
+
block_size,
|
|
1110
922
|
):
|
|
1111
923
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
1112
924
|
key_state = key_state.unsqueeze(2)
|
|
@@ -1127,67 +939,32 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1127
939
|
self.head_dim,
|
|
1128
940
|
)
|
|
1129
941
|
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
partition=self.kvcache_partition_size,
|
|
1157
|
-
mask=attn_mask if self.use_position_ids else None,
|
|
1158
|
-
)
|
|
1159
|
-
else:
|
|
1160
|
-
if self.use_attention_mask and not self.use_position_ids:
|
|
1161
|
-
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
|
|
1162
|
-
q=query_state,
|
|
1163
|
-
k=key_state,
|
|
1164
|
-
v=value_state,
|
|
1165
|
-
mask=attn_mask,
|
|
1166
|
-
kcache=past_key_state.unsqueeze(2),
|
|
1167
|
-
vcache=past_value_state.unsqueeze(2),
|
|
1168
|
-
seq=seq_position,
|
|
1169
|
-
scale=scale,
|
|
1170
|
-
block_table=block_tables,
|
|
1171
|
-
block_size=kvcache_block_size,
|
|
1172
|
-
partition=self.kvcache_partition_size,
|
|
1173
|
-
)
|
|
1174
|
-
else:
|
|
1175
|
-
attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
|
|
1176
|
-
q=query_state,
|
|
1177
|
-
k=key_state,
|
|
1178
|
-
v=value_state,
|
|
1179
|
-
kcache=past_key_state.unsqueeze(2),
|
|
1180
|
-
vcache=past_value_state.unsqueeze(2),
|
|
1181
|
-
seq=seq_position,
|
|
1182
|
-
scale=scale,
|
|
1183
|
-
block_table=block_tables,
|
|
1184
|
-
block_size=kvcache_block_size,
|
|
1185
|
-
partition=self.kvcache_partition_size,
|
|
1186
|
-
is_bidirectional=True if self.phase == "image_prefill" else False,
|
|
1187
|
-
mask=attn_mask if self.use_position_ids else None,
|
|
1188
|
-
)
|
|
1189
|
-
|
|
1190
|
-
# reshape for removing repeat_kv
|
|
942
|
+
op_args = {
|
|
943
|
+
"q": query_state,
|
|
944
|
+
"k": key_state,
|
|
945
|
+
"v": value_state,
|
|
946
|
+
"kcache": past_key_state.unsqueeze(2),
|
|
947
|
+
"vcache": past_value_state.unsqueeze(2),
|
|
948
|
+
"seq": seq_position,
|
|
949
|
+
"scale": scale,
|
|
950
|
+
"block_table": block_tables,
|
|
951
|
+
"block_size": block_size,
|
|
952
|
+
"partition": self.kvcache_partition_size,
|
|
953
|
+
}
|
|
954
|
+
|
|
955
|
+
if self.use_attention_mask:
|
|
956
|
+
op_args["mask"] = attn_mask
|
|
957
|
+
|
|
958
|
+
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
959
|
+
if not self.use_attention_mask or self.use_position_ids:
|
|
960
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
961
|
+
|
|
962
|
+
attn_op_name = self.get_attn_op_name()
|
|
963
|
+
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
964
|
+
if attn_op is None:
|
|
965
|
+
raise ValueError(f"Attention operator {attn_op_name} not found.")
|
|
966
|
+
|
|
967
|
+
attn_output = attn_op(**op_args)
|
|
1191
968
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
|
1192
969
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
1193
970
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
@@ -1196,6 +973,14 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1196
973
|
|
|
1197
974
|
|
|
1198
975
|
class SlidingWindowAttentionOp(AttentionOp):
|
|
976
|
+
def get_attn_op_name(self):
|
|
977
|
+
phase = "decode" if self.phase == "decode" else "prefill"
|
|
978
|
+
if not self.use_attention_mask:
|
|
979
|
+
raise NotImplementedError("Attention mask is needed for sliding window attention.")
|
|
980
|
+
|
|
981
|
+
attn_op_name = "paged_sliding_window_attn_" + phase
|
|
982
|
+
return attn_op_name
|
|
983
|
+
|
|
1199
984
|
def forward(
|
|
1200
985
|
self,
|
|
1201
986
|
query_state: torch.Tensor,
|
|
@@ -1226,37 +1011,121 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1226
1011
|
self.head_dim,
|
|
1227
1012
|
)
|
|
1228
1013
|
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
scale=scale,
|
|
1252
|
-
block_table=block_tables,
|
|
1253
|
-
block_size=block_size,
|
|
1254
|
-
is_bidirectional=True if self.phase == "image_prefill" else False,
|
|
1255
|
-
)
|
|
1256
|
-
|
|
1257
|
-
# reshape for removing repeat_kv
|
|
1014
|
+
op_args = {
|
|
1015
|
+
"q": query_state,
|
|
1016
|
+
"k": key_state,
|
|
1017
|
+
"v": value_state,
|
|
1018
|
+
"kcache": past_key_state.unsqueeze(2),
|
|
1019
|
+
"vcache": past_value_state.unsqueeze(2),
|
|
1020
|
+
"cache_seq_len": seq_position[0],
|
|
1021
|
+
"cache_offset": seq_position[1],
|
|
1022
|
+
"scale": scale,
|
|
1023
|
+
"block_table": block_tables,
|
|
1024
|
+
"block_size": block_size,
|
|
1025
|
+
}
|
|
1026
|
+
|
|
1027
|
+
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
1028
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
1029
|
+
|
|
1030
|
+
attn_op_name = self.get_attn_op_name()
|
|
1031
|
+
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
1032
|
+
if attn_op is None:
|
|
1033
|
+
raise ValueError(f"Attention operator {attn_op_name} not found.")
|
|
1034
|
+
|
|
1035
|
+
attn_output = attn_op(**op_args)
|
|
1258
1036
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
|
1259
1037
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
1260
1038
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
1261
1039
|
|
|
1262
1040
|
return attn_output
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
class RotaryEmbedding(nn.Module):
|
|
1044
|
+
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
1045
|
+
|
|
1046
|
+
def __init__(
|
|
1047
|
+
self,
|
|
1048
|
+
config: PretrainedConfig,
|
|
1049
|
+
max_seq_len_cached: int,
|
|
1050
|
+
):
|
|
1051
|
+
super().__init__()
|
|
1052
|
+
|
|
1053
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
1054
|
+
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
1055
|
+
else:
|
|
1056
|
+
rope_type = "default"
|
|
1057
|
+
|
|
1058
|
+
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
|
1059
|
+
cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
|
|
1060
|
+
cache_position_expanded = cache_position[:, None]
|
|
1061
|
+
|
|
1062
|
+
if rope_type == "dynamic":
|
|
1063
|
+
freqs = cache_position_expanded.float() * inv_freq.float()
|
|
1064
|
+
else:
|
|
1065
|
+
inv_freq_expanded = inv_freq[None, :]
|
|
1066
|
+
freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
|
|
1067
|
+
|
|
1068
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1069
|
+
|
|
1070
|
+
cos = emb.cos() * attention_scaling
|
|
1071
|
+
sin = emb.sin() * attention_scaling
|
|
1072
|
+
|
|
1073
|
+
self.register_buffer("_cos_cached", cos, persistent=False)
|
|
1074
|
+
self.register_buffer("_sin_cached", sin, persistent=False)
|
|
1075
|
+
|
|
1076
|
+
def forward(self, x, seq_len):
|
|
1077
|
+
return (
|
|
1078
|
+
self._cos_cached[:seq_len].to(dtype=x.dtype),
|
|
1079
|
+
self._sin_cached[:seq_len].to(dtype=x.dtype),
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
|
|
1083
|
+
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
|
1084
|
+
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
|
|
1085
|
+
if cache_position.shape[0] > 1:
|
|
1086
|
+
cos_all = []
|
|
1087
|
+
sin_all = []
|
|
1088
|
+
for i in range(cache_position.shape[0]):
|
|
1089
|
+
cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
1090
|
+
sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
1091
|
+
cos = torch.cat(cos_all, dim=0)
|
|
1092
|
+
sin = torch.cat(sin_all, dim=0)
|
|
1093
|
+
else:
|
|
1094
|
+
cos = cos[cache_position].unsqueeze(unsqueeze_dim)
|
|
1095
|
+
sin = sin[cache_position].unsqueeze(unsqueeze_dim)
|
|
1096
|
+
|
|
1097
|
+
return cos, sin
|
|
1098
|
+
|
|
1099
|
+
|
|
1100
|
+
def rotate_half(x):
|
|
1101
|
+
"""Rotates half the hidden dims of the input."""
|
|
1102
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
1103
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
1104
|
+
return torch.cat((-x2, x1), dim=-1)
|
|
1105
|
+
|
|
1106
|
+
|
|
1107
|
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
|
1108
|
+
"""Applies Rotary Position Embedding to the query and key tensors."""
|
|
1109
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
1110
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
1111
|
+
return q_embed, k_embed
|
|
1112
|
+
|
|
1113
|
+
|
|
1114
|
+
def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1115
|
+
# Partial rotary embedding
|
|
1116
|
+
query_rot, query_pass = (
|
|
1117
|
+
query_states[..., :ndim],
|
|
1118
|
+
query_states[..., ndim:],
|
|
1119
|
+
)
|
|
1120
|
+
key_rot, key_pass = (
|
|
1121
|
+
key_states[..., :ndim],
|
|
1122
|
+
key_states[..., ndim:],
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1125
|
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
|
1126
|
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
|
1127
|
+
|
|
1128
|
+
# [batch_size, seq_length, num_heads, head_dim]
|
|
1129
|
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
1130
|
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
1131
|
+
return query_states, key_states
|