optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +27 -13
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +85 -65
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
- optimum/rbln/diffusers/pipelines/__init__.py +60 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +176 -763
- optimum/rbln/modeling_diffusers.py +329 -0
- optimum/rbln/transformers/__init__.py +2 -2
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
- optimum/rbln/utils/decorator_utils.py +59 -0
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +5 -5
- optimum/rbln/utils/submodule.py +114 -0
- optimum/rbln/utils/timer_utils.py +2 -2
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- optimum_rbln-0.1.15.dist-info/RECORD +110 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum_rbln-0.1.12.dist-info/METADATA +0 -119
- optimum_rbln-0.1.12.dist-info/RECORD +0 -103
- optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -22,74 +22,209 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import math
|
25
|
-
from typing import
|
25
|
+
from typing import List, Optional, Tuple
|
26
26
|
|
27
27
|
import torch
|
28
28
|
from torch import nn
|
29
|
-
from transformers
|
30
|
-
|
29
|
+
from transformers import PretrainedConfig, PreTrainedModel
|
30
|
+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
31
|
+
|
32
|
+
from ....utils import logging
|
33
|
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
34
|
+
|
35
|
+
|
36
|
+
if is_torch_greater_or_equal_than_2_4:
|
37
|
+
register_fake = torch.library.register_fake
|
38
|
+
else:
|
39
|
+
register_fake = torch.library.impl_abstract
|
40
|
+
|
41
|
+
|
42
|
+
logger = logging.get_logger(__name__)
|
43
|
+
"""
|
44
|
+
##############################################################################
|
45
|
+
# RBLN custom operation (python interface)
|
46
|
+
# torch.compile custom operation
|
47
|
+
# torch.library.define - kernel declaration
|
48
|
+
# torch.library.impl - kernel implementation
|
49
|
+
# torch.library.impl_abstract - symbolic trace
|
50
|
+
##############################################################################
|
51
|
+
"""
|
52
|
+
|
53
|
+
# RBLN custom op(flash attention decode)
|
54
|
+
torch.library.define(
|
55
|
+
"rbln_custom_ops::flash_attn_decode",
|
56
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
|
31
57
|
)
|
32
58
|
|
33
|
-
from ...cache_utils import RebelDynamicCache
|
34
59
|
|
60
|
+
@torch.library.impl("rbln_custom_ops::flash_attn_decode", "cpu")
|
61
|
+
def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, partition):
|
62
|
+
"""
|
63
|
+
WORKAROUND:
|
64
|
+
Partition is declared as an argument to the function, even though it is
|
65
|
+
not actually used in the CPU implementation, this allows the rbln compiler
|
66
|
+
to perform flash attention operations with partition as an argument.
|
67
|
+
"""
|
68
|
+
assert kcache.dim() == k.dim()
|
69
|
+
assert vcache.dim() == v.dim()
|
70
|
+
assert k.size(-2) == v.size(-2)
|
71
|
+
assert partition.dim() == 1
|
72
|
+
b = 0
|
73
|
+
if seq.dim() == 1:
|
74
|
+
s = seq[0]
|
75
|
+
elif seq.dim() == 0:
|
76
|
+
s = seq
|
77
|
+
else:
|
78
|
+
assert False
|
79
|
+
e = s + k.size(-2)
|
80
|
+
updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
|
81
|
+
updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
|
82
|
+
attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
|
83
|
+
attn_weight = attn_weight + mask
|
84
|
+
attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
|
85
|
+
attn_output = torch.matmul(attn_weight, updated_v)
|
86
|
+
return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
|
87
|
+
|
88
|
+
|
89
|
+
@register_fake("rbln_custom_ops::flash_attn_decode")
|
90
|
+
def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
|
91
|
+
return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
|
92
|
+
|
93
|
+
|
94
|
+
# RBLN custom op(flash attention prefill)
|
95
|
+
torch.library.define(
|
96
|
+
"rbln_custom_ops::flash_attn_prefill",
|
97
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
|
98
|
+
)
|
99
|
+
|
100
|
+
|
101
|
+
@torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
|
102
|
+
def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, partition):
|
103
|
+
"""
|
104
|
+
WORKAROUND:
|
105
|
+
Partition is declared as an argument to the function, even though it is
|
106
|
+
not actually used in the CPU implementation, this allows the rbln compiler
|
107
|
+
to perform flash attention operations with partition as an argument.
|
108
|
+
"""
|
109
|
+
assert kcache.dim() == k.dim()
|
110
|
+
assert vcache.dim() == v.dim()
|
111
|
+
assert k.size(-2) == v.size(-2)
|
112
|
+
assert partition.dim() == 1
|
113
|
+
if batch.dim() == 1:
|
114
|
+
b = batch[0]
|
115
|
+
elif batch.dim() == 0:
|
116
|
+
b = batch
|
117
|
+
else:
|
118
|
+
assert False
|
119
|
+
if seq.dim() == 1:
|
120
|
+
s = seq[0]
|
121
|
+
elif seq.dim() == 0:
|
122
|
+
s = seq
|
123
|
+
else:
|
124
|
+
assert False
|
125
|
+
e = s + k.size(-2)
|
126
|
+
updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
|
127
|
+
updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
|
128
|
+
attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
|
129
|
+
attn_weight = attn_weight + mask
|
130
|
+
attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
|
131
|
+
attn_output = torch.matmul(attn_weight, updated_v)
|
132
|
+
return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
|
133
|
+
|
134
|
+
|
135
|
+
@register_fake("rbln_custom_ops::flash_attn_prefill")
|
136
|
+
def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
|
137
|
+
return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
|
138
|
+
|
139
|
+
|
140
|
+
# RBLN custom op(cache update)
|
141
|
+
torch.library.define("rbln_custom_ops::rbln_cache_update", "(Tensor x, Tensor y, Tensor z, Tensor w) -> Tensor")
|
142
|
+
|
143
|
+
|
144
|
+
@torch.library.impl("rbln_custom_ops::rbln_cache_update", "cpu")
|
145
|
+
def rbln_cache_update_cpu(cache, value, batch, seq):
|
146
|
+
updated_cache = cache[batch].slice_scatter(value, dim=-2, start=batch[0], end=batch[0] + seq[0])
|
147
|
+
return updated_cache
|
148
|
+
|
149
|
+
|
150
|
+
@register_fake("rbln_custom_ops::rbln_cache_update")
|
151
|
+
def rbln_cache_update_abstract(cache, value, batch, seq):
|
152
|
+
return torch.empty_like(cache)
|
153
|
+
|
154
|
+
|
155
|
+
class DecoderOnlyWrapper(nn.Module):
|
156
|
+
"""A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
|
157
|
+
|
158
|
+
This wrapper is designed to:
|
159
|
+
1. Convert Huggingface decoder models for RBLN compilation with static shapes
|
160
|
+
2. Handle input/model mapping and additional information supply (e.g., positional embeddings)
|
161
|
+
3. Manage different attention implementations (standard and flash attention)
|
162
|
+
4. Support both prefill and decode phases
|
163
|
+
|
164
|
+
Notes:
|
165
|
+
- Wrapper must only receive positional arguments in forward() due to torch.jit.trace dependency
|
166
|
+
- Wrapper should not contain neural network graph operations (including memory view handling)
|
167
|
+
|
168
|
+
Args:
|
169
|
+
causal_lm (PreTrainedModel): The Huggingface causal language model to wrap
|
170
|
+
max_seq_len (int): Maximum sequence length for position embeddings and cache sizes
|
171
|
+
use_rotary_emb (bool): Whether to use rotary position embeddings
|
172
|
+
kvcache_partition_len (Optional[int]): Length of KV cache partitions for flash attention.
|
173
|
+
If provided, uses flash attention; if None, uses standard attention
|
174
|
+
"""
|
35
175
|
|
36
|
-
|
37
|
-
def __init__(self, model, max_seq_len):
|
176
|
+
def __init__(self, causal_lm: PreTrainedModel, max_seq_len, use_rotary_emb: bool, kvcache_partition_len=None):
|
38
177
|
super().__init__()
|
39
|
-
self.config =
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
self.head_dim = (
|
44
|
-
self.config.head_dim
|
45
|
-
if hasattr(self.config, "head_dim")
|
46
|
-
else self.config.hidden_size // self.config.num_attention_heads
|
47
|
-
)
|
48
|
-
self.max_position_embeddings = (
|
49
|
-
self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
|
50
|
-
)
|
51
|
-
self.max_seq_len = max_seq_len
|
52
|
-
self.rope_scaling = getattr(self.config, "rope_scaling", None)
|
53
|
-
self.rotary_emb = self._init_rope()
|
54
|
-
|
55
|
-
def _init_rope(self):
|
56
|
-
if self.rope_scaling is None:
|
57
|
-
rotary_emb = RotaryEmbedding(
|
58
|
-
self.head_dim,
|
59
|
-
max_position_embeddings=self.max_position_embeddings,
|
60
|
-
base=self.config.rope_theta,
|
61
|
-
)
|
178
|
+
self.config = causal_lm.config
|
179
|
+
|
180
|
+
if use_rotary_emb:
|
181
|
+
self.rotary_emb = self.get_rotary_emb(max_seq_len=max_seq_len)
|
62
182
|
else:
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
183
|
+
self.rotary_emb = None
|
184
|
+
|
185
|
+
if kvcache_partition_len is not None:
|
186
|
+
# WORKAROUND : for passing partition length as a value to the rbln compiler.
|
187
|
+
# What is actually used is the shape of this tensor.
|
188
|
+
self.attn_impl = "flash_attn"
|
189
|
+
logger.info(f"Using flash-attention. (partition length : {kvcache_partition_len})")
|
190
|
+
else:
|
191
|
+
self.attn_impl = "eager"
|
192
|
+
self.kvcache_partition_len = kvcache_partition_len
|
193
|
+
|
194
|
+
self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm)
|
195
|
+
|
196
|
+
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
|
197
|
+
self._phase = "prefill"
|
198
|
+
|
199
|
+
def get_rotary_emb(self, max_seq_len):
|
200
|
+
return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
201
|
+
|
202
|
+
def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel):
|
203
|
+
new_layers = []
|
204
|
+
for layer in causal_lm.model.layers:
|
205
|
+
if self.attn_impl == "eager":
|
206
|
+
new_self_attn = DecoderOnlyAttention(layer.self_attn)
|
207
|
+
elif self.attn_impl == "flash_attn":
|
208
|
+
new_self_attn = DecoderOnlyFlashAttention(
|
209
|
+
layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
|
80
210
|
)
|
81
211
|
else:
|
82
|
-
raise
|
212
|
+
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
213
|
+
|
214
|
+
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
215
|
+
new_layers.append(new_layer)
|
216
|
+
new_model = DecoderOnlyModel(causal_lm.model, new_layers)
|
217
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
218
|
+
return new_causal_lm
|
83
219
|
|
84
|
-
|
220
|
+
@property
|
221
|
+
def phase(self):
|
222
|
+
return self._phase
|
85
223
|
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
"decoder_layer": DecoderOnlyAttention.forward,
|
91
|
-
}
|
92
|
-
return forward_dict
|
224
|
+
@phase.setter
|
225
|
+
def phase(self, phase: str):
|
226
|
+
self._phase = phase
|
227
|
+
self.causal_lm.phase = phase
|
93
228
|
|
94
229
|
def forward(
|
95
230
|
self,
|
@@ -97,324 +232,514 @@ class DecoderOnlyWrapper(torch.nn.Module):
|
|
97
232
|
attention_mask,
|
98
233
|
cache_position,
|
99
234
|
batch_position,
|
100
|
-
|
235
|
+
query_position,
|
101
236
|
*past_key_values,
|
102
237
|
):
|
103
|
-
if input_ids_or_inputs_embeds.shape[1] == 1:
|
104
|
-
rbln_batch_position = None
|
105
|
-
else:
|
106
|
-
rbln_batch_position = batch_position
|
107
|
-
|
108
238
|
if input_ids_or_inputs_embeds.ndim == 2:
|
109
|
-
# input_ids
|
239
|
+
# It is input_ids
|
110
240
|
input_ids = input_ids_or_inputs_embeds
|
111
241
|
inputs_embeds = None
|
112
242
|
elif input_ids_or_inputs_embeds.ndim == 3:
|
113
|
-
# inputs_embeds
|
243
|
+
# It is inputs_embeds
|
114
244
|
input_ids = None
|
115
245
|
inputs_embeds = input_ids_or_inputs_embeds
|
116
246
|
else:
|
117
247
|
raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
|
118
248
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
*past_key_values,
|
124
|
-
)
|
249
|
+
if len(past_key_values) != 2 * self.num_hidden_layers:
|
250
|
+
raise ValueError(
|
251
|
+
f"Different past_key_values to model's config. {len(past_key_values)} != {self.num_hidden_layers}"
|
252
|
+
)
|
125
253
|
|
126
|
-
|
127
|
-
|
128
|
-
self.
|
254
|
+
seq_len = input_ids_or_inputs_embeds.shape[1]
|
255
|
+
if seq_len == 1:
|
256
|
+
self.phase = "decode"
|
257
|
+
else:
|
258
|
+
self.phase = "prefill"
|
259
|
+
|
260
|
+
# [key, value] * n_layer -> ( (key, value) ) * n_layer
|
261
|
+
# cache shape : batch, n_heads, 1, max_seq_len, head_dim
|
262
|
+
_past_key_values = []
|
263
|
+
for i in range(self.config.num_hidden_layers):
|
264
|
+
key_states = past_key_values[i * 2]
|
265
|
+
value_states = past_key_values[i * 2 + 1]
|
266
|
+
past_key_value = [key_states, value_states]
|
267
|
+
_past_key_values.append(past_key_value)
|
268
|
+
past_key_values = _past_key_values
|
269
|
+
|
270
|
+
logit, present_key_values = self.causal_lm(
|
129
271
|
input_ids=input_ids,
|
130
272
|
inputs_embeds=inputs_embeds,
|
131
273
|
attention_mask=attention_mask,
|
132
|
-
|
274
|
+
cache_position=cache_position,
|
275
|
+
batch_position=batch_position,
|
276
|
+
query_position=query_position,
|
133
277
|
past_key_values=past_key_values,
|
134
|
-
|
135
|
-
rotary_pos_emb=self.rotary_emb,
|
136
|
-
forward_dict=forward_dict,
|
278
|
+
rotary_emb=self.rotary_emb,
|
137
279
|
)
|
138
280
|
|
139
|
-
|
140
|
-
|
141
|
-
|
281
|
+
# ((key, value)) * n_layer -> [key, value] * n_layer
|
282
|
+
_present_key_values = ()
|
283
|
+
for i in range(self.num_hidden_layers):
|
284
|
+
key_states = present_key_values[i][0]
|
285
|
+
value_states = present_key_values[i][1]
|
286
|
+
_present_key_values = _present_key_values + (key_states, value_states)
|
287
|
+
present_key_values = _present_key_values
|
142
288
|
|
143
|
-
|
289
|
+
# batch_position + query_position is dummy output node to keep the number of outputs
|
290
|
+
return logit, present_key_values, batch_position + query_position
|
144
291
|
|
145
|
-
output = (logits,) + outputs[1:]
|
146
292
|
|
147
|
-
|
293
|
+
class DecoderOnlyForCausalLM(nn.Module):
|
294
|
+
"""A specialized wrapper for Causal Language Models optimized for RBLN compilation.
|
148
295
|
|
296
|
+
This class adapts Huggingface's CausalLM (or similar models) for RBLN deployment by:
|
297
|
+
1. Managing model phases (prefill/decode) throughout the computation graph
|
298
|
+
2. Handling output shape alignments for static compilation
|
299
|
+
3. Coordinating between the original model and RBLN-optimized components
|
300
|
+
|
301
|
+
The class serves as an intermediate layer between DecoderOnlyWrapper and the core model,
|
302
|
+
focusing on maintaining correct model behavior while enabling RBLN-specific optimizations.
|
303
|
+
|
304
|
+
Args:
|
305
|
+
causal_lm (PreTrainedModel): Original Huggingface causal language model
|
306
|
+
model (DecoderOnlyModel): RBLN-optimized model instance
|
307
|
+
|
308
|
+
Attributes:
|
309
|
+
config: Configuration from the original causal language model
|
310
|
+
_original_mod: Reference to the original model for components like lm_head
|
311
|
+
model: RBLN-optimized decoder model instance
|
312
|
+
_phase: Current processing phase ("prefill" or "decode")
|
313
|
+
"""
|
314
|
+
|
315
|
+
def __init__(self, causal_lm: PreTrainedModel, model):
|
316
|
+
super().__init__()
|
317
|
+
self.config = causal_lm.config
|
318
|
+
self._original_mod = causal_lm
|
319
|
+
self.model = model
|
320
|
+
self._phase = "prefill"
|
321
|
+
|
322
|
+
@property
|
323
|
+
def phase(self):
|
324
|
+
return self._phase
|
325
|
+
|
326
|
+
@phase.setter
|
327
|
+
def phase(self, phase: str):
|
328
|
+
self._phase = phase
|
329
|
+
self.model.phase = phase
|
149
330
|
|
150
|
-
class DecoderOnlyAttention:
|
151
331
|
def forward(
|
152
332
|
self,
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
)
|
162
|
-
|
333
|
+
input_ids: torch.Tensor = None,
|
334
|
+
inputs_embeds: torch.Tensor = None,
|
335
|
+
attention_mask: torch.Tensor = None,
|
336
|
+
cache_position: torch.Tensor = None,
|
337
|
+
batch_position: torch.Tensor = None,
|
338
|
+
query_position: torch.Tensor = None,
|
339
|
+
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
340
|
+
rotary_emb: nn.Module = None,
|
341
|
+
):
|
342
|
+
# outputs
|
343
|
+
hidden_states, present_key_values = self.model(
|
344
|
+
input_ids=input_ids,
|
345
|
+
inputs_embeds=inputs_embeds,
|
346
|
+
attention_mask=attention_mask,
|
347
|
+
cache_position=cache_position,
|
348
|
+
batch_position=batch_position,
|
349
|
+
past_key_values=past_key_values,
|
350
|
+
rotary_emb=rotary_emb,
|
351
|
+
)
|
163
352
|
|
164
|
-
|
165
|
-
|
166
|
-
value_states = self.v_proj(hidden_states)
|
353
|
+
if self.phase == "prefill":
|
354
|
+
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
167
355
|
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
173
|
-
|
174
|
-
# Decoder
|
175
|
-
if (batch_index is None or batch_index == -1) and bsz > 1:
|
176
|
-
all_key_states = []
|
177
|
-
all_value_states = []
|
178
|
-
all_attn_output = []
|
179
|
-
|
180
|
-
for b in range(bsz):
|
181
|
-
query_state = query_states[b].unsqueeze(0)
|
182
|
-
attn_mask = attention_mask[b].unsqueeze(0)
|
183
|
-
key_state = key_states[b].unsqueeze(0)
|
184
|
-
value_state = value_states[b].unsqueeze(0)
|
185
|
-
|
186
|
-
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
187
|
-
key_state = key_state.unsqueeze(2)
|
188
|
-
value_state = value_state.unsqueeze(2)
|
189
|
-
attn_mask = attn_mask.unsqueeze(2)
|
190
|
-
|
191
|
-
query_state = query_state.view(
|
192
|
-
1,
|
193
|
-
self.num_key_value_heads,
|
194
|
-
self.num_heads // self.num_key_value_heads,
|
195
|
-
q_len,
|
196
|
-
self.head_dim,
|
197
|
-
)
|
356
|
+
logits = self._original_mod.lm_head(hidden_states)
|
357
|
+
output = (logits, present_key_values)
|
358
|
+
return output
|
198
359
|
|
199
|
-
key_state, value_state = past_key_value.update(
|
200
|
-
key_state,
|
201
|
-
value_state,
|
202
|
-
self.layer_idx,
|
203
|
-
b,
|
204
|
-
)
|
205
360
|
|
206
|
-
|
207
|
-
|
361
|
+
class DecoderOnlyModel(nn.Module):
|
362
|
+
"""A modified decoder-only model implementation optimized for RBLN compilation.
|
208
363
|
|
209
|
-
|
364
|
+
Args:
|
365
|
+
model: Original Huggingface model to adapt
|
366
|
+
layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
|
210
367
|
|
211
|
-
|
212
|
-
|
213
|
-
|
368
|
+
Attributes:
|
369
|
+
_original_mod: Reference to original Huggingface model
|
370
|
+
layers: ModuleList of RBLN-optimized transformer layers
|
371
|
+
_phase: Current processing phase ("prefill" or "decode")
|
372
|
+
"""
|
214
373
|
|
215
|
-
|
216
|
-
attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
|
374
|
+
mask_fmin = torch.finfo(torch.float16).min
|
217
375
|
|
218
|
-
|
219
|
-
|
376
|
+
def __init__(self, model, layers: List["DecoderOnlyLayer"]):
|
377
|
+
super().__init__()
|
378
|
+
self._original_mod = model
|
379
|
+
self.layers = nn.ModuleList(layers)
|
380
|
+
self._phase = "prefill"
|
220
381
|
|
221
|
-
|
222
|
-
|
223
|
-
|
382
|
+
@property
|
383
|
+
def phase(self):
|
384
|
+
return self._phase
|
224
385
|
|
225
|
-
|
226
|
-
|
227
|
-
|
386
|
+
@phase.setter
|
387
|
+
def phase(self, phase: str):
|
388
|
+
self._phase = phase
|
389
|
+
for layer in self.layers:
|
390
|
+
layer.phase = phase
|
228
391
|
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
392
|
+
@property
|
393
|
+
def hidden_multiplier(self):
|
394
|
+
return 1
|
395
|
+
|
396
|
+
def get_last_layernorm(self) -> nn.LayerNorm:
|
397
|
+
return self._original_mod.norm
|
398
|
+
|
399
|
+
def get_embedding(self) -> nn.Embedding:
|
400
|
+
return self._original_mod.embed_tokens
|
401
|
+
|
402
|
+
def get_pos_embedding(self) -> nn.Embedding:
|
403
|
+
raise NotImplementedError(
|
404
|
+
"The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
|
405
|
+
)
|
406
|
+
|
407
|
+
def forward(
|
408
|
+
self,
|
409
|
+
input_ids: torch.Tensor = None,
|
410
|
+
inputs_embeds: torch.Tensor = None,
|
411
|
+
attention_mask: torch.Tensor = None,
|
412
|
+
cache_position: torch.Tensor = None,
|
413
|
+
batch_position: torch.Tensor = None,
|
414
|
+
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
415
|
+
rotary_emb: nn.Module = None,
|
416
|
+
):
|
417
|
+
# retrieve input_ids and inputs_embeds
|
418
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
419
|
+
raise ValueError(
|
420
|
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
243
421
|
)
|
244
422
|
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
423
|
+
# embed positions
|
424
|
+
if inputs_embeds is None:
|
425
|
+
inputs_embeds = self.get_embedding()(input_ids)
|
426
|
+
|
427
|
+
hidden_states = inputs_embeds * self.hidden_multiplier
|
428
|
+
attention_mask = (1 - attention_mask) * self.mask_fmin
|
429
|
+
|
430
|
+
# get cos,sin vector if needed
|
431
|
+
if rotary_emb is not None:
|
432
|
+
cos, sin = rotary_emb(hidden_states, attention_mask.shape[-1]) # dtype carrier, max_seq_len
|
433
|
+
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
|
434
|
+
else:
|
435
|
+
batch_size = inputs_embeds.shape[0]
|
436
|
+
if cache_position.shape[0] > 1:
|
437
|
+
position_embeds = []
|
438
|
+
for b_idx in range(batch_size):
|
439
|
+
position_embed = self.get_pos_embedding()(cache_position[b_idx])
|
440
|
+
position_embeds.append(position_embed)
|
441
|
+
|
442
|
+
position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
|
443
|
+
else:
|
444
|
+
position_embeds = self.get_pos_embedding()(cache_position)
|
445
|
+
hidden_states = hidden_states + position_embeds
|
446
|
+
cos, sin = None, None
|
447
|
+
|
448
|
+
# (batch, seq_len) -> (batch,)
|
449
|
+
current_steps = cache_position[:, 0]
|
450
|
+
|
451
|
+
present_key_values = past_key_values
|
452
|
+
for layer in self.layers:
|
453
|
+
hidden_states, present_key_values = layer(
|
454
|
+
hidden_states=hidden_states,
|
455
|
+
attention_mask=attention_mask,
|
456
|
+
current_steps=current_steps,
|
457
|
+
batch_position=batch_position,
|
458
|
+
past_key_values=present_key_values,
|
459
|
+
cos=cos,
|
460
|
+
sin=sin,
|
251
461
|
)
|
252
462
|
|
253
|
-
|
254
|
-
|
463
|
+
hidden_states = self.get_last_layernorm()(hidden_states)
|
464
|
+
return hidden_states, present_key_values
|
255
465
|
|
256
|
-
# upcast attention to fp32
|
257
|
-
attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
258
|
-
attn_output = torch.matmul(attn_weight, value_states)
|
259
466
|
|
260
|
-
|
261
|
-
|
262
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
263
|
-
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
467
|
+
class DecoderOnlyLayer(nn.Module):
|
468
|
+
"""A single transformer layer adapted for RBLN compilation with static shapes.
|
264
469
|
|
265
|
-
|
470
|
+
This layer implements a modified transformer block that includes:
|
471
|
+
1. Self-attention mechanism (either standard or flash attention)
|
472
|
+
2. Feed-forward network (FFN)
|
473
|
+
3. Layer normalization
|
474
|
+
4. Residual connections
|
266
475
|
|
267
|
-
|
268
|
-
|
476
|
+
The layer is specifically designed to:
|
477
|
+
- Support compilation to RBLN custom ops
|
478
|
+
- Maintain static tensor shapes throughout computations
|
479
|
+
- Handle both prefill and decode phases efficiently
|
480
|
+
- Manage attention state transitions properly
|
269
481
|
|
270
|
-
|
482
|
+
Args:
|
483
|
+
layer: Original transformer layer module to wrap
|
484
|
+
self_attn (DecoderOnlyAttention): Modified attention module optimized for RBLN
|
271
485
|
|
486
|
+
Attributes:
|
487
|
+
_original_mod: Reference to original layer for accessing components
|
488
|
+
self_attn: Modified attention mechanism mapped to RBLN ops at compile time
|
489
|
+
phase: Current operation phase ("prefill" or "decode")
|
490
|
+
"""
|
491
|
+
|
492
|
+
def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
|
493
|
+
super().__init__()
|
494
|
+
self._original_mod = layer
|
495
|
+
self.self_attn = self_attn
|
496
|
+
self._phase = "prefill"
|
497
|
+
|
498
|
+
@property
|
499
|
+
def phase(self):
|
500
|
+
return self._phase
|
501
|
+
|
502
|
+
@phase.setter
|
503
|
+
def phase(self, phase: str):
|
504
|
+
self._phase = phase
|
505
|
+
self.self_attn.phase = phase
|
506
|
+
|
507
|
+
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
508
|
+
return self._original_mod.input_layernorm
|
509
|
+
|
510
|
+
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
511
|
+
return self._original_mod.post_attention_layernorm
|
272
512
|
|
273
|
-
class DecoderOnlyDecoderLayer:
|
274
513
|
def forward(
|
275
514
|
self,
|
276
515
|
hidden_states: torch.Tensor,
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
output_attentions: Optional[bool] = None,
|
282
|
-
use_cache: Optional[bool] = None,
|
283
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
516
|
+
attention_mask: torch.Tensor,
|
517
|
+
current_steps: torch.LongTensor,
|
518
|
+
batch_position: torch.Tensor,
|
519
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
284
520
|
cos: Optional[torch.Tensor] = None,
|
285
521
|
sin: Optional[torch.Tensor] = None,
|
286
|
-
|
287
|
-
**kwargs,
|
288
|
-
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
522
|
+
):
|
289
523
|
residual = hidden_states
|
290
524
|
|
291
|
-
hidden_states = self.
|
525
|
+
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
292
526
|
|
293
|
-
hidden_states,
|
294
|
-
self.self_attn,
|
527
|
+
hidden_states, present_key_values = self.self_attn(
|
295
528
|
hidden_states=hidden_states,
|
296
529
|
attention_mask=attention_mask,
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
batch_index=batch_ids,
|
301
|
-
use_cache=use_cache,
|
530
|
+
current_steps=current_steps,
|
531
|
+
batch_position=batch_position,
|
532
|
+
past_key_values=past_key_values,
|
302
533
|
cos=cos,
|
303
534
|
sin=sin,
|
304
|
-
**kwargs,
|
305
535
|
)
|
306
|
-
past_key_value.assign(k, v, layer_idx)
|
307
|
-
|
308
536
|
hidden_states = residual + hidden_states
|
309
537
|
|
310
538
|
# Fully Connected
|
311
539
|
residual = hidden_states
|
312
|
-
hidden_states = self.
|
313
|
-
hidden_states = self.mlp(hidden_states)
|
540
|
+
hidden_states = self.get_post_attention_layernorm()(hidden_states)
|
541
|
+
hidden_states = self._original_mod.mlp(hidden_states)
|
314
542
|
hidden_states = residual + hidden_states
|
315
543
|
|
316
|
-
|
544
|
+
return hidden_states, present_key_values
|
317
545
|
|
318
|
-
if output_attentions:
|
319
|
-
outputs += (self_attn_weight,)
|
320
546
|
|
321
|
-
|
322
|
-
|
547
|
+
class DecoderOnlyAttention(nn.Module):
|
548
|
+
"""Attention implementation for decoder-only models optimized for RBLN compilation.
|
323
549
|
|
324
|
-
|
550
|
+
This class implements a modified version of the standard attention mechanism that:
|
551
|
+
1. Supports static shape requirements for RBLN compilation
|
552
|
+
2. Handles explicit batch and position management
|
325
553
|
|
554
|
+
Args:
|
555
|
+
self_attn: Original attention module from the base model
|
556
|
+
"""
|
326
557
|
|
327
|
-
|
328
|
-
|
558
|
+
def __init__(self, self_attn):
|
559
|
+
super().__init__()
|
560
|
+
self._original_mod = self_attn
|
561
|
+
self.layer_idx = self_attn.layer_idx
|
562
|
+
self.num_heads = self._original_mod.num_heads
|
563
|
+
self.head_dim = self._original_mod.head_dim
|
564
|
+
self.phase = "prefill"
|
565
|
+
self.__post_init__()
|
566
|
+
|
567
|
+
def __post_init__(self):
|
568
|
+
self.q_proj = self._original_mod.q_proj
|
569
|
+
self.k_proj = self._original_mod.k_proj
|
570
|
+
self.v_proj = self._original_mod.v_proj
|
571
|
+
self.o_proj = self._original_mod.o_proj
|
572
|
+
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
573
|
+
|
574
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
575
|
+
"""Projects input hidden states into query, key, and value representations.
|
576
|
+
|
577
|
+
Args:
|
578
|
+
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
|
579
|
+
|
580
|
+
Returns:
|
581
|
+
Tuple of (query_states, key_states, value_states)
|
582
|
+
"""
|
583
|
+
query_states = self.q_proj(hidden_states)
|
584
|
+
key_states = self.k_proj(hidden_states)
|
585
|
+
value_states = self.v_proj(hidden_states)
|
586
|
+
return query_states, key_states, value_states
|
587
|
+
|
588
|
+
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
589
|
+
return apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
590
|
+
|
591
|
+
def rbln_attention(
|
329
592
|
self,
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
593
|
+
query_state,
|
594
|
+
key_state,
|
595
|
+
value_state,
|
596
|
+
attn_mask,
|
597
|
+
batch_idx,
|
598
|
+
past_key_state,
|
599
|
+
past_value_state,
|
600
|
+
current_step,
|
601
|
+
# below are designed for Midm, GPT which requires to support scaling for attention weights
|
602
|
+
# TODO(jongho): Merge and manage scales generally
|
603
|
+
layer_idx=None,
|
604
|
+
scale_attn_weights: bool = None,
|
605
|
+
scale_attn_by_inverse_layer_idx: bool = None,
|
606
|
+
scale_qk_by_inverse_layer_idx: bool = None,
|
607
|
+
):
|
608
|
+
"""Compute attention with static shapes and explicit cache management.
|
609
|
+
|
610
|
+
Args:
|
611
|
+
query_state: Query tensor [1, num_heads, 1, head_dim]
|
612
|
+
key_state: Key tensor [1, num_heads, seq_len, head_dim]
|
613
|
+
value_state: Value tensor [1, num_heads, seq_len, head_dim]
|
614
|
+
attn_mask: Attention mask tensor
|
615
|
+
batch_idx: Batch index for cache lookup
|
616
|
+
past_key_state: Previous key cache states
|
617
|
+
past_value_state: Previous value cache states
|
618
|
+
current_step: Current position in sequence
|
619
|
+
|
620
|
+
Returns:
|
621
|
+
Tuple of (attention_output, key_state, value_state)
|
622
|
+
"""
|
623
|
+
# Implementation details.
|
624
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
625
|
+
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
626
|
+
value_state = value_state.unsqueeze(2)
|
627
|
+
attn_mask = attn_mask.unsqueeze(2)
|
628
|
+
|
629
|
+
query_state = query_state.view(
|
630
|
+
1,
|
631
|
+
self.num_key_value_heads,
|
632
|
+
self.num_heads // self.num_key_value_heads,
|
633
|
+
-1, # seq len
|
634
|
+
self.head_dim,
|
635
|
+
) #
|
636
|
+
|
637
|
+
kend = current_step + key_state.shape[-2]
|
638
|
+
vend = current_step + value_state.shape[-2]
|
639
|
+
|
640
|
+
key_state = (
|
641
|
+
past_key_state[batch_idx]
|
642
|
+
.unsqueeze(0)
|
643
|
+
.unsqueeze(2)
|
644
|
+
.slice_scatter(key_state, dim=-2, start=current_step, end=kend)
|
645
|
+
)
|
646
|
+
value_state = (
|
647
|
+
past_value_state[batch_idx]
|
648
|
+
.unsqueeze(0)
|
649
|
+
.unsqueeze(2)
|
650
|
+
.slice_scatter(value_state, dim=-2, start=current_step, end=vend)
|
651
|
+
)
|
347
652
|
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
# decoder layers
|
360
|
-
all_hidden_states = () if output_hidden_states else None
|
361
|
-
all_self_attns = () if output_attentions else None
|
362
|
-
|
363
|
-
for layer_idx, decoder_layer in enumerate(self.layers):
|
364
|
-
if output_hidden_states:
|
365
|
-
all_hidden_states += (hidden_states,)
|
366
|
-
layer_outputs = forward_dict["model"](
|
367
|
-
decoder_layer,
|
368
|
-
hidden_states,
|
369
|
-
layer_idx,
|
370
|
-
attention_mask=attention_mask,
|
371
|
-
position_ids=position_ids,
|
372
|
-
past_key_value=past_key_values,
|
373
|
-
output_attentions=output_attentions,
|
374
|
-
use_cache=use_cache,
|
375
|
-
batch_ids=batch_ids,
|
376
|
-
cos=cos,
|
377
|
-
sin=sin,
|
378
|
-
forward_dict=forward_dict,
|
379
|
-
)
|
653
|
+
attn_weight = torch.matmul(query_state, key_state.transpose(3, 4))
|
654
|
+
attn_weight = attn_weight / math.sqrt(self.head_dim)
|
655
|
+
|
656
|
+
if layer_idx is not None and (scale_attn_by_inverse_layer_idx or scale_qk_by_inverse_layer_idx):
|
657
|
+
attn_weight = attn_weight / float(layer_idx + 1)
|
658
|
+
|
659
|
+
attn_weight += attn_mask
|
660
|
+
|
661
|
+
if layer_idx is not None and scale_qk_by_inverse_layer_idx:
|
662
|
+
attn_weight = attn_weight * float(layer_idx + 1)
|
380
663
|
|
381
|
-
|
664
|
+
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
|
382
665
|
|
383
|
-
|
666
|
+
attn_output = torch.matmul(attn_weight, value_state)
|
384
667
|
|
385
|
-
|
386
|
-
|
668
|
+
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
|
669
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
670
|
+
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
387
671
|
|
388
|
-
|
672
|
+
return attn_output, key_state, value_state
|
389
673
|
|
390
|
-
|
391
|
-
|
392
|
-
|
674
|
+
def forward(
|
675
|
+
self,
|
676
|
+
hidden_states: torch.Tensor,
|
677
|
+
attention_mask: torch.Tensor,
|
678
|
+
current_steps: torch.LongTensor,
|
679
|
+
batch_position: torch.Tensor,
|
680
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
681
|
+
cos: Optional[torch.Tensor] = None, # (batch, 1, prefill_size, head_dim)
|
682
|
+
sin: Optional[torch.Tensor] = None,
|
683
|
+
):
|
684
|
+
batch_size, query_length, _ = hidden_states.size()
|
393
685
|
|
394
|
-
|
395
|
-
next_cache = updated_cache.to_legacy_cache()
|
686
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
396
687
|
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
attentions=all_self_attns,
|
688
|
+
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
689
|
+
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
690
|
+
value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
|
691
|
+
1, 2
|
402
692
|
)
|
693
|
+
# b, num_head, query, head_dim
|
694
|
+
|
695
|
+
if cos is not None and sin is not None:
|
696
|
+
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
697
|
+
|
698
|
+
if batch_size > 1 and self.phase == "prefill":
|
699
|
+
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
700
|
+
|
701
|
+
_key_states = []
|
702
|
+
_value_states = []
|
703
|
+
_attn_outputs = []
|
704
|
+
for b in range(batch_size):
|
705
|
+
current_step = current_steps[b]
|
706
|
+
attn_output, key_state, value_state = self.rbln_attention(
|
707
|
+
query_states[b].unsqueeze(0),
|
708
|
+
key_states[b].unsqueeze(0),
|
709
|
+
value_states[b].unsqueeze(0),
|
710
|
+
attention_mask[b].unsqueeze(0)
|
711
|
+
if self.phase == "decode"
|
712
|
+
else attention_mask, # TODO(jongho): fix when msoftmax is supported
|
713
|
+
past_key_state=past_key_values[self.layer_idx][0],
|
714
|
+
past_value_state=past_key_values[self.layer_idx][1],
|
715
|
+
batch_idx=b if self.phase == "decode" else batch_position,
|
716
|
+
current_step=current_step,
|
717
|
+
)
|
718
|
+
_key_states.append(key_state)
|
719
|
+
_value_states.append(value_state)
|
720
|
+
_attn_outputs.append(attn_output)
|
721
|
+
key_states = torch.cat(_key_states, dim=0)
|
722
|
+
value_states = torch.cat(_value_states, dim=0)
|
723
|
+
attn_outputs = torch.cat(_attn_outputs, dim=0)
|
724
|
+
|
725
|
+
attn_outputs = self.o_proj(attn_outputs)
|
726
|
+
past_key_values[self.layer_idx] = key_states, value_states
|
727
|
+
return attn_outputs, past_key_values
|
403
728
|
|
404
729
|
|
405
|
-
def slice_and_unsqueeze_cos_sin(cos, sin,
|
406
|
-
"""Slice cos[
|
407
|
-
if
|
730
|
+
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
731
|
+
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
|
732
|
+
if cache_position.shape[0] > 1:
|
408
733
|
cos_all = []
|
409
734
|
sin_all = []
|
410
|
-
for i in range(
|
411
|
-
cos_all.append(cos[
|
412
|
-
sin_all.append(sin[
|
735
|
+
for i in range(cache_position.shape[0]):
|
736
|
+
cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
737
|
+
sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
413
738
|
cos = torch.cat(cos_all, dim=0)
|
414
739
|
sin = torch.cat(sin_all, dim=0)
|
415
740
|
else:
|
416
|
-
cos = cos[
|
417
|
-
sin = sin[
|
741
|
+
cos = cos[cache_position].unsqueeze(unsqueeze_dim)
|
742
|
+
sin = sin[cache_position].unsqueeze(unsqueeze_dim)
|
418
743
|
|
419
744
|
return cos, sin
|
420
745
|
|
@@ -434,34 +759,58 @@ def apply_rotary_pos_emb(q, k, cos, sin):
|
|
434
759
|
return q_embed, k_embed
|
435
760
|
|
436
761
|
|
762
|
+
def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
|
763
|
+
# Partial rotary embedding
|
764
|
+
query_rot, query_pass = (
|
765
|
+
query_states[..., :ndim],
|
766
|
+
query_states[..., ndim:],
|
767
|
+
)
|
768
|
+
key_rot, key_pass = (
|
769
|
+
key_states[..., :ndim],
|
770
|
+
key_states[..., ndim:],
|
771
|
+
)
|
772
|
+
|
773
|
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
774
|
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
775
|
+
|
776
|
+
# [batch_size, seq_length, num_heads, head_dim]
|
777
|
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
778
|
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
779
|
+
return query_states, key_states
|
780
|
+
|
781
|
+
|
437
782
|
class RotaryEmbedding(nn.Module):
|
783
|
+
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
784
|
+
|
438
785
|
def __init__(
|
439
786
|
self,
|
440
|
-
|
441
|
-
|
442
|
-
base=10000,
|
443
|
-
device=None,
|
444
|
-
scaling_factor=1.0,
|
787
|
+
config: PretrainedConfig,
|
788
|
+
max_seq_len_cached: int,
|
445
789
|
):
|
446
790
|
super().__init__()
|
447
791
|
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
792
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
793
|
+
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
794
|
+
else:
|
795
|
+
rope_type = "default"
|
796
|
+
|
797
|
+
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
798
|
+
cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
|
799
|
+
cache_position_expanded = cache_position[:, None]
|
454
800
|
|
455
|
-
|
456
|
-
|
801
|
+
if rope_type == "dynamic":
|
802
|
+
freqs = cache_position_expanded.float() * inv_freq.float()
|
803
|
+
else:
|
804
|
+
inv_freq_expanded = inv_freq[None, :]
|
805
|
+
freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
|
457
806
|
|
458
|
-
positions_ids = torch.arange(self.max_position_embeddings, device=device, dtype=self.inv_freq.dtype)
|
459
|
-
freqs = torch.outer(positions_ids, self.inv_freq)
|
460
|
-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
461
807
|
emb = torch.cat((freqs, freqs), dim=-1)
|
462
808
|
|
463
|
-
|
464
|
-
|
809
|
+
cos = emb.cos() * attention_scaling
|
810
|
+
sin = emb.sin() * attention_scaling
|
811
|
+
|
812
|
+
self.register_buffer("_cos_cached", cos, persistent=False)
|
813
|
+
self.register_buffer("_sin_cached", sin, persistent=False)
|
465
814
|
|
466
815
|
def forward(self, x, seq_len):
|
467
816
|
return (
|
@@ -470,71 +819,140 @@ class RotaryEmbedding(nn.Module):
|
|
470
819
|
)
|
471
820
|
|
472
821
|
|
473
|
-
class
|
474
|
-
|
475
|
-
|
476
|
-
|
822
|
+
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
823
|
+
def __init__(self, self_attn, kvcache_partition_len):
|
824
|
+
super().__init__(self_attn=self_attn)
|
825
|
+
self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
|
826
|
+
|
827
|
+
def get_cache_pos_for_partitions(self, current_steps, batch_size, max_seq_len):
|
828
|
+
partition_len = self.kvcache_partition_size.size()[0]
|
829
|
+
num_partition = max_seq_len // partition_len
|
830
|
+
cache_pos_for_partitions = torch.zeros((batch_size, num_partition), dtype=torch.int32)
|
831
|
+
if self.phase == "decode":
|
832
|
+
for b_idx in range(batch_size):
|
833
|
+
cache_pos = current_steps[b_idx]
|
834
|
+
for p_idx in range(num_partition):
|
835
|
+
cache_pos_for_partitions[b_idx][p_idx] = torch.clamp(
|
836
|
+
cache_pos - partition_len * p_idx, 0, partition_len
|
837
|
+
)
|
838
|
+
else: # prefill
|
839
|
+
cache_pos = current_steps[0]
|
840
|
+
for p_idx in range(num_partition):
|
841
|
+
cache_pos_for_partitions[0][p_idx] = torch.clamp(cache_pos - partition_len * p_idx, 0, partition_len)
|
842
|
+
|
843
|
+
return cache_pos_for_partitions
|
844
|
+
|
845
|
+
def rbln_flash_attention(
|
477
846
|
self,
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
847
|
+
query_state,
|
848
|
+
key_state,
|
849
|
+
value_state,
|
850
|
+
attn_mask,
|
851
|
+
batch_idx,
|
852
|
+
past_key_state,
|
853
|
+
past_value_state,
|
854
|
+
cache_pos_for_partitions,
|
484
855
|
):
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
856
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
857
|
+
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
858
|
+
value_state = value_state.unsqueeze(2)
|
859
|
+
attn_mask = attn_mask.unsqueeze(2)
|
860
|
+
|
861
|
+
query_state = query_state.view(
|
862
|
+
1,
|
863
|
+
self.num_key_value_heads,
|
864
|
+
self.num_heads // self.num_key_value_heads,
|
865
|
+
-1, # seq len
|
866
|
+
self.head_dim,
|
490
867
|
)
|
491
|
-
# difference to the original RoPE: a scaling factor is aplied to the position ids
|
492
|
-
if max_seq_len > max_position_embeddings:
|
493
|
-
positions_ids = torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
494
|
-
positions_ids = positions_ids / self.scaling_factor
|
495
|
-
freqs = torch.outer(positions_ids, self.inv_freq)
|
496
|
-
emb = torch.cat((freqs, freqs), dim=-1)
|
497
|
-
cos = emb.cos()
|
498
|
-
sin = emb.sin()
|
499
868
|
|
500
|
-
|
501
|
-
|
869
|
+
# RBLN custom flash attention(decode), dummy batch index
|
870
|
+
if self.phase == "decode":
|
871
|
+
sidx = cache_pos_for_partitions[batch_idx][0]
|
872
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_decode(
|
873
|
+
query_state,
|
874
|
+
key_state,
|
875
|
+
value_state,
|
876
|
+
attn_mask,
|
877
|
+
past_key_state.unsqueeze(2),
|
878
|
+
past_value_state.unsqueeze(2),
|
879
|
+
sidx,
|
880
|
+
self.kvcache_partition_size,
|
881
|
+
)
|
882
|
+
else:
|
883
|
+
sidx = cache_pos_for_partitions[0][0]
|
884
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_prefill(
|
885
|
+
query_state,
|
886
|
+
key_state,
|
887
|
+
value_state,
|
888
|
+
attn_mask,
|
889
|
+
past_key_state.unsqueeze(2),
|
890
|
+
past_value_state.unsqueeze(2),
|
891
|
+
batch_idx,
|
892
|
+
sidx,
|
893
|
+
self.kvcache_partition_size,
|
894
|
+
)
|
502
895
|
|
896
|
+
# reshape for removing repeat_kv
|
897
|
+
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
|
898
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
899
|
+
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
503
900
|
|
504
|
-
|
505
|
-
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
901
|
+
return attn_output, key_state, value_state
|
506
902
|
|
507
|
-
def
|
903
|
+
def forward(
|
508
904
|
self,
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
905
|
+
hidden_states: torch.Tensor,
|
906
|
+
attention_mask: torch.Tensor,
|
907
|
+
current_steps: torch.LongTensor,
|
908
|
+
batch_position: torch.Tensor,
|
909
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
910
|
+
cos: Optional[torch.Tensor] = None,
|
911
|
+
sin: Optional[torch.Tensor] = None,
|
515
912
|
):
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
913
|
+
batch_size, query_length, _ = hidden_states.size()
|
914
|
+
|
915
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
916
|
+
|
917
|
+
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
918
|
+
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
919
|
+
value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
|
920
|
+
1, 2
|
521
921
|
)
|
522
|
-
#
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
922
|
+
# b, num_head, query, head_dim
|
923
|
+
|
924
|
+
max_seq_len = past_key_values[self.layer_idx][0].shape[-2]
|
925
|
+
|
926
|
+
if cos is not None and sin is not None:
|
927
|
+
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
928
|
+
|
929
|
+
cache_pos_for_partitions = self.get_cache_pos_for_partitions(
|
930
|
+
current_steps, batch_size=batch_size, max_seq_len=max_seq_len
|
931
|
+
) # batch_size, num_partitions
|
932
|
+
|
933
|
+
_key_states = []
|
934
|
+
_value_states = []
|
935
|
+
_attn_outputs = []
|
936
|
+
for b in range(batch_size):
|
937
|
+
attn_output, key_state, value_state = self.rbln_flash_attention(
|
938
|
+
query_states[b].unsqueeze(0),
|
939
|
+
key_states[b].unsqueeze(0),
|
940
|
+
value_states[b].unsqueeze(0),
|
941
|
+
attention_mask[b].unsqueeze(0)
|
942
|
+
if self.phase == "decode"
|
943
|
+
else attention_mask, # TODO(jongho): fix when msoftmax is supported
|
944
|
+
past_key_state=past_key_values[self.layer_idx][0],
|
945
|
+
past_value_state=past_key_values[self.layer_idx][1],
|
946
|
+
batch_idx=b if self.phase == "decode" else batch_position,
|
947
|
+
cache_pos_for_partitions=cache_pos_for_partitions,
|
948
|
+
)
|
949
|
+
_key_states.append(key_state)
|
950
|
+
_value_states.append(value_state)
|
951
|
+
_attn_outputs.append(attn_output)
|
952
|
+
key_states = torch.cat(_key_states, dim=0)
|
953
|
+
value_states = torch.cat(_value_states, dim=0)
|
954
|
+
attn_outputs = torch.cat(_attn_outputs, dim=0)
|
955
|
+
|
956
|
+
attn_outputs = self.o_proj(attn_outputs)
|
957
|
+
past_key_values[self.layer_idx] = key_states, value_states
|
958
|
+
return attn_outputs, past_key_values
|