optimum-rbln 0.7.4a2__tar.gz → 0.7.4a3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/workflows/deploy-on-tag.yaml +1 -1
  2. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/workflows/rbln_optimum_pytest.yaml +3 -1
  3. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/PKG-INFO +1 -1
  4. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/__version__.py +1 -1
  5. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/modeling.py +8 -1
  6. optimum_rbln-0.7.4a3/src/optimum/rbln/ops/__init__.py +18 -0
  7. optimum_rbln-0.7.4a3/src/optimum/rbln/ops/attn.py +287 -0
  8. optimum_rbln-0.7.4a3/src/optimum/rbln/ops/flash_attn.py +176 -0
  9. optimum_rbln-0.7.4a3/src/optimum/rbln/ops/kv_cache_update.py +24 -0
  10. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/bart/__init__.py +1 -0
  11. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/decoderonly/__init__.py +10 -0
  12. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +80 -94
  13. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +17 -13
  14. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  15. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/t5/__init__.py +1 -0
  16. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/t5/modeling_t5.py +3 -37
  17. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/t5/t5_architecture.py +3 -4
  18. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  19. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +12 -22
  20. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  21. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/whisper/modeling_whisper.py +0 -1
  22. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/whisper/whisper_architecture.py +20 -32
  23. optimum_rbln-0.7.4a2/src/optimum/rbln/ops/__init__.py +0 -22
  24. optimum_rbln-0.7.4a2/src/optimum/rbln/ops/attn.py +0 -223
  25. optimum_rbln-0.7.4a2/src/optimum/rbln/ops/flash_attn.py +0 -82
  26. optimum_rbln-0.7.4a2/src/optimum/rbln/ops/kv_cache_update.py +0 -60
  27. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  28. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  29. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  30. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/ISSUE_TEMPLATE/model_request.md +0 -0
  31. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/pull_request_template.md +0 -0
  32. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/scripts/auto_code_review.py +0 -0
  33. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/scripts/validate_pr_checklist.py +0 -0
  34. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/version.yaml +0 -0
  35. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/workflows/auto_code_review.yml +0 -0
  36. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/workflows/check_code_quality.yml +0 -0
  37. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/workflows/deploy.yaml +0 -0
  38. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/workflows/pr-title-check.yaml +0 -0
  39. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/workflows/pr_checklist_validator.yml +0 -0
  40. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/workflows/rbln_check_compiler.yaml +0 -0
  41. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/workflows/rbln_dispatch_pytest.yaml +0 -0
  42. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/workflows/rbln_optimum_inference_test.yaml +0 -0
  43. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/workflows/rbln_scheduled_test.yaml +0 -0
  44. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.github/workflows/rbln_trigger_on_pr.yaml +0 -0
  45. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/.gitignore +0 -0
  46. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/CODE_OF_CONDUCT.md +0 -0
  47. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/CONTRIBUTING.md +0 -0
  48. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/LICENSE +0 -0
  49. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/README.md +0 -0
  50. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/assets/rbln_logo.png +0 -0
  51. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/audio-classification/run_ast_audio_classification.py +0 -0
  52. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/depth-estimation/run_dpt.py +0 -0
  53. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/image-classification/run_image_classification.py +0 -0
  54. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/image-classification/run_vit_image_classification.py +0 -0
  55. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/image-to-text/run_llava_next_image_to_text.py +0 -0
  56. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/kandinsky2_2/run_kandinsky2_2.py +0 -0
  57. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/kandinsky2_2/run_kandinsky2_2_combined.py +0 -0
  58. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/kandinsky2_2/run_kandinsky2_2_img2img.py +0 -0
  59. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/kandinsky2_2/run_kandinsky2_2_img2img_combined.py +0 -0
  60. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/kandinsky2_2/run_kandinsky2_2_inpaint.py +0 -0
  61. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/kandinsky2_2/run_kandinsky2_2_inpaint_combined.py +0 -0
  62. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/kandinsky2_2/run_kandinsky2_2_prior_interpolate.py +0 -0
  63. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/question-answering/run_question_answering.py +0 -0
  64. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/speech-recognition/run_wav2vec2.py +0 -0
  65. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/speech-recognition/run_whisper.py +0 -0
  66. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/stable-diffusion/run_stable_diffusion.py +0 -0
  67. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/stable-diffusion/run_stable_diffusion_controlnet.py +0 -0
  68. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/stable-diffusion/run_stable_diffusion_img2img.py +0 -0
  69. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/stable-diffusion/run_stable_diffusion_img2img_controlnet.py +0 -0
  70. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/stable-diffusion/run_stable_diffusion_inpaint.py +0 -0
  71. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/stable-diffusion/run_stable_diffusion_lora.py +0 -0
  72. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/stable-diffusion/run_stable_diffusion_multicontrolnet.py +0 -0
  73. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/text-classification/run_bge_m3_text_classification.py +0 -0
  74. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/text-classification/run_bge_reranker_v2_m3_text_classification.py +0 -0
  75. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/text-classification/run_secureBERT.py +0 -0
  76. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/text-classification/run_t5_classification.py +0 -0
  77. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/text-classification/run_twitter_roberta_text_classification.py +0 -0
  78. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/text2text-generation/run_bart_text2text_generation.py +0 -0
  79. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/text2text-generation/run_llama_peft.py +0 -0
  80. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/text2text-generation/run_llama_text2text_generation.py +0 -0
  81. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/examples/time-series-forecasting/run_time_series_forecasting.py +0 -0
  82. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/pyproject.toml +0 -0
  83. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/scripts/uv-lock.sh +0 -0
  84. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/scripts/uv-sync.sh +0 -0
  85. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/__init__.py +0 -0
  86. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/__init__.py +0 -0
  87. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/modeling_diffusers.py +0 -0
  88. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/models/__init__.py +0 -0
  89. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/models/autoencoders/__init__.py +0 -0
  90. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +0 -0
  91. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/models/autoencoders/vae.py +0 -0
  92. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/models/autoencoders/vq_model.py +0 -0
  93. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/models/controlnet.py +0 -0
  94. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/models/transformers/__init__.py +0 -0
  95. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/models/transformers/prior_transformer.py +0 -0
  96. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/models/transformers/transformer_sd3.py +0 -0
  97. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/models/unets/__init__.py +0 -0
  98. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/models/unets/unet_2d_condition.py +0 -0
  99. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/__init__.py +0 -0
  100. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/controlnet/__init__.py +0 -0
  101. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +0 -0
  102. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +0 -0
  103. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +0 -0
  104. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +0 -0
  105. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +0 -0
  106. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +0 -0
  107. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +0 -0
  108. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +0 -0
  109. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +0 -0
  110. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +0 -0
  111. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +0 -0
  112. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +0 -0
  113. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +0 -0
  114. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -0
  115. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +0 -0
  116. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +0 -0
  117. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +0 -0
  118. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +0 -0
  119. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +0 -0
  120. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +0 -0
  121. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -0
  122. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -0
  123. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -0
  124. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/modeling_base.py +0 -0
  125. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/modeling_config.py +0 -0
  126. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/ops/linear.py +0 -0
  127. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/__init__.py +0 -0
  128. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/modeling_alias.py +0 -0
  129. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/modeling_generic.py +0 -0
  130. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/modeling_rope_utils.py +0 -0
  131. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/__init__.py +0 -0
  132. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/auto/__init__.py +0 -0
  133. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/auto/auto_factory.py +0 -0
  134. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/auto/modeling_auto.py +0 -0
  135. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/bart/bart_architecture.py +0 -0
  136. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/bart/modeling_bart.py +0 -0
  137. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/bert/__init__.py +0 -0
  138. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/bert/modeling_bert.py +0 -0
  139. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/clip/__init__.py +0 -0
  140. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/clip/modeling_clip.py +0 -0
  141. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +0 -0
  142. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/dpt/__init__.py +0 -0
  143. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -0
  144. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/exaone/__init__.py +0 -0
  145. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -0
  146. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/exaone/modeling_exaone.py +0 -0
  147. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/gemma/__init__.py +0 -0
  148. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/gemma/gemma_architecture.py +0 -0
  149. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/gemma/modeling_gemma.py +0 -0
  150. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/gpt2/__init__.py +0 -0
  151. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +0 -0
  152. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +0 -0
  153. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/llama/__init__.py +0 -0
  154. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/llama/llama_architecture.py +0 -0
  155. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/llama/modeling_llama.py +0 -0
  156. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/llava_next/__init__.py +0 -0
  157. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +0 -0
  158. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/midm/__init__.py +0 -0
  159. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/midm/midm_architecture.py +0 -0
  160. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/midm/modeling_midm.py +0 -0
  161. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/mistral/__init__.py +0 -0
  162. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -0
  163. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -0
  164. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/phi/__init__.py +0 -0
  165. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/phi/modeling_phi.py +0 -0
  166. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/phi/phi_architecture.py +0 -0
  167. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/qwen2/__init__.py +0 -0
  168. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -0
  169. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -0
  170. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/seq2seq/__init__.py +0 -0
  171. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +0 -0
  172. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/wav2vec2/__init__.py +0 -0
  173. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -0
  174. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/whisper/generation_whisper.py +0 -0
  175. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/xlm_roberta/__init__.py +0 -0
  176. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +0 -0
  177. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/utils/__init__.py +0 -0
  178. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/transformers/utils/rbln_quantization.py +0 -0
  179. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/utils/__init__.py +0 -0
  180. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/utils/decorator_utils.py +0 -0
  181. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/utils/hub.py +0 -0
  182. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/utils/import_utils.py +0 -0
  183. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/utils/logging.py +0 -0
  184. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/utils/model_utils.py +0 -0
  185. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/utils/runtime_utils.py +0 -0
  186. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/utils/save_utils.py +0 -0
  187. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/src/optimum/rbln/utils/submodule.py +0 -0
  188. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/tests/__init__.py +0 -0
  189. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/tests/psnr.py +0 -0
  190. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/tests/requirements_sdxl.txt +0 -0
  191. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/tests/run_stable_diffusion_xl_base.py +0 -0
  192. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/tests/test_base.py +0 -0
  193. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/tests/test_diffusers.py +0 -0
  194. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/tests/test_llm.py +0 -0
  195. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/tests/test_transformers.py +0 -0
  196. {optimum_rbln-0.7.4a2 → optimum_rbln-0.7.4a3}/uv.lock +0 -0
@@ -28,7 +28,7 @@ jobs:
28
28
  workflow_id: 'rebel_dispatch_model_generation_for_vllm.yaml',
29
29
  ref: 'dev',
30
30
  inputs: {
31
- optimum_rbln_version: ${{ github.ref_name }},
31
+ optimum_rbln_version: "${{ github.ref_name }}",
32
32
  }
33
33
  })
34
34
  console.log(result)
@@ -43,7 +43,9 @@ jobs:
43
43
  if: ${{ inputs.commit_message == '' }}
44
44
  run: |
45
45
  COMMIT_MESSAGE=$(git log -1 --pretty=%B)
46
- echo "message=$COMMIT_MESSAGE" >> $GITHUB_OUTPUT
46
+ echo "message<<EOF" >> $GITHUB_OUTPUT
47
+ echo "$COMMIT_MESSAGE" >> $GITHUB_OUTPUT
48
+ echo "EOF" >> $GITHUB_OUTPUT
47
49
 
48
50
  - name: Setup uv
49
51
  uses: astral-sh/setup-uv@v3
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.7.4a2
3
+ Version: 0.7.4a3
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.4a2'
20
+ __version__ = version = '0.7.4a3'
21
21
  __version_tuple__ = version_tuple = (0, 7, 4)
@@ -123,8 +123,15 @@ class RBLNModel(RBLNBaseModel):
123
123
  config = AutoConfig.from_pretrained(config._name_or_path, **kwargs)
124
124
 
125
125
  if hasattr(model, "can_generate") and model.can_generate():
126
+ import json
127
+
126
128
  generation_config = model.generation_config
127
- generation_config.save_pretrained(save_dir_path / subfolder)
129
+ generation_config_path = save_dir_path / subfolder / "generation_config.json"
130
+
131
+ generation_config.save_pretrained(generation_config_path.parent)
132
+ local_config = json.loads(generation_config_path.read_text(encoding="utf-8"))
133
+ local_config["transformers_version"] = generation_config.transformers_version
134
+ generation_config_path.write_text(json.dumps(local_config, indent=2) + "\n", encoding="utf-8")
128
135
 
129
136
  if not isinstance(config, PretrainedConfig): # diffusers config
130
137
  config = PretrainedConfig(**config)
@@ -0,0 +1,18 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .attn import *
16
+ from .flash_attn import *
17
+ from .kv_cache_update import *
18
+ from .linear import linear
@@ -0,0 +1,287 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ from torch import Tensor
18
+
19
+
20
+ @torch.library.custom_op(
21
+ "rbln_custom_ops::paged_attn_decode",
22
+ mutates_args=(["kcache", "vcache"]),
23
+ )
24
+ def paged_attn_decode(
25
+ q: Tensor,
26
+ k: Tensor,
27
+ v: Tensor,
28
+ mask: Tensor,
29
+ kcache: Tensor,
30
+ vcache: Tensor,
31
+ seq: Tensor,
32
+ scale: Tensor,
33
+ block_table: Tensor,
34
+ block_size: int,
35
+ ) -> Tensor:
36
+ return torch.empty_like(q)
37
+
38
+
39
+ @paged_attn_decode.register_fake
40
+ def paged_attn_decode_fake(
41
+ q: Tensor,
42
+ k: Tensor,
43
+ v: Tensor,
44
+ mask: Tensor,
45
+ kcache: Tensor,
46
+ vcache: Tensor,
47
+ seq: Tensor,
48
+ scale: Tensor,
49
+ block_table: Tensor,
50
+ block_size: int,
51
+ ) -> Tensor:
52
+ return torch.empty_like(q)
53
+
54
+
55
+ @torch.library.custom_op(
56
+ "rbln_custom_ops::paged_attn_prefill",
57
+ mutates_args=(["kcache", "vcache"]),
58
+ )
59
+ def paged_attn_prefill(
60
+ q: Tensor,
61
+ k: Tensor,
62
+ v: Tensor,
63
+ mask: Tensor,
64
+ kcache: Tensor,
65
+ vcache: Tensor,
66
+ seq: Tensor,
67
+ scale: Tensor,
68
+ block_table: Tensor,
69
+ block_size: int,
70
+ ) -> Tensor:
71
+ """Defines the computation pattern for prefill phase attention with KV cache updates.
72
+
73
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
74
+ a single optimized NPU operation. It is NOT meant for CPU execution.
75
+
76
+ Key differences from decode pattern:
77
+ - Handles prefill phase with multiple input tokens
78
+ - Takes explicit batch index for continuous batching
79
+
80
+ Expected tensor shapes:
81
+ - q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
82
+ - k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
83
+ - v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
84
+ - mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
85
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
86
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
87
+ - seq: [1, 1] - Starting sequence position
88
+ - scale: [] - Attention scale factor
89
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
90
+ - block_size: [] - Number of tokens per block
91
+
92
+ Returns:
93
+ Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
94
+ """
95
+ return torch.empty_like(q)
96
+
97
+
98
+ @paged_attn_prefill.register_fake
99
+ def paged_attn_prefill_fake(
100
+ q: Tensor,
101
+ k: Tensor,
102
+ v: Tensor,
103
+ mask: Tensor,
104
+ kcache: Tensor,
105
+ vcache: Tensor,
106
+ seq: Tensor,
107
+ scale: Tensor,
108
+ block_table: Tensor,
109
+ block_size: int,
110
+ ) -> Tensor:
111
+ return torch.empty_like(q)
112
+
113
+
114
+ @torch.library.custom_op(
115
+ "rbln_custom_ops::paged_causal_attn_decode",
116
+ mutates_args=(["kcache", "vcache"]),
117
+ )
118
+ def paged_causal_attn_decode(
119
+ q: Tensor,
120
+ k: Tensor,
121
+ v: Tensor,
122
+ kcache: Tensor,
123
+ vcache: Tensor,
124
+ seq: Tensor,
125
+ scale: Tensor,
126
+ block_table: Tensor,
127
+ block_size: int,
128
+ ) -> Tensor:
129
+ """Defines the computation pattern for fused attention with KV cache updates.
130
+
131
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
132
+ a single optimized NPU operation. It is NOT meant for CPU execution.
133
+
134
+ Pattern components that compiler fuses into a single op:
135
+ 1. KV cache updates with new key/value states
136
+ 2. Scaled dot-product attention computation
137
+ 3. Causal masked softmax operation
138
+ 4. Final attention output computation
139
+
140
+ Expected tensor shapes:
141
+ - q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
142
+ - k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
143
+ - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
144
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
145
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
146
+ - seq: [1, 1] - Starting sequence position
147
+ - scale: [] - Attention scale factor
148
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
149
+ - block_size: [] - Number of tokens per block
150
+
151
+ Returns:
152
+ Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
153
+ """
154
+ return torch.empty_like(q)
155
+
156
+
157
+ @paged_causal_attn_decode.register_fake
158
+ def paged_causal_attn_decode_fake(
159
+ q: Tensor,
160
+ k: Tensor,
161
+ v: Tensor,
162
+ kcache: Tensor,
163
+ vcache: Tensor,
164
+ seq: Tensor,
165
+ scale: Tensor,
166
+ block_table: Tensor,
167
+ block_size: int,
168
+ ) -> Tensor:
169
+ return torch.empty_like(q)
170
+
171
+
172
+ @torch.library.custom_op(
173
+ "rbln_custom_ops::paged_causal_attn_prefill",
174
+ mutates_args=(["kcache", "vcache"]),
175
+ )
176
+ def paged_causal_attn_prefill(
177
+ q: Tensor,
178
+ k: Tensor,
179
+ v: Tensor,
180
+ kcache: Tensor,
181
+ vcache: Tensor,
182
+ seq: Tensor,
183
+ scale: Tensor,
184
+ block_table: Tensor,
185
+ block_size: int,
186
+ ) -> Tensor:
187
+ """Defines the computation pattern for prefill phase attention with KV cache updates.
188
+
189
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
190
+ a single optimized NPU operation. It is NOT meant for CPU execution.
191
+
192
+ Key differences from decode pattern:
193
+ - Handles prefill phase with multiple input tokens
194
+ - Takes explicit batch index for continuous batching
195
+
196
+ Expected tensor shapes:
197
+ - q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
198
+ - k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
199
+ - v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
200
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
201
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
202
+ - batch: [1] - Batch index for cache access
203
+ - seq: [1, 1] - Starting sequence position
204
+ - scale: [] - Attention scale factor
205
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
206
+ - block_size: [] - Number of tokens per block
207
+
208
+ Returns:
209
+ Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
210
+ """
211
+ return torch.empty_like(q)
212
+
213
+
214
+ @paged_causal_attn_prefill.register_fake
215
+ def paged_causal_attn_prefill_fake(
216
+ q: Tensor,
217
+ k: Tensor,
218
+ v: Tensor,
219
+ kcache: Tensor,
220
+ vcache: Tensor,
221
+ seq: Tensor,
222
+ scale: Tensor,
223
+ block_table: Tensor,
224
+ block_size: int,
225
+ ) -> Tensor:
226
+ return torch.empty_like(q)
227
+
228
+
229
+ @torch.library.custom_op(
230
+ "rbln_custom_ops::paged_add_softmax_attn_decode",
231
+ mutates_args=(["kcache", "vcache"]),
232
+ )
233
+ def paged_add_softmax_attn_decode(
234
+ q: Tensor,
235
+ k: Tensor,
236
+ v: Tensor,
237
+ mask: Tensor,
238
+ kcache: Tensor,
239
+ vcache: Tensor,
240
+ seq: Tensor,
241
+ scale: Tensor,
242
+ block_table: Tensor,
243
+ block_size: int,
244
+ ) -> Tensor:
245
+ """Defines the computation pattern for fused attention with KV cache updates.
246
+
247
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
248
+ a single optimized NPU operation. It is NOT meant for CPU execution.
249
+
250
+ Pattern components that compiler fuses into a single op:
251
+ 1. KV cache updates with new key/value states
252
+ 2. Scaled dot-product attention computation
253
+ 3. add-softmax operation
254
+ 4. Final attention output computation
255
+
256
+ Expected tensor shapes:
257
+ - q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
258
+ - k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
259
+ - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
260
+ - mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
261
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
262
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
263
+ - seq: [1] - Current sequence position
264
+ - scale: [] - Attention scale factor
265
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
266
+ - block_size: [] - Number of tokens per block
267
+
268
+ Returns:
269
+ Tensor: attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
270
+ """
271
+ return torch.empty_like(q)
272
+
273
+
274
+ @paged_add_softmax_attn_decode.register_fake
275
+ def paged_add_softmax_attn_decode_fake(
276
+ q: Tensor,
277
+ k: Tensor,
278
+ v: Tensor,
279
+ mask: Tensor,
280
+ kcache: Tensor,
281
+ vcache: Tensor,
282
+ seq: Tensor,
283
+ scale: Tensor,
284
+ block_table: Tensor,
285
+ block_size: int,
286
+ ) -> Tensor:
287
+ return torch.empty_like(q)
@@ -0,0 +1,176 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch import Tensor
17
+
18
+
19
+ @torch.library.custom_op(
20
+ "rbln_custom_ops::paged_flash_attn_decode",
21
+ mutates_args=(["kcache", "vcache"]),
22
+ )
23
+ def paged_flash_attn_decode(
24
+ q: Tensor,
25
+ k: Tensor,
26
+ v: Tensor,
27
+ mask: Tensor,
28
+ kcache: Tensor,
29
+ vcache: Tensor,
30
+ seq: Tensor,
31
+ scale: Tensor,
32
+ block_table: Tensor,
33
+ block_size: int,
34
+ partition: int,
35
+ ) -> Tensor:
36
+ """Defines the computation pattern for fused flash attention with KV cache for decoding.
37
+
38
+ Returns a tensor with the same shape as q.
39
+ """
40
+ return torch.empty_like(q)
41
+
42
+
43
+ @paged_flash_attn_decode.register_fake
44
+ def paged_flash_attn_decode_fake(
45
+ q: Tensor,
46
+ k: Tensor,
47
+ v: Tensor,
48
+ mask: Tensor,
49
+ kcache: Tensor,
50
+ vcache: Tensor,
51
+ seq: Tensor,
52
+ scale: Tensor,
53
+ block_table: Tensor,
54
+ block_size: int,
55
+ partition: int,
56
+ ) -> Tensor:
57
+ return torch.empty_like(q)
58
+
59
+
60
+ @torch.library.custom_op(
61
+ "rbln_custom_ops::paged_flash_attn_prefill",
62
+ mutates_args=(["kcache", "vcache"]),
63
+ )
64
+ def paged_flash_attn_prefill(
65
+ q: Tensor,
66
+ k: Tensor,
67
+ v: Tensor,
68
+ mask: Tensor,
69
+ kcache: Tensor,
70
+ vcache: Tensor,
71
+ seq: Tensor,
72
+ scale: Tensor,
73
+ block_table: Tensor,
74
+ block_size: int,
75
+ partition: int,
76
+ ) -> Tensor:
77
+ """Defines the computation pattern for fused flash attention with KV cache for prefill.
78
+
79
+ Returns a tensor with the same shape as q.
80
+ """
81
+ return torch.empty_like(q)
82
+
83
+
84
+ @paged_flash_attn_prefill.register_fake
85
+ def paged_flash_attn_prefill_fake(
86
+ q: Tensor,
87
+ k: Tensor,
88
+ v: Tensor,
89
+ mask: Tensor,
90
+ kcache: Tensor,
91
+ vcache: Tensor,
92
+ seq: Tensor,
93
+ scale: Tensor,
94
+ block_table: Tensor,
95
+ block_size: int,
96
+ partition: int,
97
+ ) -> Tensor:
98
+ return torch.empty_like(q)
99
+
100
+
101
+ @torch.library.custom_op(
102
+ "rbln_custom_ops::paged_flash_causal_attn_decode",
103
+ mutates_args=(["kcache", "vcache"]),
104
+ )
105
+ def paged_flash_causal_attn_decode(
106
+ q: Tensor,
107
+ k: Tensor,
108
+ v: Tensor,
109
+ kcache: Tensor,
110
+ vcache: Tensor,
111
+ seq: Tensor,
112
+ scale: Tensor,
113
+ block_table: Tensor,
114
+ block_size: int,
115
+ partition: int,
116
+ ) -> Tensor:
117
+ """Defines the computation pattern for fused causal flash attention with KV cache for decoding.
118
+
119
+ Returns a tensor with the same shape as q.
120
+ """
121
+ return torch.empty_like(q)
122
+
123
+
124
+ @paged_flash_causal_attn_decode.register_fake
125
+ def paged_flash_causal_attn_decode_fake(
126
+ q: Tensor,
127
+ k: Tensor,
128
+ v: Tensor,
129
+ kcache: Tensor,
130
+ vcache: Tensor,
131
+ seq: Tensor,
132
+ scale: Tensor,
133
+ block_table: Tensor,
134
+ block_size: int,
135
+ partition: int,
136
+ ) -> Tensor:
137
+ return torch.empty_like(q)
138
+
139
+
140
+ @torch.library.custom_op(
141
+ "rbln_custom_ops::paged_flash_causal_attn_prefill",
142
+ mutates_args=(["kcache", "vcache"]),
143
+ )
144
+ def paged_flash_causal_attn_prefill(
145
+ q: Tensor,
146
+ k: Tensor,
147
+ v: Tensor,
148
+ kcache: Tensor,
149
+ vcache: Tensor,
150
+ seq: Tensor,
151
+ scale: Tensor,
152
+ block_table: Tensor,
153
+ block_size: int,
154
+ partition: int,
155
+ ) -> Tensor:
156
+ """Defines the computation pattern for fused causal flash attention with KV cache for prefill.
157
+
158
+ Returns a tensor with the same shape as q.
159
+ """
160
+ return torch.empty_like(q)
161
+
162
+
163
+ @paged_flash_causal_attn_prefill.register_fake
164
+ def paged_flash_causal_attn_prefill_fake(
165
+ q: Tensor,
166
+ k: Tensor,
167
+ v: Tensor,
168
+ kcache: Tensor,
169
+ vcache: Tensor,
170
+ seq: Tensor,
171
+ scale: Tensor,
172
+ block_table: Tensor,
173
+ block_size: int,
174
+ partition: int,
175
+ ) -> Tensor:
176
+ return torch.empty_like(q)
@@ -0,0 +1,24 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch import Tensor
17
+
18
+
19
+ @torch.library.custom_op("rbln_custom_ops::rbln_cache_update", mutates_args=(["cache"]))
20
+ def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
21
+ # Define the RBLN custom operation "rbln_cache_update" which updates a cache tensor with a given state tensor.
22
+ # This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
23
+ # The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
24
+ return torch.empty_like(cache)
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from ....ops import paged_attn_decode, paged_causal_attn_decode
15
16
  from .modeling_bart import RBLNBartForConditionalGeneration, RBLNBartModel
@@ -12,4 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from ....ops import (
16
+ paged_attn_decode,
17
+ paged_attn_prefill,
18
+ paged_causal_attn_decode,
19
+ paged_causal_attn_prefill,
20
+ paged_flash_attn_decode,
21
+ paged_flash_attn_prefill,
22
+ paged_flash_causal_attn_decode,
23
+ paged_flash_causal_attn_prefill,
24
+ )
15
25
  from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM