optimum-rbln 0.7.3a1__py3-none-any.whl → 0.7.3a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,9 +13,10 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
+ from collections import deque
16
17
  from dataclasses import dataclass
17
18
  from pathlib import Path
18
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
19
+ from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tuple, Union
19
20
 
20
21
  import rebel
21
22
  import torch
@@ -50,17 +51,28 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
50
51
  phase: str,
51
52
  batch_size: int,
52
53
  dec_attn_mask: torch.Tensor,
54
+ block_tables: torch.Tensor,
55
+ free_block_pool: Deque,
56
+ kvcache_block_size: int,
53
57
  use_attention_mask: bool,
58
+ attn_impl: str,
54
59
  **kwargs: Any,
55
60
  ) -> None:
56
61
  super().__init__(runtime, **kwargs)
57
62
  self.phase = phase
58
63
  self.batch_size = batch_size
59
64
 
65
+ # shared data structures between prefill and decode phase
60
66
  self.use_attention_mask = use_attention_mask
61
67
 
62
68
  # shared tensor between prefill and decode phase
63
69
  self.dec_attn_mask = dec_attn_mask
70
+ self.block_tables = block_tables
71
+ self.free_block_pool = free_block_pool
72
+
73
+ self.kvcache_block_size = kvcache_block_size
74
+ self.empty_block = -1
75
+ self.attn_impl = attn_impl
64
76
 
65
77
  if self.phase == "prefill":
66
78
  vocab_size = kwargs.pop("vocab_size")
@@ -71,6 +83,75 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
71
83
  torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
72
84
  )
73
85
 
86
+ def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None):
87
+ """
88
+ Manages and returns the KV cache block tables.
89
+ Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
90
+
91
+ Args:
92
+ cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
93
+ batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
94
+
95
+ Returns:
96
+ torch.Tensor: Updated block tables.
97
+ """
98
+
99
+ NO_BLOCKS_ERROR = (
100
+ "No memory blocks are available for allocation."
101
+ "The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln."
102
+ "This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html)."
103
+ "Using vllm-rbln should fix this issue and enhance inference performance."
104
+ )
105
+
106
+ def update_block(batch_idx: int, block_idx: int):
107
+ """
108
+ If the block is empty (empty_block), allocates a block from the free_block_pool.
109
+ """
110
+ if self.block_tables[batch_idx][block_idx] == self.empty_block:
111
+ if self.free_block_pool:
112
+ block = self.free_block_pool.popleft()
113
+ self.block_tables[batch_idx][block_idx] = block
114
+ else:
115
+ raise RuntimeError(NO_BLOCKS_ERROR)
116
+
117
+ def replace_empty_block(block_tables: torch.Tensor):
118
+ """
119
+ Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
120
+ """
121
+ if not torch.any(block_tables == self.empty_block):
122
+ return block_tables.clone()
123
+ elif self.free_block_pool:
124
+ _free_block = self.free_block_pool[0]
125
+ return torch.where(block_tables == self.empty_block, _free_block, block_tables)
126
+ else:
127
+ raise RuntimeError(NO_BLOCKS_ERROR)
128
+
129
+ if self.phase == "prefill":
130
+ # Track previously used blocks and return them to the free_block_pool and
131
+ # reset the current batch's block table to empty blocks
132
+ prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
133
+ self.free_block_pool.extend(prev_blocks)
134
+ self.block_tables[batch_idx].fill_(self.empty_block)
135
+
136
+ # Get the start (s) and end (e) positions from cache_position and
137
+ # iterate over the cache positions to allocate necessary blocks
138
+ s, e = cache_position[0][0].item(), cache_position[0][-1].item()
139
+ for position in range(s, e + 1, self.kvcache_block_size):
140
+ block_idx = position // self.kvcache_block_size
141
+ if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
142
+ raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
143
+ update_block(batch_idx, block_idx)
144
+
145
+ return replace_empty_block(self.block_tables[batch_idx])
146
+ # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
147
+ else:
148
+ for b_idx in range(self.batch_size):
149
+ position = cache_position[b_idx][0].item()
150
+ block_idx = position // self.kvcache_block_size
151
+ update_block(b_idx, block_idx)
152
+
153
+ return replace_empty_block(self.block_tables)
154
+
74
155
  def forward(
75
156
  self,
76
157
  input_ids: Optional[torch.LongTensor] = None,
@@ -78,6 +159,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
78
159
  cache_position: torch.Tensor = None,
79
160
  attention_mask: Optional[torch.Tensor] = None,
80
161
  batch_idx: Optional[int] = None,
162
+ block_tables: Optional[torch.Tensor] = None,
81
163
  ):
82
164
  if input_ids is None and inputs_embeds is None:
83
165
  raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
@@ -89,19 +171,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
89
171
  else:
90
172
  inputs = inputs_embeds
91
173
 
174
+ if block_tables is None:
175
+ block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
176
+ is_external_block_tables = False
177
+ else:
178
+ is_external_block_tables = True
179
+
92
180
  if self.phase == "decode":
93
181
  return self.decode_forward(
94
182
  inputs,
95
183
  cache_position,
184
+ block_tables,
185
+ is_external_block_tables,
96
186
  attention_mask=attention_mask,
97
187
  )
98
188
  else:
99
- return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx)
189
+ return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx, block_tables)
100
190
 
101
191
  def decode_forward(
102
192
  self,
103
193
  inputs: torch.Tensor,
104
194
  cache_position: torch.Tensor = None,
195
+ block_tables: torch.Tensor = None,
196
+ is_external_block_tables: bool = None,
105
197
  attention_mask: Optional[torch.Tensor] = None,
106
198
  ) -> torch.FloatTensor:
107
199
  batch_size = inputs.shape[0]
@@ -120,7 +212,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
120
212
  raise ValueError(
121
213
  f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
122
214
  )
123
- self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
215
+
216
+ if is_external_block_tables:
217
+ self.dec_attn_mask[b_idx].fill_(0)
218
+ self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
219
+ else:
220
+ self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
221
+
222
+ attention_mask = self.dec_attn_mask
124
223
 
125
224
  attention_mask = self.dec_attn_mask
126
225
 
@@ -128,6 +227,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
128
227
  inputs,
129
228
  cache_position,
130
229
  attention_mask if self.use_attention_mask else None,
230
+ block_tables,
131
231
  )
132
232
 
133
233
  return logits
@@ -138,6 +238,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
138
238
  cache_position: torch.Tensor = None,
139
239
  attention_mask: Optional[torch.Tensor] = None,
140
240
  batch_idx: int = None,
241
+ block_tables: torch.Tensor = None,
242
+ is_external_block_tables: bool = None,
141
243
  ) -> torch.FloatTensor:
142
244
  """
143
245
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
@@ -145,11 +247,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
145
247
  and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
146
248
  """
147
249
 
148
- if batch_idx is None or batch_idx >= self.batch_size:
149
- raise RuntimeError(
150
- f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
151
- )
152
-
153
250
  # Handle continuous batching in a compiled graph by extracting valid inputs
154
251
  # If an attention mask is provided, select only the valid (non-masked) inputs
155
252
  inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
@@ -207,33 +304,21 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
207
304
  chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
208
305
  chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
209
306
 
210
- # Define batch position and query position
211
- batch_position = torch.tensor(batch_idx, dtype=torch.int16)
307
+ # Define query position
212
308
  query_position = torch.tensor((query_length - 1) % self.prefill_chunk_size, dtype=torch.int16)
213
309
 
214
- if self.use_attention_mask:
215
- args = (
216
- input_chunk,
217
- cache_pos_chunk,
218
- chunked_attention_mask,
219
- batch_position,
220
- query_position,
221
- )
222
- else:
223
- args = (
224
- input_chunk,
225
- cache_pos_chunk,
226
- batch_position,
227
- query_position,
228
- )
229
310
  # Forward pass for the current chunk
230
311
  logits = super().forward(
231
- *args,
312
+ input_chunk,
313
+ cache_pos_chunk,
314
+ chunked_attention_mask if self.use_attention_mask else None,
315
+ query_position,
316
+ block_tables,
232
317
  out=out_buffers,
233
318
  )
234
319
 
235
- if self.use_attention_mask:
236
- # Update decoder attention mask with processed KV-cache length from prefill phase
320
+ # Update decoder attention mask with processed KV-cache length from prefill phase
321
+ if not is_external_block_tables and self.use_attention_mask:
237
322
  self.dec_attn_mask[batch_idx].fill_(0)
238
323
  self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
239
324
 
@@ -275,9 +360,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
275
360
  self.batch_size = self.rbln_config.model_cfg["batch_size"]
276
361
  self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
277
362
  self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
363
+ self.kvcache_block_size = self.rbln_config.model_cfg["kvcache_block_size"]
364
+ # FIXME get kvcache_num_blocks from compiled results.
365
+ self.kvcache_num_blocks = self.rbln_config.model_cfg["kvcache_num_blocks"]
278
366
  self.use_attention_mask = self.rbln_config.model_cfg["use_attention_mask"]
279
-
367
+ attn_impl = self.rbln_config.model_cfg["attn_impl"]
280
368
  main_input_name = self.main_input_name
369
+
281
370
  if self.rbln_config.model_cfg["use_inputs_embeds"]:
282
371
  main_input_name = "inputs_embeds"
283
372
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
@@ -291,7 +380,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
291
380
  else:
292
381
  self.embed_tokens = None
293
382
 
383
+ # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
294
384
  dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
385
+ block_tables = torch.zeros(
386
+ self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
387
+ ).fill_(-1)
388
+ free_block_pool = deque(x for x in range(self.kvcache_num_blocks))
389
+
295
390
  self.prefill_decoder = RBLNRuntimeModel(
296
391
  runtime=self.model[0],
297
392
  main_input_name=main_input_name,
@@ -299,10 +394,14 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
299
394
  phase="prefill",
300
395
  batch_size=self.batch_size,
301
396
  dec_attn_mask=dec_attn_mask,
397
+ block_tables=block_tables,
398
+ free_block_pool=free_block_pool,
399
+ kvcache_block_size=self.kvcache_block_size,
302
400
  vocab_size=self.config.vocab_size,
303
- max_seq_len=self.max_seq_len,
304
401
  prefill_chunk_size=self.prefill_chunk_size,
402
+ max_seq_len=self.max_seq_len,
305
403
  use_attention_mask=self.use_attention_mask,
404
+ attn_impl=attn_impl,
306
405
  )
307
406
  self.decoder = RBLNRuntimeModel(
308
407
  runtime=self.model[1],
@@ -311,7 +410,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
311
410
  phase="decode",
312
411
  batch_size=self.batch_size,
313
412
  dec_attn_mask=dec_attn_mask,
413
+ block_tables=block_tables,
414
+ free_block_pool=free_block_pool,
415
+ kvcache_block_size=self.kvcache_block_size,
314
416
  use_attention_mask=self.use_attention_mask,
417
+ attn_impl=attn_impl,
315
418
  )
316
419
 
317
420
  @classmethod
@@ -409,6 +512,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
409
512
  wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
410
513
  wrapper_cfg["attn_impl"] = rbln_config.model_cfg.get("attn_impl")
411
514
  wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
515
+ wrapper_cfg["kvcache_block_size"] = rbln_config.model_cfg.get("kvcache_block_size")
412
516
  wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
413
517
  wrapper_cfg["use_attention_mask"] = rbln_config.model_cfg.get("use_attention_mask")
414
518
 
@@ -474,6 +578,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
474
578
  rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
475
579
  rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
476
580
  rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
581
+ rbln_kvcache_block_size = rbln_kwargs.get("kvcache_block_size", None)
477
582
  rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
478
583
  rbln_prefill_chunk_size = rbln_kwargs.get("prefill_chunk_size", None)
479
584
 
@@ -500,12 +605,22 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
500
605
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
501
606
  rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
502
607
 
503
- rbln_attn_impl, rbln_kvcache_partition_len = validate_attention_method(
608
+ rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size = validate_attention_method(
504
609
  rbln_attn_impl=rbln_attn_impl,
505
610
  rbln_kvcache_partition_len=rbln_kvcache_partition_len,
611
+ rbln_kvcache_block_size=rbln_kvcache_block_size,
506
612
  rbln_max_seq_len=rbln_max_seq_len,
507
613
  )
508
614
 
615
+ if rbln_kvcache_block_size is None:
616
+ if rbln_attn_impl == "eager":
617
+ rbln_kvcache_block_size = rbln_max_seq_len
618
+ else:
619
+ rbln_kvcache_block_size = rbln_kvcache_partition_len
620
+
621
+ # FIXME temporal num_blocks
622
+ rbln_kvcache_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
623
+
509
624
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
510
625
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
511
626
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
@@ -542,19 +657,25 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
542
657
  if query_length > 1:
543
658
  input_info.extend(
544
659
  [
545
- ("batch_position", [], "int16"),
546
660
  ("query_position", [], "int16"),
547
661
  ]
548
662
  )
549
663
 
664
+ max_block_cnt = rbln_max_seq_len // rbln_kvcache_block_size
665
+
666
+ if query_length > 1:
667
+ input_info.extend([("block_tables", [max_block_cnt], "int16")])
668
+ else:
669
+ input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
670
+
550
671
  input_info.extend(
551
672
  [
552
673
  (
553
674
  f"past_key_values_{i}",
554
675
  [
555
- rbln_batch_size,
676
+ rbln_kvcache_num_blocks,
556
677
  num_key_value_heads,
557
- rbln_max_seq_len,
678
+ rbln_kvcache_block_size,
558
679
  head_dim,
559
680
  ],
560
681
  "float32",
@@ -595,7 +716,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
595
716
  "use_attention_mask": rbln_use_attention_mask,
596
717
  "use_inputs_embeds": rbln_use_inputs_embeds,
597
718
  "kvcache_partition_len": rbln_kvcache_partition_len,
719
+ "kvcache_block_size": rbln_kvcache_block_size,
598
720
  "attn_impl": rbln_attn_impl,
721
+ "kvcache_num_blocks": rbln_kvcache_num_blocks,
599
722
  }
600
723
  )
601
724
 
@@ -40,10 +40,15 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
40
40
  new_layers = []
41
41
  for layer in causal_lm.transformer.h:
42
42
  if self.attn_impl == "eager":
43
- new_self_attn = ExaoneAttention(layer.attn.attention, self.use_attention_mask)
43
+ new_self_attn = ExaoneAttention(
44
+ layer.attn.attention, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
45
+ )
44
46
  elif self.attn_impl == "flash_attn":
45
47
  new_self_attn = ExaoneFlashAttention(
46
- layer.attn.attention, kvcache_partition_len=self.kvcache_partition_len
48
+ layer.attn.attention,
49
+ kvcache_partition_len=self.kvcache_partition_len,
50
+ use_attention_mask=self.use_attention_mask,
51
+ kvcache_block_size=self.kvcache_block_size,
47
52
  )
48
53
  else:
49
54
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
@@ -33,10 +33,15 @@ class GemmaWrapper(DecoderOnlyWrapper):
33
33
  new_layers = []
34
34
  for layer in causal_lm.model.layers:
35
35
  if self.attn_impl == "eager":
36
- new_self_attn = DecoderOnlyAttention(layer.self_attn, self.use_attention_mask)
36
+ new_self_attn = DecoderOnlyAttention(
37
+ layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
38
+ )
37
39
  elif self.attn_impl == "flash_attn":
38
40
  new_self_attn = DecoderOnlyFlashAttention(
39
- layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
41
+ layer.self_attn,
42
+ kvcache_partition_len=self.kvcache_partition_len,
43
+ use_attention_mask=self.use_attention_mask,
44
+ kvcache_block_size=self.kvcache_block_size,
40
45
  )
41
46
  else:
42
47
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
@@ -37,7 +37,9 @@ class GPT2Wrapper(DecoderOnlyWrapper):
37
37
  raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
38
38
  new_layers = []
39
39
  for layer in causal_lm.transformer.h:
40
- new_self_attn = GPT2Attention(layer.attn, self.use_attention_mask)
40
+ new_self_attn = GPT2Attention(
41
+ layer.attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
42
+ )
41
43
  new_layer = GPT2Layer(layer, new_self_attn)
42
44
  new_layers.append(new_layer)
43
45
  new_model = GPT2Model(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
@@ -60,7 +60,9 @@ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
60
60
  raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
61
61
  new_layers = []
62
62
  for layer in causal_lm.transformer.h:
63
- new_self_attn = MidmAttention(layer.attn, self.use_attention_mask)
63
+ new_self_attn = MidmAttention(
64
+ layer.attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
65
+ )
64
66
  new_layer = MidmLayer(layer, new_self_attn)
65
67
  new_layers.append(new_layer)
66
68
  new_model = MidmModel(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
@@ -36,7 +36,9 @@ class PhiWrapper(DecoderOnlyWrapper):
36
36
  new_layers = []
37
37
  for layer in causal_lm.model.layers:
38
38
  if self.attn_impl == "eager":
39
- new_self_attn = PhiAttention(layer.self_attn, self.use_attention_mask)
39
+ new_self_attn = PhiAttention(
40
+ layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
41
+ )
40
42
  elif self.attn_impl == "flash_attn":
41
43
  raise NotImplementedError(f"flash attn for {self.__class__} is not implemented yet.")
42
44
  else:
@@ -81,10 +83,10 @@ class PhiLayer(DecoderOnlyLayer):
81
83
  hidden_states: torch.Tensor,
82
84
  attention_mask: torch.Tensor,
83
85
  seq_positions: torch.LongTensor,
84
- batch_position: torch.Tensor,
85
86
  past_key_values: Tuple[Tuple[torch.Tensor]],
86
87
  cos: Optional[torch.Tensor] = None,
87
88
  sin: Optional[torch.Tensor] = None,
89
+ block_tables: Optional[torch.Tensor] = None,
88
90
  ):
89
91
  residual = hidden_states
90
92
 
@@ -94,10 +96,10 @@ class PhiLayer(DecoderOnlyLayer):
94
96
  hidden_states=hidden_states,
95
97
  attention_mask=attention_mask,
96
98
  seq_positions=seq_positions,
97
- batch_position=batch_position,
98
99
  past_key_values=past_key_values,
99
100
  cos=cos,
100
101
  sin=sin,
102
+ block_tables=block_tables,
101
103
  )
102
104
 
103
105
  feed_forward_hidden_states = self._original_mod.mlp(hidden_states)
@@ -50,11 +50,14 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
50
50
  runtime: rebel.Runtime,
51
51
  batch_size: int,
52
52
  dec_max_seq_len: int,
53
+ use_attention_mask: Optional[bool] = None,
53
54
  **kwargs: Any,
54
55
  ) -> None:
55
56
  super().__init__(runtime, **kwargs)
56
57
  self.batch_size = batch_size
57
58
  self.dec_max_seq_len = dec_max_seq_len
59
+ self.use_attention_mask = use_attention_mask
60
+ self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
58
61
 
59
62
  def forward(
60
63
  self,
@@ -62,6 +65,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
62
65
  attention_mask: Optional[torch.FloatTensor] = None,
63
66
  decoder_attention_mask: Optional[torch.BoolTensor] = None,
64
67
  cache_position: Optional[torch.Tensor] = None,
68
+ block_tables: Optional[torch.Tensor] = None,
65
69
  **kwargs,
66
70
  ) -> Tuple[torch.FloatTensor]:
67
71
  batch_size = decoder_input_ids.shape[0]
@@ -73,19 +77,24 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
73
77
  if batch_size != cache_position.shape[0]:
74
78
  raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
75
79
 
76
- for b_idx in range(self.batch_size):
77
- decoding_step = cache_position[b_idx].item()
78
- if not (0 <= decoding_step < self.dec_max_seq_len):
79
- raise ValueError(
80
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
81
- )
82
- decoder_attention_mask[b_idx, : decoding_step + 1] = 1
80
+ if self.use_attention_mask:
81
+ for b_idx in range(self.batch_size):
82
+ decoding_step = cache_position[b_idx].item()
83
+ if not (0 <= decoding_step < self.dec_max_seq_len):
84
+ raise ValueError(
85
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
86
+ )
87
+ decoder_attention_mask[b_idx, : decoding_step + 1] = 1
88
+
89
+ if block_tables is None:
90
+ block_tables = self.default_block_tables
83
91
 
84
92
  lm_logits = super().forward(
85
93
  decoder_input_ids,
86
- decoder_attention_mask,
94
+ decoder_attention_mask if self.use_attention_mask else None,
87
95
  attention_mask,
88
96
  cache_position,
97
+ block_tables,
89
98
  )
90
99
 
91
100
  return Seq2SeqLMOutput(logits=lm_logits)
@@ -110,12 +119,18 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
110
119
  def __post_init__(self, **kwargs):
111
120
  batch_size = self.rbln_config.model_cfg["batch_size"]
112
121
  dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
122
+ self.use_attention_mask = self.rbln_config.model_cfg.get("use_attention_mask", None)
123
+
113
124
  self.encoder = RBLNRuntimeEncoder(
114
125
  runtime=self.model[0],
115
126
  main_input_name="input_ids",
116
127
  )
117
128
  self.decoder = RBLNRuntimeDecoder(
118
- runtime=self.model[1], main_input_name="input_ids", batch_size=batch_size, dec_max_seq_len=dec_max_seq_len
129
+ runtime=self.model[1],
130
+ main_input_name="input_ids",
131
+ batch_size=batch_size,
132
+ dec_max_seq_len=dec_max_seq_len,
133
+ use_attention_mask=self.use_attention_mask,
119
134
  )
120
135
 
121
136
  @classmethod
@@ -171,6 +186,13 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
171
186
  rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
172
187
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
173
188
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
189
+ rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
190
+
191
+ if rbln_use_attention_mask is None:
192
+ rbln_use_attention_mask = False
193
+ rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
194
+ if rbln_npu == "RBLN-CA02":
195
+ rbln_use_attention_mask = True
174
196
 
175
197
  n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
176
198
  n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
@@ -232,18 +254,22 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
232
254
  ],
233
255
  "float32",
234
256
  ),
235
- ("batch_position", [], "int16"),
257
+ ("block_tables", [1], "int16"),
236
258
  ]
237
259
 
238
260
  dec_input_info = [
239
261
  ("input_ids", [rbln_batch_size, 1], "int64"),
240
- ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
241
262
  ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
242
263
  (
243
264
  "cache_position",
244
265
  [rbln_batch_size, 1],
245
266
  "int32",
246
267
  ),
268
+ (
269
+ "block_tables",
270
+ [rbln_batch_size, 1],
271
+ "int16",
272
+ ),
247
273
  ]
248
274
  dec_input_info.extend(
249
275
  [
@@ -275,6 +301,10 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
275
301
  for i in range(n_layer * 2)
276
302
  ]
277
303
  )
304
+
305
+ if rbln_use_attention_mask:
306
+ dec_input_info.insert(1, ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
307
+
278
308
  enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
279
309
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
280
310
 
@@ -290,6 +320,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
290
320
  "dec_max_seq_len": rbln_dec_max_seq_len,
291
321
  "batch_size": rbln_batch_size,
292
322
  "pad_token_id": rbln_pad_token_id,
323
+ "use_attention_mask": rbln_use_attention_mask,
293
324
  }
294
325
  )
295
326
 
@@ -400,9 +431,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
400
431
  encoder_kwargs["output_attentions"] = False
401
432
 
402
433
  for b in range(batch_size):
403
- batch_position = torch.tensor(b, dtype=torch.int16)
434
+ block_tables = torch.tensor([b], dtype=torch.int16)
404
435
  encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
405
436
  encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
406
- model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, batch_position=batch_position)
437
+ model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, block_tables=block_tables)
407
438
 
408
439
  return model_kwargs