optimum-rbln 0.7.3a1__py3-none-any.whl → 0.7.3a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,10 +20,10 @@ from torch import nn
20
20
  from transformers import PretrainedConfig, PreTrainedModel
21
21
 
22
22
  from ....ops import (
23
- register_rbln_custom_causal_masked_attention,
24
- register_rbln_custom_flash_causal_masked_attention,
25
- register_rbln_custom_flash_masked_attention,
26
- register_rbln_custom_masked_attention,
23
+ register_rbln_custom_paged_attention,
24
+ register_rbln_custom_paged_causal_attention,
25
+ register_rbln_custom_paged_flash_attention,
26
+ register_rbln_custom_paged_flash_causal_attention,
27
27
  )
28
28
  from ....utils import logging
29
29
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
@@ -39,7 +39,7 @@ MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
39
39
 
40
40
 
41
41
  def validate_attention_method(
42
- rbln_attn_impl: str, rbln_kvcache_partition_len: int, rbln_max_seq_len: int
42
+ rbln_attn_impl: str, rbln_kvcache_partition_len: int, rbln_kvcache_block_size: int, rbln_max_seq_len: int
43
43
  ) -> Tuple[str, int]:
44
44
  if rbln_kvcache_partition_len is not None:
45
45
  if rbln_attn_impl == "eager":
@@ -98,7 +98,7 @@ def validate_attention_method(
98
98
  "this requirement, or consider switching `rbln_attn_impl` to 'eager' for shorter lengths."
99
99
  )
100
100
 
101
- return rbln_attn_impl, rbln_kvcache_partition_len
101
+ return rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size
102
102
 
103
103
 
104
104
  class DecoderOnlyWrapper(nn.Module):
@@ -107,7 +107,7 @@ class DecoderOnlyWrapper(nn.Module):
107
107
  This wrapper is designed to:
108
108
  1. Convert Huggingface decoder models for RBLN compilation with static shapes
109
109
  2. Handle input/model mapping and additional information supply (e.g., positional embeddings)
110
- 3. Manage different attention implementations (standard and flash attention)
110
+ 3. Manage different attention implementations (standard/flash attention)
111
111
  4. Support both prefill and decode phases
112
112
 
113
113
  Notes:
@@ -135,6 +135,7 @@ class DecoderOnlyWrapper(nn.Module):
135
135
  attn_impl: str,
136
136
  use_attention_mask: bool,
137
137
  kvcache_partition_len: Optional[int] = None,
138
+ kvcache_block_size: Optional[int] = None,
138
139
  ):
139
140
  super().__init__()
140
141
  self.config = causal_lm.config
@@ -145,19 +146,20 @@ class DecoderOnlyWrapper(nn.Module):
145
146
  self.rotary_emb = None
146
147
 
147
148
  self.attn_impl = attn_impl
149
+ self.kvcache_block_size = kvcache_block_size
148
150
  self.use_attention_mask = use_attention_mask
149
151
  if self.attn_impl == "flash_attn":
150
152
  self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
151
153
  if self.use_attention_mask:
152
- register_rbln_custom_flash_masked_attention()
154
+ register_rbln_custom_paged_flash_attention()
153
155
  else:
154
- register_rbln_custom_flash_causal_masked_attention()
156
+ register_rbln_custom_paged_flash_causal_attention()
155
157
  elif self.attn_impl == "eager":
156
158
  self.kvcache_partition_len = None
157
159
  if self.use_attention_mask:
158
- register_rbln_custom_masked_attention()
160
+ register_rbln_custom_paged_attention()
159
161
  else:
160
- register_rbln_custom_causal_masked_attention()
162
+ register_rbln_custom_paged_causal_attention()
161
163
  else:
162
164
  raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
163
165
 
@@ -179,11 +181,14 @@ class DecoderOnlyWrapper(nn.Module):
179
181
  new_layers = []
180
182
  for layer in causal_lm.model.layers:
181
183
  if self.attn_impl == "eager":
182
- new_self_attn = DecoderOnlyAttention(layer.self_attn, self.use_attention_mask)
184
+ new_self_attn = DecoderOnlyAttention(
185
+ layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
186
+ )
183
187
  elif self.attn_impl == "flash_attn":
184
188
  new_self_attn = DecoderOnlyFlashAttention(
185
189
  layer.self_attn,
186
190
  kvcache_partition_len=self.kvcache_partition_len,
191
+ kvcache_block_size=self.kvcache_block_size,
187
192
  use_attention_mask=self.use_attention_mask,
188
193
  )
189
194
  else:
@@ -192,7 +197,11 @@ class DecoderOnlyWrapper(nn.Module):
192
197
  new_layer = DecoderOnlyLayer(layer, new_self_attn)
193
198
  new_layers.append(new_layer)
194
199
  new_model = DecoderOnlyModel(
195
- causal_lm.model, new_layers, partition_len=self.kvcache_partition_len, max_seq_len=max_seq_len
200
+ causal_lm.model,
201
+ new_layers,
202
+ partition_len=self.kvcache_partition_len,
203
+ max_seq_len=max_seq_len,
204
+ kvcache_block_size=self.kvcache_block_size,
196
205
  )
197
206
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
198
207
  return new_causal_lm
@@ -213,16 +222,17 @@ class DecoderOnlyWrapper(nn.Module):
213
222
  input_ids_or_inputs_embeds,
214
223
  cache_position,
215
224
  attention_mask,
225
+ block_tables,
216
226
  *past_key_values,
217
227
  ) = args
218
228
  else:
219
229
  (
220
230
  input_ids_or_inputs_embeds,
221
231
  cache_position,
232
+ block_tables,
222
233
  *past_key_values,
223
234
  ) = args
224
235
  attention_mask = None
225
- batch_position = torch.tensor(0, dtype=torch.int16)
226
236
  query_position = None
227
237
  elif self.phase == "prefill":
228
238
  if self.use_attention_mask:
@@ -230,16 +240,16 @@ class DecoderOnlyWrapper(nn.Module):
230
240
  input_ids_or_inputs_embeds,
231
241
  cache_position,
232
242
  attention_mask,
233
- batch_position,
234
243
  query_position,
244
+ block_tables,
235
245
  *past_key_values,
236
246
  ) = args
237
247
  else:
238
248
  (
239
249
  input_ids_or_inputs_embeds,
240
250
  cache_position,
241
- batch_position,
242
251
  query_position,
252
+ block_tables,
243
253
  *past_key_values,
244
254
  ) = args
245
255
  attention_mask = None
@@ -276,10 +286,10 @@ class DecoderOnlyWrapper(nn.Module):
276
286
  inputs_embeds=inputs_embeds,
277
287
  attention_mask=attention_mask,
278
288
  cache_position=cache_position,
279
- batch_position=batch_position,
280
289
  query_position=query_position,
281
290
  past_key_values=past_key_values,
282
291
  rotary_emb=self.rotary_emb,
292
+ block_tables=block_tables,
283
293
  )
284
294
 
285
295
  # ((key, value)) * n_layer -> [key, value] * n_layer
@@ -337,10 +347,10 @@ class DecoderOnlyForCausalLM(nn.Module):
337
347
  inputs_embeds: torch.Tensor = None,
338
348
  attention_mask: torch.Tensor = None,
339
349
  cache_position: torch.Tensor = None,
340
- batch_position: torch.Tensor = None,
341
350
  query_position: torch.Tensor = None,
342
351
  past_key_values: Tuple[Tuple[torch.Tensor]] = None,
343
352
  rotary_emb: nn.Module = None,
353
+ block_tables: Optional[torch.Tensor] = None,
344
354
  ):
345
355
  # outputs
346
356
  hidden_states, present_key_values = self.model(
@@ -348,9 +358,9 @@ class DecoderOnlyForCausalLM(nn.Module):
348
358
  inputs_embeds=inputs_embeds,
349
359
  attention_mask=attention_mask,
350
360
  cache_position=cache_position,
351
- batch_position=batch_position,
352
361
  past_key_values=past_key_values,
353
362
  rotary_emb=rotary_emb,
363
+ block_tables=block_tables,
354
364
  )
355
365
 
356
366
  if self.phase == "prefill":
@@ -374,12 +384,15 @@ class DecoderOnlyModel(nn.Module):
374
384
  _phase: Current processing phase ("prefill" or "decode")
375
385
  """
376
386
 
377
- def __init__(self, model, layers: List["DecoderOnlyLayer"], partition_len=None, max_seq_len=None):
387
+ def __init__(
388
+ self, model, layers: List["DecoderOnlyLayer"], partition_len=None, max_seq_len=None, kvcache_block_size=None
389
+ ):
378
390
  super().__init__()
379
391
  self._original_mod = model
380
392
  self.layers = nn.ModuleList(layers)
381
393
  self._phase = "prefill"
382
394
  self.partition_len = partition_len
395
+ self.kvcache_block_size = kvcache_block_size
383
396
  self.max_seq_len = max_seq_len
384
397
 
385
398
  @property
@@ -401,9 +414,8 @@ class DecoderOnlyModel(nn.Module):
401
414
  return 1
402
415
 
403
416
  def convert_sequence_positions_for_flash_attn(self, seq_positions, max_seq_len):
404
- if self.attn_impl != "flash_attn":
417
+ if self.attn_impl not in ["flash_attn"]:
405
418
  raise NotImplementedError(f"Unknown attn_impl ({self.attn_impl}).")
406
-
407
419
  partition_len = self.partition_len
408
420
  num_partition = max_seq_len // partition_len
409
421
 
@@ -429,9 +441,9 @@ class DecoderOnlyModel(nn.Module):
429
441
  inputs_embeds: torch.Tensor = None,
430
442
  attention_mask: torch.Tensor = None,
431
443
  cache_position: torch.Tensor = None,
432
- batch_position: torch.Tensor = None,
433
444
  past_key_values: Tuple[Tuple[torch.Tensor]] = None,
434
445
  rotary_emb: nn.Module = None,
446
+ block_tables: Optional[torch.Tensor] = None,
435
447
  ):
436
448
  # retrieve input_ids and inputs_embeds
437
449
  if (input_ids is None) ^ (inputs_embeds is not None):
@@ -466,9 +478,8 @@ class DecoderOnlyModel(nn.Module):
466
478
  # (batch, seq_len) -> (batch,)
467
479
  if self.attn_impl == "flash_attn":
468
480
  seq_positions = cache_position[:, 0]
469
- max_seq_len = past_key_values[0][0].shape[-2]
470
481
  seq_positions = self.convert_sequence_positions_for_flash_attn(
471
- seq_positions=seq_positions, max_seq_len=max_seq_len
482
+ seq_positions=seq_positions, max_seq_len=self.max_seq_len
472
483
  )
473
484
  else:
474
485
  seq_positions = cache_position[:, :1]
@@ -479,10 +490,10 @@ class DecoderOnlyModel(nn.Module):
479
490
  hidden_states=hidden_states,
480
491
  attention_mask=attention_mask,
481
492
  seq_positions=seq_positions,
482
- batch_position=batch_position,
483
493
  past_key_values=present_key_values,
484
494
  cos=cos,
485
495
  sin=sin,
496
+ block_tables=block_tables,
486
497
  )
487
498
 
488
499
  hidden_states = self.get_last_layernorm()(hidden_states)
@@ -540,10 +551,10 @@ class DecoderOnlyLayer(nn.Module):
540
551
  hidden_states: torch.Tensor,
541
552
  attention_mask: torch.Tensor,
542
553
  seq_positions: torch.LongTensor,
543
- batch_position: torch.Tensor,
544
554
  past_key_values: Tuple[Tuple[torch.Tensor]],
545
555
  cos: Optional[torch.Tensor] = None,
546
556
  sin: Optional[torch.Tensor] = None,
557
+ block_tables: Optional[torch.Tensor] = None,
547
558
  ):
548
559
  residual = hidden_states
549
560
  hidden_states = self.get_pre_attention_layernorm()(hidden_states)
@@ -552,10 +563,10 @@ class DecoderOnlyLayer(nn.Module):
552
563
  hidden_states=hidden_states,
553
564
  attention_mask=attention_mask,
554
565
  seq_positions=seq_positions,
555
- batch_position=batch_position,
556
566
  past_key_values=past_key_values,
557
567
  cos=cos,
558
568
  sin=sin,
569
+ block_tables=block_tables,
559
570
  )
560
571
  hidden_states = residual + hidden_states
561
572
 
@@ -579,7 +590,7 @@ class DecoderOnlyAttention(nn.Module):
579
590
  self_attn: Original attention module from the base model
580
591
  """
581
592
 
582
- def __init__(self, self_attn, use_attention_mask):
593
+ def __init__(self, self_attn, use_attention_mask, kvcache_block_size):
583
594
  super().__init__()
584
595
  self._original_mod = self_attn
585
596
  self.layer_idx = self_attn.layer_idx
@@ -599,6 +610,7 @@ class DecoderOnlyAttention(nn.Module):
599
610
 
600
611
  self.use_attention_mask = use_attention_mask
601
612
  self.attention = self.get_attention()
613
+ self.kvcache_block_size = kvcache_block_size
602
614
  self.__post_init__()
603
615
 
604
616
  @property
@@ -644,10 +656,10 @@ class DecoderOnlyAttention(nn.Module):
644
656
  hidden_states: torch.Tensor,
645
657
  attention_mask: torch.Tensor,
646
658
  seq_positions: torch.LongTensor,
647
- batch_position: torch.Tensor,
648
659
  past_key_values: Tuple[Tuple[torch.Tensor]],
649
660
  cos: Optional[torch.Tensor] = None,
650
661
  sin: Optional[torch.Tensor] = None,
662
+ block_tables: Optional[torch.Tensor] = None,
651
663
  ):
652
664
  batch_size, query_length, _ = hidden_states.size()
653
665
 
@@ -673,9 +685,10 @@ class DecoderOnlyAttention(nn.Module):
673
685
  attention_mask,
674
686
  past_key_state=past_key_values[self.layer_idx][0],
675
687
  past_value_state=past_key_values[self.layer_idx][1],
676
- batch_position=None if self.phase == "decode" else batch_position,
677
688
  seq_position=seq_positions,
678
689
  scale=self.scale,
690
+ block_tables=block_tables,
691
+ block_size=self.kvcache_block_size,
679
692
  )
680
693
  key_states = key_state
681
694
  value_states = value_state
@@ -700,11 +713,12 @@ class AttentionOp(nn.Module):
700
713
  key_state: torch.Tensor,
701
714
  value_state: torch.Tensor,
702
715
  attn_mask: torch.Tensor,
703
- batch_position: torch.Tensor,
704
716
  past_key_state: torch.Tensor,
705
717
  past_value_state: torch.Tensor,
706
718
  seq_position: torch.Tensor,
707
719
  scale: torch.Tensor,
720
+ block_tables: torch.Tensor,
721
+ block_size: int,
708
722
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
709
723
  """Compute attention with static shapes and explicit cache management.
710
724
 
@@ -713,7 +727,6 @@ class AttentionOp(nn.Module):
713
727
  key_state: Key tensor [1, num_heads, seq_len, head_dim]
714
728
  value_state: Value tensor [1, num_heads, seq_len, head_dim]
715
729
  attn_mask: Attention mask tensor ∈ {0, 1}
716
- batch_position: Batch index for cache lookup
717
730
  past_key_state: Previous key cache states
718
731
  past_value_state: Previous value cache states
719
732
  seq_position: Current position in sequence
@@ -743,7 +756,7 @@ class AttentionOp(nn.Module):
743
756
 
744
757
  if self.phase == "decode":
745
758
  if self.use_attention_mask:
746
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.masked_attn_decode(
759
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_attn_decode(
747
760
  query_state,
748
761
  key_state,
749
762
  value_state,
@@ -752,9 +765,11 @@ class AttentionOp(nn.Module):
752
765
  past_value_state.unsqueeze(2),
753
766
  seq_position,
754
767
  scale,
768
+ block_tables,
769
+ block_size,
755
770
  )
756
771
  else:
757
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.causal_masked_attn_decode(
772
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
758
773
  query_state,
759
774
  key_state,
760
775
  value_state,
@@ -762,31 +777,35 @@ class AttentionOp(nn.Module):
762
777
  past_value_state.unsqueeze(2),
763
778
  seq_position,
764
779
  scale,
780
+ block_tables,
781
+ block_size,
765
782
  )
766
783
 
767
784
  else:
768
785
  if self.use_attention_mask:
769
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.masked_attn_prefill(
786
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_attn_prefill(
770
787
  query_state,
771
788
  key_state,
772
789
  value_state,
773
790
  attn_mask,
774
791
  past_key_state.unsqueeze(2),
775
792
  past_value_state.unsqueeze(2),
776
- batch_position,
777
793
  seq_position,
778
794
  scale,
795
+ block_tables,
796
+ block_size,
779
797
  )
780
798
  else:
781
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.causal_masked_attn_prefill(
799
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
782
800
  query_state,
783
801
  key_state,
784
802
  value_state,
785
803
  past_key_state.unsqueeze(2),
786
804
  past_value_state.unsqueeze(2),
787
- batch_position,
788
805
  seq_position,
789
806
  scale,
807
+ block_tables,
808
+ block_size,
790
809
  )
791
810
 
792
811
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
@@ -889,10 +908,11 @@ class RotaryEmbedding(nn.Module):
889
908
 
890
909
 
891
910
  class DecoderOnlyFlashAttention(DecoderOnlyAttention):
892
- def __init__(self, self_attn, kvcache_partition_len, use_attention_mask):
911
+ def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask):
893
912
  self.kvcache_partition_size = kvcache_partition_len
894
- # self.use_attention_mask = use_attention_mask
895
- super().__init__(self_attn=self_attn, use_attention_mask=use_attention_mask)
913
+ super().__init__(
914
+ self_attn=self_attn, use_attention_mask=use_attention_mask, kvcache_block_size=kvcache_block_size
915
+ )
896
916
 
897
917
  def get_attention(self):
898
918
  return FlashAttentionOp(
@@ -908,10 +928,10 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
908
928
  hidden_states: torch.Tensor,
909
929
  attention_mask: torch.Tensor,
910
930
  seq_positions: torch.LongTensor,
911
- batch_position: torch.Tensor,
912
931
  past_key_values: Tuple[Tuple[torch.Tensor]],
913
932
  cos: Optional[torch.Tensor] = None,
914
933
  sin: Optional[torch.Tensor] = None,
934
+ block_tables: Optional[torch.Tensor] = None,
915
935
  ):
916
936
  batch_size, query_length, _ = hidden_states.size()
917
937
 
@@ -934,9 +954,10 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
934
954
  attention_mask,
935
955
  past_key_state=past_key_values[self.layer_idx][0],
936
956
  past_value_state=past_key_values[self.layer_idx][1],
937
- batch_position=None if self.phase == "decode" else batch_position,
938
957
  seq_position=seq_positions,
939
958
  scale=self.scale,
959
+ block_tables=block_tables,
960
+ kvcache_block_size=self.kvcache_block_size,
940
961
  )
941
962
  key_states = key_state
942
963
  value_states = value_state
@@ -970,11 +991,12 @@ class FlashAttentionOp(AttentionOp):
970
991
  key_state,
971
992
  value_state,
972
993
  attn_mask,
973
- batch_position,
974
994
  past_key_state,
975
995
  past_value_state,
976
996
  seq_position,
977
997
  scale,
998
+ block_tables,
999
+ kvcache_block_size,
978
1000
  ):
979
1001
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
980
1002
  key_state = key_state.unsqueeze(2)
@@ -997,7 +1019,7 @@ class FlashAttentionOp(AttentionOp):
997
1019
 
998
1020
  if self.phase == "decode":
999
1021
  if self.use_attention_mask:
1000
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_masked_attn_decode(
1022
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
1001
1023
  query_state,
1002
1024
  key_state,
1003
1025
  value_state,
@@ -1006,10 +1028,12 @@ class FlashAttentionOp(AttentionOp):
1006
1028
  past_value_state.unsqueeze(2),
1007
1029
  seq_position,
1008
1030
  scale,
1031
+ block_tables,
1032
+ kvcache_block_size,
1009
1033
  self.kvcache_partition_size,
1010
1034
  )
1011
1035
  else:
1012
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_causal_masked_attn_decode(
1036
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
1013
1037
  query_state,
1014
1038
  key_state,
1015
1039
  value_state,
@@ -1017,32 +1041,36 @@ class FlashAttentionOp(AttentionOp):
1017
1041
  past_value_state.unsqueeze(2),
1018
1042
  seq_position,
1019
1043
  scale,
1044
+ block_tables,
1045
+ kvcache_block_size,
1020
1046
  self.kvcache_partition_size,
1021
1047
  )
1022
1048
  else:
1023
1049
  if self.use_attention_mask:
1024
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_masked_attn_prefill(
1050
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
1025
1051
  query_state,
1026
1052
  key_state,
1027
1053
  value_state,
1028
1054
  attn_mask,
1029
1055
  past_key_state.unsqueeze(2),
1030
1056
  past_value_state.unsqueeze(2),
1031
- batch_position,
1032
1057
  seq_position,
1033
1058
  scale,
1059
+ block_tables,
1060
+ kvcache_block_size,
1034
1061
  self.kvcache_partition_size,
1035
1062
  )
1036
1063
  else:
1037
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_causal_masked_attn_prefill(
1064
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
1038
1065
  query_state,
1039
1066
  key_state,
1040
1067
  value_state,
1041
1068
  past_key_state.unsqueeze(2),
1042
1069
  past_value_state.unsqueeze(2),
1043
- batch_position,
1044
1070
  seq_position,
1045
1071
  scale,
1072
+ block_tables,
1073
+ kvcache_block_size,
1046
1074
  self.kvcache_partition_size,
1047
1075
  )
1048
1076