optimum-rbln 0.7.5a0__py3-none-any.whl → 0.7.5a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. optimum/rbln/__init__.py +20 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +9 -4
  4. optimum/rbln/modeling.py +7 -5
  5. optimum/rbln/ops/__init__.py +1 -0
  6. optimum/rbln/ops/attn.py +10 -0
  7. optimum/rbln/ops/flash_attn.py +8 -0
  8. optimum/rbln/ops/sliding_window_attn.py +111 -0
  9. optimum/rbln/transformers/__init__.py +22 -3
  10. optimum/rbln/transformers/models/__init__.py +23 -0
  11. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  12. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +93 -0
  13. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +298 -0
  14. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +12 -6
  15. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +81 -77
  16. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +160 -88
  17. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +11 -7
  18. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
  19. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  20. optimum/rbln/transformers/models/opt/configuration_opt.py +19 -0
  21. optimum/rbln/transformers/models/opt/modeling_opt.py +78 -0
  22. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  23. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +16 -10
  24. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +35 -52
  25. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -0
  26. optimum/rbln/transformers/models/siglip/__init__.py +20 -0
  27. optimum/rbln/transformers/models/siglip/configuration_siglip.py +66 -0
  28. optimum/rbln/transformers/models/siglip/modeling_siglip.py +146 -0
  29. optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -0
  30. optimum/rbln/transformers/utils/rbln_quantization.py +121 -72
  31. optimum/rbln/utils/submodule.py +13 -1
  32. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/METADATA +1 -1
  33. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/RECORD +35 -24
  34. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/WHEEL +0 -0
  35. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/licenses/LICENSE +0 -0
@@ -30,7 +30,7 @@ from ....configuration_utils import RBLNCompileConfig
30
30
  from ....modeling import RBLNModel
31
31
  from ....utils.logging import get_logger
32
32
  from ....utils.runtime_utils import RBLNPytorchRuntime
33
- from ...utils.rbln_quantization import QuantizationManager
33
+ from ...utils.rbln_quantization import prepare_model_for_quantization
34
34
  from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
35
35
  from .decoderonly_architecture import (
36
36
  DecoderOnlyWrapper,
@@ -59,6 +59,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
59
59
  kvcache_block_size: int,
60
60
  use_attention_mask: bool,
61
61
  attn_impl: str,
62
+ use_position_ids: bool,
62
63
  **kwargs: Any,
63
64
  ) -> None:
64
65
  super().__init__(runtime, **kwargs)
@@ -72,6 +73,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
72
73
  self.dec_attn_mask = dec_attn_mask
73
74
  self.block_tables = block_tables
74
75
  self.free_block_pool = free_block_pool
76
+ self.use_position_ids = use_position_ids
75
77
 
76
78
  self.kvcache_block_size = kvcache_block_size
77
79
  self.empty_block = -1
@@ -164,6 +166,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
164
166
  batch_idx: Optional[int] = None,
165
167
  block_tables: Optional[torch.Tensor] = None,
166
168
  position_embed: Optional[torch.Tensor] = None,
169
+ position_ids: Optional[torch.Tensor] = None,
167
170
  ):
168
171
  if input_ids is None and inputs_embeds is None:
169
172
  raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
@@ -189,10 +192,16 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
189
192
  is_external_block_tables,
190
193
  attention_mask=attention_mask,
191
194
  position_embed=position_embed,
195
+ position_ids=position_ids,
192
196
  )
193
197
  else:
194
198
  return self.prefill_forward(
195
- inputs, cache_position, attention_mask, batch_idx, block_tables, position_embed=position_embed
199
+ inputs,
200
+ cache_position,
201
+ attention_mask,
202
+ batch_idx,
203
+ block_tables,
204
+ position_embed=position_embed,
196
205
  )
197
206
 
198
207
  def decode_forward(
@@ -203,6 +212,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
203
212
  is_external_block_tables: bool = None,
204
213
  attention_mask: Optional[torch.Tensor] = None,
205
214
  position_embed: Optional[torch.Tensor] = None,
215
+ position_ids: Optional[torch.Tensor] = None,
206
216
  ) -> torch.FloatTensor:
207
217
  batch_size = inputs.shape[0]
208
218
  if batch_size != self.batch_size:
@@ -232,35 +242,30 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
232
242
  if self.batch_size < block_tables.shape[0]:
233
243
  block_tables = block_tables[: self.batch_size]
234
244
 
235
- if self.batch_size < attention_mask.shape[0]:
245
+ if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
236
246
  attention_mask = attention_mask[: self.batch_size]
237
247
 
238
248
  logits = super().forward(
239
249
  inputs,
240
250
  cache_position,
241
- attention_mask if self.use_attention_mask else None,
242
251
  block_tables,
243
252
  position_embed,
253
+ attention_mask if self.use_attention_mask else None,
254
+ position_ids if self.use_position_ids else None,
244
255
  )
245
256
 
246
- return logits
257
+ return RBLNDecoderOnlyOutput(logits=logits)
247
258
 
248
- def prefill_forward(
259
+ def _prepare_prefill_inputs(
249
260
  self,
250
261
  inputs: torch.Tensor,
251
- cache_position: torch.Tensor = None,
262
+ cache_position: torch.Tensor,
252
263
  attention_mask: Optional[torch.Tensor] = None,
253
- batch_idx: int = None,
254
- block_tables: torch.Tensor = None,
255
- is_external_block_tables: bool = None,
256
264
  position_embed: Optional[torch.Tensor] = None,
257
- ) -> torch.FloatTensor:
265
+ ):
258
266
  """
259
- Performs chunked prefill for efficient KV-cache updates and memory optimization.
260
- Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
261
- and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
267
+ Prepare inputs for prefill phase.
262
268
  """
263
-
264
269
  # Handle continuous batching in a compiled graph by extracting valid inputs
265
270
  # If an attention mask is provided, select only the valid (non-masked) inputs
266
271
  inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
@@ -276,8 +281,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
276
281
  )
277
282
 
278
283
  # Initialize attention mask for chunked processing
279
- if self.use_attention_mask:
280
- chunked_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
284
+ chunked_attention_mask = (
285
+ torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
286
+ if self.use_attention_mask
287
+ else None
288
+ )
281
289
 
282
290
  # Buffer for storing output logits
283
291
  out_buffers = [
@@ -288,36 +296,80 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
288
296
  )
289
297
  ]
290
298
 
291
- # Process input in chunks of size `prefill_chunk_size`
292
- for step in range(0, query_length, self.prefill_chunk_size):
293
- # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
294
- if (step + self.prefill_chunk_size) > query_length:
295
- padding_size = step + self.prefill_chunk_size - query_length
296
- # inputs_embeds
297
- if inputs.dim() == 3:
298
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
299
- # inputs_ids
300
- else:
301
- inputs = torch.nn.functional.pad(inputs, (0, padding_size))
299
+ # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
300
+ if query_length % self.prefill_chunk_size != 0:
301
+ padding_size = self.prefill_chunk_size - query_length % self.prefill_chunk_size
302
+ # inputs_embeds
303
+ if inputs.dim() == 3:
304
+ inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
305
+ # inputs_ids
306
+ else:
307
+ inputs = torch.nn.functional.pad(inputs, (0, padding_size))
302
308
 
303
- cache_position = torch.cat(
304
- [
305
- cache_position,
306
- torch.arange(
307
- query_length,
308
- step + self.prefill_chunk_size,
309
- dtype=torch.int32,
310
- ).unsqueeze(0),
311
- ],
312
- dim=-1,
313
- )
309
+ cache_position = torch.cat(
310
+ [
311
+ cache_position,
312
+ torch.arange(
313
+ query_length,
314
+ query_length + padding_size,
315
+ dtype=torch.int32,
316
+ ).unsqueeze(0),
317
+ ],
318
+ dim=-1,
319
+ )
320
+
321
+ if position_embed is not None:
322
+ position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
323
+
324
+ # Overwrite position_ids and padded_cache_lengths
325
+ position_ids = None
326
+ padded_cache_lengths = 0
327
+
328
+ return (
329
+ inputs,
330
+ cache_position,
331
+ chunked_attention_mask,
332
+ out_buffers,
333
+ position_ids,
334
+ position_embed,
335
+ padded_cache_lengths,
336
+ query_length,
337
+ )
314
338
 
315
- if position_embed is not None:
316
- position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
339
+ def prefill_forward(
340
+ self,
341
+ inputs: torch.Tensor,
342
+ cache_position: torch.Tensor = None,
343
+ attention_mask: Optional[torch.Tensor] = None,
344
+ batch_idx: int = None,
345
+ block_tables: torch.Tensor = None,
346
+ is_external_block_tables: bool = None,
347
+ position_embed: Optional[torch.Tensor] = None,
348
+ ) -> torch.FloatTensor:
349
+ """
350
+ Performs chunked prefill for efficient KV-cache updates and memory optimization.
351
+ Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
352
+ and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
353
+ """
354
+ (
355
+ inputs,
356
+ cache_position,
357
+ chunked_attention_mask,
358
+ out_buffers,
359
+ position_ids,
360
+ position_embed,
361
+ padded_cache_lengths,
362
+ query_length,
363
+ ) = self._prepare_prefill_inputs(inputs, cache_position, attention_mask, position_embed)
317
364
 
365
+ # Process input in chunks of size `prefill_chunk_size`
366
+ for step in range(0, query_length, self.prefill_chunk_size):
318
367
  # Extract the current chunk of inputs and cache positions
319
368
  input_chunk = inputs[:, step : step + self.prefill_chunk_size]
320
369
  cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
370
+ position_ids_chunk = (
371
+ position_ids[:, step : step + self.prefill_chunk_size] if position_ids is not None else None
372
+ )
321
373
  if position_embed is not None:
322
374
  position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
323
375
 
@@ -334,9 +386,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
334
386
  logits = super().forward(
335
387
  input_chunk,
336
388
  cache_pos_chunk,
337
- chunked_attention_mask if self.use_attention_mask else None,
338
- query_position,
339
389
  block_tables,
390
+ query_position,
391
+ chunked_attention_mask if self.use_attention_mask else None,
392
+ position_ids_chunk if position_ids is not None else None,
340
393
  position_embed_chunk if position_embed is not None else None,
341
394
  out=out_buffers,
342
395
  )
@@ -346,13 +399,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
346
399
  self.dec_attn_mask[batch_idx].fill_(0)
347
400
  self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
348
401
 
349
- return logits
402
+ return RBLNDecoderOnlyOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
350
403
 
351
404
 
352
405
  @dataclass
353
406
  class RBLNDecoderOnlyOutput(ModelOutput):
354
407
  logits: torch.FloatTensor = None
355
408
  generate_idx: torch.Tensor = None
409
+ padded_cache_lengths: int = None
356
410
 
357
411
 
358
412
  class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
@@ -422,6 +476,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
422
476
  max_seq_len=self.rbln_config.max_seq_len,
423
477
  use_attention_mask=self.rbln_config.use_attention_mask,
424
478
  attn_impl=self.rbln_config.attn_impl,
479
+ use_position_ids=self.rbln_config.use_position_ids,
425
480
  )
426
481
  self.decoders = {}
427
482
  for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
@@ -437,6 +492,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
437
492
  kvcache_block_size=self.rbln_config.kvcache_block_size,
438
493
  use_attention_mask=self.rbln_config.use_attention_mask,
439
494
  attn_impl=self.rbln_config.attn_impl,
495
+ use_position_ids=self.rbln_config.use_position_ids,
440
496
  )
441
497
 
442
498
  # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
@@ -482,8 +538,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
482
538
  trust_remote_code: bool = False,
483
539
  **kwargs,
484
540
  ):
485
- from ...utils.rbln_quantization import prepare_model_for_quantization
486
-
487
541
  kwargs = cls.update_kwargs(kwargs)
488
542
 
489
543
  if config is None:
@@ -500,8 +554,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
500
554
  with no_init_weights():
501
555
  model = AutoModelForCausalLM.from_config(config)
502
556
 
503
- prepare_model_for_quantization(model, model_id, kwargs.get("num_hidden_layers"))
504
-
557
+ model = prepare_model_for_quantization(
558
+ model,
559
+ model_id,
560
+ kwargs.get("num_hidden_layers"),
561
+ use_auth_token=use_auth_token,
562
+ revision=revision,
563
+ cache_dir=cache_dir,
564
+ force_download=force_download,
565
+ local_files_only=local_files_only,
566
+ )
505
567
  return model
506
568
 
507
569
  def __getattr__(self, __name: str) -> Any:
@@ -528,11 +590,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
528
590
  def get_pytorch_model(
529
591
  cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
530
592
  ) -> "PreTrainedModel":
531
- if (
532
- rbln_config is not None
533
- and "format" in rbln_config.quantization
534
- and rbln_config.quantization["format"] == "rbln"
535
- ):
593
+ if rbln_config and rbln_config.quantization:
536
594
  model = cls.get_quantized_model(*args, **kwargs)
537
595
  else:
538
596
  model = super().get_pytorch_model(*args, **kwargs)
@@ -548,6 +606,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
548
606
  "kvcache_block_size": rbln_config.kvcache_block_size,
549
607
  "use_rotary_emb": cls._use_rotary_emb,
550
608
  "use_attention_mask": rbln_config.use_attention_mask,
609
+ "use_position_ids": rbln_config.use_position_ids,
610
+ "use_inputs_embeds": rbln_config.use_inputs_embeds,
551
611
  }
552
612
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
553
613
 
@@ -572,9 +632,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
572
632
  static_tensors[name] = tensor
573
633
  context.mark_static_address(tensor)
574
634
 
575
- @QuantizationManager.with_quantization_env
576
- def compile_model(wrapped_model, compile_config, example_inputs, compile_context, **kwargs):
635
+ def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
577
636
  try:
637
+ if quantization:
638
+ quantization.maybe_set_quantization_env()
578
639
  original_linear = torch.nn.functional.linear
579
640
  torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
580
641
  compiled_model = RBLNModel.compile(
@@ -586,14 +647,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
586
647
  return compiled_model
587
648
  finally:
588
649
  torch.nn.functional.linear = original_linear
650
+ if quantization:
651
+ quantization.maybe_reset_quantization_env()
589
652
 
590
653
  wrapped_model.phase = "prefill"
591
654
  compiled_prefill = compile_model(
592
- wrapped_model,
593
- prefill_compile_config,
594
- prefill_example_inputs,
595
- context,
596
- quantize_config=rbln_config.quantization,
655
+ wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
597
656
  )
598
657
 
599
658
  wrapped_model.phase = "decode"
@@ -601,11 +660,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
601
660
  for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
602
661
  dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
603
662
  compiled_decoder = compile_model(
604
- wrapped_model,
605
- dec_compile_config,
606
- dec_example_inputs,
607
- context,
608
- quantize_config=rbln_config.quantization,
663
+ wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
609
664
  )
610
665
  compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
611
666
 
@@ -763,6 +818,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
763
818
  query_length: int,
764
819
  use_inputs_embeds: bool,
765
820
  use_attention_mask: bool,
821
+ use_position_ids: bool,
766
822
  max_seq_len: int,
767
823
  kvcache_block_size: int,
768
824
  kvcache_num_blocks: int,
@@ -785,26 +841,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
785
841
  ),
786
842
  ]
787
843
 
788
- if use_attention_mask:
844
+ max_block_cnt = max_seq_len // kvcache_block_size
845
+
846
+ if query_length > 1:
847
+ input_info.extend([("block_tables", [max_block_cnt], "int16")])
848
+ else:
849
+ input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
850
+
851
+ if query_length > 1:
789
852
  input_info.extend(
790
853
  [
791
- ("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
854
+ ("query_position", [], "int16"),
792
855
  ]
793
856
  )
794
-
795
- if query_length > 1:
857
+ if use_attention_mask:
796
858
  input_info.extend(
797
859
  [
798
- ("query_position", [], "int16"),
860
+ ("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
799
861
  ]
800
862
  )
801
-
802
- max_block_cnt = max_seq_len // kvcache_block_size
803
-
804
- if query_length > 1:
805
- input_info.extend([("block_tables", [max_block_cnt], "int16")])
806
- else:
807
- input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
863
+ if use_position_ids:
864
+ input_info.append(("position_ids", [batch_size, query_length], "int32"))
808
865
 
809
866
  input_info.extend(
810
867
  [
@@ -898,6 +955,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
898
955
  query_length=rbln_config.prefill_chunk_size,
899
956
  use_inputs_embeds=rbln_config.use_inputs_embeds,
900
957
  use_attention_mask=rbln_config.use_attention_mask,
958
+ use_position_ids=rbln_config.use_position_ids,
901
959
  max_seq_len=rbln_config.max_seq_len,
902
960
  kvcache_block_size=rbln_config.kvcache_block_size,
903
961
  kvcache_num_blocks=rbln_config.kvcache_num_blocks,
@@ -916,6 +974,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
916
974
  query_length=1,
917
975
  use_inputs_embeds=rbln_config.use_inputs_embeds,
918
976
  use_attention_mask=rbln_config.use_attention_mask,
977
+ use_position_ids=rbln_config.use_position_ids,
919
978
  max_seq_len=rbln_config.max_seq_len,
920
979
  kvcache_block_size=rbln_config.kvcache_block_size,
921
980
  kvcache_num_blocks=rbln_config.kvcache_num_blocks,
@@ -977,6 +1036,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
977
1036
  generate_idx: Optional[torch.Tensor] = None,
978
1037
  attention_mask: Optional[torch.LongTensor] = None,
979
1038
  inputs_embeds: Optional[torch.Tensor] = None,
1039
+ padded_cache_lengths: Optional[torch.Tensor] = None,
980
1040
  **kwargs,
981
1041
  ):
982
1042
  model_inputs = {}
@@ -984,13 +1044,17 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
984
1044
 
985
1045
  if is_prefill_phase:
986
1046
  generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
1047
+ padded_cache_lengths = torch.zeros_like(generate_idx)
987
1048
  cache_position = None
1049
+ position_ids = None
988
1050
  else:
989
1051
  if inputs_embeds is not None:
990
- raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
1052
+ # if `inputs_embeds` are passed, only use them in the 1st generation step for every prompt.
1053
+ inputs_embeds = None
991
1054
 
992
1055
  input_ids = input_ids[:, -1:]
993
- cache_position = generate_idx
1056
+ position_ids = generate_idx
1057
+ cache_position = generate_idx + padded_cache_lengths if padded_cache_lengths is not None else generate_idx
994
1058
  generate_idx = generate_idx + 1
995
1059
  model_inputs.update({"input_ids": input_ids})
996
1060
 
@@ -1009,6 +1073,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1009
1073
  "attention_mask": attention_mask,
1010
1074
  "cache_position": cache_position,
1011
1075
  "generate_idx": generate_idx,
1076
+ "position_ids": position_ids,
1077
+ "padded_cache_lengths": padded_cache_lengths,
1012
1078
  }
1013
1079
  )
1014
1080
 
@@ -1022,6 +1088,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1022
1088
  ) -> Dict[str, Any]:
1023
1089
  # update generate_idx
1024
1090
  model_kwargs["generate_idx"] = outputs.generate_idx
1091
+ model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
1025
1092
 
1026
1093
  return model_kwargs
1027
1094
 
@@ -1032,6 +1099,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1032
1099
  cache_position: Optional[torch.Tensor] = None,
1033
1100
  attention_mask: Optional[torch.LongTensor] = None,
1034
1101
  generate_idx: Optional[torch.Tensor] = None,
1102
+ padded_cache_lengths: Optional[torch.Tensor] = None,
1103
+ position_ids: Optional[torch.Tensor] = None,
1104
+ return_dict: Optional[torch.Tensor] = None,
1035
1105
  **kwargs,
1036
1106
  ) -> Tuple[torch.FloatTensor]:
1037
1107
  """
@@ -1045,18 +1115,17 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1045
1115
  logits = []
1046
1116
  inputs = inputs_embeds if inputs_embeds is not None else input_ids
1047
1117
  batch_size = inputs.shape[0]
1048
-
1049
1118
  for b_idx in range(batch_size):
1050
1119
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
1051
- logit = self.prefill_decoder(
1120
+ output = self.prefill_decoder(
1052
1121
  input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
1053
1122
  inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
1054
1123
  attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
1055
1124
  cache_position=cache_position,
1056
1125
  batch_idx=b_idx,
1057
1126
  )
1058
- logits.append(logit)
1059
-
1127
+ padded_cache_lengths[b_idx] += output.padded_cache_lengths
1128
+ logits.append(output.logits)
1060
1129
  logits = torch.cat(logits, dim=0)
1061
1130
  # Decoder
1062
1131
  else:
@@ -1072,9 +1141,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1072
1141
  input_ids=input_ids,
1073
1142
  inputs_embeds=inputs_embeds,
1074
1143
  cache_position=cache_position,
1075
- )
1144
+ position_ids=position_ids if self.rbln_config.use_position_ids else None,
1145
+ ).logits
1076
1146
 
1077
- return RBLNDecoderOnlyOutput(
1078
- logits=logits,
1079
- generate_idx=generate_idx,
1080
- )
1147
+ if not return_dict:
1148
+ return logits, generate_idx, padded_cache_lengths
1149
+ else:
1150
+ return RBLNDecoderOnlyOutput(
1151
+ logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
1152
+ )
@@ -421,6 +421,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
421
421
  image_hidden_states: Optional[torch.FloatTensor] = None,
422
422
  cache_position: torch.Tensor = None,
423
423
  generate_idx: Optional[torch.Tensor] = None,
424
+ return_dict: Optional[bool] = None,
424
425
  **kwargs,
425
426
  ) -> Union[Tuple, Idefics3CausalLMOutputWithPast]:
426
427
  # Prefill
@@ -434,14 +435,14 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
434
435
 
435
436
  for b_idx in range(batch_size):
436
437
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
437
- logit = self.text_model.prefill_decoder(
438
+ output = self.text_model.prefill_decoder(
438
439
  input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
439
440
  inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
440
441
  attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
441
442
  cache_position=cache_position,
442
443
  batch_idx=b_idx,
443
444
  )
444
- logits.append(logit)
445
+ logits.append(output.logits)
445
446
 
446
447
  logits = torch.cat(logits, dim=0)
447
448
 
@@ -451,9 +452,12 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
451
452
  input_ids=input_ids,
452
453
  inputs_embeds=inputs_embeds,
453
454
  cache_position=cache_position,
454
- )
455
+ ).logits
455
456
 
456
- return RBLNDecoderOnlyOutput(
457
- logits=logits,
458
- generate_idx=generate_idx,
459
- )
457
+ if not return_dict:
458
+ return logits, generate_idx
459
+ else:
460
+ return RBLNDecoderOnlyOutput(
461
+ logits=logits,
462
+ generate_idx=generate_idx,
463
+ )
@@ -372,7 +372,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
372
372
  inputs_embeds = [inputs_embeds[i : i + 1, attention_mask[i].bool()] for i in range(batch_size)]
373
373
  for batch_idx in range(batch_size):
374
374
  generate_idx[batch_idx] = inputs_embeds[batch_idx].shape[-2]
375
- logit = self.language_model.prefill_decoder(
375
+ output = self.language_model.prefill_decoder(
376
376
  inputs_embeds=inputs_embeds[batch_idx],
377
377
  batch_idx=batch_idx,
378
378
  cache_position=torch.arange(
@@ -382,14 +382,14 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
382
382
  ).unsqueeze(0),
383
383
  )
384
384
 
385
- logits.append(logit)
385
+ logits.append(output.logits)
386
386
  logits = torch.cat(logits, dim=0)
387
387
  else:
388
- logits = self.language_model.decoder(
388
+ output = self.language_model.decoder(
389
389
  inputs_embeds=inputs_embeds,
390
390
  cache_position=cache_position,
391
391
  )
392
-
392
+ logits = output.logits
393
393
  return RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
394
394
 
395
395
  # Almost copied from : https://github.com/huggingface/transformers/blob/6b550462139655d488d4c663086a63e98713c6b9/src/transformers/models/llava_next/modeling_llava_next.py
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_opt import RBLNOPTForCausalLMConfig
16
+ from .modeling_opt import RBLNOPTForCausalLM
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNOPTForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ pass