optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,350 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import Tensor
19
+
20
+
21
+ @torch.library.custom_op(
22
+ "rbln_custom_ops::paged_flash_attn_decode",
23
+ mutates_args=(["kcache", "vcache"]),
24
+ )
25
+ def paged_flash_attn_decode(
26
+ q: Tensor,
27
+ k: Tensor,
28
+ v: Tensor,
29
+ mask: Tensor,
30
+ kcache: Tensor,
31
+ vcache: Tensor,
32
+ seq: Tensor,
33
+ scale: Tensor,
34
+ block_table: Tensor,
35
+ block_size: int,
36
+ partition: int,
37
+ ) -> Tensor:
38
+ """Defines the computation pattern for fused flash attention with KV cache for decoding.
39
+
40
+ Returns a tensor with the same shape as q.
41
+ """
42
+ return torch.empty_like(q)
43
+
44
+
45
+ @paged_flash_attn_decode.register_fake
46
+ def paged_flash_attn_decode_fake(
47
+ q: Tensor,
48
+ k: Tensor,
49
+ v: Tensor,
50
+ mask: Tensor,
51
+ kcache: Tensor,
52
+ vcache: Tensor,
53
+ seq: Tensor,
54
+ scale: Tensor,
55
+ block_table: Tensor,
56
+ block_size: int,
57
+ partition: int,
58
+ ) -> Tensor:
59
+ return torch.empty_like(q)
60
+
61
+
62
+ @torch.library.custom_op(
63
+ "rbln_custom_ops::paged_flash_attn_decode_kv_fp8",
64
+ mutates_args=(["kcache", "vcache"]),
65
+ )
66
+ def paged_flash_attn_decode_kv_fp8(
67
+ q: Tensor,
68
+ k: Tensor,
69
+ v: Tensor,
70
+ mask: Tensor,
71
+ kcache: Tensor,
72
+ vcache: Tensor,
73
+ seq: Tensor,
74
+ scale: Tensor,
75
+ block_table: Tensor,
76
+ block_size: int,
77
+ partition: int,
78
+ k_scale: Tensor,
79
+ v_scale: Tensor,
80
+ ) -> Tensor:
81
+ return torch.empty_like(q)
82
+
83
+
84
+ @paged_flash_attn_decode_kv_fp8.register_fake
85
+ def paged_flash_attn_decode_kv_fp8_fake(
86
+ q: Tensor,
87
+ k: Tensor,
88
+ v: Tensor,
89
+ mask: Tensor,
90
+ kcache: Tensor,
91
+ vcache: Tensor,
92
+ seq: Tensor,
93
+ scale: Tensor,
94
+ block_table: Tensor,
95
+ block_size: int,
96
+ partition: int,
97
+ k_scale: Tensor,
98
+ v_scale: Tensor,
99
+ ) -> Tensor:
100
+ return torch.empty_like(q)
101
+
102
+
103
+ @torch.library.custom_op(
104
+ "rbln_custom_ops::paged_flash_attn_prefill",
105
+ mutates_args=(["kcache", "vcache"]),
106
+ )
107
+ def paged_flash_attn_prefill(
108
+ q: Tensor,
109
+ k: Tensor,
110
+ v: Tensor,
111
+ mask: Tensor,
112
+ kcache: Tensor,
113
+ vcache: Tensor,
114
+ seq: Tensor,
115
+ scale: Tensor,
116
+ block_table: Tensor,
117
+ block_size: int,
118
+ partition: int,
119
+ ) -> Tensor:
120
+ """Defines the computation pattern for fused flash attention with KV cache for prefill.
121
+
122
+ Returns a tensor with the same shape as q.
123
+ """
124
+ return torch.empty_like(q)
125
+
126
+
127
+ @paged_flash_attn_prefill.register_fake
128
+ def paged_flash_attn_prefill_fake(
129
+ q: Tensor,
130
+ k: Tensor,
131
+ v: Tensor,
132
+ mask: Tensor,
133
+ kcache: Tensor,
134
+ vcache: Tensor,
135
+ seq: Tensor,
136
+ scale: Tensor,
137
+ block_table: Tensor,
138
+ block_size: int,
139
+ partition: int,
140
+ ) -> Tensor:
141
+ return torch.empty_like(q)
142
+
143
+
144
+ @torch.library.custom_op(
145
+ "rbln_custom_ops::paged_flash_attn_prefill_kv_fp8",
146
+ mutates_args=(["kcache", "vcache"]),
147
+ )
148
+ def paged_flash_attn_prefill_kv_fp8(
149
+ q: Tensor,
150
+ k: Tensor,
151
+ v: Tensor,
152
+ mask: Tensor,
153
+ kcache: Tensor,
154
+ vcache: Tensor,
155
+ seq: Tensor,
156
+ scale: Tensor,
157
+ block_table: Tensor,
158
+ block_size: int,
159
+ partition: int,
160
+ k_scale: Tensor,
161
+ v_scale: Tensor,
162
+ ) -> Tensor:
163
+ return torch.empty_like(q)
164
+
165
+
166
+ @paged_flash_attn_prefill_kv_fp8.register_fake
167
+ def paged_flash_attn_prefill_kv_fp8_fake(
168
+ q: Tensor,
169
+ k: Tensor,
170
+ v: Tensor,
171
+ mask: Tensor,
172
+ kcache: Tensor,
173
+ vcache: Tensor,
174
+ seq: Tensor,
175
+ scale: Tensor,
176
+ block_table: Tensor,
177
+ block_size: int,
178
+ partition: int,
179
+ k_scale: Tensor,
180
+ v_scale: Tensor,
181
+ ) -> Tensor:
182
+ return torch.empty_like(q)
183
+
184
+
185
+ @torch.library.custom_op(
186
+ "rbln_custom_ops::paged_flash_causal_attn_decode",
187
+ mutates_args=(["kcache", "vcache"]),
188
+ )
189
+ def paged_flash_causal_attn_decode(
190
+ q: Tensor,
191
+ k: Tensor,
192
+ v: Tensor,
193
+ kcache: Tensor,
194
+ vcache: Tensor,
195
+ seq: Tensor,
196
+ scale: Tensor,
197
+ block_table: Tensor,
198
+ block_size: int,
199
+ partition: int,
200
+ mask: Optional[Tensor] = None,
201
+ ) -> Tensor:
202
+ """Defines the computation pattern for fused causal flash attention with KV cache for decoding.
203
+
204
+ Returns a tensor with the same shape as q.
205
+ """
206
+ return torch.empty_like(q)
207
+
208
+
209
+ @paged_flash_causal_attn_decode.register_fake
210
+ def paged_flash_causal_attn_decode_fake(
211
+ q: Tensor,
212
+ k: Tensor,
213
+ v: Tensor,
214
+ kcache: Tensor,
215
+ vcache: Tensor,
216
+ seq: Tensor,
217
+ scale: Tensor,
218
+ block_table: Tensor,
219
+ block_size: int,
220
+ partition: int,
221
+ mask: Optional[Tensor] = None,
222
+ ) -> Tensor:
223
+ return torch.empty_like(q)
224
+
225
+
226
+ @torch.library.custom_op(
227
+ "rbln_custom_ops::paged_flash_causal_attn_decode_kv_fp8",
228
+ mutates_args=(["kcache", "vcache"]),
229
+ )
230
+ def paged_flash_causal_attn_decode_kv_fp8(
231
+ q: Tensor,
232
+ k: Tensor,
233
+ v: Tensor,
234
+ kcache: Tensor,
235
+ vcache: Tensor,
236
+ seq: Tensor,
237
+ scale: Tensor,
238
+ block_table: Tensor,
239
+ block_size: int,
240
+ partition: int,
241
+ k_scale: Tensor,
242
+ v_scale: Tensor,
243
+ mask: Optional[Tensor] = None,
244
+ ) -> Tensor:
245
+ return torch.empty_like(q)
246
+
247
+
248
+ @paged_flash_causal_attn_decode_kv_fp8.register_fake
249
+ def paged_flash_causal_attn_decode_kv_fp8_fake(
250
+ q: Tensor,
251
+ k: Tensor,
252
+ v: Tensor,
253
+ kcache: Tensor,
254
+ vcache: Tensor,
255
+ seq: Tensor,
256
+ scale: Tensor,
257
+ block_table: Tensor,
258
+ block_size: int,
259
+ partition: int,
260
+ k_scale: Tensor,
261
+ v_scale: Tensor,
262
+ mask: Optional[Tensor] = None,
263
+ ) -> Tensor:
264
+ return torch.empty_like(q)
265
+
266
+
267
+ @torch.library.custom_op(
268
+ "rbln_custom_ops::paged_flash_causal_attn_prefill",
269
+ mutates_args=(["kcache", "vcache"]),
270
+ )
271
+ def paged_flash_causal_attn_prefill(
272
+ q: Tensor,
273
+ k: Tensor,
274
+ v: Tensor,
275
+ kcache: Tensor,
276
+ vcache: Tensor,
277
+ seq: Tensor,
278
+ scale: Tensor,
279
+ block_table: Tensor,
280
+ block_size: int,
281
+ partition: int,
282
+ is_bidirectional: bool,
283
+ mask: Optional[Tensor] = None,
284
+ ) -> Tensor:
285
+ """Defines the computation pattern for fused causal flash attention with KV cache for prefill.
286
+
287
+ Returns a tensor with the same shape as q.
288
+ """
289
+ return torch.empty_like(q)
290
+
291
+
292
+ @paged_flash_causal_attn_prefill.register_fake
293
+ def paged_flash_causal_attn_prefill_fake(
294
+ q: Tensor,
295
+ k: Tensor,
296
+ v: Tensor,
297
+ kcache: Tensor,
298
+ vcache: Tensor,
299
+ seq: Tensor,
300
+ scale: Tensor,
301
+ block_table: Tensor,
302
+ block_size: int,
303
+ partition: int,
304
+ is_bidirectional: bool,
305
+ mask: Optional[Tensor] = None,
306
+ ) -> Tensor:
307
+ return torch.empty_like(q)
308
+
309
+
310
+ @torch.library.custom_op(
311
+ "rbln_custom_ops::paged_flash_causal_attn_prefill_kv_fp8",
312
+ mutates_args=(["kcache", "vcache"]),
313
+ )
314
+ def paged_flash_causal_attn_prefill_kv_fp8(
315
+ q: Tensor,
316
+ k: Tensor,
317
+ v: Tensor,
318
+ kcache: Tensor,
319
+ vcache: Tensor,
320
+ seq: Tensor,
321
+ scale: Tensor,
322
+ block_table: Tensor,
323
+ block_size: int,
324
+ partition: int,
325
+ is_bidirectional: bool,
326
+ k_scale: Tensor,
327
+ v_scale: Tensor,
328
+ mask: Optional[Tensor] = None,
329
+ ) -> Tensor:
330
+ return torch.empty_like(q)
331
+
332
+
333
+ @paged_flash_causal_attn_prefill_kv_fp8.register_fake
334
+ def paged_flash_causal_attn_prefill_kv_fp8_fake(
335
+ q: Tensor,
336
+ k: Tensor,
337
+ v: Tensor,
338
+ kcache: Tensor,
339
+ vcache: Tensor,
340
+ seq: Tensor,
341
+ scale: Tensor,
342
+ block_table: Tensor,
343
+ block_size: int,
344
+ partition: int,
345
+ is_bidirectional: bool,
346
+ k_scale: Tensor,
347
+ v_scale: Tensor,
348
+ mask: Optional[Tensor] = None,
349
+ ) -> Tensor:
350
+ return torch.empty_like(q)
@@ -0,0 +1,29 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch import Tensor
17
+
18
+
19
+ @torch.library.custom_op("rbln_custom_ops::rbln_cache_update", mutates_args=(["cache"]))
20
+ def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
21
+ # Define the RBLN custom operation "rbln_cache_update" which updates a cache tensor with a given state tensor.
22
+ # This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
23
+ # The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
24
+ return torch.empty_like(cache)
25
+
26
+
27
+ @rbln_cache_update.register_fake
28
+ def rbln_cache_update_fake(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
29
+ return torch.empty_like(cache)
@@ -0,0 +1,32 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import Tensor
19
+
20
+
21
+ @torch.library.custom_op("rbln_custom_ops::linear", mutates_args=())
22
+ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
23
+ output_shape = list(input.shape[:-1])
24
+ output_shape += [weight.shape[0]]
25
+ return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
26
+
27
+
28
+ @linear.register_fake
29
+ def linear_fake(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
30
+ output_shape = list(input.shape[:-1])
31
+ output_shape += [weight.shape[0]]
32
+ return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
@@ -0,0 +1,111 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ from torch import Tensor
18
+
19
+
20
+ @torch.library.custom_op(
21
+ "rbln_custom_ops::paged_sliding_window_attn_prefill",
22
+ mutates_args=(["kcache", "vcache"]),
23
+ )
24
+ def paged_sliding_window_attn_prefill(
25
+ q: Tensor,
26
+ k: Tensor,
27
+ v: Tensor,
28
+ kcache: Tensor,
29
+ vcache: Tensor,
30
+ cache_seq_len: Tensor,
31
+ cache_offset: Tensor,
32
+ scale: Tensor,
33
+ block_table: Tensor,
34
+ block_size: int,
35
+ is_bidirectional: bool,
36
+ ) -> Tensor:
37
+ """Defines the computation pattern for prefill phase attention with KV cache updates.
38
+
39
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
40
+ a single optimized NPU operation. It is NOT meant for CPU execution.
41
+
42
+ Key differences from decode pattern:
43
+ - Handles prefill phase with multiple input tokens
44
+ - Takes explicit batch index for continuous batching
45
+
46
+ Expected tensor shapes:
47
+ - q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
48
+ - k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
49
+ - v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
50
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
51
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
52
+ - cache_seq_len: [] - the sequence length of the cached states that were seen by the model
53
+ - cache_offset: [] - The valid length in the combined sequence of the KV cache and the current projected key states.
54
+ - scale: [] - Attention scale factor
55
+ - is_bidirectional: [] - Whether the attention is bidirectional
56
+ Returns:
57
+ Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
58
+ """
59
+ return torch.empty_like(q)
60
+
61
+
62
+ @paged_sliding_window_attn_prefill.register_fake
63
+ def paged_sliding_window_attn_prefill_fake(
64
+ q: Tensor,
65
+ k: Tensor,
66
+ v: Tensor,
67
+ kcache: Tensor,
68
+ vcache: Tensor,
69
+ cache_seq_len: Tensor,
70
+ cache_offset: Tensor,
71
+ scale: Tensor,
72
+ block_table: Tensor,
73
+ block_size: int,
74
+ is_bidirectional: bool,
75
+ ) -> Tensor:
76
+ return torch.empty_like(q)
77
+
78
+
79
+ @torch.library.custom_op(
80
+ "rbln_custom_ops::paged_sliding_window_attn_decode",
81
+ mutates_args=(["kcache", "vcache"]),
82
+ )
83
+ def paged_sliding_window_attn_decode(
84
+ q: Tensor,
85
+ k: Tensor,
86
+ v: Tensor,
87
+ kcache: Tensor,
88
+ vcache: Tensor,
89
+ cache_seq_len: Tensor,
90
+ cache_offset: Tensor,
91
+ scale: Tensor,
92
+ block_table: Tensor,
93
+ block_size: int,
94
+ ) -> Tensor:
95
+ return torch.empty_like(q)
96
+
97
+
98
+ @paged_sliding_window_attn_decode.register_fake
99
+ def paged_sliding_window_attn_decode_fake(
100
+ q: Tensor,
101
+ k: Tensor,
102
+ v: Tensor,
103
+ kcache: Tensor,
104
+ vcache: Tensor,
105
+ cache_seq_len: Tensor,
106
+ cache_offset: Tensor,
107
+ scale: Tensor,
108
+ block_table: Tensor,
109
+ block_size: int,
110
+ ) -> Tensor:
111
+ return torch.empty_like(q)