optimum-rbln 0.7.2rc2__tar.gz → 0.7.3a0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (178) hide show
  1. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/scripts/auto_code_review.py +0 -51
  2. optimum_rbln-0.7.3a0/.github/workflows/auto_code_review.yml +72 -0
  3. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/workflows/rbln_trigger_on_pr.yaml +1 -1
  4. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/PKG-INFO +1 -1
  5. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/__version__.py +2 -2
  6. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/modeling_diffusers.py +4 -6
  7. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/modeling.py +1 -1
  8. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/modeling_base.py +15 -3
  9. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/ops/__init__.py +6 -2
  10. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/ops/attn.py +95 -7
  11. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/ops/flash_attn.py +43 -6
  12. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/modeling_generic.py +3 -3
  13. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/bart/bart_architecture.py +1 -1
  14. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/bart/modeling_bart.py +1 -1
  15. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/bert/modeling_bert.py +1 -1
  16. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +186 -78
  17. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +55 -17
  18. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -3
  19. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -3
  20. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +3 -3
  21. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/midm/midm_architecture.py +3 -3
  22. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/phi/phi_architecture.py +2 -2
  23. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  24. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/t5/modeling_t5.py +1 -1
  25. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -1
  26. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/utils/import_utils.py +7 -0
  27. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/tests/test_base.py +26 -31
  28. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/tests/test_llm.py +16 -13
  29. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/tests/test_transformers.py +3 -3
  30. optimum_rbln-0.7.2rc2/.github/workflows/auto_code_review.yml +0 -33
  31. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  32. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  33. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  34. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/ISSUE_TEMPLATE/model_request.md +0 -0
  35. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/pull_request_template.md +0 -0
  36. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/scripts/validate_pr_checklist.py +0 -0
  37. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/workflows/check_code_quality.yml +0 -0
  38. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/workflows/deploy-on-tag.yaml +0 -0
  39. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/workflows/deploy.yaml +0 -0
  40. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/workflows/pr-title-check.yaml +0 -0
  41. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/workflows/pr_checklist_validator.yml +0 -0
  42. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/workflows/rbln_dispatch_pytest.yaml +0 -0
  43. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/workflows/rbln_optimum_inference_test.yaml +0 -0
  44. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.github/workflows/rbln_optimum_pytest.yaml +0 -0
  45. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/.gitignore +0 -0
  46. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/CODE_OF_CONDUCT.md +0 -0
  47. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/CONTRIBUTING.md +0 -0
  48. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/LICENSE +0 -0
  49. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/README.md +0 -0
  50. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/assets/rbln_logo.png +0 -0
  51. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/audio-classification/run_ast_audio_classification.py +0 -0
  52. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/depth-estimation/run_dpt.py +0 -0
  53. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/image-classification/run_image_classification.py +0 -0
  54. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/image-classification/run_vit_image_classification.py +0 -0
  55. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/image-to-text/run_llava_next_image_to_text.py +0 -0
  56. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/kandinsky2_2/run_kandinsky2_2_inpaint.py +0 -0
  57. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/kandinsky2_2/run_kandinsky2_2_inpaint_combined.py +0 -0
  58. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/question-answering/run_question_answering.py +0 -0
  59. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/speech-recognition/run_wav2vec2.py +0 -0
  60. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/speech-recognition/run_whisper.py +0 -0
  61. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/stable-diffusion/run_stable_diffusion.py +0 -0
  62. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/stable-diffusion/run_stable_diffusion_controlnet.py +0 -0
  63. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/stable-diffusion/run_stable_diffusion_img2img.py +0 -0
  64. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/stable-diffusion/run_stable_diffusion_img2img_controlnet.py +0 -0
  65. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/stable-diffusion/run_stable_diffusion_inpaint.py +0 -0
  66. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/stable-diffusion/run_stable_diffusion_lora.py +0 -0
  67. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/stable-diffusion/run_stable_diffusion_multicontrolnet.py +0 -0
  68. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/text-classification/run_bge_m3_text_classification.py +0 -0
  69. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/text-classification/run_bge_reranker_v2_m3_text_classification.py +0 -0
  70. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/text-classification/run_secureBERT.py +0 -0
  71. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/text-classification/run_t5_classification.py +0 -0
  72. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/text-classification/run_twitter_roberta_text_classification.py +0 -0
  73. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/text2text-generation/run_bart_text2text_generation.py +0 -0
  74. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/text2text-generation/run_llama_peft.py +0 -0
  75. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/examples/text2text-generation/run_llama_text2text_generation.py +0 -0
  76. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/pyproject.toml +0 -0
  77. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/scripts/uv-lock.sh +0 -0
  78. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/scripts/uv-sync.sh +0 -0
  79. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/__init__.py +0 -0
  80. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/__init__.py +0 -0
  81. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/models/__init__.py +0 -0
  82. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/models/autoencoders/__init__.py +0 -0
  83. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +0 -0
  84. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/models/autoencoders/vae.py +0 -0
  85. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/models/autoencoders/vq_model.py +0 -0
  86. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/models/controlnet.py +0 -0
  87. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/models/transformers/__init__.py +0 -0
  88. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/models/transformers/prior_transformer.py +0 -0
  89. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/models/transformers/transformer_sd3.py +0 -0
  90. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/models/unets/__init__.py +0 -0
  91. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/models/unets/unet_2d_condition.py +0 -0
  92. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/__init__.py +0 -0
  93. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/controlnet/__init__.py +0 -0
  94. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +0 -0
  95. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +0 -0
  96. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +0 -0
  97. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +0 -0
  98. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +0 -0
  99. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +0 -0
  100. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +0 -0
  101. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +0 -0
  102. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +0 -0
  103. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +0 -0
  104. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +0 -0
  105. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -0
  106. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +0 -0
  107. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +0 -0
  108. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +0 -0
  109. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +0 -0
  110. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +0 -0
  111. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +0 -0
  112. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -0
  113. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -0
  114. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -0
  115. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/modeling_config.py +0 -0
  116. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/ops/kv_cache_update.py +0 -0
  117. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/__init__.py +0 -0
  118. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/modeling_alias.py +0 -0
  119. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/modeling_rope_utils.py +0 -0
  120. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/__init__.py +0 -0
  121. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/auto/__init__.py +0 -0
  122. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/auto/auto_factory.py +0 -0
  123. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/auto/modeling_auto.py +0 -0
  124. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/bart/__init__.py +0 -0
  125. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/bert/__init__.py +0 -0
  126. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/clip/__init__.py +0 -0
  127. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/clip/modeling_clip.py +0 -0
  128. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/decoderonly/__init__.py +0 -0
  129. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/dpt/__init__.py +0 -0
  130. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -0
  131. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/exaone/__init__.py +0 -0
  132. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/exaone/modeling_exaone.py +0 -0
  133. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/gemma/__init__.py +0 -0
  134. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/gemma/modeling_gemma.py +0 -0
  135. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/gpt2/__init__.py +0 -0
  136. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +0 -0
  137. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/llama/__init__.py +0 -0
  138. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/llama/llama_architecture.py +0 -0
  139. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/llama/modeling_llama.py +0 -0
  140. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/llava_next/__init__.py +0 -0
  141. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +0 -0
  142. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/midm/__init__.py +0 -0
  143. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/midm/modeling_midm.py +0 -0
  144. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/mistral/__init__.py +0 -0
  145. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -0
  146. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -0
  147. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/phi/__init__.py +0 -0
  148. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/phi/modeling_phi.py +0 -0
  149. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/qwen2/__init__.py +0 -0
  150. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -0
  151. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -0
  152. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/seq2seq/__init__.py +0 -0
  153. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +0 -0
  154. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/t5/__init__.py +0 -0
  155. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/t5/t5_architecture.py +0 -0
  156. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/wav2vec2/__init__.py +0 -0
  157. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -0
  158. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/whisper/__init__.py +0 -0
  159. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/whisper/generation_whisper.py +0 -0
  160. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/whisper/modeling_whisper.py +0 -0
  161. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -0
  162. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/models/xlm_roberta/__init__.py +0 -0
  163. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/utils/__init__.py +0 -0
  164. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/transformers/utils/rbln_quantization.py +0 -0
  165. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/utils/__init__.py +0 -0
  166. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/utils/decorator_utils.py +0 -0
  167. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/utils/hub.py +0 -0
  168. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/utils/logging.py +0 -0
  169. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/utils/model_utils.py +0 -0
  170. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/utils/runtime_utils.py +0 -0
  171. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/utils/save_utils.py +0 -0
  172. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/src/optimum/rbln/utils/submodule.py +0 -0
  173. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/tests/__init__.py +0 -0
  174. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/tests/psnr.py +0 -0
  175. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/tests/requirements_sdxl.txt +0 -0
  176. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/tests/run_stable_diffusion_xl_base.py +0 -0
  177. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/tests/test_diffusers.py +0 -0
  178. {optimum_rbln-0.7.2rc2 → optimum_rbln-0.7.3a0}/uv.lock +0 -0
@@ -97,50 +97,6 @@ def remove_file_from_diff(diff_content, file_to_remove):
97
97
  return "\n".join(result)
98
98
 
99
99
 
100
- def skip_bot(pr):
101
- global force_review
102
- """
103
- Review if
104
- 1. last commit messages starts with "[autoreview]"
105
- 2. last comment contains "/autoreview"
106
- """
107
-
108
- # Check commit message
109
- commits = list(pr.get_commits())
110
- if len(commits) == 0:
111
- return True
112
-
113
- last_commit = commits[-1]
114
- try:
115
- commit_message = last_commit.raw_data["commit"]["message"]
116
- except KeyError:
117
- commit_message = ""
118
-
119
- if commit_message.startswith("[autoreview]"):
120
- return False
121
-
122
- # Check the last comment
123
- comments = list(pr.get_issue_comments())
124
- if len(comments) == 0:
125
- return True
126
-
127
- last = comments[-1]
128
- if last.user.login.find("github-actions") != -1:
129
- return True
130
-
131
- if last.body.find("/autoreview") == -1:
132
- return True
133
-
134
- if last.reactions["heart"] > 0:
135
- return True
136
-
137
- if last.body.find("force") != -1:
138
- force_review = True
139
-
140
- last.create_reaction("heart")
141
- return False
142
-
143
-
144
100
  def main():
145
101
  github_token = os.getenv("GITHUB_TOKEN")
146
102
  pr_number = os.getenv("PR_NUMBER")
@@ -155,13 +111,6 @@ def main():
155
111
  repo = g.get_repo(os.getenv("GITHUB_REPOSITORY"))
156
112
  pr = repo.get_pull(int(pr_number))
157
113
 
158
- if skip_bot(pr):
159
- print(
160
- "To invoke review, Write '/autoreview' and re-run github actions,"
161
- " or start the commit message with '[autoreview]'. "
162
- )
163
- sys.exit(0)
164
-
165
114
  # Get PR diff
166
115
  diff = get_pr_diff()
167
116
  diff = remove_file_from_diff(diff, "uv.lock")
@@ -0,0 +1,72 @@
1
+ name: Auto Code Review
2
+
3
+ on:
4
+ pull_request:
5
+ issue_comment:
6
+ types: [created]
7
+ push:
8
+ branches:
9
+ - '**'
10
+
11
+ env:
12
+ GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
13
+ GOOGLE_MODEL_ID: ${{ vars.GOOGLE_MODEL_ID }}
14
+
15
+ jobs:
16
+ auto-review:
17
+ runs-on: ubuntu-latest
18
+ steps:
19
+ - name: Checkout repository
20
+ uses: actions/checkout@v3
21
+ with:
22
+ fetch-depth: 2
23
+
24
+ - name: Check if review should run
25
+ id: check
26
+ run: |
27
+ PR_NUMBER=""
28
+ SHOULD_RUN="false"
29
+
30
+ # For push events, check commit message
31
+ if [[ "${{ github.event_name }}" == "push" ]]; then
32
+ if [[ "${{ contains(github.event.head_commit.message, '[autoreview]') }}" == "true" ]]; then
33
+ SHOULD_RUN="true"
34
+ # Use GitHub CLI to find PR associated with this commit
35
+ PR_NUMBER=$(gh pr list --head ${{ github.ref_name }} --json number --jq '.[0].number')
36
+ fi
37
+
38
+ # For PR events
39
+ elif [[ "${{ github.event_name }}" == "pull_request" ]]; then
40
+ PR_NUMBER="${{ github.event.pull_request.number }}"
41
+
42
+ # For comment events, check if it's "/autoreview"
43
+ elif [[ "${{ github.event_name }}" == "issue_comment" ]]; then
44
+ if [[ "${{ github.event.issue.pull_request != null }}" == "true" && "${{ contains(github.event.comment.body, '/autoreview') }}" == "true" ]]; then
45
+ SHOULD_RUN="true"
46
+ PR_NUMBER="${{ github.event.issue.number }}"
47
+ fi
48
+ fi
49
+
50
+ echo "should_run=$SHOULD_RUN" >> $GITHUB_OUTPUT
51
+ echo "pr_number=$PR_NUMBER" >> $GITHUB_OUTPUT
52
+ env:
53
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
54
+
55
+ - name: Set up Python
56
+ if: steps.check.outputs.should_run == 'true' && steps.check.outputs.pr_number != ''
57
+ uses: actions/setup-python@v4
58
+ with:
59
+ python-version: '3.x'
60
+
61
+ - name: Install dependencies
62
+ if: steps.check.outputs.should_run == 'true' && steps.check.outputs.pr_number != ''
63
+ run: |
64
+ python -m pip install --upgrade pip
65
+ pip install google-generativeai PyGithub
66
+
67
+ - name: Run Auto Code Review
68
+ if: steps.check.outputs.should_run == 'true' && steps.check.outputs.pr_number != ''
69
+ env:
70
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
71
+ PR_NUMBER: ${{ steps.check.outputs.pr_number }}
72
+ run: python .github/scripts/auto_code_review.py
@@ -9,7 +9,7 @@ env:
9
9
  REBEL_PYPI_ENDPOINT: ${{ vars.REBEL_PYPI_INTERNAL_ENDPOINT }}
10
10
  REBEL_PYPI_USERNAME: ${{ secrets.REBEL_PYPI_USERNAME }}
11
11
  REBEL_PYPI_PASSWORD: ${{ secrets.REBEL_PYPI_PASSWORD }}
12
- REBEL_COMPILER_VERSION: 0.7.2.dev213+g4f75bbe9
12
+ REBEL_COMPILER_VERSION: 0.7.3.dev100+g3fd6ed0a
13
13
 
14
14
  jobs:
15
15
  check-rebel-compiler-version:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.7.2rc2
3
+ Version: 0.7.3a0
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.2rc2'
21
- __version_tuple__ = version_tuple = (0, 7, 2)
20
+ __version__ = version = '0.7.3a0'
21
+ __version_tuple__ = version_tuple = (0, 7, 3)
@@ -71,13 +71,11 @@ class RBLNDiffusionMixin:
71
71
  _prefix = {}
72
72
 
73
73
  @classmethod
74
- @property
75
- def img2img_pipeline(cls):
74
+ def is_img2img_pipeline(cls):
76
75
  return "Img2Img" in cls.__name__
77
76
 
78
77
  @classmethod
79
- @property
80
- def inpaint_pipeline(cls):
78
+ def is_inpaint_pipeline(cls):
81
79
  return "Inpaint" in cls.__name__
82
80
 
83
81
  @classmethod
@@ -100,8 +98,8 @@ class RBLNDiffusionMixin:
100
98
  submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
101
99
  submodule_config.update(
102
100
  {
103
- "img2img_pipeline": cls.img2img_pipeline,
104
- "inpaint_pipeline": cls.inpaint_pipeline,
101
+ "img2img_pipeline": cls.is_img2img_pipeline(),
102
+ "inpaint_pipeline": cls.is_inpaint_pipeline(),
105
103
  }
106
104
  )
107
105
  submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
@@ -196,7 +196,7 @@ class RBLNModel(RBLNBaseModel):
196
196
  **kwargs,
197
197
  ) -> "PreTrainedModel":
198
198
  kwargs = cls.update_kwargs(kwargs)
199
- return cls.hf_class.from_pretrained(
199
+ return cls.get_hf_class().from_pretrained(
200
200
  model_id,
201
201
  subfolder=subfolder,
202
202
  revision=revision,
@@ -389,8 +389,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
389
389
  return rbln_config
390
390
 
391
391
  @classmethod
392
- @property
393
- def hf_class(cls):
392
+ def get_hf_class(cls):
394
393
  """
395
394
  Lazily loads and caches the corresponding Hugging Face model class.
396
395
  Removes 'RBLN' prefix from the class name to get the original class name
@@ -416,7 +415,20 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
416
415
  return self.forward(*args, **kwargs)
417
416
 
418
417
  def __repr__(self):
419
- return repr(self.model) + repr(self.rbln_submodules)
418
+ has_submodules = len(self.rbln_submodules) > 0
419
+ repr_str: str = f"<{self.__class__.__name__}>\n"
420
+ repr_str += f"- Total {len(self.model)} Runtimes"
421
+ repr_str += f" and {len(self.rbln_submodules)} Submodules\n" if has_submodules else "\n"
422
+ repr_str += "[Runtimes]\n"
423
+ repr_str += "\n".join([repr(model) for model in self.model])
424
+ repr_str += "\n"
425
+
426
+ if has_submodules > 0:
427
+ for i, submodule in enumerate(self.rbln_submodules):
428
+ repr_str += f"[Submodules {i} : {self._rbln_submodules[i]['name']}]\n"
429
+ repr_str += repr(submodule) + "\n"
430
+
431
+ return repr_str
420
432
 
421
433
  def __post_init__(self, **kwargs):
422
434
  pass
@@ -12,6 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .attn import register_rbln_custom_attention, register_rbln_custom_attention_add_softmax
16
- from .flash_attn import register_rbln_custom_flash_attention
15
+ from .attn import (
16
+ register_rbln_custom_attention_add_softmax,
17
+ register_rbln_custom_causal_masked_attention,
18
+ register_rbln_custom_masked_attention,
19
+ )
20
+ from .flash_attn import register_rbln_custom_flash_causal_masked_attention, register_rbln_custom_flash_masked_attention
17
21
  from .kv_cache_update import register_rbln_custom_cache_update
@@ -25,13 +25,13 @@ else:
25
25
 
26
26
 
27
27
  @lru_cache
28
- def register_rbln_custom_attention():
28
+ def register_rbln_custom_masked_attention():
29
29
  torch.library.define(
30
- "rbln_custom_ops::attn_decode",
30
+ "rbln_custom_ops::masked_attn_decode",
31
31
  "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
32
32
  )
33
33
 
34
- @torch.library.impl("rbln_custom_ops::attn_decode", "cpu")
34
+ @torch.library.impl("rbln_custom_ops::masked_attn_decode", "cpu")
35
35
  def attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale):
36
36
  """Defines the computation pattern for fused attention with KV cache updates.
37
37
 
@@ -66,7 +66,7 @@ def register_rbln_custom_attention():
66
66
  torch.empty(*vcache.shape, device=vcache.device),
67
67
  )
68
68
 
69
- @register_fake("rbln_custom_ops::attn_decode")
69
+ @register_fake("rbln_custom_ops::masked_attn_decode")
70
70
  def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
71
71
  return (
72
72
  q,
@@ -75,11 +75,11 @@ def register_rbln_custom_attention():
75
75
  )
76
76
 
77
77
  torch.library.define(
78
- "rbln_custom_ops::attn_prefill",
78
+ "rbln_custom_ops::masked_attn_prefill",
79
79
  "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
80
80
  )
81
81
 
82
- @torch.library.impl("rbln_custom_ops::attn_prefill", "cpu")
82
+ @torch.library.impl("rbln_custom_ops::masked_attn_prefill", "cpu")
83
83
  def attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale):
84
84
  """Defines the computation pattern for prefill phase attention with KV cache updates.
85
85
 
@@ -109,11 +109,99 @@ def register_rbln_custom_attention():
109
109
  """
110
110
  return q, kcache, vcache
111
111
 
112
- @register_fake("rbln_custom_ops::attn_prefill")
112
+ @register_fake("rbln_custom_ops::masked_attn_prefill")
113
113
  def attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
114
114
  return q, kcache, vcache
115
115
 
116
116
 
117
+ @lru_cache
118
+ def register_rbln_custom_causal_masked_attention():
119
+ torch.library.define(
120
+ "rbln_custom_ops::causal_masked_attn_decode",
121
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
122
+ )
123
+
124
+ @torch.library.impl("rbln_custom_ops::causal_masked_attn_decode", "cpu")
125
+ def attn_decode_cpu(q, k, v, kcache, vcache, seq, scale):
126
+ """Defines the computation pattern for fused attention with KV cache updates.
127
+
128
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
129
+ a single optimized NPU operation. It is NOT meant for CPU execution.
130
+
131
+ Pattern components that compiler fuses into a single op:
132
+ 1. KV cache updates with new key/value states
133
+ 2. Scaled dot-product attention computation
134
+ 3. Causal masked softmax operation
135
+ 4. Final attention output computation
136
+
137
+ Expected tensor shapes:
138
+ - q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
139
+ - k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
140
+ - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
141
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
142
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
143
+ - seq: [1] - Current sequence position
144
+ - scale: [] - Attention scale factor
145
+
146
+ Returns:
147
+ Tuple[Tensor, Tensor, Tensor]:
148
+ - attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
149
+ - kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
150
+ - vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
151
+ """
152
+ return (
153
+ q,
154
+ torch.empty(*kcache.shape, device=kcache.device),
155
+ torch.empty(*vcache.shape, device=vcache.device),
156
+ )
157
+
158
+ @register_fake("rbln_custom_ops::causal_masked_attn_decode")
159
+ def attn_decode_abstract(q, k, v, kcache, vcache, seq, partition):
160
+ return (
161
+ q,
162
+ torch.empty(*kcache.shape, device=kcache.device),
163
+ torch.empty(*vcache.shape, device=vcache.device),
164
+ )
165
+
166
+ torch.library.define(
167
+ "rbln_custom_ops::causal_masked_attn_prefill",
168
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
169
+ )
170
+
171
+ @torch.library.impl("rbln_custom_ops::causal_masked_attn_prefill", "cpu")
172
+ def attn_prefill_cpu(q, k, v, kcache, vcache, batch, seq, scale):
173
+ """Defines the computation pattern for prefill phase attention with KV cache updates.
174
+
175
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
176
+ a single optimized NPU operation. It is NOT meant for CPU execution.
177
+
178
+ Key differences from decode pattern:
179
+ - Handles prefill phase with multiple input tokens
180
+ - Takes explicit batch index for continuous batching
181
+
182
+ Expected tensor shapes:
183
+ - q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
184
+ - k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
185
+ - v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
186
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
187
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
188
+ - batch: [1] - Batch index for cache access
189
+ - seq: [1] - Starting sequence position
190
+ - scale: [] - Attention scale factor
191
+
192
+ Returns:
193
+ Tuple[Tensor, Tensor, Tensor]:
194
+ - attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
195
+ - empty_kcache: Same shape as input kcache - Placeholder for compiler
196
+ - empty_vcache: Same shape as input vcache - Placeholder for compiler
197
+ """
198
+ return q, kcache, vcache
199
+
200
+ @register_fake("rbln_custom_ops::causal_masked_attn_prefill")
201
+ def attn_prefill_abstract(q, k, v, kcache, vcache, batch, seq, partition):
202
+ return q, kcache, vcache
203
+
204
+
117
205
  @lru_cache
118
206
  def register_rbln_custom_attention_add_softmax():
119
207
  torch.library.define(
@@ -25,13 +25,13 @@ else:
25
25
 
26
26
 
27
27
  @lru_cache
28
- def register_rbln_custom_flash_attention():
28
+ def register_rbln_custom_flash_masked_attention():
29
29
  torch.library.define(
30
- "rbln_custom_ops::flash_attn_decode",
30
+ "rbln_custom_ops::flash_masked_attn_decode",
31
31
  "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, int e) -> Tensor[]",
32
32
  )
33
33
 
34
- @torch.library.impl("rbln_custom_ops::flash_attn_decode", "cpu")
34
+ @torch.library.impl("rbln_custom_ops::flash_masked_attn_decode", "cpu")
35
35
  def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, partition):
36
36
  return (
37
37
  q,
@@ -39,7 +39,7 @@ def register_rbln_custom_flash_attention():
39
39
  torch.empty(*vcache.shape, device=vcache.device),
40
40
  )
41
41
 
42
- @register_fake("rbln_custom_ops::flash_attn_decode")
42
+ @register_fake("rbln_custom_ops::flash_masked_attn_decode")
43
43
  def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, partition):
44
44
  return (
45
45
  q,
@@ -48,7 +48,7 @@ def register_rbln_custom_flash_attention():
48
48
  )
49
49
 
50
50
  torch.library.define(
51
- "rbln_custom_ops::flash_attn_prefill",
51
+ "rbln_custom_ops::flash_masked_attn_prefill",
52
52
  "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
53
53
  )
54
54
 
@@ -56,6 +56,43 @@ def register_rbln_custom_flash_attention():
56
56
  def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale, partition):
57
57
  return q, kcache, vcache
58
58
 
59
- @register_fake("rbln_custom_ops::flash_attn_prefill")
59
+ @register_fake("rbln_custom_ops::flash_masked_attn_prefill")
60
60
  def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, scale, partition):
61
61
  return q, kcache, vcache
62
+
63
+
64
+ @lru_cache
65
+ def register_rbln_custom_flash_causal_masked_attention():
66
+ torch.library.define(
67
+ "rbln_custom_ops::flash_causal_masked_attn_decode",
68
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, int e) -> Tensor[]",
69
+ )
70
+
71
+ @torch.library.impl("rbln_custom_ops::flash_causal_masked_attn_decode", "cpu")
72
+ def flash_attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, partition):
73
+ return (
74
+ q,
75
+ torch.empty(*kcache.shape, device=kcache.device),
76
+ torch.empty(*vcache.shape, device=vcache.device),
77
+ )
78
+
79
+ @register_fake("rbln_custom_ops::flash_causal_masked_attn_decode")
80
+ def flash_attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, partition):
81
+ return (
82
+ q,
83
+ torch.empty(*kcache.shape, device=kcache.device),
84
+ torch.empty(*vcache.shape, device=vcache.device),
85
+ )
86
+
87
+ torch.library.define(
88
+ "rbln_custom_ops::flash_causal_masked_attn_prefill",
89
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
90
+ )
91
+
92
+ @torch.library.impl("rbln_custom_ops::flash_causal_masked_attn_prefill", "cpu")
93
+ def flash_attn_prefill_cpu(q, k, v, kcache, vcache, batch, seq, scale, partition):
94
+ return q, kcache, vcache
95
+
96
+ @register_fake("rbln_custom_ops::flash_causal_masked_attn_prefill")
97
+ def flash_attn_prefill_abstract(q, k, v, kcache, vcache, batch, seq, scale, partition):
98
+ return q, kcache, vcache
@@ -73,7 +73,7 @@ class RBLNModelForQuestionAnswering(RBLNModel):
73
73
  if rbln_batch_size is None:
74
74
  rbln_batch_size = 1
75
75
 
76
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
76
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
77
77
 
78
78
  if rbln_model_input_names is None:
79
79
  for tokenizer in preprocessors:
@@ -289,7 +289,7 @@ class RBLNModelForSequenceClassification(RBLNModel):
289
289
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
290
290
  raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
291
291
 
292
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
292
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
293
293
 
294
294
  if rbln_model_input_names is None:
295
295
  for tokenizer in preprocessors:
@@ -362,7 +362,7 @@ class RBLNModelForMaskedLM(RBLNModel):
362
362
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
363
363
  raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
364
364
 
365
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
365
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
366
366
 
367
367
  if rbln_model_input_names is None:
368
368
  for tokenizer in preprocessors:
@@ -142,7 +142,7 @@ class BartSelfAttention(Seq2SeqSelfAttention):
142
142
  self.num_heads = self._original_mod.num_heads
143
143
  self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
144
144
  self.scaling = self.head_dim**-0.5
145
- self.attn_decode = torch.ops.rbln_custom_ops.attn_decode
145
+ self.attn_decode = torch.ops.rbln_custom_ops.masked_attn_decode
146
146
 
147
147
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
148
148
  query_states = self.q_proj(hidden_states) * self.scaling
@@ -58,7 +58,7 @@ class RBLNBartModel(RBLNModel):
58
58
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
59
59
  raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
60
60
 
61
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
61
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
62
62
 
63
63
  if rbln_model_input_names is None:
64
64
  for tokenizer in preprocessors:
@@ -56,7 +56,7 @@ class RBLNBertModel(RBLNModel):
56
56
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
57
57
  raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
58
58
 
59
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
59
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
60
60
 
61
61
  if rbln_model_input_names is None:
62
62
  for tokenizer in preprocessors: