optimum-rbln 0.8.1a0__py3-none-any.whl → 0.8.1a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. optimum/rbln/__init__.py +2 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +53 -33
  4. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  5. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
  11. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
  15. optimum/rbln/diffusers/modeling_diffusers.py +16 -26
  16. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
  17. optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
  18. optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
  19. optimum/rbln/diffusers/models/controlnet.py +13 -7
  20. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  21. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
  23. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  24. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  25. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  26. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  28. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  29. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  30. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  31. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  33. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  34. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  35. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  36. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  38. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  42. optimum/rbln/modeling.py +33 -35
  43. optimum/rbln/modeling_base.py +45 -107
  44. optimum/rbln/transformers/__init__.py +39 -47
  45. optimum/rbln/transformers/configuration_generic.py +16 -13
  46. optimum/rbln/transformers/modeling_generic.py +18 -19
  47. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  48. optimum/rbln/transformers/models/__init__.py +46 -4
  49. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  52. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  54. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
  55. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  56. optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
  57. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  58. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  59. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +229 -175
  60. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  61. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
  62. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
  63. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  64. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  65. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  66. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  67. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  68. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  69. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  70. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  71. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +106 -236
  72. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  73. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  74. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  75. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  76. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  77. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  78. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  79. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  80. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  81. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  82. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  83. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  84. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  85. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  86. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  87. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  88. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  89. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  90. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  91. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  92. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  93. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
  94. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +58 -27
  95. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +47 -2
  96. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  97. optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
  98. optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
  99. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  100. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
  101. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
  102. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  103. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  104. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
  105. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  106. optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
  107. optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
  108. optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
  109. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  110. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  111. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
  112. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
  113. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  114. optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
  115. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  116. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  117. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  118. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
  119. optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  123. optimum/rbln/utils/model_utils.py +20 -0
  124. optimum/rbln/utils/submodule.py +6 -8
  125. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/METADATA +2 -2
  126. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/RECORD +130 -117
  127. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  128. /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
  129. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/WHEEL +0 -0
  130. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/licenses/LICENSE +0 -0
@@ -19,6 +19,13 @@ from ...modeling_diffusers import RBLNDiffusionMixin
19
19
 
20
20
 
21
21
  class RBLNStableDiffusionImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionImg2ImgPipeline):
22
+ """
23
+ RBLN-accelerated implementation of Stable Diffusion pipeline for image-to-image generation.
24
+
25
+ This pipeline compiles Stable Diffusion models to run efficiently on RBLN NPUs, enabling high-performance
26
+ inference for transforming input images based on text prompts with controlled strength and guidance.
27
+ """
28
+
22
29
  original_class = StableDiffusionImg2ImgPipeline
23
30
  _rbln_config_class = RBLNStableDiffusionImg2ImgPipelineConfig
24
31
  _submodules = ["text_encoder", "unet", "vae"]
@@ -19,6 +19,13 @@ from ...modeling_diffusers import RBLNDiffusionMixin
19
19
 
20
20
 
21
21
  class RBLNStableDiffusionInpaintPipeline(RBLNDiffusionMixin, StableDiffusionInpaintPipeline):
22
+ """
23
+ RBLN-accelerated implementation of Stable Diffusion pipeline for image inpainting.
24
+
25
+ This pipeline compiles Stable Diffusion models to run efficiently on RBLN NPUs, enabling high-performance
26
+ inference for filling masked regions of images based on text prompts with seamless integration.
27
+ """
28
+
22
29
  original_class = StableDiffusionInpaintPipeline
23
30
  _rbln_config_class = RBLNStableDiffusionInpaintPipelineConfig
24
31
  _submodules = ["text_encoder", "unet", "vae"]
@@ -19,6 +19,13 @@ from ...modeling_diffusers import RBLNDiffusionMixin
19
19
 
20
20
 
21
21
  class RBLNStableDiffusion3Pipeline(RBLNDiffusionMixin, StableDiffusion3Pipeline):
22
+ """
23
+ RBLN-accelerated implementation of Stable Diffusion 3 pipeline for advanced text-to-image generation.
24
+
25
+ This pipeline compiles Stable Diffusion 3 models to run efficiently on RBLN NPUs, enabling high-performance
26
+ inference with improved text understanding, enhanced image quality, and superior prompt adherence.
27
+ """
28
+
22
29
  original_class = StableDiffusion3Pipeline
23
30
  _rbln_config_class = RBLNStableDiffusion3PipelineConfig
24
31
  _submodules = ["transformer", "text_encoder_3", "text_encoder", "text_encoder_2", "vae"]
@@ -19,6 +19,13 @@ from ...modeling_diffusers import RBLNDiffusionMixin
19
19
 
20
20
 
21
21
  class RBLNStableDiffusion3Img2ImgPipeline(RBLNDiffusionMixin, StableDiffusion3Img2ImgPipeline):
22
+ """
23
+ RBLN-accelerated implementation of Stable Diffusion 3 pipeline for advanced image-to-image generation.
24
+
25
+ This pipeline compiles Stable Diffusion 3 models to run efficiently on RBLN NPUs, enabling high-performance
26
+ inference for transforming input images with superior text understanding and enhanced visual quality.
27
+ """
28
+
22
29
  original_class = StableDiffusion3Img2ImgPipeline
23
30
  _rbln_config_class = RBLNStableDiffusion3Img2ImgPipelineConfig
24
31
  _submodules = ["transformer", "text_encoder_3", "text_encoder", "text_encoder_2", "vae"]
@@ -19,6 +19,13 @@ from ...modeling_diffusers import RBLNDiffusionMixin
19
19
 
20
20
 
21
21
  class RBLNStableDiffusion3InpaintPipeline(RBLNDiffusionMixin, StableDiffusion3InpaintPipeline):
22
+ """
23
+ RBLN-accelerated implementation of Stable Diffusion 3 pipeline for advanced image inpainting.
24
+
25
+ This pipeline compiles Stable Diffusion 3 models to run efficiently on RBLN NPUs, enabling high-performance
26
+ inference for filling masked regions with superior text understanding and seamless content generation.
27
+ """
28
+
22
29
  original_class = StableDiffusion3InpaintPipeline
23
30
  _rbln_config_class = RBLNStableDiffusion3InpaintPipelineConfig
24
31
  _submodules = ["transformer", "text_encoder_3", "text_encoder", "text_encoder_2", "vae"]
@@ -19,6 +19,13 @@ from ...modeling_diffusers import RBLNDiffusionMixin
19
19
 
20
20
 
21
21
  class RBLNStableDiffusionXLPipeline(RBLNDiffusionMixin, StableDiffusionXLPipeline):
22
+ """
23
+ RBLN-accelerated implementation of Stable Diffusion XL pipeline for high-resolution text-to-image generation.
24
+
25
+ This pipeline compiles Stable Diffusion XL models to run efficiently on RBLN NPUs, enabling high-performance
26
+ inference for generating high-quality images with enhanced detail and improved prompt adherence.
27
+ """
28
+
22
29
  original_class = StableDiffusionXLPipeline
23
30
  _rbln_config_class = RBLNStableDiffusionXLPipelineConfig
24
31
  _submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
@@ -19,6 +19,13 @@ from ...modeling_diffusers import RBLNDiffusionMixin
19
19
 
20
20
 
21
21
  class RBLNStableDiffusionXLImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionXLImg2ImgPipeline):
22
+ """
23
+ RBLN-accelerated implementation of Stable Diffusion XL pipeline for high-resolution image-to-image generation.
24
+
25
+ This pipeline compiles Stable Diffusion XL models to run efficiently on RBLN NPUs, enabling high-performance
26
+ inference for transforming input images with enhanced quality and detail preservation.
27
+ """
28
+
22
29
  original_class = StableDiffusionXLImg2ImgPipeline
23
30
  _rbln_config_class = RBLNStableDiffusionXLImg2ImgPipelineConfig
24
31
  _submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
@@ -19,6 +19,13 @@ from ...modeling_diffusers import RBLNDiffusionMixin
19
19
 
20
20
 
21
21
  class RBLNStableDiffusionXLInpaintPipeline(RBLNDiffusionMixin, StableDiffusionXLInpaintPipeline):
22
+ """
23
+ RBLN-accelerated implementation of Stable Diffusion XL pipeline for high-resolution image inpainting.
24
+
25
+ This pipeline compiles Stable Diffusion XL models to run efficiently on RBLN NPUs, enabling high-performance
26
+ inference for filling masked regions with enhanced quality and seamless blending capabilities.
27
+ """
28
+
22
29
  original_class = StableDiffusionXLInpaintPipeline
23
30
  _rbln_config_class = RBLNStableDiffusionXLInpaintPipelineConfig
24
31
  _submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
optimum/rbln/modeling.py CHANGED
@@ -14,7 +14,7 @@
14
14
 
15
15
  from pathlib import Path
16
16
  from tempfile import TemporaryDirectory
17
- from typing import TYPE_CHECKING, Dict, List, Optional, Union, get_args, get_origin, get_type_hints
17
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, get_args, get_origin, get_type_hints
18
18
 
19
19
  import rebel
20
20
  import torch
@@ -35,27 +35,12 @@ logger = get_logger(__name__)
35
35
 
36
36
 
37
37
  class RBLNModel(RBLNBaseModel):
38
- """
39
- A class that inherits from RBLNBaseModel for models consisting of a single `torch.nn.Module`.
40
-
41
- This class supports all the functionality of RBLNBaseModel, including loading and saving models using
42
- the `from_pretrained` and `save_pretrained` methods, compiling PyTorch models for execution on RBLN NPU
43
- devices.
44
-
45
- Example:
46
- ```python
47
- model = RBLNModel.from_pretrained("model_id", export=True, rbln_npu="npu_name")
48
- outputs = model(**inputs)
49
- ```
50
- """
51
-
52
38
  _output_class = None
53
39
 
54
40
  @classmethod
55
41
  def update_kwargs(cls, kwargs):
56
- """
57
- Update user-given kwargs to get proper pytorch model.
58
- """
42
+ # Update user-given kwargs to get proper pytorch model.
43
+
59
44
  return kwargs
60
45
 
61
46
  @classmethod
@@ -66,10 +51,9 @@ class RBLNModel(RBLNBaseModel):
66
51
  subfolder: str,
67
52
  rbln_config: RBLNModelConfig,
68
53
  ):
69
- """
70
- If you are unavoidably running on a CPU rather than an RBLN device,
71
- store the torch tensor, weight, etc. in this function.
72
- """
54
+ # If you are unavoidably running on a CPU rather than an RBLN device,
55
+ # store the torch tensor, weight, etc. in this function.
56
+ pass
73
57
 
74
58
  @classmethod
75
59
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
@@ -88,11 +72,32 @@ class RBLNModel(RBLNBaseModel):
88
72
  cls,
89
73
  model: "PreTrainedModel",
90
74
  config: Optional[PretrainedConfig] = None,
91
- rbln_config: Optional[RBLNModelConfig] = None,
75
+ rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
92
76
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
93
77
  subfolder: str = "",
94
- **kwargs,
95
- ):
78
+ **kwargs: Dict[str, Any],
79
+ ) -> "RBLNModel":
80
+ """
81
+ Converts and compiles a pre-trained HuggingFace library model into a RBLN model.
82
+ This method performs the actual model conversion and compilation process.
83
+
84
+ Args:
85
+ model: The PyTorch model to be compiled. The object must be an instance of the HuggingFace transformers PreTrainedModel class.
86
+ rbln_config: Configuration for RBLN model compilation and runtime. This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
87
+ For detailed configuration options, see the specific model's configuration class documentation.
88
+
89
+ kwargs: Additional keyword arguments. Arguments with the prefix 'rbln_' are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
90
+
91
+ The method performs the following steps:
92
+
93
+ 1. Compiles the PyTorch model into an optimized RBLN graph
94
+ 2. Configures the model for the specified NPU device
95
+ 3. Creates the necessary runtime objects if requested
96
+ 4. Saves the compiled model and configurations
97
+
98
+ Returns:
99
+ A RBLN model instance ready for inference on RBLN NPU devices.
100
+ """
96
101
  preprocessors = kwargs.pop("preprocessors", [])
97
102
  rbln_config, kwargs = cls.prepare_rbln_config(rbln_config=rbln_config, **kwargs)
98
103
 
@@ -246,12 +251,7 @@ class RBLNModel(RBLNBaseModel):
246
251
 
247
252
  @classmethod
248
253
  def get_hf_output_class(cls):
249
- """
250
- Dynamically gets the output class from the corresponding HuggingFace model class.
251
-
252
- Returns:
253
- type: The appropriate output class from transformers or diffusers
254
- """
254
+ # Dynamically gets the output class from the corresponding HuggingFace model class.
255
255
  if cls._output_class:
256
256
  return cls._output_class
257
257
 
@@ -278,10 +278,8 @@ class RBLNModel(RBLNBaseModel):
278
278
  return BaseModelOutput
279
279
 
280
280
  def _prepare_output(self, output, return_dict):
281
- """
282
- Prepare model output based on return_dict flag.
283
- This method can be overridden by subclasses to provide task-specific output handling.
284
- """
281
+ # Prepare model output based on return_dict flag.
282
+ # This method can be overridden by subclasses to provide task-specific output handling.
285
283
  tuple_output = (output,) if not isinstance(output, (tuple, list)) else tuple(output)
286
284
  if not return_dict:
287
285
  return tuple_output
@@ -15,7 +15,7 @@
15
15
  import importlib
16
16
  import os
17
17
  import shutil
18
- from abc import ABC, abstractmethod
18
+ from abc import ABC
19
19
  from pathlib import Path
20
20
  from tempfile import TemporaryDirectory
21
21
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
@@ -24,7 +24,7 @@ import rebel
24
24
  import torch
25
25
  from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
26
26
 
27
- from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig
27
+ from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
28
28
  from .utils.hub import PushToHubMixin, pull_compiled_model_from_hub, validate_files
29
29
  from .utils.logging import get_logger
30
30
  from .utils.runtime_utils import UnavailableRuntime
@@ -47,40 +47,6 @@ class RBLNBaseModelConfig(RBLNModelConfig):
47
47
 
48
48
 
49
49
  class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
50
- """
51
- An abstract base class for compiling, loading, and saving neural network models from the huggingface
52
- transformers and diffusers libraries to run on RBLN NPU devices.
53
-
54
- This class supports loading and saving models using the `from_pretrained` and `save_pretrained` methods,
55
- similar to the huggingface libraries.
56
-
57
- The `from_pretrained` method loads a model corresponding to the given `model_id` from a local repository
58
- or the huggingface hub onto the NPU. If the model is a PyTorch model and `export=True` is passed as a
59
- kwarg, it compiles the PyTorch model corresponding to the given `model_id` before loading. If `model_id`
60
- is an already rbln-compiled model, it can be directly loaded onto the NPU with `export=False`.
61
-
62
- `rbln_npu` is a kwarg required for compilation, specifying the name of the NPU to be used. If this
63
- keyword is not specified, the NPU installed on the host machine is used. If no NPU is installed on the
64
- host machine, an error occurs.
65
-
66
- `rbln_device` specifies the device to be used at runtime. If not specified, device 0 is used.
67
-
68
- `rbln_create_runtimes` indicates whether to create runtime objects. If False, the runtime does not load
69
- the model onto the NPU. This option is particularly useful when you want to perform compilation only on a
70
- host machine without an NPU.
71
-
72
- `RBLNModel`, `RBLNModelFor*`, etc. are all child classes of RBLNBaseModel.
73
-
74
- Models compiled in this way can be saved to a local repository using `save_pretrained` or uploaded to
75
- the huggingface hub.
76
-
77
- It also supports generation through `generate` (for transformers models that support generation).
78
-
79
- RBLNBaseModel is a class for models consisting of an arbitrary number of `torch.nn.Module`s, and
80
- therefore is an abstract class without explicit implementations of `forward` or `export` functions.
81
- To inherit from this class, `forward`, `export`, etc. must be implemented.
82
- """
83
-
84
50
  model_type = "rbln_model"
85
51
  auto_model_class = AutoModel
86
52
  config_class = AutoConfig
@@ -156,7 +122,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
156
122
  subfolder: str = "",
157
123
  local_files_only: bool = False,
158
124
  ) -> str:
159
- """Load the directory containing the compiled model files."""
125
+ # Load the directory containing the compiled model files.
160
126
  model_path = Path(model_id)
161
127
 
162
128
  if model_path.is_dir():
@@ -372,19 +338,40 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
372
338
  def prepare_rbln_config(
373
339
  cls, rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None, **kwargs
374
340
  ) -> Tuple[RBLNModelConfig, Dict[str, Any]]:
375
- """
376
- Extract rbln-config from kwargs and convert it to RBLNModelConfig.
377
- """
341
+ # Extract rbln-config from kwargs and convert it to RBLNModelConfig.
342
+
378
343
  config_cls = cls.get_rbln_config_class()
379
344
  rbln_config, kwargs = config_cls.initialize_from_kwargs(rbln_config, **kwargs)
380
345
  return rbln_config, kwargs
381
346
 
382
347
  @classmethod
383
- def from_pretrained(cls, model_id: Union[str, Path], export: bool = False, **kwargs) -> "RBLNBaseModel":
348
+ def from_pretrained(
349
+ cls: Type["RBLNBaseModel"],
350
+ model_id: Union[str, Path],
351
+ export: bool = False,
352
+ rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
353
+ **kwargs: Dict[str, Any],
354
+ ) -> "RBLNBaseModel":
355
+ """
356
+ The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
357
+ User can use this function to load a pre-trained model from the HuggingFace library and convert it to a RBLN model to be run on RBLN NPUs.
358
+
359
+ Args:
360
+ model_id: The model id of the pre-trained model to be loaded. It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
361
+ export: A boolean flag to indicate whether the model should be compiled.
362
+ rbln_config: Configuration for RBLN model compilation and runtime. This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
363
+ For detailed configuration options, see the specific model's configuration class documentation.
364
+
365
+ kwargs: Additional keyword arguments. Arguments with the prefix 'rbln_' are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
366
+
367
+ Returns:
368
+ A RBLN model instance ready for inference on RBLN NPU devices.
369
+ """
370
+
384
371
  if isinstance(model_id, Path):
385
372
  model_id = model_id.as_posix()
386
373
  from_pretrained_method = cls._export if export else cls._from_pretrained
387
- return from_pretrained_method(model_id=model_id, **kwargs)
374
+ return from_pretrained_method(model_id=model_id, **kwargs, rbln_config=rbln_config)
388
375
 
389
376
  @classmethod
390
377
  def compile(cls, model, rbln_compile_config: Optional[RBLNCompileConfig] = None, **kwargs):
@@ -411,15 +398,13 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
411
398
 
412
399
  @classmethod
413
400
  def get_hf_class(cls):
414
- """
415
- Lazily loads and caches the corresponding HuggingFace model class.
416
- Removes 'RBLN' prefix from the class name to get the original class name
417
- (e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
418
- the transformers/diffusers module.
401
+ # Lazily loads and caches the corresponding HuggingFace model class.
402
+ # Removes 'RBLN' prefix from the class name to get the original class name
403
+ # (e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
404
+ # the transformers/diffusers module.
419
405
 
420
- Returns:
421
- type: The original HuggingFace model class
422
- """
406
+ # Returns:
407
+ # type: The original HuggingFace model class
423
408
  if cls._hf_class is None:
424
409
  hf_cls_name = cls.__name__[4:]
425
410
  library = importlib.import_module(cls.hf_library_name)
@@ -428,18 +413,10 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
428
413
 
429
414
  @classmethod
430
415
  def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
431
- """
432
- Lazily loads and caches the corresponding RBLN model config class.
433
- """
416
+ # Lazily loads and caches the corresponding RBLN model config class.
434
417
  if cls._rbln_config_class is None:
435
418
  rbln_config_class_name = cls.__name__ + "Config"
436
- library = importlib.import_module("optimum.rbln")
437
- cls._rbln_config_class = getattr(library, rbln_config_class_name, None)
438
- if cls._rbln_config_class is None:
439
- raise ValueError(
440
- f"RBLN config class {rbln_config_class_name} not found. This is an internal error. "
441
- "Please report it to the developers."
442
- )
419
+ cls._rbln_config_class = get_rbln_config_class(rbln_config_class_name)
443
420
  return cls._rbln_config_class
444
421
 
445
422
  def can_generate(self):
@@ -449,17 +426,15 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
449
426
  return self
450
427
 
451
428
  def parameters(self):
452
- """
453
- Provides a dummy parameter generator for compatibility.
429
+ # A dummy parameter generator for compatibility.
454
430
 
455
- This method mimics the interface of torch.nn.Module.parameters()
456
- specifically for code that uses `next(model.parameters())` to infer
457
- the device or dtype. It yields a single dummy tensor on CPU with float32 dtype.
431
+ # This method mimics the interface of torch.nn.Module.parameters()
432
+ # specifically for code that uses `next(model.parameters())` to infer
433
+ # the device or dtype. It yields a single dummy tensor on CPU with float32 dtype.
458
434
 
459
- Warning:
460
- This does NOT yield the actual model parameters used by the RBLN runtime.
461
- Code relying on iterating through all model parameters will not work as expected.
462
- """
435
+ # Warning:
436
+ # This does NOT yield the actual model parameters used by the RBLN runtime.
437
+ # Code relying on iterating through all model parameters will not work as expected.
463
438
  yield torch.tensor([1.0], dtype=torch.float32, device=torch.device("cpu"))
464
439
 
465
440
  def __call__(self, *args, **kwargs):
@@ -547,7 +522,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
547
522
 
548
523
  @staticmethod
549
524
  def _raise_missing_compiled_file_error(missing_files: List[str]):
550
- """Raises a KeyError with a message indicating missing compiled model files."""
525
+ # Raises a KeyError with a message indicating missing compiled model files.
551
526
 
552
527
  if len(missing_files) == 1:
553
528
  message = f"The rbln model folder is missing the required '{missing_files[0]}.rbln' file. "
@@ -563,40 +538,3 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
563
538
  "and ensure the compilation completes successfully."
564
539
  )
565
540
  raise KeyError(message)
566
-
567
- @classmethod
568
- @abstractmethod
569
- def _update_rbln_config(cls, **rbln_config_kwargs) -> RBLNModelConfig:
570
- pass
571
-
572
- @classmethod
573
- @abstractmethod
574
- def _create_runtimes(
575
- cls,
576
- compiled_models: List[rebel.RBLNCompiledModel],
577
- rbln_config: RBLNModelConfig,
578
- ) -> List[rebel.Runtime]:
579
- # compiled_models -> runtimes
580
- pass
581
-
582
- @classmethod
583
- @abstractmethod
584
- def get_pytorch_model(cls, *args, **kwargs):
585
- pass
586
-
587
- @classmethod
588
- @abstractmethod
589
- def from_model(
590
- cls,
591
- model: "PreTrainedModel",
592
- config: Optional[PretrainedConfig] = None,
593
- rbln_config: Optional[RBLNModelConfig] = None,
594
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
595
- subfolder: str = "",
596
- **kwargs,
597
- ):
598
- pass
599
-
600
- @abstractmethod
601
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
602
- pass
@@ -18,16 +18,9 @@ from transformers.utils import _LazyModule
18
18
 
19
19
 
20
20
  _import_structure = {
21
- "configuration_alias": [
22
- "RBLNASTForAudioClassificationConfig",
23
- "RBLNDistilBertForQuestionAnsweringConfig",
24
- "RBLNResNetForImageClassificationConfig",
25
- "RBLNXLMRobertaForSequenceClassificationConfig",
26
- "RBLNRobertaForSequenceClassificationConfig",
27
- "RBLNRobertaForMaskedLMConfig",
28
- "RBLNViTForImageClassificationConfig",
29
- ],
30
21
  "models": [
22
+ "RBLNASTForAudioClassification",
23
+ "RBLNASTForAudioClassificationConfig",
31
24
  "RBLNAutoModel",
32
25
  "RBLNAutoModelForAudioClassification",
33
26
  "RBLNAutoModelForCausalLM",
@@ -51,12 +44,12 @@ _import_structure = {
51
44
  "RBLNBertForQuestionAnsweringConfig",
52
45
  "RBLNBertModel",
53
46
  "RBLNBertModelConfig",
54
- "RBLNBlip2VisionModelConfig",
55
- "RBLNBlip2VisionModel",
56
- "RBLNBlip2QFormerModel",
57
- "RBLNBlip2QFormerModelConfig",
58
47
  "RBLNBlip2ForConditionalGeneration",
59
48
  "RBLNBlip2ForConditionalGenerationConfig",
49
+ "RBLNBlip2QFormerModel",
50
+ "RBLNBlip2QFormerModelConfig",
51
+ "RBLNBlip2VisionModel",
52
+ "RBLNBlip2VisionModelConfig",
60
53
  "RBLNCLIPTextModel",
61
54
  "RBLNCLIPTextModelConfig",
62
55
  "RBLNCLIPTextModelWithProjection",
@@ -67,40 +60,48 @@ _import_structure = {
67
60
  "RBLNCLIPVisionModelWithProjectionConfig",
68
61
  "RBLNDecoderOnlyModelForCausalLM",
69
62
  "RBLNDecoderOnlyModelForCausalLMConfig",
63
+ "RBLNDistilBertForQuestionAnswering",
64
+ "RBLNDistilBertForQuestionAnsweringConfig",
70
65
  "RBLNDPTForDepthEstimation",
71
66
  "RBLNDPTForDepthEstimationConfig",
72
67
  "RBLNExaoneForCausalLM",
73
68
  "RBLNExaoneForCausalLMConfig",
74
- "RBLNGemmaForCausalLM",
75
- "RBLNGemmaForCausalLMConfig",
76
69
  "RBLNGemma3ForCausalLM",
77
70
  "RBLNGemma3ForCausalLMConfig",
78
71
  "RBLNGemma3ForConditionalGeneration",
79
72
  "RBLNGemma3ForConditionalGenerationConfig",
73
+ "RBLNGemmaForCausalLM",
74
+ "RBLNGemmaForCausalLMConfig",
80
75
  "RBLNGPT2LMHeadModel",
81
76
  "RBLNGPT2LMHeadModelConfig",
82
- "RBLNIdefics3VisionTransformer",
83
77
  "RBLNIdefics3ForConditionalGeneration",
84
78
  "RBLNIdefics3ForConditionalGenerationConfig",
79
+ "RBLNIdefics3VisionTransformer",
85
80
  "RBLNIdefics3VisionTransformerConfig",
86
81
  "RBLNLlamaForCausalLM",
87
82
  "RBLNLlamaForCausalLMConfig",
88
- "RBLNOPTForCausalLM",
89
- "RBLNOPTForCausalLMConfig",
90
83
  "RBLNLlavaNextForConditionalGeneration",
91
84
  "RBLNLlavaNextForConditionalGenerationConfig",
92
85
  "RBLNMidmLMHeadModel",
93
86
  "RBLNMidmLMHeadModelConfig",
94
87
  "RBLNMistralForCausalLM",
95
88
  "RBLNMistralForCausalLMConfig",
89
+ "RBLNOPTForCausalLM",
90
+ "RBLNOPTForCausalLMConfig",
96
91
  "RBLNPhiForCausalLM",
97
92
  "RBLNPhiForCausalLMConfig",
98
- "RBLNQwen2ForCausalLM",
99
- "RBLNQwen2ForCausalLMConfig",
100
93
  "RBLNQwen2_5_VisionTransformerPretrainedModel",
101
94
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
102
95
  "RBLNQwen2_5_VLForConditionalGeneration",
103
96
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
97
+ "RBLNQwen2ForCausalLM",
98
+ "RBLNQwen2ForCausalLMConfig",
99
+ "RBLNResNetForImageClassification",
100
+ "RBLNResNetForImageClassificationConfig",
101
+ "RBLNRobertaForMaskedLM",
102
+ "RBLNRobertaForMaskedLMConfig",
103
+ "RBLNRobertaForSequenceClassification",
104
+ "RBLNRobertaForSequenceClassificationConfig",
104
105
  "RBLNSiglipVisionModel",
105
106
  "RBLNSiglipVisionModelConfig",
106
107
  "RBLNT5EncoderModel",
@@ -109,44 +110,23 @@ _import_structure = {
109
110
  "RBLNT5ForConditionalGenerationConfig",
110
111
  "RBLNTimeSeriesTransformerForPrediction",
111
112
  "RBLNTimeSeriesTransformerForPredictionConfig",
113
+ "RBLNViTForImageClassification",
114
+ "RBLNViTForImageClassificationConfig",
112
115
  "RBLNWav2Vec2ForCTC",
113
116
  "RBLNWav2Vec2ForCTCConfig",
114
117
  "RBLNWhisperForConditionalGeneration",
115
118
  "RBLNWhisperForConditionalGenerationConfig",
119
+ "RBLNXLMRobertaForSequenceClassification",
120
+ "RBLNXLMRobertaForSequenceClassificationConfig",
116
121
  "RBLNXLMRobertaModel",
117
122
  "RBLNXLMRobertaModelConfig",
118
123
  ],
119
- "modeling_alias": [
120
- "RBLNASTForAudioClassification",
121
- "RBLNDistilBertForQuestionAnswering",
122
- "RBLNResNetForImageClassification",
123
- "RBLNXLMRobertaForSequenceClassification",
124
- "RBLNRobertaForSequenceClassification",
125
- "RBLNRobertaForMaskedLM",
126
- "RBLNViTForImageClassification",
127
- ],
128
124
  }
129
125
 
130
126
  if TYPE_CHECKING:
131
- from .configuration_alias import (
132
- RBLNASTForAudioClassificationConfig,
133
- RBLNDistilBertForQuestionAnsweringConfig,
134
- RBLNResNetForImageClassificationConfig,
135
- RBLNRobertaForMaskedLMConfig,
136
- RBLNRobertaForSequenceClassificationConfig,
137
- RBLNViTForImageClassificationConfig,
138
- RBLNXLMRobertaForSequenceClassificationConfig,
139
- )
140
- from .modeling_alias import (
141
- RBLNASTForAudioClassification,
142
- RBLNDistilBertForQuestionAnswering,
143
- RBLNResNetForImageClassification,
144
- RBLNRobertaForMaskedLM,
145
- RBLNRobertaForSequenceClassification,
146
- RBLNViTForImageClassification,
147
- RBLNXLMRobertaForSequenceClassification,
148
- )
149
127
  from .models import (
128
+ RBLNASTForAudioClassification,
129
+ RBLNASTForAudioClassificationConfig,
150
130
  RBLNAutoModel,
151
131
  RBLNAutoModelForAudioClassification,
152
132
  RBLNAutoModelForCausalLM,
@@ -186,6 +166,8 @@ if TYPE_CHECKING:
186
166
  RBLNCLIPVisionModelWithProjectionConfig,
187
167
  RBLNDecoderOnlyModelForCausalLM,
188
168
  RBLNDecoderOnlyModelForCausalLMConfig,
169
+ RBLNDistilBertForQuestionAnswering,
170
+ RBLNDistilBertForQuestionAnsweringConfig,
189
171
  RBLNDPTForDepthEstimation,
190
172
  RBLNDPTForDepthEstimationConfig,
191
173
  RBLNExaoneForCausalLM,
@@ -220,6 +202,12 @@ if TYPE_CHECKING:
220
202
  RBLNQwen2_5_VLForConditionalGenerationConfig,
221
203
  RBLNQwen2ForCausalLM,
222
204
  RBLNQwen2ForCausalLMConfig,
205
+ RBLNResNetForImageClassification,
206
+ RBLNResNetForImageClassificationConfig,
207
+ RBLNRobertaForMaskedLM,
208
+ RBLNRobertaForMaskedLMConfig,
209
+ RBLNRobertaForSequenceClassification,
210
+ RBLNRobertaForSequenceClassificationConfig,
223
211
  RBLNSiglipVisionModel,
224
212
  RBLNSiglipVisionModelConfig,
225
213
  RBLNT5EncoderModel,
@@ -228,10 +216,14 @@ if TYPE_CHECKING:
228
216
  RBLNT5ForConditionalGenerationConfig,
229
217
  RBLNTimeSeriesTransformerForPrediction,
230
218
  RBLNTimeSeriesTransformerForPredictionConfig,
219
+ RBLNViTForImageClassification,
220
+ RBLNViTForImageClassificationConfig,
231
221
  RBLNWav2Vec2ForCTC,
232
222
  RBLNWav2Vec2ForCTCConfig,
233
223
  RBLNWhisperForConditionalGeneration,
234
224
  RBLNWhisperForConditionalGenerationConfig,
225
+ RBLNXLMRobertaForSequenceClassification,
226
+ RBLNXLMRobertaForSequenceClassificationConfig,
235
227
  RBLNXLMRobertaModel,
236
228
  RBLNXLMRobertaModelConfig,
237
229
  )