optimum-rbln 0.1.15__py3-none-any.whl → 0.2.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +26 -33
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/__init__.py +4 -0
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
- optimum/rbln/diffusers/models/__init__.py +2 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
- optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
- optimum/rbln/diffusers/models/controlnet.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
- optimum/rbln/diffusers/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
- optimum/rbln/modeling.py +13 -347
- optimum/rbln/modeling_base.py +24 -4
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -0
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
- optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/utils/rbln_quantization.py +1 -2
- optimum/rbln/utils/decorator_utils.py +51 -15
- optimum/rbln/utils/import_utils.py +8 -1
- optimum/rbln/utils/logging.py +38 -1
- optimum/rbln/utils/model_utils.py +0 -1
- optimum/rbln/utils/runtime_utils.py +9 -3
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +23 -0
- optimum_rbln-0.2.1a0.dist-info/METADATA +121 -0
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/RECORD +76 -72
- optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.15.dist-info/METADATA +0 -106
- optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/WHEEL +0 -0
@@ -27,129 +27,82 @@ from typing import List, Optional, Tuple
|
|
27
27
|
import torch
|
28
28
|
from torch import nn
|
29
29
|
from transformers import PretrainedConfig, PreTrainedModel
|
30
|
-
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
31
30
|
|
31
|
+
from ....ops import register_rbln_custom_attention, register_rbln_custom_flash_attention
|
32
32
|
from ....utils import logging
|
33
33
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
34
34
|
|
35
35
|
|
36
|
-
if is_torch_greater_or_equal_than_2_4:
|
37
|
-
register_fake = torch.library.register_fake
|
38
|
-
else:
|
39
|
-
register_fake = torch.library.impl_abstract
|
40
|
-
|
41
|
-
|
42
36
|
logger = logging.get_logger(__name__)
|
43
|
-
"""
|
44
|
-
##############################################################################
|
45
|
-
# RBLN custom operation (python interface)
|
46
|
-
# torch.compile custom operation
|
47
|
-
# torch.library.define - kernel declaration
|
48
|
-
# torch.library.impl - kernel implementation
|
49
|
-
# torch.library.impl_abstract - symbolic trace
|
50
|
-
##############################################################################
|
51
|
-
"""
|
52
|
-
|
53
|
-
# RBLN custom op(flash attention decode)
|
54
|
-
torch.library.define(
|
55
|
-
"rbln_custom_ops::flash_attn_decode",
|
56
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
|
57
|
-
)
|
58
|
-
|
59
|
-
|
60
|
-
@torch.library.impl("rbln_custom_ops::flash_attn_decode", "cpu")
|
61
|
-
def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, partition):
|
62
|
-
"""
|
63
|
-
WORKAROUND:
|
64
|
-
Partition is declared as an argument to the function, even though it is
|
65
|
-
not actually used in the CPU implementation, this allows the rbln compiler
|
66
|
-
to perform flash attention operations with partition as an argument.
|
67
|
-
"""
|
68
|
-
assert kcache.dim() == k.dim()
|
69
|
-
assert vcache.dim() == v.dim()
|
70
|
-
assert k.size(-2) == v.size(-2)
|
71
|
-
assert partition.dim() == 1
|
72
|
-
b = 0
|
73
|
-
if seq.dim() == 1:
|
74
|
-
s = seq[0]
|
75
|
-
elif seq.dim() == 0:
|
76
|
-
s = seq
|
77
|
-
else:
|
78
|
-
assert False
|
79
|
-
e = s + k.size(-2)
|
80
|
-
updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
|
81
|
-
updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
|
82
|
-
attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
|
83
|
-
attn_weight = attn_weight + mask
|
84
|
-
attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
|
85
|
-
attn_output = torch.matmul(attn_weight, updated_v)
|
86
|
-
return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
|
87
|
-
|
88
|
-
|
89
|
-
@register_fake("rbln_custom_ops::flash_attn_decode")
|
90
|
-
def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
|
91
|
-
return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
|
92
|
-
|
93
|
-
|
94
|
-
# RBLN custom op(flash attention prefill)
|
95
|
-
torch.library.define(
|
96
|
-
"rbln_custom_ops::flash_attn_prefill",
|
97
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
|
98
|
-
)
|
99
|
-
|
100
|
-
|
101
|
-
@torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
|
102
|
-
def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, partition):
|
103
|
-
"""
|
104
|
-
WORKAROUND:
|
105
|
-
Partition is declared as an argument to the function, even though it is
|
106
|
-
not actually used in the CPU implementation, this allows the rbln compiler
|
107
|
-
to perform flash attention operations with partition as an argument.
|
108
|
-
"""
|
109
|
-
assert kcache.dim() == k.dim()
|
110
|
-
assert vcache.dim() == v.dim()
|
111
|
-
assert k.size(-2) == v.size(-2)
|
112
|
-
assert partition.dim() == 1
|
113
|
-
if batch.dim() == 1:
|
114
|
-
b = batch[0]
|
115
|
-
elif batch.dim() == 0:
|
116
|
-
b = batch
|
117
|
-
else:
|
118
|
-
assert False
|
119
|
-
if seq.dim() == 1:
|
120
|
-
s = seq[0]
|
121
|
-
elif seq.dim() == 0:
|
122
|
-
s = seq
|
123
|
-
else:
|
124
|
-
assert False
|
125
|
-
e = s + k.size(-2)
|
126
|
-
updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
|
127
|
-
updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
|
128
|
-
attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
|
129
|
-
attn_weight = attn_weight + mask
|
130
|
-
attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
|
131
|
-
attn_output = torch.matmul(attn_weight, updated_v)
|
132
|
-
return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
|
133
|
-
|
134
|
-
|
135
|
-
@register_fake("rbln_custom_ops::flash_attn_prefill")
|
136
|
-
def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
|
137
|
-
return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
|
138
37
|
|
38
|
+
DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
|
39
|
+
DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
|
40
|
+
MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
|
41
|
+
MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
|
42
|
+
MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
|
139
43
|
|
140
|
-
# RBLN custom op(cache update)
|
141
|
-
torch.library.define("rbln_custom_ops::rbln_cache_update", "(Tensor x, Tensor y, Tensor z, Tensor w) -> Tensor")
|
142
44
|
|
45
|
+
def validate_attention_method(
|
46
|
+
rbln_attn_impl: str, rbln_kvcache_partition_len: int, rbln_max_seq_len: int
|
47
|
+
) -> Tuple[str, int]:
|
48
|
+
if rbln_kvcache_partition_len is not None:
|
49
|
+
if rbln_attn_impl == "eager":
|
50
|
+
raise ValueError(
|
51
|
+
f"`rbln_kvcache_partition_len` is set to {rbln_kvcache_partition_len}, but KV cache partitioning"
|
52
|
+
" is not supported with 'eager' attention. Please set `rbln_kvcache_partition_len` to None, "
|
53
|
+
"or switch `rbln_attn_impl` to 'flash_attn' to use KV cache partitioning."
|
54
|
+
)
|
55
|
+
elif rbln_attn_impl is None:
|
56
|
+
rbln_attn_impl = "flash_attn"
|
57
|
+
logger.warning(
|
58
|
+
"A non-null `rbln_kvcache_partition_len` was provided, but `rbln_attn_impl` was not explicitly set. "
|
59
|
+
"Since KV cache partitioning is only supported with flash attention, "
|
60
|
+
"`rbln_attn_impl` has been automatically switched to 'flash_attn'."
|
61
|
+
)
|
143
62
|
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
63
|
+
rbln_attn_impl = "eager" if rbln_attn_impl is None else rbln_attn_impl
|
64
|
+
if rbln_attn_impl not in ["eager", "flash_attn"]:
|
65
|
+
raise ValueError(f"Unknown `rbln_attn_impl` : {rbln_attn_impl}. (Available : 'eager', 'flash_attn`)")
|
66
|
+
|
67
|
+
if rbln_kvcache_partition_len is None and rbln_attn_impl == "flash_attn":
|
68
|
+
rbln_kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
69
|
+
|
70
|
+
## Checking Constraints...
|
71
|
+
# Constraint of eager attention:
|
72
|
+
# - `max_seq_len` <= 32k
|
73
|
+
|
74
|
+
# Constraints of flash attention:
|
75
|
+
# 1. `max_seq_len` should be multiple of `partition_len`.
|
76
|
+
# 2. 4k <= `partition_len` <= 32k.
|
77
|
+
# 3. `max_seq_len` should be larger then 8k.
|
78
|
+
if rbln_attn_impl == "eager" and rbln_max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
|
79
|
+
raise ValueError(
|
80
|
+
f"`rbln_max_seq_len` is set to {rbln_max_seq_len}, "
|
81
|
+
f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
|
82
|
+
f"Please reduce the `rbln_max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
|
83
|
+
" or consider switching `rbln_attn_impl` to 'flash_attn' for larger sequence lengths."
|
84
|
+
)
|
148
85
|
|
86
|
+
if rbln_attn_impl == "flash_attn":
|
87
|
+
if rbln_max_seq_len // rbln_kvcache_partition_len < 2 or rbln_max_seq_len % rbln_kvcache_partition_len != 0:
|
88
|
+
raise ValueError(
|
89
|
+
f"`rbln_max_seq_len` ({rbln_max_seq_len}) must be a multiple of `rbln_kvcache_partition_len` ({rbln_kvcache_partition_len}) "
|
90
|
+
f"when using 'flash_attn'. Please adjust either value to meet this requirement."
|
91
|
+
)
|
92
|
+
elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= rbln_kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
|
93
|
+
raise ValueError(
|
94
|
+
f"`rbln_kvcache_partition_len` ({rbln_kvcache_partition_len}) is out of the supported range for 'flash_attn' "
|
95
|
+
f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `rbln_kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
|
96
|
+
f"Please provide a valid value within this range."
|
97
|
+
)
|
98
|
+
elif rbln_max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
|
99
|
+
raise ValueError(
|
100
|
+
f"`rbln_max_seq_len` ({rbln_max_seq_len}) is too small for 'flash_attn'. The minimum "
|
101
|
+
f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `rbln_max_seq_len` to meet "
|
102
|
+
"this requirement, or consider switching `rbln_attn_impl` to 'eager' for shorter lengths."
|
103
|
+
)
|
149
104
|
|
150
|
-
|
151
|
-
def rbln_cache_update_abstract(cache, value, batch, seq):
|
152
|
-
return torch.empty_like(cache)
|
105
|
+
return rbln_attn_impl, rbln_kvcache_partition_len
|
153
106
|
|
154
107
|
|
155
108
|
class DecoderOnlyWrapper(nn.Module):
|
@@ -169,11 +122,23 @@ class DecoderOnlyWrapper(nn.Module):
|
|
169
122
|
causal_lm (PreTrainedModel): The Huggingface causal language model to wrap
|
170
123
|
max_seq_len (int): Maximum sequence length for position embeddings and cache sizes
|
171
124
|
use_rotary_emb (bool): Whether to use rotary position embeddings
|
125
|
+
attn_impl (str): The attention implementation to use.
|
126
|
+
- "eager": Uses the standard attention.
|
127
|
+
- "flash_attn": Uses flash attention. When set,
|
128
|
+
the key/value cache is partitioned into chunks of length
|
129
|
+
`kvcache_partition_len`.
|
172
130
|
kvcache_partition_len (Optional[int]): Length of KV cache partitions for flash attention.
|
173
|
-
|
131
|
+
This is only relevant if `attn_impl` is set to "flash_attn`
|
174
132
|
"""
|
175
133
|
|
176
|
-
def __init__(
|
134
|
+
def __init__(
|
135
|
+
self,
|
136
|
+
causal_lm: PreTrainedModel,
|
137
|
+
max_seq_len: int,
|
138
|
+
use_rotary_emb: bool,
|
139
|
+
attn_impl: str,
|
140
|
+
kvcache_partition_len: Optional[int] = None,
|
141
|
+
):
|
177
142
|
super().__init__()
|
178
143
|
self.config = causal_lm.config
|
179
144
|
|
@@ -182,14 +147,21 @@ class DecoderOnlyWrapper(nn.Module):
|
|
182
147
|
else:
|
183
148
|
self.rotary_emb = None
|
184
149
|
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
150
|
+
self.attn_impl = attn_impl
|
151
|
+
if self.attn_impl == "flash_attn":
|
152
|
+
self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
153
|
+
register_rbln_custom_flash_attention()
|
154
|
+
elif self.attn_impl == "eager":
|
155
|
+
self.kvcache_partition_len = None
|
156
|
+
register_rbln_custom_attention()
|
190
157
|
else:
|
191
|
-
self.attn_impl
|
192
|
-
|
158
|
+
raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
|
159
|
+
|
160
|
+
if kvcache_partition_len and kvcache_partition_len > max_seq_len:
|
161
|
+
raise ValueError(
|
162
|
+
f"kvcache_partition_len({kvcache_partition_len}) should be lower"
|
163
|
+
f" or equal to max_seq_len({max_seq_len})!"
|
164
|
+
)
|
193
165
|
|
194
166
|
self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm)
|
195
167
|
|
@@ -213,12 +185,12 @@ class DecoderOnlyWrapper(nn.Module):
|
|
213
185
|
|
214
186
|
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
215
187
|
new_layers.append(new_layer)
|
216
|
-
new_model = DecoderOnlyModel(causal_lm.model, new_layers)
|
188
|
+
new_model = DecoderOnlyModel(causal_lm.model, new_layers, partition_len=self.kvcache_partition_len)
|
217
189
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
218
190
|
return new_causal_lm
|
219
191
|
|
220
192
|
@property
|
221
|
-
def phase(self):
|
193
|
+
def phase(self) -> str:
|
222
194
|
return self._phase
|
223
195
|
|
224
196
|
@phase.setter
|
@@ -226,21 +198,32 @@ class DecoderOnlyWrapper(nn.Module):
|
|
226
198
|
self._phase = phase
|
227
199
|
self.causal_lm.phase = phase
|
228
200
|
|
229
|
-
def forward(
|
230
|
-
self
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
201
|
+
def forward(self, *args):
|
202
|
+
if self.phase == "decode":
|
203
|
+
(
|
204
|
+
input_ids_or_inputs_embeds,
|
205
|
+
attention_mask,
|
206
|
+
cache_position,
|
207
|
+
*past_key_values,
|
208
|
+
) = args
|
209
|
+
batch_position = torch.tensor(0, dtype=torch.int16)
|
210
|
+
query_position = None
|
211
|
+
elif self.phase == "prefill":
|
212
|
+
(
|
213
|
+
input_ids_or_inputs_embeds,
|
214
|
+
attention_mask,
|
215
|
+
cache_position,
|
216
|
+
batch_position,
|
217
|
+
query_position,
|
218
|
+
*past_key_values,
|
219
|
+
) = args
|
220
|
+
else:
|
221
|
+
raise ValueError(f"Unknown phase: {self.phase}")
|
222
|
+
|
238
223
|
if input_ids_or_inputs_embeds.ndim == 2:
|
239
|
-
# It is input_ids
|
240
224
|
input_ids = input_ids_or_inputs_embeds
|
241
225
|
inputs_embeds = None
|
242
226
|
elif input_ids_or_inputs_embeds.ndim == 3:
|
243
|
-
# It is inputs_embeds
|
244
227
|
input_ids = None
|
245
228
|
inputs_embeds = input_ids_or_inputs_embeds
|
246
229
|
else:
|
@@ -248,15 +231,9 @@ class DecoderOnlyWrapper(nn.Module):
|
|
248
231
|
|
249
232
|
if len(past_key_values) != 2 * self.num_hidden_layers:
|
250
233
|
raise ValueError(
|
251
|
-
f"Different past_key_values to model's config. {len(past_key_values)} != {self.num_hidden_layers}"
|
234
|
+
f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
|
252
235
|
)
|
253
236
|
|
254
|
-
seq_len = input_ids_or_inputs_embeds.shape[1]
|
255
|
-
if seq_len == 1:
|
256
|
-
self.phase = "decode"
|
257
|
-
else:
|
258
|
-
self.phase = "prefill"
|
259
|
-
|
260
237
|
# [key, value] * n_layer -> ( (key, value) ) * n_layer
|
261
238
|
# cache shape : batch, n_heads, 1, max_seq_len, head_dim
|
262
239
|
_past_key_values = []
|
@@ -286,8 +263,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
286
263
|
_present_key_values = _present_key_values + (key_states, value_states)
|
287
264
|
present_key_values = _present_key_values
|
288
265
|
|
289
|
-
|
290
|
-
return logit, present_key_values, batch_position + query_position
|
266
|
+
return logit, present_key_values
|
291
267
|
|
292
268
|
|
293
269
|
class DecoderOnlyForCausalLM(nn.Module):
|
@@ -371,13 +347,12 @@ class DecoderOnlyModel(nn.Module):
|
|
371
347
|
_phase: Current processing phase ("prefill" or "decode")
|
372
348
|
"""
|
373
349
|
|
374
|
-
|
375
|
-
|
376
|
-
def __init__(self, model, layers: List["DecoderOnlyLayer"]):
|
350
|
+
def __init__(self, model, layers: List["DecoderOnlyLayer"], partition_len=None):
|
377
351
|
super().__init__()
|
378
352
|
self._original_mod = model
|
379
353
|
self.layers = nn.ModuleList(layers)
|
380
354
|
self._phase = "prefill"
|
355
|
+
self.partition_len = partition_len
|
381
356
|
|
382
357
|
@property
|
383
358
|
def phase(self):
|
@@ -389,10 +364,26 @@ class DecoderOnlyModel(nn.Module):
|
|
389
364
|
for layer in self.layers:
|
390
365
|
layer.phase = phase
|
391
366
|
|
367
|
+
@property
|
368
|
+
def attn_impl(self) -> str:
|
369
|
+
return "eager" if self.partition_len is None else "flash_attn"
|
370
|
+
|
392
371
|
@property
|
393
372
|
def hidden_multiplier(self):
|
394
373
|
return 1
|
395
374
|
|
375
|
+
def convert_sequence_positions_for_flash_attn(self, seq_positions, max_seq_len):
|
376
|
+
if self.attn_impl != "flash_attn":
|
377
|
+
raise NotImplementedError(f"Unknown attn_impl ({self.attn_impl}).")
|
378
|
+
|
379
|
+
partition_len = self.partition_len
|
380
|
+
num_partition = max_seq_len // partition_len
|
381
|
+
|
382
|
+
cs = seq_positions.repeat(num_partition, 1).transpose(0, 1)
|
383
|
+
pidx = torch.arange(num_partition)
|
384
|
+
cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
|
385
|
+
return cache_pos_for_partitions
|
386
|
+
|
396
387
|
def get_last_layernorm(self) -> nn.LayerNorm:
|
397
388
|
return self._original_mod.norm
|
398
389
|
|
@@ -425,7 +416,6 @@ class DecoderOnlyModel(nn.Module):
|
|
425
416
|
inputs_embeds = self.get_embedding()(input_ids)
|
426
417
|
|
427
418
|
hidden_states = inputs_embeds * self.hidden_multiplier
|
428
|
-
attention_mask = (1 - attention_mask) * self.mask_fmin
|
429
419
|
|
430
420
|
# get cos,sin vector if needed
|
431
421
|
if rotary_emb is not None:
|
@@ -446,14 +436,19 @@ class DecoderOnlyModel(nn.Module):
|
|
446
436
|
cos, sin = None, None
|
447
437
|
|
448
438
|
# (batch, seq_len) -> (batch,)
|
449
|
-
|
439
|
+
seq_positions = cache_position[:, 0]
|
440
|
+
if self.attn_impl == "flash_attn":
|
441
|
+
max_seq_len = past_key_values[0][0].shape[-2]
|
442
|
+
seq_positions = self.convert_sequence_positions_for_flash_attn(
|
443
|
+
seq_positions=seq_positions, max_seq_len=max_seq_len
|
444
|
+
)
|
450
445
|
|
451
446
|
present_key_values = past_key_values
|
452
447
|
for layer in self.layers:
|
453
448
|
hidden_states, present_key_values = layer(
|
454
449
|
hidden_states=hidden_states,
|
455
450
|
attention_mask=attention_mask,
|
456
|
-
|
451
|
+
seq_positions=seq_positions,
|
457
452
|
batch_position=batch_position,
|
458
453
|
past_key_values=present_key_values,
|
459
454
|
cos=cos,
|
@@ -514,20 +509,19 @@ class DecoderOnlyLayer(nn.Module):
|
|
514
509
|
self,
|
515
510
|
hidden_states: torch.Tensor,
|
516
511
|
attention_mask: torch.Tensor,
|
517
|
-
|
512
|
+
seq_positions: torch.LongTensor,
|
518
513
|
batch_position: torch.Tensor,
|
519
514
|
past_key_values: Tuple[Tuple[torch.Tensor]],
|
520
515
|
cos: Optional[torch.Tensor] = None,
|
521
516
|
sin: Optional[torch.Tensor] = None,
|
522
517
|
):
|
523
518
|
residual = hidden_states
|
524
|
-
|
525
519
|
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
526
520
|
|
527
521
|
hidden_states, present_key_values = self.self_attn(
|
528
522
|
hidden_states=hidden_states,
|
529
523
|
attention_mask=attention_mask,
|
530
|
-
|
524
|
+
seq_positions=seq_positions,
|
531
525
|
batch_position=batch_position,
|
532
526
|
past_key_values=past_key_values,
|
533
527
|
cos=cos,
|
@@ -561,15 +555,34 @@ class DecoderOnlyAttention(nn.Module):
|
|
561
555
|
self.layer_idx = self_attn.layer_idx
|
562
556
|
self.num_heads = self._original_mod.num_heads
|
563
557
|
self.head_dim = self._original_mod.head_dim
|
564
|
-
self.
|
558
|
+
self._phase = "prefill"
|
559
|
+
self.scale = torch.tensor(self.get_attn_scale())
|
560
|
+
|
561
|
+
if hasattr(self._original_mod, "num_key_value_heads"):
|
562
|
+
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
563
|
+
else:
|
564
|
+
self.num_key_value_heads = self._original_mod.num_heads
|
565
|
+
|
566
|
+
self.attention = self.get_attention()
|
565
567
|
self.__post_init__()
|
566
568
|
|
569
|
+
@property
|
570
|
+
def phase(self):
|
571
|
+
return self._phase
|
572
|
+
|
573
|
+
@phase.setter
|
574
|
+
def phase(self, phase: str):
|
575
|
+
self._phase = phase
|
576
|
+
self.attention.phase = phase
|
577
|
+
|
578
|
+
def get_attention(self):
|
579
|
+
return AttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads)
|
580
|
+
|
567
581
|
def __post_init__(self):
|
568
582
|
self.q_proj = self._original_mod.q_proj
|
569
583
|
self.k_proj = self._original_mod.k_proj
|
570
584
|
self.v_proj = self._original_mod.v_proj
|
571
585
|
self.o_proj = self._original_mod.o_proj
|
572
|
-
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
573
586
|
|
574
587
|
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
575
588
|
"""Projects input hidden states into query, key, and value representations.
|
@@ -588,97 +601,17 @@ class DecoderOnlyAttention(nn.Module):
|
|
588
601
|
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
589
602
|
return apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
590
603
|
|
591
|
-
def
|
592
|
-
self
|
593
|
-
query_state,
|
594
|
-
key_state,
|
595
|
-
value_state,
|
596
|
-
attn_mask,
|
597
|
-
batch_idx,
|
598
|
-
past_key_state,
|
599
|
-
past_value_state,
|
600
|
-
current_step,
|
601
|
-
# below are designed for Midm, GPT which requires to support scaling for attention weights
|
602
|
-
# TODO(jongho): Merge and manage scales generally
|
603
|
-
layer_idx=None,
|
604
|
-
scale_attn_weights: bool = None,
|
605
|
-
scale_attn_by_inverse_layer_idx: bool = None,
|
606
|
-
scale_qk_by_inverse_layer_idx: bool = None,
|
607
|
-
):
|
608
|
-
"""Compute attention with static shapes and explicit cache management.
|
609
|
-
|
610
|
-
Args:
|
611
|
-
query_state: Query tensor [1, num_heads, 1, head_dim]
|
612
|
-
key_state: Key tensor [1, num_heads, seq_len, head_dim]
|
613
|
-
value_state: Value tensor [1, num_heads, seq_len, head_dim]
|
614
|
-
attn_mask: Attention mask tensor
|
615
|
-
batch_idx: Batch index for cache lookup
|
616
|
-
past_key_state: Previous key cache states
|
617
|
-
past_value_state: Previous value cache states
|
618
|
-
current_step: Current position in sequence
|
619
|
-
|
620
|
-
Returns:
|
621
|
-
Tuple of (attention_output, key_state, value_state)
|
622
|
-
"""
|
623
|
-
# Implementation details.
|
624
|
-
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
625
|
-
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
626
|
-
value_state = value_state.unsqueeze(2)
|
627
|
-
attn_mask = attn_mask.unsqueeze(2)
|
628
|
-
|
629
|
-
query_state = query_state.view(
|
630
|
-
1,
|
631
|
-
self.num_key_value_heads,
|
632
|
-
self.num_heads // self.num_key_value_heads,
|
633
|
-
-1, # seq len
|
634
|
-
self.head_dim,
|
635
|
-
) #
|
636
|
-
|
637
|
-
kend = current_step + key_state.shape[-2]
|
638
|
-
vend = current_step + value_state.shape[-2]
|
639
|
-
|
640
|
-
key_state = (
|
641
|
-
past_key_state[batch_idx]
|
642
|
-
.unsqueeze(0)
|
643
|
-
.unsqueeze(2)
|
644
|
-
.slice_scatter(key_state, dim=-2, start=current_step, end=kend)
|
645
|
-
)
|
646
|
-
value_state = (
|
647
|
-
past_value_state[batch_idx]
|
648
|
-
.unsqueeze(0)
|
649
|
-
.unsqueeze(2)
|
650
|
-
.slice_scatter(value_state, dim=-2, start=current_step, end=vend)
|
651
|
-
)
|
652
|
-
|
653
|
-
attn_weight = torch.matmul(query_state, key_state.transpose(3, 4))
|
654
|
-
attn_weight = attn_weight / math.sqrt(self.head_dim)
|
655
|
-
|
656
|
-
if layer_idx is not None and (scale_attn_by_inverse_layer_idx or scale_qk_by_inverse_layer_idx):
|
657
|
-
attn_weight = attn_weight / float(layer_idx + 1)
|
658
|
-
|
659
|
-
attn_weight += attn_mask
|
660
|
-
|
661
|
-
if layer_idx is not None and scale_qk_by_inverse_layer_idx:
|
662
|
-
attn_weight = attn_weight * float(layer_idx + 1)
|
663
|
-
|
664
|
-
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
|
665
|
-
|
666
|
-
attn_output = torch.matmul(attn_weight, value_state)
|
667
|
-
|
668
|
-
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
|
669
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
670
|
-
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
671
|
-
|
672
|
-
return attn_output, key_state, value_state
|
604
|
+
def get_attn_scale(self):
|
605
|
+
return 1 / math.sqrt(self.head_dim)
|
673
606
|
|
674
607
|
def forward(
|
675
608
|
self,
|
676
609
|
hidden_states: torch.Tensor,
|
677
610
|
attention_mask: torch.Tensor,
|
678
|
-
|
611
|
+
seq_positions: torch.LongTensor,
|
679
612
|
batch_position: torch.Tensor,
|
680
613
|
past_key_values: Tuple[Tuple[torch.Tensor]],
|
681
|
-
cos: Optional[torch.Tensor] = None,
|
614
|
+
cos: Optional[torch.Tensor] = None,
|
682
615
|
sin: Optional[torch.Tensor] = None,
|
683
616
|
):
|
684
617
|
batch_size, query_length, _ = hidden_states.size()
|
@@ -698,22 +631,24 @@ class DecoderOnlyAttention(nn.Module):
|
|
698
631
|
if batch_size > 1 and self.phase == "prefill":
|
699
632
|
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
700
633
|
|
634
|
+
# TODO(jongho): flash attn legacy. (clone)
|
635
|
+
_seq_positions = seq_positions.clone().unsqueeze(1)
|
636
|
+
|
701
637
|
_key_states = []
|
702
638
|
_value_states = []
|
703
639
|
_attn_outputs = []
|
704
640
|
for b in range(batch_size):
|
705
|
-
|
706
|
-
attn_output, key_state, value_state = self.
|
641
|
+
seq_position = _seq_positions[b][0]
|
642
|
+
attn_output, key_state, value_state = self.attention(
|
707
643
|
query_states[b].unsqueeze(0),
|
708
644
|
key_states[b].unsqueeze(0),
|
709
645
|
value_states[b].unsqueeze(0),
|
710
|
-
attention_mask[b].unsqueeze(0)
|
711
|
-
if self.phase == "decode"
|
712
|
-
else attention_mask, # TODO(jongho): fix when msoftmax is supported
|
646
|
+
attention_mask[b].unsqueeze(0) if self.phase == "decode" else attention_mask,
|
713
647
|
past_key_state=past_key_values[self.layer_idx][0],
|
714
648
|
past_value_state=past_key_values[self.layer_idx][1],
|
715
|
-
|
716
|
-
|
649
|
+
batch_position=b if self.phase == "decode" else batch_position,
|
650
|
+
seq_position=seq_position,
|
651
|
+
scale=self.scale,
|
717
652
|
)
|
718
653
|
_key_states.append(key_state)
|
719
654
|
_value_states.append(value_state)
|
@@ -727,6 +662,87 @@ class DecoderOnlyAttention(nn.Module):
|
|
727
662
|
return attn_outputs, past_key_values
|
728
663
|
|
729
664
|
|
665
|
+
class AttentionOp(nn.Module):
|
666
|
+
def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int):
|
667
|
+
super().__init__()
|
668
|
+
self.num_heads = num_heads
|
669
|
+
self.head_dim = head_dim
|
670
|
+
self.num_key_value_heads = num_key_value_heads
|
671
|
+
self.phase = "prefill"
|
672
|
+
|
673
|
+
def forward(
|
674
|
+
self,
|
675
|
+
query_state: torch.Tensor,
|
676
|
+
key_state: torch.Tensor,
|
677
|
+
value_state: torch.Tensor,
|
678
|
+
attn_mask: torch.Tensor,
|
679
|
+
batch_position: torch.Tensor,
|
680
|
+
past_key_state: torch.Tensor,
|
681
|
+
past_value_state: torch.Tensor,
|
682
|
+
seq_position: torch.Tensor,
|
683
|
+
scale: torch.Tensor,
|
684
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
685
|
+
"""Compute attention with static shapes and explicit cache management.
|
686
|
+
|
687
|
+
Args:
|
688
|
+
query_state: Query tensor [1, num_heads, 1, head_dim]
|
689
|
+
key_state: Key tensor [1, num_heads, seq_len, head_dim]
|
690
|
+
value_state: Value tensor [1, num_heads, seq_len, head_dim]
|
691
|
+
attn_mask: Attention mask tensor ∈ {0, 1}
|
692
|
+
batch_position: Batch index for cache lookup
|
693
|
+
past_key_state: Previous key cache states
|
694
|
+
past_value_state: Previous value cache states
|
695
|
+
seq_position: Current position in sequence
|
696
|
+
scale: Scale applied to attn weights
|
697
|
+
|
698
|
+
Returns:
|
699
|
+
Tuple of (attention_output, key_state, value_state)
|
700
|
+
"""
|
701
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
702
|
+
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
703
|
+
value_state = value_state.unsqueeze(2)
|
704
|
+
attn_mask = attn_mask.unsqueeze(2)
|
705
|
+
|
706
|
+
query_state = query_state.view(
|
707
|
+
1,
|
708
|
+
self.num_key_value_heads,
|
709
|
+
self.num_heads // self.num_key_value_heads,
|
710
|
+
-1, # seq len
|
711
|
+
self.head_dim,
|
712
|
+
)
|
713
|
+
|
714
|
+
if self.phase == "decode":
|
715
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.attn_decode(
|
716
|
+
query_state,
|
717
|
+
key_state,
|
718
|
+
value_state,
|
719
|
+
attn_mask,
|
720
|
+
past_key_state.unsqueeze(2),
|
721
|
+
past_value_state.unsqueeze(2),
|
722
|
+
seq_position,
|
723
|
+
scale,
|
724
|
+
)
|
725
|
+
|
726
|
+
else:
|
727
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.attn_prefill(
|
728
|
+
query_state,
|
729
|
+
key_state,
|
730
|
+
value_state,
|
731
|
+
attn_mask,
|
732
|
+
past_key_state.unsqueeze(2),
|
733
|
+
past_value_state.unsqueeze(2),
|
734
|
+
batch_position,
|
735
|
+
seq_position,
|
736
|
+
scale,
|
737
|
+
)
|
738
|
+
|
739
|
+
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
|
740
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
741
|
+
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
742
|
+
|
743
|
+
return attn_output, key_state.squeeze(2), value_state.squeeze(2)
|
744
|
+
|
745
|
+
|
730
746
|
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
731
747
|
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
|
732
748
|
if cache_position.shape[0] > 1:
|
@@ -821,40 +837,83 @@ class RotaryEmbedding(nn.Module):
|
|
821
837
|
|
822
838
|
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
823
839
|
def __init__(self, self_attn, kvcache_partition_len):
|
840
|
+
self.kvcache_partition_size = kvcache_partition_len
|
824
841
|
super().__init__(self_attn=self_attn)
|
825
|
-
self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
|
826
842
|
|
827
|
-
def
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
843
|
+
def get_attention(self):
|
844
|
+
return FlashAttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads, self.kvcache_partition_size)
|
845
|
+
|
846
|
+
def forward(
|
847
|
+
self,
|
848
|
+
hidden_states: torch.Tensor,
|
849
|
+
attention_mask: torch.Tensor,
|
850
|
+
seq_positions: torch.LongTensor,
|
851
|
+
batch_position: torch.Tensor,
|
852
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
853
|
+
cos: Optional[torch.Tensor] = None,
|
854
|
+
sin: Optional[torch.Tensor] = None,
|
855
|
+
):
|
856
|
+
batch_size, query_length, _ = hidden_states.size()
|
857
|
+
|
858
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
859
|
+
|
860
|
+
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
861
|
+
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
862
|
+
value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
|
863
|
+
1, 2
|
864
|
+
)
|
865
|
+
# b, num_head, query, head_dim
|
866
|
+
|
867
|
+
if cos is not None and sin is not None:
|
868
|
+
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
869
|
+
|
870
|
+
_key_states = []
|
871
|
+
_value_states = []
|
872
|
+
_attn_outputs = []
|
873
|
+
for b in range(batch_size):
|
874
|
+
seq_position = seq_positions[b][0] # FIXME: Remove take-take pattern matching
|
875
|
+
attn_output, key_state, value_state = self.attention(
|
876
|
+
query_states[b].unsqueeze(0),
|
877
|
+
key_states[b].unsqueeze(0),
|
878
|
+
value_states[b].unsqueeze(0),
|
879
|
+
attention_mask[b].unsqueeze(0) if self.phase == "decode" else attention_mask,
|
880
|
+
past_key_state=past_key_values[self.layer_idx][0],
|
881
|
+
past_value_state=past_key_values[self.layer_idx][1],
|
882
|
+
batch_position=b if self.phase == "decode" else batch_position,
|
883
|
+
seq_position=seq_position,
|
884
|
+
scale=self.scale,
|
885
|
+
)
|
886
|
+
_key_states.append(key_state)
|
887
|
+
_value_states.append(value_state)
|
888
|
+
_attn_outputs.append(attn_output)
|
889
|
+
key_states = torch.cat(_key_states, dim=0)
|
890
|
+
value_states = torch.cat(_value_states, dim=0)
|
891
|
+
attn_outputs = torch.cat(_attn_outputs, dim=0)
|
892
|
+
|
893
|
+
attn_outputs = self.o_proj(attn_outputs)
|
894
|
+
past_key_values[self.layer_idx] = key_states, value_states
|
895
|
+
return attn_outputs, past_key_values
|
842
896
|
|
843
|
-
return cache_pos_for_partitions
|
844
897
|
|
845
|
-
|
898
|
+
class FlashAttentionOp(AttentionOp):
|
899
|
+
def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int, kvcache_partition_len: int):
|
900
|
+
super().__init__(num_heads=num_heads, head_dim=head_dim, num_key_value_heads=num_key_value_heads)
|
901
|
+
self.kvcache_partition_size = kvcache_partition_len
|
902
|
+
|
903
|
+
def forward(
|
846
904
|
self,
|
847
905
|
query_state,
|
848
906
|
key_state,
|
849
907
|
value_state,
|
850
908
|
attn_mask,
|
851
|
-
|
909
|
+
batch_position,
|
852
910
|
past_key_state,
|
853
911
|
past_value_state,
|
854
|
-
|
912
|
+
seq_position,
|
913
|
+
scale,
|
855
914
|
):
|
856
915
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
857
|
-
key_state = key_state.unsqueeze(2)
|
916
|
+
key_state = key_state.unsqueeze(2)
|
858
917
|
value_state = value_state.unsqueeze(2)
|
859
918
|
attn_mask = attn_mask.unsqueeze(2)
|
860
919
|
|
@@ -866,9 +925,7 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
866
925
|
self.head_dim,
|
867
926
|
)
|
868
927
|
|
869
|
-
# RBLN custom flash attention(decode), dummy batch index
|
870
928
|
if self.phase == "decode":
|
871
|
-
sidx = cache_pos_for_partitions[batch_idx][0]
|
872
929
|
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_decode(
|
873
930
|
query_state,
|
874
931
|
key_state,
|
@@ -876,11 +933,11 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
876
933
|
attn_mask,
|
877
934
|
past_key_state.unsqueeze(2),
|
878
935
|
past_value_state.unsqueeze(2),
|
879
|
-
|
936
|
+
seq_position,
|
937
|
+
scale,
|
880
938
|
self.kvcache_partition_size,
|
881
939
|
)
|
882
940
|
else:
|
883
|
-
sidx = cache_pos_for_partitions[0][0]
|
884
941
|
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_prefill(
|
885
942
|
query_state,
|
886
943
|
key_state,
|
@@ -888,8 +945,9 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
888
945
|
attn_mask,
|
889
946
|
past_key_state.unsqueeze(2),
|
890
947
|
past_value_state.unsqueeze(2),
|
891
|
-
|
892
|
-
|
948
|
+
batch_position,
|
949
|
+
seq_position,
|
950
|
+
scale,
|
893
951
|
self.kvcache_partition_size,
|
894
952
|
)
|
895
953
|
|
@@ -899,60 +957,3 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
899
957
|
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
900
958
|
|
901
959
|
return attn_output, key_state, value_state
|
902
|
-
|
903
|
-
def forward(
|
904
|
-
self,
|
905
|
-
hidden_states: torch.Tensor,
|
906
|
-
attention_mask: torch.Tensor,
|
907
|
-
current_steps: torch.LongTensor,
|
908
|
-
batch_position: torch.Tensor,
|
909
|
-
past_key_values: Tuple[Tuple[torch.Tensor]],
|
910
|
-
cos: Optional[torch.Tensor] = None,
|
911
|
-
sin: Optional[torch.Tensor] = None,
|
912
|
-
):
|
913
|
-
batch_size, query_length, _ = hidden_states.size()
|
914
|
-
|
915
|
-
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
916
|
-
|
917
|
-
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
918
|
-
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
919
|
-
value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
|
920
|
-
1, 2
|
921
|
-
)
|
922
|
-
# b, num_head, query, head_dim
|
923
|
-
|
924
|
-
max_seq_len = past_key_values[self.layer_idx][0].shape[-2]
|
925
|
-
|
926
|
-
if cos is not None and sin is not None:
|
927
|
-
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
928
|
-
|
929
|
-
cache_pos_for_partitions = self.get_cache_pos_for_partitions(
|
930
|
-
current_steps, batch_size=batch_size, max_seq_len=max_seq_len
|
931
|
-
) # batch_size, num_partitions
|
932
|
-
|
933
|
-
_key_states = []
|
934
|
-
_value_states = []
|
935
|
-
_attn_outputs = []
|
936
|
-
for b in range(batch_size):
|
937
|
-
attn_output, key_state, value_state = self.rbln_flash_attention(
|
938
|
-
query_states[b].unsqueeze(0),
|
939
|
-
key_states[b].unsqueeze(0),
|
940
|
-
value_states[b].unsqueeze(0),
|
941
|
-
attention_mask[b].unsqueeze(0)
|
942
|
-
if self.phase == "decode"
|
943
|
-
else attention_mask, # TODO(jongho): fix when msoftmax is supported
|
944
|
-
past_key_state=past_key_values[self.layer_idx][0],
|
945
|
-
past_value_state=past_key_values[self.layer_idx][1],
|
946
|
-
batch_idx=b if self.phase == "decode" else batch_position,
|
947
|
-
cache_pos_for_partitions=cache_pos_for_partitions,
|
948
|
-
)
|
949
|
-
_key_states.append(key_state)
|
950
|
-
_value_states.append(value_state)
|
951
|
-
_attn_outputs.append(attn_output)
|
952
|
-
key_states = torch.cat(_key_states, dim=0)
|
953
|
-
value_states = torch.cat(_value_states, dim=0)
|
954
|
-
attn_outputs = torch.cat(_attn_outputs, dim=0)
|
955
|
-
|
956
|
-
attn_outputs = self.o_proj(attn_outputs)
|
957
|
-
past_key_values[self.layer_idx] = key_states, value_states
|
958
|
-
return attn_outputs, past_key_values
|