optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +41 -38
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +26 -2
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
- optimum/rbln/diffusers/models/__init__.py +36 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
- optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
- optimum/rbln/diffusers/models/controlnet.py +54 -14
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
- optimum/rbln/diffusers/pipelines/__init__.py +23 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/modeling.py +238 -0
- optimum/rbln/modeling_base.py +186 -760
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -2
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
- optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
- optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
- optimum/rbln/utils/decorator_utils.py +51 -11
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +22 -1
- optimum/rbln/utils/logging.py +37 -0
- optimum/rbln/utils/model_utils.py +52 -0
- optimum/rbln/utils/runtime_utils.py +10 -4
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +137 -0
- optimum_rbln-0.2.0.dist-info/METADATA +117 -0
- optimum_rbln-0.2.0.dist-info/RECORD +114 -0
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
- optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum/rbln/utils/context.py +0 -58
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.13.dist-info/METADATA +0 -120
- optimum_rbln-0.1.13.dist-info/RECORD +0 -107
- optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
- optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -22,623 +22,740 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import math
|
25
|
-
from typing import
|
25
|
+
from typing import List, Optional, Tuple
|
26
26
|
|
27
27
|
import torch
|
28
28
|
from torch import nn
|
29
|
-
from transformers import PretrainedConfig
|
30
|
-
from transformers.modeling_outputs import (
|
31
|
-
BaseModelOutputWithPast,
|
32
|
-
)
|
29
|
+
from transformers import PretrainedConfig, PreTrainedModel
|
33
30
|
|
31
|
+
from ....ops import register_rbln_custom_attention, register_rbln_custom_flash_attention
|
34
32
|
from ....utils import logging
|
35
|
-
from ...cache_utils import RebelDynamicCache
|
36
33
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
37
34
|
|
38
35
|
|
39
36
|
logger = logging.get_logger(__name__)
|
40
|
-
"""
|
41
|
-
##############################################################################
|
42
|
-
# RBLN custom operation (python interface)
|
43
|
-
# torch.compile custom operation
|
44
|
-
# torch.library.define - kernel declaration
|
45
|
-
# torch.library.impl - kernel implementation
|
46
|
-
# torch.library.impl_abstract - symbolic trace
|
47
|
-
##############################################################################
|
48
|
-
"""
|
49
|
-
|
50
|
-
# RBLN custom op(flash attention decode)
|
51
|
-
torch.library.define(
|
52
|
-
"rbln_custom_ops::flash_attn_decode",
|
53
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
|
54
|
-
)
|
55
|
-
|
56
|
-
|
57
|
-
@torch.library.impl("rbln_custom_ops::flash_attn_decode", "cpu")
|
58
|
-
def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, partition):
|
59
|
-
"""
|
60
|
-
WORKAROUND:
|
61
|
-
Partition is declared as an argument to the function, even though it is
|
62
|
-
not actually used in the CPU implementation, this allows the rbln compiler
|
63
|
-
to perform flash attention operations with partition as an argument.
|
64
|
-
"""
|
65
|
-
assert kcache.dim() == k.dim()
|
66
|
-
assert vcache.dim() == v.dim()
|
67
|
-
assert k.size(-2) == v.size(-2)
|
68
|
-
assert partition.dim() == 1
|
69
|
-
b = 0
|
70
|
-
if seq.dim() == 1:
|
71
|
-
s = seq[0]
|
72
|
-
elif seq.dim() == 0:
|
73
|
-
s = seq
|
74
|
-
else:
|
75
|
-
assert False
|
76
|
-
e = s + k.size(-2)
|
77
|
-
updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
|
78
|
-
updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
|
79
|
-
attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
|
80
|
-
attn_weight = attn_weight + mask
|
81
|
-
attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
|
82
|
-
attn_output = torch.matmul(attn_weight, updated_v)
|
83
|
-
return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
|
84
37
|
|
38
|
+
DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
|
39
|
+
DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
|
40
|
+
MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
|
41
|
+
MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
|
42
|
+
MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
|
85
43
|
|
86
|
-
@torch.library.impl_abstract("rbln_custom_ops::flash_attn_decode")
|
87
|
-
def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
|
88
|
-
return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
|
89
44
|
|
45
|
+
def validate_attention_method(
|
46
|
+
rbln_attn_impl: str, rbln_kvcache_partition_len: int, rbln_max_seq_len: int
|
47
|
+
) -> Tuple[str, int]:
|
48
|
+
if rbln_kvcache_partition_len is not None:
|
49
|
+
if rbln_attn_impl == "eager":
|
50
|
+
raise ValueError(
|
51
|
+
f"`rbln_kvcache_partition_len` is set to {rbln_kvcache_partition_len}, but KV cache partitioning"
|
52
|
+
" is not supported with 'eager' attention. Please set `rbln_kvcache_partition_len` to None, "
|
53
|
+
"or switch `rbln_attn_impl` to 'flash_attn' to use KV cache partitioning."
|
54
|
+
)
|
55
|
+
elif rbln_attn_impl is None:
|
56
|
+
rbln_attn_impl = "flash_attn"
|
57
|
+
logger.warning(
|
58
|
+
"A non-null `rbln_kvcache_partition_len` was provided, but `rbln_attn_impl` was not explicitly set. "
|
59
|
+
"Since KV cache partitioning is only supported with flash attention, "
|
60
|
+
"`rbln_attn_impl` has been automatically switched to 'flash_attn'."
|
61
|
+
)
|
90
62
|
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
63
|
+
rbln_attn_impl = "eager" if rbln_attn_impl is None else rbln_attn_impl
|
64
|
+
if rbln_attn_impl not in ["eager", "flash_attn"]:
|
65
|
+
raise ValueError(f"Unknown `rbln_attn_impl` : {rbln_attn_impl}. (Available : 'eager', 'flash_attn`)")
|
66
|
+
|
67
|
+
if rbln_kvcache_partition_len is None and rbln_attn_impl == "flash_attn":
|
68
|
+
rbln_kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
69
|
+
|
70
|
+
## Checking Constraints...
|
71
|
+
# Constraint of eager attention:
|
72
|
+
# - `max_seq_len` <= 32k
|
73
|
+
|
74
|
+
# Constraints of flash attention:
|
75
|
+
# 1. `max_seq_len` should be multiple of `partition_len`.
|
76
|
+
# 2. 4k <= `partition_len` <= 32k.
|
77
|
+
# 3. `max_seq_len` should be larger then 8k.
|
78
|
+
if rbln_attn_impl == "eager" and rbln_max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
|
79
|
+
raise ValueError(
|
80
|
+
f"`rbln_max_seq_len` is set to {rbln_max_seq_len}, "
|
81
|
+
f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
|
82
|
+
f"Please reduce the `rbln_max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
|
83
|
+
" or consider switching `rbln_attn_impl` to 'flash_attn' for larger sequence lengths."
|
84
|
+
)
|
96
85
|
|
86
|
+
if rbln_attn_impl == "flash_attn":
|
87
|
+
if rbln_max_seq_len // rbln_kvcache_partition_len < 2 or rbln_max_seq_len % rbln_kvcache_partition_len != 0:
|
88
|
+
raise ValueError(
|
89
|
+
f"`rbln_max_seq_len` ({rbln_max_seq_len}) must be a multiple of `rbln_kvcache_partition_len` ({rbln_kvcache_partition_len}) "
|
90
|
+
f"when using 'flash_attn'. Please adjust either value to meet this requirement."
|
91
|
+
)
|
92
|
+
elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= rbln_kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
|
93
|
+
raise ValueError(
|
94
|
+
f"`rbln_kvcache_partition_len` ({rbln_kvcache_partition_len}) is out of the supported range for 'flash_attn' "
|
95
|
+
f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `rbln_kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
|
96
|
+
f"Please provide a valid value within this range."
|
97
|
+
)
|
98
|
+
elif rbln_max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
|
99
|
+
raise ValueError(
|
100
|
+
f"`rbln_max_seq_len` ({rbln_max_seq_len}) is too small for 'flash_attn'. The minimum "
|
101
|
+
f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `rbln_max_seq_len` to meet "
|
102
|
+
"this requirement, or consider switching `rbln_attn_impl` to 'eager' for shorter lengths."
|
103
|
+
)
|
97
104
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
+
return rbln_attn_impl, rbln_kvcache_partition_len
|
106
|
+
|
107
|
+
|
108
|
+
class DecoderOnlyWrapper(nn.Module):
|
109
|
+
"""A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
|
110
|
+
|
111
|
+
This wrapper is designed to:
|
112
|
+
1. Convert Huggingface decoder models for RBLN compilation with static shapes
|
113
|
+
2. Handle input/model mapping and additional information supply (e.g., positional embeddings)
|
114
|
+
3. Manage different attention implementations (standard and flash attention)
|
115
|
+
4. Support both prefill and decode phases
|
116
|
+
|
117
|
+
Notes:
|
118
|
+
- Wrapper must only receive positional arguments in forward() due to torch.jit.trace dependency
|
119
|
+
- Wrapper should not contain neural network graph operations (including memory view handling)
|
120
|
+
|
121
|
+
Args:
|
122
|
+
causal_lm (PreTrainedModel): The Huggingface causal language model to wrap
|
123
|
+
max_seq_len (int): Maximum sequence length for position embeddings and cache sizes
|
124
|
+
use_rotary_emb (bool): Whether to use rotary position embeddings
|
125
|
+
attn_impl (str): The attention implementation to use.
|
126
|
+
- "eager": Uses the standard attention.
|
127
|
+
- "flash_attn": Uses flash attention. When set,
|
128
|
+
the key/value cache is partitioned into chunks of length
|
129
|
+
`kvcache_partition_len`.
|
130
|
+
kvcache_partition_len (Optional[int]): Length of KV cache partitions for flash attention.
|
131
|
+
This is only relevant if `attn_impl` is set to "flash_attn`
|
105
132
|
"""
|
106
|
-
assert kcache.dim() == k.dim()
|
107
|
-
assert vcache.dim() == v.dim()
|
108
|
-
assert k.size(-2) == v.size(-2)
|
109
|
-
assert partition.dim() == 1
|
110
|
-
if batch.dim() == 1:
|
111
|
-
b = batch[0]
|
112
|
-
elif batch.dim() == 0:
|
113
|
-
b = batch
|
114
|
-
else:
|
115
|
-
assert False
|
116
|
-
if seq.dim() == 1:
|
117
|
-
s = seq[0]
|
118
|
-
elif seq.dim() == 0:
|
119
|
-
s = seq
|
120
|
-
else:
|
121
|
-
assert False
|
122
|
-
e = s + k.size(-2)
|
123
|
-
updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
|
124
|
-
updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
|
125
|
-
attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
|
126
|
-
attn_weight = attn_weight + mask
|
127
|
-
attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
|
128
|
-
attn_output = torch.matmul(attn_weight, updated_v)
|
129
|
-
return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
|
130
133
|
|
134
|
+
def __init__(
|
135
|
+
self,
|
136
|
+
causal_lm: PreTrainedModel,
|
137
|
+
max_seq_len: int,
|
138
|
+
use_rotary_emb: bool,
|
139
|
+
attn_impl: str,
|
140
|
+
kvcache_partition_len: Optional[int] = None,
|
141
|
+
):
|
142
|
+
super().__init__()
|
143
|
+
self.config = causal_lm.config
|
131
144
|
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
145
|
+
if use_rotary_emb:
|
146
|
+
self.rotary_emb = self.get_rotary_emb(max_seq_len=max_seq_len)
|
147
|
+
else:
|
148
|
+
self.rotary_emb = None
|
149
|
+
|
150
|
+
self.attn_impl = attn_impl
|
151
|
+
if self.attn_impl == "flash_attn":
|
152
|
+
self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
153
|
+
register_rbln_custom_flash_attention()
|
154
|
+
elif self.attn_impl == "eager":
|
155
|
+
self.kvcache_partition_len = None
|
156
|
+
register_rbln_custom_attention()
|
157
|
+
else:
|
158
|
+
raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
|
136
159
|
|
137
|
-
|
138
|
-
|
160
|
+
if kvcache_partition_len and kvcache_partition_len > max_seq_len:
|
161
|
+
raise ValueError(
|
162
|
+
f"kvcache_partition_len({kvcache_partition_len}) should be lower"
|
163
|
+
f" or equal to max_seq_len({max_seq_len})!"
|
164
|
+
)
|
139
165
|
|
166
|
+
self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm)
|
140
167
|
|
141
|
-
|
142
|
-
|
143
|
-
updated_cache = cache[batch].slice_scatter(value, dim=-2, start=batch[0], end=batch[0] + seq[0])
|
144
|
-
return updated_cache
|
168
|
+
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
|
169
|
+
self._phase = "prefill"
|
145
170
|
|
171
|
+
def get_rotary_emb(self, max_seq_len):
|
172
|
+
return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
146
173
|
|
147
|
-
|
148
|
-
|
149
|
-
|
174
|
+
def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel):
|
175
|
+
new_layers = []
|
176
|
+
for layer in causal_lm.model.layers:
|
177
|
+
if self.attn_impl == "eager":
|
178
|
+
new_self_attn = DecoderOnlyAttention(layer.self_attn)
|
179
|
+
elif self.attn_impl == "flash_attn":
|
180
|
+
new_self_attn = DecoderOnlyFlashAttention(
|
181
|
+
layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
|
182
|
+
)
|
183
|
+
else:
|
184
|
+
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
185
|
+
|
186
|
+
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
187
|
+
new_layers.append(new_layer)
|
188
|
+
new_model = DecoderOnlyModel(causal_lm.model, new_layers, partition_len=self.kvcache_partition_len)
|
189
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
190
|
+
return new_causal_lm
|
191
|
+
|
192
|
+
@property
|
193
|
+
def phase(self) -> str:
|
194
|
+
return self._phase
|
195
|
+
|
196
|
+
@phase.setter
|
197
|
+
def phase(self, phase: str):
|
198
|
+
self._phase = phase
|
199
|
+
self.causal_lm.phase = phase
|
200
|
+
|
201
|
+
def forward(self, *args):
|
202
|
+
if self.phase == "decode":
|
203
|
+
(
|
204
|
+
input_ids_or_inputs_embeds,
|
205
|
+
attention_mask,
|
206
|
+
cache_position,
|
207
|
+
*past_key_values,
|
208
|
+
) = args
|
209
|
+
batch_position = torch.tensor(0, dtype=torch.int16)
|
210
|
+
query_position = None
|
211
|
+
elif self.phase == "prefill":
|
212
|
+
(
|
213
|
+
input_ids_or_inputs_embeds,
|
214
|
+
attention_mask,
|
215
|
+
cache_position,
|
216
|
+
batch_position,
|
217
|
+
query_position,
|
218
|
+
*past_key_values,
|
219
|
+
) = args
|
220
|
+
else:
|
221
|
+
raise ValueError(f"Unknown phase: {self.phase}")
|
150
222
|
|
223
|
+
if input_ids_or_inputs_embeds.ndim == 2:
|
224
|
+
input_ids = input_ids_or_inputs_embeds
|
225
|
+
inputs_embeds = None
|
226
|
+
elif input_ids_or_inputs_embeds.ndim == 3:
|
227
|
+
input_ids = None
|
228
|
+
inputs_embeds = input_ids_or_inputs_embeds
|
229
|
+
else:
|
230
|
+
raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
|
151
231
|
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
value_state = value_state.unsqueeze(2)
|
157
|
-
attn_mask = attn_mask.unsqueeze(2)
|
232
|
+
if len(past_key_values) != 2 * self.num_hidden_layers:
|
233
|
+
raise ValueError(
|
234
|
+
f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
|
235
|
+
)
|
158
236
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
237
|
+
# [key, value] * n_layer -> ( (key, value) ) * n_layer
|
238
|
+
# cache shape : batch, n_heads, 1, max_seq_len, head_dim
|
239
|
+
_past_key_values = []
|
240
|
+
for i in range(self.config.num_hidden_layers):
|
241
|
+
key_states = past_key_values[i * 2]
|
242
|
+
value_states = past_key_values[i * 2 + 1]
|
243
|
+
past_key_value = [key_states, value_states]
|
244
|
+
_past_key_values.append(past_key_value)
|
245
|
+
past_key_values = _past_key_values
|
246
|
+
|
247
|
+
logit, present_key_values = self.causal_lm(
|
248
|
+
input_ids=input_ids,
|
249
|
+
inputs_embeds=inputs_embeds,
|
250
|
+
attention_mask=attention_mask,
|
251
|
+
cache_position=cache_position,
|
252
|
+
batch_position=batch_position,
|
253
|
+
query_position=query_position,
|
254
|
+
past_key_values=past_key_values,
|
255
|
+
rotary_emb=self.rotary_emb,
|
165
256
|
)
|
166
257
|
|
167
|
-
|
168
|
-
|
169
|
-
)
|
258
|
+
# ((key, value)) * n_layer -> [key, value] * n_layer
|
259
|
+
_present_key_values = ()
|
260
|
+
for i in range(self.num_hidden_layers):
|
261
|
+
key_states = present_key_values[i][0]
|
262
|
+
value_states = present_key_values[i][1]
|
263
|
+
_present_key_values = _present_key_values + (key_states, value_states)
|
264
|
+
present_key_values = _present_key_values
|
170
265
|
|
171
|
-
|
172
|
-
attn_weight += attn_mask
|
173
|
-
attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_state.dtype)
|
174
|
-
attn_output = torch.matmul(attn_weight, value_state)
|
266
|
+
return logit, present_key_values
|
175
267
|
|
176
|
-
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
|
177
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
178
|
-
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
179
268
|
|
180
|
-
|
269
|
+
class DecoderOnlyForCausalLM(nn.Module):
|
270
|
+
"""A specialized wrapper for Causal Language Models optimized for RBLN compilation.
|
181
271
|
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
position_ids: Optional[torch.LongTensor] = None,
|
187
|
-
past_key_value: Optional[RebelDynamicCache] = None,
|
188
|
-
batch_index: Optional[torch.Tensor] = None,
|
189
|
-
output_attentions: bool = False,
|
190
|
-
cos: Optional[torch.Tensor] = None,
|
191
|
-
sin: Optional[torch.Tensor] = None,
|
192
|
-
**kwargs,
|
193
|
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
194
|
-
bsz, q_len, _ = hidden_states.size()
|
195
|
-
query_states = self.q_proj(hidden_states)
|
196
|
-
key_states = self.k_proj(hidden_states)
|
197
|
-
value_states = self.v_proj(hidden_states)
|
272
|
+
This class adapts Huggingface's CausalLM (or similar models) for RBLN deployment by:
|
273
|
+
1. Managing model phases (prefill/decode) throughout the computation graph
|
274
|
+
2. Handling output shape alignments for static compilation
|
275
|
+
3. Coordinating between the original model and RBLN-optimized components
|
198
276
|
|
199
|
-
|
200
|
-
|
201
|
-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
202
|
-
|
203
|
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
204
|
-
|
205
|
-
# Decoder (bsz > 1)
|
206
|
-
if bsz > 1:
|
207
|
-
iterate_results = {"key_states": [], "value_states": [], "attn_output": []}
|
208
|
-
for b in range(bsz):
|
209
|
-
attn_output, key_state, value_state = DecoderOnlyAttention._attn(
|
210
|
-
self,
|
211
|
-
query_states[b].unsqueeze(0),
|
212
|
-
key_states[b].unsqueeze(0),
|
213
|
-
value_states[b].unsqueeze(0),
|
214
|
-
attention_mask[b].unsqueeze(0),
|
215
|
-
past_key_value,
|
216
|
-
batch_idx=b,
|
217
|
-
is_prefill=False,
|
218
|
-
)
|
277
|
+
The class serves as an intermediate layer between DecoderOnlyWrapper and the core model,
|
278
|
+
focusing on maintaining correct model behavior while enabling RBLN-specific optimizations.
|
219
279
|
|
220
|
-
|
221
|
-
|
222
|
-
|
280
|
+
Args:
|
281
|
+
causal_lm (PreTrainedModel): Original Huggingface causal language model
|
282
|
+
model (DecoderOnlyModel): RBLN-optimized model instance
|
223
283
|
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
self,
|
231
|
-
query_states,
|
232
|
-
key_states,
|
233
|
-
value_states,
|
234
|
-
attention_mask,
|
235
|
-
past_key_value,
|
236
|
-
batch_idx=batch_index,
|
237
|
-
is_prefill=True,
|
238
|
-
)
|
239
|
-
|
240
|
-
attn_output = self.o_proj(attn_output)
|
284
|
+
Attributes:
|
285
|
+
config: Configuration from the original causal language model
|
286
|
+
_original_mod: Reference to the original model for components like lm_head
|
287
|
+
model: RBLN-optimized decoder model instance
|
288
|
+
_phase: Current processing phase ("prefill" or "decode")
|
289
|
+
"""
|
241
290
|
|
242
|
-
|
243
|
-
|
291
|
+
def __init__(self, causal_lm: PreTrainedModel, model):
|
292
|
+
super().__init__()
|
293
|
+
self.config = causal_lm.config
|
294
|
+
self._original_mod = causal_lm
|
295
|
+
self.model = model
|
296
|
+
self._phase = "prefill"
|
244
297
|
|
245
|
-
|
298
|
+
@property
|
299
|
+
def phase(self):
|
300
|
+
return self._phase
|
246
301
|
|
302
|
+
@phase.setter
|
303
|
+
def phase(self, phase: str):
|
304
|
+
self._phase = phase
|
305
|
+
self.model.phase = phase
|
247
306
|
|
248
|
-
class DecoderOnlyFlashAttention:
|
249
307
|
def forward(
|
250
308
|
self,
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
271
|
-
|
272
|
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
273
|
-
|
274
|
-
# Decoder (bsz > 1)
|
275
|
-
if bsz > 1:
|
276
|
-
all_key_states = []
|
277
|
-
all_value_states = []
|
278
|
-
all_attn_output = []
|
279
|
-
|
280
|
-
for b in range(bsz):
|
281
|
-
query_state = query_states[b].unsqueeze(0)
|
282
|
-
attn_mask = attention_mask[b].unsqueeze(0)
|
283
|
-
key_state = key_states[b].unsqueeze(0)
|
284
|
-
value_state = value_states[b].unsqueeze(0)
|
285
|
-
|
286
|
-
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
287
|
-
key_state = key_state.unsqueeze(2)
|
288
|
-
value_state = value_state.unsqueeze(2)
|
289
|
-
attn_mask = attn_mask.unsqueeze(2)
|
290
|
-
|
291
|
-
query_state = query_state.view(
|
292
|
-
1,
|
293
|
-
self.num_key_value_heads,
|
294
|
-
self.num_heads // self.num_key_value_heads,
|
295
|
-
q_len,
|
296
|
-
self.head_dim,
|
297
|
-
)
|
298
|
-
|
299
|
-
# RBLN custom flash attention(decode), dummy batch index
|
300
|
-
sidx = cache_pos_for_partitions[b][0]
|
301
|
-
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_decode(
|
302
|
-
query_state,
|
303
|
-
key_state,
|
304
|
-
value_state,
|
305
|
-
attn_mask,
|
306
|
-
past_key_value.key_cache[self.layer_idx].unsqueeze(2),
|
307
|
-
past_key_value.value_cache[self.layer_idx].unsqueeze(2),
|
308
|
-
sidx,
|
309
|
-
kvcache_partition_size,
|
310
|
-
)
|
309
|
+
input_ids: torch.Tensor = None,
|
310
|
+
inputs_embeds: torch.Tensor = None,
|
311
|
+
attention_mask: torch.Tensor = None,
|
312
|
+
cache_position: torch.Tensor = None,
|
313
|
+
batch_position: torch.Tensor = None,
|
314
|
+
query_position: torch.Tensor = None,
|
315
|
+
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
316
|
+
rotary_emb: nn.Module = None,
|
317
|
+
):
|
318
|
+
# outputs
|
319
|
+
hidden_states, present_key_values = self.model(
|
320
|
+
input_ids=input_ids,
|
321
|
+
inputs_embeds=inputs_embeds,
|
322
|
+
attention_mask=attention_mask,
|
323
|
+
cache_position=cache_position,
|
324
|
+
batch_position=batch_position,
|
325
|
+
past_key_values=past_key_values,
|
326
|
+
rotary_emb=rotary_emb,
|
327
|
+
)
|
311
328
|
|
312
|
-
|
313
|
-
|
314
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
315
|
-
attn_output = attn_output.reshape(1, q_len, self.num_heads * self.head_dim)
|
329
|
+
if self.phase == "prefill":
|
330
|
+
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
316
331
|
|
317
|
-
|
318
|
-
|
319
|
-
|
332
|
+
logits = self._original_mod.lm_head(hidden_states)
|
333
|
+
output = (logits, present_key_values)
|
334
|
+
return output
|
320
335
|
|
321
|
-
key_states = torch.cat(all_key_states, dim=0)
|
322
|
-
value_states = torch.cat(all_value_states, dim=0)
|
323
|
-
attn_output = torch.cat(all_attn_output, dim=0)
|
324
336
|
|
325
|
-
|
326
|
-
|
327
|
-
key_states = key_states.unsqueeze(2)
|
328
|
-
value_states = value_states.unsqueeze(2)
|
329
|
-
attention_mask = attention_mask.unsqueeze(2)
|
330
|
-
query_states = query_states.view(
|
331
|
-
1,
|
332
|
-
self.num_key_value_heads,
|
333
|
-
self.num_heads // self.num_key_value_heads,
|
334
|
-
q_len,
|
335
|
-
self.head_dim,
|
336
|
-
)
|
337
|
+
class DecoderOnlyModel(nn.Module):
|
338
|
+
"""A modified decoder-only model implementation optimized for RBLN compilation.
|
337
339
|
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
sidx = cache_pos_for_partitions[0][0]
|
342
|
-
attn_output, key_states, value_states = torch.ops.rbln_custom_ops.flash_attn_prefill(
|
343
|
-
query_states,
|
344
|
-
key_states,
|
345
|
-
value_states,
|
346
|
-
attention_mask,
|
347
|
-
past_key_value.key_cache[self.layer_idx].unsqueeze(2),
|
348
|
-
past_key_value.value_cache[self.layer_idx].unsqueeze(2),
|
349
|
-
bidx,
|
350
|
-
sidx,
|
351
|
-
kvcache_partition_size,
|
352
|
-
)
|
340
|
+
Args:
|
341
|
+
model: Original Huggingface model to adapt
|
342
|
+
layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
|
353
343
|
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
344
|
+
Attributes:
|
345
|
+
_original_mod: Reference to original Huggingface model
|
346
|
+
layers: ModuleList of RBLN-optimized transformer layers
|
347
|
+
_phase: Current processing phase ("prefill" or "decode")
|
348
|
+
"""
|
358
349
|
|
359
|
-
|
350
|
+
def __init__(self, model, layers: List["DecoderOnlyLayer"], partition_len=None):
|
351
|
+
super().__init__()
|
352
|
+
self._original_mod = model
|
353
|
+
self.layers = nn.ModuleList(layers)
|
354
|
+
self._phase = "prefill"
|
355
|
+
self.partition_len = partition_len
|
356
|
+
|
357
|
+
@property
|
358
|
+
def phase(self):
|
359
|
+
return self._phase
|
360
|
+
|
361
|
+
@phase.setter
|
362
|
+
def phase(self, phase: str):
|
363
|
+
self._phase = phase
|
364
|
+
for layer in self.layers:
|
365
|
+
layer.phase = phase
|
366
|
+
|
367
|
+
@property
|
368
|
+
def attn_impl(self) -> str:
|
369
|
+
return "eager" if self.partition_len is None else "flash_attn"
|
370
|
+
|
371
|
+
@property
|
372
|
+
def hidden_multiplier(self):
|
373
|
+
return 1
|
374
|
+
|
375
|
+
def convert_sequence_positions_for_flash_attn(self, seq_positions, max_seq_len):
|
376
|
+
if self.attn_impl != "flash_attn":
|
377
|
+
raise NotImplementedError(f"Unknown attn_impl ({self.attn_impl}).")
|
378
|
+
|
379
|
+
partition_len = self.partition_len
|
380
|
+
num_partition = max_seq_len // partition_len
|
381
|
+
|
382
|
+
cs = seq_positions.repeat(num_partition, 1).transpose(0, 1)
|
383
|
+
pidx = torch.arange(num_partition)
|
384
|
+
cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
|
385
|
+
return cache_pos_for_partitions
|
386
|
+
|
387
|
+
def get_last_layernorm(self) -> nn.LayerNorm:
|
388
|
+
return self._original_mod.norm
|
389
|
+
|
390
|
+
def get_embedding(self) -> nn.Embedding:
|
391
|
+
return self._original_mod.embed_tokens
|
392
|
+
|
393
|
+
def get_pos_embedding(self) -> nn.Embedding:
|
394
|
+
raise NotImplementedError(
|
395
|
+
"The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
|
396
|
+
)
|
360
397
|
|
361
|
-
|
362
|
-
|
398
|
+
def forward(
|
399
|
+
self,
|
400
|
+
input_ids: torch.Tensor = None,
|
401
|
+
inputs_embeds: torch.Tensor = None,
|
402
|
+
attention_mask: torch.Tensor = None,
|
403
|
+
cache_position: torch.Tensor = None,
|
404
|
+
batch_position: torch.Tensor = None,
|
405
|
+
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
406
|
+
rotary_emb: nn.Module = None,
|
407
|
+
):
|
408
|
+
# retrieve input_ids and inputs_embeds
|
409
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
410
|
+
raise ValueError(
|
411
|
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
412
|
+
)
|
363
413
|
|
364
|
-
|
414
|
+
# embed positions
|
415
|
+
if inputs_embeds is None:
|
416
|
+
inputs_embeds = self.get_embedding()(input_ids)
|
365
417
|
|
418
|
+
hidden_states = inputs_embeds * self.hidden_multiplier
|
366
419
|
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
420
|
+
# get cos,sin vector if needed
|
421
|
+
if rotary_emb is not None:
|
422
|
+
cos, sin = rotary_emb(hidden_states, attention_mask.shape[-1]) # dtype carrier, max_seq_len
|
423
|
+
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
|
424
|
+
else:
|
425
|
+
batch_size = inputs_embeds.shape[0]
|
426
|
+
if cache_position.shape[0] > 1:
|
427
|
+
position_embeds = []
|
428
|
+
for b_idx in range(batch_size):
|
429
|
+
position_embed = self.get_pos_embedding()(cache_position[b_idx])
|
430
|
+
position_embeds.append(position_embed)
|
431
|
+
|
432
|
+
position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
|
433
|
+
else:
|
434
|
+
position_embeds = self.get_pos_embedding()(cache_position)
|
435
|
+
hidden_states = hidden_states + position_embeds
|
436
|
+
cos, sin = None, None
|
437
|
+
|
438
|
+
# (batch, seq_len) -> (batch,)
|
439
|
+
seq_positions = cache_position[:, 0]
|
440
|
+
if self.attn_impl == "flash_attn":
|
441
|
+
max_seq_len = past_key_values[0][0].shape[-2]
|
442
|
+
seq_positions = self.convert_sequence_positions_for_flash_attn(
|
443
|
+
seq_positions=seq_positions, max_seq_len=max_seq_len
|
444
|
+
)
|
372
445
|
|
446
|
+
present_key_values = past_key_values
|
447
|
+
for layer in self.layers:
|
448
|
+
hidden_states, present_key_values = layer(
|
449
|
+
hidden_states=hidden_states,
|
450
|
+
attention_mask=attention_mask,
|
451
|
+
seq_positions=seq_positions,
|
452
|
+
batch_position=batch_position,
|
453
|
+
past_key_values=present_key_values,
|
454
|
+
cos=cos,
|
455
|
+
sin=sin,
|
456
|
+
)
|
373
457
|
|
374
|
-
|
375
|
-
|
376
|
-
super().__init__()
|
377
|
-
self.config = model.config
|
378
|
-
self.model = model.model
|
379
|
-
self.lm_head = model.lm_head
|
380
|
-
self.max_seq_len = max_seq_len
|
381
|
-
self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
382
|
-
|
383
|
-
if kvcache_partition_len is not None:
|
384
|
-
# WORKAROUND : for passing partition length as a value to the rbln compiler.
|
385
|
-
# What is actually used is the shape of this tensor.
|
386
|
-
self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
|
387
|
-
self.attn_implementation = "flash_attn_rbln"
|
388
|
-
logger.info(f"Using rbln-flash-attention. (partition length : {kvcache_partition_len})")
|
389
|
-
else:
|
390
|
-
self.kvcache_partition_size = None
|
391
|
-
self.attn_implementation = "eager"
|
458
|
+
hidden_states = self.get_last_layernorm()(hidden_states)
|
459
|
+
return hidden_states, present_key_values
|
392
460
|
|
393
|
-
def get_forward_dict(self):
|
394
|
-
forward_dict = {
|
395
|
-
"wrapper": DecoderOnlyModel.forward,
|
396
|
-
"model": DecoderOnlyDecoderLayer.forward,
|
397
|
-
"decoder_layer": DECODERONLY_ATTENTION_CLASSES[self.attn_implementation].forward,
|
398
|
-
}
|
399
|
-
return forward_dict
|
400
461
|
|
401
|
-
|
402
|
-
|
403
|
-
input_ids_or_inputs_embeds,
|
404
|
-
attention_mask,
|
405
|
-
cache_position,
|
406
|
-
batch_position,
|
407
|
-
query_idx,
|
408
|
-
*past_key_values,
|
409
|
-
):
|
410
|
-
if input_ids_or_inputs_embeds.ndim == 2:
|
411
|
-
# input_ids
|
412
|
-
input_ids = input_ids_or_inputs_embeds
|
413
|
-
inputs_embeds = None
|
414
|
-
elif input_ids_or_inputs_embeds.ndim == 3:
|
415
|
-
# inputs_embeds
|
416
|
-
input_ids = None
|
417
|
-
inputs_embeds = input_ids_or_inputs_embeds
|
418
|
-
else:
|
419
|
-
raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
|
462
|
+
class DecoderOnlyLayer(nn.Module):
|
463
|
+
"""A single transformer layer adapted for RBLN compilation with static shapes.
|
420
464
|
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
)
|
465
|
+
This layer implements a modified transformer block that includes:
|
466
|
+
1. Self-attention mechanism (either standard or flash attention)
|
467
|
+
2. Feed-forward network (FFN)
|
468
|
+
3. Layer normalization
|
469
|
+
4. Residual connections
|
427
470
|
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
elif self.attn_implementation == "flash_attn_rbln":
|
434
|
-
p_len = self.kvcache_partition_size.size()[0]
|
435
|
-
num_partition = self.max_seq_len // p_len
|
436
|
-
if self.max_seq_len % p_len > 0:
|
437
|
-
raise ValueError(
|
438
|
-
f"The partition length({p_len}) must be exactly divisible by the max_seq_len({self.max_seq_len})."
|
439
|
-
)
|
440
|
-
cache_pos_for_partitions = torch.zeros((batch_size, num_partition), dtype=torch.int32)
|
471
|
+
The layer is specifically designed to:
|
472
|
+
- Support compilation to RBLN custom ops
|
473
|
+
- Maintain static tensor shapes throughout computations
|
474
|
+
- Handle both prefill and decode phases efficiently
|
475
|
+
- Manage attention state transitions properly
|
441
476
|
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
cache_pos = decoding_step
|
446
|
-
for p_idx in range(num_partition):
|
447
|
-
input_0 = torch.tensor(cache_pos - p_len * p_idx, dtype=torch.int32)
|
448
|
-
input_1 = torch.tensor(p_len, dtype=torch.int32)
|
449
|
-
min = torch.minimum(input_0, input_1)
|
450
|
-
cache_pos_for_partition = torch.maximum(min, torch.tensor(0, dtype=torch.int32))
|
451
|
-
cache_pos_for_partitions[b_idx][p_idx] = cache_pos_for_partition
|
452
|
-
else: # prefill
|
453
|
-
cache_pos = cache_position[0][0]
|
454
|
-
for p_idx in range(num_partition):
|
455
|
-
input_0 = torch.tensor(cache_pos - p_len * p_idx, dtype=torch.int32)
|
456
|
-
input_1 = torch.tensor(p_len, dtype=torch.int32)
|
457
|
-
min = torch.minimum(input_0, input_1)
|
458
|
-
cache_pos_for_partition = torch.maximum(min, torch.tensor(0, dtype=torch.int32))
|
459
|
-
cache_pos_for_partitions[0][p_idx] = cache_pos_for_partition
|
460
|
-
else:
|
461
|
-
raise NotImplementedError(f"Unknown attn_implementation: {self.attn_implementation}")
|
477
|
+
Args:
|
478
|
+
layer: Original transformer layer module to wrap
|
479
|
+
self_attn (DecoderOnlyAttention): Modified attention module optimized for RBLN
|
462
480
|
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
attention_mask=attention_mask,
|
469
|
-
position_ids=cache_position,
|
470
|
-
past_key_values=past_key_values,
|
471
|
-
batch_ids=batch_position,
|
472
|
-
rotary_pos_emb=self.rotary_emb,
|
473
|
-
cache_pos_for_partitions=cache_pos_for_partitions,
|
474
|
-
kvcache_partition_size=self.kvcache_partition_size,
|
475
|
-
forward_dict=forward_dict,
|
476
|
-
)
|
481
|
+
Attributes:
|
482
|
+
_original_mod: Reference to original layer for accessing components
|
483
|
+
self_attn: Modified attention mechanism mapped to RBLN ops at compile time
|
484
|
+
phase: Current operation phase ("prefill" or "decode")
|
485
|
+
"""
|
477
486
|
|
478
|
-
|
479
|
-
|
480
|
-
|
487
|
+
def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
|
488
|
+
super().__init__()
|
489
|
+
self._original_mod = layer
|
490
|
+
self.self_attn = self_attn
|
491
|
+
self._phase = "prefill"
|
481
492
|
|
482
|
-
|
493
|
+
@property
|
494
|
+
def phase(self):
|
495
|
+
return self._phase
|
483
496
|
|
484
|
-
|
497
|
+
@phase.setter
|
498
|
+
def phase(self, phase: str):
|
499
|
+
self._phase = phase
|
500
|
+
self.self_attn.phase = phase
|
485
501
|
|
486
|
-
|
502
|
+
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
503
|
+
return self._original_mod.input_layernorm
|
487
504
|
|
505
|
+
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
506
|
+
return self._original_mod.post_attention_layernorm
|
488
507
|
|
489
|
-
class DecoderOnlyDecoderLayer:
|
490
508
|
def forward(
|
491
509
|
self,
|
492
510
|
hidden_states: torch.Tensor,
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
output_attentions: Optional[bool] = None,
|
498
|
-
use_cache: Optional[bool] = None,
|
499
|
-
batch_ids: Optional[torch.Tensor] = None,
|
511
|
+
attention_mask: torch.Tensor,
|
512
|
+
seq_positions: torch.LongTensor,
|
513
|
+
batch_position: torch.Tensor,
|
514
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
500
515
|
cos: Optional[torch.Tensor] = None,
|
501
516
|
sin: Optional[torch.Tensor] = None,
|
502
|
-
|
503
|
-
kvcache_partition_size: Optional[torch.Tensor] = None,
|
504
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
505
|
-
**kwargs,
|
506
|
-
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
517
|
+
):
|
507
518
|
residual = hidden_states
|
519
|
+
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
508
520
|
|
509
|
-
hidden_states = self.
|
510
|
-
|
511
|
-
hidden_states, self_attn_weight, k, v = forward_dict["decoder_layer"](
|
512
|
-
self.self_attn,
|
521
|
+
hidden_states, present_key_values = self.self_attn(
|
513
522
|
hidden_states=hidden_states,
|
514
523
|
attention_mask=attention_mask,
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
batch_index=batch_ids,
|
519
|
-
use_cache=use_cache,
|
524
|
+
seq_positions=seq_positions,
|
525
|
+
batch_position=batch_position,
|
526
|
+
past_key_values=past_key_values,
|
520
527
|
cos=cos,
|
521
528
|
sin=sin,
|
522
|
-
cache_pos_for_partitions=cache_pos_for_partitions,
|
523
|
-
kvcache_partition_size=kvcache_partition_size,
|
524
|
-
**kwargs,
|
525
529
|
)
|
526
|
-
past_key_value.assign(k, v, layer_idx)
|
527
|
-
|
528
530
|
hidden_states = residual + hidden_states
|
529
531
|
|
530
532
|
# Fully Connected
|
531
533
|
residual = hidden_states
|
532
|
-
hidden_states = self.
|
533
|
-
hidden_states = self.mlp(hidden_states)
|
534
|
+
hidden_states = self.get_post_attention_layernorm()(hidden_states)
|
535
|
+
hidden_states = self._original_mod.mlp(hidden_states)
|
534
536
|
hidden_states = residual + hidden_states
|
535
537
|
|
536
|
-
|
538
|
+
return hidden_states, present_key_values
|
539
|
+
|
540
|
+
|
541
|
+
class DecoderOnlyAttention(nn.Module):
|
542
|
+
"""Attention implementation for decoder-only models optimized for RBLN compilation.
|
543
|
+
|
544
|
+
This class implements a modified version of the standard attention mechanism that:
|
545
|
+
1. Supports static shape requirements for RBLN compilation
|
546
|
+
2. Handles explicit batch and position management
|
547
|
+
|
548
|
+
Args:
|
549
|
+
self_attn: Original attention module from the base model
|
550
|
+
"""
|
551
|
+
|
552
|
+
def __init__(self, self_attn):
|
553
|
+
super().__init__()
|
554
|
+
self._original_mod = self_attn
|
555
|
+
self.layer_idx = self_attn.layer_idx
|
556
|
+
self.num_heads = self._original_mod.num_heads
|
557
|
+
self.head_dim = self._original_mod.head_dim
|
558
|
+
self._phase = "prefill"
|
559
|
+
self.scale = torch.tensor(self.get_attn_scale())
|
560
|
+
|
561
|
+
if hasattr(self._original_mod, "num_key_value_heads"):
|
562
|
+
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
563
|
+
else:
|
564
|
+
self.num_key_value_heads = self._original_mod.num_heads
|
565
|
+
|
566
|
+
self.attention = self.get_attention()
|
567
|
+
self.__post_init__()
|
568
|
+
|
569
|
+
@property
|
570
|
+
def phase(self):
|
571
|
+
return self._phase
|
572
|
+
|
573
|
+
@phase.setter
|
574
|
+
def phase(self, phase: str):
|
575
|
+
self._phase = phase
|
576
|
+
self.attention.phase = phase
|
537
577
|
|
538
|
-
|
539
|
-
|
578
|
+
def get_attention(self):
|
579
|
+
return AttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads)
|
540
580
|
|
541
|
-
|
542
|
-
|
581
|
+
def __post_init__(self):
|
582
|
+
self.q_proj = self._original_mod.q_proj
|
583
|
+
self.k_proj = self._original_mod.k_proj
|
584
|
+
self.v_proj = self._original_mod.v_proj
|
585
|
+
self.o_proj = self._original_mod.o_proj
|
543
586
|
|
544
|
-
|
587
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
588
|
+
"""Projects input hidden states into query, key, and value representations.
|
545
589
|
|
590
|
+
Args:
|
591
|
+
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
|
592
|
+
|
593
|
+
Returns:
|
594
|
+
Tuple of (query_states, key_states, value_states)
|
595
|
+
"""
|
596
|
+
query_states = self.q_proj(hidden_states)
|
597
|
+
key_states = self.k_proj(hidden_states)
|
598
|
+
value_states = self.v_proj(hidden_states)
|
599
|
+
return query_states, key_states, value_states
|
600
|
+
|
601
|
+
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
602
|
+
return apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
603
|
+
|
604
|
+
def get_attn_scale(self):
|
605
|
+
return 1 / math.sqrt(self.head_dim)
|
546
606
|
|
547
|
-
class DecoderOnlyModel:
|
548
607
|
def forward(
|
549
608
|
self,
|
550
|
-
|
551
|
-
attention_mask:
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
cache_pos_for_partitions: Optional[torch.Tensor] = None,
|
560
|
-
kvcache_partition_size: Optional[torch.Tensor] = None,
|
561
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
562
|
-
rotary_pos_emb=None,
|
563
|
-
) -> BaseModelOutputWithPast:
|
564
|
-
# retrieve input_ids and inputs_embeds
|
565
|
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
566
|
-
raise ValueError(
|
567
|
-
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
568
|
-
)
|
609
|
+
hidden_states: torch.Tensor,
|
610
|
+
attention_mask: torch.Tensor,
|
611
|
+
seq_positions: torch.LongTensor,
|
612
|
+
batch_position: torch.Tensor,
|
613
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
614
|
+
cos: Optional[torch.Tensor] = None,
|
615
|
+
sin: Optional[torch.Tensor] = None,
|
616
|
+
):
|
617
|
+
batch_size, query_length, _ = hidden_states.size()
|
569
618
|
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
#
|
578
|
-
|
579
|
-
cos
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
619
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
620
|
+
|
621
|
+
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
622
|
+
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
623
|
+
value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
|
624
|
+
1, 2
|
625
|
+
)
|
626
|
+
# b, num_head, query, head_dim
|
627
|
+
|
628
|
+
if cos is not None and sin is not None:
|
629
|
+
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
630
|
+
|
631
|
+
if batch_size > 1 and self.phase == "prefill":
|
632
|
+
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
633
|
+
|
634
|
+
# TODO(jongho): flash attn legacy. (clone)
|
635
|
+
_seq_positions = seq_positions.clone().unsqueeze(1)
|
636
|
+
|
637
|
+
_key_states = []
|
638
|
+
_value_states = []
|
639
|
+
_attn_outputs = []
|
640
|
+
for b in range(batch_size):
|
641
|
+
seq_position = _seq_positions[b][0]
|
642
|
+
attn_output, key_state, value_state = self.attention(
|
643
|
+
query_states[b].unsqueeze(0),
|
644
|
+
key_states[b].unsqueeze(0),
|
645
|
+
value_states[b].unsqueeze(0),
|
646
|
+
attention_mask[b].unsqueeze(0) if self.phase == "decode" else attention_mask,
|
647
|
+
past_key_state=past_key_values[self.layer_idx][0],
|
648
|
+
past_value_state=past_key_values[self.layer_idx][1],
|
649
|
+
batch_position=b if self.phase == "decode" else batch_position,
|
650
|
+
seq_position=seq_position,
|
651
|
+
scale=self.scale,
|
603
652
|
)
|
653
|
+
_key_states.append(key_state)
|
654
|
+
_value_states.append(value_state)
|
655
|
+
_attn_outputs.append(attn_output)
|
656
|
+
key_states = torch.cat(_key_states, dim=0)
|
657
|
+
value_states = torch.cat(_value_states, dim=0)
|
658
|
+
attn_outputs = torch.cat(_attn_outputs, dim=0)
|
604
659
|
|
605
|
-
|
660
|
+
attn_outputs = self.o_proj(attn_outputs)
|
661
|
+
past_key_values[self.layer_idx] = key_states, value_states
|
662
|
+
return attn_outputs, past_key_values
|
606
663
|
|
607
|
-
updated_cache = layer_outputs[2 if output_attentions else 1]
|
608
664
|
|
609
|
-
|
610
|
-
|
665
|
+
class AttentionOp(nn.Module):
|
666
|
+
def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int):
|
667
|
+
super().__init__()
|
668
|
+
self.num_heads = num_heads
|
669
|
+
self.head_dim = head_dim
|
670
|
+
self.num_key_value_heads = num_key_value_heads
|
671
|
+
self.phase = "prefill"
|
611
672
|
|
612
|
-
|
673
|
+
def forward(
|
674
|
+
self,
|
675
|
+
query_state: torch.Tensor,
|
676
|
+
key_state: torch.Tensor,
|
677
|
+
value_state: torch.Tensor,
|
678
|
+
attn_mask: torch.Tensor,
|
679
|
+
batch_position: torch.Tensor,
|
680
|
+
past_key_state: torch.Tensor,
|
681
|
+
past_value_state: torch.Tensor,
|
682
|
+
seq_position: torch.Tensor,
|
683
|
+
scale: torch.Tensor,
|
684
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
685
|
+
"""Compute attention with static shapes and explicit cache management.
|
686
|
+
|
687
|
+
Args:
|
688
|
+
query_state: Query tensor [1, num_heads, 1, head_dim]
|
689
|
+
key_state: Key tensor [1, num_heads, seq_len, head_dim]
|
690
|
+
value_state: Value tensor [1, num_heads, seq_len, head_dim]
|
691
|
+
attn_mask: Attention mask tensor ∈ {0, 1}
|
692
|
+
batch_position: Batch index for cache lookup
|
693
|
+
past_key_state: Previous key cache states
|
694
|
+
past_value_state: Previous value cache states
|
695
|
+
seq_position: Current position in sequence
|
696
|
+
scale: Scale applied to attn weights
|
697
|
+
|
698
|
+
Returns:
|
699
|
+
Tuple of (attention_output, key_state, value_state)
|
700
|
+
"""
|
701
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
702
|
+
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
703
|
+
value_state = value_state.unsqueeze(2)
|
704
|
+
attn_mask = attn_mask.unsqueeze(2)
|
613
705
|
|
614
|
-
|
615
|
-
|
616
|
-
|
706
|
+
query_state = query_state.view(
|
707
|
+
1,
|
708
|
+
self.num_key_value_heads,
|
709
|
+
self.num_heads // self.num_key_value_heads,
|
710
|
+
-1, # seq len
|
711
|
+
self.head_dim,
|
712
|
+
)
|
617
713
|
|
618
|
-
|
619
|
-
|
714
|
+
if self.phase == "decode":
|
715
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.attn_decode(
|
716
|
+
query_state,
|
717
|
+
key_state,
|
718
|
+
value_state,
|
719
|
+
attn_mask,
|
720
|
+
past_key_state.unsqueeze(2),
|
721
|
+
past_value_state.unsqueeze(2),
|
722
|
+
seq_position,
|
723
|
+
scale,
|
724
|
+
)
|
620
725
|
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
726
|
+
else:
|
727
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.attn_prefill(
|
728
|
+
query_state,
|
729
|
+
key_state,
|
730
|
+
value_state,
|
731
|
+
attn_mask,
|
732
|
+
past_key_state.unsqueeze(2),
|
733
|
+
past_value_state.unsqueeze(2),
|
734
|
+
batch_position,
|
735
|
+
seq_position,
|
736
|
+
scale,
|
737
|
+
)
|
627
738
|
|
739
|
+
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
|
740
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
741
|
+
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
742
|
+
|
743
|
+
return attn_output, key_state.squeeze(2), value_state.squeeze(2)
|
628
744
|
|
629
|
-
|
630
|
-
|
631
|
-
|
745
|
+
|
746
|
+
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
747
|
+
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
|
748
|
+
if cache_position.shape[0] > 1:
|
632
749
|
cos_all = []
|
633
750
|
sin_all = []
|
634
|
-
for i in range(
|
635
|
-
cos_all.append(cos[
|
636
|
-
sin_all.append(sin[
|
751
|
+
for i in range(cache_position.shape[0]):
|
752
|
+
cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
753
|
+
sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
637
754
|
cos = torch.cat(cos_all, dim=0)
|
638
755
|
sin = torch.cat(sin_all, dim=0)
|
639
756
|
else:
|
640
|
-
cos = cos[
|
641
|
-
sin = sin[
|
757
|
+
cos = cos[cache_position].unsqueeze(unsqueeze_dim)
|
758
|
+
sin = sin[cache_position].unsqueeze(unsqueeze_dim)
|
642
759
|
|
643
760
|
return cos, sin
|
644
761
|
|
@@ -658,6 +775,26 @@ def apply_rotary_pos_emb(q, k, cos, sin):
|
|
658
775
|
return q_embed, k_embed
|
659
776
|
|
660
777
|
|
778
|
+
def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
|
779
|
+
# Partial rotary embedding
|
780
|
+
query_rot, query_pass = (
|
781
|
+
query_states[..., :ndim],
|
782
|
+
query_states[..., ndim:],
|
783
|
+
)
|
784
|
+
key_rot, key_pass = (
|
785
|
+
key_states[..., :ndim],
|
786
|
+
key_states[..., ndim:],
|
787
|
+
)
|
788
|
+
|
789
|
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
790
|
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
791
|
+
|
792
|
+
# [batch_size, seq_length, num_heads, head_dim]
|
793
|
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
794
|
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
795
|
+
return query_states, key_states
|
796
|
+
|
797
|
+
|
661
798
|
class RotaryEmbedding(nn.Module):
|
662
799
|
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
663
800
|
|
@@ -674,14 +811,14 @@ class RotaryEmbedding(nn.Module):
|
|
674
811
|
rope_type = "default"
|
675
812
|
|
676
813
|
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
677
|
-
|
678
|
-
|
814
|
+
cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
|
815
|
+
cache_position_expanded = cache_position[:, None]
|
679
816
|
|
680
817
|
if rope_type == "dynamic":
|
681
|
-
freqs =
|
818
|
+
freqs = cache_position_expanded.float() * inv_freq.float()
|
682
819
|
else:
|
683
820
|
inv_freq_expanded = inv_freq[None, :]
|
684
|
-
freqs =
|
821
|
+
freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
|
685
822
|
|
686
823
|
emb = torch.cat((freqs, freqs), dim=-1)
|
687
824
|
|
@@ -696,3 +833,127 @@ class RotaryEmbedding(nn.Module):
|
|
696
833
|
self._cos_cached[:seq_len].to(dtype=x.dtype),
|
697
834
|
self._sin_cached[:seq_len].to(dtype=x.dtype),
|
698
835
|
)
|
836
|
+
|
837
|
+
|
838
|
+
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
839
|
+
def __init__(self, self_attn, kvcache_partition_len):
|
840
|
+
self.kvcache_partition_size = kvcache_partition_len
|
841
|
+
super().__init__(self_attn=self_attn)
|
842
|
+
|
843
|
+
def get_attention(self):
|
844
|
+
return FlashAttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads, self.kvcache_partition_size)
|
845
|
+
|
846
|
+
def forward(
|
847
|
+
self,
|
848
|
+
hidden_states: torch.Tensor,
|
849
|
+
attention_mask: torch.Tensor,
|
850
|
+
seq_positions: torch.LongTensor,
|
851
|
+
batch_position: torch.Tensor,
|
852
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
853
|
+
cos: Optional[torch.Tensor] = None,
|
854
|
+
sin: Optional[torch.Tensor] = None,
|
855
|
+
):
|
856
|
+
batch_size, query_length, _ = hidden_states.size()
|
857
|
+
|
858
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
859
|
+
|
860
|
+
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
861
|
+
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
862
|
+
value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
|
863
|
+
1, 2
|
864
|
+
)
|
865
|
+
# b, num_head, query, head_dim
|
866
|
+
|
867
|
+
if cos is not None and sin is not None:
|
868
|
+
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
869
|
+
|
870
|
+
_key_states = []
|
871
|
+
_value_states = []
|
872
|
+
_attn_outputs = []
|
873
|
+
for b in range(batch_size):
|
874
|
+
seq_position = seq_positions[b][0] # FIXME: Remove take-take pattern matching
|
875
|
+
attn_output, key_state, value_state = self.attention(
|
876
|
+
query_states[b].unsqueeze(0),
|
877
|
+
key_states[b].unsqueeze(0),
|
878
|
+
value_states[b].unsqueeze(0),
|
879
|
+
attention_mask[b].unsqueeze(0) if self.phase == "decode" else attention_mask,
|
880
|
+
past_key_state=past_key_values[self.layer_idx][0],
|
881
|
+
past_value_state=past_key_values[self.layer_idx][1],
|
882
|
+
batch_position=b if self.phase == "decode" else batch_position,
|
883
|
+
seq_position=seq_position,
|
884
|
+
scale=self.scale,
|
885
|
+
)
|
886
|
+
_key_states.append(key_state)
|
887
|
+
_value_states.append(value_state)
|
888
|
+
_attn_outputs.append(attn_output)
|
889
|
+
key_states = torch.cat(_key_states, dim=0)
|
890
|
+
value_states = torch.cat(_value_states, dim=0)
|
891
|
+
attn_outputs = torch.cat(_attn_outputs, dim=0)
|
892
|
+
|
893
|
+
attn_outputs = self.o_proj(attn_outputs)
|
894
|
+
past_key_values[self.layer_idx] = key_states, value_states
|
895
|
+
return attn_outputs, past_key_values
|
896
|
+
|
897
|
+
|
898
|
+
class FlashAttentionOp(AttentionOp):
|
899
|
+
def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int, kvcache_partition_len: int):
|
900
|
+
super().__init__(num_heads=num_heads, head_dim=head_dim, num_key_value_heads=num_key_value_heads)
|
901
|
+
self.kvcache_partition_size = kvcache_partition_len
|
902
|
+
|
903
|
+
def forward(
|
904
|
+
self,
|
905
|
+
query_state,
|
906
|
+
key_state,
|
907
|
+
value_state,
|
908
|
+
attn_mask,
|
909
|
+
batch_position,
|
910
|
+
past_key_state,
|
911
|
+
past_value_state,
|
912
|
+
seq_position,
|
913
|
+
scale,
|
914
|
+
):
|
915
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
916
|
+
key_state = key_state.unsqueeze(2)
|
917
|
+
value_state = value_state.unsqueeze(2)
|
918
|
+
attn_mask = attn_mask.unsqueeze(2)
|
919
|
+
|
920
|
+
query_state = query_state.view(
|
921
|
+
1,
|
922
|
+
self.num_key_value_heads,
|
923
|
+
self.num_heads // self.num_key_value_heads,
|
924
|
+
-1, # seq len
|
925
|
+
self.head_dim,
|
926
|
+
)
|
927
|
+
|
928
|
+
if self.phase == "decode":
|
929
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_decode(
|
930
|
+
query_state,
|
931
|
+
key_state,
|
932
|
+
value_state,
|
933
|
+
attn_mask,
|
934
|
+
past_key_state.unsqueeze(2),
|
935
|
+
past_value_state.unsqueeze(2),
|
936
|
+
seq_position,
|
937
|
+
scale,
|
938
|
+
self.kvcache_partition_size,
|
939
|
+
)
|
940
|
+
else:
|
941
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_prefill(
|
942
|
+
query_state,
|
943
|
+
key_state,
|
944
|
+
value_state,
|
945
|
+
attn_mask,
|
946
|
+
past_key_state.unsqueeze(2),
|
947
|
+
past_value_state.unsqueeze(2),
|
948
|
+
batch_position,
|
949
|
+
seq_position,
|
950
|
+
scale,
|
951
|
+
self.kvcache_partition_size,
|
952
|
+
)
|
953
|
+
|
954
|
+
# reshape for removing repeat_kv
|
955
|
+
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
|
956
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
957
|
+
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
958
|
+
|
959
|
+
return attn_output, key_state, value_state
|