optimum-rbln 0.7.3a1__py3-none-any.whl → 0.7.3a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,9 +13,10 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
+ from collections import deque
16
17
  from dataclasses import dataclass
17
18
  from pathlib import Path
18
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
19
+ from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tuple, Union
19
20
 
20
21
  import rebel
21
22
  import torch
@@ -50,17 +51,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
50
51
  phase: str,
51
52
  batch_size: int,
52
53
  dec_attn_mask: torch.Tensor,
54
+ block_tables: torch.Tensor,
55
+ free_block_pool: Deque,
56
+ kvcache_block_size: int,
57
+ kvcache_num_blocks: int,
53
58
  use_attention_mask: bool,
59
+ attn_impl: str,
54
60
  **kwargs: Any,
55
61
  ) -> None:
56
62
  super().__init__(runtime, **kwargs)
57
63
  self.phase = phase
58
64
  self.batch_size = batch_size
59
65
 
66
+ # shared data structures between prefill and decode phase
60
67
  self.use_attention_mask = use_attention_mask
61
68
 
62
69
  # shared tensor between prefill and decode phase
63
70
  self.dec_attn_mask = dec_attn_mask
71
+ self.block_tables = block_tables
72
+ self.free_block_pool = free_block_pool
73
+
74
+ self.kvcache_block_size = kvcache_block_size
75
+ self.empty_block = kvcache_num_blocks - 1
76
+ self.attn_impl = attn_impl
64
77
 
65
78
  if self.phase == "prefill":
66
79
  vocab_size = kwargs.pop("vocab_size")
@@ -71,6 +84,72 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
71
84
  torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
72
85
  )
73
86
 
87
+ def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None):
88
+ """
89
+ Manages and returns the KV cache block tables.
90
+ Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
91
+
92
+ Args:
93
+ cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
94
+ batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
95
+
96
+ Returns:
97
+ torch.Tensor: Updated block tables.
98
+ """
99
+
100
+ def update_block(batch_idx, block_idx):
101
+ """
102
+ Helper function to update the block table for a given batch index and block index.
103
+ If the block is empty (empty_block), allocates a block from the free_block_pool.
104
+
105
+ Args:
106
+ batch_idx (int): Batch index.
107
+ block_idx (int): Block index.
108
+
109
+ Raises:
110
+ RuntimeError: Raised if no available blocks are found in the free_block_pool.
111
+ """
112
+ if self.block_tables[batch_idx][block_idx] == self.empty_block:
113
+ if self.free_block_pool:
114
+ block = self.free_block_pool.popleft()
115
+ self.block_tables[batch_idx][block_idx] = block
116
+ else:
117
+ raise RuntimeError("Not available blocks")
118
+
119
+ if self.attn_impl == "eager":
120
+ if self.phase == "prefill":
121
+ return self.block_tables[batch_idx]
122
+ else:
123
+ return self.block_tables
124
+ # Case for 'flash_attn' attention implementation
125
+ else:
126
+ if self.phase == "prefill":
127
+ # Track previously used blocks and return them to the free_block_pool and
128
+ # reset the current batch's block table to empty blocks
129
+ prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
130
+ self.free_block_pool.extend(prev_blocks)
131
+ self.block_tables[batch_idx].fill_(self.empty_block)
132
+
133
+ # Get the start (s) and end (e) positions from cache_position and
134
+ # iterate over the cache positions to allocate necessary blocks
135
+ s, e = cache_position[0][0].item(), cache_position[0][-1].item()
136
+ for position in range(s, e + 1, self.kvcache_block_size):
137
+ block_idx = position // self.kvcache_block_size
138
+ if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
139
+ raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
140
+ update_block(batch_idx, block_idx)
141
+
142
+ return self.block_tables[batch_idx]
143
+
144
+ # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
145
+ else:
146
+ for b_idx in range(self.batch_size):
147
+ position = cache_position[b_idx][0].item()
148
+ block_idx = position // self.kvcache_block_size
149
+ update_block(b_idx, block_idx)
150
+
151
+ return self.block_tables
152
+
74
153
  def forward(
75
154
  self,
76
155
  input_ids: Optional[torch.LongTensor] = None,
@@ -78,6 +157,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
78
157
  cache_position: torch.Tensor = None,
79
158
  attention_mask: Optional[torch.Tensor] = None,
80
159
  batch_idx: Optional[int] = None,
160
+ block_tables: Optional[torch.Tensor] = None,
81
161
  ):
82
162
  if input_ids is None and inputs_embeds is None:
83
163
  raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
@@ -89,19 +169,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
89
169
  else:
90
170
  inputs = inputs_embeds
91
171
 
172
+ if block_tables is None:
173
+ block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
174
+ is_external_block_tables = False
175
+ else:
176
+ is_external_block_tables = True
177
+
92
178
  if self.phase == "decode":
93
179
  return self.decode_forward(
94
180
  inputs,
95
181
  cache_position,
182
+ block_tables,
183
+ is_external_block_tables,
96
184
  attention_mask=attention_mask,
97
185
  )
98
186
  else:
99
- return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx)
187
+ return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx, block_tables)
100
188
 
101
189
  def decode_forward(
102
190
  self,
103
191
  inputs: torch.Tensor,
104
192
  cache_position: torch.Tensor = None,
193
+ block_tables: torch.Tensor = None,
194
+ is_external_block_tables: bool = None,
105
195
  attention_mask: Optional[torch.Tensor] = None,
106
196
  ) -> torch.FloatTensor:
107
197
  batch_size = inputs.shape[0]
@@ -120,7 +210,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
120
210
  raise ValueError(
121
211
  f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
122
212
  )
123
- self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
213
+
214
+ if is_external_block_tables:
215
+ self.dec_attn_mask[b_idx].fill_(0)
216
+ self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
217
+ else:
218
+ self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
219
+
220
+ attention_mask = self.dec_attn_mask
124
221
 
125
222
  attention_mask = self.dec_attn_mask
126
223
 
@@ -128,6 +225,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
128
225
  inputs,
129
226
  cache_position,
130
227
  attention_mask if self.use_attention_mask else None,
228
+ block_tables,
131
229
  )
132
230
 
133
231
  return logits
@@ -138,6 +236,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
138
236
  cache_position: torch.Tensor = None,
139
237
  attention_mask: Optional[torch.Tensor] = None,
140
238
  batch_idx: int = None,
239
+ block_tables: torch.Tensor = None,
240
+ is_external_block_tables: bool = None,
141
241
  ) -> torch.FloatTensor:
142
242
  """
143
243
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
@@ -145,11 +245,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
145
245
  and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
146
246
  """
147
247
 
148
- if batch_idx is None or batch_idx >= self.batch_size:
149
- raise RuntimeError(
150
- f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
151
- )
152
-
153
248
  # Handle continuous batching in a compiled graph by extracting valid inputs
154
249
  # If an attention mask is provided, select only the valid (non-masked) inputs
155
250
  inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
@@ -207,33 +302,21 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
207
302
  chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
208
303
  chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
209
304
 
210
- # Define batch position and query position
211
- batch_position = torch.tensor(batch_idx, dtype=torch.int16)
305
+ # Define query position
212
306
  query_position = torch.tensor((query_length - 1) % self.prefill_chunk_size, dtype=torch.int16)
213
307
 
214
- if self.use_attention_mask:
215
- args = (
216
- input_chunk,
217
- cache_pos_chunk,
218
- chunked_attention_mask,
219
- batch_position,
220
- query_position,
221
- )
222
- else:
223
- args = (
224
- input_chunk,
225
- cache_pos_chunk,
226
- batch_position,
227
- query_position,
228
- )
229
308
  # Forward pass for the current chunk
230
309
  logits = super().forward(
231
- *args,
310
+ input_chunk,
311
+ cache_pos_chunk,
312
+ chunked_attention_mask if self.use_attention_mask else None,
313
+ query_position,
314
+ block_tables,
232
315
  out=out_buffers,
233
316
  )
234
317
 
235
- if self.use_attention_mask:
236
- # Update decoder attention mask with processed KV-cache length from prefill phase
318
+ # Update decoder attention mask with processed KV-cache length from prefill phase
319
+ if not is_external_block_tables and self.use_attention_mask:
237
320
  self.dec_attn_mask[batch_idx].fill_(0)
238
321
  self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
239
322
 
@@ -275,9 +358,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
275
358
  self.batch_size = self.rbln_config.model_cfg["batch_size"]
276
359
  self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
277
360
  self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
361
+ self.kvcache_block_size = self.rbln_config.model_cfg["kvcache_block_size"]
362
+ # FIXME get kvcache_num_blocks from compiled results.
363
+ self.kvcache_num_blocks = self.rbln_config.model_cfg["kvcache_num_blocks"]
278
364
  self.use_attention_mask = self.rbln_config.model_cfg["use_attention_mask"]
279
-
365
+ attn_impl = self.rbln_config.model_cfg["attn_impl"]
280
366
  main_input_name = self.main_input_name
367
+
281
368
  if self.rbln_config.model_cfg["use_inputs_embeds"]:
282
369
  main_input_name = "inputs_embeds"
283
370
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
@@ -291,7 +378,17 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
291
378
  else:
292
379
  self.embed_tokens = None
293
380
 
381
+ # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
294
382
  dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
383
+ if attn_impl == "eager":
384
+ block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).reshape(self.batch_size, 1)
385
+ free_block_pool = None
386
+ else:
387
+ block_tables = torch.zeros(
388
+ self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
389
+ ).fill_(self.kvcache_num_blocks - 1)
390
+ free_block_pool = deque(x for x in range(self.kvcache_num_blocks - 1))
391
+
295
392
  self.prefill_decoder = RBLNRuntimeModel(
296
393
  runtime=self.model[0],
297
394
  main_input_name=main_input_name,
@@ -299,10 +396,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
299
396
  phase="prefill",
300
397
  batch_size=self.batch_size,
301
398
  dec_attn_mask=dec_attn_mask,
399
+ block_tables=block_tables,
400
+ free_block_pool=free_block_pool,
401
+ kvcache_block_size=self.kvcache_block_size,
402
+ kvcache_num_blocks=self.kvcache_num_blocks,
302
403
  vocab_size=self.config.vocab_size,
303
- max_seq_len=self.max_seq_len,
304
404
  prefill_chunk_size=self.prefill_chunk_size,
405
+ max_seq_len=self.max_seq_len,
305
406
  use_attention_mask=self.use_attention_mask,
407
+ attn_impl=attn_impl,
306
408
  )
307
409
  self.decoder = RBLNRuntimeModel(
308
410
  runtime=self.model[1],
@@ -311,7 +413,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
311
413
  phase="decode",
312
414
  batch_size=self.batch_size,
313
415
  dec_attn_mask=dec_attn_mask,
416
+ block_tables=block_tables,
417
+ free_block_pool=free_block_pool,
418
+ kvcache_block_size=self.kvcache_block_size,
419
+ kvcache_num_blocks=self.kvcache_num_blocks,
314
420
  use_attention_mask=self.use_attention_mask,
421
+ attn_impl=attn_impl,
315
422
  )
316
423
 
317
424
  @classmethod
@@ -409,6 +516,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
409
516
  wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
410
517
  wrapper_cfg["attn_impl"] = rbln_config.model_cfg.get("attn_impl")
411
518
  wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
519
+ wrapper_cfg["kvcache_block_size"] = rbln_config.model_cfg.get("kvcache_block_size")
412
520
  wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
413
521
  wrapper_cfg["use_attention_mask"] = rbln_config.model_cfg.get("use_attention_mask")
414
522
 
@@ -474,6 +582,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
474
582
  rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
475
583
  rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
476
584
  rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
585
+ rbln_kvcache_block_size = rbln_kwargs.get("kvcache_block_size", None)
477
586
  rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
478
587
  rbln_prefill_chunk_size = rbln_kwargs.get("prefill_chunk_size", None)
479
588
 
@@ -500,12 +609,22 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
500
609
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
501
610
  rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
502
611
 
503
- rbln_attn_impl, rbln_kvcache_partition_len = validate_attention_method(
612
+ rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size = validate_attention_method(
504
613
  rbln_attn_impl=rbln_attn_impl,
505
614
  rbln_kvcache_partition_len=rbln_kvcache_partition_len,
615
+ rbln_kvcache_block_size=rbln_kvcache_block_size,
506
616
  rbln_max_seq_len=rbln_max_seq_len,
507
617
  )
508
618
 
619
+ if rbln_kvcache_block_size is None:
620
+ if rbln_attn_impl == "eager":
621
+ rbln_kvcache_block_size = rbln_max_seq_len
622
+ else:
623
+ rbln_kvcache_block_size = rbln_kvcache_partition_len
624
+
625
+ # FIXME temporal num_blocks
626
+ rbln_kvcache_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
627
+
509
628
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
510
629
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
511
630
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
@@ -542,19 +661,25 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
542
661
  if query_length > 1:
543
662
  input_info.extend(
544
663
  [
545
- ("batch_position", [], "int16"),
546
664
  ("query_position", [], "int16"),
547
665
  ]
548
666
  )
549
667
 
668
+ max_block_cnt = rbln_max_seq_len // rbln_kvcache_block_size
669
+
670
+ if query_length > 1:
671
+ input_info.extend([("block_tables", [max_block_cnt], "int16")])
672
+ else:
673
+ input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
674
+
550
675
  input_info.extend(
551
676
  [
552
677
  (
553
678
  f"past_key_values_{i}",
554
679
  [
555
- rbln_batch_size,
680
+ rbln_kvcache_num_blocks,
556
681
  num_key_value_heads,
557
- rbln_max_seq_len,
682
+ rbln_kvcache_block_size,
558
683
  head_dim,
559
684
  ],
560
685
  "float32",
@@ -595,7 +720,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
595
720
  "use_attention_mask": rbln_use_attention_mask,
596
721
  "use_inputs_embeds": rbln_use_inputs_embeds,
597
722
  "kvcache_partition_len": rbln_kvcache_partition_len,
723
+ "kvcache_block_size": rbln_kvcache_block_size,
598
724
  "attn_impl": rbln_attn_impl,
725
+ "kvcache_num_blocks": rbln_kvcache_num_blocks,
599
726
  }
600
727
  )
601
728
 
@@ -40,10 +40,15 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
40
40
  new_layers = []
41
41
  for layer in causal_lm.transformer.h:
42
42
  if self.attn_impl == "eager":
43
- new_self_attn = ExaoneAttention(layer.attn.attention, self.use_attention_mask)
43
+ new_self_attn = ExaoneAttention(
44
+ layer.attn.attention, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
45
+ )
44
46
  elif self.attn_impl == "flash_attn":
45
47
  new_self_attn = ExaoneFlashAttention(
46
- layer.attn.attention, kvcache_partition_len=self.kvcache_partition_len
48
+ layer.attn.attention,
49
+ kvcache_partition_len=self.kvcache_partition_len,
50
+ use_attention_mask=self.use_attention_mask,
51
+ kvcache_block_size=self.kvcache_block_size,
47
52
  )
48
53
  else:
49
54
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
@@ -33,10 +33,15 @@ class GemmaWrapper(DecoderOnlyWrapper):
33
33
  new_layers = []
34
34
  for layer in causal_lm.model.layers:
35
35
  if self.attn_impl == "eager":
36
- new_self_attn = DecoderOnlyAttention(layer.self_attn, self.use_attention_mask)
36
+ new_self_attn = DecoderOnlyAttention(
37
+ layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
38
+ )
37
39
  elif self.attn_impl == "flash_attn":
38
40
  new_self_attn = DecoderOnlyFlashAttention(
39
- layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
41
+ layer.self_attn,
42
+ kvcache_partition_len=self.kvcache_partition_len,
43
+ use_attention_mask=self.use_attention_mask,
44
+ kvcache_block_size=self.kvcache_block_size,
40
45
  )
41
46
  else:
42
47
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
@@ -37,7 +37,9 @@ class GPT2Wrapper(DecoderOnlyWrapper):
37
37
  raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
38
38
  new_layers = []
39
39
  for layer in causal_lm.transformer.h:
40
- new_self_attn = GPT2Attention(layer.attn, self.use_attention_mask)
40
+ new_self_attn = GPT2Attention(
41
+ layer.attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
42
+ )
41
43
  new_layer = GPT2Layer(layer, new_self_attn)
42
44
  new_layers.append(new_layer)
43
45
  new_model = GPT2Model(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
@@ -60,7 +60,9 @@ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
60
60
  raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
61
61
  new_layers = []
62
62
  for layer in causal_lm.transformer.h:
63
- new_self_attn = MidmAttention(layer.attn, self.use_attention_mask)
63
+ new_self_attn = MidmAttention(
64
+ layer.attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
65
+ )
64
66
  new_layer = MidmLayer(layer, new_self_attn)
65
67
  new_layers.append(new_layer)
66
68
  new_model = MidmModel(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
@@ -36,7 +36,9 @@ class PhiWrapper(DecoderOnlyWrapper):
36
36
  new_layers = []
37
37
  for layer in causal_lm.model.layers:
38
38
  if self.attn_impl == "eager":
39
- new_self_attn = PhiAttention(layer.self_attn, self.use_attention_mask)
39
+ new_self_attn = PhiAttention(
40
+ layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
41
+ )
40
42
  elif self.attn_impl == "flash_attn":
41
43
  raise NotImplementedError(f"flash attn for {self.__class__} is not implemented yet.")
42
44
  else:
@@ -81,10 +83,10 @@ class PhiLayer(DecoderOnlyLayer):
81
83
  hidden_states: torch.Tensor,
82
84
  attention_mask: torch.Tensor,
83
85
  seq_positions: torch.LongTensor,
84
- batch_position: torch.Tensor,
85
86
  past_key_values: Tuple[Tuple[torch.Tensor]],
86
87
  cos: Optional[torch.Tensor] = None,
87
88
  sin: Optional[torch.Tensor] = None,
89
+ block_tables: Optional[torch.Tensor] = None,
88
90
  ):
89
91
  residual = hidden_states
90
92
 
@@ -94,10 +96,10 @@ class PhiLayer(DecoderOnlyLayer):
94
96
  hidden_states=hidden_states,
95
97
  attention_mask=attention_mask,
96
98
  seq_positions=seq_positions,
97
- batch_position=batch_position,
98
99
  past_key_values=past_key_values,
99
100
  cos=cos,
100
101
  sin=sin,
102
+ block_tables=block_tables,
101
103
  )
102
104
 
103
105
  feed_forward_hidden_states = self._original_mod.mlp(hidden_states)
@@ -50,11 +50,14 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
50
50
  runtime: rebel.Runtime,
51
51
  batch_size: int,
52
52
  dec_max_seq_len: int,
53
+ use_attention_mask: Optional[bool] = None,
53
54
  **kwargs: Any,
54
55
  ) -> None:
55
56
  super().__init__(runtime, **kwargs)
56
57
  self.batch_size = batch_size
57
58
  self.dec_max_seq_len = dec_max_seq_len
59
+ self.use_attention_mask = use_attention_mask
60
+ self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
58
61
 
59
62
  def forward(
60
63
  self,
@@ -62,6 +65,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
62
65
  attention_mask: Optional[torch.FloatTensor] = None,
63
66
  decoder_attention_mask: Optional[torch.BoolTensor] = None,
64
67
  cache_position: Optional[torch.Tensor] = None,
68
+ block_tables: Optional[torch.Tensor] = None,
65
69
  **kwargs,
66
70
  ) -> Tuple[torch.FloatTensor]:
67
71
  batch_size = decoder_input_ids.shape[0]
@@ -73,19 +77,24 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
73
77
  if batch_size != cache_position.shape[0]:
74
78
  raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
75
79
 
76
- for b_idx in range(self.batch_size):
77
- decoding_step = cache_position[b_idx].item()
78
- if not (0 <= decoding_step < self.dec_max_seq_len):
79
- raise ValueError(
80
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
81
- )
82
- decoder_attention_mask[b_idx, : decoding_step + 1] = 1
80
+ if self.use_attention_mask:
81
+ for b_idx in range(self.batch_size):
82
+ decoding_step = cache_position[b_idx].item()
83
+ if not (0 <= decoding_step < self.dec_max_seq_len):
84
+ raise ValueError(
85
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
86
+ )
87
+ decoder_attention_mask[b_idx, : decoding_step + 1] = 1
88
+
89
+ if block_tables is None:
90
+ block_tables = self.default_block_tables
83
91
 
84
92
  lm_logits = super().forward(
85
93
  decoder_input_ids,
86
- decoder_attention_mask,
94
+ decoder_attention_mask if self.use_attention_mask else None,
87
95
  attention_mask,
88
96
  cache_position,
97
+ block_tables,
89
98
  )
90
99
 
91
100
  return Seq2SeqLMOutput(logits=lm_logits)
@@ -110,12 +119,18 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
110
119
  def __post_init__(self, **kwargs):
111
120
  batch_size = self.rbln_config.model_cfg["batch_size"]
112
121
  dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
122
+ self.use_attention_mask = self.rbln_config.model_cfg.get("use_attention_mask", None)
123
+
113
124
  self.encoder = RBLNRuntimeEncoder(
114
125
  runtime=self.model[0],
115
126
  main_input_name="input_ids",
116
127
  )
117
128
  self.decoder = RBLNRuntimeDecoder(
118
- runtime=self.model[1], main_input_name="input_ids", batch_size=batch_size, dec_max_seq_len=dec_max_seq_len
129
+ runtime=self.model[1],
130
+ main_input_name="input_ids",
131
+ batch_size=batch_size,
132
+ dec_max_seq_len=dec_max_seq_len,
133
+ use_attention_mask=self.use_attention_mask,
119
134
  )
120
135
 
121
136
  @classmethod
@@ -171,6 +186,13 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
171
186
  rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
172
187
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
173
188
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
189
+ rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
190
+
191
+ if rbln_use_attention_mask is None:
192
+ rbln_use_attention_mask = False
193
+ rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
194
+ if rbln_npu == "RBLN-CA02":
195
+ rbln_use_attention_mask = True
174
196
 
175
197
  n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
176
198
  n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
@@ -232,18 +254,22 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
232
254
  ],
233
255
  "float32",
234
256
  ),
235
- ("batch_position", [], "int16"),
257
+ ("block_tables", [1], "int16"),
236
258
  ]
237
259
 
238
260
  dec_input_info = [
239
261
  ("input_ids", [rbln_batch_size, 1], "int64"),
240
- ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
241
262
  ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
242
263
  (
243
264
  "cache_position",
244
265
  [rbln_batch_size, 1],
245
266
  "int32",
246
267
  ),
268
+ (
269
+ "block_tables",
270
+ [rbln_batch_size, 1],
271
+ "int16",
272
+ ),
247
273
  ]
248
274
  dec_input_info.extend(
249
275
  [
@@ -275,6 +301,10 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
275
301
  for i in range(n_layer * 2)
276
302
  ]
277
303
  )
304
+
305
+ if rbln_use_attention_mask:
306
+ dec_input_info.insert(1, ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
307
+
278
308
  enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
279
309
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
280
310
 
@@ -290,6 +320,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
290
320
  "dec_max_seq_len": rbln_dec_max_seq_len,
291
321
  "batch_size": rbln_batch_size,
292
322
  "pad_token_id": rbln_pad_token_id,
323
+ "use_attention_mask": rbln_use_attention_mask,
293
324
  }
294
325
  )
295
326
 
@@ -400,9 +431,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
400
431
  encoder_kwargs["output_attentions"] = False
401
432
 
402
433
  for b in range(batch_size):
403
- batch_position = torch.tensor(b, dtype=torch.int16)
434
+ block_tables = torch.tensor([b], dtype=torch.int16)
404
435
  encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
405
436
  encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
406
- model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, batch_position=batch_position)
437
+ model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, block_tables=block_tables)
407
438
 
408
439
  return model_kwargs