optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,37 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ from transformers.modeling_outputs import ModelOutput
20
+
21
+
22
+ @dataclass
23
+ class RBLNDecoderOnlyOutput(ModelOutput):
24
+ logits: torch.FloatTensor = None
25
+ generate_idx: torch.Tensor = None
26
+ padded_cache_lengths: int = None
27
+
28
+
29
+ @dataclass
30
+ class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
31
+ attention_mask: Optional[torch.Tensor] = None
32
+
33
+
34
+ @dataclass
35
+ class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
36
+ last_hidden_states: torch.FloatTensor = None
37
+ params: Tuple[torch.FloatTensor] = None
@@ -0,0 +1,314 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright 2025 Rebellions Inc. All rights reserved.
16
+
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at:
20
+
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+
29
+ import math
30
+ from typing import Optional, Tuple
31
+
32
+ import torch
33
+ from transformers import PretrainedConfig
34
+
35
+
36
+ def _compute_default_rope_parameters(
37
+ config: Optional[PretrainedConfig] = None,
38
+ seq_len: Optional[int] = None,
39
+ ) -> Tuple["torch.Tensor", float]:
40
+ """
41
+ Computes the inverse frequencies according to the original RoPE implementation
42
+ Args:
43
+ config ([`~transformers.PretrainedConfig`]):
44
+ The model configuration.
45
+ seq_len (`int`, *optional*):
46
+ The current sequence length. Unused for this type of RoPE.
47
+ Returns:
48
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
49
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
50
+ """
51
+ base = config.rope_theta
52
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
53
+ head_dim = (
54
+ config.head_dim
55
+ if hasattr(config, "head_dim") and config.head_dim is not None
56
+ else config.hidden_size // config.num_attention_heads
57
+ )
58
+ dim = int(head_dim * partial_rotary_factor)
59
+
60
+ attention_factor = 1.0 # Unused in this type of RoPE
61
+
62
+ # Compute the inverse frequencies
63
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
64
+ return inv_freq, attention_factor
65
+
66
+
67
+ def _compute_linear_scaling_rope_parameters(
68
+ config: Optional[PretrainedConfig] = None,
69
+ seq_len: Optional[int] = None,
70
+ ) -> Tuple["torch.Tensor", float]:
71
+ """
72
+ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
73
+ Args:
74
+ config ([`~transformers.PretrainedConfig`]):
75
+ The model configuration.
76
+ seq_len (`int`, *optional*):
77
+ The current sequence length. Unused for this type of RoPE.
78
+ Returns:
79
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
80
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
81
+ """
82
+
83
+ factor = config.rope_scaling["factor"]
84
+
85
+ # Gets the default RoPE parameters
86
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
87
+
88
+ # Then applies linear scaling to the frequencies.
89
+ # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
90
+ # applying scaling to the inverse frequencies is equivalent.
91
+ inv_freq /= factor
92
+ return inv_freq, attention_factor
93
+
94
+
95
+ def _compute_dynamic_ntk_parameters(
96
+ config: Optional[PretrainedConfig] = None,
97
+ seq_len: Optional[int] = None,
98
+ ) -> Tuple["torch.Tensor", float]:
99
+ """
100
+ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
101
+ Args:
102
+ config ([`~transformers.PretrainedConfig`]):
103
+ The model configuration.
104
+ seq_len (`int`, *optional*):
105
+ The current sequence length, used to update the dynamic RoPE at inference time.
106
+ Returns:
107
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
108
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
109
+ """
110
+
111
+ base = config.rope_theta
112
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
113
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
114
+ dim = int(head_dim * partial_rotary_factor)
115
+ max_position_embeddings = config.max_position_embeddings
116
+ factor = config.rope_scaling["factor"]
117
+
118
+ attention_factor = 1.0 # Unused in this type of RoPE
119
+
120
+ # Process with chunk_size to reduce precesion error
121
+ chunk_size = 4096
122
+ chunks = (seq_len + chunk_size - 1) // chunk_size
123
+
124
+ inv_freq_list = []
125
+ for i in range(chunks):
126
+ start = i * chunk_size
127
+ end = min((i + 1) * chunk_size, seq_len)
128
+
129
+ seq_lens = torch.arange(start, end, dtype=torch.float32).view(-1, 1) + 1.0
130
+ seq_lens = torch.where(seq_lens > max_position_embeddings, seq_lens, max_position_embeddings)
131
+
132
+ # Compute the inverse frequencies for each chunk
133
+ scaled_base = base * ((factor * seq_lens / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
134
+ inv_freq = 1.0 / (scaled_base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
135
+
136
+ inv_freq_list.append(inv_freq)
137
+
138
+ final_inv_freq = torch.cat(inv_freq_list, dim=0)
139
+
140
+ return final_inv_freq, attention_factor
141
+
142
+
143
+ def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> Tuple["torch.Tensor", float]:
144
+ """
145
+ Computes the inverse frequencies with NTK scaling. Please refer to the
146
+ [original paper](https://arxiv.org/abs/2309.00071)
147
+ Args:
148
+ config ([`~transformers.PretrainedConfig`]):
149
+ The model configuration.
150
+ seq_len (`int`, *optional*):
151
+ The current sequence length. Unused for this type of RoPE.
152
+ Returns:
153
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
154
+ post-processing scaling factor applied to the computed cos/sin.
155
+ """
156
+
157
+ base = config.rope_theta
158
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
159
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
160
+ dim = int(head_dim * partial_rotary_factor)
161
+ max_position_embeddings = config.max_position_embeddings
162
+ factor = config.rope_scaling["factor"]
163
+
164
+ # Sets the attention factor as suggested in the paper
165
+ attention_factor = config.rope_scaling.get("attention_factor")
166
+ if attention_factor is None:
167
+ attention_factor = 0.1 * math.log(factor) + 1.0
168
+
169
+ # Optional config options
170
+ # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
171
+ beta_fast = config.rope_scaling.get("beta_fast") or 32
172
+ beta_slow = config.rope_scaling.get("beta_slow") or 1
173
+
174
+ # Compute the inverse frequencies
175
+ def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
176
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
177
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
178
+
179
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
180
+ """Find dimension range bounds based on rotations"""
181
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
182
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
183
+ return max(low, 0), min(high, dim - 1)
184
+
185
+ def linear_ramp_factor(min, max, dim):
186
+ if min == max:
187
+ max += 0.001 # Prevent singularity
188
+
189
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
190
+ ramp_func = torch.clamp(linear_func, 0, 1)
191
+ return ramp_func
192
+
193
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
194
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
195
+ pos_freqs = base ** (torch.arange(0, dim, 2).float() / dim)
196
+ inv_freq_extrapolation = 1.0 / pos_freqs
197
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs)
198
+
199
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
200
+
201
+ # Get n-dimensional rotational scaling corrected for extrapolation
202
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float()
203
+ inv_freq = (
204
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
205
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
206
+ )
207
+
208
+ return inv_freq, attention_factor
209
+
210
+
211
+ def _compute_longrope_parameters(
212
+ config: PretrainedConfig, seq_len: Optional[int] = None
213
+ ) -> Tuple["torch.Tensor", float]:
214
+ """
215
+ Computes the inverse frequencies with LongRoPE scaling. Please refer to the
216
+ [original implementation](https://github.com/microsoft/LongRoPE)
217
+ Args:
218
+ config ([`~transformers.PretrainedConfig`]):
219
+ The model configuration.
220
+ seq_len (`int`, *optional*):
221
+ The current sequence length. Unused for this type of RoPE.
222
+ Returns:
223
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
224
+ post-processing scaling factor applied to the computed cos/sin.
225
+ """
226
+
227
+ base = config.rope_theta
228
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
229
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
230
+ dim = int(head_dim * partial_rotary_factor)
231
+ long_factor = config.rope_scaling["long_factor"]
232
+ short_factor = config.rope_scaling["short_factor"]
233
+ factor = config.rope_scaling.get("factor")
234
+ attention_factor = config.rope_scaling.get("attention_factor")
235
+
236
+ # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
237
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
238
+ # values to compute the default attention scaling factor, instead of using `factor`.
239
+ if hasattr(config, "original_max_position_embeddings"):
240
+ max_position_embeddings = config.original_max_position_embeddings
241
+ expanded_max_position_embeddings = config.max_position_embeddings
242
+ factor = expanded_max_position_embeddings / max_position_embeddings
243
+ else:
244
+ max_position_embeddings = config.max_position_embeddings
245
+ expanded_max_position_embeddings = max_position_embeddings * factor
246
+
247
+ # Sets the attention factor as suggested in the paper
248
+ if attention_factor is None:
249
+ if factor <= 1.0:
250
+ attention_factor = 1.0
251
+ else:
252
+ attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
253
+
254
+ # Compute the inverse frequencies -- scaled based on the target sequence length
255
+ if expanded_max_position_embeddings > max_position_embeddings:
256
+ ext_factors = torch.tensor(long_factor, dtype=torch.float32)
257
+ else:
258
+ ext_factors = torch.tensor(short_factor, dtype=torch.float32)
259
+ inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
260
+ inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
261
+
262
+ return inv_freq, attention_factor
263
+
264
+
265
+ def _compute_llama3_parameters(
266
+ config: PretrainedConfig, seq_len: Optional[int] = None
267
+ ) -> Tuple["torch.Tensor", float]:
268
+ """
269
+ Computes the inverse frequencies for llama 3.1.
270
+
271
+ Args:
272
+ config ([`~transformers.PretrainedConfig`]):
273
+ The model configuration.
274
+ seq_len (`int`, *optional*):
275
+ The current sequence length. Unused for this type of RoPE.
276
+ Returns:
277
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
278
+ post-processing scaling factor applied to the computed cos/sin.
279
+ """
280
+ # Gets the default RoPE parameters
281
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
282
+
283
+ factor = config.rope_scaling["factor"] # `8` in the original implementation
284
+ low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
285
+ high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
286
+ old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
287
+
288
+ low_freq_wavelen = old_context_len / low_freq_factor
289
+ high_freq_wavelen = old_context_len / high_freq_factor
290
+
291
+ wavelen = 2 * math.pi / inv_freq
292
+ # wavelen < high_freq_wavelen: do nothing
293
+ # wavelen > low_freq_wavelen: divide by factor
294
+ inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
295
+ # otherwise: interpolate between the two, using a smooth factor
296
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
297
+ smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
298
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
299
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
300
+
301
+ return inv_freq_llama, attention_factor
302
+
303
+
304
+ # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
305
+ # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
306
+ # parameterizations, as long as the callable has the same signature.
307
+ ROPE_INIT_FUNCTIONS = {
308
+ "default": _compute_default_rope_parameters,
309
+ "linear": _compute_linear_scaling_rope_parameters,
310
+ "dynamic": _compute_dynamic_ntk_parameters,
311
+ "yarn": _compute_yarn_parameters,
312
+ "longrope": _compute_longrope_parameters,
313
+ "llama3": _compute_llama3_parameters,
314
+ }