optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,968 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ import inspect
17
+ import json
18
+ from dataclasses import asdict, dataclass
19
+ from pathlib import Path
20
+ from typing import Any, Dict, List, Optional, Protocol, Tuple, Type, Union, runtime_checkable
21
+
22
+ import numpy as np
23
+ import torch
24
+ from packaging.version import Version
25
+
26
+ from .__version__ import __version__
27
+ from .utils.deprecation import warn_deprecated_npu
28
+ from .utils.logging import get_logger
29
+ from .utils.runtime_utils import ContextRblnConfig
30
+
31
+
32
+ logger = get_logger(__name__)
33
+
34
+
35
+ DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
36
+ TypeInputInfo = List[Tuple[str, Tuple[int], str]]
37
+
38
+
39
+ @runtime_checkable
40
+ class RBLNSerializableConfigProtocol(Protocol):
41
+ def _prepare_for_serialization(self) -> Dict[str, Any]: ...
42
+
43
+ def __repr__(self) -> str:
44
+ return f"{self.__class__.__name__}({self._prepare_for_serialization()})"
45
+
46
+
47
+ @dataclass
48
+ class RBLNCompileConfig:
49
+ """
50
+ Configuration for RBLN compilation.
51
+
52
+ Attributes:
53
+ compiled_model_name (str): Name of the compiled model.
54
+ input_info (Union[List[TypeInputInfo], TypeInputInfo]): Information about input tensors.
55
+ npu (Optional[str]): NPU configuration.
56
+ tensor_parallel_size (Optional[int]): Size for tensor parallelism.
57
+ """
58
+
59
+ compiled_model_name: str = DEFAULT_COMPILED_MODEL_NAME
60
+ input_info: Union[List[TypeInputInfo], TypeInputInfo] = None
61
+ npu: Optional[str] = None
62
+ tensor_parallel_size: Optional[int] = None
63
+
64
+ @staticmethod
65
+ def normalize_dtype(dtype: Union[str, torch.dtype, np.dtype]) -> str:
66
+ """
67
+ Convert framework-specific dtype to string representation.
68
+ i.e. torch.float32 -> "float32"
69
+
70
+ Args:
71
+ dtype: The input dtype (can be string, torch dtype, or numpy dtype).
72
+
73
+ Returns:
74
+ The normalized string representation of the dtype.
75
+ """
76
+ if isinstance(dtype, str):
77
+ return dtype
78
+ else:
79
+ dtype: str = repr(dtype).split(".")[-1]
80
+ if dtype.endswith("'>"): # numpy
81
+ dtype = dtype[:-2]
82
+ return dtype
83
+
84
+ @property
85
+ def is_multiple_input_info(self) -> bool:
86
+ def is_valid_input_info(input_info):
87
+ if not isinstance(input_info, list):
88
+ return False
89
+ return all(
90
+ isinstance(item, (tuple, list))
91
+ and len(item) == 3
92
+ and isinstance(item[0], str) # name
93
+ and isinstance(item[1], (tuple, list)) # shape
94
+ and all(isinstance(x, int) for x in item[1])
95
+ and isinstance(item[2], str) # dtype
96
+ for item in input_info
97
+ )
98
+
99
+ if isinstance(self.input_info, list):
100
+ return all(is_valid_input_info(info) for info in self.input_info)
101
+ return False
102
+
103
+ def __post_init__(self):
104
+ def normalize_input_info(input_info):
105
+ return [(i[0], i[1], RBLNCompileConfig.normalize_dtype(i[2]) or "float32") for i in input_info]
106
+
107
+ if self.is_multiple_input_info:
108
+ self.input_info = [normalize_input_info(info) for info in self.input_info]
109
+ else:
110
+ self.input_info = normalize_input_info(self.input_info)
111
+
112
+ def update(self, kwargs: Dict[str, Any]):
113
+ self.compiled_model_name = kwargs.get("compiled_model_name", self.compiled_model_name)
114
+ self.input_info = kwargs.get("input_info", self.input_info)
115
+ self.npu = kwargs.get("npu", self.npu)
116
+ self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
117
+ return self
118
+
119
+ def get_dummy_inputs(
120
+ self, fill=0, static_tensors: Dict[str, torch.Tensor] = {}, meta_tensor_names: List[str] = []
121
+ ):
122
+ dummy = []
123
+ for name, shape, dtype in self.input_info:
124
+ if name in static_tensors:
125
+ tensor = static_tensors[name]
126
+ if shape != list(tensor.shape):
127
+ raise RuntimeError(f"Different shape for dummy inputs. ({shape} != {list(tensor.shape)})")
128
+ if getattr(torch, dtype) != tensor.dtype:
129
+ raise RuntimeError(f"Different dtype for dummy inputs ({dtype} != {tensor.dtype})")
130
+ dummy.append(tensor)
131
+ else:
132
+ if name in meta_tensor_names:
133
+ device = "meta"
134
+ else:
135
+ device = "cpu"
136
+
137
+ dummy.append(
138
+ torch.fill(torch.empty(*shape, dtype=getattr(torch, dtype), device=torch.device(device)), fill)
139
+ if len(shape) > 0
140
+ else torch.tensor(fill, dtype=getattr(torch, dtype), device=torch.device(device))
141
+ )
142
+ return tuple(dummy)
143
+
144
+ def asdict(self):
145
+ return asdict(self)
146
+
147
+
148
+ RUNTIME_KEYWORDS = ["create_runtimes", "device", "device_map", "activate_profiler", "timeout"]
149
+ CONFIG_MAPPING: Dict[str, Type["RBLNModelConfig"]] = {}
150
+
151
+
152
+ def get_rbln_config_class(rbln_config_class_name: str) -> Type["RBLNModelConfig"]:
153
+ cls = getattr(importlib.import_module("optimum.rbln"), rbln_config_class_name, None)
154
+ if cls is None:
155
+ if rbln_config_class_name in CONFIG_MAPPING:
156
+ cls = CONFIG_MAPPING[rbln_config_class_name]
157
+ else:
158
+ raise ValueError(f"Configuration for {rbln_config_class_name} not found.")
159
+ return cls
160
+
161
+
162
+ def load_config(path: str) -> Tuple[Type["RBLNModelConfig"], Dict[str, Any]]:
163
+ path = Path(path)
164
+ if path.is_dir():
165
+ path = path / "rbln_config.json"
166
+
167
+ with open(path, "r") as jsonf:
168
+ config_file = json.load(jsonf)
169
+
170
+ if "_meta" in config_file:
171
+ is_legacy_rbln_config = True
172
+
173
+ if is_legacy_rbln_config:
174
+ raise RuntimeError(
175
+ f"`{path}` is an old version. Please recompile the model to get the latest config file."
176
+ )
177
+
178
+ cls_name = config_file["cls_name"]
179
+ cls = get_rbln_config_class(cls_name)
180
+ return cls, config_file
181
+
182
+
183
+ class RBLNAutoConfig:
184
+ """
185
+ Resolver and factory for RBLN model configurations.
186
+
187
+ This class selects the concrete `RBLNModelConfig` subclass, validates the
188
+ provided data, and returns a frozen configuration object that serves as the
189
+ single source of truth during export and load. It does not define the schema
190
+ or control model behavior.
191
+ """
192
+
193
+ def __new__(cls, **kwargs):
194
+ cls_name = kwargs.get("cls_name")
195
+ if cls_name is None:
196
+ raise ValueError("`cls_name` is required.")
197
+ cls = get_rbln_config_class(cls_name)
198
+ return cls(**kwargs)
199
+
200
+ @staticmethod
201
+ def load_from_dict(config_dict: Dict[str, Any]) -> "RBLNModelConfig":
202
+ """
203
+ Build a `RBLNModelConfig` from a plain dictionary.
204
+
205
+ The dictionary must contain `cls_name`, which identifies the concrete
206
+ configuration class to instantiate. All other keys are forwarded to the
207
+ target class initializer. This method does not mutate `config_dict`.
208
+
209
+ Args:
210
+ config_dict: Mapping typically created by `json.load` or `yaml.safe_load`.
211
+ For example, the parsed contents of `rbln_config.json`.
212
+
213
+ Returns:
214
+ RBLNModelConfig: A configuration instance. The specific subclass is
215
+ selected by `config_dict["cls_name"]`.
216
+
217
+ Raises:
218
+ ValueError: If `cls_name` is missing.
219
+ Exception: Any error raised by the target config class during init.
220
+
221
+ Examples:
222
+ >>> data = {
223
+ ... "cls_name": "RBLNLlamaForCausalLMConfig",
224
+ ... "create_runtimes": False,
225
+ ... "tensor_parallel_size": 4
226
+ ... }
227
+ >>> cfg = RBLNAutoConfig.load_from_dict(data)
228
+ """
229
+ cls_name = config_dict.get("cls_name")
230
+ if cls_name is None:
231
+ raise ValueError("`cls_name` is required.")
232
+ cls = get_rbln_config_class(cls_name)
233
+ return cls(**config_dict)
234
+
235
+ @staticmethod
236
+ def register(config: Type["RBLNModelConfig"], exist_ok=False):
237
+ """
238
+ Register a new configuration for this class.
239
+
240
+ Args:
241
+ config (RBLNModelConfig): The config to register.
242
+ exist_ok (bool): Whether to allow registering an already registered model.
243
+ """
244
+ if not issubclass(config, RBLNModelConfig):
245
+ raise ValueError("`config` must be a subclass of RBLNModelConfig.")
246
+
247
+ native_cls = getattr(importlib.import_module("optimum.rbln"), config.__name__, None)
248
+ if config.__name__ in CONFIG_MAPPING or native_cls is not None:
249
+ if not exist_ok:
250
+ raise ValueError(f"Configuration for {config.__name__} already registered.")
251
+
252
+ CONFIG_MAPPING[config.__name__] = config
253
+
254
+ @staticmethod
255
+ def load(
256
+ path: str,
257
+ passed_rbln_config: Optional["RBLNModelConfig"] = None,
258
+ kwargs: Optional[Dict[str, Any]] = {},
259
+ return_unused_kwargs: bool = False,
260
+ ) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
261
+ """
262
+ Load RBLNModelConfig from a path.
263
+ Class name is automatically inferred from the `rbln_config.json` file.
264
+
265
+ Args:
266
+ path (str): Path to the RBLNModelConfig.
267
+ passed_rbln_config (Optional["RBLNModelConfig"]): RBLNModelConfig to pass its runtime options.
268
+
269
+ Returns:
270
+ RBLNModelConfig: The loaded RBLNModelConfig.
271
+ """
272
+ cls, config_file = load_config(path)
273
+
274
+ rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
275
+ rbln_runtime_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in RUNTIME_KEYWORDS}
276
+ rbln_submodule_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in cls.submodules}
277
+
278
+ rbln_kwargs = {
279
+ key[5:]: kwargs.pop(key)
280
+ for key in rbln_keys
281
+ if key[5:] not in RUNTIME_KEYWORDS and key[5:] not in cls.submodules
282
+ }
283
+
284
+ # Process submodule's rbln_config
285
+ for submodule in cls.submodules:
286
+ if submodule not in config_file:
287
+ raise ValueError(f"Submodule {submodule} not found in rbln_config.json.")
288
+ submodule_config = config_file[submodule]
289
+ submodule_config.update(rbln_submodule_kwargs.pop(submodule, {}))
290
+ config_file[submodule] = RBLNAutoConfig.load_from_dict(submodule_config)
291
+
292
+ if passed_rbln_config is not None:
293
+ config_file.update(passed_rbln_config._runtime_options)
294
+ # TODO(jongho): Reject if the passed_rbln_config has different attributes from the config_file
295
+
296
+ config_file.update(rbln_runtime_kwargs)
297
+
298
+ rbln_config = cls(**config_file)
299
+
300
+ if len(rbln_kwargs) > 0:
301
+ for key, value in rbln_kwargs.items():
302
+ if getattr(rbln_config, key) != value:
303
+ raise ValueError(
304
+ f"Cannot set the following arguments: {list(rbln_kwargs.keys())} "
305
+ f"Since the value is already set to {getattr(rbln_config, key)}"
306
+ )
307
+
308
+ if return_unused_kwargs:
309
+ return cls(**config_file), kwargs
310
+ else:
311
+ return cls(**config_file)
312
+
313
+
314
+ class RBLNModelConfig(RBLNSerializableConfigProtocol):
315
+ """Base configuration class for RBLN models that handles compilation settings, runtime options, and submodules.
316
+
317
+ This class provides functionality for:
318
+
319
+ 1. Managing compilation configurations for RBLN devices
320
+ 2. Configuring runtime behavior such as device placement
321
+ 3. Handling nested configuration objects for complex model architectures
322
+ 4. Serializing and deserializing configurations
323
+
324
+ Examples:
325
+ Using with RBLNModel.from_pretrained():
326
+ ```python
327
+ from optimum.rbln import RBLNResNetForImageClassification
328
+
329
+ # Method 1: Using rbln_ prefixed arguments (recommended for simple cases)
330
+ model = RBLNResNetForImageClassification.from_pretrained(
331
+ "model_id",
332
+ export=True, # Compile the model
333
+ rbln_image_size=224,
334
+ rbln_batch_size=16,
335
+ rbln_create_runtimes=True,
336
+ rbln_device=0
337
+ )
338
+
339
+ # Method 2: Using a config dictionary
340
+ rbln_config_dict = {
341
+ "image_size": 224,
342
+ "batch_size": 16,
343
+ "create_runtimes": True
344
+ }
345
+ model = RBLNResNetForImageClassification.from_pretrained(
346
+ "model_id",
347
+ export=True,
348
+ rbln_config=rbln_config_dict
349
+ )
350
+
351
+ # Method 3: Using a RBLNModelConfig instance
352
+ from optimum.rbln import RBLNResNetForImageClassificationConfig
353
+
354
+ config = RBLNResNetForImageClassificationConfig(
355
+ image_size=224,
356
+ batch_size=16,
357
+ create_runtimes=True
358
+ )
359
+
360
+ model = RBLNResNetForImageClassification.from_pretrained(
361
+ "model_id",
362
+ export=True,
363
+ rbln_config=config
364
+ )
365
+
366
+ # Method 4: Combining a config object with override parameters
367
+ # (rbln_ prefixed parameters take precedence over rbln_config values)
368
+ model = RBLNResNetForImageClassification.from_pretrained(
369
+ "model_id",
370
+ export=True,
371
+ rbln_config=config,
372
+ rbln_image_size=320, # This overrides the value in config
373
+ rbln_device=1 # This sets a new value
374
+ )
375
+ ```
376
+
377
+
378
+ Save and load configuration:
379
+ ```python
380
+ # Save to disk
381
+ config.save("/path/to/model")
382
+
383
+ # Using AutoConfig
384
+ loaded_config = RBLNAutoConfig.load("/path/to/model")
385
+ ```
386
+
387
+
388
+ Converting between configuration formats:
389
+ ```python
390
+ # Converting a dictionary to a config instance
391
+ config_dict = {
392
+ "image_size": 224,
393
+ "batch_size": 8,
394
+ "create_runtimes": True
395
+ }
396
+ config = RBLNResNetForImageClassificationConfig(**config_dict)
397
+ ```
398
+
399
+ Configuration for language models:
400
+ ```python
401
+ from optimum.rbln import RBLNLlamaForCausalLMConfig, RBLNCompileConfig
402
+
403
+ # Configure a LLaMA for RBLN
404
+ config = RBLNLlamaForCausalLMConfig(
405
+ max_seq_len=4096,
406
+ device=[0, 1, 2, 3],
407
+ tensor_parallel_size=4 # For multi-NPU parallel inference
408
+ )
409
+ ```
410
+
411
+ Working with models that have submodules:
412
+ ```python
413
+ from optimum.rbln import RBLNLlavaNextForConditionalGeneration
414
+
415
+ # Configuring a model with submodules
416
+ # LlavaNext has a vision_tower and a language_model submodule
417
+ model = RBLNLlavaNextForConditionalGeneration.from_pretrained(
418
+ "llava-hf/llava-v1.6-mistral-7b-hf",
419
+ export=True,
420
+ rbln_config={
421
+ # Main model's (projector, which is not a submodule) configuration
422
+ "create_runtimes": True,
423
+ "device": 0,
424
+
425
+ # Submodule configurations as nested dictionaries
426
+ "vision_tower": {
427
+ "image_size": 336,
428
+ },
429
+ "language_model": {
430
+ "tensor_parallel_size": 4, # Distribute across 4 NPUs
431
+ "max_seq_len": 8192,
432
+ "use_inputs_embeds": True,
433
+ "batch_size": 1,
434
+ },
435
+ },
436
+ )
437
+ ```
438
+
439
+ Advanced multi-device deployment with tensor parallelism:
440
+ ```python
441
+ from optimum.rbln import RBLNLlamaForCausalLMConfig
442
+
443
+ # Setup a complex multi-device configuration for large language models
444
+ llm_config = RBLNLlamaForCausalLMConfig(
445
+ # Split model across 8 NPUs
446
+ tensor_parallel_size=8,
447
+
448
+ # Runtime options
449
+ device=[8, 9, 10, 11, 12, 13, 14, 15],
450
+ create_runtimes=True,
451
+ activate_profiler=True, # Enable profiling for performance analysis
452
+
453
+ # Model-specific parameters for the LLM
454
+ max_seq_len=131072,
455
+ batch_size=4,
456
+ attn_impl="flash_attn",
457
+ )
458
+ ```
459
+
460
+ Compilation without runtime creation (create_runtimes=False):
461
+ ```python
462
+ from optimum.rbln import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig
463
+
464
+ # Compile a model on a machine without NPU or for later use
465
+ config = RBLNLlamaForCausalLMConfig(
466
+ create_runtimes=False, # Compile only, don't create runtime
467
+ npu="RBLN-CA25", # Specify target NPU for compilation
468
+ max_seq_len=4096,
469
+ tensor_parallel_size=4,
470
+ batch_size=1
471
+ )
472
+
473
+ # Export the model - will compile but not create runtimes
474
+ model = RBLNLlamaForCausalLM.from_pretrained(
475
+ "meta-llama/Llama-2-7b-hf",
476
+ export=True,
477
+ rbln_config=config
478
+ )
479
+
480
+ # Save the compiled model for later use on NPU
481
+ model.save_pretrained("./compiled_llama_model")
482
+
483
+ # Later, on a machine with the target NPU
484
+ inference_model = RBLNLlamaForCausalLM.from_pretrained(
485
+ "./compiled_llama_model",
486
+ rbln_create_runtimes=True, # Now create runtimes (Optional)
487
+ )
488
+ ```
489
+
490
+ Two-stage workflow with separate compilation and runtime:
491
+ ```python
492
+ from optimum.rbln import RBLNResNetForImageClassification
493
+
494
+ # Stage 1: Model engineer compiles model (can be on any machine)
495
+ def compile_model():
496
+ model = RBLNResNetForImageClassification.from_pretrained(
497
+ "microsoft/resnet-50",
498
+ export=True,
499
+ rbln_create_runtimes=False,
500
+ rbln_npu="RBLN-CA25",
501
+ rbln_image_size=224
502
+ )
503
+ model.save_pretrained("./compiled_model")
504
+ print("Model compiled and saved, ready for deployment")
505
+
506
+ # Stage 2: Deployment engineer loads model on NPU
507
+ def deploy_model():
508
+ model = RBLNResNetForImageClassification.from_pretrained(
509
+ "./compiled_model",
510
+ rbln_create_runtimes=True,
511
+ )
512
+ print("Model loaded and ready for inference")
513
+ return model
514
+ ```
515
+ """
516
+
517
+ non_save_attributes = [
518
+ "_frozen",
519
+ "_runtime_options",
520
+ "torch_dtype",
521
+ "npu",
522
+ "tensor_parallel_size",
523
+ "create_runtimes",
524
+ "device",
525
+ "device_map",
526
+ "activate_profiler",
527
+ "timeout",
528
+ ]
529
+ submodules: List[str] = []
530
+ subclass_non_save_attributes = []
531
+ _allow_no_compile_cfgs = False
532
+
533
+ def initialize_submodule_config(
534
+ self,
535
+ submodule_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
536
+ force_kwargs: bool = False,
537
+ **kwargs: Any,
538
+ ) -> "RBLNModelConfig":
539
+ if submodule_config is None:
540
+ submodule_config = {}
541
+
542
+ if isinstance(submodule_config, RBLNModelConfig):
543
+ return submodule_config
544
+
545
+ if isinstance(submodule_config, dict):
546
+ from_predecessor = self._runtime_options.copy()
547
+ from_predecessor.update(
548
+ {
549
+ "npu": self.npu,
550
+ "tensor_parallel_size": self.tensor_parallel_size,
551
+ "optimum_rbln_version": self.optimum_rbln_version,
552
+ }
553
+ )
554
+ from_predecessor.update(kwargs)
555
+
556
+ init_kwargs = from_predecessor
557
+ init_kwargs.update(submodule_config)
558
+
559
+ if force_kwargs:
560
+ for key, value in kwargs.items():
561
+ if key in init_kwargs:
562
+ if init_kwargs[key] != value:
563
+ raise ValueError(
564
+ f"Parameter conflict for '{key}': submodule_config has {init_kwargs[key]}, "
565
+ f"but kwargs has {value}. Using kwargs value: {value}"
566
+ )
567
+ init_kwargs[key] = value
568
+
569
+ if "cls_name" in init_kwargs:
570
+ config_cls = get_rbln_config_class(init_kwargs["cls_name"])
571
+ else:
572
+ return init_kwargs
573
+
574
+ submodule_config = config_cls(**init_kwargs)
575
+
576
+ if not isinstance(submodule_config, RBLNModelConfig):
577
+ raise TypeError(f"Invalid submodule config type: {type(submodule_config)}")
578
+
579
+ return submodule_config
580
+
581
+ def filter_parameters(self, config_cls: Type["RBLNModelConfig"], parameters: Dict[str, Any]) -> Dict[str, Any]:
582
+ import importlib
583
+
584
+ model_cls_name = config_cls.__name__.replace("Config", "")
585
+ modeling_module_name = config_cls.__module__.replace("configuration_", "modeling_")
586
+
587
+ model_cls = None
588
+ try:
589
+ modeling_module = importlib.import_module(modeling_module_name)
590
+ if hasattr(modeling_module, model_cls_name):
591
+ model_cls = getattr(modeling_module, model_cls_name)
592
+ except ImportError:
593
+ logger.debug(f"Could not import modeling module: {modeling_module_name}")
594
+
595
+ filtered_out_params = set()
596
+
597
+ if model_cls is not None:
598
+ if not getattr(model_cls, "_tp_support", False):
599
+ filtered_out_params.add("tensor_parallel_size")
600
+
601
+ filtered_params = {}
602
+ for key, value in parameters.items():
603
+ if key in filtered_out_params:
604
+ logger.debug(
605
+ f"Parameter '{key}' filtered out for {config_cls.__name__} (not supported by model flags)."
606
+ )
607
+ else:
608
+ filtered_params[key] = value
609
+
610
+ return filtered_params
611
+
612
+ def __setattr__(self, key, value):
613
+ if (
614
+ key != "_attributes_map"
615
+ and key not in self.non_save_attributes
616
+ and key not in self.subclass_non_save_attributes
617
+ ):
618
+ self._attributes_map[key] = value
619
+
620
+ if hasattr(self, "_frozen") and self._frozen:
621
+ if not hasattr(self, key) or getattr(self, key) != value:
622
+ raise RuntimeError(
623
+ f"`{self.__class__.__name__}` is frozen. Cannot update or set attribute after freezing."
624
+ )
625
+
626
+ # If the submodule is a dict, Instantiate the submodule config class
627
+ if key in self.submodules and isinstance(value, dict) and (cls_name := value.get("cls_name")):
628
+ rbln_config_cls = getattr(importlib.import_module("optimum.rbln"), cls_name)
629
+ value = rbln_config_cls(**value)
630
+
631
+ # Forbid setting keyword-only arguments
632
+ # keyword-only arguments should be translated to other attributes, not set directly
633
+ _keyword_only_args = set()
634
+ init_signature = inspect.signature(self.__class__.__init__)
635
+ for param_name, param in init_signature.parameters.items():
636
+ if param.kind == inspect.Parameter.KEYWORD_ONLY:
637
+ _keyword_only_args.add(param_name)
638
+
639
+ if key in _keyword_only_args:
640
+ raise AttributeError(
641
+ f"Cannot set attribute '{key}'. This is an internal error. Please report it to the developers."
642
+ )
643
+
644
+ super().__setattr__(key, value)
645
+
646
+ def __init__(
647
+ self,
648
+ cls_name: Optional[str] = None,
649
+ create_runtimes: Optional[bool] = None,
650
+ device: Optional[Union[int, List[int]]] = None,
651
+ device_map: Optional[Dict[str, Union[int, List[int]]]] = None,
652
+ activate_profiler: Optional[bool] = None,
653
+ npu: Optional[str] = None,
654
+ tensor_parallel_size: Optional[int] = None,
655
+ timeout: Optional[int] = None,
656
+ optimum_rbln_version: Optional[str] = None,
657
+ _torch_dtype: Optional[str] = None,
658
+ _compile_cfgs: List[RBLNCompileConfig] = [],
659
+ *,
660
+ optimize_host_memory: Optional[bool] = None,
661
+ **kwargs: Any,
662
+ ):
663
+ """
664
+ Initialize a RBLN model configuration with runtime options and compile configurations.
665
+
666
+ Args:
667
+ cls_name (Optional[str]): The class name of the configuration. Defaults to the current class name.
668
+ create_runtimes (Optional[bool]): Whether to create RBLN runtimes. Defaults to True.
669
+ device (Optional[Union[int, List[int]]]): The device(s) to load the model onto. Can be a single device ID or a list.
670
+ device_map (Optional[Dict[str, Union[int, List[int]]]]): Mapping from compiled model names to device IDs.
671
+ activate_profiler (Optional[bool]): Whether to activate the profiler for performance analysis.
672
+ npu (Optional[str]): The NPU device name to use for compilation.
673
+ tensor_parallel_size (Optional[int]): Size for tensor parallelism to distribute the model across devices.
674
+ timeout (Optional[int]): The timeout for the runtime in seconds. If it isn't provided, it will be set to 60 by default.
675
+ optimum_rbln_version (Optional[str]): The optimum-rbln version used for this configuration.
676
+ _torch_dtype (Optional[str]): The data type to use for the model.
677
+ _compile_cfgs (List[RBLNCompileConfig]): List of compilation configurations for the model.
678
+ kwargs: Additional keyword arguments.
679
+
680
+ Raises:
681
+ ValueError: If unexpected keyword arguments are provided.
682
+
683
+
684
+ """
685
+ self._attributes_map = {}
686
+ self._frozen = False
687
+
688
+ self.cls_name = cls_name
689
+ if self.cls_name is None:
690
+ self.cls_name = self.__class__.__name__
691
+
692
+ self._runtime_options = {}
693
+ self._runtime_options["create_runtimes"] = create_runtimes
694
+ self._runtime_options["device"] = device
695
+ self._runtime_options["device_map"] = device_map
696
+ self._runtime_options["activate_profiler"] = activate_profiler
697
+ self._runtime_options["timeout"] = timeout
698
+
699
+ if optimize_host_memory is not None:
700
+ logger.warning("`optimize_host_memory` is deprecated and will be removed in future versions.")
701
+
702
+ # Automatically pass npu, tensor_parallel_size to compile_cfgs
703
+ self.npu = npu
704
+ self.tensor_parallel_size = tensor_parallel_size
705
+
706
+ self._torch_dtype = _torch_dtype or "float32"
707
+ self.optimum_rbln_version = optimum_rbln_version
708
+ if self.optimum_rbln_version is None:
709
+ self.optimum_rbln_version = __version__
710
+
711
+ self._compile_cfgs: List[RBLNCompileConfig] = _compile_cfgs
712
+
713
+ if not isinstance(self._compile_cfgs, list):
714
+ raise ValueError("`compile_cfgs` must be a list of `RBLNCompileConfig`.")
715
+ if len(self._compile_cfgs) > 0 and not isinstance(self._compile_cfgs[0], RBLNCompileConfig):
716
+ self.set_compile_cfgs([RBLNCompileConfig(**cfg) for cfg in self._compile_cfgs])
717
+
718
+ if len(kwargs) > 0:
719
+ if optimum_rbln_version is not None: # loaded from file
720
+ if Version(__version__) < Version(optimum_rbln_version):
721
+ diff = "newer"
722
+ elif Version(__version__) > Version(optimum_rbln_version):
723
+ diff = "older"
724
+ else:
725
+ diff = None
726
+ if diff is not None:
727
+ raise ValueError(
728
+ f"Unexpected arguments: {kwargs.keys()}\n"
729
+ f"Maybe you are trying to load a model compiled with {diff} version of optimum-rbln. "
730
+ "It is recommended to use the same version to compile and load the model.\n"
731
+ f"Current version: {__version__}, Loaded version: {optimum_rbln_version}"
732
+ )
733
+
734
+ raise ValueError(f"Unexpected arguments: {kwargs.keys()}")
735
+
736
+ @property
737
+ def torch_dtype(self):
738
+ return getattr(torch, self._torch_dtype)
739
+
740
+ @torch_dtype.setter
741
+ def torch_dtype(self, torch_dtype: Union[str, torch.dtype]):
742
+ if isinstance(torch_dtype, torch.dtype):
743
+ torch_dtype = RBLNCompileConfig.normalize_dtype(torch_dtype)
744
+
745
+ self._torch_dtype = torch_dtype
746
+
747
+ @property
748
+ def rbln_model_cls_name(self) -> str:
749
+ return self.__class__.__name__[:-6]
750
+
751
+ @property
752
+ def rbln_model_cls(self) -> Type:
753
+ rbln_model_cls = getattr(importlib.import_module("optimum.rbln"), self.rbln_model_cls_name, None)
754
+ if rbln_model_cls is None:
755
+ raise ValueError(
756
+ f"RBLN model class {self.rbln_model_cls_name} not found. This is an internal error. "
757
+ "Please report it to the developers."
758
+ )
759
+ return rbln_model_cls
760
+
761
+ def _prepare_for_serialization(self) -> Dict[str, Any]:
762
+ # Prepare the attributes map for serialization by converting nested RBLNModelConfig
763
+ # objects to their serializable form.
764
+ serializable_map = {}
765
+ for key, value in self._attributes_map.items():
766
+ if isinstance(value, RBLNSerializableConfigProtocol):
767
+ # Convert nested RBLNModelConfig to its serializable form
768
+ serializable_map[key] = value._prepare_for_serialization()
769
+ elif key == "_compile_cfgs":
770
+ serializable_map[key] = [cfg.asdict() for cfg in value]
771
+ else:
772
+ serializable_map[key] = value
773
+ return serializable_map
774
+
775
+ def __repr__(self):
776
+ repr_dict = self._prepare_for_serialization()
777
+ return json.dumps(repr_dict, indent=2)
778
+
779
+ @property
780
+ def compile_cfgs(self):
781
+ return self._compile_cfgs
782
+
783
+ @compile_cfgs.setter
784
+ def compile_cfgs(self, compile_cfgs: List[RBLNCompileConfig]):
785
+ raise RuntimeError("`compile_cfgs` cannot be set directly. Please use `set_compile_cfgs` instead.")
786
+
787
+ def set_compile_cfgs(self, compile_cfgs: List[RBLNCompileConfig]):
788
+ if not isinstance(compile_cfgs, list):
789
+ raise ValueError("`compile_cfgs` must be a list of `RBLNCompileConfig`.")
790
+ if len(compile_cfgs) == 0:
791
+ raise ValueError("`compile_cfgs` must contain at least one `RBLNCompileConfig`.")
792
+ if not isinstance(compile_cfgs[0], RBLNCompileConfig):
793
+ raise ValueError("`compile_cfgs` must contain only `RBLNCompileConfig`.")
794
+
795
+ self._compile_cfgs = compile_cfgs
796
+ for compile_cfg in self._compile_cfgs:
797
+ compile_cfg.npu = self.npu
798
+ compile_cfg.tensor_parallel_size = self.tensor_parallel_size
799
+
800
+ target_npu = self.npu or next((cfg.npu for cfg in self._compile_cfgs if cfg.npu is not None), None)
801
+ warn_deprecated_npu(target_npu)
802
+
803
+ def freeze(self):
804
+ if self._frozen:
805
+ raise RuntimeError(f"`{self.__class__.__name__}` is already frozen.")
806
+
807
+ if (
808
+ not isinstance(self._compile_cfgs, list)
809
+ or len(self._compile_cfgs) == 0
810
+ or not all(isinstance(cfg, RBLNCompileConfig) for cfg in self._compile_cfgs)
811
+ ):
812
+ if not self._allow_no_compile_cfgs:
813
+ raise RuntimeError("`compile_cfgs` must contain at least one `RBLNCompileConfig` before freezing.")
814
+
815
+ for submodule_name in self.submodules:
816
+ submodule_config = getattr(self, submodule_name, None)
817
+ if not isinstance(submodule_config, RBLNModelConfig):
818
+ raise ValueError(f"`{submodule_name}` must be an instance of `RBLNModelConfig` before freezing.")
819
+
820
+ if not submodule_config.is_frozen():
821
+ raise ValueError(f"`{submodule_name}` config must be frozen before freezing super config.")
822
+
823
+ self._frozen = True
824
+
825
+ def is_frozen(self):
826
+ return self._frozen
827
+
828
+ def save(self, path: str):
829
+ if not self._frozen:
830
+ raise RuntimeError("`RBLNModelConfig` is not frozen. Please call `set_compile_cfgs` first.")
831
+
832
+ # save as json file without runtime attributes
833
+ path = Path(path)
834
+ if path.is_dir():
835
+ path = path / "rbln_config.json"
836
+
837
+ with open(path, "w") as jsonf:
838
+ serializable_data = self._prepare_for_serialization()
839
+ json.dump(serializable_data, jsonf, indent=2)
840
+
841
+ @classmethod
842
+ def load(cls, path: str, **kwargs: Any) -> "RBLNModelConfig":
843
+ """
844
+ Load a RBLNModelConfig from a path.
845
+
846
+ Args:
847
+ path (str): Path to the RBLNModelConfig file or directory containing the config file.
848
+ kwargs: Additional keyword arguments to override configuration values.
849
+ Keys starting with 'rbln_' will have the prefix removed and be used
850
+ to update the configuration.
851
+
852
+ Returns:
853
+ RBLNModelConfig: The loaded configuration instance.
854
+
855
+ Note:
856
+ This method loads the configuration from the specified path and applies any
857
+ provided overrides. If the loaded configuration class doesn't match the expected
858
+ class, a warning will be logged.
859
+ """
860
+ cls_reserved, config_file = load_config(path)
861
+
862
+ if cls_reserved != cls:
863
+ logger.warning(f"Expected {cls.__name__}, but got {cls_reserved.__name__}.")
864
+
865
+ rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
866
+ rbln_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys}
867
+ config_file.update(rbln_kwargs)
868
+
869
+ return cls(**config_file)
870
+
871
+ @classmethod
872
+ def initialize_from_kwargs(
873
+ cls: Type["RBLNModelConfig"],
874
+ rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
875
+ **kwargs: Any,
876
+ ) -> Tuple["RBLNModelConfig", Dict[str, Any]]:
877
+ # Initialize RBLNModelConfig from kwargs.
878
+ kwargs_keys = list(kwargs.keys())
879
+ rbln_kwargs = {key[5:]: kwargs.pop(key) for key in kwargs_keys if key.startswith("rbln_")}
880
+
881
+ if isinstance(rbln_config, dict):
882
+ rbln_config.update(rbln_kwargs)
883
+ rbln_config = cls(**rbln_config)
884
+
885
+ elif rbln_config is None:
886
+ rbln_config = cls(**rbln_kwargs)
887
+
888
+ elif isinstance(rbln_config, RBLNModelConfig):
889
+ for key, value in rbln_kwargs.items():
890
+ setattr(rbln_config, key, value)
891
+
892
+ return rbln_config, kwargs
893
+
894
+ def get_default_values_for_original_cls(self, func_name: str, keys: List[str]) -> Dict[str, Any]:
895
+ # Get default values for original class attributes from RBLNModelConfig.
896
+ model_cls = self.rbln_model_cls.get_hf_class()
897
+ func = getattr(model_cls, func_name)
898
+ func_signature = inspect.signature(func)
899
+ default_values = {}
900
+ for key in keys:
901
+ if key in func_signature.parameters:
902
+ default_values[key] = func_signature.parameters[key].default
903
+ else:
904
+ raise ValueError(f"Default value for `{key}` is not set for the model class.")
905
+ return default_values
906
+
907
+ @property
908
+ def create_runtimes(self):
909
+ context = ContextRblnConfig.get_current_context()["create_runtimes"]
910
+ if context is not None:
911
+ return context
912
+ elif self._runtime_options["create_runtimes"] is None:
913
+ return True
914
+ return self._runtime_options["create_runtimes"]
915
+
916
+ @create_runtimes.setter
917
+ def create_runtimes(self, create_runtimes: bool):
918
+ self._runtime_options["create_runtimes"] = create_runtimes
919
+
920
+ @property
921
+ def device(self):
922
+ context = ContextRblnConfig.get_current_context()["device"]
923
+ if context is not None:
924
+ return context
925
+ return self._runtime_options["device"]
926
+
927
+ @device.setter
928
+ def device(self, device: Union[int, List[int]]):
929
+ self._runtime_options["device"] = device
930
+
931
+ @property
932
+ def device_map(self):
933
+ context = ContextRblnConfig.get_current_context()["device_map"]
934
+ if context:
935
+ return context
936
+ elif self._runtime_options["device_map"] is None:
937
+ rbln_device_map = {}
938
+ device_val = self.device
939
+ for cfg in self.compile_cfgs:
940
+ rbln_device_map[cfg.compiled_model_name] = device_val
941
+ return rbln_device_map
942
+ return self._runtime_options["device_map"]
943
+
944
+ @device_map.setter
945
+ def device_map(self, device_map: Dict[str, Union[int, List[int]]]):
946
+ self._runtime_options["device_map"] = device_map
947
+
948
+ @property
949
+ def activate_profiler(self):
950
+ context = ContextRblnConfig.get_current_context()["activate_profiler"]
951
+ if context is not None:
952
+ return context
953
+ return self._runtime_options["activate_profiler"]
954
+
955
+ @activate_profiler.setter
956
+ def activate_profiler(self, activate_profiler: bool):
957
+ self._runtime_options["activate_profiler"] = activate_profiler
958
+
959
+ @property
960
+ def timeout(self):
961
+ context = ContextRblnConfig.get_current_context()["timeout"]
962
+ if context is not None:
963
+ return context
964
+ return self._runtime_options["timeout"]
965
+
966
+ @timeout.setter
967
+ def timeout(self, timeout: int):
968
+ self._runtime_options["timeout"] = timeout