optimum-rbln 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +22 -12
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +44 -58
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +54 -14
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +78 -16
- optimum/rbln/diffusers/pipelines/__init__.py +22 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +5 -26
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +14 -6
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +14 -6
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +164 -758
- optimum/rbln/modeling_diffusers.py +51 -122
- optimum/rbln/transformers/__init__.py +0 -2
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -3
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +672 -412
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +38 -155
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +61 -45
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -75
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -242
- optimum/rbln/transformers/models/midm/modeling_midm.py +6 -6
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -261
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/modeling_t5.py +102 -4
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +120 -3
- optimum/rbln/utils/decorator_utils.py +10 -6
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +15 -1
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +1 -1
- optimum/rbln/utils/submodule.py +114 -0
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/RECORD +69 -66
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum/rbln/utils/context.py +0 -58
- optimum_rbln-0.1.13.dist-info/METADATA +0 -120
- optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -22,20 +22,23 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import math
|
25
|
-
from typing import
|
25
|
+
from typing import List, Optional, Tuple
|
26
26
|
|
27
27
|
import torch
|
28
28
|
from torch import nn
|
29
|
-
from transformers import PretrainedConfig
|
30
|
-
from transformers.
|
31
|
-
BaseModelOutputWithPast,
|
32
|
-
)
|
29
|
+
from transformers import PretrainedConfig, PreTrainedModel
|
30
|
+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
33
31
|
|
34
32
|
from ....utils import logging
|
35
|
-
from ...cache_utils import RebelDynamicCache
|
36
33
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
37
34
|
|
38
35
|
|
36
|
+
if is_torch_greater_or_equal_than_2_4:
|
37
|
+
register_fake = torch.library.register_fake
|
38
|
+
else:
|
39
|
+
register_fake = torch.library.impl_abstract
|
40
|
+
|
41
|
+
|
39
42
|
logger = logging.get_logger(__name__)
|
40
43
|
"""
|
41
44
|
##############################################################################
|
@@ -83,7 +86,7 @@ def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, partition):
|
|
83
86
|
return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
|
84
87
|
|
85
88
|
|
86
|
-
@
|
89
|
+
@register_fake("rbln_custom_ops::flash_attn_decode")
|
87
90
|
def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
|
88
91
|
return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
|
89
92
|
|
@@ -129,7 +132,7 @@ def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, partition)
|
|
129
132
|
return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
|
130
133
|
|
131
134
|
|
132
|
-
@
|
135
|
+
@register_fake("rbln_custom_ops::flash_attn_prefill")
|
133
136
|
def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
|
134
137
|
return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
|
135
138
|
|
@@ -144,501 +147,599 @@ def rbln_cache_update_cpu(cache, value, batch, seq):
|
|
144
147
|
return updated_cache
|
145
148
|
|
146
149
|
|
147
|
-
@
|
150
|
+
@register_fake("rbln_custom_ops::rbln_cache_update")
|
148
151
|
def rbln_cache_update_abstract(cache, value, batch, seq):
|
149
152
|
return torch.empty_like(cache)
|
150
153
|
|
151
154
|
|
152
|
-
class
|
153
|
-
|
154
|
-
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
155
|
-
key_state = key_state.unsqueeze(2)
|
156
|
-
value_state = value_state.unsqueeze(2)
|
157
|
-
attn_mask = attn_mask.unsqueeze(2)
|
155
|
+
class DecoderOnlyWrapper(nn.Module):
|
156
|
+
"""A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
|
158
157
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
self.head_dim,
|
165
|
-
)
|
158
|
+
This wrapper is designed to:
|
159
|
+
1. Convert Huggingface decoder models for RBLN compilation with static shapes
|
160
|
+
2. Handle input/model mapping and additional information supply (e.g., positional embeddings)
|
161
|
+
3. Manage different attention implementations (standard and flash attention)
|
162
|
+
4. Support both prefill and decode phases
|
166
163
|
|
167
|
-
|
168
|
-
|
169
|
-
|
164
|
+
Notes:
|
165
|
+
- Wrapper must only receive positional arguments in forward() due to torch.jit.trace dependency
|
166
|
+
- Wrapper should not contain neural network graph operations (including memory view handling)
|
170
167
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
168
|
+
Args:
|
169
|
+
causal_lm (PreTrainedModel): The Huggingface causal language model to wrap
|
170
|
+
max_seq_len (int): Maximum sequence length for position embeddings and cache sizes
|
171
|
+
use_rotary_emb (bool): Whether to use rotary position embeddings
|
172
|
+
kvcache_partition_len (Optional[int]): Length of KV cache partitions for flash attention.
|
173
|
+
If provided, uses flash attention; if None, uses standard attention
|
174
|
+
"""
|
175
175
|
|
176
|
-
|
177
|
-
|
178
|
-
|
176
|
+
def __init__(self, causal_lm: PreTrainedModel, max_seq_len, use_rotary_emb: bool, kvcache_partition_len=None):
|
177
|
+
super().__init__()
|
178
|
+
self.config = causal_lm.config
|
179
179
|
|
180
|
-
|
180
|
+
if use_rotary_emb:
|
181
|
+
self.rotary_emb = self.get_rotary_emb(max_seq_len=max_seq_len)
|
182
|
+
else:
|
183
|
+
self.rotary_emb = None
|
181
184
|
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
cos: Optional[torch.Tensor] = None,
|
191
|
-
sin: Optional[torch.Tensor] = None,
|
192
|
-
**kwargs,
|
193
|
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
194
|
-
bsz, q_len, _ = hidden_states.size()
|
195
|
-
query_states = self.q_proj(hidden_states)
|
196
|
-
key_states = self.k_proj(hidden_states)
|
197
|
-
value_states = self.v_proj(hidden_states)
|
185
|
+
if kvcache_partition_len is not None:
|
186
|
+
# WORKAROUND : for passing partition length as a value to the rbln compiler.
|
187
|
+
# What is actually used is the shape of this tensor.
|
188
|
+
self.attn_impl = "flash_attn"
|
189
|
+
logger.info(f"Using flash-attention. (partition length : {kvcache_partition_len})")
|
190
|
+
else:
|
191
|
+
self.attn_impl = "eager"
|
192
|
+
self.kvcache_partition_len = kvcache_partition_len
|
198
193
|
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
past_key_value,
|
216
|
-
batch_idx=b,
|
217
|
-
is_prefill=False,
|
194
|
+
self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm)
|
195
|
+
|
196
|
+
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
|
197
|
+
self._phase = "prefill"
|
198
|
+
|
199
|
+
def get_rotary_emb(self, max_seq_len):
|
200
|
+
return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
201
|
+
|
202
|
+
def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel):
|
203
|
+
new_layers = []
|
204
|
+
for layer in causal_lm.model.layers:
|
205
|
+
if self.attn_impl == "eager":
|
206
|
+
new_self_attn = DecoderOnlyAttention(layer.self_attn)
|
207
|
+
elif self.attn_impl == "flash_attn":
|
208
|
+
new_self_attn = DecoderOnlyFlashAttention(
|
209
|
+
layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
|
218
210
|
)
|
211
|
+
else:
|
212
|
+
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
213
|
+
|
214
|
+
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
215
|
+
new_layers.append(new_layer)
|
216
|
+
new_model = DecoderOnlyModel(causal_lm.model, new_layers)
|
217
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
218
|
+
return new_causal_lm
|
219
219
|
|
220
|
-
|
221
|
-
|
222
|
-
|
220
|
+
@property
|
221
|
+
def phase(self):
|
222
|
+
return self._phase
|
223
223
|
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
224
|
+
@phase.setter
|
225
|
+
def phase(self, phase: str):
|
226
|
+
self._phase = phase
|
227
|
+
self.causal_lm.phase = phase
|
228
|
+
|
229
|
+
def forward(
|
230
|
+
self,
|
231
|
+
input_ids_or_inputs_embeds,
|
232
|
+
attention_mask,
|
233
|
+
cache_position,
|
234
|
+
batch_position,
|
235
|
+
query_position,
|
236
|
+
*past_key_values,
|
237
|
+
):
|
238
|
+
if input_ids_or_inputs_embeds.ndim == 2:
|
239
|
+
# It is input_ids
|
240
|
+
input_ids = input_ids_or_inputs_embeds
|
241
|
+
inputs_embeds = None
|
242
|
+
elif input_ids_or_inputs_embeds.ndim == 3:
|
243
|
+
# It is inputs_embeds
|
244
|
+
input_ids = None
|
245
|
+
inputs_embeds = input_ids_or_inputs_embeds
|
228
246
|
else:
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
attention_mask,
|
235
|
-
past_key_value,
|
236
|
-
batch_idx=batch_index,
|
237
|
-
is_prefill=True,
|
247
|
+
raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
|
248
|
+
|
249
|
+
if len(past_key_values) != 2 * self.num_hidden_layers:
|
250
|
+
raise ValueError(
|
251
|
+
f"Different past_key_values to model's config. {len(past_key_values)} != {self.num_hidden_layers}"
|
238
252
|
)
|
239
253
|
|
240
|
-
|
254
|
+
seq_len = input_ids_or_inputs_embeds.shape[1]
|
255
|
+
if seq_len == 1:
|
256
|
+
self.phase = "decode"
|
257
|
+
else:
|
258
|
+
self.phase = "prefill"
|
259
|
+
|
260
|
+
# [key, value] * n_layer -> ( (key, value) ) * n_layer
|
261
|
+
# cache shape : batch, n_heads, 1, max_seq_len, head_dim
|
262
|
+
_past_key_values = []
|
263
|
+
for i in range(self.config.num_hidden_layers):
|
264
|
+
key_states = past_key_values[i * 2]
|
265
|
+
value_states = past_key_values[i * 2 + 1]
|
266
|
+
past_key_value = [key_states, value_states]
|
267
|
+
_past_key_values.append(past_key_value)
|
268
|
+
past_key_values = _past_key_values
|
269
|
+
|
270
|
+
logit, present_key_values = self.causal_lm(
|
271
|
+
input_ids=input_ids,
|
272
|
+
inputs_embeds=inputs_embeds,
|
273
|
+
attention_mask=attention_mask,
|
274
|
+
cache_position=cache_position,
|
275
|
+
batch_position=batch_position,
|
276
|
+
query_position=query_position,
|
277
|
+
past_key_values=past_key_values,
|
278
|
+
rotary_emb=self.rotary_emb,
|
279
|
+
)
|
280
|
+
|
281
|
+
# ((key, value)) * n_layer -> [key, value] * n_layer
|
282
|
+
_present_key_values = ()
|
283
|
+
for i in range(self.num_hidden_layers):
|
284
|
+
key_states = present_key_values[i][0]
|
285
|
+
value_states = present_key_values[i][1]
|
286
|
+
_present_key_values = _present_key_values + (key_states, value_states)
|
287
|
+
present_key_values = _present_key_values
|
241
288
|
|
242
|
-
|
243
|
-
|
289
|
+
# batch_position + query_position is dummy output node to keep the number of outputs
|
290
|
+
return logit, present_key_values, batch_position + query_position
|
244
291
|
|
245
|
-
return attn_output, attn_weight, key_states, value_states
|
246
292
|
|
293
|
+
class DecoderOnlyForCausalLM(nn.Module):
|
294
|
+
"""A specialized wrapper for Causal Language Models optimized for RBLN compilation.
|
247
295
|
|
248
|
-
class
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
attention_mask: Optional[torch.Tensor] = None,
|
253
|
-
position_ids: Optional[torch.LongTensor] = None,
|
254
|
-
past_key_value: Optional[RebelDynamicCache] = None,
|
255
|
-
batch_index: Optional[torch.Tensor] = None,
|
256
|
-
output_attentions: bool = False,
|
257
|
-
cos: Optional[torch.Tensor] = None,
|
258
|
-
sin: Optional[torch.Tensor] = None,
|
259
|
-
cache_pos_for_partitions: Optional[torch.Tensor] = None,
|
260
|
-
kvcache_partition_size: Optional[torch.Tensor] = None,
|
261
|
-
**kwargs,
|
262
|
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
263
|
-
bsz, q_len, _ = hidden_states.size()
|
264
|
-
query_states = self.q_proj(hidden_states)
|
265
|
-
key_states = self.k_proj(hidden_states)
|
266
|
-
value_states = self.v_proj(hidden_states)
|
296
|
+
This class adapts Huggingface's CausalLM (or similar models) for RBLN deployment by:
|
297
|
+
1. Managing model phases (prefill/decode) throughout the computation graph
|
298
|
+
2. Handling output shape alignments for static compilation
|
299
|
+
3. Coordinating between the original model and RBLN-optimized components
|
267
300
|
|
268
|
-
|
269
|
-
|
270
|
-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
271
|
-
|
272
|
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
273
|
-
|
274
|
-
# Decoder (bsz > 1)
|
275
|
-
if bsz > 1:
|
276
|
-
all_key_states = []
|
277
|
-
all_value_states = []
|
278
|
-
all_attn_output = []
|
279
|
-
|
280
|
-
for b in range(bsz):
|
281
|
-
query_state = query_states[b].unsqueeze(0)
|
282
|
-
attn_mask = attention_mask[b].unsqueeze(0)
|
283
|
-
key_state = key_states[b].unsqueeze(0)
|
284
|
-
value_state = value_states[b].unsqueeze(0)
|
285
|
-
|
286
|
-
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
287
|
-
key_state = key_state.unsqueeze(2)
|
288
|
-
value_state = value_state.unsqueeze(2)
|
289
|
-
attn_mask = attn_mask.unsqueeze(2)
|
290
|
-
|
291
|
-
query_state = query_state.view(
|
292
|
-
1,
|
293
|
-
self.num_key_value_heads,
|
294
|
-
self.num_heads // self.num_key_value_heads,
|
295
|
-
q_len,
|
296
|
-
self.head_dim,
|
297
|
-
)
|
301
|
+
The class serves as an intermediate layer between DecoderOnlyWrapper and the core model,
|
302
|
+
focusing on maintaining correct model behavior while enabling RBLN-specific optimizations.
|
298
303
|
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
query_state,
|
303
|
-
key_state,
|
304
|
-
value_state,
|
305
|
-
attn_mask,
|
306
|
-
past_key_value.key_cache[self.layer_idx].unsqueeze(2),
|
307
|
-
past_key_value.value_cache[self.layer_idx].unsqueeze(2),
|
308
|
-
sidx,
|
309
|
-
kvcache_partition_size,
|
310
|
-
)
|
304
|
+
Args:
|
305
|
+
causal_lm (PreTrainedModel): Original Huggingface causal language model
|
306
|
+
model (DecoderOnlyModel): RBLN-optimized model instance
|
311
307
|
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
308
|
+
Attributes:
|
309
|
+
config: Configuration from the original causal language model
|
310
|
+
_original_mod: Reference to the original model for components like lm_head
|
311
|
+
model: RBLN-optimized decoder model instance
|
312
|
+
_phase: Current processing phase ("prefill" or "decode")
|
313
|
+
"""
|
316
314
|
|
317
|
-
|
318
|
-
|
319
|
-
|
315
|
+
def __init__(self, causal_lm: PreTrainedModel, model):
|
316
|
+
super().__init__()
|
317
|
+
self.config = causal_lm.config
|
318
|
+
self._original_mod = causal_lm
|
319
|
+
self.model = model
|
320
|
+
self._phase = "prefill"
|
320
321
|
|
321
|
-
|
322
|
-
|
323
|
-
|
322
|
+
@property
|
323
|
+
def phase(self):
|
324
|
+
return self._phase
|
324
325
|
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
attention_mask = attention_mask.unsqueeze(2)
|
330
|
-
query_states = query_states.view(
|
331
|
-
1,
|
332
|
-
self.num_key_value_heads,
|
333
|
-
self.num_heads // self.num_key_value_heads,
|
334
|
-
q_len,
|
335
|
-
self.head_dim,
|
336
|
-
)
|
326
|
+
@phase.setter
|
327
|
+
def phase(self, phase: str):
|
328
|
+
self._phase = phase
|
329
|
+
self.model.phase = phase
|
337
330
|
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
331
|
+
def forward(
|
332
|
+
self,
|
333
|
+
input_ids: torch.Tensor = None,
|
334
|
+
inputs_embeds: torch.Tensor = None,
|
335
|
+
attention_mask: torch.Tensor = None,
|
336
|
+
cache_position: torch.Tensor = None,
|
337
|
+
batch_position: torch.Tensor = None,
|
338
|
+
query_position: torch.Tensor = None,
|
339
|
+
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
340
|
+
rotary_emb: nn.Module = None,
|
341
|
+
):
|
342
|
+
# outputs
|
343
|
+
hidden_states, present_key_values = self.model(
|
344
|
+
input_ids=input_ids,
|
345
|
+
inputs_embeds=inputs_embeds,
|
346
|
+
attention_mask=attention_mask,
|
347
|
+
cache_position=cache_position,
|
348
|
+
batch_position=batch_position,
|
349
|
+
past_key_values=past_key_values,
|
350
|
+
rotary_emb=rotary_emb,
|
351
|
+
)
|
353
352
|
|
354
|
-
|
355
|
-
|
356
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
357
|
-
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
353
|
+
if self.phase == "prefill":
|
354
|
+
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
358
355
|
|
359
|
-
|
356
|
+
logits = self._original_mod.lm_head(hidden_states)
|
357
|
+
output = (logits, present_key_values)
|
358
|
+
return output
|
360
359
|
|
361
|
-
if not output_attentions:
|
362
|
-
attn_weight = None
|
363
360
|
|
364
|
-
|
361
|
+
class DecoderOnlyModel(nn.Module):
|
362
|
+
"""A modified decoder-only model implementation optimized for RBLN compilation.
|
365
363
|
|
364
|
+
Args:
|
365
|
+
model: Original Huggingface model to adapt
|
366
|
+
layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
|
366
367
|
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
368
|
+
Attributes:
|
369
|
+
_original_mod: Reference to original Huggingface model
|
370
|
+
layers: ModuleList of RBLN-optimized transformer layers
|
371
|
+
_phase: Current processing phase ("prefill" or "decode")
|
372
|
+
"""
|
372
373
|
|
374
|
+
mask_fmin = torch.finfo(torch.float16).min
|
373
375
|
|
374
|
-
|
375
|
-
def __init__(self, model, max_seq_len, kvcache_partition_len=None):
|
376
|
+
def __init__(self, model, layers: List["DecoderOnlyLayer"]):
|
376
377
|
super().__init__()
|
377
|
-
self.
|
378
|
-
self.
|
379
|
-
self.
|
380
|
-
self.max_seq_len = max_seq_len
|
381
|
-
self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
378
|
+
self._original_mod = model
|
379
|
+
self.layers = nn.ModuleList(layers)
|
380
|
+
self._phase = "prefill"
|
382
381
|
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
382
|
+
@property
|
383
|
+
def phase(self):
|
384
|
+
return self._phase
|
385
|
+
|
386
|
+
@phase.setter
|
387
|
+
def phase(self, phase: str):
|
388
|
+
self._phase = phase
|
389
|
+
for layer in self.layers:
|
390
|
+
layer.phase = phase
|
392
391
|
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
392
|
+
@property
|
393
|
+
def hidden_multiplier(self):
|
394
|
+
return 1
|
395
|
+
|
396
|
+
def get_last_layernorm(self) -> nn.LayerNorm:
|
397
|
+
return self._original_mod.norm
|
398
|
+
|
399
|
+
def get_embedding(self) -> nn.Embedding:
|
400
|
+
return self._original_mod.embed_tokens
|
401
|
+
|
402
|
+
def get_pos_embedding(self) -> nn.Embedding:
|
403
|
+
raise NotImplementedError(
|
404
|
+
"The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
|
405
|
+
)
|
400
406
|
|
401
407
|
def forward(
|
402
408
|
self,
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
+
input_ids: torch.Tensor = None,
|
410
|
+
inputs_embeds: torch.Tensor = None,
|
411
|
+
attention_mask: torch.Tensor = None,
|
412
|
+
cache_position: torch.Tensor = None,
|
413
|
+
batch_position: torch.Tensor = None,
|
414
|
+
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
415
|
+
rotary_emb: nn.Module = None,
|
409
416
|
):
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
# inputs_embeds
|
416
|
-
input_ids = None
|
417
|
-
inputs_embeds = input_ids_or_inputs_embeds
|
418
|
-
else:
|
419
|
-
raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
|
417
|
+
# retrieve input_ids and inputs_embeds
|
418
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
419
|
+
raise ValueError(
|
420
|
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
421
|
+
)
|
420
422
|
|
421
|
-
#
|
422
|
-
|
423
|
-
|
424
|
-
self.config.num_hidden_layers,
|
425
|
-
*past_key_values,
|
426
|
-
)
|
423
|
+
# embed positions
|
424
|
+
if inputs_embeds is None:
|
425
|
+
inputs_embeds = self.get_embedding()(input_ids)
|
427
426
|
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
if self.attn_implementation == "eager":
|
432
|
-
cache_pos_for_partitions = None
|
433
|
-
elif self.attn_implementation == "flash_attn_rbln":
|
434
|
-
p_len = self.kvcache_partition_size.size()[0]
|
435
|
-
num_partition = self.max_seq_len // p_len
|
436
|
-
if self.max_seq_len % p_len > 0:
|
437
|
-
raise ValueError(
|
438
|
-
f"The partition length({p_len}) must be exactly divisible by the max_seq_len({self.max_seq_len})."
|
439
|
-
)
|
440
|
-
cache_pos_for_partitions = torch.zeros((batch_size, num_partition), dtype=torch.int32)
|
427
|
+
hidden_states = inputs_embeds * self.hidden_multiplier
|
428
|
+
attention_mask = (1 - attention_mask) * self.mask_fmin
|
441
429
|
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
for p_idx in range(num_partition):
|
447
|
-
input_0 = torch.tensor(cache_pos - p_len * p_idx, dtype=torch.int32)
|
448
|
-
input_1 = torch.tensor(p_len, dtype=torch.int32)
|
449
|
-
min = torch.minimum(input_0, input_1)
|
450
|
-
cache_pos_for_partition = torch.maximum(min, torch.tensor(0, dtype=torch.int32))
|
451
|
-
cache_pos_for_partitions[b_idx][p_idx] = cache_pos_for_partition
|
452
|
-
else: # prefill
|
453
|
-
cache_pos = cache_position[0][0]
|
454
|
-
for p_idx in range(num_partition):
|
455
|
-
input_0 = torch.tensor(cache_pos - p_len * p_idx, dtype=torch.int32)
|
456
|
-
input_1 = torch.tensor(p_len, dtype=torch.int32)
|
457
|
-
min = torch.minimum(input_0, input_1)
|
458
|
-
cache_pos_for_partition = torch.maximum(min, torch.tensor(0, dtype=torch.int32))
|
459
|
-
cache_pos_for_partitions[0][p_idx] = cache_pos_for_partition
|
430
|
+
# get cos,sin vector if needed
|
431
|
+
if rotary_emb is not None:
|
432
|
+
cos, sin = rotary_emb(hidden_states, attention_mask.shape[-1]) # dtype carrier, max_seq_len
|
433
|
+
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
|
460
434
|
else:
|
461
|
-
|
435
|
+
batch_size = inputs_embeds.shape[0]
|
436
|
+
if cache_position.shape[0] > 1:
|
437
|
+
position_embeds = []
|
438
|
+
for b_idx in range(batch_size):
|
439
|
+
position_embed = self.get_pos_embedding()(cache_position[b_idx])
|
440
|
+
position_embeds.append(position_embed)
|
441
|
+
|
442
|
+
position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
|
443
|
+
else:
|
444
|
+
position_embeds = self.get_pos_embedding()(cache_position)
|
445
|
+
hidden_states = hidden_states + position_embeds
|
446
|
+
cos, sin = None, None
|
447
|
+
|
448
|
+
# (batch, seq_len) -> (batch,)
|
449
|
+
current_steps = cache_position[:, 0]
|
450
|
+
|
451
|
+
present_key_values = past_key_values
|
452
|
+
for layer in self.layers:
|
453
|
+
hidden_states, present_key_values = layer(
|
454
|
+
hidden_states=hidden_states,
|
455
|
+
attention_mask=attention_mask,
|
456
|
+
current_steps=current_steps,
|
457
|
+
batch_position=batch_position,
|
458
|
+
past_key_values=present_key_values,
|
459
|
+
cos=cos,
|
460
|
+
sin=sin,
|
461
|
+
)
|
462
462
|
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
forward_dict=forward_dict,
|
476
|
-
)
|
463
|
+
hidden_states = self.get_last_layernorm()(hidden_states)
|
464
|
+
return hidden_states, present_key_values
|
465
|
+
|
466
|
+
|
467
|
+
class DecoderOnlyLayer(nn.Module):
|
468
|
+
"""A single transformer layer adapted for RBLN compilation with static shapes.
|
469
|
+
|
470
|
+
This layer implements a modified transformer block that includes:
|
471
|
+
1. Self-attention mechanism (either standard or flash attention)
|
472
|
+
2. Feed-forward network (FFN)
|
473
|
+
3. Layer normalization
|
474
|
+
4. Residual connections
|
477
475
|
|
478
|
-
|
479
|
-
|
480
|
-
|
476
|
+
The layer is specifically designed to:
|
477
|
+
- Support compilation to RBLN custom ops
|
478
|
+
- Maintain static tensor shapes throughout computations
|
479
|
+
- Handle both prefill and decode phases efficiently
|
480
|
+
- Manage attention state transitions properly
|
481
481
|
|
482
|
-
|
482
|
+
Args:
|
483
|
+
layer: Original transformer layer module to wrap
|
484
|
+
self_attn (DecoderOnlyAttention): Modified attention module optimized for RBLN
|
483
485
|
|
484
|
-
|
486
|
+
Attributes:
|
487
|
+
_original_mod: Reference to original layer for accessing components
|
488
|
+
self_attn: Modified attention mechanism mapped to RBLN ops at compile time
|
489
|
+
phase: Current operation phase ("prefill" or "decode")
|
490
|
+
"""
|
491
|
+
|
492
|
+
def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
|
493
|
+
super().__init__()
|
494
|
+
self._original_mod = layer
|
495
|
+
self.self_attn = self_attn
|
496
|
+
self._phase = "prefill"
|
497
|
+
|
498
|
+
@property
|
499
|
+
def phase(self):
|
500
|
+
return self._phase
|
485
501
|
|
486
|
-
|
502
|
+
@phase.setter
|
503
|
+
def phase(self, phase: str):
|
504
|
+
self._phase = phase
|
505
|
+
self.self_attn.phase = phase
|
487
506
|
|
507
|
+
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
508
|
+
return self._original_mod.input_layernorm
|
509
|
+
|
510
|
+
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
511
|
+
return self._original_mod.post_attention_layernorm
|
488
512
|
|
489
|
-
class DecoderOnlyDecoderLayer:
|
490
513
|
def forward(
|
491
514
|
self,
|
492
515
|
hidden_states: torch.Tensor,
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
output_attentions: Optional[bool] = None,
|
498
|
-
use_cache: Optional[bool] = None,
|
499
|
-
batch_ids: Optional[torch.Tensor] = None,
|
516
|
+
attention_mask: torch.Tensor,
|
517
|
+
current_steps: torch.LongTensor,
|
518
|
+
batch_position: torch.Tensor,
|
519
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
500
520
|
cos: Optional[torch.Tensor] = None,
|
501
521
|
sin: Optional[torch.Tensor] = None,
|
502
|
-
|
503
|
-
kvcache_partition_size: Optional[torch.Tensor] = None,
|
504
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
505
|
-
**kwargs,
|
506
|
-
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
522
|
+
):
|
507
523
|
residual = hidden_states
|
508
524
|
|
509
|
-
hidden_states = self.
|
525
|
+
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
510
526
|
|
511
|
-
hidden_states,
|
512
|
-
self.self_attn,
|
527
|
+
hidden_states, present_key_values = self.self_attn(
|
513
528
|
hidden_states=hidden_states,
|
514
529
|
attention_mask=attention_mask,
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
batch_index=batch_ids,
|
519
|
-
use_cache=use_cache,
|
530
|
+
current_steps=current_steps,
|
531
|
+
batch_position=batch_position,
|
532
|
+
past_key_values=past_key_values,
|
520
533
|
cos=cos,
|
521
534
|
sin=sin,
|
522
|
-
cache_pos_for_partitions=cache_pos_for_partitions,
|
523
|
-
kvcache_partition_size=kvcache_partition_size,
|
524
|
-
**kwargs,
|
525
535
|
)
|
526
|
-
past_key_value.assign(k, v, layer_idx)
|
527
|
-
|
528
536
|
hidden_states = residual + hidden_states
|
529
537
|
|
530
538
|
# Fully Connected
|
531
539
|
residual = hidden_states
|
532
|
-
hidden_states = self.
|
533
|
-
hidden_states = self.mlp(hidden_states)
|
540
|
+
hidden_states = self.get_post_attention_layernorm()(hidden_states)
|
541
|
+
hidden_states = self._original_mod.mlp(hidden_states)
|
534
542
|
hidden_states = residual + hidden_states
|
535
543
|
|
536
|
-
|
544
|
+
return hidden_states, present_key_values
|
537
545
|
|
538
|
-
if output_attentions:
|
539
|
-
outputs += (self_attn_weight,)
|
540
546
|
|
541
|
-
|
542
|
-
|
547
|
+
class DecoderOnlyAttention(nn.Module):
|
548
|
+
"""Attention implementation for decoder-only models optimized for RBLN compilation.
|
543
549
|
|
544
|
-
|
550
|
+
This class implements a modified version of the standard attention mechanism that:
|
551
|
+
1. Supports static shape requirements for RBLN compilation
|
552
|
+
2. Handles explicit batch and position management
|
545
553
|
|
554
|
+
Args:
|
555
|
+
self_attn: Original attention module from the base model
|
556
|
+
"""
|
546
557
|
|
547
|
-
|
548
|
-
|
558
|
+
def __init__(self, self_attn):
|
559
|
+
super().__init__()
|
560
|
+
self._original_mod = self_attn
|
561
|
+
self.layer_idx = self_attn.layer_idx
|
562
|
+
self.num_heads = self._original_mod.num_heads
|
563
|
+
self.head_dim = self._original_mod.head_dim
|
564
|
+
self.phase = "prefill"
|
565
|
+
self.__post_init__()
|
566
|
+
|
567
|
+
def __post_init__(self):
|
568
|
+
self.q_proj = self._original_mod.q_proj
|
569
|
+
self.k_proj = self._original_mod.k_proj
|
570
|
+
self.v_proj = self._original_mod.v_proj
|
571
|
+
self.o_proj = self._original_mod.o_proj
|
572
|
+
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
573
|
+
|
574
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
575
|
+
"""Projects input hidden states into query, key, and value representations.
|
576
|
+
|
577
|
+
Args:
|
578
|
+
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
|
579
|
+
|
580
|
+
Returns:
|
581
|
+
Tuple of (query_states, key_states, value_states)
|
582
|
+
"""
|
583
|
+
query_states = self.q_proj(hidden_states)
|
584
|
+
key_states = self.k_proj(hidden_states)
|
585
|
+
value_states = self.v_proj(hidden_states)
|
586
|
+
return query_states, key_states, value_states
|
587
|
+
|
588
|
+
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
589
|
+
return apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
590
|
+
|
591
|
+
def rbln_attention(
|
549
592
|
self,
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
593
|
+
query_state,
|
594
|
+
key_state,
|
595
|
+
value_state,
|
596
|
+
attn_mask,
|
597
|
+
batch_idx,
|
598
|
+
past_key_state,
|
599
|
+
past_value_state,
|
600
|
+
current_step,
|
601
|
+
# below are designed for Midm, GPT which requires to support scaling for attention weights
|
602
|
+
# TODO(jongho): Merge and manage scales generally
|
603
|
+
layer_idx=None,
|
604
|
+
scale_attn_weights: bool = None,
|
605
|
+
scale_attn_by_inverse_layer_idx: bool = None,
|
606
|
+
scale_qk_by_inverse_layer_idx: bool = None,
|
607
|
+
):
|
608
|
+
"""Compute attention with static shapes and explicit cache management.
|
609
|
+
|
610
|
+
Args:
|
611
|
+
query_state: Query tensor [1, num_heads, 1, head_dim]
|
612
|
+
key_state: Key tensor [1, num_heads, seq_len, head_dim]
|
613
|
+
value_state: Value tensor [1, num_heads, seq_len, head_dim]
|
614
|
+
attn_mask: Attention mask tensor
|
615
|
+
batch_idx: Batch index for cache lookup
|
616
|
+
past_key_state: Previous key cache states
|
617
|
+
past_value_state: Previous value cache states
|
618
|
+
current_step: Current position in sequence
|
619
|
+
|
620
|
+
Returns:
|
621
|
+
Tuple of (attention_output, key_state, value_state)
|
622
|
+
"""
|
623
|
+
# Implementation details.
|
624
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
625
|
+
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
626
|
+
value_state = value_state.unsqueeze(2)
|
627
|
+
attn_mask = attn_mask.unsqueeze(2)
|
569
628
|
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
kvcache_partition_size=kvcache_partition_size,
|
602
|
-
forward_dict=forward_dict,
|
603
|
-
)
|
629
|
+
query_state = query_state.view(
|
630
|
+
1,
|
631
|
+
self.num_key_value_heads,
|
632
|
+
self.num_heads // self.num_key_value_heads,
|
633
|
+
-1, # seq len
|
634
|
+
self.head_dim,
|
635
|
+
) #
|
636
|
+
|
637
|
+
kend = current_step + key_state.shape[-2]
|
638
|
+
vend = current_step + value_state.shape[-2]
|
639
|
+
|
640
|
+
key_state = (
|
641
|
+
past_key_state[batch_idx]
|
642
|
+
.unsqueeze(0)
|
643
|
+
.unsqueeze(2)
|
644
|
+
.slice_scatter(key_state, dim=-2, start=current_step, end=kend)
|
645
|
+
)
|
646
|
+
value_state = (
|
647
|
+
past_value_state[batch_idx]
|
648
|
+
.unsqueeze(0)
|
649
|
+
.unsqueeze(2)
|
650
|
+
.slice_scatter(value_state, dim=-2, start=current_step, end=vend)
|
651
|
+
)
|
652
|
+
|
653
|
+
attn_weight = torch.matmul(query_state, key_state.transpose(3, 4))
|
654
|
+
attn_weight = attn_weight / math.sqrt(self.head_dim)
|
655
|
+
|
656
|
+
if layer_idx is not None and (scale_attn_by_inverse_layer_idx or scale_qk_by_inverse_layer_idx):
|
657
|
+
attn_weight = attn_weight / float(layer_idx + 1)
|
658
|
+
|
659
|
+
attn_weight += attn_mask
|
604
660
|
|
605
|
-
|
661
|
+
if layer_idx is not None and scale_qk_by_inverse_layer_idx:
|
662
|
+
attn_weight = attn_weight * float(layer_idx + 1)
|
606
663
|
|
607
|
-
|
664
|
+
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
|
608
665
|
|
609
|
-
|
610
|
-
all_self_attns += (layer_outputs[1],)
|
666
|
+
attn_output = torch.matmul(attn_weight, value_state)
|
611
667
|
|
612
|
-
|
668
|
+
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
|
669
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
670
|
+
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
613
671
|
|
614
|
-
|
615
|
-
if output_hidden_states:
|
616
|
-
all_hidden_states += (hidden_states,)
|
672
|
+
return attn_output, key_state, value_state
|
617
673
|
|
618
|
-
|
619
|
-
|
674
|
+
def forward(
|
675
|
+
self,
|
676
|
+
hidden_states: torch.Tensor,
|
677
|
+
attention_mask: torch.Tensor,
|
678
|
+
current_steps: torch.LongTensor,
|
679
|
+
batch_position: torch.Tensor,
|
680
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
681
|
+
cos: Optional[torch.Tensor] = None, # (batch, 1, prefill_size, head_dim)
|
682
|
+
sin: Optional[torch.Tensor] = None,
|
683
|
+
):
|
684
|
+
batch_size, query_length, _ = hidden_states.size()
|
620
685
|
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
686
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
687
|
+
|
688
|
+
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
689
|
+
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
690
|
+
value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
|
691
|
+
1, 2
|
626
692
|
)
|
693
|
+
# b, num_head, query, head_dim
|
694
|
+
|
695
|
+
if cos is not None and sin is not None:
|
696
|
+
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
697
|
+
|
698
|
+
if batch_size > 1 and self.phase == "prefill":
|
699
|
+
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
700
|
+
|
701
|
+
_key_states = []
|
702
|
+
_value_states = []
|
703
|
+
_attn_outputs = []
|
704
|
+
for b in range(batch_size):
|
705
|
+
current_step = current_steps[b]
|
706
|
+
attn_output, key_state, value_state = self.rbln_attention(
|
707
|
+
query_states[b].unsqueeze(0),
|
708
|
+
key_states[b].unsqueeze(0),
|
709
|
+
value_states[b].unsqueeze(0),
|
710
|
+
attention_mask[b].unsqueeze(0)
|
711
|
+
if self.phase == "decode"
|
712
|
+
else attention_mask, # TODO(jongho): fix when msoftmax is supported
|
713
|
+
past_key_state=past_key_values[self.layer_idx][0],
|
714
|
+
past_value_state=past_key_values[self.layer_idx][1],
|
715
|
+
batch_idx=b if self.phase == "decode" else batch_position,
|
716
|
+
current_step=current_step,
|
717
|
+
)
|
718
|
+
_key_states.append(key_state)
|
719
|
+
_value_states.append(value_state)
|
720
|
+
_attn_outputs.append(attn_output)
|
721
|
+
key_states = torch.cat(_key_states, dim=0)
|
722
|
+
value_states = torch.cat(_value_states, dim=0)
|
723
|
+
attn_outputs = torch.cat(_attn_outputs, dim=0)
|
627
724
|
|
725
|
+
attn_outputs = self.o_proj(attn_outputs)
|
726
|
+
past_key_values[self.layer_idx] = key_states, value_states
|
727
|
+
return attn_outputs, past_key_values
|
628
728
|
|
629
|
-
|
630
|
-
|
631
|
-
|
729
|
+
|
730
|
+
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
731
|
+
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
|
732
|
+
if cache_position.shape[0] > 1:
|
632
733
|
cos_all = []
|
633
734
|
sin_all = []
|
634
|
-
for i in range(
|
635
|
-
cos_all.append(cos[
|
636
|
-
sin_all.append(sin[
|
735
|
+
for i in range(cache_position.shape[0]):
|
736
|
+
cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
737
|
+
sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
637
738
|
cos = torch.cat(cos_all, dim=0)
|
638
739
|
sin = torch.cat(sin_all, dim=0)
|
639
740
|
else:
|
640
|
-
cos = cos[
|
641
|
-
sin = sin[
|
741
|
+
cos = cos[cache_position].unsqueeze(unsqueeze_dim)
|
742
|
+
sin = sin[cache_position].unsqueeze(unsqueeze_dim)
|
642
743
|
|
643
744
|
return cos, sin
|
644
745
|
|
@@ -658,6 +759,26 @@ def apply_rotary_pos_emb(q, k, cos, sin):
|
|
658
759
|
return q_embed, k_embed
|
659
760
|
|
660
761
|
|
762
|
+
def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
|
763
|
+
# Partial rotary embedding
|
764
|
+
query_rot, query_pass = (
|
765
|
+
query_states[..., :ndim],
|
766
|
+
query_states[..., ndim:],
|
767
|
+
)
|
768
|
+
key_rot, key_pass = (
|
769
|
+
key_states[..., :ndim],
|
770
|
+
key_states[..., ndim:],
|
771
|
+
)
|
772
|
+
|
773
|
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
774
|
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
775
|
+
|
776
|
+
# [batch_size, seq_length, num_heads, head_dim]
|
777
|
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
778
|
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
779
|
+
return query_states, key_states
|
780
|
+
|
781
|
+
|
661
782
|
class RotaryEmbedding(nn.Module):
|
662
783
|
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
663
784
|
|
@@ -674,14 +795,14 @@ class RotaryEmbedding(nn.Module):
|
|
674
795
|
rope_type = "default"
|
675
796
|
|
676
797
|
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
677
|
-
|
678
|
-
|
798
|
+
cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
|
799
|
+
cache_position_expanded = cache_position[:, None]
|
679
800
|
|
680
801
|
if rope_type == "dynamic":
|
681
|
-
freqs =
|
802
|
+
freqs = cache_position_expanded.float() * inv_freq.float()
|
682
803
|
else:
|
683
804
|
inv_freq_expanded = inv_freq[None, :]
|
684
|
-
freqs =
|
805
|
+
freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
|
685
806
|
|
686
807
|
emb = torch.cat((freqs, freqs), dim=-1)
|
687
808
|
|
@@ -696,3 +817,142 @@ class RotaryEmbedding(nn.Module):
|
|
696
817
|
self._cos_cached[:seq_len].to(dtype=x.dtype),
|
697
818
|
self._sin_cached[:seq_len].to(dtype=x.dtype),
|
698
819
|
)
|
820
|
+
|
821
|
+
|
822
|
+
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
823
|
+
def __init__(self, self_attn, kvcache_partition_len):
|
824
|
+
super().__init__(self_attn=self_attn)
|
825
|
+
self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
|
826
|
+
|
827
|
+
def get_cache_pos_for_partitions(self, current_steps, batch_size, max_seq_len):
|
828
|
+
partition_len = self.kvcache_partition_size.size()[0]
|
829
|
+
num_partition = max_seq_len // partition_len
|
830
|
+
cache_pos_for_partitions = torch.zeros((batch_size, num_partition), dtype=torch.int32)
|
831
|
+
if self.phase == "decode":
|
832
|
+
for b_idx in range(batch_size):
|
833
|
+
cache_pos = current_steps[b_idx]
|
834
|
+
for p_idx in range(num_partition):
|
835
|
+
cache_pos_for_partitions[b_idx][p_idx] = torch.clamp(
|
836
|
+
cache_pos - partition_len * p_idx, 0, partition_len
|
837
|
+
)
|
838
|
+
else: # prefill
|
839
|
+
cache_pos = current_steps[0]
|
840
|
+
for p_idx in range(num_partition):
|
841
|
+
cache_pos_for_partitions[0][p_idx] = torch.clamp(cache_pos - partition_len * p_idx, 0, partition_len)
|
842
|
+
|
843
|
+
return cache_pos_for_partitions
|
844
|
+
|
845
|
+
def rbln_flash_attention(
|
846
|
+
self,
|
847
|
+
query_state,
|
848
|
+
key_state,
|
849
|
+
value_state,
|
850
|
+
attn_mask,
|
851
|
+
batch_idx,
|
852
|
+
past_key_state,
|
853
|
+
past_value_state,
|
854
|
+
cache_pos_for_partitions,
|
855
|
+
):
|
856
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
857
|
+
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
858
|
+
value_state = value_state.unsqueeze(2)
|
859
|
+
attn_mask = attn_mask.unsqueeze(2)
|
860
|
+
|
861
|
+
query_state = query_state.view(
|
862
|
+
1,
|
863
|
+
self.num_key_value_heads,
|
864
|
+
self.num_heads // self.num_key_value_heads,
|
865
|
+
-1, # seq len
|
866
|
+
self.head_dim,
|
867
|
+
)
|
868
|
+
|
869
|
+
# RBLN custom flash attention(decode), dummy batch index
|
870
|
+
if self.phase == "decode":
|
871
|
+
sidx = cache_pos_for_partitions[batch_idx][0]
|
872
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_decode(
|
873
|
+
query_state,
|
874
|
+
key_state,
|
875
|
+
value_state,
|
876
|
+
attn_mask,
|
877
|
+
past_key_state.unsqueeze(2),
|
878
|
+
past_value_state.unsqueeze(2),
|
879
|
+
sidx,
|
880
|
+
self.kvcache_partition_size,
|
881
|
+
)
|
882
|
+
else:
|
883
|
+
sidx = cache_pos_for_partitions[0][0]
|
884
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_prefill(
|
885
|
+
query_state,
|
886
|
+
key_state,
|
887
|
+
value_state,
|
888
|
+
attn_mask,
|
889
|
+
past_key_state.unsqueeze(2),
|
890
|
+
past_value_state.unsqueeze(2),
|
891
|
+
batch_idx,
|
892
|
+
sidx,
|
893
|
+
self.kvcache_partition_size,
|
894
|
+
)
|
895
|
+
|
896
|
+
# reshape for removing repeat_kv
|
897
|
+
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
|
898
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
899
|
+
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
900
|
+
|
901
|
+
return attn_output, key_state, value_state
|
902
|
+
|
903
|
+
def forward(
|
904
|
+
self,
|
905
|
+
hidden_states: torch.Tensor,
|
906
|
+
attention_mask: torch.Tensor,
|
907
|
+
current_steps: torch.LongTensor,
|
908
|
+
batch_position: torch.Tensor,
|
909
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
910
|
+
cos: Optional[torch.Tensor] = None,
|
911
|
+
sin: Optional[torch.Tensor] = None,
|
912
|
+
):
|
913
|
+
batch_size, query_length, _ = hidden_states.size()
|
914
|
+
|
915
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
916
|
+
|
917
|
+
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
918
|
+
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
919
|
+
value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
|
920
|
+
1, 2
|
921
|
+
)
|
922
|
+
# b, num_head, query, head_dim
|
923
|
+
|
924
|
+
max_seq_len = past_key_values[self.layer_idx][0].shape[-2]
|
925
|
+
|
926
|
+
if cos is not None and sin is not None:
|
927
|
+
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
928
|
+
|
929
|
+
cache_pos_for_partitions = self.get_cache_pos_for_partitions(
|
930
|
+
current_steps, batch_size=batch_size, max_seq_len=max_seq_len
|
931
|
+
) # batch_size, num_partitions
|
932
|
+
|
933
|
+
_key_states = []
|
934
|
+
_value_states = []
|
935
|
+
_attn_outputs = []
|
936
|
+
for b in range(batch_size):
|
937
|
+
attn_output, key_state, value_state = self.rbln_flash_attention(
|
938
|
+
query_states[b].unsqueeze(0),
|
939
|
+
key_states[b].unsqueeze(0),
|
940
|
+
value_states[b].unsqueeze(0),
|
941
|
+
attention_mask[b].unsqueeze(0)
|
942
|
+
if self.phase == "decode"
|
943
|
+
else attention_mask, # TODO(jongho): fix when msoftmax is supported
|
944
|
+
past_key_state=past_key_values[self.layer_idx][0],
|
945
|
+
past_value_state=past_key_values[self.layer_idx][1],
|
946
|
+
batch_idx=b if self.phase == "decode" else batch_position,
|
947
|
+
cache_pos_for_partitions=cache_pos_for_partitions,
|
948
|
+
)
|
949
|
+
_key_states.append(key_state)
|
950
|
+
_value_states.append(value_state)
|
951
|
+
_attn_outputs.append(attn_output)
|
952
|
+
key_states = torch.cat(_key_states, dim=0)
|
953
|
+
value_states = torch.cat(_value_states, dim=0)
|
954
|
+
attn_outputs = torch.cat(_attn_outputs, dim=0)
|
955
|
+
|
956
|
+
attn_outputs = self.o_proj(attn_outputs)
|
957
|
+
past_key_values[self.layer_idx] = key_states, value_states
|
958
|
+
return attn_outputs, past_key_values
|