optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,455 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import Tensor
19
+
20
+
21
+ @torch.library.custom_op(
22
+ "rbln_custom_ops::paged_attn_decode",
23
+ mutates_args=(["kcache", "vcache"]),
24
+ )
25
+ def paged_attn_decode(
26
+ q: Tensor,
27
+ k: Tensor,
28
+ v: Tensor,
29
+ mask: Tensor,
30
+ kcache: Tensor,
31
+ vcache: Tensor,
32
+ seq: Tensor,
33
+ scale: Tensor,
34
+ block_table: Tensor,
35
+ block_size: int,
36
+ ) -> Tensor:
37
+ return torch.empty_like(q)
38
+
39
+
40
+ @paged_attn_decode.register_fake
41
+ def paged_attn_decode_fake(
42
+ q: Tensor,
43
+ k: Tensor,
44
+ v: Tensor,
45
+ mask: Tensor,
46
+ kcache: Tensor,
47
+ vcache: Tensor,
48
+ seq: Tensor,
49
+ scale: Tensor,
50
+ block_table: Tensor,
51
+ block_size: int,
52
+ ) -> Tensor:
53
+ return torch.empty_like(q)
54
+
55
+
56
+ @torch.library.custom_op(
57
+ "rbln_custom_ops::paged_attn_decode_kv_fp8",
58
+ mutates_args=(["kcache", "vcache"]),
59
+ )
60
+ def paged_attn_decode_kv_fp8(
61
+ q: Tensor,
62
+ k: Tensor,
63
+ v: Tensor,
64
+ mask: Tensor,
65
+ kcache: Tensor,
66
+ vcache: Tensor,
67
+ seq: Tensor,
68
+ scale: Tensor,
69
+ block_table: Tensor,
70
+ block_size: int,
71
+ k_scale: Tensor,
72
+ v_scale: Tensor,
73
+ ) -> Tensor:
74
+ return torch.empty_like(q)
75
+
76
+
77
+ @paged_attn_decode_kv_fp8.register_fake
78
+ def paged_attn_decode_kv_fp8_fake(
79
+ q: Tensor,
80
+ k: Tensor,
81
+ v: Tensor,
82
+ mask: Tensor,
83
+ kcache: Tensor,
84
+ vcache: Tensor,
85
+ seq: Tensor,
86
+ scale: Tensor,
87
+ block_table: Tensor,
88
+ block_size: int,
89
+ k_scale: Tensor,
90
+ v_scale: Tensor,
91
+ ) -> Tensor:
92
+ return torch.empty_like(q)
93
+
94
+
95
+ @torch.library.custom_op(
96
+ "rbln_custom_ops::paged_attn_prefill",
97
+ mutates_args=(["kcache", "vcache"]),
98
+ )
99
+ def paged_attn_prefill(
100
+ q: Tensor,
101
+ k: Tensor,
102
+ v: Tensor,
103
+ mask: Tensor,
104
+ kcache: Tensor,
105
+ vcache: Tensor,
106
+ seq: Tensor,
107
+ scale: Tensor,
108
+ block_table: Tensor,
109
+ block_size: int,
110
+ ) -> Tensor:
111
+ """Defines the computation pattern for prefill phase attention with KV cache updates.
112
+
113
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
114
+ a single optimized NPU operation. It is NOT meant for CPU execution.
115
+
116
+ Key differences from decode pattern:
117
+ - Handles prefill phase with multiple input tokens
118
+ - Takes explicit batch index for continuous batching
119
+
120
+ Expected tensor shapes:
121
+ - q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
122
+ - k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
123
+ - v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
124
+ - mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
125
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
126
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
127
+ - seq: [1, 1] - Starting sequence position
128
+ - scale: [] - Attention scale factor
129
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
130
+ - block_size: [] - Number of tokens per block
131
+
132
+ Returns:
133
+ Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
134
+ """
135
+ return torch.empty_like(q)
136
+
137
+
138
+ @paged_attn_prefill.register_fake
139
+ def paged_attn_prefill_fake(
140
+ q: Tensor,
141
+ k: Tensor,
142
+ v: Tensor,
143
+ mask: Tensor,
144
+ kcache: Tensor,
145
+ vcache: Tensor,
146
+ seq: Tensor,
147
+ scale: Tensor,
148
+ block_table: Tensor,
149
+ block_size: int,
150
+ ) -> Tensor:
151
+ return torch.empty_like(q)
152
+
153
+
154
+ @torch.library.custom_op(
155
+ "rbln_custom_ops::paged_attn_prefill_kv_fp8",
156
+ mutates_args=(["kcache", "vcache"]),
157
+ )
158
+ def paged_attn_prefill_kv_fp8(
159
+ q: Tensor,
160
+ k: Tensor,
161
+ v: Tensor,
162
+ mask: Tensor,
163
+ kcache: Tensor,
164
+ vcache: Tensor,
165
+ seq: Tensor,
166
+ scale: Tensor,
167
+ block_table: Tensor,
168
+ block_size: int,
169
+ k_scale: Tensor,
170
+ v_scale: Tensor,
171
+ ) -> Tensor:
172
+ return torch.empty_like(q)
173
+
174
+
175
+ @paged_attn_prefill_kv_fp8.register_fake
176
+ def paged_attn_prefill_kv_fp8_fake(
177
+ q: Tensor,
178
+ k: Tensor,
179
+ v: Tensor,
180
+ mask: Tensor,
181
+ kcache: Tensor,
182
+ vcache: Tensor,
183
+ seq: Tensor,
184
+ scale: Tensor,
185
+ block_table: Tensor,
186
+ block_size: int,
187
+ k_scale: Tensor,
188
+ v_scale: Tensor,
189
+ ) -> Tensor:
190
+ return torch.empty_like(q)
191
+
192
+
193
+ @torch.library.custom_op(
194
+ "rbln_custom_ops::paged_causal_attn_decode",
195
+ mutates_args=(["kcache", "vcache"]),
196
+ )
197
+ def paged_causal_attn_decode(
198
+ q: Tensor,
199
+ k: Tensor,
200
+ v: Tensor,
201
+ kcache: Tensor,
202
+ vcache: Tensor,
203
+ seq: Tensor,
204
+ scale: Tensor,
205
+ block_table: Tensor,
206
+ block_size: int,
207
+ mask: Optional[Tensor] = None,
208
+ ) -> Tensor:
209
+ """Defines the computation pattern for fused attention with KV cache updates.
210
+
211
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
212
+ a single optimized NPU operation. It is NOT meant for CPU execution.
213
+
214
+ Pattern components that compiler fuses into a single op:
215
+ 1. KV cache updates with new key/value states
216
+ 2. Scaled dot-product attention computation
217
+ 3. Causal masked softmax operation
218
+ 4. Final attention output computation
219
+
220
+ Expected tensor shapes:
221
+ - q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
222
+ - k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
223
+ - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
224
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
225
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
226
+ - seq: [1, 1] - Starting sequence position
227
+ - scale: [] - Attention scale factor
228
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
229
+ - block_size: [] - Number of tokens per block
230
+ - mask: [batch=1, max_seq_len] - attention mask when use position_ids
231
+
232
+ Returns:
233
+ Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
234
+ """
235
+ return torch.empty_like(q)
236
+
237
+
238
+ @paged_causal_attn_decode.register_fake
239
+ def paged_causal_attn_decode_fake(
240
+ q: Tensor,
241
+ k: Tensor,
242
+ v: Tensor,
243
+ kcache: Tensor,
244
+ vcache: Tensor,
245
+ seq: Tensor,
246
+ scale: Tensor,
247
+ block_table: Tensor,
248
+ block_size: int,
249
+ mask: Optional[Tensor] = None,
250
+ ) -> Tensor:
251
+ return torch.empty_like(q)
252
+
253
+
254
+ @torch.library.custom_op(
255
+ "rbln_custom_ops::paged_causal_attn_prefill",
256
+ mutates_args=(["kcache", "vcache"]),
257
+ )
258
+ def paged_causal_attn_prefill(
259
+ q: Tensor,
260
+ k: Tensor,
261
+ v: Tensor,
262
+ kcache: Tensor,
263
+ vcache: Tensor,
264
+ seq: Tensor,
265
+ scale: Tensor,
266
+ block_table: Tensor,
267
+ block_size: int,
268
+ is_bidirectional: bool,
269
+ mask: Optional[Tensor] = None,
270
+ ) -> Tensor:
271
+ """Defines the computation pattern for prefill phase attention with KV cache updates.
272
+
273
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
274
+ a single optimized NPU operation. It is NOT meant for CPU execution.
275
+
276
+ Key differences from decode pattern:
277
+ - Handles prefill phase with multiple input tokens
278
+ - Takes explicit batch index for continuous batching
279
+
280
+ Expected tensor shapes:
281
+ - q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
282
+ - k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
283
+ - v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
284
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
285
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
286
+ - batch: [1] - Batch index for cache access
287
+ - seq: [1, 1] - Starting sequence position
288
+ - scale: [] - Attention scale factor
289
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
290
+ - block_size: [] - Number of tokens per block
291
+ - is_bidirectional: [] - Whether the attention is bidirectional at current sequence position
292
+ - mask: [batch=1, max_seq_len] - attention mask when use position_ids
293
+
294
+ Returns:
295
+ Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
296
+ """
297
+ return torch.empty_like(q)
298
+
299
+
300
+ @paged_causal_attn_prefill.register_fake
301
+ def paged_causal_attn_prefill_fake(
302
+ q: Tensor,
303
+ k: Tensor,
304
+ v: Tensor,
305
+ kcache: Tensor,
306
+ vcache: Tensor,
307
+ seq: Tensor,
308
+ scale: Tensor,
309
+ block_table: Tensor,
310
+ block_size: int,
311
+ is_bidirectional: bool,
312
+ mask: Optional[Tensor] = None,
313
+ ) -> Tensor:
314
+ return torch.empty_like(q)
315
+
316
+
317
+ @torch.library.custom_op(
318
+ "rbln_custom_ops::paged_causal_attn_decode_kv_fp8",
319
+ mutates_args=(["kcache", "vcache"]),
320
+ )
321
+ def paged_causal_attn_decode_kv_fp8(
322
+ q: Tensor,
323
+ k: Tensor,
324
+ v: Tensor,
325
+ kcache: Tensor,
326
+ vcache: Tensor,
327
+ seq: Tensor,
328
+ scale: Tensor,
329
+ block_table: Tensor,
330
+ block_size: int,
331
+ k_scale: Tensor,
332
+ v_scale: Tensor,
333
+ mask: Optional[Tensor] = None,
334
+ ) -> Tensor:
335
+ return torch.empty_like(q)
336
+
337
+
338
+ @paged_causal_attn_decode_kv_fp8.register_fake
339
+ def paged_causal_attn_decode_kv_fp8_fake(
340
+ q: Tensor,
341
+ k: Tensor,
342
+ v: Tensor,
343
+ kcache: Tensor,
344
+ vcache: Tensor,
345
+ seq: Tensor,
346
+ scale: Tensor,
347
+ block_table: Tensor,
348
+ block_size: int,
349
+ k_scale: Tensor,
350
+ v_scale: Tensor,
351
+ mask: Optional[Tensor] = None,
352
+ ) -> Tensor:
353
+ return torch.empty_like(q)
354
+
355
+
356
+ @torch.library.custom_op(
357
+ "rbln_custom_ops::paged_causal_attn_prefill_kv_fp8",
358
+ mutates_args=(["kcache", "vcache"]),
359
+ )
360
+ def paged_causal_attn_prefill_kv_fp8(
361
+ q: Tensor,
362
+ k: Tensor,
363
+ v: Tensor,
364
+ kcache: Tensor,
365
+ vcache: Tensor,
366
+ seq: Tensor,
367
+ scale: Tensor,
368
+ block_table: Tensor,
369
+ block_size: int,
370
+ is_bidirectional: bool,
371
+ k_scale: Tensor,
372
+ v_scale: Tensor,
373
+ mask: Optional[Tensor] = None,
374
+ ) -> Tensor:
375
+ return torch.empty_like(q)
376
+
377
+
378
+ @paged_causal_attn_prefill_kv_fp8.register_fake
379
+ def paged_causal_attn_prefill_kv_fp8_fake(
380
+ q: Tensor,
381
+ k: Tensor,
382
+ v: Tensor,
383
+ kcache: Tensor,
384
+ vcache: Tensor,
385
+ seq: Tensor,
386
+ scale: Tensor,
387
+ block_table: Tensor,
388
+ block_size: int,
389
+ is_bidirectional: bool,
390
+ k_scale: Tensor,
391
+ v_scale: Tensor,
392
+ mask: Optional[Tensor] = None,
393
+ ) -> Tensor:
394
+ return torch.empty_like(q)
395
+
396
+
397
+ @torch.library.custom_op(
398
+ "rbln_custom_ops::paged_add_softmax_attn_decode",
399
+ mutates_args=(["kcache", "vcache"]),
400
+ )
401
+ def paged_add_softmax_attn_decode(
402
+ q: Tensor,
403
+ k: Tensor,
404
+ v: Tensor,
405
+ mask: Tensor,
406
+ kcache: Tensor,
407
+ vcache: Tensor,
408
+ seq: Tensor,
409
+ scale: Tensor,
410
+ block_table: Tensor,
411
+ block_size: int,
412
+ ) -> Tensor:
413
+ """Defines the computation pattern for fused attention with KV cache updates.
414
+
415
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
416
+ a single optimized NPU operation. It is NOT meant for CPU execution.
417
+
418
+ Pattern components that compiler fuses into a single op:
419
+ 1. KV cache updates with new key/value states
420
+ 2. Scaled dot-product attention computation
421
+ 3. add-softmax operation
422
+ 4. Final attention output computation
423
+
424
+ Expected tensor shapes:
425
+ - q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
426
+ - k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
427
+ - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
428
+ - mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
429
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
430
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
431
+ - seq: [1] - Current sequence position
432
+ - scale: [] - Attention scale factor
433
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
434
+ - block_size: [] - Number of tokens per block
435
+
436
+ Returns:
437
+ Tensor: attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
438
+ """
439
+ return torch.empty_like(q)
440
+
441
+
442
+ @paged_add_softmax_attn_decode.register_fake
443
+ def paged_add_softmax_attn_decode_fake(
444
+ q: Tensor,
445
+ k: Tensor,
446
+ v: Tensor,
447
+ mask: Tensor,
448
+ kcache: Tensor,
449
+ vcache: Tensor,
450
+ seq: Tensor,
451
+ scale: Tensor,
452
+ block_table: Tensor,
453
+ block_size: int,
454
+ ) -> Tensor:
455
+ return torch.empty_like(q)