optimum-rbln 0.7.2rc2__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +8 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/__init__.py +8 -0
- optimum/rbln/diffusers/modeling_diffusers.py +103 -117
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -3
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +15 -8
- optimum/rbln/diffusers/pipelines/__init__.py +8 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +7 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +25 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +107 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +25 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +3 -0
- optimum/rbln/modeling.py +4 -1
- optimum/rbln/modeling_base.py +16 -3
- optimum/rbln/ops/__init__.py +6 -2
- optimum/rbln/ops/attn.py +94 -85
- optimum/rbln/ops/flash_attn.py +46 -25
- optimum/rbln/ops/kv_cache_update.py +4 -4
- optimum/rbln/transformers/modeling_generic.py +3 -3
- optimum/rbln/transformers/models/bart/bart_architecture.py +10 -6
- optimum/rbln/transformers/models/bart/modeling_bart.py +6 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +264 -133
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +276 -29
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +11 -4
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +11 -4
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +5 -3
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -3
- optimum/rbln/transformers/models/phi/phi_architecture.py +9 -7
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +50 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +60 -36
- optimum/rbln/transformers/models/t5/modeling_t5.py +3 -1
- optimum/rbln/transformers/models/t5/t5_architecture.py +65 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +26 -36
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -14
- optimum/rbln/utils/import_utils.py +7 -0
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/RECORD +40 -38
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/licenses/LICENSE +0 -0
@@ -19,7 +19,12 @@ import torch
|
|
19
19
|
from torch import nn
|
20
20
|
from transformers import PretrainedConfig, PreTrainedModel
|
21
21
|
|
22
|
-
from ....ops import
|
22
|
+
from ....ops import (
|
23
|
+
register_rbln_custom_paged_attention,
|
24
|
+
register_rbln_custom_paged_causal_attention,
|
25
|
+
register_rbln_custom_paged_flash_attention,
|
26
|
+
register_rbln_custom_paged_flash_causal_attention,
|
27
|
+
)
|
23
28
|
from ....utils import logging
|
24
29
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
25
30
|
|
@@ -34,7 +39,7 @@ MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
|
|
34
39
|
|
35
40
|
|
36
41
|
def validate_attention_method(
|
37
|
-
rbln_attn_impl: str, rbln_kvcache_partition_len: int, rbln_max_seq_len: int
|
42
|
+
rbln_attn_impl: str, rbln_kvcache_partition_len: int, rbln_kvcache_block_size: int, rbln_max_seq_len: int
|
38
43
|
) -> Tuple[str, int]:
|
39
44
|
if rbln_kvcache_partition_len is not None:
|
40
45
|
if rbln_attn_impl == "eager":
|
@@ -93,7 +98,19 @@ def validate_attention_method(
|
|
93
98
|
"this requirement, or consider switching `rbln_attn_impl` to 'eager' for shorter lengths."
|
94
99
|
)
|
95
100
|
|
96
|
-
|
101
|
+
if rbln_kvcache_block_size is not None:
|
102
|
+
if rbln_attn_impl == "flash_attn" and rbln_kvcache_partition_len != rbln_kvcache_block_size:
|
103
|
+
raise ValueError(
|
104
|
+
f" When using 'flash attention', the `rbln_kvcache_block_size` ({rbln_kvcache_block_size}) "
|
105
|
+
f"must always be set equal to the `rbln_kvcache_partition_len` {rbln_kvcache_partition_len}."
|
106
|
+
)
|
107
|
+
elif rbln_attn_impl == "eager" and rbln_kvcache_block_size != rbln_max_seq_len:
|
108
|
+
raise ValueError(
|
109
|
+
f" When using 'eager attention', the `rbln_kvcache_block_size` ({rbln_kvcache_block_size}) "
|
110
|
+
f"must always be set equal to the `rbln_max_seq_len` {rbln_max_seq_len}."
|
111
|
+
)
|
112
|
+
|
113
|
+
return rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size
|
97
114
|
|
98
115
|
|
99
116
|
class DecoderOnlyWrapper(nn.Module):
|
@@ -102,7 +119,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
102
119
|
This wrapper is designed to:
|
103
120
|
1. Convert Huggingface decoder models for RBLN compilation with static shapes
|
104
121
|
2. Handle input/model mapping and additional information supply (e.g., positional embeddings)
|
105
|
-
3. Manage different attention implementations (standard
|
122
|
+
3. Manage different attention implementations (standard/flash attention)
|
106
123
|
4. Support both prefill and decode phases
|
107
124
|
|
108
125
|
Notes:
|
@@ -128,7 +145,9 @@ class DecoderOnlyWrapper(nn.Module):
|
|
128
145
|
max_seq_len: int,
|
129
146
|
use_rotary_emb: bool,
|
130
147
|
attn_impl: str,
|
148
|
+
use_attention_mask: bool,
|
131
149
|
kvcache_partition_len: Optional[int] = None,
|
150
|
+
kvcache_block_size: Optional[int] = None,
|
132
151
|
):
|
133
152
|
super().__init__()
|
134
153
|
self.config = causal_lm.config
|
@@ -139,12 +158,20 @@ class DecoderOnlyWrapper(nn.Module):
|
|
139
158
|
self.rotary_emb = None
|
140
159
|
|
141
160
|
self.attn_impl = attn_impl
|
161
|
+
self.kvcache_block_size = kvcache_block_size
|
162
|
+
self.use_attention_mask = use_attention_mask
|
142
163
|
if self.attn_impl == "flash_attn":
|
143
164
|
self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
144
|
-
|
165
|
+
if self.use_attention_mask:
|
166
|
+
register_rbln_custom_paged_flash_attention()
|
167
|
+
else:
|
168
|
+
register_rbln_custom_paged_flash_causal_attention()
|
145
169
|
elif self.attn_impl == "eager":
|
146
170
|
self.kvcache_partition_len = None
|
147
|
-
|
171
|
+
if self.use_attention_mask:
|
172
|
+
register_rbln_custom_paged_attention()
|
173
|
+
else:
|
174
|
+
register_rbln_custom_paged_causal_attention()
|
148
175
|
else:
|
149
176
|
raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
|
150
177
|
|
@@ -154,7 +181,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
154
181
|
f" or equal to max_seq_len({max_seq_len})!"
|
155
182
|
)
|
156
183
|
|
157
|
-
self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm)
|
184
|
+
self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm, max_seq_len)
|
158
185
|
|
159
186
|
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
|
160
187
|
self._phase = "prefill"
|
@@ -162,21 +189,32 @@ class DecoderOnlyWrapper(nn.Module):
|
|
162
189
|
def get_rotary_emb(self, max_seq_len):
|
163
190
|
return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
164
191
|
|
165
|
-
def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel):
|
192
|
+
def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
|
166
193
|
new_layers = []
|
167
194
|
for layer in causal_lm.model.layers:
|
168
195
|
if self.attn_impl == "eager":
|
169
|
-
new_self_attn = DecoderOnlyAttention(
|
196
|
+
new_self_attn = DecoderOnlyAttention(
|
197
|
+
layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
198
|
+
)
|
170
199
|
elif self.attn_impl == "flash_attn":
|
171
200
|
new_self_attn = DecoderOnlyFlashAttention(
|
172
|
-
layer.self_attn,
|
201
|
+
layer.self_attn,
|
202
|
+
kvcache_partition_len=self.kvcache_partition_len,
|
203
|
+
kvcache_block_size=self.kvcache_block_size,
|
204
|
+
use_attention_mask=self.use_attention_mask,
|
173
205
|
)
|
174
206
|
else:
|
175
207
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
176
208
|
|
177
209
|
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
178
210
|
new_layers.append(new_layer)
|
179
|
-
new_model = DecoderOnlyModel(
|
211
|
+
new_model = DecoderOnlyModel(
|
212
|
+
causal_lm.model,
|
213
|
+
new_layers,
|
214
|
+
partition_len=self.kvcache_partition_len,
|
215
|
+
max_seq_len=max_seq_len,
|
216
|
+
kvcache_block_size=self.kvcache_block_size,
|
217
|
+
)
|
180
218
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
181
219
|
return new_causal_lm
|
182
220
|
|
@@ -191,23 +229,43 @@ class DecoderOnlyWrapper(nn.Module):
|
|
191
229
|
|
192
230
|
def forward(self, *args):
|
193
231
|
if self.phase == "decode":
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
232
|
+
if self.use_attention_mask:
|
233
|
+
(
|
234
|
+
input_ids_or_inputs_embeds,
|
235
|
+
cache_position,
|
236
|
+
attention_mask,
|
237
|
+
block_tables,
|
238
|
+
*past_key_values,
|
239
|
+
) = args
|
240
|
+
else:
|
241
|
+
(
|
242
|
+
input_ids_or_inputs_embeds,
|
243
|
+
cache_position,
|
244
|
+
block_tables,
|
245
|
+
*past_key_values,
|
246
|
+
) = args
|
247
|
+
attention_mask = None
|
201
248
|
query_position = None
|
202
249
|
elif self.phase == "prefill":
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
250
|
+
if self.use_attention_mask:
|
251
|
+
(
|
252
|
+
input_ids_or_inputs_embeds,
|
253
|
+
cache_position,
|
254
|
+
attention_mask,
|
255
|
+
query_position,
|
256
|
+
block_tables,
|
257
|
+
*past_key_values,
|
258
|
+
) = args
|
259
|
+
else:
|
260
|
+
(
|
261
|
+
input_ids_or_inputs_embeds,
|
262
|
+
cache_position,
|
263
|
+
query_position,
|
264
|
+
block_tables,
|
265
|
+
*past_key_values,
|
266
|
+
) = args
|
267
|
+
attention_mask = None
|
268
|
+
|
211
269
|
else:
|
212
270
|
raise ValueError(f"Unknown phase: {self.phase}")
|
213
271
|
|
@@ -235,26 +293,18 @@ class DecoderOnlyWrapper(nn.Module):
|
|
235
293
|
_past_key_values.append(past_key_value)
|
236
294
|
past_key_values = _past_key_values
|
237
295
|
|
238
|
-
logit
|
296
|
+
logit = self.causal_lm(
|
239
297
|
input_ids=input_ids,
|
240
298
|
inputs_embeds=inputs_embeds,
|
241
299
|
attention_mask=attention_mask,
|
242
300
|
cache_position=cache_position,
|
243
|
-
batch_position=batch_position,
|
244
301
|
query_position=query_position,
|
245
302
|
past_key_values=past_key_values,
|
246
303
|
rotary_emb=self.rotary_emb,
|
304
|
+
block_tables=block_tables,
|
247
305
|
)
|
248
306
|
|
249
|
-
|
250
|
-
_present_key_values = ()
|
251
|
-
for i in range(self.num_hidden_layers):
|
252
|
-
key_states = present_key_values[i][0]
|
253
|
-
value_states = present_key_values[i][1]
|
254
|
-
_present_key_values = _present_key_values + (key_states, value_states)
|
255
|
-
present_key_values = _present_key_values
|
256
|
-
|
257
|
-
return logit, present_key_values
|
307
|
+
return logit
|
258
308
|
|
259
309
|
|
260
310
|
class DecoderOnlyForCausalLM(nn.Module):
|
@@ -301,28 +351,27 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
301
351
|
inputs_embeds: torch.Tensor = None,
|
302
352
|
attention_mask: torch.Tensor = None,
|
303
353
|
cache_position: torch.Tensor = None,
|
304
|
-
batch_position: torch.Tensor = None,
|
305
354
|
query_position: torch.Tensor = None,
|
306
355
|
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
307
356
|
rotary_emb: nn.Module = None,
|
357
|
+
block_tables: Optional[torch.Tensor] = None,
|
308
358
|
):
|
309
359
|
# outputs
|
310
|
-
hidden_states
|
360
|
+
hidden_states = self.model(
|
311
361
|
input_ids=input_ids,
|
312
362
|
inputs_embeds=inputs_embeds,
|
313
363
|
attention_mask=attention_mask,
|
314
364
|
cache_position=cache_position,
|
315
|
-
batch_position=batch_position,
|
316
365
|
past_key_values=past_key_values,
|
317
366
|
rotary_emb=rotary_emb,
|
367
|
+
block_tables=block_tables,
|
318
368
|
)
|
319
369
|
|
320
370
|
if self.phase == "prefill":
|
321
371
|
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
322
372
|
|
323
373
|
logits = self._original_mod.lm_head(hidden_states)
|
324
|
-
|
325
|
-
return output
|
374
|
+
return logits
|
326
375
|
|
327
376
|
|
328
377
|
class DecoderOnlyModel(nn.Module):
|
@@ -338,12 +387,16 @@ class DecoderOnlyModel(nn.Module):
|
|
338
387
|
_phase: Current processing phase ("prefill" or "decode")
|
339
388
|
"""
|
340
389
|
|
341
|
-
def __init__(
|
390
|
+
def __init__(
|
391
|
+
self, model, layers: List["DecoderOnlyLayer"], partition_len=None, max_seq_len=None, kvcache_block_size=None
|
392
|
+
):
|
342
393
|
super().__init__()
|
343
394
|
self._original_mod = model
|
344
395
|
self.layers = nn.ModuleList(layers)
|
345
396
|
self._phase = "prefill"
|
346
397
|
self.partition_len = partition_len
|
398
|
+
self.kvcache_block_size = kvcache_block_size
|
399
|
+
self.max_seq_len = max_seq_len
|
347
400
|
|
348
401
|
@property
|
349
402
|
def phase(self):
|
@@ -364,9 +417,8 @@ class DecoderOnlyModel(nn.Module):
|
|
364
417
|
return 1
|
365
418
|
|
366
419
|
def convert_sequence_positions_for_flash_attn(self, seq_positions, max_seq_len):
|
367
|
-
if self.attn_impl
|
420
|
+
if self.attn_impl not in ["flash_attn"]:
|
368
421
|
raise NotImplementedError(f"Unknown attn_impl ({self.attn_impl}).")
|
369
|
-
|
370
422
|
partition_len = self.partition_len
|
371
423
|
num_partition = max_seq_len // partition_len
|
372
424
|
|
@@ -392,9 +444,9 @@ class DecoderOnlyModel(nn.Module):
|
|
392
444
|
inputs_embeds: torch.Tensor = None,
|
393
445
|
attention_mask: torch.Tensor = None,
|
394
446
|
cache_position: torch.Tensor = None,
|
395
|
-
batch_position: torch.Tensor = None,
|
396
447
|
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
397
448
|
rotary_emb: nn.Module = None,
|
449
|
+
block_tables: Optional[torch.Tensor] = None,
|
398
450
|
):
|
399
451
|
# retrieve input_ids and inputs_embeds
|
400
452
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
@@ -410,7 +462,7 @@ class DecoderOnlyModel(nn.Module):
|
|
410
462
|
|
411
463
|
# get cos,sin vector if needed
|
412
464
|
if rotary_emb is not None:
|
413
|
-
cos, sin = rotary_emb(hidden_states,
|
465
|
+
cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
|
414
466
|
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
|
415
467
|
else:
|
416
468
|
batch_size = inputs_embeds.shape[0]
|
@@ -429,27 +481,25 @@ class DecoderOnlyModel(nn.Module):
|
|
429
481
|
# (batch, seq_len) -> (batch,)
|
430
482
|
if self.attn_impl == "flash_attn":
|
431
483
|
seq_positions = cache_position[:, 0]
|
432
|
-
max_seq_len = past_key_values[0][0].shape[-2]
|
433
484
|
seq_positions = self.convert_sequence_positions_for_flash_attn(
|
434
|
-
seq_positions=seq_positions, max_seq_len=max_seq_len
|
485
|
+
seq_positions=seq_positions, max_seq_len=self.max_seq_len
|
435
486
|
)
|
436
487
|
else:
|
437
488
|
seq_positions = cache_position[:, :1]
|
438
489
|
|
439
|
-
present_key_values = past_key_values
|
440
490
|
for layer in self.layers:
|
441
|
-
hidden_states
|
491
|
+
hidden_states = layer(
|
442
492
|
hidden_states=hidden_states,
|
443
493
|
attention_mask=attention_mask,
|
444
494
|
seq_positions=seq_positions,
|
445
|
-
|
446
|
-
past_key_values=present_key_values,
|
495
|
+
past_key_values=past_key_values,
|
447
496
|
cos=cos,
|
448
497
|
sin=sin,
|
498
|
+
block_tables=block_tables,
|
449
499
|
)
|
450
500
|
|
451
501
|
hidden_states = self.get_last_layernorm()(hidden_states)
|
452
|
-
return hidden_states
|
502
|
+
return hidden_states
|
453
503
|
|
454
504
|
|
455
505
|
class DecoderOnlyLayer(nn.Module):
|
@@ -503,22 +553,22 @@ class DecoderOnlyLayer(nn.Module):
|
|
503
553
|
hidden_states: torch.Tensor,
|
504
554
|
attention_mask: torch.Tensor,
|
505
555
|
seq_positions: torch.LongTensor,
|
506
|
-
batch_position: torch.Tensor,
|
507
556
|
past_key_values: Tuple[Tuple[torch.Tensor]],
|
508
557
|
cos: Optional[torch.Tensor] = None,
|
509
558
|
sin: Optional[torch.Tensor] = None,
|
559
|
+
block_tables: Optional[torch.Tensor] = None,
|
510
560
|
):
|
511
561
|
residual = hidden_states
|
512
562
|
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
513
563
|
|
514
|
-
hidden_states
|
564
|
+
hidden_states = self.self_attn(
|
515
565
|
hidden_states=hidden_states,
|
516
566
|
attention_mask=attention_mask,
|
517
567
|
seq_positions=seq_positions,
|
518
|
-
batch_position=batch_position,
|
519
568
|
past_key_values=past_key_values,
|
520
569
|
cos=cos,
|
521
570
|
sin=sin,
|
571
|
+
block_tables=block_tables,
|
522
572
|
)
|
523
573
|
hidden_states = residual + hidden_states
|
524
574
|
|
@@ -528,7 +578,7 @@ class DecoderOnlyLayer(nn.Module):
|
|
528
578
|
hidden_states = self._original_mod.mlp(hidden_states)
|
529
579
|
hidden_states = residual + hidden_states
|
530
580
|
|
531
|
-
return hidden_states
|
581
|
+
return hidden_states
|
532
582
|
|
533
583
|
|
534
584
|
class DecoderOnlyAttention(nn.Module):
|
@@ -542,7 +592,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
542
592
|
self_attn: Original attention module from the base model
|
543
593
|
"""
|
544
594
|
|
545
|
-
def __init__(self, self_attn):
|
595
|
+
def __init__(self, self_attn, use_attention_mask, kvcache_block_size):
|
546
596
|
super().__init__()
|
547
597
|
self._original_mod = self_attn
|
548
598
|
self.layer_idx = self_attn.layer_idx
|
@@ -560,7 +610,9 @@ class DecoderOnlyAttention(nn.Module):
|
|
560
610
|
else:
|
561
611
|
self.num_key_value_heads = self.num_heads
|
562
612
|
|
613
|
+
self.use_attention_mask = use_attention_mask
|
563
614
|
self.attention = self.get_attention()
|
615
|
+
self.kvcache_block_size = kvcache_block_size
|
564
616
|
self.__post_init__()
|
565
617
|
|
566
618
|
@property
|
@@ -573,7 +625,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
573
625
|
self.attention.phase = phase
|
574
626
|
|
575
627
|
def get_attention(self):
|
576
|
-
return AttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads)
|
628
|
+
return AttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask)
|
577
629
|
|
578
630
|
def __post_init__(self):
|
579
631
|
self.q_proj = self._original_mod.q_proj
|
@@ -606,10 +658,10 @@ class DecoderOnlyAttention(nn.Module):
|
|
606
658
|
hidden_states: torch.Tensor,
|
607
659
|
attention_mask: torch.Tensor,
|
608
660
|
seq_positions: torch.LongTensor,
|
609
|
-
batch_position: torch.Tensor,
|
610
661
|
past_key_values: Tuple[Tuple[torch.Tensor]],
|
611
662
|
cos: Optional[torch.Tensor] = None,
|
612
663
|
sin: Optional[torch.Tensor] = None,
|
664
|
+
block_tables: Optional[torch.Tensor] = None,
|
613
665
|
):
|
614
666
|
batch_size, query_length, _ = hidden_states.size()
|
615
667
|
|
@@ -628,32 +680,31 @@ class DecoderOnlyAttention(nn.Module):
|
|
628
680
|
if batch_size > 1 and self.phase == "prefill":
|
629
681
|
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
630
682
|
|
631
|
-
attn_output
|
683
|
+
attn_output = self.attention(
|
632
684
|
query_states,
|
633
685
|
key_states,
|
634
686
|
value_states,
|
635
687
|
attention_mask,
|
636
688
|
past_key_state=past_key_values[self.layer_idx][0],
|
637
689
|
past_value_state=past_key_values[self.layer_idx][1],
|
638
|
-
batch_position=None if self.phase == "decode" else batch_position,
|
639
690
|
seq_position=seq_positions,
|
640
691
|
scale=self.scale,
|
692
|
+
block_tables=block_tables,
|
693
|
+
block_size=self.kvcache_block_size,
|
641
694
|
)
|
642
|
-
key_states = key_state
|
643
|
-
value_states = value_state
|
644
695
|
|
645
696
|
attn_outputs = self.o_proj(attn_output)
|
646
|
-
|
647
|
-
return attn_outputs, past_key_values
|
697
|
+
return attn_outputs
|
648
698
|
|
649
699
|
|
650
700
|
class AttentionOp(nn.Module):
|
651
|
-
def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int):
|
701
|
+
def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool):
|
652
702
|
super().__init__()
|
653
703
|
self.num_heads = num_heads
|
654
704
|
self.head_dim = head_dim
|
655
705
|
self.num_key_value_heads = num_key_value_heads
|
656
706
|
self.phase = "prefill"
|
707
|
+
self.use_attention_mask = use_attention_mask
|
657
708
|
|
658
709
|
def forward(
|
659
710
|
self,
|
@@ -661,11 +712,12 @@ class AttentionOp(nn.Module):
|
|
661
712
|
key_state: torch.Tensor,
|
662
713
|
value_state: torch.Tensor,
|
663
714
|
attn_mask: torch.Tensor,
|
664
|
-
batch_position: torch.Tensor,
|
665
715
|
past_key_state: torch.Tensor,
|
666
716
|
past_value_state: torch.Tensor,
|
667
717
|
seq_position: torch.Tensor,
|
668
718
|
scale: torch.Tensor,
|
719
|
+
block_tables: torch.Tensor,
|
720
|
+
block_size: int,
|
669
721
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
670
722
|
"""Compute attention with static shapes and explicit cache management.
|
671
723
|
|
@@ -674,19 +726,19 @@ class AttentionOp(nn.Module):
|
|
674
726
|
key_state: Key tensor [1, num_heads, seq_len, head_dim]
|
675
727
|
value_state: Value tensor [1, num_heads, seq_len, head_dim]
|
676
728
|
attn_mask: Attention mask tensor ∈ {0, 1}
|
677
|
-
batch_position: Batch index for cache lookup
|
678
729
|
past_key_state: Previous key cache states
|
679
730
|
past_value_state: Previous value cache states
|
680
731
|
seq_position: Current position in sequence
|
681
732
|
scale: Scale applied to attn weights
|
682
733
|
|
683
734
|
Returns:
|
684
|
-
|
735
|
+
Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
|
685
736
|
"""
|
686
737
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
687
738
|
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
688
739
|
value_state = value_state.unsqueeze(2)
|
689
|
-
|
740
|
+
if self.use_attention_mask:
|
741
|
+
attn_mask = attn_mask.unsqueeze(2)
|
690
742
|
|
691
743
|
if self.phase == "decode":
|
692
744
|
batch_size = key_state.shape[0]
|
@@ -702,35 +754,64 @@ class AttentionOp(nn.Module):
|
|
702
754
|
)
|
703
755
|
|
704
756
|
if self.phase == "decode":
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
757
|
+
if self.use_attention_mask:
|
758
|
+
attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
|
759
|
+
query_state,
|
760
|
+
key_state,
|
761
|
+
value_state,
|
762
|
+
attn_mask,
|
763
|
+
past_key_state.unsqueeze(2),
|
764
|
+
past_value_state.unsqueeze(2),
|
765
|
+
seq_position,
|
766
|
+
scale,
|
767
|
+
block_tables,
|
768
|
+
block_size,
|
769
|
+
)
|
770
|
+
else:
|
771
|
+
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
|
772
|
+
query_state,
|
773
|
+
key_state,
|
774
|
+
value_state,
|
775
|
+
past_key_state.unsqueeze(2),
|
776
|
+
past_value_state.unsqueeze(2),
|
777
|
+
seq_position,
|
778
|
+
scale,
|
779
|
+
block_tables,
|
780
|
+
block_size,
|
781
|
+
)
|
715
782
|
|
716
783
|
else:
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
784
|
+
if self.use_attention_mask:
|
785
|
+
attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
|
786
|
+
query_state,
|
787
|
+
key_state,
|
788
|
+
value_state,
|
789
|
+
attn_mask,
|
790
|
+
past_key_state.unsqueeze(2),
|
791
|
+
past_value_state.unsqueeze(2),
|
792
|
+
seq_position,
|
793
|
+
scale,
|
794
|
+
block_tables,
|
795
|
+
block_size,
|
796
|
+
)
|
797
|
+
else:
|
798
|
+
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
|
799
|
+
query_state,
|
800
|
+
key_state,
|
801
|
+
value_state,
|
802
|
+
past_key_state.unsqueeze(2),
|
803
|
+
past_value_state.unsqueeze(2),
|
804
|
+
seq_position,
|
805
|
+
scale,
|
806
|
+
block_tables,
|
807
|
+
block_size,
|
808
|
+
)
|
728
809
|
|
729
810
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
730
811
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
731
812
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
732
813
|
|
733
|
-
return attn_output
|
814
|
+
return attn_output
|
734
815
|
|
735
816
|
|
736
817
|
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
@@ -826,22 +907,30 @@ class RotaryEmbedding(nn.Module):
|
|
826
907
|
|
827
908
|
|
828
909
|
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
829
|
-
def __init__(self, self_attn, kvcache_partition_len):
|
910
|
+
def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask):
|
830
911
|
self.kvcache_partition_size = kvcache_partition_len
|
831
|
-
super().__init__(
|
912
|
+
super().__init__(
|
913
|
+
self_attn=self_attn, use_attention_mask=use_attention_mask, kvcache_block_size=kvcache_block_size
|
914
|
+
)
|
832
915
|
|
833
916
|
def get_attention(self):
|
834
|
-
return FlashAttentionOp(
|
917
|
+
return FlashAttentionOp(
|
918
|
+
self.num_heads,
|
919
|
+
self.head_dim,
|
920
|
+
self.num_key_value_heads,
|
921
|
+
self.kvcache_partition_size,
|
922
|
+
self.use_attention_mask,
|
923
|
+
)
|
835
924
|
|
836
925
|
def forward(
|
837
926
|
self,
|
838
927
|
hidden_states: torch.Tensor,
|
839
928
|
attention_mask: torch.Tensor,
|
840
929
|
seq_positions: torch.LongTensor,
|
841
|
-
batch_position: torch.Tensor,
|
842
930
|
past_key_values: Tuple[Tuple[torch.Tensor]],
|
843
931
|
cos: Optional[torch.Tensor] = None,
|
844
932
|
sin: Optional[torch.Tensor] = None,
|
933
|
+
block_tables: Optional[torch.Tensor] = None,
|
845
934
|
):
|
846
935
|
batch_size, query_length, _ = hidden_states.size()
|
847
936
|
|
@@ -857,29 +946,38 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
857
946
|
if cos is not None and sin is not None:
|
858
947
|
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
859
948
|
|
860
|
-
attn_output
|
949
|
+
attn_output = self.attention(
|
861
950
|
query_states,
|
862
951
|
key_states,
|
863
952
|
value_states,
|
864
953
|
attention_mask,
|
865
954
|
past_key_state=past_key_values[self.layer_idx][0],
|
866
955
|
past_value_state=past_key_values[self.layer_idx][1],
|
867
|
-
batch_position=None if self.phase == "decode" else batch_position,
|
868
956
|
seq_position=seq_positions,
|
869
957
|
scale=self.scale,
|
958
|
+
block_tables=block_tables,
|
959
|
+
kvcache_block_size=self.kvcache_block_size,
|
870
960
|
)
|
871
|
-
key_states = key_state
|
872
|
-
value_states = value_state
|
873
961
|
|
874
962
|
attn_outputs = self.o_proj(attn_output)
|
875
|
-
|
876
|
-
|
877
|
-
return attn_outputs, past_key_values
|
963
|
+
return attn_outputs
|
878
964
|
|
879
965
|
|
880
966
|
class FlashAttentionOp(AttentionOp):
|
881
|
-
def __init__(
|
882
|
-
|
967
|
+
def __init__(
|
968
|
+
self,
|
969
|
+
num_heads: int,
|
970
|
+
head_dim: int,
|
971
|
+
num_key_value_heads: int,
|
972
|
+
kvcache_partition_len: int,
|
973
|
+
use_attention_mask: bool,
|
974
|
+
):
|
975
|
+
super().__init__(
|
976
|
+
num_heads=num_heads,
|
977
|
+
head_dim=head_dim,
|
978
|
+
num_key_value_heads=num_key_value_heads,
|
979
|
+
use_attention_mask=use_attention_mask,
|
980
|
+
)
|
883
981
|
self.kvcache_partition_size = kvcache_partition_len
|
884
982
|
|
885
983
|
def forward(
|
@@ -888,16 +986,18 @@ class FlashAttentionOp(AttentionOp):
|
|
888
986
|
key_state,
|
889
987
|
value_state,
|
890
988
|
attn_mask,
|
891
|
-
batch_position,
|
892
989
|
past_key_state,
|
893
990
|
past_value_state,
|
894
991
|
seq_position,
|
895
992
|
scale,
|
993
|
+
block_tables,
|
994
|
+
kvcache_block_size,
|
896
995
|
):
|
897
996
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
898
997
|
key_state = key_state.unsqueeze(2)
|
899
998
|
value_state = value_state.unsqueeze(2)
|
900
|
-
|
999
|
+
if self.use_attention_mask:
|
1000
|
+
attn_mask = attn_mask.unsqueeze(2)
|
901
1001
|
|
902
1002
|
if self.phase == "decode":
|
903
1003
|
batch_size = key_state.shape[0]
|
@@ -913,34 +1013,65 @@ class FlashAttentionOp(AttentionOp):
|
|
913
1013
|
)
|
914
1014
|
|
915
1015
|
if self.phase == "decode":
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
1016
|
+
if self.use_attention_mask:
|
1017
|
+
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
|
1018
|
+
query_state,
|
1019
|
+
key_state,
|
1020
|
+
value_state,
|
1021
|
+
attn_mask,
|
1022
|
+
past_key_state.unsqueeze(2),
|
1023
|
+
past_value_state.unsqueeze(2),
|
1024
|
+
seq_position,
|
1025
|
+
scale,
|
1026
|
+
block_tables,
|
1027
|
+
kvcache_block_size,
|
1028
|
+
self.kvcache_partition_size,
|
1029
|
+
)
|
1030
|
+
else:
|
1031
|
+
attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
|
1032
|
+
query_state,
|
1033
|
+
key_state,
|
1034
|
+
value_state,
|
1035
|
+
past_key_state.unsqueeze(2),
|
1036
|
+
past_value_state.unsqueeze(2),
|
1037
|
+
seq_position,
|
1038
|
+
scale,
|
1039
|
+
block_tables,
|
1040
|
+
kvcache_block_size,
|
1041
|
+
self.kvcache_partition_size,
|
1042
|
+
)
|
927
1043
|
else:
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
1044
|
+
if self.use_attention_mask:
|
1045
|
+
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
|
1046
|
+
query_state,
|
1047
|
+
key_state,
|
1048
|
+
value_state,
|
1049
|
+
attn_mask,
|
1050
|
+
past_key_state.unsqueeze(2),
|
1051
|
+
past_value_state.unsqueeze(2),
|
1052
|
+
seq_position,
|
1053
|
+
scale,
|
1054
|
+
block_tables,
|
1055
|
+
kvcache_block_size,
|
1056
|
+
self.kvcache_partition_size,
|
1057
|
+
)
|
1058
|
+
else:
|
1059
|
+
attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
|
1060
|
+
query_state,
|
1061
|
+
key_state,
|
1062
|
+
value_state,
|
1063
|
+
past_key_state.unsqueeze(2),
|
1064
|
+
past_value_state.unsqueeze(2),
|
1065
|
+
seq_position,
|
1066
|
+
scale,
|
1067
|
+
block_tables,
|
1068
|
+
kvcache_block_size,
|
1069
|
+
self.kvcache_partition_size,
|
1070
|
+
)
|
940
1071
|
|
941
1072
|
# reshape for removing repeat_kv
|
942
1073
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
943
1074
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
944
1075
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
945
1076
|
|
946
|
-
return attn_output
|
1077
|
+
return attn_output
|