optimum-rbln 0.8.2a4__py3-none-any.whl → 0.8.2a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (64) hide show
  1. optimum/rbln/__init__.py +44 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +4 -0
  4. optimum/rbln/ops/kv_cache_update.py +5 -0
  5. optimum/rbln/ops/linear.py +7 -0
  6. optimum/rbln/transformers/__init__.py +48 -0
  7. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  8. optimum/rbln/transformers/models/__init__.py +35 -14
  9. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  10. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
  11. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -205
  12. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +569 -366
  13. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  14. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  15. optimum/rbln/transformers/models/gemma/modeling_gemma.py +13 -1
  16. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +7 -5
  17. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +82 -59
  18. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  19. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  20. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -7
  21. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +16 -1
  22. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -2
  23. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  24. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  25. optimum/rbln/transformers/models/llama/modeling_llama.py +13 -1
  26. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  27. optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
  28. optimum/rbln/transformers/models/llava/modeling_llava.py +379 -0
  29. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
  30. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  31. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  32. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  33. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  34. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  35. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  36. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  37. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  38. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  39. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  40. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  41. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +163 -0
  42. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  43. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  44. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  45. optimum/rbln/transformers/models/phi/phi_architecture.py +6 -6
  46. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  47. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  48. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +318 -0
  49. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  50. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  51. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  52. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  53. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -3
  54. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  55. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +10 -328
  56. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +0 -241
  57. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +0 -10
  58. optimum/rbln/transformers/models/whisper/configuration_whisper.py +1 -10
  59. optimum/rbln/transformers/models/whisper/modeling_whisper.py +5 -1
  60. optimum/rbln/utils/depreacate_utils.py +16 -0
  61. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/METADATA +1 -1
  62. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/RECORD +64 -51
  63. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/WHEEL +0 -0
  64. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/licenses/LICENSE +0 -0
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_gemma import RBLNGemmaForCausalLMConfig
16
- from .modeling_gemma import RBLNGemmaForCausalLM
15
+ from .configuration_gemma import RBLNGemmaForCausalLMConfig, RBLNGemmaModelConfig
16
+ from .modeling_gemma import RBLNGemmaForCausalLM, RBLNGemmaModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNGemmaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -40,3 +40,11 @@ class RBLNGemmaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
40
40
  )
41
41
  ```
42
42
  """
43
+
44
+
45
+ class RBLNGemmaModelConfig(RBLNDecoderOnlyModelConfig):
46
+ """
47
+ Configuration class for RBLN Gemma models.
48
+
49
+ This class is an alias of RBLNDecoderOnlyModelConfig.
50
+ """
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....utils import logging
16
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
17
17
  from .gemma_architecture import GemmaWrapper
18
18
 
19
19
 
@@ -81,3 +81,15 @@ class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
81
81
  """
82
82
 
83
83
  _decoder_wrapper_cls = GemmaWrapper
84
+
85
+
86
+ class RBLNGemmaModel(RBLNDecoderOnlyModel):
87
+ """
88
+ The Gemma Model transformer without a language modeling head.
89
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
90
+
91
+ A class to convert and run pre-trained transformers based GemmaModel model on RBLN devices.
92
+ It implements the methods to convert a pre-trained transformers GemmaModel model into a RBLN transformer model by:
93
+ """
94
+
95
+ _decoder_wrapper_cls = GemmaWrapper
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  from typing import Any, Dict, Optional
15
15
 
16
- import rebel
17
-
18
16
  from ....configuration_utils import RBLNModelConfig
19
17
  from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
20
18
  from ..siglip.configuration_siglip import RBLNSiglipVisionModelConfig
@@ -39,9 +37,13 @@ class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
39
37
  )
40
38
  self.image_prefill_chunk_size = image_prefill_chunk_size
41
39
 
42
- npu = self.npu or rebel.get_npu_name()
43
- if npu == "RBLN-CA02":
44
- raise NotImplementedError("Gemma3 is currently not supported on RBLN-CA02")
40
+ @property
41
+ def use_image_prefill(self):
42
+ return self.image_prefill_chunk_size is not None
43
+
44
+ @property
45
+ def decoder_runtime_idx(self):
46
+ return 2 if self.use_image_prefill else 1
45
47
 
46
48
 
47
49
  class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
@@ -31,7 +31,11 @@ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbed
31
31
 
32
32
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
33
33
  from ....modeling import RBLNModel
34
- from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput, RBLNRuntimeModel
34
+ from ..decoderonly.modeling_decoderonly import (
35
+ RBLNDecoderOnlyForCausalLMOutput,
36
+ RBLNDecoderOnlyModelForCausalLM,
37
+ RBLNRuntimeModel,
38
+ )
35
39
  from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
36
40
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
37
41
 
@@ -41,7 +45,7 @@ if TYPE_CHECKING:
41
45
 
42
46
 
43
47
  @dataclass
44
- class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
48
+ class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyForCausalLMOutput):
45
49
  attention_mask: Optional[torch.Tensor] = None
46
50
 
47
51
 
@@ -197,7 +201,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
197
201
 
198
202
  def _update_model_kwargs_for_generation(
199
203
  self,
200
- outputs: RBLNDecoderOnlyOutput,
204
+ outputs: RBLNDecoderOnlyForCausalLMOutput,
201
205
  model_kwargs: Dict[str, Any],
202
206
  **kwargs,
203
207
  ) -> Dict[str, Any]:
@@ -266,7 +270,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
266
270
  position_ids: Optional[torch.Tensor] = None,
267
271
  token_type_ids: Optional[torch.Tensor] = None,
268
272
  **lm_kwargs: Dict[str, Any],
269
- ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
273
+ ) -> Union[Tuple, RBLNDecoderOnlyForCausalLMOutput]:
270
274
  # prefill
271
275
  if cache_position is None:
272
276
  logits = []
@@ -304,7 +308,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
304
308
  position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
305
309
  ).logits
306
310
 
307
- return RBLNDecoderOnlyOutput(
311
+ return RBLNDecoderOnlyForCausalLMOutput(
308
312
  logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
309
313
  )
310
314
 
@@ -333,11 +337,12 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
333
337
  chunked_attention_mask = torch.zeros(1, chunked_attention_mask.shape[-1], dtype=torch.float32)
334
338
 
335
339
  # as gemma3 has different prefill chunk size for image and text, we need to pad the inputs to the max of the two.
336
- padding_size = max(self.rbln_config.prefill_chunk_size, self.rbln_config.image_prefill_chunk_size)
337
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
338
- cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
339
- position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
340
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
340
+ if self.rbln_config.use_image_prefill:
341
+ padding_size = max(self.rbln_config.prefill_chunk_size, self.rbln_config.image_prefill_chunk_size)
342
+ inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
343
+ cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
344
+ position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
345
+ token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
341
346
 
342
347
  return (
343
348
  inputs,
@@ -385,7 +390,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
385
390
  step = 0
386
391
  while step < query_length:
387
392
  # Check if the prefill chunk is an image prefill
388
- is_image_prefill = torch.all(
393
+ is_image_prefill = self.rbln_config.use_image_prefill and torch.all(
389
394
  token_type_ids[:, step : step + self.rbln_config.image_prefill_chunk_size] == 1
390
395
  )
391
396
  prefill_chunk_size = (
@@ -393,8 +398,10 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
393
398
  )
394
399
 
395
400
  # Check if the prefill chunk is a text prefill which have image_tokens in it.
396
- is_text_prefill_with_image_tokens = not is_image_prefill and torch.any(
397
- token_type_ids[:, step : step + prefill_chunk_size] == 1
401
+ is_text_prefill_with_image_tokens = (
402
+ self.rbln_config.use_image_prefill
403
+ and not is_image_prefill
404
+ and torch.any(token_type_ids[:, step : step + prefill_chunk_size] == 1)
398
405
  )
399
406
 
400
407
  # Check if the prefill chunk crosses a block boundary, requiring padding to align with block boundaries
@@ -414,7 +421,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
414
421
  num_processed_tokens = prefill_chunk_size
415
422
  if is_text_prefill_with_image_tokens:
416
423
  first_image_token_idx = torch.where(token_type_ids[:, step : step + prefill_chunk_size] == 1)[1][0]
417
- num_processed_tokens = first_image_token_idx
424
+ num_processed_tokens = first_image_token_idx.item()
418
425
  if is_last_chunk:
419
426
  num_processed_tokens = query_length - step
420
427
 
@@ -509,7 +516,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
509
516
 
510
517
  logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
511
518
 
512
- return RBLNDecoderOnlyOutput(logits=logits)
519
+ return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
513
520
 
514
521
 
515
522
  class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
@@ -544,9 +551,10 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
544
551
  dtype=torch.int16,
545
552
  ).fill_(-1)
546
553
  free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
554
+
547
555
  self.prefill_decoder = RBLNGemma3RuntimeModel(
548
556
  runtime=self.model[0],
549
- image_prefill=self.model[1],
557
+ image_prefill=self.model[1] if self.rbln_config.use_image_prefill else None,
550
558
  main_input_name=main_input_name,
551
559
  embed_tokens=self.embed_tokens,
552
560
  phase="prefill",
@@ -561,7 +569,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
561
569
  self.decoders = {}
562
570
  for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
563
571
  self.decoders[batch_size] = RBLNGemma3RuntimeModel(
564
- runtime=self.model[i + 2],
572
+ runtime=self.model[i + self.rbln_config.decoder_runtime_idx],
565
573
  main_input_name=main_input_name,
566
574
  embed_tokens=self.embed_tokens,
567
575
  phase="decode",
@@ -624,20 +632,21 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
624
632
  if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
625
633
  raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
626
634
 
627
- # Update image prefill compile config
628
- img_prefill_input_info = cls.get_input_info(
629
- batch_size=1,
630
- query_length=rbln_config.image_prefill_chunk_size,
631
- rbln_config=rbln_config,
632
- model_config=model_config,
633
- )
634
- image_prefill_compile_config = RBLNCompileConfig(
635
- compiled_model_name="image_prefill", input_info=img_prefill_input_info
636
- )
637
- # Insert image_prefill compile config at index 1
638
- compile_cfgs = rbln_config.compile_cfgs
639
- compile_cfgs.insert(1, image_prefill_compile_config)
640
- rbln_config.set_compile_cfgs(compile_cfgs)
635
+ if rbln_config.use_image_prefill:
636
+ # Update image prefill compile config
637
+ img_prefill_input_info = cls.get_input_info(
638
+ batch_size=1,
639
+ query_length=rbln_config.image_prefill_chunk_size,
640
+ rbln_config=rbln_config,
641
+ model_config=model_config,
642
+ )
643
+ image_prefill_compile_config = RBLNCompileConfig(
644
+ compiled_model_name="image_prefill", input_info=img_prefill_input_info
645
+ )
646
+ # Insert image_prefill compile config at index 1
647
+ compile_cfgs = rbln_config.compile_cfgs
648
+ compile_cfgs.insert(1, image_prefill_compile_config)
649
+ rbln_config.set_compile_cfgs(compile_cfgs)
641
650
 
642
651
  return rbln_config
643
652
 
@@ -690,23 +699,27 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
690
699
  context,
691
700
  rbln_config.quantization,
692
701
  )
702
+ compiled_models = {"prefill": compiled_prefill}
693
703
 
694
- image_prefill_compile_config = rbln_compile_configs[1]
695
- image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
696
- fill=0, static_tensors=static_tensors
697
- )
698
- wrapped_model.phase = "image_prefill"
699
- compiled_image_prefill = compile_model(
700
- wrapped_model,
701
- image_prefill_compile_config,
702
- image_prefill_example_inputs,
703
- context,
704
- rbln_config.quantization,
705
- )
704
+ if rbln_config.use_image_prefill:
705
+ image_prefill_compile_config = rbln_compile_configs[1]
706
+ image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
707
+ fill=0, static_tensors=static_tensors
708
+ )
709
+ wrapped_model.phase = "image_prefill"
710
+ compiled_image_prefill = compile_model(
711
+ wrapped_model,
712
+ image_prefill_compile_config,
713
+ image_prefill_example_inputs,
714
+ context,
715
+ rbln_config.quantization,
716
+ )
717
+ compiled_models["image_prefill"] = compiled_image_prefill
706
718
 
707
- compiled_models = {"prefill": compiled_prefill, "image_prefill": compiled_image_prefill}
708
719
  wrapped_model.phase = "decode"
709
- for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[2:]):
720
+ for batch_size, dec_compile_config in zip(
721
+ rbln_config.decoder_batch_sizes, rbln_compile_configs[rbln_config.decoder_runtime_idx :]
722
+ ):
710
723
  dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
711
724
  compiled_decoder = compile_model(
712
725
  wrapped_model,
@@ -727,35 +740,45 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
727
740
  ) -> List[rebel.Runtime]:
728
741
  expected_model_names = [
729
742
  "prefill",
730
- "image_prefill",
731
743
  *[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
732
744
  ]
745
+ if rbln_config.use_image_prefill:
746
+ expected_model_names.insert(1, "image_prefill")
747
+
733
748
  if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
734
749
  cls._raise_missing_compiled_file_error(expected_model_names)
735
750
 
736
- return [
751
+ ret_val = [
737
752
  rebel.Runtime(
738
753
  compiled_models[0],
739
754
  tensor_type="pt",
740
755
  device=rbln_config.device_map["prefill"],
741
756
  activate_profiler=rbln_config.activate_profiler,
742
757
  timeout=rbln_config.timeout,
743
- ),
744
- rebel.Runtime(
745
- compiled_models[1],
746
- tensor_type="pt",
747
- device=rbln_config.device_map["image_prefill"],
748
- activate_profiler=rbln_config.activate_profiler,
749
- timeout=rbln_config.timeout,
750
- ),
751
- *[
758
+ )
759
+ ]
760
+ if rbln_config.use_image_prefill:
761
+ ret_val.append(
762
+ rebel.Runtime(
763
+ compiled_models[1],
764
+ tensor_type="pt",
765
+ device=rbln_config.device_map["image_prefill"],
766
+ activate_profiler=rbln_config.activate_profiler,
767
+ timeout=rbln_config.timeout,
768
+ ),
769
+ )
770
+
771
+ ret_val.extend(
772
+ [
752
773
  rebel.Runtime(
753
- compiled_models[i + 2],
774
+ compiled_models[i + rbln_config.decoder_runtime_idx],
754
775
  tensor_type="pt",
755
776
  device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
756
777
  activate_profiler=rbln_config.activate_profiler,
757
778
  timeout=rbln_config.timeout,
758
779
  )
759
780
  for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
760
- ],
761
- ]
781
+ ]
782
+ )
783
+
784
+ return ret_val
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_gpt2 import RBLNGPT2LMHeadModelConfig
16
- from .modeling_gpt2 import RBLNGPT2LMHeadModel
15
+ from .configuration_gpt2 import RBLNGPT2LMHeadModelConfig, RBLNGPT2ModelConfig
16
+ from .modeling_gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2Model
@@ -12,11 +12,39 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNGPT2LMHeadModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
19
  """
20
- Configuration class for GPT-2 causal language model.
21
- Inherits from RBLNDecoderOnlyModelForCausalLMConfig with no additional parameters.
20
+ Configuration class for RBLN GPT2 models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+ """
24
+
25
+
26
+ class RBLNGPT2ModelConfig(RBLNDecoderOnlyModelConfig):
27
+ """
28
+ Configuration class for RBLN GPT2 models.
29
+
30
+ This class is an alias of RBLNDecoderOnlyModelConfig.
31
+
32
+ Example usage:
33
+ ```python
34
+ from optimum.rbln import RBLNGPT2Model, RBLNGPT2ModelConfig
35
+
36
+ # Create a configuration object
37
+ config = RBLNGPT2ModelConfig(
38
+ batch_size=1,
39
+ max_seq_len=1024,
40
+ tensor_parallel_size=4
41
+ )
42
+
43
+ # Use the configuration with from_pretrained
44
+ model = RBLNGPT2Model.from_pretrained(
45
+ "openai/gpt2",
46
+ export=True,
47
+ rbln_config=config
48
+ )
49
+ ```
22
50
  """
@@ -13,11 +13,10 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import math
16
- from typing import TYPE_CHECKING, Tuple
16
+ from typing import TYPE_CHECKING, Tuple, Union
17
17
 
18
18
  import torch
19
19
  import torch.nn as nn
20
- from transformers import PreTrainedModel
21
20
 
22
21
  from ..decoderonly.decoderonly_architecture import (
23
22
  DecoderOnlyAttention,
@@ -28,7 +27,7 @@ from ..decoderonly.decoderonly_architecture import (
28
27
 
29
28
 
30
29
  if TYPE_CHECKING:
31
- from transformers import GPT2LMHeadModel
30
+ from transformers import GPT2LMHeadModel, GPT2Model
32
31
 
33
32
 
34
33
  class GPT2Wrapper(DecoderOnlyWrapper):
@@ -44,11 +43,11 @@ class GPT2Wrapper(DecoderOnlyWrapper):
44
43
  def get_attn_layer(self, layer: nn.Module):
45
44
  return layer.attn
46
45
 
47
- def get_model_layer(self, causal_lm: "GPT2LMHeadModel"):
48
- return causal_lm.transformer
46
+ def get_model_layer(self, model: Union["GPT2LMHeadModel", "GPT2Model"]):
47
+ return model.transformer if self.is_causal_lm else model
49
48
 
50
- def get_decoder_layers(self, causal_lm: PreTrainedModel):
51
- return causal_lm.transformer.h
49
+ def get_decoder_layers(self, model: Union["GPT2LMHeadModel", "GPT2Model"]):
50
+ return model.transformer.h if self.is_causal_lm else model.h
52
51
 
53
52
 
54
53
  class GPT2Model(DecoderOnlyModel):
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....utils import logging
16
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
17
17
  from .gpt2_architecture import GPT2Wrapper
18
18
 
19
19
 
@@ -36,3 +36,18 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
36
36
 
37
37
  _decoder_wrapper_cls = GPT2Wrapper
38
38
  _use_rotary_emb = False
39
+
40
+
41
+ class RBLNGPT2Model(RBLNDecoderOnlyModel):
42
+ """
43
+ The GPT2 Model transformer without a language modeling head.
44
+
45
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the
46
+ library implements for all its model.
47
+
48
+ A class to convert and run pre-trained transformers based GPT2Model model on RBLN devices.
49
+ It implements the methods to convert a pre-trained transformers GPT2Model model into a RBLN transformer model by:
50
+ """
51
+
52
+ _decoder_wrapper_cls = GPT2Wrapper
53
+ _use_rotary_emb = False
@@ -35,7 +35,7 @@ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
35
35
  from ....modeling import RBLNModel
36
36
  from ....utils.runtime_utils import RBLNPytorchRuntime
37
37
  from ..decoderonly.modeling_decoderonly import (
38
- RBLNDecoderOnlyOutput,
38
+ RBLNDecoderOnlyForCausalLMOutput,
39
39
  )
40
40
 
41
41
 
@@ -494,7 +494,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
494
494
  if not return_dict:
495
495
  return logits, generate_idx
496
496
  else:
497
- return RBLNDecoderOnlyOutput(
497
+ return RBLNDecoderOnlyForCausalLMOutput(
498
498
  logits=logits,
499
499
  generate_idx=generate_idx,
500
500
  )
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_llama import RBLNLlamaForCausalLMConfig
16
- from .modeling_llama import RBLNLlamaForCausalLM
15
+ from .configuration_llama import RBLNLlamaForCausalLMConfig, RBLNLlamaModelConfig
16
+ from .modeling_llama import RBLNLlamaForCausalLM, RBLNLlamaModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNLlamaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -40,3 +40,11 @@ class RBLNLlamaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
40
40
  )
41
41
  ```
42
42
  """
43
+
44
+
45
+ class RBLNLlamaModelConfig(RBLNDecoderOnlyModelConfig):
46
+ """
47
+ Configuration class for RBLN Llama models.
48
+
49
+ This class is an alias of RBLNDecoderOnlyModelConfig.
50
+ """
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....utils import logging
16
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
17
17
  from .llama_architecture import LlamaWrapper
18
18
 
19
19
 
@@ -81,3 +81,15 @@ class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
81
81
  """
82
82
 
83
83
  _decoder_wrapper_cls = LlamaWrapper
84
+
85
+
86
+ class RBLNLlamaModel(RBLNDecoderOnlyModel):
87
+ """
88
+ The Llama Model transformer with a language modeling head (linear layer) on top.
89
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
90
+
91
+ A class to convert and run pre-trained transformers based LlamaModel model on RBLN devices.
92
+ It implements the methods to convert a pre-trained transformers LlamaModel model into a RBLN transformer model by:
93
+ """
94
+
95
+ _decoder_wrapper_cls = LlamaWrapper
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_llava import RBLNLlavaForConditionalGenerationConfig
16
+ from .modeling_llava import RBLNLlavaForConditionalGeneration
@@ -0,0 +1,54 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNLlavaForConditionalGenerationConfig(RBLNModelConfig):
21
+ """
22
+ Configuration class for RBLNLlavaForConditionalGenerationConfig.
23
+
24
+ This configuration class stores the configuration parameters specific to
25
+ RBLN-optimized LLaVA models for multimodal conditional generation tasks
26
+ that combine vision and language processing capabilities.
27
+ """
28
+
29
+ submodules = ["vision_tower", "language_model"]
30
+
31
+ def __init__(
32
+ self,
33
+ batch_size: Optional[int] = None,
34
+ vision_tower: Optional[RBLNModelConfig] = None,
35
+ language_model: Optional[RBLNModelConfig] = None,
36
+ **kwargs: Dict[str, Any],
37
+ ):
38
+ """
39
+ Args:
40
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
41
+ vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
42
+ language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
43
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
44
+
45
+ Raises:
46
+ ValueError: If batch_size is not a positive integer.
47
+ """
48
+ super().__init__(**kwargs)
49
+ self.batch_size = batch_size or 1
50
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
51
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
52
+
53
+ self.vision_tower = vision_tower
54
+ self.language_model = language_model