optimum-rbln 0.7.3a1__py3-none-any.whl → 0.7.3a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/ops/__init__.py +4 -4
- optimum/rbln/ops/attn.py +44 -84
- optimum/rbln/ops/flash_attn.py +25 -25
- optimum/rbln/transformers/models/bart/bart_architecture.py +10 -6
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +79 -51
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +157 -34
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +7 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +7 -2
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +3 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +3 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +5 -3
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +44 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +50 -19
- optimum/rbln/transformers/models/t5/modeling_t5.py +211 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +69 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +19 -24
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/RECORD +22 -22
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/licenses/LICENSE +0 -0
@@ -20,10 +20,10 @@ from torch import nn
|
|
20
20
|
from transformers import PretrainedConfig, PreTrainedModel
|
21
21
|
|
22
22
|
from ....ops import (
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
23
|
+
register_rbln_custom_paged_attention,
|
24
|
+
register_rbln_custom_paged_causal_attention,
|
25
|
+
register_rbln_custom_paged_flash_attention,
|
26
|
+
register_rbln_custom_paged_flash_causal_attention,
|
27
27
|
)
|
28
28
|
from ....utils import logging
|
29
29
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
@@ -39,7 +39,7 @@ MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
|
|
39
39
|
|
40
40
|
|
41
41
|
def validate_attention_method(
|
42
|
-
rbln_attn_impl: str, rbln_kvcache_partition_len: int, rbln_max_seq_len: int
|
42
|
+
rbln_attn_impl: str, rbln_kvcache_partition_len: int, rbln_kvcache_block_size: int, rbln_max_seq_len: int
|
43
43
|
) -> Tuple[str, int]:
|
44
44
|
if rbln_kvcache_partition_len is not None:
|
45
45
|
if rbln_attn_impl == "eager":
|
@@ -98,7 +98,7 @@ def validate_attention_method(
|
|
98
98
|
"this requirement, or consider switching `rbln_attn_impl` to 'eager' for shorter lengths."
|
99
99
|
)
|
100
100
|
|
101
|
-
return rbln_attn_impl, rbln_kvcache_partition_len
|
101
|
+
return rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size
|
102
102
|
|
103
103
|
|
104
104
|
class DecoderOnlyWrapper(nn.Module):
|
@@ -107,7 +107,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
107
107
|
This wrapper is designed to:
|
108
108
|
1. Convert Huggingface decoder models for RBLN compilation with static shapes
|
109
109
|
2. Handle input/model mapping and additional information supply (e.g., positional embeddings)
|
110
|
-
3. Manage different attention implementations (standard
|
110
|
+
3. Manage different attention implementations (standard/flash attention)
|
111
111
|
4. Support both prefill and decode phases
|
112
112
|
|
113
113
|
Notes:
|
@@ -135,6 +135,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
135
135
|
attn_impl: str,
|
136
136
|
use_attention_mask: bool,
|
137
137
|
kvcache_partition_len: Optional[int] = None,
|
138
|
+
kvcache_block_size: Optional[int] = None,
|
138
139
|
):
|
139
140
|
super().__init__()
|
140
141
|
self.config = causal_lm.config
|
@@ -145,19 +146,20 @@ class DecoderOnlyWrapper(nn.Module):
|
|
145
146
|
self.rotary_emb = None
|
146
147
|
|
147
148
|
self.attn_impl = attn_impl
|
149
|
+
self.kvcache_block_size = kvcache_block_size
|
148
150
|
self.use_attention_mask = use_attention_mask
|
149
151
|
if self.attn_impl == "flash_attn":
|
150
152
|
self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
151
153
|
if self.use_attention_mask:
|
152
|
-
|
154
|
+
register_rbln_custom_paged_flash_attention()
|
153
155
|
else:
|
154
|
-
|
156
|
+
register_rbln_custom_paged_flash_causal_attention()
|
155
157
|
elif self.attn_impl == "eager":
|
156
158
|
self.kvcache_partition_len = None
|
157
159
|
if self.use_attention_mask:
|
158
|
-
|
160
|
+
register_rbln_custom_paged_attention()
|
159
161
|
else:
|
160
|
-
|
162
|
+
register_rbln_custom_paged_causal_attention()
|
161
163
|
else:
|
162
164
|
raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
|
163
165
|
|
@@ -179,11 +181,14 @@ class DecoderOnlyWrapper(nn.Module):
|
|
179
181
|
new_layers = []
|
180
182
|
for layer in causal_lm.model.layers:
|
181
183
|
if self.attn_impl == "eager":
|
182
|
-
new_self_attn = DecoderOnlyAttention(
|
184
|
+
new_self_attn = DecoderOnlyAttention(
|
185
|
+
layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
186
|
+
)
|
183
187
|
elif self.attn_impl == "flash_attn":
|
184
188
|
new_self_attn = DecoderOnlyFlashAttention(
|
185
189
|
layer.self_attn,
|
186
190
|
kvcache_partition_len=self.kvcache_partition_len,
|
191
|
+
kvcache_block_size=self.kvcache_block_size,
|
187
192
|
use_attention_mask=self.use_attention_mask,
|
188
193
|
)
|
189
194
|
else:
|
@@ -192,7 +197,11 @@ class DecoderOnlyWrapper(nn.Module):
|
|
192
197
|
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
193
198
|
new_layers.append(new_layer)
|
194
199
|
new_model = DecoderOnlyModel(
|
195
|
-
causal_lm.model,
|
200
|
+
causal_lm.model,
|
201
|
+
new_layers,
|
202
|
+
partition_len=self.kvcache_partition_len,
|
203
|
+
max_seq_len=max_seq_len,
|
204
|
+
kvcache_block_size=self.kvcache_block_size,
|
196
205
|
)
|
197
206
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
198
207
|
return new_causal_lm
|
@@ -213,16 +222,17 @@ class DecoderOnlyWrapper(nn.Module):
|
|
213
222
|
input_ids_or_inputs_embeds,
|
214
223
|
cache_position,
|
215
224
|
attention_mask,
|
225
|
+
block_tables,
|
216
226
|
*past_key_values,
|
217
227
|
) = args
|
218
228
|
else:
|
219
229
|
(
|
220
230
|
input_ids_or_inputs_embeds,
|
221
231
|
cache_position,
|
232
|
+
block_tables,
|
222
233
|
*past_key_values,
|
223
234
|
) = args
|
224
235
|
attention_mask = None
|
225
|
-
batch_position = torch.tensor(0, dtype=torch.int16)
|
226
236
|
query_position = None
|
227
237
|
elif self.phase == "prefill":
|
228
238
|
if self.use_attention_mask:
|
@@ -230,16 +240,16 @@ class DecoderOnlyWrapper(nn.Module):
|
|
230
240
|
input_ids_or_inputs_embeds,
|
231
241
|
cache_position,
|
232
242
|
attention_mask,
|
233
|
-
batch_position,
|
234
243
|
query_position,
|
244
|
+
block_tables,
|
235
245
|
*past_key_values,
|
236
246
|
) = args
|
237
247
|
else:
|
238
248
|
(
|
239
249
|
input_ids_or_inputs_embeds,
|
240
250
|
cache_position,
|
241
|
-
batch_position,
|
242
251
|
query_position,
|
252
|
+
block_tables,
|
243
253
|
*past_key_values,
|
244
254
|
) = args
|
245
255
|
attention_mask = None
|
@@ -276,10 +286,10 @@ class DecoderOnlyWrapper(nn.Module):
|
|
276
286
|
inputs_embeds=inputs_embeds,
|
277
287
|
attention_mask=attention_mask,
|
278
288
|
cache_position=cache_position,
|
279
|
-
batch_position=batch_position,
|
280
289
|
query_position=query_position,
|
281
290
|
past_key_values=past_key_values,
|
282
291
|
rotary_emb=self.rotary_emb,
|
292
|
+
block_tables=block_tables,
|
283
293
|
)
|
284
294
|
|
285
295
|
# ((key, value)) * n_layer -> [key, value] * n_layer
|
@@ -337,10 +347,10 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
337
347
|
inputs_embeds: torch.Tensor = None,
|
338
348
|
attention_mask: torch.Tensor = None,
|
339
349
|
cache_position: torch.Tensor = None,
|
340
|
-
batch_position: torch.Tensor = None,
|
341
350
|
query_position: torch.Tensor = None,
|
342
351
|
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
343
352
|
rotary_emb: nn.Module = None,
|
353
|
+
block_tables: Optional[torch.Tensor] = None,
|
344
354
|
):
|
345
355
|
# outputs
|
346
356
|
hidden_states, present_key_values = self.model(
|
@@ -348,9 +358,9 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
348
358
|
inputs_embeds=inputs_embeds,
|
349
359
|
attention_mask=attention_mask,
|
350
360
|
cache_position=cache_position,
|
351
|
-
batch_position=batch_position,
|
352
361
|
past_key_values=past_key_values,
|
353
362
|
rotary_emb=rotary_emb,
|
363
|
+
block_tables=block_tables,
|
354
364
|
)
|
355
365
|
|
356
366
|
if self.phase == "prefill":
|
@@ -374,12 +384,15 @@ class DecoderOnlyModel(nn.Module):
|
|
374
384
|
_phase: Current processing phase ("prefill" or "decode")
|
375
385
|
"""
|
376
386
|
|
377
|
-
def __init__(
|
387
|
+
def __init__(
|
388
|
+
self, model, layers: List["DecoderOnlyLayer"], partition_len=None, max_seq_len=None, kvcache_block_size=None
|
389
|
+
):
|
378
390
|
super().__init__()
|
379
391
|
self._original_mod = model
|
380
392
|
self.layers = nn.ModuleList(layers)
|
381
393
|
self._phase = "prefill"
|
382
394
|
self.partition_len = partition_len
|
395
|
+
self.kvcache_block_size = kvcache_block_size
|
383
396
|
self.max_seq_len = max_seq_len
|
384
397
|
|
385
398
|
@property
|
@@ -401,9 +414,8 @@ class DecoderOnlyModel(nn.Module):
|
|
401
414
|
return 1
|
402
415
|
|
403
416
|
def convert_sequence_positions_for_flash_attn(self, seq_positions, max_seq_len):
|
404
|
-
if self.attn_impl
|
417
|
+
if self.attn_impl not in ["flash_attn"]:
|
405
418
|
raise NotImplementedError(f"Unknown attn_impl ({self.attn_impl}).")
|
406
|
-
|
407
419
|
partition_len = self.partition_len
|
408
420
|
num_partition = max_seq_len // partition_len
|
409
421
|
|
@@ -429,9 +441,9 @@ class DecoderOnlyModel(nn.Module):
|
|
429
441
|
inputs_embeds: torch.Tensor = None,
|
430
442
|
attention_mask: torch.Tensor = None,
|
431
443
|
cache_position: torch.Tensor = None,
|
432
|
-
batch_position: torch.Tensor = None,
|
433
444
|
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
434
445
|
rotary_emb: nn.Module = None,
|
446
|
+
block_tables: Optional[torch.Tensor] = None,
|
435
447
|
):
|
436
448
|
# retrieve input_ids and inputs_embeds
|
437
449
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
@@ -466,9 +478,8 @@ class DecoderOnlyModel(nn.Module):
|
|
466
478
|
# (batch, seq_len) -> (batch,)
|
467
479
|
if self.attn_impl == "flash_attn":
|
468
480
|
seq_positions = cache_position[:, 0]
|
469
|
-
max_seq_len = past_key_values[0][0].shape[-2]
|
470
481
|
seq_positions = self.convert_sequence_positions_for_flash_attn(
|
471
|
-
seq_positions=seq_positions, max_seq_len=max_seq_len
|
482
|
+
seq_positions=seq_positions, max_seq_len=self.max_seq_len
|
472
483
|
)
|
473
484
|
else:
|
474
485
|
seq_positions = cache_position[:, :1]
|
@@ -479,10 +490,10 @@ class DecoderOnlyModel(nn.Module):
|
|
479
490
|
hidden_states=hidden_states,
|
480
491
|
attention_mask=attention_mask,
|
481
492
|
seq_positions=seq_positions,
|
482
|
-
batch_position=batch_position,
|
483
493
|
past_key_values=present_key_values,
|
484
494
|
cos=cos,
|
485
495
|
sin=sin,
|
496
|
+
block_tables=block_tables,
|
486
497
|
)
|
487
498
|
|
488
499
|
hidden_states = self.get_last_layernorm()(hidden_states)
|
@@ -540,10 +551,10 @@ class DecoderOnlyLayer(nn.Module):
|
|
540
551
|
hidden_states: torch.Tensor,
|
541
552
|
attention_mask: torch.Tensor,
|
542
553
|
seq_positions: torch.LongTensor,
|
543
|
-
batch_position: torch.Tensor,
|
544
554
|
past_key_values: Tuple[Tuple[torch.Tensor]],
|
545
555
|
cos: Optional[torch.Tensor] = None,
|
546
556
|
sin: Optional[torch.Tensor] = None,
|
557
|
+
block_tables: Optional[torch.Tensor] = None,
|
547
558
|
):
|
548
559
|
residual = hidden_states
|
549
560
|
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
@@ -552,10 +563,10 @@ class DecoderOnlyLayer(nn.Module):
|
|
552
563
|
hidden_states=hidden_states,
|
553
564
|
attention_mask=attention_mask,
|
554
565
|
seq_positions=seq_positions,
|
555
|
-
batch_position=batch_position,
|
556
566
|
past_key_values=past_key_values,
|
557
567
|
cos=cos,
|
558
568
|
sin=sin,
|
569
|
+
block_tables=block_tables,
|
559
570
|
)
|
560
571
|
hidden_states = residual + hidden_states
|
561
572
|
|
@@ -579,7 +590,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
579
590
|
self_attn: Original attention module from the base model
|
580
591
|
"""
|
581
592
|
|
582
|
-
def __init__(self, self_attn, use_attention_mask):
|
593
|
+
def __init__(self, self_attn, use_attention_mask, kvcache_block_size):
|
583
594
|
super().__init__()
|
584
595
|
self._original_mod = self_attn
|
585
596
|
self.layer_idx = self_attn.layer_idx
|
@@ -599,6 +610,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
599
610
|
|
600
611
|
self.use_attention_mask = use_attention_mask
|
601
612
|
self.attention = self.get_attention()
|
613
|
+
self.kvcache_block_size = kvcache_block_size
|
602
614
|
self.__post_init__()
|
603
615
|
|
604
616
|
@property
|
@@ -644,10 +656,10 @@ class DecoderOnlyAttention(nn.Module):
|
|
644
656
|
hidden_states: torch.Tensor,
|
645
657
|
attention_mask: torch.Tensor,
|
646
658
|
seq_positions: torch.LongTensor,
|
647
|
-
batch_position: torch.Tensor,
|
648
659
|
past_key_values: Tuple[Tuple[torch.Tensor]],
|
649
660
|
cos: Optional[torch.Tensor] = None,
|
650
661
|
sin: Optional[torch.Tensor] = None,
|
662
|
+
block_tables: Optional[torch.Tensor] = None,
|
651
663
|
):
|
652
664
|
batch_size, query_length, _ = hidden_states.size()
|
653
665
|
|
@@ -673,9 +685,10 @@ class DecoderOnlyAttention(nn.Module):
|
|
673
685
|
attention_mask,
|
674
686
|
past_key_state=past_key_values[self.layer_idx][0],
|
675
687
|
past_value_state=past_key_values[self.layer_idx][1],
|
676
|
-
batch_position=None if self.phase == "decode" else batch_position,
|
677
688
|
seq_position=seq_positions,
|
678
689
|
scale=self.scale,
|
690
|
+
block_tables=block_tables,
|
691
|
+
block_size=self.kvcache_block_size,
|
679
692
|
)
|
680
693
|
key_states = key_state
|
681
694
|
value_states = value_state
|
@@ -700,11 +713,12 @@ class AttentionOp(nn.Module):
|
|
700
713
|
key_state: torch.Tensor,
|
701
714
|
value_state: torch.Tensor,
|
702
715
|
attn_mask: torch.Tensor,
|
703
|
-
batch_position: torch.Tensor,
|
704
716
|
past_key_state: torch.Tensor,
|
705
717
|
past_value_state: torch.Tensor,
|
706
718
|
seq_position: torch.Tensor,
|
707
719
|
scale: torch.Tensor,
|
720
|
+
block_tables: torch.Tensor,
|
721
|
+
block_size: int,
|
708
722
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
709
723
|
"""Compute attention with static shapes and explicit cache management.
|
710
724
|
|
@@ -713,7 +727,6 @@ class AttentionOp(nn.Module):
|
|
713
727
|
key_state: Key tensor [1, num_heads, seq_len, head_dim]
|
714
728
|
value_state: Value tensor [1, num_heads, seq_len, head_dim]
|
715
729
|
attn_mask: Attention mask tensor ∈ {0, 1}
|
716
|
-
batch_position: Batch index for cache lookup
|
717
730
|
past_key_state: Previous key cache states
|
718
731
|
past_value_state: Previous value cache states
|
719
732
|
seq_position: Current position in sequence
|
@@ -743,7 +756,7 @@ class AttentionOp(nn.Module):
|
|
743
756
|
|
744
757
|
if self.phase == "decode":
|
745
758
|
if self.use_attention_mask:
|
746
|
-
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.
|
759
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_attn_decode(
|
747
760
|
query_state,
|
748
761
|
key_state,
|
749
762
|
value_state,
|
@@ -752,9 +765,11 @@ class AttentionOp(nn.Module):
|
|
752
765
|
past_value_state.unsqueeze(2),
|
753
766
|
seq_position,
|
754
767
|
scale,
|
768
|
+
block_tables,
|
769
|
+
block_size,
|
755
770
|
)
|
756
771
|
else:
|
757
|
-
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.
|
772
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
|
758
773
|
query_state,
|
759
774
|
key_state,
|
760
775
|
value_state,
|
@@ -762,31 +777,35 @@ class AttentionOp(nn.Module):
|
|
762
777
|
past_value_state.unsqueeze(2),
|
763
778
|
seq_position,
|
764
779
|
scale,
|
780
|
+
block_tables,
|
781
|
+
block_size,
|
765
782
|
)
|
766
783
|
|
767
784
|
else:
|
768
785
|
if self.use_attention_mask:
|
769
|
-
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.
|
786
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_attn_prefill(
|
770
787
|
query_state,
|
771
788
|
key_state,
|
772
789
|
value_state,
|
773
790
|
attn_mask,
|
774
791
|
past_key_state.unsqueeze(2),
|
775
792
|
past_value_state.unsqueeze(2),
|
776
|
-
batch_position,
|
777
793
|
seq_position,
|
778
794
|
scale,
|
795
|
+
block_tables,
|
796
|
+
block_size,
|
779
797
|
)
|
780
798
|
else:
|
781
|
-
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.
|
799
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
|
782
800
|
query_state,
|
783
801
|
key_state,
|
784
802
|
value_state,
|
785
803
|
past_key_state.unsqueeze(2),
|
786
804
|
past_value_state.unsqueeze(2),
|
787
|
-
batch_position,
|
788
805
|
seq_position,
|
789
806
|
scale,
|
807
|
+
block_tables,
|
808
|
+
block_size,
|
790
809
|
)
|
791
810
|
|
792
811
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
@@ -889,10 +908,11 @@ class RotaryEmbedding(nn.Module):
|
|
889
908
|
|
890
909
|
|
891
910
|
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
892
|
-
def __init__(self, self_attn, kvcache_partition_len, use_attention_mask):
|
911
|
+
def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask):
|
893
912
|
self.kvcache_partition_size = kvcache_partition_len
|
894
|
-
|
895
|
-
|
913
|
+
super().__init__(
|
914
|
+
self_attn=self_attn, use_attention_mask=use_attention_mask, kvcache_block_size=kvcache_block_size
|
915
|
+
)
|
896
916
|
|
897
917
|
def get_attention(self):
|
898
918
|
return FlashAttentionOp(
|
@@ -908,10 +928,10 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
908
928
|
hidden_states: torch.Tensor,
|
909
929
|
attention_mask: torch.Tensor,
|
910
930
|
seq_positions: torch.LongTensor,
|
911
|
-
batch_position: torch.Tensor,
|
912
931
|
past_key_values: Tuple[Tuple[torch.Tensor]],
|
913
932
|
cos: Optional[torch.Tensor] = None,
|
914
933
|
sin: Optional[torch.Tensor] = None,
|
934
|
+
block_tables: Optional[torch.Tensor] = None,
|
915
935
|
):
|
916
936
|
batch_size, query_length, _ = hidden_states.size()
|
917
937
|
|
@@ -934,9 +954,10 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
934
954
|
attention_mask,
|
935
955
|
past_key_state=past_key_values[self.layer_idx][0],
|
936
956
|
past_value_state=past_key_values[self.layer_idx][1],
|
937
|
-
batch_position=None if self.phase == "decode" else batch_position,
|
938
957
|
seq_position=seq_positions,
|
939
958
|
scale=self.scale,
|
959
|
+
block_tables=block_tables,
|
960
|
+
kvcache_block_size=self.kvcache_block_size,
|
940
961
|
)
|
941
962
|
key_states = key_state
|
942
963
|
value_states = value_state
|
@@ -970,11 +991,12 @@ class FlashAttentionOp(AttentionOp):
|
|
970
991
|
key_state,
|
971
992
|
value_state,
|
972
993
|
attn_mask,
|
973
|
-
batch_position,
|
974
994
|
past_key_state,
|
975
995
|
past_value_state,
|
976
996
|
seq_position,
|
977
997
|
scale,
|
998
|
+
block_tables,
|
999
|
+
kvcache_block_size,
|
978
1000
|
):
|
979
1001
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
980
1002
|
key_state = key_state.unsqueeze(2)
|
@@ -997,7 +1019,7 @@ class FlashAttentionOp(AttentionOp):
|
|
997
1019
|
|
998
1020
|
if self.phase == "decode":
|
999
1021
|
if self.use_attention_mask:
|
1000
|
-
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.
|
1022
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
|
1001
1023
|
query_state,
|
1002
1024
|
key_state,
|
1003
1025
|
value_state,
|
@@ -1006,10 +1028,12 @@ class FlashAttentionOp(AttentionOp):
|
|
1006
1028
|
past_value_state.unsqueeze(2),
|
1007
1029
|
seq_position,
|
1008
1030
|
scale,
|
1031
|
+
block_tables,
|
1032
|
+
kvcache_block_size,
|
1009
1033
|
self.kvcache_partition_size,
|
1010
1034
|
)
|
1011
1035
|
else:
|
1012
|
-
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.
|
1036
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
|
1013
1037
|
query_state,
|
1014
1038
|
key_state,
|
1015
1039
|
value_state,
|
@@ -1017,32 +1041,36 @@ class FlashAttentionOp(AttentionOp):
|
|
1017
1041
|
past_value_state.unsqueeze(2),
|
1018
1042
|
seq_position,
|
1019
1043
|
scale,
|
1044
|
+
block_tables,
|
1045
|
+
kvcache_block_size,
|
1020
1046
|
self.kvcache_partition_size,
|
1021
1047
|
)
|
1022
1048
|
else:
|
1023
1049
|
if self.use_attention_mask:
|
1024
|
-
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.
|
1050
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
|
1025
1051
|
query_state,
|
1026
1052
|
key_state,
|
1027
1053
|
value_state,
|
1028
1054
|
attn_mask,
|
1029
1055
|
past_key_state.unsqueeze(2),
|
1030
1056
|
past_value_state.unsqueeze(2),
|
1031
|
-
batch_position,
|
1032
1057
|
seq_position,
|
1033
1058
|
scale,
|
1059
|
+
block_tables,
|
1060
|
+
kvcache_block_size,
|
1034
1061
|
self.kvcache_partition_size,
|
1035
1062
|
)
|
1036
1063
|
else:
|
1037
|
-
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.
|
1064
|
+
attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
|
1038
1065
|
query_state,
|
1039
1066
|
key_state,
|
1040
1067
|
value_state,
|
1041
1068
|
past_key_state.unsqueeze(2),
|
1042
1069
|
past_value_state.unsqueeze(2),
|
1043
|
-
batch_position,
|
1044
1070
|
seq_position,
|
1045
1071
|
scale,
|
1072
|
+
block_tables,
|
1073
|
+
kvcache_block_size,
|
1046
1074
|
self.kvcache_partition_size,
|
1047
1075
|
)
|
1048
1076
|
|