optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,637 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ import os
17
+ import shutil
18
+ from abc import ABC
19
+ from pathlib import Path
20
+ from tempfile import TemporaryDirectory
21
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
22
+
23
+ import rebel
24
+ import torch
25
+ from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
26
+ from transformers.utils.hub import PushToHubMixin
27
+
28
+ from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
29
+ from .utils.hub import pull_compiled_model_from_hub, validate_files
30
+ from .utils.logging import get_logger
31
+ from .utils.runtime_utils import UnavailableRuntime, tp_and_devices_are_ok
32
+ from .utils.save_utils import maybe_load_preprocessors
33
+ from .utils.submodule import SubModulesMixin
34
+
35
+
36
+ if TYPE_CHECKING:
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
38
+
39
+ logger = get_logger(__name__)
40
+
41
+
42
+ class PreTrainedModel(ABC): # noqa: F811
43
+ pass
44
+
45
+
46
+ class RBLNBaseModelConfig(RBLNModelConfig):
47
+ pass
48
+
49
+
50
+ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
51
+ model_type = "rbln_model"
52
+ auto_model_class = AutoModel
53
+ config_class = AutoConfig
54
+ config_name = "config.json"
55
+ hf_library_name = "transformers"
56
+ _supports_non_fp32 = False
57
+
58
+ def __init__(
59
+ self,
60
+ models: List[rebel.Runtime],
61
+ config: "PretrainedConfig",
62
+ rbln_config: RBLNModelConfig,
63
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
64
+ subfolder: str = "",
65
+ rbln_compiled_models: Optional[rebel.RBLNCompiledModel] = None,
66
+ rbln_submodules: List["RBLNBaseModel"] = [],
67
+ **kwargs,
68
+ ):
69
+ self.model = models
70
+ self.config = config
71
+ self.rbln_config = rbln_config
72
+ if not rbln_config.is_frozen():
73
+ raise RuntimeError("`rbln_config` must be frozen. Please call `rbln_config.freeze()` first.")
74
+ self.compiled_models = rbln_compiled_models
75
+
76
+ # Registers the RBLN classes into the transformers AutoModel classes to avoid warnings when creating
77
+ # a pipeline https://github.com/huggingface/transformers/blob/3d3204c025b6b5de013e07dd364208e28b4d9589/src/transformers/pipelines/base.py#L940
78
+ AutoConfig.register(self.model_type, AutoConfig)
79
+ if hasattr(self.auto_model_class, "register"):
80
+ self.auto_model_class.register(AutoConfig, self.__class__)
81
+
82
+ # copied from tranformers PreTrainedModel __init__
83
+ if self.can_generate():
84
+ gen_config_dir = model_save_dir.name if isinstance(model_save_dir, TemporaryDirectory) else model_save_dir
85
+ self.generation_config = GenerationConfig.from_pretrained(gen_config_dir, trust_remote_code=True)
86
+ else:
87
+ self.generation_config = None
88
+
89
+ if self.generation_config is not None:
90
+ self.generation_config.use_cache = True
91
+
92
+ self.device = torch.device("cpu")
93
+ self.training = False
94
+ self.dtype = rbln_config.torch_dtype
95
+
96
+ # FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
97
+ # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
98
+ # would end-up removing the directory containing the underlying RBLN model.
99
+ self._model_save_dir_tempdirectory_instance = None
100
+ if isinstance(model_save_dir, TemporaryDirectory):
101
+ self._model_save_dir_tempdirectory_instance = model_save_dir
102
+ self.model_save_dir = Path(model_save_dir.name)
103
+ elif isinstance(model_save_dir, str):
104
+ self.model_save_dir = Path(model_save_dir)
105
+ else:
106
+ self.model_save_dir = model_save_dir
107
+ self.subfolder = subfolder
108
+
109
+ self.rbln_submodules = rbln_submodules
110
+ self.__post_init__(**kwargs)
111
+
112
+ @classmethod
113
+ def _load_compiled_model_dir(
114
+ cls,
115
+ model_id: Union[str, Path],
116
+ token: Optional[Union[bool, str]] = None,
117
+ revision: Optional[str] = None,
118
+ force_download: bool = False,
119
+ cache_dir: Optional[str] = None,
120
+ subfolder: str = "",
121
+ local_files_only: bool = False,
122
+ ) -> str:
123
+ # Load the directory containing the compiled model files.
124
+ model_path = Path(model_id)
125
+
126
+ if model_path.is_dir():
127
+ model_path = model_path / subfolder
128
+ rbln_files = list(model_path.glob("*.rbln"))
129
+ rbln_config_filenames = list(model_path.glob("rbln_config.json"))
130
+ validate_files(rbln_files, rbln_config_filenames, f"directory {model_path}")
131
+ else:
132
+ model_path = pull_compiled_model_from_hub(
133
+ model_id=model_id,
134
+ subfolder=subfolder,
135
+ token=token,
136
+ revision=revision,
137
+ cache_dir=cache_dir,
138
+ force_download=force_download,
139
+ local_files_only=local_files_only,
140
+ )
141
+
142
+ return str(model_path)
143
+
144
+ @classmethod
145
+ def _load_compiled_models(cls, model_path: str, expected_compiled_model_names: List[str]):
146
+ compiled_models = Path(model_path).glob("*.rbln")
147
+ expected_compiled_models = [
148
+ Path(model_path) / f"{compiled_model_name}.rbln" for compiled_model_name in expected_compiled_model_names
149
+ ]
150
+ unexpected_compiled_models = [cm for cm in compiled_models if cm not in expected_compiled_models]
151
+ if unexpected_compiled_models:
152
+ # TODO(jongho): fix after May release. raise error if unexpected compiled models are found
153
+ logger.warning(
154
+ f"Unexpected compiled models found: {[cm.name for cm in unexpected_compiled_models]}. "
155
+ f"Please check the model path: {model_path}"
156
+ )
157
+
158
+ rbln_compiled_models = {}
159
+ for compiled_model in expected_compiled_models:
160
+ if not compiled_model.exists():
161
+ raise FileNotFoundError(
162
+ f"Expected RBLN compiled model '{compiled_model.name}' not found at '{model_path}'. "
163
+ "Please ensure all models specified in `rbln_config` are present."
164
+ )
165
+ rbln_compiled_models[compiled_model.stem] = rebel.RBLNCompiledModel(compiled_model)
166
+ return rbln_compiled_models
167
+
168
+ @classmethod
169
+ def _from_pretrained(
170
+ cls,
171
+ model_id: Union[str, Path],
172
+ config: Optional["PretrainedConfig"] = None,
173
+ token: Optional[Union[bool, str]] = None,
174
+ revision: Optional[str] = None,
175
+ force_download: bool = False,
176
+ cache_dir: Optional[str] = None,
177
+ subfolder: str = "",
178
+ local_files_only: bool = False,
179
+ trust_remote_code: bool = False,
180
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
181
+ # passed from compile function
182
+ rbln_config: Optional[RBLNModelConfig] = None,
183
+ rbln_compiled_models: Optional[Dict[str, rebel.RBLNCompiledModel]] = None,
184
+ rbln_submodules: List["RBLNBaseModel"] = [],
185
+ **kwargs,
186
+ ) -> "RBLNBaseModel":
187
+ if rbln_compiled_models is None:
188
+ model_path_subfolder = cls._load_compiled_model_dir(
189
+ model_id=model_id,
190
+ token=token,
191
+ revision=revision,
192
+ force_download=force_download,
193
+ cache_dir=cache_dir,
194
+ subfolder=subfolder,
195
+ local_files_only=local_files_only,
196
+ )
197
+
198
+ if isinstance(rbln_config, dict):
199
+ rbln_config_as_kwargs = {f"rbln_{key}": value for key, value in rbln_config.items()}
200
+ kwargs.update(rbln_config_as_kwargs)
201
+ rbln_config = None
202
+ elif isinstance(rbln_config, RBLNModelConfig) and rbln_config.rbln_model_cls_name != cls.__name__:
203
+ raise ValueError(
204
+ f"Cannot use the passed rbln_config. Its model class name ({rbln_config.rbln_model_cls_name}) "
205
+ f"does not match the expected model class name ({cls.__name__})."
206
+ )
207
+
208
+ rbln_config, kwargs = RBLNAutoConfig.load(
209
+ model_path_subfolder, passed_rbln_config=rbln_config, kwargs=kwargs, return_unused_kwargs=True
210
+ )
211
+
212
+ if rbln_config.rbln_model_cls_name != cls.__name__:
213
+ raise NameError(
214
+ f"Cannot load the model. The model was originally compiled using "
215
+ f"{rbln_config.rbln_model_cls_name}, but you are trying to load it with {cls.__name__}."
216
+ "Please use the same model class that was used during compilation."
217
+ )
218
+
219
+ if len(cls._rbln_submodules) > 0:
220
+ rbln_submodules = cls._load_submodules(model_save_dir=model_id, rbln_config=rbln_config, **kwargs)
221
+ else:
222
+ rbln_submodules = []
223
+
224
+ rbln_config.freeze()
225
+
226
+ if config is None:
227
+ if cls.hf_library_name == "transformers":
228
+ config = AutoConfig.from_pretrained(
229
+ model_path_subfolder,
230
+ cache_dir=cache_dir,
231
+ force_download=force_download,
232
+ revision=revision,
233
+ token=token,
234
+ trust_remote_code=trust_remote_code,
235
+ )
236
+ elif cls.hf_library_name == "diffusers":
237
+ # import here to prevent diffusers dependency
238
+ # TODO(jongho): Remove diffusers dependency if use transformers only.
239
+ from diffusers.configuration_utils import ConfigMixin
240
+
241
+ class DummyConfigMixin(ConfigMixin):
242
+ # Just to load config, We need to specify `config_name`
243
+ config_name = "config.json"
244
+
245
+ config = DummyConfigMixin.load_config(
246
+ model_id,
247
+ cache_dir=cache_dir,
248
+ force_download=force_download,
249
+ local_files_only=local_files_only,
250
+ revision=revision,
251
+ token=token,
252
+ subfolder=subfolder,
253
+ )
254
+ config = PretrainedConfig(**config)
255
+
256
+ compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
257
+ rbln_compiled_models = cls._load_compiled_models(model_path_subfolder, compiled_model_names)
258
+
259
+ if subfolder != "":
260
+ model_save_dir = Path(model_path_subfolder).absolute().parent
261
+ else:
262
+ model_save_dir = Path(model_path_subfolder).absolute()
263
+
264
+ return cls._from_compiled_models(
265
+ rbln_compiled_models=rbln_compiled_models,
266
+ rbln_config=rbln_config,
267
+ config=config,
268
+ model_save_dir=model_save_dir,
269
+ subfolder=subfolder,
270
+ rbln_submodules=rbln_submodules,
271
+ **kwargs,
272
+ )
273
+
274
+ @classmethod
275
+ def _from_compiled_models(
276
+ cls,
277
+ rbln_compiled_models: Dict[str, rebel.RBLNCompiledModel],
278
+ rbln_config: RBLNModelConfig,
279
+ config: "PretrainedConfig",
280
+ model_save_dir: Union[Path, str],
281
+ subfolder: Union[Path, str],
282
+ rbln_submodules: List["RBLNBaseModel"] = [],
283
+ **kwargs,
284
+ ):
285
+ if isinstance(model_save_dir, str):
286
+ model_save_dir = Path(model_save_dir)
287
+
288
+ # FIXME:: Should we convert it?
289
+ compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
290
+ rbln_compiled_models = [rbln_compiled_models[cm_name] for cm_name in compiled_model_names]
291
+
292
+ # create runtimes only if `rbln_create_runtimes` is enabled
293
+ try:
294
+ models = (
295
+ cls._create_runtimes(rbln_compiled_models, rbln_config)
296
+ if rbln_config.create_runtimes
297
+ else UnavailableRuntime()
298
+ )
299
+
300
+ except rebel.core.exception.RBLNRuntimeError as e:
301
+ error_msg = (
302
+ f"\nFailed to create RBLN runtime: {str(e)}\n\n"
303
+ f"If you only need to compile the model without loading it to NPU, you can use:\n"
304
+ f" from_pretrained(..., rbln_create_runtimes=False) or\n"
305
+ f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
306
+ f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
307
+ f"Make sure your NPU is properly installed and operational."
308
+ )
309
+ raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
310
+
311
+ return cls(
312
+ models,
313
+ config,
314
+ rbln_config,
315
+ model_save_dir=model_save_dir,
316
+ subfolder=subfolder,
317
+ rbln_compiled_models=rbln_compiled_models,
318
+ rbln_submodules=rbln_submodules,
319
+ **kwargs,
320
+ )
321
+
322
+ @classmethod
323
+ def _export(cls, model_id: Union[str, Path], **kwargs) -> "RBLNBaseModel":
324
+ subfolder = kwargs.get("subfolder", "")
325
+ model_save_dir = kwargs.pop("model_save_dir", None)
326
+
327
+ rbln_config, kwargs = cls.prepare_rbln_config(**kwargs)
328
+
329
+ model: "PreTrainedModel" = cls.get_pytorch_model(model_id=model_id, rbln_config=rbln_config, **kwargs)
330
+ preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
331
+ return cls.from_model(
332
+ model, preprocessors=preprocessors, model_save_dir=model_save_dir, rbln_config=rbln_config, **kwargs
333
+ )
334
+
335
+ @classmethod
336
+ def prepare_rbln_config(
337
+ cls, rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None, **kwargs
338
+ ) -> Tuple[RBLNModelConfig, Dict[str, Any]]:
339
+ # Extract rbln-config from kwargs and convert it to RBLNModelConfig.
340
+
341
+ config_cls = cls.get_rbln_config_class()
342
+ rbln_config, kwargs = config_cls.initialize_from_kwargs(rbln_config, **kwargs)
343
+ return rbln_config, kwargs
344
+
345
+ @classmethod
346
+ def _is_compiled(
347
+ cls,
348
+ model_id: Union[str, Path],
349
+ token: Optional[Union[bool, str]] = None,
350
+ revision: Optional[str] = None,
351
+ force_download: bool = False,
352
+ cache_dir: Optional[str] = None,
353
+ subfolder: str = "",
354
+ local_files_only: bool = False,
355
+ ) -> bool:
356
+ # Check if the model is already compiled.
357
+ try:
358
+ cls._load_compiled_model_dir(
359
+ model_id=model_id,
360
+ token=token,
361
+ revision=revision,
362
+ force_download=force_download,
363
+ cache_dir=cache_dir,
364
+ subfolder=subfolder,
365
+ local_files_only=local_files_only,
366
+ )
367
+ return True
368
+ except (FileNotFoundError, KeyError):
369
+ return False
370
+
371
+ @classmethod
372
+ def from_pretrained(
373
+ cls: Type["RBLNBaseModel"],
374
+ model_id: Union[str, Path],
375
+ export: Optional[bool] = None,
376
+ rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
377
+ **kwargs: Any,
378
+ ) -> "RBLNBaseModel":
379
+ """
380
+ The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
381
+ User can use this function to load a pre-trained model from the HuggingFace library and convert it to a RBLN model to be run on RBLN NPUs.
382
+
383
+ Args:
384
+ model_id (Union[str, Path]): The model id of the pre-trained model to be loaded.
385
+ It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
386
+ export (Optional[bool]): A boolean flag to indicate whether the model should be compiled.
387
+ If None, it will be determined based on the existence of the compiled model files in the model_id.
388
+ rbln_config (Optional[Union[Dict, RBLNModelConfig]]): Configuration for RBLN model compilation and runtime.
389
+ This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
390
+ For detailed configuration options, see the specific model's configuration class documentation.
391
+ kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
392
+
393
+ Returns:
394
+ (RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
395
+ """
396
+
397
+ if isinstance(model_id, Path):
398
+ model_id = model_id.as_posix()
399
+
400
+ if export is None:
401
+ export = not cls._is_compiled(
402
+ model_id=model_id,
403
+ token=kwargs.get("token"),
404
+ revision=kwargs.get("revision"),
405
+ force_download=kwargs.get("force_download", False),
406
+ cache_dir=kwargs.get("cache_dir"),
407
+ subfolder=kwargs.get("subfolder", ""),
408
+ local_files_only=kwargs.get("local_files_only", False),
409
+ )
410
+
411
+ from_pretrained_method = cls._export if export else cls._from_pretrained
412
+ return from_pretrained_method(model_id=model_id, **kwargs, rbln_config=rbln_config)
413
+
414
+ @classmethod
415
+ def compile(
416
+ cls,
417
+ model,
418
+ rbln_compile_config: RBLNCompileConfig,
419
+ create_runtimes: bool,
420
+ device: Union[int, List[int]],
421
+ **kwargs,
422
+ ):
423
+ if create_runtimes:
424
+ runtime_cannot_be_created = tp_and_devices_are_ok(
425
+ tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
426
+ device=device,
427
+ npu=rbln_compile_config.npu,
428
+ )
429
+ if runtime_cannot_be_created:
430
+ raise ValueError(runtime_cannot_be_created)
431
+
432
+ compiled_model = rebel.compile_from_torch(
433
+ model,
434
+ input_info=rbln_compile_config.input_info,
435
+ npu=rbln_compile_config.npu,
436
+ tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
437
+ **kwargs,
438
+ )
439
+ return compiled_model
440
+
441
+ @classmethod
442
+ def update_rbln_config(
443
+ cls,
444
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
445
+ model: "PreTrainedModel",
446
+ model_config: "PretrainedConfig",
447
+ rbln_config: RBLNModelConfig,
448
+ ) -> RBLNModelConfig:
449
+ rbln_config.torch_dtype = model.dtype
450
+ if not cls._supports_non_fp32 and rbln_config.torch_dtype != torch.float32:
451
+ raise NotImplementedError(
452
+ f"Currently, {cls.__name__} does not support non-fp32 dtype. Please use float32 dtype."
453
+ )
454
+ rbln_config = cls._update_rbln_config(
455
+ preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
456
+ )
457
+ rbln_config.freeze()
458
+ if rbln_config.rbln_model_cls_name != cls.__name__:
459
+ raise NameError(
460
+ f"Cannot get the rbln config. {cls.__name__} is not the same as {rbln_config.rbln_model_cls_name}. "
461
+ "This is an internal error. Please report it to the developers."
462
+ )
463
+ return rbln_config
464
+
465
+ @classmethod
466
+ def get_hf_class(cls):
467
+ # Lazily loads and caches the corresponding HuggingFace model class.
468
+ # Removes 'RBLN' prefix from the class name to get the original class name
469
+ # (e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
470
+ # the transformers/diffusers module.
471
+
472
+ # Returns:
473
+ # type: The original HuggingFace model class
474
+ if "_hf_class" not in cls.__dict__ or cls._hf_class is None:
475
+ hf_cls_name = cls.__name__[4:]
476
+ library = importlib.import_module(cls.hf_library_name)
477
+ cls._hf_class = getattr(library, hf_cls_name, None)
478
+ return cls._hf_class
479
+
480
+ @classmethod
481
+ def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
482
+ # Lazily loads and caches the corresponding RBLN model config class.
483
+ if "_rbln_config_class" not in cls.__dict__ or cls._rbln_config_class is None:
484
+ rbln_config_class_name = cls.__name__ + "Config"
485
+ cls._rbln_config_class = get_rbln_config_class(rbln_config_class_name)
486
+ return cls._rbln_config_class
487
+
488
+ def can_generate(self):
489
+ return False
490
+
491
+ def to(self, *args, **kwargs):
492
+ return self
493
+
494
+ def parameters(self):
495
+ # A dummy parameter generator for compatibility.
496
+
497
+ # This method mimics the interface of torch.nn.Module.parameters()
498
+ # specifically for code that uses `next(model.parameters())` to infer
499
+ # the device or dtype. It yields a single dummy tensor on CPU with model dtype.
500
+
501
+ # Warning:
502
+ # This does NOT yield the actual model parameters used by the RBLN runtime.
503
+ # Code relying on iterating through all model parameters will not work as expected.
504
+ yield torch.tensor([1.0], dtype=self.dtype, device=torch.device("cpu"))
505
+
506
+ def __call__(self, *args, **kwargs):
507
+ return self.forward(*args, **kwargs)
508
+
509
+ def __repr__(self):
510
+ has_submodules = len(self.rbln_submodules) > 0
511
+ repr_str: str = f"<{self.__class__.__name__}>\n"
512
+ repr_str += f"- Total {len(self.model)} Runtimes"
513
+ repr_str += f" and {len(self.rbln_submodules)} Submodules\n" if has_submodules else "\n"
514
+ repr_str += "[Runtimes]\n"
515
+ repr_str += "\n".join([repr(model) for model in self.model])
516
+ repr_str += "\n"
517
+
518
+ if has_submodules > 0:
519
+ for i, submodule in enumerate(self.rbln_submodules):
520
+ repr_str += f"[Submodules {i} : {self._rbln_submodules[i]['name']}]\n"
521
+ repr_str += repr(submodule) + "\n"
522
+
523
+ return repr_str
524
+
525
+ def __post_init__(self, **kwargs):
526
+ pass
527
+
528
+ def save_pretrained(
529
+ self,
530
+ save_directory: Union[str, Path],
531
+ push_to_hub: bool = False,
532
+ **kwargs,
533
+ ):
534
+ """
535
+ Saves a model and its configuration file to a directory, so that it can be re-loaded using the
536
+ [`~optimum.rbln.modeling_base.RBLNBaseModel.from_pretrained`] class method.
537
+
538
+ Args:
539
+ save_directory (Union[str, Path]):
540
+ Directory where to save the model file.
541
+ push_to_hub (bool):
542
+ Whether or not to push your model to the HuggingFace model hub after saving it.
543
+
544
+ """
545
+ if os.path.isfile(save_directory):
546
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
547
+ return
548
+
549
+ # Normalize paths to handle relative paths and symlinks
550
+ real_save_dir = Path(self.model_save_dir).resolve() / self.subfolder
551
+ save_directory_path = Path(save_directory).resolve()
552
+
553
+ if not os.path.exists(real_save_dir) or not os.path.isdir(real_save_dir):
554
+ raise FileNotFoundError(
555
+ f"Unable to save the model. The model directory '{real_save_dir}' does not exist or is not accessible. "
556
+ f"Cannot save to the specified destination '{save_directory}'. "
557
+ f"Please ensure the model directory exists and you have the necessary permissions to access it."
558
+ )
559
+
560
+ if isinstance(self.config, PretrainedConfig):
561
+ self.config.save_pretrained(real_save_dir)
562
+
563
+ if save_directory_path == real_save_dir:
564
+ raise FileExistsError(
565
+ f"Cannot save model to '{save_directory}'. This directory already exists and contains the model files."
566
+ )
567
+
568
+ # Create a temporary directory with normalized path
569
+ tmp_dir = str(save_directory_path) + ".tmp"
570
+ try:
571
+ # Remove temporary directory if it exists from a previous failed attempt
572
+ if os.path.exists(tmp_dir):
573
+ shutil.rmtree(tmp_dir)
574
+
575
+ # First copy everything to a temporary directory
576
+ shutil.copytree(real_save_dir, tmp_dir)
577
+
578
+ # If everything succeeded, move files to target directory
579
+ if os.path.exists(save_directory_path):
580
+ # Merge files from tmp_dir into existing directory
581
+ def _merge_dir(src_root: str, dst_root: str):
582
+ for name in os.listdir(src_root):
583
+ src_item = os.path.join(src_root, name)
584
+ dst_item = os.path.join(dst_root, name)
585
+
586
+ if os.path.islink(src_item) or os.path.isfile(src_item):
587
+ os.makedirs(os.path.dirname(dst_item), exist_ok=True)
588
+ if os.path.isdir(dst_item) and not os.path.islink(dst_item):
589
+ shutil.rmtree(dst_item)
590
+ os.replace(src_item, dst_item)
591
+ elif os.path.isdir(src_item):
592
+ if os.path.islink(dst_item) or os.path.isfile(dst_item):
593
+ os.remove(dst_item)
594
+ os.makedirs(dst_item, exist_ok=True)
595
+ _merge_dir(src_item, dst_item)
596
+ else:
597
+ # Fallback for special file types
598
+ os.replace(src_item, dst_item)
599
+
600
+ _merge_dir(tmp_dir, str(save_directory_path))
601
+
602
+ # Remove the temporary directory tree after merge
603
+ shutil.rmtree(tmp_dir)
604
+ else:
605
+ # If target doesn't exist, just rename tmp_dir to target
606
+ os.rename(tmp_dir, save_directory_path)
607
+
608
+ except Exception as e:
609
+ # Clean up the temporary directory if anything fails
610
+ if os.path.exists(tmp_dir):
611
+ shutil.rmtree(tmp_dir)
612
+ raise e # Re-raise the exception after cleanup
613
+
614
+ if push_to_hub:
615
+ repo_id = kwargs.pop("repo_id", None)
616
+ if repo_id is None:
617
+ raise ValueError("`repo_id` must be provided to push the model to the HuggingFace model hub.")
618
+ return super().push_to_hub(repo_id=repo_id, **kwargs)
619
+
620
+ @staticmethod
621
+ def _raise_missing_compiled_file_error(missing_files: List[str]):
622
+ # Raises a KeyError with a message indicating missing compiled model files.
623
+
624
+ if len(missing_files) == 1:
625
+ message = f"The rbln model folder is missing the required '{missing_files[0]}.rbln' file. "
626
+ else:
627
+ files_str = ", ".join([f"'{f}.rbln'" for f in missing_files])
628
+ message = (
629
+ "The rbln model folder is missing required files. "
630
+ f"Ensure that {files_str} files are present in the folder. "
631
+ )
632
+ message += (
633
+ "These files are necessary for loading the rbln model. "
634
+ "If these files are missing, please recompile the model using the latest optimum-rbln "
635
+ "and ensure the compilation completes successfully."
636
+ )
637
+ raise KeyError(message)
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .attn import *
16
+ from .flash_attn import *
17
+ from .kv_cache_update import *
18
+ from .linear import linear
19
+ from .sliding_window_attn import *