optimum-rbln 0.2.1a3__tar.gz → 0.2.1a5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/PKG-INFO +1 -1
  2. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/__version__.py +1 -1
  3. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/modeling_base.py +10 -9
  4. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +3 -1
  5. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +200 -154
  6. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -7
  7. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +59 -37
  8. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +1 -1
  9. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  10. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  11. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  12. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/ISSUE_TEMPLATE/model_request.md +0 -0
  13. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/pull_request_template.md +0 -0
  14. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/scripts/auto_code_review.py +0 -0
  15. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/scripts/validate_pr_checklist.py +0 -0
  16. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/workflows/auto_code_review.yml +0 -0
  17. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/workflows/check_code_quality.yml +0 -0
  18. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/workflows/deploy-on-tag.yaml +0 -0
  19. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/workflows/deploy.yaml +0 -0
  20. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/workflows/pr-title-check.yaml +0 -0
  21. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/workflows/pr_checklist_validator.yml +0 -0
  22. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/workflows/rbln_dispatch_pytest.yaml +0 -0
  23. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/workflows/rbln_optimum_inference_test.yaml +0 -0
  24. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/workflows/rbln_optimum_pytest.yaml +0 -0
  25. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.github/workflows/rbln_trigger_on_pr.yaml +0 -0
  26. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/.gitignore +0 -0
  27. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/CODE_OF_CONDUCT.md +0 -0
  28. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/CONTRIBUTING.md +0 -0
  29. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/LICENSE +0 -0
  30. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/README.md +0 -0
  31. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/assets/rbln_logo.png +0 -0
  32. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/audio-classification/run_ast_audio_classification.py +0 -0
  33. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/depth-estimation/run_dpt.py +0 -0
  34. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/image-classification/run_image_classification.py +0 -0
  35. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/image-classification/run_vit_image_classification.py +0 -0
  36. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/image-to-text/run_llava_next_image_to_text.py +0 -0
  37. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/question-answering/run_question_answering.py +0 -0
  38. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/speech-recognition/run_wav2vec2.py +0 -0
  39. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/speech-recognition/run_whisper.py +0 -0
  40. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/stable-diffusion/run_stable_diffusion.py +0 -0
  41. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/stable-diffusion/run_stable_diffusion_controlnet.py +0 -0
  42. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/stable-diffusion/run_stable_diffusion_img2img.py +0 -0
  43. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/stable-diffusion/run_stable_diffusion_img2img_controlnet.py +0 -0
  44. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/stable-diffusion/run_stable_diffusion_inpaint.py +0 -0
  45. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/stable-diffusion/run_stable_diffusion_lora.py +0 -0
  46. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/stable-diffusion/run_stable_diffusion_multicontrolnet.py +0 -0
  47. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/text-classification/run_bge_m3_text_classification.py +0 -0
  48. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/text-classification/run_bge_reranker_v2_m3_text_classification.py +0 -0
  49. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/text-classification/run_secureBERT.py +0 -0
  50. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/text-classification/run_t5_classification.py +0 -0
  51. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/text-classification/run_twitter_roberta_text_classification.py +0 -0
  52. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/text2text-generation/run_bart_text2text_generation.py +0 -0
  53. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/text2text-generation/run_llama_peft.py +0 -0
  54. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/examples/text2text-generation/run_llama_text2text_generation.py +0 -0
  55. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/pyproject.toml +0 -0
  56. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/scripts/uv-lock.sh +0 -0
  57. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/scripts/uv-sync.sh +0 -0
  58. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/__init__.py +0 -0
  59. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/__init__.py +0 -0
  60. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/modeling_diffusers.py +0 -0
  61. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/models/__init__.py +0 -0
  62. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/models/autoencoders/__init__.py +0 -0
  63. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +0 -0
  64. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/models/autoencoders/vae.py +0 -0
  65. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/models/controlnet.py +0 -0
  66. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/models/transformers/__init__.py +0 -0
  67. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/models/transformers/transformer_sd3.py +0 -0
  68. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/models/unets/__init__.py +0 -0
  69. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/models/unets/unet_2d_condition.py +0 -0
  70. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/__init__.py +0 -0
  71. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/controlnet/__init__.py +0 -0
  72. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +0 -0
  73. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +0 -0
  74. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +0 -0
  75. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +0 -0
  76. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +0 -0
  77. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +0 -0
  78. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +0 -0
  79. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -0
  80. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +0 -0
  81. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +0 -0
  82. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +0 -0
  83. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +0 -0
  84. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +0 -0
  85. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +0 -0
  86. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -0
  87. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -0
  88. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -0
  89. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/modeling.py +0 -0
  90. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/modeling_config.py +0 -0
  91. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/ops/__init__.py +0 -0
  92. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/ops/attn.py +0 -0
  93. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/ops/flash_attn.py +0 -0
  94. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/ops/kv_cache_update.py +0 -0
  95. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/__init__.py +0 -0
  96. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/modeling_alias.py +0 -0
  97. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/modeling_generic.py +0 -0
  98. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/modeling_rope_utils.py +0 -0
  99. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/__init__.py +0 -0
  100. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/auto/__init__.py +0 -0
  101. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/auto/auto_factory.py +0 -0
  102. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/auto/modeling_auto.py +0 -0
  103. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/bart/__init__.py +0 -0
  104. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/bart/bart_architecture.py +0 -0
  105. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/bart/modeling_bart.py +0 -0
  106. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/bert/__init__.py +0 -0
  107. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/bert/modeling_bert.py +0 -0
  108. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/clip/__init__.py +0 -0
  109. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/clip/modeling_clip.py +0 -0
  110. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/decoderonly/__init__.py +0 -0
  111. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/dpt/__init__.py +0 -0
  112. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -0
  113. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/exaone/__init__.py +0 -0
  114. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -0
  115. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/exaone/modeling_exaone.py +0 -0
  116. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/gemma/__init__.py +0 -0
  117. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/gemma/gemma_architecture.py +0 -0
  118. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/gemma/modeling_gemma.py +0 -0
  119. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/gpt2/__init__.py +0 -0
  120. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +0 -0
  121. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +0 -0
  122. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/llama/__init__.py +0 -0
  123. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/llama/llama_architecture.py +0 -0
  124. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/llama/modeling_llama.py +0 -0
  125. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/llava_next/__init__.py +0 -0
  126. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/midm/__init__.py +0 -0
  127. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/midm/midm_architecture.py +0 -0
  128. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/midm/modeling_midm.py +0 -0
  129. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/mistral/__init__.py +0 -0
  130. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -0
  131. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -0
  132. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/phi/__init__.py +0 -0
  133. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/phi/modeling_phi.py +0 -0
  134. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/phi/phi_architecture.py +0 -0
  135. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/qwen2/__init__.py +0 -0
  136. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -0
  137. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -0
  138. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/seq2seq/__init__.py +0 -0
  139. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/t5/__init__.py +0 -0
  140. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/t5/modeling_t5.py +0 -0
  141. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/t5/t5_architecture.py +0 -0
  142. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/wav2vec2/__init__.py +0 -0
  143. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -0
  144. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/whisper/__init__.py +0 -0
  145. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/whisper/generation_whisper.py +0 -0
  146. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/whisper/modeling_whisper.py +0 -0
  147. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -0
  148. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/xlm_roberta/__init__.py +0 -0
  149. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +0 -0
  150. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/utils/__init__.py +0 -0
  151. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/transformers/utils/rbln_quantization.py +0 -0
  152. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/utils/__init__.py +0 -0
  153. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/utils/decorator_utils.py +0 -0
  154. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/utils/hub.py +0 -0
  155. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/utils/import_utils.py +0 -0
  156. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/utils/logging.py +0 -0
  157. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/utils/model_utils.py +0 -0
  158. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/utils/runtime_utils.py +0 -0
  159. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/utils/save_utils.py +0 -0
  160. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/src/optimum/rbln/utils/submodule.py +0 -0
  161. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/tests/__init__.py +0 -0
  162. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/tests/psnr.py +0 -0
  163. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/tests/requirements_sdxl.txt +0 -0
  164. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/tests/run_stable_diffusion_xl_base.py +0 -0
  165. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/tests/test_base.py +0 -0
  166. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/tests/test_diffusers.py +0 -0
  167. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/tests/test_llm.py +0 -0
  168. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/tests/test_transformers.py +0 -0
  169. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a5}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.2.1a3
3
+ Version: 0.2.1a5
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.1a3'
15
+ __version__ = version = '0.2.1a5'
16
16
  __version_tuple__ = version_tuple = (0, 2, 1)
@@ -442,8 +442,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
442
442
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
443
443
  return
444
444
 
445
- real_save_dir = self.model_save_dir / self.subfolder
446
- save_directory_path = Path(save_directory)
445
+ # Normalize paths to handle relative paths and symlinks
446
+ real_save_dir = Path(self.model_save_dir).resolve() / self.subfolder
447
+ save_directory_path = Path(save_directory).resolve()
447
448
 
448
449
  if not os.path.exists(real_save_dir) or not os.path.isdir(real_save_dir):
449
450
  raise FileNotFoundError(
@@ -452,13 +453,13 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
452
453
  f"Please ensure the model directory exists and you have the necessary permissions to access it."
453
454
  )
454
455
 
455
- if save_directory_path.absolute() == real_save_dir.absolute():
456
+ if save_directory_path == real_save_dir:
456
457
  raise FileExistsError(
457
458
  f"Cannot save model to '{save_directory}'. This directory already exists and contains the model files."
458
459
  )
459
460
 
460
- # Create a temporary directory next to the target directory
461
- tmp_dir = save_directory + ".tmp"
461
+ # Create a temporary directory with normalized path
462
+ tmp_dir = str(save_directory_path) + ".tmp"
462
463
  try:
463
464
  # Remove temporary directory if it exists from a previous failed attempt
464
465
  if os.path.exists(tmp_dir):
@@ -473,9 +474,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
473
474
  self.generation_config.save_pretrained(tmp_dir)
474
475
 
475
476
  # If everything succeeded, atomically replace the target directory
476
- if os.path.exists(save_directory):
477
- shutil.rmtree(save_directory)
478
- os.rename(tmp_dir, save_directory)
477
+ if os.path.exists(save_directory_path):
478
+ shutil.rmtree(save_directory_path)
479
+ os.rename(tmp_dir, save_directory_path)
479
480
 
480
481
  except Exception as e:
481
482
  # Clean up the temporary directory if anything fails
@@ -484,7 +485,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
484
485
  raise e # Re-raise the exception after cleanup
485
486
 
486
487
  if push_to_hub:
487
- return super().push_to_hub(save_directory, **kwargs)
488
+ return super().push_to_hub(str(save_directory_path), **kwargs)
488
489
 
489
490
  @staticmethod
490
491
  def _raise_missing_compiled_file_error(missing_files: List[str]):
@@ -427,12 +427,14 @@ class DecoderOnlyModel(nn.Module):
427
427
  cos, sin = None, None
428
428
 
429
429
  # (batch, seq_len) -> (batch,)
430
- seq_positions = cache_position[:, 0]
431
430
  if self.attn_impl == "flash_attn":
431
+ seq_positions = cache_position[:, 0]
432
432
  max_seq_len = past_key_values[0][0].shape[-2]
433
433
  seq_positions = self.convert_sequence_positions_for_flash_attn(
434
434
  seq_positions=seq_positions, max_seq_len=max_seq_len
435
435
  )
436
+ else:
437
+ seq_positions = cache_position[:, :1]
436
438
 
437
439
  present_key_values = past_key_values
438
440
  for layer in self.layers:
@@ -38,34 +38,188 @@ from .decoderonly_architecture import (
38
38
  logger = get_logger()
39
39
 
40
40
  if TYPE_CHECKING:
41
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
41
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
42
42
 
43
43
 
44
44
  class RBLNRuntimeModel(RBLNPytorchRuntime):
45
45
  mandatory_members = ["main_input_name", "embed_tokens"]
46
46
 
47
+ def __init__(
48
+ self,
49
+ runtime: rebel.Runtime,
50
+ phase: str,
51
+ batch_size: int,
52
+ dec_attn_mask: torch.Tensor,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ super().__init__(runtime, **kwargs)
56
+ self.phase = phase
57
+ self.batch_size = batch_size
58
+
59
+ # shared tensor between prefill and decode phase
60
+ self.dec_attn_mask = dec_attn_mask
61
+
62
+ if self.phase == "prefill":
63
+ vocab_size = kwargs.pop("vocab_size")
64
+ self.max_seq_len = kwargs.pop("max_seq_len")
65
+ self.prefill_chunk_size = kwargs.pop("prefill_chunk_size")
66
+ self.output_size = [1, 1, vocab_size]
67
+ self.causal_mask = 1 - torch.triu(
68
+ torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
69
+ )
70
+
47
71
  def forward(
48
72
  self,
49
- input_ids: torch.LongTensor,
50
- inputs_embeds: torch.Tensor,
51
- attention_mask: torch.Tensor,
52
- cache_position: torch.Tensor,
53
- **kwargs,
73
+ input_ids: Optional[torch.LongTensor] = None,
74
+ inputs_embeds: Optional[torch.Tensor] = None,
75
+ cache_position: torch.Tensor = None,
76
+ attention_mask: Optional[torch.Tensor] = None,
77
+ batch_idx: Optional[int] = None,
54
78
  ):
79
+ if input_ids is None and inputs_embeds is None:
80
+ raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
81
+
55
82
  if inputs_embeds is None:
56
- inp = input_ids
83
+ inputs = input_ids
57
84
  if self.embed_tokens is not None:
58
- inp = self.embed_tokens(inp)
85
+ inputs = self.embed_tokens(inputs)
59
86
  else:
60
- inp = inputs_embeds
87
+ inputs = inputs_embeds
61
88
 
62
- return super().forward(
63
- inp,
64
- attention_mask,
89
+ if self.phase == "decode":
90
+ return self.decode_forward(
91
+ inputs,
92
+ cache_position,
93
+ attention_mask=attention_mask,
94
+ )
95
+ else:
96
+ return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx)
97
+
98
+ def decode_forward(
99
+ self,
100
+ inputs: torch.Tensor,
101
+ cache_position: torch.Tensor = None,
102
+ attention_mask: Optional[torch.Tensor] = None,
103
+ ) -> torch.FloatTensor:
104
+ batch_size = inputs.shape[0]
105
+ if batch_size != self.batch_size:
106
+ raise RuntimeError(
107
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
108
+ )
109
+
110
+ if batch_size != cache_position.shape[0]:
111
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
112
+
113
+ if attention_mask is None:
114
+ for b_idx in range(batch_size):
115
+ decoding_step = cache_position[b_idx].item()
116
+ if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
117
+ raise ValueError(
118
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
119
+ )
120
+ self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
121
+
122
+ logits = super().forward(
123
+ inputs,
124
+ self.dec_attn_mask if attention_mask is None else attention_mask,
65
125
  cache_position,
66
- **kwargs,
67
126
  )
68
127
 
128
+ return logits
129
+
130
+ def prefill_forward(
131
+ self,
132
+ inputs: torch.Tensor,
133
+ cache_position: torch.Tensor = None,
134
+ attention_mask: Optional[torch.Tensor] = None,
135
+ batch_idx: int = None,
136
+ ) -> torch.FloatTensor:
137
+ """
138
+ Performs chunked prefill for efficient KV-cache updates and memory optimization.
139
+ Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
140
+ and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
141
+ """
142
+
143
+ if batch_idx is None or batch_idx >= self.batch_size:
144
+ raise RuntimeError(
145
+ f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
146
+ )
147
+
148
+ # Handle continuous batching in a compiled graph by extracting valid inputs
149
+ # If an attention mask is provided, select only the valid (non-masked) inputs
150
+ inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
151
+
152
+ query_length = inputs.shape[1]
153
+ if query_length > self.max_seq_len:
154
+ raise ValueError(
155
+ f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
156
+ )
157
+
158
+ # Initialize attention mask for chunked processing
159
+ chunked_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
160
+
161
+ # Buffer for storing output logits
162
+ out_buffers = [
163
+ torch.empty(
164
+ size=self.output_size,
165
+ dtype=torch.float32,
166
+ device="cpu",
167
+ )
168
+ ]
169
+
170
+ # Process input in chunks of size `prefill_chunk_size`
171
+ for step in range(0, query_length, self.prefill_chunk_size):
172
+ # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
173
+ if (step + self.prefill_chunk_size) > query_length:
174
+ padding_size = step + self.prefill_chunk_size - query_length
175
+ # inputs_embeds
176
+ if inputs.dim() == 3:
177
+ inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
178
+ # inputs_ids
179
+ else:
180
+ inputs = torch.nn.functional.pad(inputs, (0, padding_size))
181
+
182
+ cache_position = torch.cat(
183
+ [
184
+ cache_position,
185
+ torch.arange(
186
+ query_length,
187
+ step + self.prefill_chunk_size,
188
+ dtype=torch.int32,
189
+ ).unsqueeze(0),
190
+ ],
191
+ dim=-1,
192
+ )
193
+
194
+ # Extract the current chunk of inputs and cache positions
195
+ input_chunk = inputs[:, step : step + self.prefill_chunk_size]
196
+ cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
197
+
198
+ # Update attention mask to ensure proper causal behavior
199
+ if step >= self.prefill_chunk_size:
200
+ chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
201
+ chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
202
+
203
+ # Define batch position and query position
204
+ batch_position = torch.tensor(batch_idx, dtype=torch.int16)
205
+ query_position = torch.tensor((query_length - 1) % self.prefill_chunk_size, dtype=torch.int16)
206
+
207
+ # Forward pass for the current chunk
208
+ logits = super().forward(
209
+ input_chunk,
210
+ chunked_attention_mask,
211
+ cache_pos_chunk,
212
+ batch_position,
213
+ query_position,
214
+ out=out_buffers,
215
+ )
216
+
217
+ # Update decoder attention mask with processed KV-cache length from prefill phase
218
+ self.dec_attn_mask[batch_idx].fill_(0)
219
+ self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
220
+
221
+ return logits
222
+
69
223
 
70
224
  @dataclass
71
225
  class RBLNDecoderOnlyOutput(ModelOutput):
@@ -103,13 +257,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
103
257
  self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
104
258
  self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
105
259
 
106
- self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
107
- self.causal_mask = 1 - torch.triu(
108
- torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
109
- )
110
- self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.float32)
111
- self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
112
-
113
260
  main_input_name = self.main_input_name
114
261
  if self.rbln_config.model_cfg["use_inputs_embeds"]:
115
262
  main_input_name = "inputs_embeds"
@@ -124,11 +271,25 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
124
271
  else:
125
272
  self.embed_tokens = None
126
273
 
274
+ dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
127
275
  self.prefill_decoder = RBLNRuntimeModel(
128
- runtime=self.model[0], main_input_name=main_input_name, embed_tokens=self.embed_tokens
276
+ runtime=self.model[0],
277
+ main_input_name=main_input_name,
278
+ embed_tokens=self.embed_tokens,
279
+ phase="prefill",
280
+ batch_size=self.batch_size,
281
+ dec_attn_mask=dec_attn_mask,
282
+ vocab_size=self.config.vocab_size,
283
+ max_seq_len=self.max_seq_len,
284
+ prefill_chunk_size=self.prefill_chunk_size,
129
285
  )
130
286
  self.decoder = RBLNRuntimeModel(
131
- runtime=self.model[1], main_input_name=main_input_name, embed_tokens=self.embed_tokens
287
+ runtime=self.model[1],
288
+ main_input_name=main_input_name,
289
+ embed_tokens=self.embed_tokens,
290
+ phase="decode",
291
+ batch_size=self.batch_size,
292
+ dec_attn_mask=dec_attn_mask,
132
293
  )
133
294
 
134
295
  @classmethod
@@ -155,7 +316,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
155
316
  def get_quantized_model(
156
317
  cls,
157
318
  model_id: str,
158
- config: Optional[PretrainedConfig] = None,
319
+ config: Optional["PretrainedConfig"] = None,
159
320
  use_auth_token: Optional[Union[bool, str]] = None,
160
321
  revision: Optional[str] = None,
161
322
  force_download: bool = False,
@@ -496,32 +657,33 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
496
657
  generate_idx: Optional[torch.Tensor] = None,
497
658
  **kwargs,
498
659
  ) -> Tuple[torch.FloatTensor]:
499
- # prefll
660
+ """
661
+ Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
662
+ For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
663
+ A for-loop ensures synchronization with the HuggingFace generate API.
664
+ The decoder stage operates as usual, processing inputs in batch mode.
665
+ """
666
+ # Prefll
500
667
  if cache_position is None:
501
668
  logits = []
502
- input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
503
- batch_size = input_tensors.shape[0]
669
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
670
+ batch_size = inputs.shape[0]
504
671
 
505
672
  for b_idx in range(batch_size):
506
- # Transform inputs as vllm format
507
- if attention_mask is not None:
508
- input_tensor = input_tensors[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
509
- else:
510
- input_tensor = input_tensors[b_idx : b_idx + 1]
511
-
512
673
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
513
-
514
- logit = self._forward_prefill(
515
- input_ids=input_tensor if inputs_embeds is None else None,
516
- inputs_embeds=input_tensor if inputs_embeds is not None else None,
674
+ logit = self.prefill_decoder(
675
+ input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
676
+ inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
677
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
517
678
  cache_position=cache_position,
518
679
  batch_idx=b_idx,
519
680
  )
520
681
  logits.append(logit)
682
+
521
683
  logits = torch.cat(logits, dim=0)
522
- # decoder
684
+ # Decoder
523
685
  else:
524
- logits = self._forward_decoder(
686
+ logits = self.decoder(
525
687
  input_ids=input_ids,
526
688
  inputs_embeds=inputs_embeds,
527
689
  cache_position=cache_position,
@@ -531,119 +693,3 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
531
693
  logits=logits,
532
694
  generate_idx=generate_idx,
533
695
  )
534
-
535
- def _forward_prefill(
536
- self,
537
- input_ids: torch.LongTensor = None,
538
- inputs_embeds: torch.Tensor = None,
539
- cache_position: torch.Tensor = None,
540
- batch_idx: int = None,
541
- ) -> torch.FloatTensor:
542
- if batch_idx is None or batch_idx >= self.batch_size:
543
- raise RuntimeError(
544
- f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
545
- )
546
-
547
- out_buffers = [
548
- torch.empty(
549
- size=[
550
- 1,
551
- 1,
552
- self.config.vocab_size,
553
- ],
554
- dtype=torch.float32,
555
- device="cpu",
556
- )
557
- ]
558
-
559
- input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
560
- query_length = input_tensors.shape[1]
561
- if query_length > self.max_seq_len:
562
- raise ValueError(
563
- f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
564
- )
565
-
566
- _attention_mask = self.prefill_attention_mask.clone()
567
-
568
- for step in range(0, query_length, self.prefill_chunk_size):
569
- # pad input_tensors & cache_position for prefill_chunk
570
- if (step + self.prefill_chunk_size) > query_length:
571
- pad_to_chunk = step + self.prefill_chunk_size - query_length
572
- if inputs_embeds is not None:
573
- input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, pad_to_chunk))
574
- else:
575
- input_tensors = torch.nn.functional.pad(input_tensors, (0, pad_to_chunk))
576
-
577
- cache_position = torch.cat(
578
- [
579
- cache_position,
580
- torch.arange(
581
- query_length,
582
- step + self.prefill_chunk_size,
583
- dtype=torch.int32,
584
- ).unsqueeze(0),
585
- ],
586
- dim=-1,
587
- )
588
-
589
- # slice input_tensor & cache_position with prefill_chunk_size
590
- _input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
591
- _cache_position = cache_position[:, step : step + self.prefill_chunk_size]
592
-
593
- # update attention_mask
594
- if step >= self.prefill_chunk_size:
595
- _attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
596
- _attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
597
-
598
- query_position = (query_length - 1) % self.prefill_chunk_size
599
-
600
- logits = self.prefill_decoder(
601
- input_ids=_input_tensors.contiguous() if inputs_embeds is None else None,
602
- inputs_embeds=_input_tensors.contiguous() if inputs_embeds is not None else None,
603
- attention_mask=_attention_mask.contiguous(),
604
- cache_position=_cache_position.contiguous(),
605
- batch_position=torch.tensor(batch_idx, dtype=torch.int16),
606
- query_position=torch.tensor(query_position, dtype=torch.int16),
607
- out=out_buffers,
608
- )
609
-
610
- # update decoder_attn_mask with preprocessed kv-cache length in prefill phase
611
- self.dec_attn_mask[batch_idx] = self.dec_attn_mask_init.clone()
612
- self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
613
-
614
- return logits
615
-
616
- def _forward_decoder(
617
- self,
618
- input_ids: torch.LongTensor = None,
619
- inputs_embeds: torch.Tensor = None,
620
- cache_position: torch.Tensor = None,
621
- ) -> torch.FloatTensor:
622
- input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
623
- if input_tensors is None:
624
- raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
625
-
626
- batch_size = input_tensors.shape[0]
627
- if batch_size != self.batch_size:
628
- raise RuntimeError(
629
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
630
- )
631
-
632
- if batch_size != cache_position.shape[0]:
633
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
634
-
635
- for b_idx in range(batch_size):
636
- decoding_step = cache_position[b_idx].item()
637
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
638
- raise ValueError(
639
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
640
- )
641
- self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
642
- logits = self.decoder(
643
- input_ids=input_tensors.contiguous() if inputs_embeds is None else None,
644
- inputs_embeds=input_tensors.contiguous() if inputs_embeds is not None else None,
645
- attention_mask=self.dec_attn_mask.contiguous(),
646
- cache_position=cache_position.contiguous(),
647
- )
648
-
649
- return logits
@@ -25,7 +25,6 @@ from transformers import (
25
25
  PreTrainedModel,
26
26
  )
27
27
  from transformers.modeling_outputs import BaseModelOutputWithPooling
28
- from transformers.models.llava_next.modeling_llava_next import LlavaNextCausalLMOutputWithPast
29
28
 
30
29
  from ....modeling import RBLNModel
31
30
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
@@ -337,7 +336,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
337
336
  generate_idx: Optional[torch.Tensor] = None,
338
337
  batch_idx: Optional[int] = None,
339
338
  **kwargs,
340
- ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
339
+ ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
341
340
  vision_feature_layer = (
342
341
  vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
343
342
  )
@@ -378,7 +377,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
378
377
  inputs_embeds = [inputs_embeds[i : i + 1, attention_mask[i].bool()] for i in range(batch_size)]
379
378
  for batch_idx in range(batch_size):
380
379
  generate_idx[batch_idx] = inputs_embeds[batch_idx].shape[-2]
381
- logit = self.language_model._forward_prefill(
380
+ logit = self.language_model.prefill_decoder(
382
381
  inputs_embeds=inputs_embeds[batch_idx],
383
382
  batch_idx=batch_idx,
384
383
  cache_position=torch.arange(
@@ -390,15 +389,13 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
390
389
 
391
390
  logits.append(logit)
392
391
  logits = torch.cat(logits, dim=0)
393
- outputs = RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
394
392
  else:
395
- outputs: RBLNDecoderOnlyOutput = self.language_model(
393
+ logits = self.language_model.decoder(
396
394
  inputs_embeds=inputs_embeds,
397
395
  cache_position=cache_position,
398
- generate_idx=generate_idx,
399
396
  )
400
397
 
401
- return outputs
398
+ return RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
402
399
 
403
400
  # Almost copied from : https://github.com/huggingface/transformers/blob/6b550462139655d488d4c663086a63e98713c6b9/src/transformers/models/llava_next/modeling_llava_next.py
404
401
  def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):