optimum-rbln 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +5 -1
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -61
- optimum/rbln/diffusers/models/controlnet.py +36 -56
- optimum/rbln/diffusers/models/unet_2d_condition.py +57 -153
- optimum/rbln/diffusers/pipelines/__init__.py +40 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -115
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -122
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -125
- optimum/rbln/modeling_base.py +12 -5
- optimum/rbln/modeling_diffusers.py +400 -0
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/clip/modeling_clip.py +13 -22
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +74 -16
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +18 -9
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +27 -8
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
- optimum/rbln/transformers/models/midm/modeling_midm.py +4 -29
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +57 -4
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
- optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
- optimum/rbln/utils/context.py +58 -0
- optimum/rbln/utils/decorator_utils.py +55 -0
- optimum/rbln/utils/import_utils.py +7 -0
- optimum/rbln/utils/runtime_utils.py +4 -4
- optimum/rbln/utils/timer_utils.py +2 -2
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +8 -7
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/RECORD +52 -48
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -26,141 +26,241 @@ from typing import Dict, Optional, Tuple
|
|
26
26
|
|
27
27
|
import torch
|
28
28
|
from torch import nn
|
29
|
+
from transformers import PretrainedConfig
|
29
30
|
from transformers.modeling_outputs import (
|
30
31
|
BaseModelOutputWithPast,
|
31
32
|
)
|
32
33
|
|
34
|
+
from ....utils import logging
|
33
35
|
from ...cache_utils import RebelDynamicCache
|
36
|
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
37
|
+
|
38
|
+
|
39
|
+
logger = logging.get_logger(__name__)
|
40
|
+
"""
|
41
|
+
##############################################################################
|
42
|
+
# RBLN custom operation (python interface)
|
43
|
+
# torch.compile custom operation
|
44
|
+
# torch.library.define - kernel declaration
|
45
|
+
# torch.library.impl - kernel implementation
|
46
|
+
# torch.library.impl_abstract - symbolic trace
|
47
|
+
##############################################################################
|
48
|
+
"""
|
49
|
+
|
50
|
+
# RBLN custom op(flash attention decode)
|
51
|
+
torch.library.define(
|
52
|
+
"rbln_custom_ops::flash_attn_decode",
|
53
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
|
54
|
+
)
|
34
55
|
|
35
56
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
57
|
+
@torch.library.impl("rbln_custom_ops::flash_attn_decode", "cpu")
|
58
|
+
def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, partition):
|
59
|
+
"""
|
60
|
+
WORKAROUND:
|
61
|
+
Partition is declared as an argument to the function, even though it is
|
62
|
+
not actually used in the CPU implementation, this allows the rbln compiler
|
63
|
+
to perform flash attention operations with partition as an argument.
|
64
|
+
"""
|
65
|
+
assert kcache.dim() == k.dim()
|
66
|
+
assert vcache.dim() == v.dim()
|
67
|
+
assert k.size(-2) == v.size(-2)
|
68
|
+
assert partition.dim() == 1
|
69
|
+
b = 0
|
70
|
+
if seq.dim() == 1:
|
71
|
+
s = seq[0]
|
72
|
+
elif seq.dim() == 0:
|
73
|
+
s = seq
|
74
|
+
else:
|
75
|
+
assert False
|
76
|
+
e = s + k.size(-2)
|
77
|
+
updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
|
78
|
+
updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
|
79
|
+
attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
|
80
|
+
attn_weight = attn_weight + mask
|
81
|
+
attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
|
82
|
+
attn_output = torch.matmul(attn_weight, updated_v)
|
83
|
+
return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
|
84
|
+
|
85
|
+
|
86
|
+
@torch.library.impl_abstract("rbln_custom_ops::flash_attn_decode")
|
87
|
+
def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
|
88
|
+
return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
|
89
|
+
|
90
|
+
|
91
|
+
# RBLN custom op(flash attention prefill)
|
92
|
+
torch.library.define(
|
93
|
+
"rbln_custom_ops::flash_attn_prefill",
|
94
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
|
95
|
+
)
|
96
|
+
|
97
|
+
|
98
|
+
@torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
|
99
|
+
def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, partition):
|
100
|
+
"""
|
101
|
+
WORKAROUND:
|
102
|
+
Partition is declared as an argument to the function, even though it is
|
103
|
+
not actually used in the CPU implementation, this allows the rbln compiler
|
104
|
+
to perform flash attention operations with partition as an argument.
|
105
|
+
"""
|
106
|
+
assert kcache.dim() == k.dim()
|
107
|
+
assert vcache.dim() == v.dim()
|
108
|
+
assert k.size(-2) == v.size(-2)
|
109
|
+
assert partition.dim() == 1
|
110
|
+
if batch.dim() == 1:
|
111
|
+
b = batch[0]
|
112
|
+
elif batch.dim() == 0:
|
113
|
+
b = batch
|
114
|
+
else:
|
115
|
+
assert False
|
116
|
+
if seq.dim() == 1:
|
117
|
+
s = seq[0]
|
118
|
+
elif seq.dim() == 0:
|
119
|
+
s = seq
|
120
|
+
else:
|
121
|
+
assert False
|
122
|
+
e = s + k.size(-2)
|
123
|
+
updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
|
124
|
+
updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
|
125
|
+
attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
|
126
|
+
attn_weight = attn_weight + mask
|
127
|
+
attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
|
128
|
+
attn_output = torch.matmul(attn_weight, updated_v)
|
129
|
+
return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
|
130
|
+
|
42
131
|
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
132
|
+
@torch.library.impl_abstract("rbln_custom_ops::flash_attn_prefill")
|
133
|
+
def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
|
134
|
+
return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
|
135
|
+
|
136
|
+
|
137
|
+
# RBLN custom op(cache update)
|
138
|
+
torch.library.define("rbln_custom_ops::rbln_cache_update", "(Tensor x, Tensor y, Tensor z, Tensor w) -> Tensor")
|
139
|
+
|
140
|
+
|
141
|
+
@torch.library.impl("rbln_custom_ops::rbln_cache_update", "cpu")
|
142
|
+
def rbln_cache_update_cpu(cache, value, batch, seq):
|
143
|
+
updated_cache = cache[batch].slice_scatter(value, dim=-2, start=batch[0], end=batch[0] + seq[0])
|
144
|
+
return updated_cache
|
145
|
+
|
146
|
+
|
147
|
+
@torch.library.impl_abstract("rbln_custom_ops::rbln_cache_update")
|
148
|
+
def rbln_cache_update_abstract(cache, value, batch, seq):
|
149
|
+
return torch.empty_like(cache)
|
150
|
+
|
151
|
+
|
152
|
+
class DecoderOnlyAttention:
|
153
|
+
def _attn(self, query_state, key_state, value_state, attn_mask, past_key_value, batch_idx=0, is_prefill=False):
|
154
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
155
|
+
key_state = key_state.unsqueeze(2)
|
156
|
+
value_state = value_state.unsqueeze(2)
|
157
|
+
attn_mask = attn_mask.unsqueeze(2)
|
158
|
+
|
159
|
+
query_state = query_state.view(
|
160
|
+
1,
|
161
|
+
self.num_key_value_heads,
|
162
|
+
self.num_heads // self.num_key_value_heads,
|
163
|
+
-1,
|
164
|
+
self.head_dim,
|
47
165
|
)
|
48
|
-
|
49
|
-
|
166
|
+
|
167
|
+
key_state, value_state = past_key_value.update(
|
168
|
+
key_state, value_state, self.layer_idx, batch_idx, read_first_step=is_prefill
|
50
169
|
)
|
51
|
-
self.max_seq_len = max_seq_len
|
52
|
-
self.rope_scaling = getattr(self.config, "rope_scaling", None)
|
53
|
-
self.rotary_emb = self._init_rope()
|
54
170
|
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
max_position_embeddings=self.max_position_embeddings,
|
60
|
-
base=self.config.rope_theta,
|
61
|
-
)
|
62
|
-
else:
|
63
|
-
scaling_type = self.rope_scaling["type"]
|
64
|
-
scaling_factor = self.rope_scaling["factor"]
|
65
|
-
if scaling_type == "linear":
|
66
|
-
rotary_emb = LinearScalingRotaryEmbedding(
|
67
|
-
self.head_dim,
|
68
|
-
max_position_embeddings=self.max_position_embeddings,
|
69
|
-
scaling_factor=scaling_factor,
|
70
|
-
base=self.config.rope_theta,
|
71
|
-
max_seq_len=self.max_seq_len,
|
72
|
-
)
|
73
|
-
elif scaling_type == "dynamic":
|
74
|
-
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
75
|
-
self.head_dim,
|
76
|
-
max_position_embeddings=self.max_position_embeddings,
|
77
|
-
scaling_factor=scaling_factor,
|
78
|
-
base=self.config.rope_theta,
|
79
|
-
max_seq_len=self.max_seq_len,
|
80
|
-
)
|
81
|
-
else:
|
82
|
-
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
171
|
+
attn_weight = torch.matmul(query_state, key_state.transpose(3, 4)) / math.sqrt(self.head_dim)
|
172
|
+
attn_weight += attn_mask
|
173
|
+
attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_state.dtype)
|
174
|
+
attn_output = torch.matmul(attn_weight, value_state)
|
83
175
|
|
84
|
-
|
176
|
+
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
|
177
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
178
|
+
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
85
179
|
|
86
|
-
|
87
|
-
forward_dict = {
|
88
|
-
"wrapper": DecoderOnlyModel.forward,
|
89
|
-
"model": DecoderOnlyDecoderLayer.forward,
|
90
|
-
"decoder_layer": DecoderOnlyAttention.forward,
|
91
|
-
}
|
92
|
-
return forward_dict
|
180
|
+
return attn_output, key_state, value_state
|
93
181
|
|
94
182
|
def forward(
|
95
183
|
self,
|
96
|
-
|
97
|
-
attention_mask,
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
184
|
+
hidden_states: torch.Tensor,
|
185
|
+
attention_mask: Optional[torch.Tensor] = None,
|
186
|
+
position_ids: Optional[torch.LongTensor] = None,
|
187
|
+
past_key_value: Optional[RebelDynamicCache] = None,
|
188
|
+
batch_index: Optional[torch.Tensor] = None,
|
189
|
+
output_attentions: bool = False,
|
190
|
+
cos: Optional[torch.Tensor] = None,
|
191
|
+
sin: Optional[torch.Tensor] = None,
|
192
|
+
**kwargs,
|
193
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
194
|
+
bsz, q_len, _ = hidden_states.size()
|
195
|
+
query_states = self.q_proj(hidden_states)
|
196
|
+
key_states = self.k_proj(hidden_states)
|
197
|
+
value_states = self.v_proj(hidden_states)
|
107
198
|
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
inputs_embeds = None
|
112
|
-
elif input_ids_or_inputs_embeds.ndim == 3:
|
113
|
-
# inputs_embeds
|
114
|
-
input_ids = None
|
115
|
-
inputs_embeds = input_ids_or_inputs_embeds
|
116
|
-
else:
|
117
|
-
raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
|
199
|
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
200
|
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
201
|
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
118
202
|
|
119
|
-
|
120
|
-
past_key_values = RebelDynamicCache.from_input_format(
|
121
|
-
cache_position,
|
122
|
-
self.config.num_hidden_layers,
|
123
|
-
*past_key_values,
|
124
|
-
)
|
203
|
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
125
204
|
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
205
|
+
# Decoder (bsz > 1)
|
206
|
+
if bsz > 1:
|
207
|
+
iterate_results = {"key_states": [], "value_states": [], "attn_output": []}
|
208
|
+
for b in range(bsz):
|
209
|
+
attn_output, key_state, value_state = DecoderOnlyAttention._attn(
|
210
|
+
self,
|
211
|
+
query_states[b].unsqueeze(0),
|
212
|
+
key_states[b].unsqueeze(0),
|
213
|
+
value_states[b].unsqueeze(0),
|
214
|
+
attention_mask[b].unsqueeze(0),
|
215
|
+
past_key_value,
|
216
|
+
batch_idx=b,
|
217
|
+
is_prefill=False,
|
218
|
+
)
|
138
219
|
|
139
|
-
|
140
|
-
|
141
|
-
|
220
|
+
iterate_results["key_states"].append(key_state)
|
221
|
+
iterate_results["value_states"].append(value_state)
|
222
|
+
iterate_results["attn_output"].append(attn_output)
|
142
223
|
|
143
|
-
|
224
|
+
key_states = torch.cat(iterate_results["key_states"], dim=0)
|
225
|
+
value_states = torch.cat(iterate_results["value_states"], dim=0)
|
226
|
+
attn_output = torch.cat(iterate_results["attn_output"], dim=0)
|
227
|
+
# Prefill & Decoder (bsz == 1)
|
228
|
+
else:
|
229
|
+
attn_output, key_states, value_states = DecoderOnlyAttention._attn(
|
230
|
+
self,
|
231
|
+
query_states,
|
232
|
+
key_states,
|
233
|
+
value_states,
|
234
|
+
attention_mask,
|
235
|
+
past_key_value,
|
236
|
+
batch_idx=batch_index,
|
237
|
+
is_prefill=True,
|
238
|
+
)
|
144
239
|
|
145
|
-
|
240
|
+
attn_output = self.o_proj(attn_output)
|
146
241
|
|
147
|
-
|
242
|
+
if not output_attentions:
|
243
|
+
attn_weight = None
|
148
244
|
|
245
|
+
return attn_output, attn_weight, key_states, value_states
|
149
246
|
|
150
|
-
|
247
|
+
|
248
|
+
class DecoderOnlyFlashAttention:
|
151
249
|
def forward(
|
152
250
|
self,
|
153
251
|
hidden_states: torch.Tensor,
|
154
252
|
attention_mask: Optional[torch.Tensor] = None,
|
253
|
+
position_ids: Optional[torch.LongTensor] = None,
|
155
254
|
past_key_value: Optional[RebelDynamicCache] = None,
|
156
|
-
batch_index: Optional[
|
255
|
+
batch_index: Optional[torch.Tensor] = None,
|
157
256
|
output_attentions: bool = False,
|
158
257
|
cos: Optional[torch.Tensor] = None,
|
159
258
|
sin: Optional[torch.Tensor] = None,
|
259
|
+
cache_pos_for_partitions: Optional[torch.Tensor] = None,
|
260
|
+
kvcache_partition_size: Optional[torch.Tensor] = None,
|
160
261
|
**kwargs,
|
161
262
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
162
263
|
bsz, q_len, _ = hidden_states.size()
|
163
|
-
|
164
264
|
query_states = self.q_proj(hidden_states)
|
165
265
|
key_states = self.k_proj(hidden_states)
|
166
266
|
value_states = self.v_proj(hidden_states)
|
@@ -171,8 +271,8 @@ class DecoderOnlyAttention:
|
|
171
271
|
|
172
272
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
173
273
|
|
174
|
-
# Decoder
|
175
|
-
if
|
274
|
+
# Decoder (bsz > 1)
|
275
|
+
if bsz > 1:
|
176
276
|
all_key_states = []
|
177
277
|
all_value_states = []
|
178
278
|
all_attn_output = []
|
@@ -196,25 +296,21 @@ class DecoderOnlyAttention:
|
|
196
296
|
self.head_dim,
|
197
297
|
)
|
198
298
|
|
199
|
-
|
299
|
+
# RBLN custom flash attention(decode), dummy batch index
|
300
|
+
sidx = cache_pos_for_partitions[b][0]
|
301
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_decode(
|
302
|
+
query_state,
|
200
303
|
key_state,
|
201
304
|
value_state,
|
202
|
-
|
203
|
-
|
305
|
+
attn_mask,
|
306
|
+
past_key_value.key_cache[self.layer_idx].unsqueeze(2),
|
307
|
+
past_key_value.value_cache[self.layer_idx].unsqueeze(2),
|
308
|
+
sidx,
|
309
|
+
kvcache_partition_size,
|
204
310
|
)
|
205
311
|
|
206
|
-
# reshape for removing repeat_kv
|
207
|
-
attn_weight = torch.matmul(query_state, key_state.transpose(3, 4)) / math.sqrt(self.head_dim)
|
208
|
-
|
209
|
-
attn_weight = attn_weight + attn_mask
|
210
|
-
|
211
|
-
# upcast attention to fp32
|
212
|
-
attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
213
|
-
attn_output = torch.matmul(attn_weight, value_state)
|
214
|
-
|
215
312
|
# reshape for removing repeat_kv
|
216
313
|
attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
|
217
|
-
|
218
314
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
219
315
|
attn_output = attn_output.reshape(1, q_len, self.num_heads * self.head_dim)
|
220
316
|
|
@@ -227,9 +323,6 @@ class DecoderOnlyAttention:
|
|
227
323
|
attn_output = torch.cat(all_attn_output, dim=0)
|
228
324
|
|
229
325
|
else:
|
230
|
-
if batch_index is None or batch_index == -1:
|
231
|
-
batch_index = 0
|
232
|
-
|
233
326
|
# reshape for removing repeat_kv
|
234
327
|
key_states = key_states.unsqueeze(2)
|
235
328
|
value_states = value_states.unsqueeze(2)
|
@@ -242,21 +335,22 @@ class DecoderOnlyAttention:
|
|
242
335
|
self.head_dim,
|
243
336
|
)
|
244
337
|
|
245
|
-
|
338
|
+
assert batch_index.dim() == 0
|
339
|
+
assert not output_attentions
|
340
|
+
bidx = batch_index
|
341
|
+
sidx = cache_pos_for_partitions[0][0]
|
342
|
+
attn_output, key_states, value_states = torch.ops.rbln_custom_ops.flash_attn_prefill(
|
343
|
+
query_states,
|
246
344
|
key_states,
|
247
345
|
value_states,
|
248
|
-
|
249
|
-
|
250
|
-
|
346
|
+
attention_mask,
|
347
|
+
past_key_value.key_cache[self.layer_idx].unsqueeze(2),
|
348
|
+
past_key_value.value_cache[self.layer_idx].unsqueeze(2),
|
349
|
+
bidx,
|
350
|
+
sidx,
|
351
|
+
kvcache_partition_size,
|
251
352
|
)
|
252
353
|
|
253
|
-
attn_weight = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
|
254
|
-
attn_weight = attn_weight + attention_mask
|
255
|
-
|
256
|
-
# upcast attention to fp32
|
257
|
-
attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
258
|
-
attn_output = torch.matmul(attn_weight, value_states)
|
259
|
-
|
260
354
|
# reshape for removing repeat_kv
|
261
355
|
attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
|
262
356
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
@@ -270,6 +364,128 @@ class DecoderOnlyAttention:
|
|
270
364
|
return attn_output, attn_weight, key_states, value_states
|
271
365
|
|
272
366
|
|
367
|
+
DECODERONLY_ATTENTION_CLASSES = {
|
368
|
+
"eager": DecoderOnlyAttention,
|
369
|
+
"flash_attn_rbln": DecoderOnlyFlashAttention,
|
370
|
+
# "sdpa": DecoderOnlySdpaAttention,
|
371
|
+
}
|
372
|
+
|
373
|
+
|
374
|
+
class DecoderOnlyWrapper(torch.nn.Module):
|
375
|
+
def __init__(self, model, max_seq_len, kvcache_partition_len=None):
|
376
|
+
super().__init__()
|
377
|
+
self.config = model.config
|
378
|
+
self.model = model.model
|
379
|
+
self.lm_head = model.lm_head
|
380
|
+
self.max_seq_len = max_seq_len
|
381
|
+
self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
382
|
+
|
383
|
+
if kvcache_partition_len is not None:
|
384
|
+
# WORKAROUND : for passing partition length as a value to the rbln compiler.
|
385
|
+
# What is actually used is the shape of this tensor.
|
386
|
+
self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
|
387
|
+
self.attn_implementation = "flash_attn_rbln"
|
388
|
+
logger.info(f"Using rbln-flash-attention. (partition length : {kvcache_partition_len})")
|
389
|
+
else:
|
390
|
+
self.kvcache_partition_size = None
|
391
|
+
self.attn_implementation = "eager"
|
392
|
+
|
393
|
+
def get_forward_dict(self):
|
394
|
+
forward_dict = {
|
395
|
+
"wrapper": DecoderOnlyModel.forward,
|
396
|
+
"model": DecoderOnlyDecoderLayer.forward,
|
397
|
+
"decoder_layer": DECODERONLY_ATTENTION_CLASSES[self.attn_implementation].forward,
|
398
|
+
}
|
399
|
+
return forward_dict
|
400
|
+
|
401
|
+
def forward(
|
402
|
+
self,
|
403
|
+
input_ids_or_inputs_embeds,
|
404
|
+
attention_mask,
|
405
|
+
cache_position,
|
406
|
+
batch_position,
|
407
|
+
query_idx,
|
408
|
+
*past_key_values,
|
409
|
+
):
|
410
|
+
if input_ids_or_inputs_embeds.ndim == 2:
|
411
|
+
# input_ids
|
412
|
+
input_ids = input_ids_or_inputs_embeds
|
413
|
+
inputs_embeds = None
|
414
|
+
elif input_ids_or_inputs_embeds.ndim == 3:
|
415
|
+
# inputs_embeds
|
416
|
+
input_ids = None
|
417
|
+
inputs_embeds = input_ids_or_inputs_embeds
|
418
|
+
else:
|
419
|
+
raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
|
420
|
+
|
421
|
+
# Formatting list of past_kv to DynamicCache class.
|
422
|
+
past_key_values = RebelDynamicCache.from_input_format(
|
423
|
+
cache_position,
|
424
|
+
self.config.num_hidden_layers,
|
425
|
+
*past_key_values,
|
426
|
+
)
|
427
|
+
|
428
|
+
batch_size = input_ids_or_inputs_embeds.size()[0]
|
429
|
+
seq_len = input_ids_or_inputs_embeds.size()[1]
|
430
|
+
|
431
|
+
if self.attn_implementation == "eager":
|
432
|
+
cache_pos_for_partitions = None
|
433
|
+
elif self.attn_implementation == "flash_attn_rbln":
|
434
|
+
p_len = self.kvcache_partition_size.size()[0]
|
435
|
+
num_partition = self.max_seq_len // p_len
|
436
|
+
if self.max_seq_len % p_len > 0:
|
437
|
+
raise ValueError(
|
438
|
+
f"The partition length({p_len}) must be exactly divisible by the max_seq_len({self.max_seq_len})."
|
439
|
+
)
|
440
|
+
cache_pos_for_partitions = torch.zeros((batch_size, num_partition), dtype=torch.int32)
|
441
|
+
|
442
|
+
if batch_size > 1: # decode
|
443
|
+
for b_idx in range(batch_size):
|
444
|
+
decoding_step = cache_position[b_idx]
|
445
|
+
cache_pos = decoding_step
|
446
|
+
for p_idx in range(num_partition):
|
447
|
+
input_0 = torch.tensor(cache_pos - p_len * p_idx, dtype=torch.int32)
|
448
|
+
input_1 = torch.tensor(p_len, dtype=torch.int32)
|
449
|
+
min = torch.minimum(input_0, input_1)
|
450
|
+
cache_pos_for_partition = torch.maximum(min, torch.tensor(0, dtype=torch.int32))
|
451
|
+
cache_pos_for_partitions[b_idx][p_idx] = cache_pos_for_partition
|
452
|
+
else: # prefill
|
453
|
+
cache_pos = cache_position[0][0]
|
454
|
+
for p_idx in range(num_partition):
|
455
|
+
input_0 = torch.tensor(cache_pos - p_len * p_idx, dtype=torch.int32)
|
456
|
+
input_1 = torch.tensor(p_len, dtype=torch.int32)
|
457
|
+
min = torch.minimum(input_0, input_1)
|
458
|
+
cache_pos_for_partition = torch.maximum(min, torch.tensor(0, dtype=torch.int32))
|
459
|
+
cache_pos_for_partitions[0][p_idx] = cache_pos_for_partition
|
460
|
+
else:
|
461
|
+
raise NotImplementedError(f"Unknown attn_implementation: {self.attn_implementation}")
|
462
|
+
|
463
|
+
forward_dict = self.get_forward_dict()
|
464
|
+
outputs = forward_dict["wrapper"](
|
465
|
+
self.model,
|
466
|
+
input_ids=input_ids,
|
467
|
+
inputs_embeds=inputs_embeds,
|
468
|
+
attention_mask=attention_mask,
|
469
|
+
position_ids=cache_position,
|
470
|
+
past_key_values=past_key_values,
|
471
|
+
batch_ids=batch_position,
|
472
|
+
rotary_pos_emb=self.rotary_emb,
|
473
|
+
cache_pos_for_partitions=cache_pos_for_partitions,
|
474
|
+
kvcache_partition_size=self.kvcache_partition_size,
|
475
|
+
forward_dict=forward_dict,
|
476
|
+
)
|
477
|
+
|
478
|
+
hidden_states = outputs[0]
|
479
|
+
if seq_len != 1:
|
480
|
+
hidden_states = hidden_states[:, query_idx.to(torch.int).unsqueeze(0)]
|
481
|
+
|
482
|
+
logits = self.lm_head(hidden_states)
|
483
|
+
|
484
|
+
output = (logits,) + outputs[1:]
|
485
|
+
|
486
|
+
return output, batch_position + query_idx
|
487
|
+
|
488
|
+
|
273
489
|
class DecoderOnlyDecoderLayer:
|
274
490
|
def forward(
|
275
491
|
self,
|
@@ -280,9 +496,11 @@ class DecoderOnlyDecoderLayer:
|
|
280
496
|
past_key_value: Optional[RebelDynamicCache] = None,
|
281
497
|
output_attentions: Optional[bool] = None,
|
282
498
|
use_cache: Optional[bool] = None,
|
283
|
-
batch_ids: Optional[torch.
|
499
|
+
batch_ids: Optional[torch.Tensor] = None,
|
284
500
|
cos: Optional[torch.Tensor] = None,
|
285
501
|
sin: Optional[torch.Tensor] = None,
|
502
|
+
cache_pos_for_partitions: Optional[torch.Tensor] = None,
|
503
|
+
kvcache_partition_size: Optional[torch.Tensor] = None,
|
286
504
|
forward_dict: Optional[Dict[str, classmethod]] = None,
|
287
505
|
**kwargs,
|
288
506
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
@@ -301,6 +519,8 @@ class DecoderOnlyDecoderLayer:
|
|
301
519
|
use_cache=use_cache,
|
302
520
|
cos=cos,
|
303
521
|
sin=sin,
|
522
|
+
cache_pos_for_partitions=cache_pos_for_partitions,
|
523
|
+
kvcache_partition_size=kvcache_partition_size,
|
304
524
|
**kwargs,
|
305
525
|
)
|
306
526
|
past_key_value.assign(k, v, layer_idx)
|
@@ -331,11 +551,13 @@ class DecoderOnlyModel:
|
|
331
551
|
attention_mask: Optional[torch.Tensor] = None,
|
332
552
|
position_ids: Optional[torch.LongTensor] = None,
|
333
553
|
past_key_values: Optional[RebelDynamicCache] = None,
|
334
|
-
batch_ids: Optional[torch.
|
554
|
+
batch_ids: Optional[torch.Tensor] = None,
|
335
555
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
336
556
|
use_cache: Optional[bool] = True,
|
337
557
|
output_attentions: Optional[bool] = False,
|
338
558
|
output_hidden_states: Optional[bool] = False,
|
559
|
+
cache_pos_for_partitions: Optional[torch.Tensor] = None,
|
560
|
+
kvcache_partition_size: Optional[torch.Tensor] = None,
|
339
561
|
forward_dict: Optional[Dict[str, classmethod]] = None,
|
340
562
|
rotary_pos_emb=None,
|
341
563
|
) -> BaseModelOutputWithPast:
|
@@ -375,6 +597,8 @@ class DecoderOnlyModel:
|
|
375
597
|
batch_ids=batch_ids,
|
376
598
|
cos=cos,
|
377
599
|
sin=sin,
|
600
|
+
cache_pos_for_partitions=cache_pos_for_partitions,
|
601
|
+
kvcache_partition_size=kvcache_partition_size,
|
378
602
|
forward_dict=forward_dict,
|
379
603
|
)
|
380
604
|
|
@@ -435,106 +659,40 @@ def apply_rotary_pos_emb(q, k, cos, sin):
|
|
435
659
|
|
436
660
|
|
437
661
|
class RotaryEmbedding(nn.Module):
|
662
|
+
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
663
|
+
|
438
664
|
def __init__(
|
439
665
|
self,
|
440
|
-
|
441
|
-
|
442
|
-
base=10000,
|
443
|
-
device=None,
|
444
|
-
scaling_factor=1.0,
|
666
|
+
config: PretrainedConfig,
|
667
|
+
max_seq_len_cached: int,
|
445
668
|
):
|
446
669
|
super().__init__()
|
447
670
|
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
453
|
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
671
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
672
|
+
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
673
|
+
else:
|
674
|
+
rope_type = "default"
|
454
675
|
|
455
|
-
|
456
|
-
|
676
|
+
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
677
|
+
position_ids = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
|
678
|
+
position_ids_expanded = position_ids[:, None]
|
679
|
+
|
680
|
+
if rope_type == "dynamic":
|
681
|
+
freqs = position_ids_expanded.float() * inv_freq.float()
|
682
|
+
else:
|
683
|
+
inv_freq_expanded = inv_freq[None, :]
|
684
|
+
freqs = position_ids_expanded.float() @ inv_freq_expanded.float()
|
457
685
|
|
458
|
-
positions_ids = torch.arange(self.max_position_embeddings, device=device, dtype=self.inv_freq.dtype)
|
459
|
-
freqs = torch.outer(positions_ids, self.inv_freq)
|
460
|
-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
461
686
|
emb = torch.cat((freqs, freqs), dim=-1)
|
462
687
|
|
463
|
-
|
464
|
-
|
688
|
+
cos = emb.cos() * attention_scaling
|
689
|
+
sin = emb.sin() * attention_scaling
|
690
|
+
|
691
|
+
self.register_buffer("_cos_cached", cos, persistent=False)
|
692
|
+
self.register_buffer("_sin_cached", sin, persistent=False)
|
465
693
|
|
466
694
|
def forward(self, x, seq_len):
|
467
695
|
return (
|
468
696
|
self._cos_cached[:seq_len].to(dtype=x.dtype),
|
469
697
|
self._sin_cached[:seq_len].to(dtype=x.dtype),
|
470
698
|
)
|
471
|
-
|
472
|
-
|
473
|
-
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
474
|
-
"""RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
475
|
-
|
476
|
-
def __init__(
|
477
|
-
self,
|
478
|
-
dim,
|
479
|
-
max_position_embeddings=2048,
|
480
|
-
base=10000,
|
481
|
-
device=None,
|
482
|
-
scaling_factor=1.0,
|
483
|
-
max_seq_len=2048,
|
484
|
-
):
|
485
|
-
super().__init__(
|
486
|
-
dim,
|
487
|
-
max_position_embeddings=max_position_embeddings,
|
488
|
-
base=base,
|
489
|
-
scaling_factor=scaling_factor,
|
490
|
-
)
|
491
|
-
# difference to the original RoPE: a scaling factor is aplied to the position ids
|
492
|
-
if max_seq_len > max_position_embeddings:
|
493
|
-
positions_ids = torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
494
|
-
positions_ids = positions_ids / self.scaling_factor
|
495
|
-
freqs = torch.outer(positions_ids, self.inv_freq)
|
496
|
-
emb = torch.cat((freqs, freqs), dim=-1)
|
497
|
-
cos = emb.cos()
|
498
|
-
sin = emb.sin()
|
499
|
-
|
500
|
-
self._cos_cached = torch.cat([self._cos_cached, cos[max_position_embeddings:]], dim=0)
|
501
|
-
self._sin_cached = torch.cat([self._sin_cached, sin[max_position_embeddings:]], dim=0)
|
502
|
-
|
503
|
-
|
504
|
-
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
505
|
-
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
506
|
-
|
507
|
-
def __init__(
|
508
|
-
self,
|
509
|
-
dim,
|
510
|
-
max_position_embeddings=2048,
|
511
|
-
base=10000,
|
512
|
-
device=None,
|
513
|
-
scaling_factor=1.0,
|
514
|
-
max_seq_len=2048,
|
515
|
-
):
|
516
|
-
super().__init__(
|
517
|
-
dim,
|
518
|
-
max_position_embeddings=max_position_embeddings,
|
519
|
-
base=base,
|
520
|
-
scaling_factor=scaling_factor,
|
521
|
-
)
|
522
|
-
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
|
523
|
-
device = self.inv_freq.device
|
524
|
-
dtype = self.inv_freq.dtype
|
525
|
-
if max_seq_len > max_position_embeddings:
|
526
|
-
position_ids = torch.arange(max_position_embeddings, max_seq_len, dtype=dtype).view(-1, 1)
|
527
|
-
seq_len = position_ids + 1
|
528
|
-
base = self.base * (
|
529
|
-
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
530
|
-
) ** (self.dim / (self.dim - 2))
|
531
|
-
|
532
|
-
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
533
|
-
|
534
|
-
freqs = position_ids * inv_freq
|
535
|
-
emb = torch.cat((freqs, freqs), dim=-1)
|
536
|
-
cos = emb.cos()
|
537
|
-
sin = emb.sin()
|
538
|
-
|
539
|
-
self._cos_cached = torch.cat([self._cos_cached, cos], dim=0)
|
540
|
-
self._sin_cached = torch.cat([self._sin_cached, sin], dim=0)
|