optimum-rbln 0.7.4a9__py3-none-any.whl → 0.7.5a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. optimum/rbln/__init__.py +21 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +11 -7
  4. optimum/rbln/diffusers/models/controlnet.py +1 -1
  5. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
  6. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -1
  7. optimum/rbln/modeling.py +7 -5
  8. optimum/rbln/ops/__init__.py +1 -0
  9. optimum/rbln/ops/attn.py +10 -0
  10. optimum/rbln/ops/flash_attn.py +8 -0
  11. optimum/rbln/ops/sliding_window_attn.py +111 -0
  12. optimum/rbln/transformers/__init__.py +22 -3
  13. optimum/rbln/transformers/models/__init__.py +23 -0
  14. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  15. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +93 -0
  16. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +298 -0
  17. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +42 -6
  18. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +81 -77
  19. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +251 -135
  20. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +11 -7
  21. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
  22. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  23. optimum/rbln/transformers/models/opt/configuration_opt.py +19 -0
  24. optimum/rbln/transformers/models/opt/modeling_opt.py +78 -0
  25. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  26. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +16 -10
  27. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +35 -52
  28. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -0
  29. optimum/rbln/transformers/models/siglip/__init__.py +20 -0
  30. optimum/rbln/transformers/models/siglip/configuration_siglip.py +66 -0
  31. optimum/rbln/transformers/models/siglip/modeling_siglip.py +146 -0
  32. optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -0
  33. optimum/rbln/transformers/utils/rbln_quantization.py +121 -72
  34. optimum/rbln/utils/import_utils.py +23 -6
  35. optimum/rbln/utils/submodule.py +13 -1
  36. {optimum_rbln-0.7.4a9.dist-info → optimum_rbln-0.7.5a1.dist-info}/METADATA +1 -1
  37. {optimum_rbln-0.7.4a9.dist-info → optimum_rbln-0.7.5a1.dist-info}/RECORD +39 -28
  38. {optimum_rbln-0.7.4a9.dist-info → optimum_rbln-0.7.5a1.dist-info}/WHEEL +0 -0
  39. {optimum_rbln-0.7.4a9.dist-info → optimum_rbln-0.7.5a1.dist-info}/licenses/LICENSE +0 -0
@@ -30,7 +30,7 @@ from ....configuration_utils import RBLNCompileConfig
30
30
  from ....modeling import RBLNModel
31
31
  from ....utils.logging import get_logger
32
32
  from ....utils.runtime_utils import RBLNPytorchRuntime
33
- from ...utils.rbln_quantization import QuantizationManager
33
+ from ...utils.rbln_quantization import prepare_model_for_quantization
34
34
  from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
35
35
  from .decoderonly_architecture import (
36
36
  DecoderOnlyWrapper,
@@ -59,6 +59,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
59
59
  kvcache_block_size: int,
60
60
  use_attention_mask: bool,
61
61
  attn_impl: str,
62
+ use_position_ids: bool,
62
63
  **kwargs: Any,
63
64
  ) -> None:
64
65
  super().__init__(runtime, **kwargs)
@@ -72,6 +73,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
72
73
  self.dec_attn_mask = dec_attn_mask
73
74
  self.block_tables = block_tables
74
75
  self.free_block_pool = free_block_pool
76
+ self.use_position_ids = use_position_ids
75
77
 
76
78
  self.kvcache_block_size = kvcache_block_size
77
79
  self.empty_block = -1
@@ -164,6 +166,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
164
166
  batch_idx: Optional[int] = None,
165
167
  block_tables: Optional[torch.Tensor] = None,
166
168
  position_embed: Optional[torch.Tensor] = None,
169
+ position_ids: Optional[torch.Tensor] = None,
167
170
  ):
168
171
  if input_ids is None and inputs_embeds is None:
169
172
  raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
@@ -189,10 +192,16 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
189
192
  is_external_block_tables,
190
193
  attention_mask=attention_mask,
191
194
  position_embed=position_embed,
195
+ position_ids=position_ids,
192
196
  )
193
197
  else:
194
198
  return self.prefill_forward(
195
- inputs, cache_position, attention_mask, batch_idx, block_tables, position_embed=position_embed
199
+ inputs,
200
+ cache_position,
201
+ attention_mask,
202
+ batch_idx,
203
+ block_tables,
204
+ position_embed=position_embed,
196
205
  )
197
206
 
198
207
  def decode_forward(
@@ -203,6 +212,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
203
212
  is_external_block_tables: bool = None,
204
213
  attention_mask: Optional[torch.Tensor] = None,
205
214
  position_embed: Optional[torch.Tensor] = None,
215
+ position_ids: Optional[torch.Tensor] = None,
206
216
  ) -> torch.FloatTensor:
207
217
  batch_size = inputs.shape[0]
208
218
  if batch_size != self.batch_size:
@@ -229,32 +239,33 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
229
239
 
230
240
  attention_mask = self.dec_attn_mask
231
241
 
242
+ if self.batch_size < block_tables.shape[0]:
243
+ block_tables = block_tables[: self.batch_size]
244
+
245
+ if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
246
+ attention_mask = attention_mask[: self.batch_size]
247
+
232
248
  logits = super().forward(
233
249
  inputs,
234
250
  cache_position,
235
- attention_mask if self.use_attention_mask else None,
236
251
  block_tables,
237
252
  position_embed,
253
+ attention_mask if self.use_attention_mask else None,
254
+ position_ids if self.use_position_ids else None,
238
255
  )
239
256
 
240
- return logits
257
+ return RBLNDecoderOnlyOutput(logits=logits)
241
258
 
242
- def prefill_forward(
259
+ def _prepare_prefill_inputs(
243
260
  self,
244
261
  inputs: torch.Tensor,
245
- cache_position: torch.Tensor = None,
262
+ cache_position: torch.Tensor,
246
263
  attention_mask: Optional[torch.Tensor] = None,
247
- batch_idx: int = None,
248
- block_tables: torch.Tensor = None,
249
- is_external_block_tables: bool = None,
250
264
  position_embed: Optional[torch.Tensor] = None,
251
- ) -> torch.FloatTensor:
265
+ ):
252
266
  """
253
- Performs chunked prefill for efficient KV-cache updates and memory optimization.
254
- Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
255
- and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
267
+ Prepare inputs for prefill phase.
256
268
  """
257
-
258
269
  # Handle continuous batching in a compiled graph by extracting valid inputs
259
270
  # If an attention mask is provided, select only the valid (non-masked) inputs
260
271
  inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
@@ -270,8 +281,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
270
281
  )
271
282
 
272
283
  # Initialize attention mask for chunked processing
273
- if self.use_attention_mask:
274
- chunked_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
284
+ chunked_attention_mask = (
285
+ torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
286
+ if self.use_attention_mask
287
+ else None
288
+ )
275
289
 
276
290
  # Buffer for storing output logits
277
291
  out_buffers = [
@@ -282,36 +296,80 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
282
296
  )
283
297
  ]
284
298
 
285
- # Process input in chunks of size `prefill_chunk_size`
286
- for step in range(0, query_length, self.prefill_chunk_size):
287
- # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
288
- if (step + self.prefill_chunk_size) > query_length:
289
- padding_size = step + self.prefill_chunk_size - query_length
290
- # inputs_embeds
291
- if inputs.dim() == 3:
292
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
293
- # inputs_ids
294
- else:
295
- inputs = torch.nn.functional.pad(inputs, (0, padding_size))
299
+ # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
300
+ if query_length % self.prefill_chunk_size != 0:
301
+ padding_size = self.prefill_chunk_size - query_length % self.prefill_chunk_size
302
+ # inputs_embeds
303
+ if inputs.dim() == 3:
304
+ inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
305
+ # inputs_ids
306
+ else:
307
+ inputs = torch.nn.functional.pad(inputs, (0, padding_size))
296
308
 
297
- cache_position = torch.cat(
298
- [
299
- cache_position,
300
- torch.arange(
301
- query_length,
302
- step + self.prefill_chunk_size,
303
- dtype=torch.int32,
304
- ).unsqueeze(0),
305
- ],
306
- dim=-1,
307
- )
309
+ cache_position = torch.cat(
310
+ [
311
+ cache_position,
312
+ torch.arange(
313
+ query_length,
314
+ query_length + padding_size,
315
+ dtype=torch.int32,
316
+ ).unsqueeze(0),
317
+ ],
318
+ dim=-1,
319
+ )
320
+
321
+ if position_embed is not None:
322
+ position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
323
+
324
+ # Overwrite position_ids and padded_cache_lengths
325
+ position_ids = None
326
+ padded_cache_lengths = 0
327
+
328
+ return (
329
+ inputs,
330
+ cache_position,
331
+ chunked_attention_mask,
332
+ out_buffers,
333
+ position_ids,
334
+ position_embed,
335
+ padded_cache_lengths,
336
+ query_length,
337
+ )
308
338
 
309
- if position_embed is not None:
310
- position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
339
+ def prefill_forward(
340
+ self,
341
+ inputs: torch.Tensor,
342
+ cache_position: torch.Tensor = None,
343
+ attention_mask: Optional[torch.Tensor] = None,
344
+ batch_idx: int = None,
345
+ block_tables: torch.Tensor = None,
346
+ is_external_block_tables: bool = None,
347
+ position_embed: Optional[torch.Tensor] = None,
348
+ ) -> torch.FloatTensor:
349
+ """
350
+ Performs chunked prefill for efficient KV-cache updates and memory optimization.
351
+ Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
352
+ and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
353
+ """
354
+ (
355
+ inputs,
356
+ cache_position,
357
+ chunked_attention_mask,
358
+ out_buffers,
359
+ position_ids,
360
+ position_embed,
361
+ padded_cache_lengths,
362
+ query_length,
363
+ ) = self._prepare_prefill_inputs(inputs, cache_position, attention_mask, position_embed)
311
364
 
365
+ # Process input in chunks of size `prefill_chunk_size`
366
+ for step in range(0, query_length, self.prefill_chunk_size):
312
367
  # Extract the current chunk of inputs and cache positions
313
368
  input_chunk = inputs[:, step : step + self.prefill_chunk_size]
314
369
  cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
370
+ position_ids_chunk = (
371
+ position_ids[:, step : step + self.prefill_chunk_size] if position_ids is not None else None
372
+ )
315
373
  if position_embed is not None:
316
374
  position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
317
375
 
@@ -328,9 +386,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
328
386
  logits = super().forward(
329
387
  input_chunk,
330
388
  cache_pos_chunk,
331
- chunked_attention_mask if self.use_attention_mask else None,
332
- query_position,
333
389
  block_tables,
390
+ query_position,
391
+ chunked_attention_mask if self.use_attention_mask else None,
392
+ position_ids_chunk if position_ids is not None else None,
334
393
  position_embed_chunk if position_embed is not None else None,
335
394
  out=out_buffers,
336
395
  )
@@ -340,13 +399,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
340
399
  self.dec_attn_mask[batch_idx].fill_(0)
341
400
  self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
342
401
 
343
- return logits
402
+ return RBLNDecoderOnlyOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
344
403
 
345
404
 
346
405
  @dataclass
347
406
  class RBLNDecoderOnlyOutput(ModelOutput):
348
407
  logits: torch.FloatTensor = None
349
408
  generate_idx: torch.Tensor = None
409
+ padded_cache_lengths: int = None
350
410
 
351
411
 
352
412
  class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
@@ -416,20 +476,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
416
476
  max_seq_len=self.rbln_config.max_seq_len,
417
477
  use_attention_mask=self.rbln_config.use_attention_mask,
418
478
  attn_impl=self.rbln_config.attn_impl,
479
+ use_position_ids=self.rbln_config.use_position_ids,
419
480
  )
420
- self.decoder = RBLNRuntimeModel(
421
- runtime=self.model[1],
422
- main_input_name=main_input_name,
423
- embed_tokens=self.embed_tokens,
424
- phase="decode",
425
- batch_size=self.rbln_config.batch_size,
426
- dec_attn_mask=dec_attn_mask,
427
- block_tables=block_tables,
428
- free_block_pool=free_block_pool,
429
- kvcache_block_size=self.rbln_config.kvcache_block_size,
430
- use_attention_mask=self.rbln_config.use_attention_mask,
431
- attn_impl=self.rbln_config.attn_impl,
432
- )
481
+ self.decoders = {}
482
+ for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
483
+ self.decoders[batch_size] = RBLNRuntimeModel(
484
+ runtime=self.model[i + 1],
485
+ main_input_name=main_input_name,
486
+ embed_tokens=self.embed_tokens,
487
+ phase="decode",
488
+ batch_size=batch_size,
489
+ dec_attn_mask=dec_attn_mask,
490
+ block_tables=block_tables,
491
+ free_block_pool=free_block_pool,
492
+ kvcache_block_size=self.rbln_config.kvcache_block_size,
493
+ use_attention_mask=self.rbln_config.use_attention_mask,
494
+ attn_impl=self.rbln_config.attn_impl,
495
+ use_position_ids=self.rbln_config.use_position_ids,
496
+ )
497
+
498
+ # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
499
+ self.decoder = self.decoders[self.rbln_config.batch_size]
433
500
 
434
501
  @classmethod
435
502
  def save_torch_artifacts(
@@ -471,8 +538,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
471
538
  trust_remote_code: bool = False,
472
539
  **kwargs,
473
540
  ):
474
- from ...utils.rbln_quantization import prepare_model_for_quantization
475
-
476
541
  kwargs = cls.update_kwargs(kwargs)
477
542
 
478
543
  if config is None:
@@ -489,8 +554,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
489
554
  with no_init_weights():
490
555
  model = AutoModelForCausalLM.from_config(config)
491
556
 
492
- prepare_model_for_quantization(model, model_id, kwargs.get("num_hidden_layers"))
493
-
557
+ model = prepare_model_for_quantization(
558
+ model,
559
+ model_id,
560
+ kwargs.get("num_hidden_layers"),
561
+ use_auth_token=use_auth_token,
562
+ revision=revision,
563
+ cache_dir=cache_dir,
564
+ force_download=force_download,
565
+ local_files_only=local_files_only,
566
+ )
494
567
  return model
495
568
 
496
569
  def __getattr__(self, __name: str) -> Any:
@@ -517,11 +590,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
517
590
  def get_pytorch_model(
518
591
  cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
519
592
  ) -> "PreTrainedModel":
520
- if (
521
- rbln_config is not None
522
- and "format" in rbln_config.quantization
523
- and rbln_config.quantization["format"] == "rbln"
524
- ):
593
+ if rbln_config and rbln_config.quantization:
525
594
  model = cls.get_quantized_model(*args, **kwargs)
526
595
  else:
527
596
  model = super().get_pytorch_model(*args, **kwargs)
@@ -537,6 +606,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
537
606
  "kvcache_block_size": rbln_config.kvcache_block_size,
538
607
  "use_rotary_emb": cls._use_rotary_emb,
539
608
  "use_attention_mask": rbln_config.use_attention_mask,
609
+ "use_position_ids": rbln_config.use_position_ids,
610
+ "use_inputs_embeds": rbln_config.use_inputs_embeds,
540
611
  }
541
612
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
542
613
 
@@ -547,7 +618,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
547
618
 
548
619
  rbln_compile_configs = rbln_config.compile_cfgs
549
620
  prefill_compile_config = rbln_compile_configs[0]
550
- dec_compile_config = rbln_compile_configs[1]
551
621
 
552
622
  context = CompileContext(use_weight_sharing=True)
553
623
 
@@ -562,33 +632,37 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
562
632
  static_tensors[name] = tensor
563
633
  context.mark_static_address(tensor)
564
634
 
565
- dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
566
-
567
- @QuantizationManager.with_quantization_env
568
- def compile_model(*args, **kwargs):
635
+ def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
569
636
  try:
637
+ if quantization:
638
+ quantization.maybe_set_quantization_env()
570
639
  original_linear = torch.nn.functional.linear
571
640
  torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
572
- wrapped_model.phase = "prefill"
573
- compiled_prefill = RBLNModel.compile(
574
- wrapped_model,
575
- prefill_compile_config,
576
- example_inputs=prefill_example_inputs,
577
- compile_context=context,
578
- )
579
-
580
- wrapped_model.phase = "decode"
581
- compiled_decoder = RBLNModel.compile(
641
+ compiled_model = RBLNModel.compile(
582
642
  wrapped_model,
583
- dec_compile_config,
584
- example_inputs=dec_example_inputs,
585
- compile_context=context,
643
+ compile_config,
644
+ example_inputs=example_inputs,
645
+ compile_context=compile_context,
586
646
  )
587
- return {"prefill": compiled_prefill, "decoder": compiled_decoder}
647
+ return compiled_model
588
648
  finally:
589
649
  torch.nn.functional.linear = original_linear
650
+ if quantization:
651
+ quantization.maybe_reset_quantization_env()
590
652
 
591
- compiled_models = compile_model(quantize_config=rbln_config.quantization)
653
+ wrapped_model.phase = "prefill"
654
+ compiled_prefill = compile_model(
655
+ wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
656
+ )
657
+
658
+ wrapped_model.phase = "decode"
659
+ compiled_models = {"prefill": compiled_prefill}
660
+ for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
661
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
662
+ compiled_decoder = compile_model(
663
+ wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
664
+ )
665
+ compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
592
666
 
593
667
  # check if the memory is enough to have additional blocks
594
668
  required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
@@ -613,8 +687,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
613
687
  alloc_memory_by_key: Dict[str, int] = {
614
688
  key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
615
689
  }
616
- for key, memory_per_node in compiled_models["decoder"].get_alloc_per_node_by_key().items():
617
- alloc_memory_by_key[key] += sum(memory_per_node)
690
+ for batch_size in rbln_config.decoder_batch_sizes:
691
+ for key, memory_per_node in (
692
+ compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
693
+ ):
694
+ alloc_memory_by_key[key] += sum(memory_per_node)
618
695
  alloc_memory_by_key.pop("PortRecur") # kv-cache
619
696
  kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
620
697
 
@@ -650,6 +727,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
650
727
  n_model_params: Optional[int] = None,
651
728
  kernel_size: Optional[int] = None,
652
729
  buffer: Optional[int] = None,
730
+ num_runtimes: int = 2,
653
731
  ) -> int:
654
732
  """
655
733
  We are finding max_n_blocks(x) that satisfies the following equation:
@@ -721,7 +799,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
721
799
 
722
800
  if buffer is None:
723
801
  # TODO: Accurate buffer estimation
724
- buffer_per_core = 2**29 # 500MB per npu
802
+ buffer_per_runtime_per_core = 2**28 # 256MB per runtime
803
+ buffer_per_core = buffer_per_runtime_per_core * num_runtimes # 1 for prefill, 1 for decoder
725
804
  buffer = buffer_per_core * tensor_parallel_size
726
805
  available_dram -= buffer
727
806
 
@@ -739,6 +818,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
739
818
  query_length: int,
740
819
  use_inputs_embeds: bool,
741
820
  use_attention_mask: bool,
821
+ use_position_ids: bool,
742
822
  max_seq_len: int,
743
823
  kvcache_block_size: int,
744
824
  kvcache_num_blocks: int,
@@ -761,26 +841,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
761
841
  ),
762
842
  ]
763
843
 
764
- if use_attention_mask:
844
+ max_block_cnt = max_seq_len // kvcache_block_size
845
+
846
+ if query_length > 1:
847
+ input_info.extend([("block_tables", [max_block_cnt], "int16")])
848
+ else:
849
+ input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
850
+
851
+ if query_length > 1:
765
852
  input_info.extend(
766
853
  [
767
- ("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
854
+ ("query_position", [], "int16"),
768
855
  ]
769
856
  )
770
-
771
- if query_length > 1:
857
+ if use_attention_mask:
772
858
  input_info.extend(
773
859
  [
774
- ("query_position", [], "int16"),
860
+ ("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
775
861
  ]
776
862
  )
777
-
778
- max_block_cnt = max_seq_len // kvcache_block_size
779
-
780
- if query_length > 1:
781
- input_info.extend([("block_tables", [max_block_cnt], "int16")])
782
- else:
783
- input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
863
+ if use_position_ids:
864
+ input_info.append(("position_ids", [batch_size, query_length], "int32"))
784
865
 
785
866
  input_info.extend(
786
867
  [
@@ -839,6 +920,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
839
920
  kvcache_block_size=rbln_config.kvcache_block_size,
840
921
  nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
841
922
  n_model_params=sum(p.numel() for p in model.parameters()),
923
+ num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
842
924
  )
843
925
 
844
926
  max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
@@ -873,19 +955,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
873
955
  query_length=rbln_config.prefill_chunk_size,
874
956
  use_inputs_embeds=rbln_config.use_inputs_embeds,
875
957
  use_attention_mask=rbln_config.use_attention_mask,
876
- max_seq_len=rbln_config.max_seq_len,
877
- kvcache_block_size=rbln_config.kvcache_block_size,
878
- kvcache_num_blocks=rbln_config.kvcache_num_blocks,
879
- num_key_value_heads=num_key_value_heads,
880
- num_hidden_layers=num_hidden_layers,
881
- hidden_size=hidden_size,
882
- head_dim=head_dim,
883
- )
884
- dec_input_info = cls.get_input_info(
885
- batch_size=rbln_config.batch_size,
886
- query_length=1,
887
- use_inputs_embeds=rbln_config.use_inputs_embeds,
888
- use_attention_mask=rbln_config.use_attention_mask,
958
+ use_position_ids=rbln_config.use_position_ids,
889
959
  max_seq_len=rbln_config.max_seq_len,
890
960
  kvcache_block_size=rbln_config.kvcache_block_size,
891
961
  kvcache_num_blocks=rbln_config.kvcache_num_blocks,
@@ -896,9 +966,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
896
966
  )
897
967
 
898
968
  prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
899
- dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
900
969
 
901
- rbln_config.set_compile_cfgs([prefill_compile_config, dec_compile_config])
970
+ dec_compile_configs = []
971
+ for batch_size in rbln_config.decoder_batch_sizes:
972
+ dec_input_info = cls.get_input_info(
973
+ batch_size=batch_size,
974
+ query_length=1,
975
+ use_inputs_embeds=rbln_config.use_inputs_embeds,
976
+ use_attention_mask=rbln_config.use_attention_mask,
977
+ use_position_ids=rbln_config.use_position_ids,
978
+ max_seq_len=rbln_config.max_seq_len,
979
+ kvcache_block_size=rbln_config.kvcache_block_size,
980
+ kvcache_num_blocks=rbln_config.kvcache_num_blocks,
981
+ num_key_value_heads=num_key_value_heads,
982
+ num_hidden_layers=num_hidden_layers,
983
+ hidden_size=hidden_size,
984
+ head_dim=head_dim,
985
+ )
986
+ dec_compile_configs.append(
987
+ RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
988
+ )
989
+ rbln_config.set_compile_cfgs([prefill_compile_config, *dec_compile_configs])
902
990
 
903
991
  return rbln_config
904
992
 
@@ -908,8 +996,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
908
996
  compiled_models: List[rebel.RBLNCompiledModel],
909
997
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
910
998
  ) -> List[rebel.Runtime]:
911
- if any(model_name not in rbln_config.device_map for model_name in ["prefill", "decoder"]):
912
- cls._raise_missing_compiled_file_error(["prefill", "decoder"])
999
+ expected_model_names = [
1000
+ "prefill",
1001
+ *[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
1002
+ ]
1003
+ if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
1004
+ cls._raise_missing_compiled_file_error(expected_model_names)
913
1005
 
914
1006
  return [
915
1007
  rebel.Runtime(
@@ -918,12 +1010,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
918
1010
  device=rbln_config.device_map["prefill"],
919
1011
  activate_profiler=rbln_config.activate_profiler,
920
1012
  ),
921
- rebel.Runtime(
922
- compiled_models[1],
923
- tensor_type="pt",
924
- device=rbln_config.device_map["decoder"],
925
- activate_profiler=rbln_config.activate_profiler,
926
- ),
1013
+ *[
1014
+ rebel.Runtime(
1015
+ compiled_models[i + 1],
1016
+ tensor_type="pt",
1017
+ device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
1018
+ activate_profiler=rbln_config.activate_profiler,
1019
+ )
1020
+ for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
1021
+ ],
927
1022
  ]
928
1023
 
929
1024
  def get_decoder(self):
@@ -941,6 +1036,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
941
1036
  generate_idx: Optional[torch.Tensor] = None,
942
1037
  attention_mask: Optional[torch.LongTensor] = None,
943
1038
  inputs_embeds: Optional[torch.Tensor] = None,
1039
+ padded_cache_lengths: Optional[torch.Tensor] = None,
944
1040
  **kwargs,
945
1041
  ):
946
1042
  model_inputs = {}
@@ -948,13 +1044,17 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
948
1044
 
949
1045
  if is_prefill_phase:
950
1046
  generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
1047
+ padded_cache_lengths = torch.zeros_like(generate_idx)
951
1048
  cache_position = None
1049
+ position_ids = None
952
1050
  else:
953
1051
  if inputs_embeds is not None:
954
- raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
1052
+ # if `inputs_embeds` are passed, only use them in the 1st generation step for every prompt.
1053
+ inputs_embeds = None
955
1054
 
956
1055
  input_ids = input_ids[:, -1:]
957
- cache_position = generate_idx
1056
+ position_ids = generate_idx
1057
+ cache_position = generate_idx + padded_cache_lengths if padded_cache_lengths is not None else generate_idx
958
1058
  generate_idx = generate_idx + 1
959
1059
  model_inputs.update({"input_ids": input_ids})
960
1060
 
@@ -973,6 +1073,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
973
1073
  "attention_mask": attention_mask,
974
1074
  "cache_position": cache_position,
975
1075
  "generate_idx": generate_idx,
1076
+ "position_ids": position_ids,
1077
+ "padded_cache_lengths": padded_cache_lengths,
976
1078
  }
977
1079
  )
978
1080
 
@@ -986,6 +1088,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
986
1088
  ) -> Dict[str, Any]:
987
1089
  # update generate_idx
988
1090
  model_kwargs["generate_idx"] = outputs.generate_idx
1091
+ model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
989
1092
 
990
1093
  return model_kwargs
991
1094
 
@@ -996,6 +1099,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
996
1099
  cache_position: Optional[torch.Tensor] = None,
997
1100
  attention_mask: Optional[torch.LongTensor] = None,
998
1101
  generate_idx: Optional[torch.Tensor] = None,
1102
+ padded_cache_lengths: Optional[torch.Tensor] = None,
1103
+ position_ids: Optional[torch.Tensor] = None,
1104
+ return_dict: Optional[torch.Tensor] = None,
999
1105
  **kwargs,
1000
1106
  ) -> Tuple[torch.FloatTensor]:
1001
1107
  """
@@ -1009,28 +1115,38 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1009
1115
  logits = []
1010
1116
  inputs = inputs_embeds if inputs_embeds is not None else input_ids
1011
1117
  batch_size = inputs.shape[0]
1012
-
1013
1118
  for b_idx in range(batch_size):
1014
1119
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
1015
- logit = self.prefill_decoder(
1120
+ output = self.prefill_decoder(
1016
1121
  input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
1017
1122
  inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
1018
1123
  attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
1019
1124
  cache_position=cache_position,
1020
1125
  batch_idx=b_idx,
1021
1126
  )
1022
- logits.append(logit)
1023
-
1127
+ padded_cache_lengths[b_idx] += output.padded_cache_lengths
1128
+ logits.append(output.logits)
1024
1129
  logits = torch.cat(logits, dim=0)
1025
1130
  # Decoder
1026
1131
  else:
1027
- logits = self.decoder(
1132
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
1133
+ batch_size = inputs.shape[0]
1134
+ if batch_size not in self.decoders:
1135
+ raise ValueError(
1136
+ f"No decoder runtime available for batch size {batch_size}. "
1137
+ f"Available batch sizes are: {list(self.decoders.keys())}. "
1138
+ f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
1139
+ )
1140
+ logits = self.decoders[batch_size](
1028
1141
  input_ids=input_ids,
1029
1142
  inputs_embeds=inputs_embeds,
1030
1143
  cache_position=cache_position,
1031
- )
1144
+ position_ids=position_ids if self.rbln_config.use_position_ids else None,
1145
+ ).logits
1032
1146
 
1033
- return RBLNDecoderOnlyOutput(
1034
- logits=logits,
1035
- generate_idx=generate_idx,
1036
- )
1147
+ if not return_dict:
1148
+ return logits, generate_idx, padded_cache_lengths
1149
+ else:
1150
+ return RBLNDecoderOnlyOutput(
1151
+ logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
1152
+ )