optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,589 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import glob
16
+ import os
17
+ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type, Union
18
+
19
+ import torch
20
+ from huggingface_hub import hf_hub_download, list_repo_files
21
+ from safetensors.torch import load_file
22
+ from torch.nn import Linear, Parameter
23
+ from torch.nn import functional as F
24
+ from transformers import AutoConfig
25
+ from transformers.modeling_utils import get_state_dict_dtype, no_init_weights
26
+
27
+ from ...configuration_utils import RBLNSerializableConfigProtocol
28
+ from ...utils.logging import get_logger
29
+
30
+
31
+ if TYPE_CHECKING:
32
+ from transformers.models.auto.modeling_auto import _BaseAutoModelClass
33
+
34
+ logger = get_logger()
35
+
36
+
37
+ # Constants
38
+ QUANTIZED_WEIGHTS = {
39
+ "q_proj",
40
+ "k_proj",
41
+ "v_proj",
42
+ "o_proj",
43
+ "gate_proj",
44
+ "up_proj",
45
+ "down_proj",
46
+ }
47
+
48
+ # Common alias sets seen in community checkpoints
49
+ VARIANT_ALIASES: Dict[str, List[str]] = {
50
+ "weight_scale": ["weight_scale", "scales", "w_scale", "scale"],
51
+ "input_scale": ["input_scale", "act_scale", "activation_scale", "a_scale"],
52
+ "kv_scale": ["kv_scale", "kv_scales"],
53
+ "k_scale": ["k_scale", "k_scales"],
54
+ "v_scale": ["v_scale", "v_scales"],
55
+ }
56
+
57
+
58
+ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
59
+ SUPPORTED_FORMATS = ["rbln"]
60
+ SUPPORTED_WEIGHTS = ["int4", "int8", "fp8", "fp16"]
61
+ SUPPORTED_ACTIVATIONS = ["int8", "fp8", "fp16"]
62
+ SUPPORTED_KVCACHES = ["fp8", "fp16"]
63
+ RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
64
+
65
+ def __init__(
66
+ self,
67
+ format: Optional[str] = None,
68
+ weights: Optional[str] = None,
69
+ activations: Optional[str] = None,
70
+ kv_caches: Optional[str] = None,
71
+ *,
72
+ precision: Optional[str] = None,
73
+ ):
74
+ self.format = format or "rbln"
75
+ if self.format not in self.SUPPORTED_FORMATS:
76
+ raise ValueError(f"Invalid format: {self.format}, supported formats are: {self.SUPPORTED_FORMATS}")
77
+
78
+ if precision is not None:
79
+ logger.warning("The `precision` argument is deprecated. Use `weights` and `activations` instead.")
80
+ if any(precision_arg is not None for precision_arg in (weights, activations)):
81
+ raise ValueError("`precision` and `weights` or `activations` cannot be set at the same time.")
82
+
83
+ if precision == "w4a16":
84
+ weights = "int4"
85
+ activations = "fp16"
86
+ else:
87
+ raise ValueError(f"Invalid precision: {precision}")
88
+
89
+ self.weights = weights or "fp16"
90
+ self.activations = activations or "fp16"
91
+ self.kv_caches = kv_caches or "fp16"
92
+ self._validate()
93
+
94
+ def _validate(self):
95
+ if self.format not in self.SUPPORTED_FORMATS:
96
+ raise ValueError(f"Invalid format: {self.format}, supported formats are: {self.SUPPORTED_FORMATS}")
97
+ if self.weights not in self.SUPPORTED_WEIGHTS:
98
+ raise ValueError(f"Invalid weights: {self.weights}, supported weights are: {self.SUPPORTED_WEIGHTS}")
99
+ if self.activations not in self.SUPPORTED_ACTIVATIONS:
100
+ raise ValueError(
101
+ f"Invalid activations: {self.activations}, supported activations are: {self.SUPPORTED_ACTIVATIONS}"
102
+ )
103
+ if self.kv_caches not in self.SUPPORTED_KVCACHES:
104
+ raise ValueError(
105
+ f"Invalid kv_caches: {self.kv_caches}, supported kv_caches are: {self.SUPPORTED_KVCACHES}"
106
+ )
107
+ if self.weights == "fp16" and self.activations == "fp16":
108
+ raise ValueError("weights and activations of QuantizationConfig cannot be both fp16. It is meaningless.")
109
+
110
+ def _prepare_for_serialization(self) -> Dict[str, Any]:
111
+ return {
112
+ "format": self.format,
113
+ "weights": self.weights,
114
+ "activations": self.activations,
115
+ "kv_caches": self.kv_caches,
116
+ }
117
+
118
+ def maybe_set_quantization_env(self):
119
+ if self.weights == "int4":
120
+ os.environ[self.RBLN_QUANT_BITS_ENV] = "4"
121
+
122
+ def maybe_reset_quantization_env(self):
123
+ if self.RBLN_QUANT_BITS_ENV in os.environ:
124
+ os.environ.pop(self.RBLN_QUANT_BITS_ENV)
125
+
126
+ @property
127
+ def nbits_per_param(self) -> int:
128
+ if self.weights in ["int4", "fp4"]:
129
+ return 4
130
+ elif self.weights in ["int8", "fp8"]:
131
+ return 8
132
+ else:
133
+ raise ValueError(f"Invalid weights: {self.weights}")
134
+
135
+
136
+ class QuantizedLayerFactory:
137
+ def __init__(self, quantization_config: RBLNQuantizationConfig):
138
+ self.quantization_config = quantization_config
139
+
140
+ def create_linear(self, layer: Linear) -> Linear:
141
+ if self.quantization_config.weights in ["int4", "int8"]:
142
+ return self.create_qlinear(layer)
143
+ elif self.quantization_config.weights == "fp8":
144
+ return self.create_fp8linear(layer)
145
+ else:
146
+ raise ValueError(f"Invalid quantization weights: {self.quantization_config.weights}")
147
+
148
+ def create_qlinear(self, layer: Linear) -> Linear:
149
+ return create_qlinear(layer, self.quantization_config)
150
+
151
+ def create_fp8linear(self, layer: Linear) -> Linear:
152
+ return create_fp8linear(layer, self.quantization_config)
153
+
154
+
155
+ def get_quantized_model(
156
+ hf_auto_model_class: Type["_BaseAutoModelClass"],
157
+ model_id: str,
158
+ use_auth_token: Optional[Union[bool, str]] = None,
159
+ revision: Optional[str] = None,
160
+ cache_dir: Optional[str] = None,
161
+ force_download: bool = False,
162
+ local_files_only: bool = False,
163
+ rbln_quantization: Optional[RBLNQuantizationConfig] = None,
164
+ **kwargs,
165
+ ):
166
+ """
167
+ Get a quantized model from a model class and model id.
168
+ """
169
+ # torch_dtype should not be passed to AutoConfig.from_pretrained
170
+ # since it doesn't support 'auto'
171
+ torch_dtype = kwargs.pop("torch_dtype", None)
172
+ if torch_dtype is not None:
173
+ logger.warning(
174
+ "torch_dtype is not supported for quantized models. "
175
+ "It will be ignored and the dtype of the model will be determined by the weights."
176
+ )
177
+ torch_dtype = None
178
+
179
+ # get paths of safetensors files in the model repo
180
+ safetensor_files = load_weight_files(
181
+ model_id,
182
+ use_auth_token=use_auth_token,
183
+ revision=revision,
184
+ cache_dir=cache_dir,
185
+ force_download=force_download,
186
+ local_files_only=local_files_only,
187
+ )
188
+
189
+ # load safetensors files into memory
190
+ safetensors = [load_file(safetensor_file) for safetensor_file in safetensor_files]
191
+
192
+ # get the dtype of the model from the first safetensor file
193
+ torch_dtype = get_state_dict_dtype(safetensors[0])
194
+
195
+ config = AutoConfig.from_pretrained(
196
+ model_id,
197
+ use_auth_token=use_auth_token,
198
+ revision=revision,
199
+ cache_dir=cache_dir,
200
+ force_download=force_download,
201
+ local_files_only=local_files_only,
202
+ **kwargs,
203
+ )
204
+
205
+ with no_init_weights():
206
+ model = hf_auto_model_class.from_config(config, torch_dtype=torch_dtype)
207
+
208
+ # Quantize the model
209
+ update_layers_to_quantize(model, rbln_quantization)
210
+
211
+ # Load weights into the model
212
+ load_weights_from_files(model, safetensors, rbln_quantization)
213
+
214
+ return model
215
+
216
+
217
+ def load_weight_files(
218
+ model_id: str,
219
+ use_auth_token: Optional[Union[bool, str]] = None,
220
+ revision: Optional[str] = None,
221
+ cache_dir: Optional[str] = None,
222
+ force_download: bool = False,
223
+ local_files_only: bool = False,
224
+ ) -> list[str]:
225
+ """
226
+ Discover and download safetensors files for the given model id.
227
+ """
228
+
229
+ if os.path.isdir(model_id):
230
+ safetensor_files = glob.glob(f"{model_id}/*.safetensors")
231
+ else:
232
+ try:
233
+ # List all files in the repository
234
+ repo_files = list_repo_files(model_id, revision=revision, token=use_auth_token)
235
+ # Filter for safetensors files
236
+ safetensor_files = []
237
+
238
+ for file in repo_files:
239
+ if file.endswith(".safetensors"):
240
+ # Download the safetensors file
241
+ downloaded_file = hf_hub_download(
242
+ repo_id=model_id,
243
+ filename=file,
244
+ revision=revision,
245
+ token=use_auth_token,
246
+ cache_dir=cache_dir,
247
+ force_download=force_download,
248
+ local_files_only=local_files_only,
249
+ )
250
+ safetensor_files.append(downloaded_file)
251
+ except Exception as e:
252
+ logger.error(f"Failed to download safetensors files from Hugging Face Hub: {e}")
253
+ raise e
254
+
255
+ if not safetensor_files:
256
+ raise FileNotFoundError(f"No safetensors files found for model_id: {model_id}")
257
+
258
+ return safetensor_files
259
+
260
+
261
+ def update_layers_to_quantize(
262
+ module: torch.nn.Module,
263
+ rbln_quantization: Optional[RBLNQuantizationConfig] = None,
264
+ ) -> None:
265
+ """
266
+ Updates specified linear layers to quantized (qlinear) layers in the given module.
267
+ """
268
+
269
+ processed_layers = []
270
+ quantized_layer_factory = QuantizedLayerFactory(rbln_quantization)
271
+
272
+ for name, layer in module.named_modules():
273
+ if is_target_for_qlinear_replacement(name, layer):
274
+ parent_module, layer_name = get_parent_and_child(module, name)
275
+ setattr(parent_module, layer_name, quantized_layer_factory.create_linear(layer))
276
+ processed_layers.append(name)
277
+
278
+ if processed_layers:
279
+ logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
280
+
281
+
282
+ def _last_segment(key: str) -> str:
283
+ parts = key.split(".")
284
+ return parts[-1]
285
+
286
+
287
+ def _replace_last_with(key: str, new_tail: str) -> str:
288
+ parts = key.split(".")
289
+ return ".".join(parts[:-1] + new_tail.split("."))
290
+
291
+
292
+ def _matches_any_alias(key: str, kind: str) -> bool:
293
+ tail = _last_segment(key)
294
+ return tail in VARIANT_ALIASES.get(kind, [])
295
+
296
+
297
+ def _reduce_to_scalar(t: torch.Tensor) -> torch.Tensor:
298
+ if t.ndim == 0:
299
+ return t
300
+ return t.reshape(-1).amax()
301
+
302
+
303
+ def _coerce_per_out_channel_scale(scale: torch.Tensor, out_features: int) -> torch.Tensor:
304
+ s = scale
305
+ if s.ndim == 0:
306
+ # scalar -> expand to [out_features, 1]
307
+ return s.reshape(1, 1).expand(out_features, 1).contiguous()
308
+ if s.ndim == 1:
309
+ if s.numel() == 1:
310
+ return s.reshape(1, 1).expand(out_features, 1).contiguous()
311
+ if s.numel() == out_features:
312
+ return s.reshape(out_features, 1).contiguous()
313
+ # fallback: reduce to scalar then expand
314
+ v = _reduce_to_scalar(s)
315
+ return v.reshape(1, 1).expand(out_features, 1).contiguous()
316
+ if s.ndim == 2:
317
+ if s.shape == (out_features, 1):
318
+ return s.contiguous()
319
+ if s.shape == (1, out_features):
320
+ return s.transpose(0, 1).contiguous()
321
+ # fallback: reduce to [out_features] on non-out dims if possible
322
+ if s.shape[0] == out_features:
323
+ v = s
324
+ while v.ndim > 2:
325
+ v = v.amax(dim=-1)
326
+ if v.shape[-1] != 1:
327
+ v = v.amax(dim=-1, keepdim=True)
328
+ return v.contiguous()
329
+ # otherwise reduce to scalar then expand
330
+ v = _reduce_to_scalar(s)
331
+ return v.reshape(1, 1).expand(out_features, 1).contiguous()
332
+ # high-rank: reduce to scalar then expand
333
+ v = _reduce_to_scalar(s)
334
+ return v.reshape(1, 1).expand(out_features, 1).contiguous()
335
+
336
+
337
+ def _kv_split_items(base_key: str, tensor: torch.Tensor) -> List[Tuple[str, torch.Tensor]]:
338
+ # base_key is the original key whose last token was 'kv_scale'
339
+ # We produce keys with 'k_proj.k_scale' and 'v_proj.v_scale'
340
+ if tensor.ndim == 1 and tensor.numel() >= 2:
341
+ tk, tv = tensor[0], tensor[1]
342
+ elif tensor.ndim == 2 and tensor.shape[0] >= 2 and tensor.shape[1] == 1:
343
+ tk, tv = tensor[0, 0], tensor[1, 0]
344
+ else:
345
+ tk = tv = tensor
346
+ k_key = _replace_last_with(base_key, "k_proj.k_scale")
347
+ v_key = _replace_last_with(base_key, "v_proj.v_scale")
348
+ return [(k_key, tk), (v_key, tv)]
349
+
350
+
351
+ def canonicalize_checkpoint_items(
352
+ model: torch.nn.Module,
353
+ items: Iterable[Tuple[str, torch.Tensor]],
354
+ rbln_quantization: Optional[RBLNQuantizationConfig],
355
+ ) -> List[Tuple[str, torch.Tensor]]:
356
+ params = dict(model.named_parameters(recurse=True))
357
+ results: List[Tuple[str, torch.Tensor]] = []
358
+
359
+ for key, value in items:
360
+ t = value
361
+ # Normalize weight scale variants
362
+ if _matches_any_alias(key, "weight_scale"):
363
+ # rename last token to the canonical weight scale key
364
+ target_key = _replace_last_with(key, "weight_scale")
365
+
366
+ # Determine associated weight param to infer shape
367
+ weight_key = _replace_last_with(target_key, "weight")
368
+ out_features = None
369
+ if weight_key in params:
370
+ wshape = params[weight_key].shape
371
+ if len(wshape) == 2:
372
+ out_features = int(wshape[0])
373
+
374
+ if rbln_quantization.weights in ["int4", "int8"] and out_features is not None:
375
+ t = _coerce_per_out_channel_scale(t.to(torch.float32), out_features)
376
+ elif rbln_quantization.weights == "fp8":
377
+ # Use a conservative scalar scale to ensure broadcastability
378
+ t = _reduce_to_scalar(t.to(torch.float32))
379
+ else:
380
+ t = t.to(torch.float32)
381
+
382
+ results.append((target_key, t))
383
+ continue
384
+
385
+ # Normalize input/activation scale variants
386
+ if _matches_any_alias(key, "input_scale"):
387
+ target_key = _replace_last_with(key, "input_scale")
388
+ t = _reduce_to_scalar(t.to(torch.float32))
389
+ results.append((target_key, t))
390
+ continue
391
+
392
+ # KV scale handling
393
+ if _matches_any_alias(key, "kv_scale"):
394
+ # For quark-like formats, expand to k/v
395
+ kv_items = _kv_split_items(key, t.to(torch.float32))
396
+ for k2, v2 in kv_items:
397
+ results.append((k2, v2))
398
+ continue
399
+
400
+ if _matches_any_alias(key, "k_scale") or _matches_any_alias(key, "v_scale"):
401
+ results.append((key, t.to(torch.float32)))
402
+ continue
403
+
404
+ # Default: passthrough
405
+ results.append((key, t))
406
+
407
+ return results
408
+
409
+
410
+ def load_weights_from_files(
411
+ model: torch.nn.Module,
412
+ safetensors: List[Dict[str, torch.Tensor]],
413
+ rbln_quantization: Optional[RBLNQuantizationConfig] = None,
414
+ ):
415
+ """
416
+ Load safetensor file data directly into the model from provided safetensor files.
417
+ """
418
+
419
+ model_params = dict(model.named_parameters(recurse=True))
420
+ model_buffers = dict(model.named_buffers(recurse=True))
421
+
422
+ unloaded_keys = []
423
+ loaded_input_scale = False
424
+ loaded_kv_scale = False
425
+ loaded_weight_scale = False
426
+
427
+ for safetensor in safetensors:
428
+ # Normalize all (key, tensor) pairs to the internal schema
429
+ normalized_items = canonicalize_checkpoint_items(
430
+ model=model,
431
+ items=safetensor.items(),
432
+ rbln_quantization=rbln_quantization,
433
+ )
434
+
435
+ for key, value in normalized_items:
436
+ # Track which types of scales were observed (post-normalization)
437
+ if key.endswith("input_scale"):
438
+ loaded_input_scale = True
439
+ if key.endswith("weight_scale"):
440
+ loaded_weight_scale = True
441
+ if key.endswith("k_scale") or key.endswith("v_scale"):
442
+ loaded_kv_scale = True
443
+
444
+ # Copy into parameters or buffers
445
+ if key in model_params:
446
+ # Ensure dtype compatibility
447
+ if model_params[key].dtype != value.dtype:
448
+ value = value.to(model_params[key].dtype)
449
+ model_params[key].data.copy_(value)
450
+ elif key in model_buffers:
451
+ if model_buffers[key].dtype != value.dtype:
452
+ value = value.to(model_buffers[key].dtype)
453
+ model_buffers[key].data.copy_(value)
454
+ else:
455
+ unloaded_keys.append(key)
456
+
457
+ if len(unloaded_keys) > 0:
458
+ logger.warning(f"There are unexpected parameters/buffers on the checkpoint: {unloaded_keys}")
459
+ if not loaded_input_scale and rbln_quantization.activations == "fp8":
460
+ raise ValueError(
461
+ "No input_scale found in the checkpoint. Did you use the correct quantization config? "
462
+ "If you are using fp8 quantization, you need to use the correct quantization config."
463
+ )
464
+ if not loaded_weight_scale and rbln_quantization.weights == "fp8":
465
+ raise ValueError(
466
+ "No weight_scale found in the checkpoint. Did you use the correct quantization config? "
467
+ "If you are using fp8 quantization, you need to use the correct quantization config."
468
+ )
469
+ if not loaded_kv_scale and rbln_quantization.kv_caches == "fp8":
470
+ raise ValueError(
471
+ "No kv_scale found in the checkpoint. Did you use the correct quantization config? "
472
+ "If you are using fp8 quantization, you need to use the correct quantization config."
473
+ )
474
+ if loaded_kv_scale and rbln_quantization.kv_caches != "fp8":
475
+ logger.warning(
476
+ "kv_scale found in the checkpoint, but kv_caches of quantization config is not fp8. Ignoring kv_scale."
477
+ )
478
+
479
+
480
+ def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -> bool:
481
+ """
482
+ Checks if a layer is a target for qlinear replacement.
483
+ """
484
+ return layer_name.split(".")[-1] in QUANTIZED_WEIGHTS and isinstance(layer, torch.nn.Linear)
485
+
486
+
487
+ def is_target_for_adding_kv_scales(layer_name: str) -> bool:
488
+ return layer_name.split(".")[-1] in ["self_attn"]
489
+
490
+
491
+ def get_parent_and_child(module: torch.nn.Module, full_name: str) -> tuple:
492
+ """
493
+ Splits the full layer name to retrieve the parent module and the child layer.
494
+ """
495
+ *parent_address, child_name = full_name.split(".")
496
+ parent_module = access_attribute(module, parent_address)
497
+ return parent_module, child_name
498
+
499
+
500
+ def access_attribute(obj: Any, attributes: list[str]) -> Any:
501
+ """
502
+ Recursively accesses a nested attribute from an object using a list of attribute names.
503
+ """
504
+ for attr in attributes:
505
+ obj = getattr(obj, attr)
506
+ return obj
507
+
508
+
509
+ def create_qlinear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
510
+ """
511
+ Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass.
512
+ """
513
+
514
+ def qlinear_forward(self, inputs: torch.Tensor) -> torch.Tensor:
515
+ weight_scale = self.weight_scale
516
+ if inputs.dtype != weight_scale.dtype:
517
+ raise TypeError(f"Expected input dtype {weight_scale.dtype}, but got {inputs.dtype}")
518
+
519
+ w_fp = self.weight.type(inputs.dtype)
520
+ w_fp *= weight_scale.view(-1, 1)
521
+ return F.linear(inputs, w_fp, self.bias)
522
+
523
+ # Convert weight to int8 and add scale parameter
524
+ layer.weight = Parameter(layer.weight.to(torch.int8), requires_grad=False)
525
+ layer.weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=torch.float32), requires_grad=False)
526
+ layer.forward = lambda inputs: qlinear_forward(layer, inputs)
527
+
528
+ return layer
529
+
530
+
531
+ def create_fp8linear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
532
+ """
533
+ Converts a standard linear layer to a fp8 linear layer with a custom forward pass.
534
+ """
535
+
536
+ def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
537
+ finfo = torch.finfo(torch.float8_e4m3fn)
538
+ qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
539
+ return qweight
540
+
541
+ def fp8_gemm(A: torch.Tensor, A_scale, B: torch.Tensor, B_scale, bias, out_dtype: torch.dtype):
542
+ A = A.type(out_dtype)
543
+ B = B.type(out_dtype)
544
+
545
+ if A_scale is not None:
546
+ A *= A_scale
547
+ if B_scale is not None:
548
+ B *= B_scale.to(out_dtype)
549
+
550
+ output = torch.nn.functional.linear(A, B, bias=bias)
551
+ return output
552
+
553
+ def fp8linear_forward(self, x: torch.Tensor) -> torch.Tensor:
554
+ if self.input_scale:
555
+ input = static_per_tensor_quantize(x, self.input_scale)
556
+ else:
557
+ input = x
558
+
559
+ if self.weight_scale:
560
+ # broadcast weight_scale to vector
561
+ weight_scale = self.weight_scale.broadcast_to(self.weight.shape[-1:])
562
+ else:
563
+ weight_scale = None
564
+ output = fp8_gemm(
565
+ A=input,
566
+ A_scale=self.input_scale,
567
+ B=self.weight,
568
+ B_scale=weight_scale,
569
+ bias=self.bias,
570
+ out_dtype=x.dtype,
571
+ )
572
+
573
+ return output
574
+
575
+ layer.weight = Parameter(layer.weight.to(torch.float8_e4m3fn), requires_grad=False)
576
+ layer.weight_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
577
+
578
+ if rbln_quantization.activations == "fp8":
579
+ layer.input_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
580
+ else:
581
+ layer.input_scale = None
582
+
583
+ if rbln_quantization.kv_caches == "fp8":
584
+ layer.k_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
585
+ layer.v_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
586
+
587
+ layer.forward = lambda inputs: fp8linear_forward(layer, inputs)
588
+
589
+ return layer
@@ -0,0 +1,79 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
18
+
19
+ from torch.nn import Module
20
+
21
+ from ...modeling import RBLNModel
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ import rebel
26
+
27
+
28
+ class LoopProcessor(Module, ABC):
29
+ def __init__(self, model: Union[RBLNModel, "rebel.Runtime"]):
30
+ super().__init__()
31
+ self.model = model
32
+
33
+ def __repr__(self) -> str:
34
+ return repr(self.model)
35
+
36
+ def _is_batch_implemented(self) -> bool:
37
+ return self._forward_batch.__func__ is not LoopProcessor._forward_batch
38
+
39
+ def forward(self, *args, force_loop: bool = False, **kwargs) -> Any:
40
+ if not force_loop and self._is_batch_implemented():
41
+ return self._forward_batch(*args, **kwargs)
42
+ else:
43
+ return self._forward_loop(*args, **kwargs)
44
+
45
+ def _forward_loop(self, *args, **kwargs) -> Any:
46
+ batch_size = self._get_batch_size(*args, **kwargs)
47
+
48
+ if not isinstance(batch_size, int) or batch_size == 0:
49
+ return self._process_outputs([])
50
+
51
+ common_inputs = self._prepare_inputs_before_loop(*args, **kwargs)
52
+
53
+ outputs = []
54
+ for i in range(batch_size):
55
+ item_args, item_kwargs = self._prepare_inputs_for_iteration(i, common_inputs, *args, **kwargs)
56
+ item_output = self.model(*item_args, **item_kwargs)
57
+ outputs.append(item_output)
58
+
59
+ return self._process_outputs(outputs, **kwargs)
60
+
61
+ def _forward_batch(self, *args, **kwargs) -> Any:
62
+ raise NotImplementedError("The batch processing logic (_forward_batch) is not implemented in this class.")
63
+
64
+ @abstractmethod
65
+ def _get_batch_size(self, *args, **kwargs) -> int:
66
+ pass
67
+
68
+ @abstractmethod
69
+ def _prepare_inputs_for_iteration(
70
+ self, index: int, common_inputs: Dict[str, Any], *args, **kwargs
71
+ ) -> Tuple[List[Any], Dict[str, Any]]:
72
+ pass
73
+
74
+ def _prepare_inputs_before_loop(self, *args, **kwargs) -> Dict[str, Any]:
75
+ pass
76
+
77
+ @abstractmethod
78
+ def _process_outputs(self, outputs: List[Any], **kwargs) -> Any:
79
+ pass
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .import_utils import check_version_compats, is_rbln_available
16
+ from .runtime_utils import RBLNPytorchRuntime