optimum-rbln 0.7.5a0__py3-none-any.whl → 0.7.5rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +30 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +9 -4
- optimum/rbln/modeling.py +7 -5
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +32 -3
- optimum/rbln/transformers/models/__init__.py +37 -0
- optimum/rbln/transformers/models/auto/__init__.py +1 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +93 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +298 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +12 -6
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +189 -90
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +186 -95
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +69 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +446 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1057 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +11 -7
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +19 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +80 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +77 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +4 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -11
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +35 -52
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -0
- optimum/rbln/transformers/models/siglip/__init__.py +20 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +66 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +146 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +121 -72
- optimum/rbln/utils/submodule.py +13 -1
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/RECORD +46 -31
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/licenses/LICENSE +0 -0
@@ -146,7 +146,10 @@ class DecoderOnlyWrapper(nn.Module):
|
|
146
146
|
max_seq_len: int,
|
147
147
|
use_rotary_emb: bool,
|
148
148
|
attn_impl: str,
|
149
|
+
use_inputs_embeds: bool,
|
149
150
|
use_attention_mask: bool,
|
151
|
+
use_position_ids: bool,
|
152
|
+
use_learned_pos_emb: Optional[bool] = None,
|
150
153
|
kvcache_partition_len: Optional[int] = None,
|
151
154
|
kvcache_block_size: Optional[int] = None,
|
152
155
|
):
|
@@ -154,13 +157,21 @@ class DecoderOnlyWrapper(nn.Module):
|
|
154
157
|
self.config = causal_lm.config
|
155
158
|
|
156
159
|
if use_rotary_emb:
|
157
|
-
|
160
|
+
rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
|
161
|
+
if isinstance(rotary_embs, tuple):
|
162
|
+
self.rotary_emb_global, self.rotary_emb_local = rotary_embs
|
163
|
+
else:
|
164
|
+
self.rotary_emb = rotary_embs
|
158
165
|
else:
|
159
166
|
self.rotary_emb = None
|
160
167
|
|
161
168
|
self.attn_impl = attn_impl
|
162
169
|
self.kvcache_block_size = kvcache_block_size
|
163
170
|
self.use_attention_mask = use_attention_mask
|
171
|
+
self.use_position_ids = use_position_ids
|
172
|
+
self.use_inputs_embeds = use_inputs_embeds
|
173
|
+
self.use_learned_pos_emb = use_learned_pos_emb
|
174
|
+
|
164
175
|
if self.attn_impl == "flash_attn":
|
165
176
|
self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
166
177
|
elif self.attn_impl == "eager":
|
@@ -188,7 +199,10 @@ class DecoderOnlyWrapper(nn.Module):
|
|
188
199
|
for layer in causal_lm.model.layers:
|
189
200
|
if self.attn_impl == "eager":
|
190
201
|
new_self_attn = DecoderOnlyAttention(
|
191
|
-
layer.self_attn,
|
202
|
+
layer.self_attn,
|
203
|
+
self.use_attention_mask,
|
204
|
+
self.use_position_ids,
|
205
|
+
kvcache_block_size=self.kvcache_block_size,
|
192
206
|
)
|
193
207
|
elif self.attn_impl == "flash_attn":
|
194
208
|
new_self_attn = DecoderOnlyFlashAttention(
|
@@ -196,6 +210,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
196
210
|
kvcache_partition_len=self.kvcache_partition_len,
|
197
211
|
kvcache_block_size=self.kvcache_block_size,
|
198
212
|
use_attention_mask=self.use_attention_mask,
|
213
|
+
use_position_ids=self.use_position_ids,
|
199
214
|
)
|
200
215
|
else:
|
201
216
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
@@ -209,6 +224,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
209
224
|
partition_len=self.kvcache_partition_len,
|
210
225
|
max_seq_len=max_seq_len,
|
211
226
|
kvcache_block_size=self.kvcache_block_size,
|
227
|
+
use_learned_pos_emb=self.use_learned_pos_emb,
|
212
228
|
)
|
213
229
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
214
230
|
return new_causal_lm
|
@@ -222,24 +238,16 @@ class DecoderOnlyWrapper(nn.Module):
|
|
222
238
|
self._phase = phase
|
223
239
|
self.causal_lm.phase = phase
|
224
240
|
|
225
|
-
def
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
if input_ids_or_inputs_embeds.ndim == 2:
|
236
|
-
input_ids = input_ids_or_inputs_embeds
|
237
|
-
inputs_embeds = None
|
238
|
-
elif input_ids_or_inputs_embeds.ndim == 3:
|
239
|
-
input_ids = None
|
240
|
-
inputs_embeds = input_ids_or_inputs_embeds
|
241
|
-
else:
|
242
|
-
raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
|
241
|
+
def prepare_forward_args(self, *args):
|
242
|
+
args = list(args)
|
243
|
+
input_ids = None if self.use_inputs_embeds else args.pop(0)
|
244
|
+
inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
|
245
|
+
cache_position = args.pop(0)
|
246
|
+
block_tables = args.pop(0)
|
247
|
+
query_position = args.pop(0) if self.phase == "prefill" else None
|
248
|
+
attention_mask = args.pop(0) if self.use_attention_mask else None
|
249
|
+
position_ids = args.pop(0) if self.use_position_ids else None
|
250
|
+
past_key_values = args
|
243
251
|
|
244
252
|
if len(past_key_values) != 2 * self.num_hidden_layers:
|
245
253
|
raise ValueError(
|
@@ -256,11 +264,37 @@ class DecoderOnlyWrapper(nn.Module):
|
|
256
264
|
_past_key_values.append(past_key_value)
|
257
265
|
past_key_values = _past_key_values
|
258
266
|
|
267
|
+
return (
|
268
|
+
input_ids,
|
269
|
+
inputs_embeds,
|
270
|
+
cache_position,
|
271
|
+
block_tables,
|
272
|
+
query_position,
|
273
|
+
attention_mask,
|
274
|
+
position_ids,
|
275
|
+
past_key_values,
|
276
|
+
self.rotary_emb,
|
277
|
+
)
|
278
|
+
|
279
|
+
def forward(self, *args):
|
280
|
+
(
|
281
|
+
input_ids,
|
282
|
+
inputs_embeds,
|
283
|
+
cache_position,
|
284
|
+
block_tables,
|
285
|
+
query_position,
|
286
|
+
attention_mask,
|
287
|
+
position_ids,
|
288
|
+
past_key_values,
|
289
|
+
rotary_emb,
|
290
|
+
) = self.prepare_forward_args(*args)
|
291
|
+
|
259
292
|
logit = self.causal_lm(
|
260
293
|
input_ids=input_ids,
|
261
294
|
inputs_embeds=inputs_embeds,
|
262
295
|
attention_mask=attention_mask,
|
263
296
|
cache_position=cache_position,
|
297
|
+
position_ids=position_ids,
|
264
298
|
query_position=query_position,
|
265
299
|
past_key_values=past_key_values,
|
266
300
|
rotary_emb=rotary_emb,
|
@@ -269,58 +303,6 @@ class DecoderOnlyWrapper(nn.Module):
|
|
269
303
|
|
270
304
|
return logit
|
271
305
|
|
272
|
-
def forward(self, *args):
|
273
|
-
if self.phase == "decode":
|
274
|
-
if self.use_attention_mask:
|
275
|
-
(
|
276
|
-
input_ids_or_inputs_embeds,
|
277
|
-
cache_position,
|
278
|
-
attention_mask,
|
279
|
-
block_tables,
|
280
|
-
*past_key_values,
|
281
|
-
) = args
|
282
|
-
else:
|
283
|
-
(
|
284
|
-
input_ids_or_inputs_embeds,
|
285
|
-
cache_position,
|
286
|
-
block_tables,
|
287
|
-
*past_key_values,
|
288
|
-
) = args
|
289
|
-
attention_mask = None
|
290
|
-
query_position = None
|
291
|
-
elif self.phase == "prefill":
|
292
|
-
if self.use_attention_mask:
|
293
|
-
(
|
294
|
-
input_ids_or_inputs_embeds,
|
295
|
-
cache_position,
|
296
|
-
attention_mask,
|
297
|
-
query_position,
|
298
|
-
block_tables,
|
299
|
-
*past_key_values,
|
300
|
-
) = args
|
301
|
-
else:
|
302
|
-
(
|
303
|
-
input_ids_or_inputs_embeds,
|
304
|
-
cache_position,
|
305
|
-
query_position,
|
306
|
-
block_tables,
|
307
|
-
*past_key_values,
|
308
|
-
) = args
|
309
|
-
attention_mask = None
|
310
|
-
|
311
|
-
else:
|
312
|
-
raise ValueError(f"Unknown phase: {self.phase}")
|
313
|
-
|
314
|
-
return self.forward_common(
|
315
|
-
input_ids_or_inputs_embeds,
|
316
|
-
cache_position,
|
317
|
-
attention_mask,
|
318
|
-
query_position,
|
319
|
-
block_tables,
|
320
|
-
self.rotary_emb,
|
321
|
-
*past_key_values,
|
322
|
-
)
|
323
|
-
|
324
306
|
|
325
307
|
class DecoderOnlyForCausalLM(nn.Module):
|
326
308
|
"""A specialized wrapper for Causal Language Models optimized for RBLN compilation.
|
@@ -367,6 +349,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
367
349
|
inputs_embeds: torch.Tensor = None,
|
368
350
|
attention_mask: torch.Tensor = None,
|
369
351
|
cache_position: torch.Tensor = None,
|
352
|
+
position_ids: torch.Tensor = None,
|
370
353
|
query_position: torch.Tensor = None,
|
371
354
|
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
372
355
|
rotary_emb: nn.Module = None,
|
@@ -378,6 +361,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
378
361
|
inputs_embeds=inputs_embeds,
|
379
362
|
attention_mask=attention_mask,
|
380
363
|
cache_position=cache_position,
|
364
|
+
position_ids=position_ids,
|
381
365
|
past_key_values=past_key_values,
|
382
366
|
rotary_emb=rotary_emb,
|
383
367
|
block_tables=block_tables,
|
@@ -387,6 +371,13 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
387
371
|
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
388
372
|
|
389
373
|
logits = self.lm_head(hidden_states)
|
374
|
+
|
375
|
+
# Apply final logit softmaxing if configured, e.g. for Gemma2
|
376
|
+
if getattr(self.config, "final_logit_softcapping", None) is not None:
|
377
|
+
logits = logits / self.config.final_logit_softcapping
|
378
|
+
logits = torch.tanh(logits)
|
379
|
+
logits = logits * self.config.final_logit_softcapping
|
380
|
+
|
390
381
|
return logits
|
391
382
|
|
392
383
|
|
@@ -404,7 +395,13 @@ class DecoderOnlyModel(nn.Module):
|
|
404
395
|
"""
|
405
396
|
|
406
397
|
def __init__(
|
407
|
-
self,
|
398
|
+
self,
|
399
|
+
model,
|
400
|
+
layers: List["DecoderOnlyLayer"],
|
401
|
+
partition_len=None,
|
402
|
+
max_seq_len=None,
|
403
|
+
kvcache_block_size=None,
|
404
|
+
use_learned_pos_emb=None,
|
408
405
|
):
|
409
406
|
super().__init__()
|
410
407
|
self._original_mod = model
|
@@ -413,6 +410,7 @@ class DecoderOnlyModel(nn.Module):
|
|
413
410
|
self.partition_len = partition_len
|
414
411
|
self.kvcache_block_size = kvcache_block_size
|
415
412
|
self.max_seq_len = max_seq_len
|
413
|
+
self.use_learned_pos_emb = use_learned_pos_emb
|
416
414
|
|
417
415
|
@property
|
418
416
|
def phase(self):
|
@@ -457,11 +455,12 @@ class DecoderOnlyModel(nn.Module):
|
|
457
455
|
def forward(
|
458
456
|
self,
|
459
457
|
input_ids: torch.Tensor = None,
|
460
|
-
inputs_embeds: torch.Tensor = None,
|
458
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
461
459
|
attention_mask: torch.Tensor = None,
|
462
460
|
cache_position: torch.Tensor = None,
|
461
|
+
position_ids: torch.Tensor = None,
|
463
462
|
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
464
|
-
rotary_emb: nn.Module = None,
|
463
|
+
rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
|
465
464
|
block_tables: Optional[torch.Tensor] = None,
|
466
465
|
):
|
467
466
|
# retrieve input_ids and inputs_embeds
|
@@ -477,24 +476,38 @@ class DecoderOnlyModel(nn.Module):
|
|
477
476
|
hidden_states = inputs_embeds * self.hidden_multiplier
|
478
477
|
|
479
478
|
# get cos,sin vector if needed
|
479
|
+
position_ids = position_ids if position_ids is not None else cache_position
|
480
480
|
if rotary_emb is not None:
|
481
481
|
if isinstance(rotary_emb, torch.Tensor):
|
482
482
|
cos = rotary_emb[0]
|
483
483
|
sin = rotary_emb[1]
|
484
484
|
else:
|
485
485
|
cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
|
486
|
-
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin,
|
486
|
+
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
|
487
|
+
|
488
|
+
elif self.use_learned_pos_emb:
|
489
|
+
batch_size = inputs_embeds.shape[0]
|
490
|
+
hidden_all = []
|
491
|
+
for i in range(batch_size):
|
492
|
+
positions_idx = position_ids[i]
|
493
|
+
position_weight = self.get_pos_embedding().weight[2:]
|
494
|
+
position = position_weight[positions_idx]
|
495
|
+
batch_hidden = position + inputs_embeds[i]
|
496
|
+
hidden_all.append(batch_hidden)
|
497
|
+
hidden_states = torch.stack(hidden_all, dim=0)
|
498
|
+
cos, sin = None, None
|
499
|
+
|
487
500
|
else:
|
488
501
|
batch_size = inputs_embeds.shape[0]
|
489
|
-
if
|
502
|
+
if position_ids.shape[0] > 1:
|
490
503
|
position_embeds = []
|
491
504
|
for b_idx in range(batch_size):
|
492
|
-
position_embed = self.get_pos_embedding()(
|
505
|
+
position_embed = self.get_pos_embedding()(position_ids[b_idx])
|
493
506
|
position_embeds.append(position_embed)
|
494
507
|
|
495
508
|
position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
|
496
509
|
else:
|
497
|
-
position_embeds = self.get_pos_embedding()(
|
510
|
+
position_embeds = self.get_pos_embedding()(position_ids)
|
498
511
|
hidden_states = hidden_states + position_embeds
|
499
512
|
cos, sin = None, None
|
500
513
|
|
@@ -612,7 +625,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
612
625
|
self_attn: Original attention module from the base model
|
613
626
|
"""
|
614
627
|
|
615
|
-
def __init__(self, self_attn, use_attention_mask, kvcache_block_size):
|
628
|
+
def __init__(self, self_attn, use_attention_mask, use_position_ids, kvcache_block_size):
|
616
629
|
super().__init__()
|
617
630
|
self._original_mod = self_attn
|
618
631
|
self.layer_idx = self_attn.layer_idx
|
@@ -631,6 +644,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
631
644
|
self.num_key_value_heads = self.num_heads
|
632
645
|
|
633
646
|
self.use_attention_mask = use_attention_mask
|
647
|
+
self.use_position_ids = use_position_ids
|
634
648
|
self.attention = self.get_attention()
|
635
649
|
self.kvcache_block_size = kvcache_block_size
|
636
650
|
self.__post_init__()
|
@@ -645,7 +659,9 @@ class DecoderOnlyAttention(nn.Module):
|
|
645
659
|
self.attention.phase = phase
|
646
660
|
|
647
661
|
def get_attention(self):
|
648
|
-
return AttentionOp(
|
662
|
+
return AttentionOp(
|
663
|
+
self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
|
664
|
+
)
|
649
665
|
|
650
666
|
def __post_init__(self):
|
651
667
|
self.q_proj = self._original_mod.q_proj
|
@@ -718,13 +734,16 @@ class DecoderOnlyAttention(nn.Module):
|
|
718
734
|
|
719
735
|
|
720
736
|
class AttentionOp(nn.Module):
|
721
|
-
def __init__(
|
737
|
+
def __init__(
|
738
|
+
self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool, use_position_ids: bool
|
739
|
+
):
|
722
740
|
super().__init__()
|
723
741
|
self.num_heads = num_heads
|
724
742
|
self.head_dim = head_dim
|
725
743
|
self.num_key_value_heads = num_key_value_heads
|
726
744
|
self.phase = "prefill"
|
727
745
|
self.use_attention_mask = use_attention_mask
|
746
|
+
self.use_position_ids = use_position_ids
|
728
747
|
|
729
748
|
def forward(
|
730
749
|
self,
|
@@ -757,7 +776,8 @@ class AttentionOp(nn.Module):
|
|
757
776
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
758
777
|
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
759
778
|
value_state = value_state.unsqueeze(2)
|
760
|
-
|
779
|
+
|
780
|
+
if self.use_attention_mask and not self.use_position_ids:
|
761
781
|
attn_mask = attn_mask.unsqueeze(2)
|
762
782
|
|
763
783
|
if self.phase == "decode":
|
@@ -774,7 +794,7 @@ class AttentionOp(nn.Module):
|
|
774
794
|
)
|
775
795
|
|
776
796
|
if self.phase == "decode":
|
777
|
-
if self.use_attention_mask:
|
797
|
+
if self.use_attention_mask and not self.use_position_ids:
|
778
798
|
attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
|
779
799
|
q=query_state,
|
780
800
|
k=key_state,
|
@@ -798,10 +818,11 @@ class AttentionOp(nn.Module):
|
|
798
818
|
scale=scale,
|
799
819
|
block_table=block_tables,
|
800
820
|
block_size=block_size,
|
821
|
+
mask=attn_mask if self.use_position_ids else None,
|
801
822
|
)
|
802
823
|
|
803
824
|
else:
|
804
|
-
if self.use_attention_mask:
|
825
|
+
if self.use_attention_mask and not self.use_position_ids:
|
805
826
|
attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
|
806
827
|
q=query_state,
|
807
828
|
k=key_state,
|
@@ -825,6 +846,8 @@ class AttentionOp(nn.Module):
|
|
825
846
|
scale=scale,
|
826
847
|
block_table=block_tables,
|
827
848
|
block_size=block_size,
|
849
|
+
is_bidirectional=True if self.phase == "image_prefill" else False, # FIXME, Hard-coded for Gemma3.
|
850
|
+
mask=attn_mask if self.use_position_ids else None,
|
828
851
|
)
|
829
852
|
|
830
853
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
@@ -926,10 +949,13 @@ class RotaryEmbedding(nn.Module):
|
|
926
949
|
|
927
950
|
|
928
951
|
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
929
|
-
def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask):
|
952
|
+
def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask, use_position_ids):
|
930
953
|
self.kvcache_partition_size = kvcache_partition_len
|
931
954
|
super().__init__(
|
932
|
-
self_attn=self_attn,
|
955
|
+
self_attn=self_attn,
|
956
|
+
use_attention_mask=use_attention_mask,
|
957
|
+
use_position_ids=use_position_ids,
|
958
|
+
kvcache_block_size=kvcache_block_size,
|
933
959
|
)
|
934
960
|
|
935
961
|
def get_attention(self):
|
@@ -939,6 +965,7 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
939
965
|
self.num_key_value_heads,
|
940
966
|
self.kvcache_partition_size,
|
941
967
|
self.use_attention_mask,
|
968
|
+
self.use_position_ids,
|
942
969
|
)
|
943
970
|
|
944
971
|
def forward(
|
@@ -990,12 +1017,14 @@ class FlashAttentionOp(AttentionOp):
|
|
990
1017
|
num_key_value_heads: int,
|
991
1018
|
kvcache_partition_len: int,
|
992
1019
|
use_attention_mask: bool,
|
1020
|
+
use_position_ids: bool,
|
993
1021
|
):
|
994
1022
|
super().__init__(
|
995
1023
|
num_heads=num_heads,
|
996
1024
|
head_dim=head_dim,
|
997
1025
|
num_key_value_heads=num_key_value_heads,
|
998
1026
|
use_attention_mask=use_attention_mask,
|
1027
|
+
use_position_ids=use_position_ids,
|
999
1028
|
)
|
1000
1029
|
self.kvcache_partition_size = kvcache_partition_len
|
1001
1030
|
|
@@ -1015,7 +1044,7 @@ class FlashAttentionOp(AttentionOp):
|
|
1015
1044
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
1016
1045
|
key_state = key_state.unsqueeze(2)
|
1017
1046
|
value_state = value_state.unsqueeze(2)
|
1018
|
-
if self.use_attention_mask:
|
1047
|
+
if self.use_attention_mask and not self.use_position_ids:
|
1019
1048
|
attn_mask = attn_mask.unsqueeze(2)
|
1020
1049
|
|
1021
1050
|
if self.phase == "decode":
|
@@ -1032,7 +1061,7 @@ class FlashAttentionOp(AttentionOp):
|
|
1032
1061
|
)
|
1033
1062
|
|
1034
1063
|
if self.phase == "decode":
|
1035
|
-
if self.use_attention_mask:
|
1064
|
+
if self.use_attention_mask and not self.use_position_ids:
|
1036
1065
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
|
1037
1066
|
q=query_state,
|
1038
1067
|
k=key_state,
|
@@ -1058,9 +1087,10 @@ class FlashAttentionOp(AttentionOp):
|
|
1058
1087
|
block_table=block_tables,
|
1059
1088
|
block_size=kvcache_block_size,
|
1060
1089
|
partition=self.kvcache_partition_size,
|
1090
|
+
mask=attn_mask if self.use_position_ids else None,
|
1061
1091
|
)
|
1062
1092
|
else:
|
1063
|
-
if self.use_attention_mask:
|
1093
|
+
if self.use_attention_mask and not self.use_position_ids:
|
1064
1094
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
|
1065
1095
|
q=query_state,
|
1066
1096
|
k=key_state,
|
@@ -1086,6 +1116,8 @@ class FlashAttentionOp(AttentionOp):
|
|
1086
1116
|
block_table=block_tables,
|
1087
1117
|
block_size=kvcache_block_size,
|
1088
1118
|
partition=self.kvcache_partition_size,
|
1119
|
+
is_bidirectional=True if self.phase == "image_prefill" else False,
|
1120
|
+
mask=attn_mask if self.use_position_ids else None,
|
1089
1121
|
)
|
1090
1122
|
|
1091
1123
|
# reshape for removing repeat_kv
|
@@ -1094,3 +1126,70 @@ class FlashAttentionOp(AttentionOp):
|
|
1094
1126
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
1095
1127
|
|
1096
1128
|
return attn_output
|
1129
|
+
|
1130
|
+
|
1131
|
+
class SlidingWindowAttentionOp(AttentionOp):
|
1132
|
+
def forward(
|
1133
|
+
self,
|
1134
|
+
query_state: torch.Tensor,
|
1135
|
+
key_state: torch.Tensor,
|
1136
|
+
value_state: torch.Tensor,
|
1137
|
+
attn_mask: torch.Tensor,
|
1138
|
+
past_key_state: torch.Tensor,
|
1139
|
+
past_value_state: torch.Tensor,
|
1140
|
+
seq_position: Tuple[torch.Tensor],
|
1141
|
+
scale: torch.Tensor,
|
1142
|
+
block_tables: torch.Tensor,
|
1143
|
+
block_size: int,
|
1144
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
1145
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
1146
|
+
key_state = key_state.unsqueeze(2)
|
1147
|
+
value_state = value_state.unsqueeze(2)
|
1148
|
+
|
1149
|
+
if self.phase == "decode":
|
1150
|
+
batch_size = key_state.shape[0]
|
1151
|
+
else:
|
1152
|
+
batch_size = 1
|
1153
|
+
|
1154
|
+
query_state = query_state.view(
|
1155
|
+
batch_size,
|
1156
|
+
self.num_key_value_heads,
|
1157
|
+
self.num_heads // self.num_key_value_heads,
|
1158
|
+
-1, # seq len
|
1159
|
+
self.head_dim,
|
1160
|
+
)
|
1161
|
+
|
1162
|
+
if self.phase == "decode":
|
1163
|
+
attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_decode(
|
1164
|
+
q=query_state,
|
1165
|
+
k=key_state,
|
1166
|
+
v=value_state,
|
1167
|
+
kcache=past_key_state.unsqueeze(2),
|
1168
|
+
vcache=past_value_state.unsqueeze(2),
|
1169
|
+
cache_seq_len=seq_position[0],
|
1170
|
+
cache_offset=seq_position[1],
|
1171
|
+
scale=scale,
|
1172
|
+
block_table=block_tables,
|
1173
|
+
block_size=block_size,
|
1174
|
+
)
|
1175
|
+
else:
|
1176
|
+
attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_prefill(
|
1177
|
+
q=query_state,
|
1178
|
+
k=key_state,
|
1179
|
+
v=value_state,
|
1180
|
+
kcache=past_key_state.unsqueeze(2),
|
1181
|
+
vcache=past_value_state.unsqueeze(2),
|
1182
|
+
cache_seq_len=seq_position[0],
|
1183
|
+
cache_offset=seq_position[1],
|
1184
|
+
scale=scale,
|
1185
|
+
block_table=block_tables,
|
1186
|
+
block_size=block_size,
|
1187
|
+
is_bidirectional=True if self.phase == "image_prefill" else False,
|
1188
|
+
)
|
1189
|
+
|
1190
|
+
# reshape for removing repeat_kv
|
1191
|
+
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
1192
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
1193
|
+
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
1194
|
+
|
1195
|
+
return attn_output
|