optimum-rbln 0.7.2rc2__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +8 -0
  4. optimum/rbln/diffusers/modeling_diffusers.py +103 -117
  5. optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -3
  6. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +15 -8
  7. optimum/rbln/diffusers/pipelines/__init__.py +8 -0
  8. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +7 -1
  9. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +25 -0
  10. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +107 -1
  11. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +25 -0
  12. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +3 -0
  13. optimum/rbln/modeling.py +4 -1
  14. optimum/rbln/modeling_base.py +16 -3
  15. optimum/rbln/ops/__init__.py +6 -2
  16. optimum/rbln/ops/attn.py +94 -85
  17. optimum/rbln/ops/flash_attn.py +46 -25
  18. optimum/rbln/ops/kv_cache_update.py +4 -4
  19. optimum/rbln/transformers/modeling_generic.py +3 -3
  20. optimum/rbln/transformers/models/bart/bart_architecture.py +10 -6
  21. optimum/rbln/transformers/models/bart/modeling_bart.py +6 -2
  22. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -1
  23. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +264 -133
  24. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +276 -29
  25. optimum/rbln/transformers/models/exaone/exaone_architecture.py +11 -4
  26. optimum/rbln/transformers/models/gemma/gemma_architecture.py +11 -4
  27. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +5 -3
  28. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -3
  29. optimum/rbln/transformers/models/phi/phi_architecture.py +9 -7
  30. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +50 -13
  31. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +60 -36
  32. optimum/rbln/transformers/models/t5/modeling_t5.py +3 -1
  33. optimum/rbln/transformers/models/t5/t5_architecture.py +65 -3
  34. optimum/rbln/transformers/models/whisper/whisper_architecture.py +26 -36
  35. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -14
  36. optimum/rbln/utils/import_utils.py +7 -0
  37. {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/METADATA +1 -1
  38. {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/RECORD +40 -38
  39. {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/WHEEL +0 -0
  40. {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/licenses/LICENSE +0 -0
@@ -13,9 +13,11 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
+ import math
17
+ from collections import deque
16
18
  from dataclasses import dataclass
17
19
  from pathlib import Path
18
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
20
+ from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tuple, Union
19
21
 
20
22
  import rebel
21
23
  import torch
@@ -50,14 +52,28 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
50
52
  phase: str,
51
53
  batch_size: int,
52
54
  dec_attn_mask: torch.Tensor,
55
+ block_tables: torch.Tensor,
56
+ free_block_pool: Deque,
57
+ kvcache_block_size: int,
58
+ use_attention_mask: bool,
59
+ attn_impl: str,
53
60
  **kwargs: Any,
54
61
  ) -> None:
55
62
  super().__init__(runtime, **kwargs)
56
63
  self.phase = phase
57
64
  self.batch_size = batch_size
58
65
 
66
+ # shared data structures between prefill and decode phase
67
+ self.use_attention_mask = use_attention_mask
68
+
59
69
  # shared tensor between prefill and decode phase
60
70
  self.dec_attn_mask = dec_attn_mask
71
+ self.block_tables = block_tables
72
+ self.free_block_pool = free_block_pool
73
+
74
+ self.kvcache_block_size = kvcache_block_size
75
+ self.empty_block = -1
76
+ self.attn_impl = attn_impl
61
77
 
62
78
  if self.phase == "prefill":
63
79
  vocab_size = kwargs.pop("vocab_size")
@@ -68,6 +84,75 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
68
84
  torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
69
85
  )
70
86
 
87
+ def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None):
88
+ """
89
+ Manages and returns the KV cache block tables.
90
+ Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
91
+
92
+ Args:
93
+ cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
94
+ batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
95
+
96
+ Returns:
97
+ torch.Tensor: Updated block tables.
98
+ """
99
+
100
+ NO_BLOCKS_ERROR = (
101
+ "No memory blocks are available for allocation. "
102
+ "The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln. "
103
+ "This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html). "
104
+ "Using vllm-rbln should fix this issue and enhance inference performance."
105
+ )
106
+
107
+ def update_block(batch_idx: int, block_idx: int):
108
+ """
109
+ If the block is empty (empty_block), allocates a block from the free_block_pool.
110
+ """
111
+ if self.block_tables[batch_idx][block_idx] == self.empty_block:
112
+ if self.free_block_pool:
113
+ block = self.free_block_pool.popleft()
114
+ self.block_tables[batch_idx][block_idx] = block
115
+ else:
116
+ raise RuntimeError(NO_BLOCKS_ERROR)
117
+
118
+ def replace_empty_block(block_tables: torch.Tensor):
119
+ """
120
+ Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
121
+ """
122
+ if not torch.any(block_tables == self.empty_block):
123
+ return block_tables.clone()
124
+ elif self.free_block_pool:
125
+ _free_block = self.free_block_pool[0]
126
+ return torch.where(block_tables == self.empty_block, _free_block, block_tables)
127
+ else:
128
+ raise RuntimeError(NO_BLOCKS_ERROR)
129
+
130
+ if self.phase == "prefill":
131
+ # Track previously used blocks and return them to the free_block_pool and
132
+ # reset the current batch's block table to empty blocks
133
+ prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
134
+ self.free_block_pool.extend(prev_blocks)
135
+ self.block_tables[batch_idx].fill_(self.empty_block)
136
+
137
+ # Get the start (s) and end (e) positions from cache_position and
138
+ # iterate over the cache positions to allocate necessary blocks
139
+ s, e = cache_position[0][0].item(), cache_position[0][-1].item()
140
+ for position in range(s, e + 1, self.kvcache_block_size):
141
+ block_idx = position // self.kvcache_block_size
142
+ if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
143
+ raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
144
+ update_block(batch_idx, block_idx)
145
+
146
+ return replace_empty_block(self.block_tables[batch_idx])
147
+ # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
148
+ else:
149
+ for b_idx in range(self.batch_size):
150
+ position = cache_position[b_idx][0].item()
151
+ block_idx = position // self.kvcache_block_size
152
+ update_block(b_idx, block_idx)
153
+
154
+ return replace_empty_block(self.block_tables)
155
+
71
156
  def forward(
72
157
  self,
73
158
  input_ids: Optional[torch.LongTensor] = None,
@@ -75,6 +160,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
75
160
  cache_position: torch.Tensor = None,
76
161
  attention_mask: Optional[torch.Tensor] = None,
77
162
  batch_idx: Optional[int] = None,
163
+ block_tables: Optional[torch.Tensor] = None,
78
164
  ):
79
165
  if input_ids is None and inputs_embeds is None:
80
166
  raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
@@ -86,19 +172,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
86
172
  else:
87
173
  inputs = inputs_embeds
88
174
 
175
+ if block_tables is None:
176
+ block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
177
+ is_external_block_tables = False
178
+ else:
179
+ is_external_block_tables = True
180
+
89
181
  if self.phase == "decode":
90
182
  return self.decode_forward(
91
183
  inputs,
92
184
  cache_position,
185
+ block_tables,
186
+ is_external_block_tables,
93
187
  attention_mask=attention_mask,
94
188
  )
95
189
  else:
96
- return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx)
190
+ return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx, block_tables)
97
191
 
98
192
  def decode_forward(
99
193
  self,
100
194
  inputs: torch.Tensor,
101
195
  cache_position: torch.Tensor = None,
196
+ block_tables: torch.Tensor = None,
197
+ is_external_block_tables: bool = None,
102
198
  attention_mask: Optional[torch.Tensor] = None,
103
199
  ) -> torch.FloatTensor:
104
200
  batch_size = inputs.shape[0]
@@ -110,19 +206,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
110
206
  if batch_size != cache_position.shape[0]:
111
207
  raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
112
208
 
113
- if attention_mask is None:
209
+ if self.use_attention_mask and attention_mask is None:
114
210
  for b_idx in range(batch_size):
115
211
  decoding_step = cache_position[b_idx].item()
116
212
  if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
117
213
  raise ValueError(
118
214
  f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
119
215
  )
120
- self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
216
+
217
+ if is_external_block_tables:
218
+ self.dec_attn_mask[b_idx].fill_(0)
219
+ self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
220
+ else:
221
+ self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
222
+
223
+ attention_mask = self.dec_attn_mask
224
+
225
+ attention_mask = self.dec_attn_mask
121
226
 
122
227
  logits = super().forward(
123
228
  inputs,
124
- self.dec_attn_mask if attention_mask is None else attention_mask,
125
229
  cache_position,
230
+ attention_mask if self.use_attention_mask else None,
231
+ block_tables,
126
232
  )
127
233
 
128
234
  return logits
@@ -133,6 +239,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
133
239
  cache_position: torch.Tensor = None,
134
240
  attention_mask: Optional[torch.Tensor] = None,
135
241
  batch_idx: int = None,
242
+ block_tables: torch.Tensor = None,
243
+ is_external_block_tables: bool = None,
136
244
  ) -> torch.FloatTensor:
137
245
  """
138
246
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
@@ -140,11 +248,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
140
248
  and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
141
249
  """
142
250
 
143
- if batch_idx is None or batch_idx >= self.batch_size:
144
- raise RuntimeError(
145
- f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
146
- )
147
-
148
251
  # Handle continuous batching in a compiled graph by extracting valid inputs
149
252
  # If an attention mask is provided, select only the valid (non-masked) inputs
150
253
  inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
@@ -156,7 +259,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
156
259
  )
157
260
 
158
261
  # Initialize attention mask for chunked processing
159
- chunked_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
262
+ if self.use_attention_mask:
263
+ chunked_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
160
264
 
161
265
  # Buffer for storing output logits
162
266
  out_buffers = [
@@ -195,28 +299,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
195
299
  input_chunk = inputs[:, step : step + self.prefill_chunk_size]
196
300
  cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
197
301
 
198
- # Update attention mask to ensure proper causal behavior
199
- if step >= self.prefill_chunk_size:
200
- chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
201
- chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
302
+ if self.use_attention_mask:
303
+ # Update attention mask to ensure proper causal behavior
304
+ if step >= self.prefill_chunk_size:
305
+ chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
306
+ chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
202
307
 
203
- # Define batch position and query position
204
- batch_position = torch.tensor(batch_idx, dtype=torch.int16)
308
+ # Define query position
205
309
  query_position = torch.tensor((query_length - 1) % self.prefill_chunk_size, dtype=torch.int16)
206
310
 
207
311
  # Forward pass for the current chunk
208
312
  logits = super().forward(
209
313
  input_chunk,
210
- chunked_attention_mask,
211
314
  cache_pos_chunk,
212
- batch_position,
315
+ chunked_attention_mask if self.use_attention_mask else None,
213
316
  query_position,
317
+ block_tables,
214
318
  out=out_buffers,
215
319
  )
216
320
 
217
321
  # Update decoder attention mask with processed KV-cache length from prefill phase
218
- self.dec_attn_mask[batch_idx].fill_(0)
219
- self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
322
+ if not is_external_block_tables and self.use_attention_mask:
323
+ self.dec_attn_mask[batch_idx].fill_(0)
324
+ self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
220
325
 
221
326
  return logits
222
327
 
@@ -256,8 +361,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
256
361
  self.batch_size = self.rbln_config.model_cfg["batch_size"]
257
362
  self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
258
363
  self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
259
-
364
+ self.kvcache_block_size = self.rbln_config.model_cfg["kvcache_block_size"]
365
+ # FIXME get kvcache_num_blocks from compiled results.
366
+ self.kvcache_num_blocks = self.rbln_config.model_cfg["kvcache_num_blocks"]
367
+ self.use_attention_mask = self.rbln_config.model_cfg["use_attention_mask"]
368
+ attn_impl = self.rbln_config.model_cfg["attn_impl"]
260
369
  main_input_name = self.main_input_name
370
+
261
371
  if self.rbln_config.model_cfg["use_inputs_embeds"]:
262
372
  main_input_name = "inputs_embeds"
263
373
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
@@ -271,7 +381,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
271
381
  else:
272
382
  self.embed_tokens = None
273
383
 
384
+ # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
274
385
  dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
386
+ block_tables = torch.zeros(
387
+ self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
388
+ ).fill_(-1)
389
+ free_block_pool = deque(x for x in range(self.kvcache_num_blocks))
390
+
275
391
  self.prefill_decoder = RBLNRuntimeModel(
276
392
  runtime=self.model[0],
277
393
  main_input_name=main_input_name,
@@ -279,9 +395,14 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
279
395
  phase="prefill",
280
396
  batch_size=self.batch_size,
281
397
  dec_attn_mask=dec_attn_mask,
398
+ block_tables=block_tables,
399
+ free_block_pool=free_block_pool,
400
+ kvcache_block_size=self.kvcache_block_size,
282
401
  vocab_size=self.config.vocab_size,
283
- max_seq_len=self.max_seq_len,
284
402
  prefill_chunk_size=self.prefill_chunk_size,
403
+ max_seq_len=self.max_seq_len,
404
+ use_attention_mask=self.use_attention_mask,
405
+ attn_impl=attn_impl,
285
406
  )
286
407
  self.decoder = RBLNRuntimeModel(
287
408
  runtime=self.model[1],
@@ -290,6 +411,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
290
411
  phase="decode",
291
412
  batch_size=self.batch_size,
292
413
  dec_attn_mask=dec_attn_mask,
414
+ block_tables=block_tables,
415
+ free_block_pool=free_block_pool,
416
+ kvcache_block_size=self.kvcache_block_size,
417
+ use_attention_mask=self.use_attention_mask,
418
+ attn_impl=attn_impl,
293
419
  )
294
420
 
295
421
  @classmethod
@@ -363,7 +489,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
363
489
  def redirect(func):
364
490
  return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
365
491
 
366
- val = getattr(self.hf_class, __name, None) or getattr(PreTrainedModel, __name)
492
+ val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
367
493
  if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
368
494
  return redirect(val)
369
495
  return val
@@ -387,7 +513,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
387
513
  wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
388
514
  wrapper_cfg["attn_impl"] = rbln_config.model_cfg.get("attn_impl")
389
515
  wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
516
+ wrapper_cfg["kvcache_block_size"] = rbln_config.model_cfg.get("kvcache_block_size")
390
517
  wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
518
+ wrapper_cfg["use_attention_mask"] = rbln_config.model_cfg.get("use_attention_mask")
391
519
 
392
520
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
393
521
 
@@ -438,6 +566,71 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
438
566
 
439
567
  return compile_model(quantize_config=quantize_config)
440
568
 
569
+ @classmethod
570
+ def get_maximum_num_blocks(
571
+ cls,
572
+ config: PretrainedConfig,
573
+ tensor_parallel_size: int,
574
+ kvcache_block_size: int,
575
+ nbits_per_param: int,
576
+ n_model_params: int,
577
+ ) -> int:
578
+ def align(x: int, nbytes: int) -> int:
579
+ return int(math.ceil(x / nbytes) * nbytes)
580
+
581
+ def align_2MB(x: int) -> int:
582
+ return align(x, 2 * 1024 * 1024)
583
+
584
+ num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
585
+ num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
586
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
587
+ vocab_size = config.vocab_size
588
+ hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
589
+ num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
590
+
591
+ # TODO(jongho): Update if target npu is REBEL.
592
+ ATOM_DRAM_NBYTES = 16 * 2**30
593
+ ATOM_SYS_DRAM_NBYTES = 288 * 2**20
594
+ available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
595
+
596
+ # Get estimated kernel size (approximated)
597
+ lm_heads_params = align(vocab_size, 64) * hidden_size
598
+ lm_heads_nbytes = (
599
+ align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
600
+ )
601
+ params = n_model_params - lm_heads_params
602
+ layer_nbytes = (
603
+ align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
604
+ * num_layers
605
+ * tensor_parallel_size
606
+ )
607
+ kernel_size = layer_nbytes + lm_heads_nbytes
608
+
609
+ available_dram -= kernel_size
610
+
611
+ # TODO: Accurate buffer estimation
612
+ buffer = 2**30 # 1GB Buffer
613
+ if tensor_parallel_size <= 4:
614
+ buffer /= 4
615
+
616
+ available_dram -= buffer
617
+
618
+ # Estimate nbytes per a single kvcache block
619
+ nbytes_per_block = (
620
+ align_2MB(
621
+ kvcache_block_size
622
+ * head_dim
623
+ * math.ceil(num_key_value_heads / tensor_parallel_size) # Shard
624
+ * 2 # (fp16)
625
+ )
626
+ * num_layers
627
+ * 2 # (k, v)
628
+ * tensor_parallel_size
629
+ )
630
+ n_blocks = available_dram // nbytes_per_block
631
+
632
+ return n_blocks, nbytes_per_block
633
+
441
634
  @classmethod
442
635
  def _get_rbln_config(
443
636
  cls,
@@ -448,11 +641,19 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
448
641
  rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
449
642
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
450
643
  rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
644
+ rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
451
645
  rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
452
646
  rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
647
+ rbln_kvcache_block_size = rbln_kwargs.get("kvcache_block_size", None)
453
648
  rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
454
649
  rbln_prefill_chunk_size = rbln_kwargs.get("prefill_chunk_size", None)
455
650
 
651
+ if rbln_use_attention_mask is None:
652
+ rbln_use_attention_mask = False
653
+ rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
654
+ if rbln_npu == "RBLN-CA02":
655
+ rbln_use_attention_mask = True
656
+
456
657
  if rbln_prefill_chunk_size is None:
457
658
  rbln_prefill_chunk_size = 128
458
659
  elif rbln_prefill_chunk_size % 64 != 0 or rbln_prefill_chunk_size == 0:
@@ -470,12 +671,42 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
470
671
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
471
672
  rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
472
673
 
473
- rbln_attn_impl, rbln_kvcache_partition_len = validate_attention_method(
674
+ rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size = validate_attention_method(
474
675
  rbln_attn_impl=rbln_attn_impl,
475
676
  rbln_kvcache_partition_len=rbln_kvcache_partition_len,
677
+ rbln_kvcache_block_size=rbln_kvcache_block_size,
476
678
  rbln_max_seq_len=rbln_max_seq_len,
477
679
  )
478
680
 
681
+ if rbln_kvcache_block_size is None:
682
+ if rbln_attn_impl == "eager":
683
+ rbln_kvcache_block_size = rbln_max_seq_len
684
+ else:
685
+ rbln_kvcache_block_size = rbln_kvcache_partition_len
686
+
687
+ rbln_kvcache_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
688
+ if rbln_attn_impl == "flash_attn":
689
+ max_num_blocks, _ = cls.get_maximum_num_blocks(
690
+ config=model_config,
691
+ tensor_parallel_size=rbln_kwargs.get("tensor_parallel_size", 1),
692
+ kvcache_block_size=rbln_kvcache_block_size,
693
+ nbits_per_param=16 if rbln_quantization is None else 4, # TODO(jongho): FIX Ad-hoc
694
+ n_model_params=rbln_kwargs["n_model_params"],
695
+ )
696
+ rbln_kvcache_num_blocks = min(rbln_kvcache_num_blocks, max_num_blocks)
697
+
698
+ required_blocks = rbln_max_seq_len // rbln_kvcache_block_size + 1
699
+ if rbln_kvcache_num_blocks < required_blocks:
700
+ rbln_kvcache_num_blocks = required_blocks
701
+
702
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_kvcache_num_blocks}")
703
+
704
+ if rbln_kvcache_num_blocks < rbln_batch_size:
705
+ raise RuntimeError(
706
+ f"Batch size ({rbln_batch_size}) exceeds available KV cache blocks ({rbln_kvcache_num_blocks}). "
707
+ "Ensure the number of blocks is at least equal to the batch size."
708
+ )
709
+
479
710
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
480
711
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
481
712
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
@@ -495,29 +726,42 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
495
726
 
496
727
  input_info = [
497
728
  main_input,
498
- ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
499
729
  (
500
730
  "cache_position",
501
731
  [batch_size, query_length],
502
732
  "int32",
503
733
  ),
504
734
  ]
735
+
736
+ if rbln_use_attention_mask:
737
+ input_info.extend(
738
+ [
739
+ ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
740
+ ]
741
+ )
742
+
505
743
  if query_length > 1:
506
744
  input_info.extend(
507
745
  [
508
- ("batch_position", [], "int16"),
509
746
  ("query_position", [], "int16"),
510
747
  ]
511
748
  )
512
749
 
750
+ max_block_cnt = rbln_max_seq_len // rbln_kvcache_block_size
751
+
752
+ if query_length > 1:
753
+ input_info.extend([("block_tables", [max_block_cnt], "int16")])
754
+ else:
755
+ input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
756
+
513
757
  input_info.extend(
514
758
  [
515
759
  (
516
760
  f"past_key_values_{i}",
517
761
  [
518
- rbln_batch_size,
762
+ rbln_kvcache_num_blocks,
519
763
  num_key_value_heads,
520
- rbln_max_seq_len,
764
+ rbln_kvcache_block_size,
521
765
  head_dim,
522
766
  ],
523
767
  "float32",
@@ -555,9 +799,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
555
799
  "max_seq_len": rbln_max_seq_len,
556
800
  "batch_size": rbln_batch_size,
557
801
  "prefill_chunk_size": rbln_prefill_chunk_size,
802
+ "use_attention_mask": rbln_use_attention_mask,
558
803
  "use_inputs_embeds": rbln_use_inputs_embeds,
559
804
  "kvcache_partition_len": rbln_kvcache_partition_len,
805
+ "kvcache_block_size": rbln_kvcache_block_size,
560
806
  "attn_impl": rbln_attn_impl,
807
+ "kvcache_num_blocks": rbln_kvcache_num_blocks,
561
808
  }
562
809
  )
563
810
 
@@ -36,21 +36,28 @@ logger = logging.get_logger(__name__)
36
36
  class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
37
37
  """A wrapper class for the Exaone model with a language modeling head."""
38
38
 
39
- def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM"):
39
+ def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM", max_seq_len: int):
40
40
  new_layers = []
41
41
  for layer in causal_lm.transformer.h:
42
42
  if self.attn_impl == "eager":
43
- new_self_attn = ExaoneAttention(layer.attn.attention)
43
+ new_self_attn = ExaoneAttention(
44
+ layer.attn.attention, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
45
+ )
44
46
  elif self.attn_impl == "flash_attn":
45
47
  new_self_attn = ExaoneFlashAttention(
46
- layer.attn.attention, kvcache_partition_len=self.kvcache_partition_len
48
+ layer.attn.attention,
49
+ kvcache_partition_len=self.kvcache_partition_len,
50
+ use_attention_mask=self.use_attention_mask,
51
+ kvcache_block_size=self.kvcache_block_size,
47
52
  )
48
53
  else:
49
54
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
50
55
 
51
56
  new_layer = ExaoneLayer(layer, new_self_attn)
52
57
  new_layers.append(new_layer)
53
- new_model = ExaoneModel(causal_lm.transformer, new_layers, partition_len=self.kvcache_partition_len)
58
+ new_model = ExaoneModel(
59
+ causal_lm.transformer, new_layers, partition_len=self.kvcache_partition_len, max_seq_len=max_seq_len
60
+ )
54
61
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
55
62
  return new_causal_lm
56
63
 
@@ -29,20 +29,27 @@ if TYPE_CHECKING:
29
29
 
30
30
 
31
31
  class GemmaWrapper(DecoderOnlyWrapper):
32
- def convert_to_rbln_causal_lm(self, causal_lm: "GemmaForCausalLM"):
32
+ def convert_to_rbln_causal_lm(self, causal_lm: "GemmaForCausalLM", max_seq_len: int):
33
33
  new_layers = []
34
34
  for layer in causal_lm.model.layers:
35
35
  if self.attn_impl == "eager":
36
- new_self_attn = DecoderOnlyAttention(layer.self_attn)
36
+ new_self_attn = DecoderOnlyAttention(
37
+ layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
38
+ )
37
39
  elif self.attn_impl == "flash_attn":
38
40
  new_self_attn = DecoderOnlyFlashAttention(
39
- layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
41
+ layer.self_attn,
42
+ kvcache_partition_len=self.kvcache_partition_len,
43
+ use_attention_mask=self.use_attention_mask,
44
+ kvcache_block_size=self.kvcache_block_size,
40
45
  )
41
46
  else:
42
47
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
43
48
  new_layer = DecoderOnlyLayer(layer, new_self_attn)
44
49
  new_layers.append(new_layer)
45
- new_model = GemmaModel(causal_lm.model, new_layers, partition_len=self.kvcache_partition_len)
50
+ new_model = GemmaModel(
51
+ causal_lm.model, new_layers, partition_len=self.kvcache_partition_len, max_seq_len=max_seq_len
52
+ )
46
53
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
47
54
  return new_causal_lm
48
55
 
@@ -32,15 +32,17 @@ if TYPE_CHECKING:
32
32
 
33
33
 
34
34
  class GPT2Wrapper(DecoderOnlyWrapper):
35
- def convert_to_rbln_causal_lm(self, causal_lm: "GPT2LMHeadModel"):
35
+ def convert_to_rbln_causal_lm(self, causal_lm: "GPT2LMHeadModel", max_seq_len: int):
36
36
  if self.attn_impl != "eager":
37
37
  raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
38
38
  new_layers = []
39
39
  for layer in causal_lm.transformer.h:
40
- new_self_attn = GPT2Attention(layer.attn)
40
+ new_self_attn = GPT2Attention(
41
+ layer.attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
42
+ )
41
43
  new_layer = GPT2Layer(layer, new_self_attn)
42
44
  new_layers.append(new_layer)
43
- new_model = GPT2Model(causal_lm.transformer, new_layers)
45
+ new_model = GPT2Model(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
44
46
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
45
47
  return new_causal_lm
46
48
 
@@ -55,15 +55,17 @@ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
55
55
  self.config.partial_rotary_factor = self.config.rotary_percentage
56
56
  return super().get_rotary_emb(max_seq_len=max_seq_len)
57
57
 
58
- def convert_to_rbln_causal_lm(self, causal_lm: "MidmLMHeadModel"):
58
+ def convert_to_rbln_causal_lm(self, causal_lm: "MidmLMHeadModel", max_seq_len: int):
59
59
  if self.attn_impl != "eager":
60
60
  raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
61
61
  new_layers = []
62
62
  for layer in causal_lm.transformer.h:
63
- new_self_attn = MidmAttention(layer.attn)
63
+ new_self_attn = MidmAttention(
64
+ layer.attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
65
+ )
64
66
  new_layer = MidmLayer(layer, new_self_attn)
65
67
  new_layers.append(new_layer)
66
- new_model = MidmModel(causal_lm.transformer, new_layers)
68
+ new_model = MidmModel(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
67
69
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
68
70
  return new_causal_lm
69
71