optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,31 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from diffusers import StableDiffusion3Img2ImgPipeline
16
+
17
+ from ...configurations import RBLNStableDiffusion3Img2ImgPipelineConfig
18
+ from ...modeling_diffusers import RBLNDiffusionMixin
19
+
20
+
21
+ class RBLNStableDiffusion3Img2ImgPipeline(RBLNDiffusionMixin, StableDiffusion3Img2ImgPipeline):
22
+ """
23
+ RBLN-accelerated implementation of Stable Diffusion 3 pipeline for advanced image-to-image generation.
24
+
25
+ This pipeline compiles Stable Diffusion 3 models to run efficiently on RBLN NPUs, enabling high-performance
26
+ inference for transforming input images with superior text understanding and enhanced visual quality.
27
+ """
28
+
29
+ original_class = StableDiffusion3Img2ImgPipeline
30
+ _rbln_config_class = RBLNStableDiffusion3Img2ImgPipelineConfig
31
+ _submodules = ["transformer", "text_encoder_3", "text_encoder", "text_encoder_2", "vae"]
@@ -0,0 +1,31 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from diffusers import StableDiffusion3InpaintPipeline
16
+
17
+ from ...configurations import RBLNStableDiffusion3InpaintPipelineConfig
18
+ from ...modeling_diffusers import RBLNDiffusionMixin
19
+
20
+
21
+ class RBLNStableDiffusion3InpaintPipeline(RBLNDiffusionMixin, StableDiffusion3InpaintPipeline):
22
+ """
23
+ RBLN-accelerated implementation of Stable Diffusion 3 pipeline for advanced image inpainting.
24
+
25
+ This pipeline compiles Stable Diffusion 3 models to run efficiently on RBLN NPUs, enabling high-performance
26
+ inference for filling masked regions with superior text understanding and seamless content generation.
27
+ """
28
+
29
+ original_class = StableDiffusion3InpaintPipeline
30
+ _rbln_config_class = RBLNStableDiffusion3InpaintPipelineConfig
31
+ _submodules = ["transformer", "text_encoder_3", "text_encoder", "text_encoder_2", "vae"]
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .pipeline_stable_diffusion_xl import RBLNStableDiffusionXLPipeline
16
+ from .pipeline_stable_diffusion_xl_img2img import RBLNStableDiffusionXLImg2ImgPipeline
17
+ from .pipeline_stable_diffusion_xl_inpaint import RBLNStableDiffusionXLInpaintPipeline
@@ -0,0 +1,31 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from diffusers import StableDiffusionXLPipeline
16
+
17
+ from ...configurations import RBLNStableDiffusionXLPipelineConfig
18
+ from ...modeling_diffusers import RBLNDiffusionMixin
19
+
20
+
21
+ class RBLNStableDiffusionXLPipeline(RBLNDiffusionMixin, StableDiffusionXLPipeline):
22
+ """
23
+ RBLN-accelerated implementation of Stable Diffusion XL pipeline for high-resolution text-to-image generation.
24
+
25
+ This pipeline compiles Stable Diffusion XL models to run efficiently on RBLN NPUs, enabling high-performance
26
+ inference for generating high-quality images with enhanced detail and improved prompt adherence.
27
+ """
28
+
29
+ original_class = StableDiffusionXLPipeline
30
+ _rbln_config_class = RBLNStableDiffusionXLPipelineConfig
31
+ _submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
@@ -0,0 +1,31 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from diffusers import StableDiffusionXLImg2ImgPipeline
16
+
17
+ from ...configurations import RBLNStableDiffusionXLImg2ImgPipelineConfig
18
+ from ...modeling_diffusers import RBLNDiffusionMixin
19
+
20
+
21
+ class RBLNStableDiffusionXLImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionXLImg2ImgPipeline):
22
+ """
23
+ RBLN-accelerated implementation of Stable Diffusion XL pipeline for high-resolution image-to-image generation.
24
+
25
+ This pipeline compiles Stable Diffusion XL models to run efficiently on RBLN NPUs, enabling high-performance
26
+ inference for transforming input images with enhanced quality and detail preservation.
27
+ """
28
+
29
+ original_class = StableDiffusionXLImg2ImgPipeline
30
+ _rbln_config_class = RBLNStableDiffusionXLImg2ImgPipelineConfig
31
+ _submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
@@ -0,0 +1,31 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from diffusers import StableDiffusionXLInpaintPipeline
16
+
17
+ from ...configurations import RBLNStableDiffusionXLInpaintPipelineConfig
18
+ from ...modeling_diffusers import RBLNDiffusionMixin
19
+
20
+
21
+ class RBLNStableDiffusionXLInpaintPipeline(RBLNDiffusionMixin, StableDiffusionXLInpaintPipeline):
22
+ """
23
+ RBLN-accelerated implementation of Stable Diffusion XL pipeline for high-resolution image inpainting.
24
+
25
+ This pipeline compiles Stable Diffusion XL models to run efficiently on RBLN NPUs, enabling high-performance
26
+ inference for filling masked regions with enhanced quality and seamless blending capabilities.
27
+ """
28
+
29
+ original_class = StableDiffusionXLInpaintPipeline
30
+ _rbln_config_class = RBLNStableDiffusionXLInpaintPipelineConfig
31
+ _submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
@@ -0,0 +1,15 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .pipeline_stable_video_diffusion import RBLNStableVideoDiffusionPipeline
@@ -0,0 +1,46 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from diffusers import StableVideoDiffusionPipeline
17
+
18
+ from ....utils.logging import get_logger
19
+ from ...configurations import RBLNStableVideoDiffusionPipelineConfig
20
+ from ...modeling_diffusers import RBLNDiffusionMixin
21
+
22
+
23
+ logger = get_logger(__name__)
24
+
25
+
26
+ class RBLNStableVideoDiffusionPipeline(RBLNDiffusionMixin, StableVideoDiffusionPipeline):
27
+ """
28
+ RBLN-accelerated implementation of Stable Video Diffusion pipeline for image-to-video generation.
29
+
30
+ This pipeline compiles Stable Video Diffusion models to run efficiently on RBLN NPUs, enabling high-performance
31
+ inference for generating videos from images with optimized memory usage and throughput.
32
+ """
33
+
34
+ original_class = StableVideoDiffusionPipeline
35
+ _rbln_config_class = RBLNStableVideoDiffusionPipelineConfig
36
+ _submodules = ["image_encoder", "unet", "vae"]
37
+
38
+ def handle_additional_kwargs(self, **kwargs):
39
+ compiled_num_frames = self.unet.rbln_config.num_frames
40
+ if compiled_num_frames is not None:
41
+ kwargs["num_frames"] = compiled_num_frames
42
+
43
+ compiled_decode_chunk_size = self.vae.rbln_config.decode_chunk_size
44
+ if compiled_decode_chunk_size is not None:
45
+ kwargs["decode_chunk_size"] = compiled_decode_chunk_size
46
+ return kwargs
@@ -0,0 +1,364 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from tempfile import TemporaryDirectory
17
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, get_args, get_origin, get_type_hints
18
+
19
+ import rebel
20
+ import torch
21
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
22
+ from transformers import PretrainedConfig
23
+ from transformers.modeling_outputs import BaseModelOutput
24
+
25
+ from .configuration_utils import DEFAULT_COMPILED_MODEL_NAME, RBLNModelConfig
26
+ from .modeling_base import RBLNBaseModel
27
+ from .utils.logging import get_logger
28
+
29
+
30
+ if TYPE_CHECKING:
31
+ from transformers import PreTrainedModel
32
+
33
+
34
+ logger = get_logger(__name__)
35
+
36
+
37
+ class RBLNModel(RBLNBaseModel):
38
+ @classmethod
39
+ def update_kwargs(cls, kwargs):
40
+ # Update user-given kwargs to get proper pytorch model.
41
+
42
+ return kwargs
43
+
44
+ @classmethod
45
+ def save_torch_artifacts(
46
+ cls,
47
+ model: "PreTrainedModel",
48
+ save_dir_path: Path,
49
+ subfolder: str,
50
+ rbln_config: RBLNModelConfig,
51
+ ):
52
+ # If you are unavoidably running on a CPU rather than an RBLN device,
53
+ # store the torch tensor, weight, etc. in this function.
54
+ pass
55
+
56
+ @classmethod
57
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
58
+ # Wrap the model if needed.
59
+ return model
60
+
61
+ @classmethod
62
+ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
63
+ if rbln_config._allow_no_compile_cfgs:
64
+ return {}
65
+
66
+ model = cls._wrap_model_if_needed(model, rbln_config)
67
+ rbln_compile_config = rbln_config.compile_cfgs[0]
68
+ compiled_model = cls.compile(
69
+ model,
70
+ rbln_compile_config=rbln_compile_config,
71
+ create_runtimes=rbln_config.create_runtimes,
72
+ device=rbln_config.device,
73
+ )
74
+ return compiled_model
75
+
76
+ @classmethod
77
+ def _update_rbln_config(
78
+ cls,
79
+ preprocessors: Optional[Any],
80
+ model: Optional["PreTrainedModel"] = None,
81
+ model_config: Optional["PretrainedConfig"] = None,
82
+ rbln_config: Optional[RBLNModelConfig] = None,
83
+ ) -> RBLNModelConfig:
84
+ # Default implementation: return config as-is
85
+ # Subclasses should override to set compile_cfgs if needed
86
+ return rbln_config
87
+
88
+ @classmethod
89
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
90
+ return model
91
+
92
+ @classmethod
93
+ def from_model(
94
+ cls,
95
+ model: "PreTrainedModel",
96
+ config: Optional[PretrainedConfig] = None,
97
+ rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
98
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
99
+ subfolder: str = "",
100
+ **kwargs: Any,
101
+ ) -> "RBLNModel":
102
+ """
103
+ Converts and compiles a pre-trained HuggingFace library model into a RBLN model.
104
+ This method performs the actual model conversion and compilation process.
105
+
106
+ Args:
107
+ model (PreTrainedModel): The PyTorch model to be compiled.
108
+ The object must be an instance of the HuggingFace transformers PreTrainedModel class.
109
+ config (Optional[PretrainedConfig]): The configuration object associated with the model.
110
+ rbln_config (Optional[Union[RBLNModelConfig, Dict]]): Configuration for RBLN model compilation and runtime.
111
+ This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
112
+ For detailed configuration options, see the specific model's configuration class documentation.
113
+ kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
114
+
115
+ The method performs the following steps:
116
+
117
+ 1. Compiles the PyTorch model into an optimized RBLN graph
118
+ 2. Configures the model for the specified NPU device
119
+ 3. Creates the necessary runtime objects if requested
120
+ 4. Saves the compiled model and configurations
121
+
122
+ Returns:
123
+ (RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
124
+ """
125
+
126
+ model = cls._reconstruct_model_if_needed(model)
127
+ preprocessors = kwargs.pop("preprocessors", [])
128
+ rbln_config, kwargs = cls.prepare_rbln_config(rbln_config=rbln_config, **kwargs)
129
+
130
+ # Directory to save compile artifacts(.rbln) and original configs
131
+ if model_save_dir is None:
132
+ save_dir = TemporaryDirectory()
133
+ save_dir_path = Path(save_dir.name)
134
+ else:
135
+ save_dir = model_save_dir
136
+ if isinstance(save_dir, TemporaryDirectory):
137
+ save_dir_path = Path(model_save_dir.name)
138
+ else:
139
+ save_dir_path = Path(model_save_dir)
140
+ save_dir_path.mkdir(exist_ok=True)
141
+
142
+ # Save configs
143
+ if config is None:
144
+ config = model.config
145
+
146
+ if hasattr(model, "can_generate") and model.can_generate():
147
+ import json
148
+
149
+ generation_config = model.generation_config
150
+ generation_config_path = save_dir_path / subfolder / "generation_config.json"
151
+
152
+ generation_config.save_pretrained(generation_config_path.parent)
153
+ local_config = json.loads(generation_config_path.read_text(encoding="utf-8"))
154
+ local_config["transformers_version"] = generation_config.transformers_version
155
+ generation_config_path.write_text(json.dumps(local_config, indent=2) + "\n", encoding="utf-8")
156
+
157
+ if not isinstance(config, PretrainedConfig): # diffusers config
158
+ config = PretrainedConfig(**config)
159
+
160
+ # Save preprocessor
161
+ for preprocessor in preprocessors:
162
+ preprocessor.save_pretrained(save_dir_path / subfolder)
163
+
164
+ # Load submodules
165
+ if len(cls._rbln_submodules) > 0:
166
+ rbln_submodules = cls._load_submodules(
167
+ model=model,
168
+ model_save_dir=save_dir,
169
+ rbln_config=rbln_config,
170
+ preprocessors=preprocessors,
171
+ **kwargs,
172
+ )
173
+ else:
174
+ rbln_submodules = []
175
+
176
+ # Get compilation arguments (e.g. input_info)
177
+ rbln_config: RBLNModelConfig = cls.update_rbln_config(
178
+ preprocessors=preprocessors, model=model, model_config=config, rbln_config=rbln_config
179
+ )
180
+
181
+ # torchscript should be True for jit to work
182
+ torchscript_backup = config.torchscript
183
+ config.torchscript = True
184
+
185
+ compiled_model: Union[rebel.RBLNCompiledModel, Dict[str, rebel.RBLNCompiledModel]] = cls.get_compiled_model(
186
+ model, rbln_config=rbln_config
187
+ )
188
+
189
+ # Save compiled models (.rbln)
190
+ (save_dir_path / subfolder).mkdir(exist_ok=True)
191
+ if not isinstance(compiled_model, dict):
192
+ compiled_models = {DEFAULT_COMPILED_MODEL_NAME: compiled_model}
193
+ else:
194
+ compiled_models = compiled_model
195
+ for compiled_model_name, cm in compiled_models.items():
196
+ cm.save(save_dir_path / subfolder / f"{compiled_model_name}.rbln")
197
+ rbln_config.save(save_dir_path / subfolder)
198
+
199
+ config.torchscript = torchscript_backup
200
+ config.save_pretrained(save_dir_path / subfolder)
201
+
202
+ # Save torch artifacts (e.g. embedding matrix if needed.)
203
+ cls.save_torch_artifacts(model, save_dir_path=save_dir_path, subfolder=subfolder, rbln_config=rbln_config)
204
+
205
+ # Instantiate
206
+ return cls._from_pretrained(
207
+ model_id=save_dir_path,
208
+ config=config,
209
+ model_save_dir=save_dir,
210
+ subfolder=subfolder,
211
+ rbln_config=rbln_config,
212
+ rbln_compiled_models=compiled_models,
213
+ rbln_submodules=rbln_submodules,
214
+ **kwargs,
215
+ )
216
+
217
+ @classmethod
218
+ def get_pytorch_model(
219
+ cls,
220
+ model_id: str,
221
+ use_auth_token: Optional[Union[bool, str]] = None,
222
+ revision: Optional[str] = None,
223
+ force_download: bool = False,
224
+ cache_dir: Optional[str] = HUGGINGFACE_HUB_CACHE,
225
+ subfolder: str = "",
226
+ local_files_only: bool = False,
227
+ trust_remote_code: bool = False,
228
+ # Some rbln-config should be applied before loading torch module (i.e. quantized llm)
229
+ rbln_config: Optional[RBLNModelConfig] = None,
230
+ **kwargs,
231
+ ) -> "PreTrainedModel":
232
+ kwargs = cls.update_kwargs(kwargs)
233
+
234
+ return cls.get_hf_class().from_pretrained(
235
+ model_id,
236
+ subfolder=subfolder,
237
+ revision=revision,
238
+ cache_dir=cache_dir,
239
+ use_auth_token=use_auth_token,
240
+ local_files_only=local_files_only,
241
+ force_download=force_download,
242
+ trust_remote_code=trust_remote_code,
243
+ **kwargs,
244
+ )
245
+
246
+ @classmethod
247
+ def _create_runtimes(
248
+ cls,
249
+ compiled_models: List[rebel.RBLNCompiledModel],
250
+ rbln_config: RBLNModelConfig,
251
+ ) -> List[rebel.Runtime]:
252
+ if len(rbln_config.compile_cfgs) == 0:
253
+ return []
254
+
255
+ if DEFAULT_COMPILED_MODEL_NAME not in rbln_config.device_map:
256
+ cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
257
+
258
+ return [
259
+ rebel.Runtime(
260
+ compiled_model,
261
+ tensor_type="pt",
262
+ device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
263
+ activate_profiler=rbln_config.activate_profiler,
264
+ timeout=rbln_config.timeout,
265
+ )
266
+ for compiled_model in compiled_models
267
+ ]
268
+
269
+ def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Any) -> Any:
270
+ """
271
+ Defines the forward pass of `RBLNModel`. The interface mirrors HuggingFace conventions so it can act as a drop-in
272
+ replacement in many cases.
273
+
274
+ This method executes the compiled RBLN model on RBLN NPU devices while remaining fully compatible with Hugging Face
275
+ Transformers and Diffusers APIs. In practice, `RBLNModel` can replace models built on `torch.nn.Module` — including
276
+ `transformers.PreTrainedModel` implementations and Diffusers components based on `diffusers.ModelMixin` — enabling
277
+ seamless integration into existing workflows.
278
+
279
+ Args:
280
+ args: Variable length argument list containing model inputs. The format matches the original
281
+ HuggingFace model's forward method signature (e.g., input_ids, attention_mask for
282
+ transformers models, or sample, timestep for diffusers models).
283
+ return_dict:
284
+ Whether to return outputs as a dictionary-like object or as a tuple. When `None`:
285
+ - For transformers models: Uses `self.config.use_return_dict` (typically `True`)
286
+ - For diffusers models: Defaults to `True`
287
+ kwargs: Arbitrary keyword arguments containing additional model inputs and parameters,
288
+ matching the original HuggingFace model's interface.
289
+
290
+ Returns:
291
+ Model outputs in the same format as the original HuggingFace model.
292
+
293
+ If `return_dict=True`, Returns a dictionary-like object (e.g., BaseModelOutput,
294
+ CausalLMOutput) with named fields such as `logits`, `hidden_states`, etc.
295
+ If `return_dict=False`, Returns a tuple containing the raw model outputs.
296
+
297
+ Note:
298
+ - This method maintains the exact same interface as the original HuggingFace model's forward method
299
+ - The compiled model runs on RBLN NPU hardware for accelerated inference
300
+ - All HuggingFace model features (generation, attention patterns, etc.) are preserved
301
+ - Can be used directly in HuggingFace pipelines, transformers.Trainer, and other workflows
302
+ """
303
+ if self.hf_library_name == "transformers":
304
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
305
+ else:
306
+ return_dict = True if return_dict is None else return_dict
307
+
308
+ # Get output from the model
309
+ output = self.model[0](*args, **kwargs)
310
+
311
+ # Format output according to task requirements
312
+ return self._prepare_output(output, return_dict)
313
+
314
+ @classmethod
315
+ def get_hf_output_class(cls):
316
+ # Dynamically gets the output class from the corresponding HuggingFace model class.
317
+ if "_output_class" in cls.__dict__ and cls._output_class is not None:
318
+ return cls._output_class
319
+
320
+ hf_class = cls.get_hf_class()
321
+ if hf_class is None:
322
+ raise ValueError(f"No HuggingFace model class found for {cls.__name__}")
323
+
324
+ hints = get_type_hints(hf_class.forward) if hasattr(hf_class, "forward") else {}
325
+ ret = hints.get("return")
326
+
327
+ if ret is not None:
328
+ candidates = get_args(ret) if get_origin(ret) is Union else (ret,)
329
+
330
+ for t in candidates:
331
+ if t is type(None): # Skip NoneType in Union
332
+ continue
333
+ mod = getattr(t, "__module__", "")
334
+ if "transformers" in mod or "diffusers" in mod:
335
+ cls._output_class = t
336
+ return t
337
+
338
+ # Fallback to BaseModelOutput
339
+ cls._output_class = BaseModelOutput
340
+ return BaseModelOutput
341
+
342
+ def _prepare_output(self, output, return_dict):
343
+ # Prepare model output based on return_dict flag.
344
+ # This method can be overridden by subclasses to provide task-specific output handling.
345
+ tuple_output = (output,) if not isinstance(output, (tuple, list)) else tuple(output)
346
+ if not return_dict:
347
+ return tuple_output
348
+ else:
349
+ output_class = self.get_hf_output_class()
350
+ if hasattr(output_class, "loss"):
351
+ tuple_output = (None,) + tuple_output
352
+
353
+ # Truncate if we have too many outputs, otherwise use as is
354
+ if hasattr(output_class, "__annotations__"):
355
+ num_fields = len(output_class.__annotations__)
356
+ if len(tuple_output) > num_fields:
357
+ tuple_output = tuple_output[:num_fields]
358
+ logger.warning(
359
+ f"Truncating output to {num_fields} fields for {output_class.__name__}. "
360
+ f"Expected {num_fields} fields, but got {len(tuple_output)} fields."
361
+ "This is unexpected. Please report this issue to the developers."
362
+ )
363
+
364
+ return output_class(*tuple_output)