optimum-rbln 0.2.1a3__tar.gz → 0.2.1a4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/PKG-INFO +1 -1
  2. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/__version__.py +1 -1
  3. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +200 -154
  4. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -7
  5. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +59 -37
  6. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  7. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  8. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  9. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/ISSUE_TEMPLATE/model_request.md +0 -0
  10. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/pull_request_template.md +0 -0
  11. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/scripts/auto_code_review.py +0 -0
  12. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/scripts/validate_pr_checklist.py +0 -0
  13. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/workflows/auto_code_review.yml +0 -0
  14. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/workflows/check_code_quality.yml +0 -0
  15. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/workflows/deploy-on-tag.yaml +0 -0
  16. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/workflows/deploy.yaml +0 -0
  17. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/workflows/pr-title-check.yaml +0 -0
  18. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/workflows/pr_checklist_validator.yml +0 -0
  19. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/workflows/rbln_dispatch_pytest.yaml +0 -0
  20. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/workflows/rbln_optimum_inference_test.yaml +0 -0
  21. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/workflows/rbln_optimum_pytest.yaml +0 -0
  22. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.github/workflows/rbln_trigger_on_pr.yaml +0 -0
  23. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/.gitignore +0 -0
  24. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/CODE_OF_CONDUCT.md +0 -0
  25. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/CONTRIBUTING.md +0 -0
  26. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/LICENSE +0 -0
  27. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/README.md +0 -0
  28. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/assets/rbln_logo.png +0 -0
  29. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/audio-classification/run_ast_audio_classification.py +0 -0
  30. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/depth-estimation/run_dpt.py +0 -0
  31. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/image-classification/run_image_classification.py +0 -0
  32. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/image-classification/run_vit_image_classification.py +0 -0
  33. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/image-to-text/run_llava_next_image_to_text.py +0 -0
  34. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/question-answering/run_question_answering.py +0 -0
  35. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/speech-recognition/run_wav2vec2.py +0 -0
  36. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/speech-recognition/run_whisper.py +0 -0
  37. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/stable-diffusion/run_stable_diffusion.py +0 -0
  38. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/stable-diffusion/run_stable_diffusion_controlnet.py +0 -0
  39. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/stable-diffusion/run_stable_diffusion_img2img.py +0 -0
  40. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/stable-diffusion/run_stable_diffusion_img2img_controlnet.py +0 -0
  41. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/stable-diffusion/run_stable_diffusion_inpaint.py +0 -0
  42. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/stable-diffusion/run_stable_diffusion_lora.py +0 -0
  43. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/stable-diffusion/run_stable_diffusion_multicontrolnet.py +0 -0
  44. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/text-classification/run_bge_m3_text_classification.py +0 -0
  45. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/text-classification/run_bge_reranker_v2_m3_text_classification.py +0 -0
  46. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/text-classification/run_secureBERT.py +0 -0
  47. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/text-classification/run_t5_classification.py +0 -0
  48. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/text-classification/run_twitter_roberta_text_classification.py +0 -0
  49. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/text2text-generation/run_bart_text2text_generation.py +0 -0
  50. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/text2text-generation/run_llama_peft.py +0 -0
  51. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/examples/text2text-generation/run_llama_text2text_generation.py +0 -0
  52. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/pyproject.toml +0 -0
  53. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/scripts/uv-lock.sh +0 -0
  54. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/scripts/uv-sync.sh +0 -0
  55. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/__init__.py +0 -0
  56. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/__init__.py +0 -0
  57. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/modeling_diffusers.py +0 -0
  58. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/models/__init__.py +0 -0
  59. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/models/autoencoders/__init__.py +0 -0
  60. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +0 -0
  61. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/models/autoencoders/vae.py +0 -0
  62. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/models/controlnet.py +0 -0
  63. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/models/transformers/__init__.py +0 -0
  64. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/models/transformers/transformer_sd3.py +0 -0
  65. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/models/unets/__init__.py +0 -0
  66. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/models/unets/unet_2d_condition.py +0 -0
  67. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/__init__.py +0 -0
  68. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/controlnet/__init__.py +0 -0
  69. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +0 -0
  70. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +0 -0
  71. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +0 -0
  72. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +0 -0
  73. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +0 -0
  74. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +0 -0
  75. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +0 -0
  76. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -0
  77. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +0 -0
  78. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +0 -0
  79. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +0 -0
  80. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +0 -0
  81. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +0 -0
  82. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +0 -0
  83. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -0
  84. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -0
  85. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -0
  86. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/modeling.py +0 -0
  87. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/modeling_base.py +0 -0
  88. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/modeling_config.py +0 -0
  89. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/ops/__init__.py +0 -0
  90. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/ops/attn.py +0 -0
  91. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/ops/flash_attn.py +0 -0
  92. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/ops/kv_cache_update.py +0 -0
  93. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/__init__.py +0 -0
  94. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/modeling_alias.py +0 -0
  95. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/modeling_generic.py +0 -0
  96. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/modeling_rope_utils.py +0 -0
  97. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/__init__.py +0 -0
  98. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/auto/__init__.py +0 -0
  99. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/auto/auto_factory.py +0 -0
  100. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/auto/modeling_auto.py +0 -0
  101. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/bart/__init__.py +0 -0
  102. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/bart/bart_architecture.py +0 -0
  103. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/bart/modeling_bart.py +0 -0
  104. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/bert/__init__.py +0 -0
  105. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/bert/modeling_bert.py +0 -0
  106. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/clip/__init__.py +0 -0
  107. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/clip/modeling_clip.py +0 -0
  108. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/decoderonly/__init__.py +0 -0
  109. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +0 -0
  110. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/dpt/__init__.py +0 -0
  111. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -0
  112. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/exaone/__init__.py +0 -0
  113. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -0
  114. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/exaone/modeling_exaone.py +0 -0
  115. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/gemma/__init__.py +0 -0
  116. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/gemma/gemma_architecture.py +0 -0
  117. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/gemma/modeling_gemma.py +0 -0
  118. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/gpt2/__init__.py +0 -0
  119. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +0 -0
  120. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +0 -0
  121. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/llama/__init__.py +0 -0
  122. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/llama/llama_architecture.py +0 -0
  123. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/llama/modeling_llama.py +0 -0
  124. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/llava_next/__init__.py +0 -0
  125. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/midm/__init__.py +0 -0
  126. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/midm/midm_architecture.py +0 -0
  127. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/midm/modeling_midm.py +0 -0
  128. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/mistral/__init__.py +0 -0
  129. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -0
  130. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -0
  131. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/phi/__init__.py +0 -0
  132. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/phi/modeling_phi.py +0 -0
  133. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/phi/phi_architecture.py +0 -0
  134. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/qwen2/__init__.py +0 -0
  135. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -0
  136. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -0
  137. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/seq2seq/__init__.py +0 -0
  138. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +0 -0
  139. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/t5/__init__.py +0 -0
  140. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/t5/modeling_t5.py +0 -0
  141. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/t5/t5_architecture.py +0 -0
  142. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/wav2vec2/__init__.py +0 -0
  143. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -0
  144. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/whisper/__init__.py +0 -0
  145. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/whisper/generation_whisper.py +0 -0
  146. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/whisper/modeling_whisper.py +0 -0
  147. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -0
  148. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/xlm_roberta/__init__.py +0 -0
  149. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +0 -0
  150. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/utils/__init__.py +0 -0
  151. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/transformers/utils/rbln_quantization.py +0 -0
  152. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/utils/__init__.py +0 -0
  153. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/utils/decorator_utils.py +0 -0
  154. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/utils/hub.py +0 -0
  155. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/utils/import_utils.py +0 -0
  156. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/utils/logging.py +0 -0
  157. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/utils/model_utils.py +0 -0
  158. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/utils/runtime_utils.py +0 -0
  159. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/utils/save_utils.py +0 -0
  160. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/src/optimum/rbln/utils/submodule.py +0 -0
  161. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/tests/__init__.py +0 -0
  162. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/tests/psnr.py +0 -0
  163. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/tests/requirements_sdxl.txt +0 -0
  164. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/tests/run_stable_diffusion_xl_base.py +0 -0
  165. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/tests/test_base.py +0 -0
  166. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/tests/test_diffusers.py +0 -0
  167. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/tests/test_llm.py +0 -0
  168. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/tests/test_transformers.py +0 -0
  169. {optimum_rbln-0.2.1a3 → optimum_rbln-0.2.1a4}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.2.1a3
3
+ Version: 0.2.1a4
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.1a3'
15
+ __version__ = version = '0.2.1a4'
16
16
  __version_tuple__ = version_tuple = (0, 2, 1)
@@ -38,34 +38,188 @@ from .decoderonly_architecture import (
38
38
  logger = get_logger()
39
39
 
40
40
  if TYPE_CHECKING:
41
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
41
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
42
42
 
43
43
 
44
44
  class RBLNRuntimeModel(RBLNPytorchRuntime):
45
45
  mandatory_members = ["main_input_name", "embed_tokens"]
46
46
 
47
+ def __init__(
48
+ self,
49
+ runtime: rebel.Runtime,
50
+ phase: str,
51
+ batch_size: int,
52
+ dec_attn_mask: torch.Tensor,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ super().__init__(runtime, **kwargs)
56
+ self.phase = phase
57
+ self.batch_size = batch_size
58
+
59
+ # shared tensor between prefill and decode phase
60
+ self.dec_attn_mask = dec_attn_mask
61
+
62
+ if self.phase == "prefill":
63
+ vocab_size = kwargs.pop("vocab_size")
64
+ self.max_seq_len = kwargs.pop("max_seq_len")
65
+ self.prefill_chunk_size = kwargs.pop("prefill_chunk_size")
66
+ self.output_size = [1, 1, vocab_size]
67
+ self.causal_mask = 1 - torch.triu(
68
+ torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
69
+ )
70
+
47
71
  def forward(
48
72
  self,
49
- input_ids: torch.LongTensor,
50
- inputs_embeds: torch.Tensor,
51
- attention_mask: torch.Tensor,
52
- cache_position: torch.Tensor,
53
- **kwargs,
73
+ input_ids: Optional[torch.LongTensor] = None,
74
+ inputs_embeds: Optional[torch.Tensor] = None,
75
+ cache_position: torch.Tensor = None,
76
+ attention_mask: Optional[torch.Tensor] = None,
77
+ batch_idx: Optional[int] = None,
54
78
  ):
79
+ if input_ids is None and inputs_embeds is None:
80
+ raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
81
+
55
82
  if inputs_embeds is None:
56
- inp = input_ids
83
+ inputs = input_ids
57
84
  if self.embed_tokens is not None:
58
- inp = self.embed_tokens(inp)
85
+ inputs = self.embed_tokens(inputs)
59
86
  else:
60
- inp = inputs_embeds
87
+ inputs = inputs_embeds
61
88
 
62
- return super().forward(
63
- inp,
64
- attention_mask,
89
+ if self.phase == "decode":
90
+ return self.decode_forward(
91
+ inputs,
92
+ cache_position,
93
+ attention_mask=attention_mask,
94
+ )
95
+ else:
96
+ return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx)
97
+
98
+ def decode_forward(
99
+ self,
100
+ inputs: torch.Tensor,
101
+ cache_position: torch.Tensor = None,
102
+ attention_mask: Optional[torch.Tensor] = None,
103
+ ) -> torch.FloatTensor:
104
+ batch_size = inputs.shape[0]
105
+ if batch_size != self.batch_size:
106
+ raise RuntimeError(
107
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
108
+ )
109
+
110
+ if batch_size != cache_position.shape[0]:
111
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
112
+
113
+ if attention_mask is None:
114
+ for b_idx in range(batch_size):
115
+ decoding_step = cache_position[b_idx].item()
116
+ if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
117
+ raise ValueError(
118
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
119
+ )
120
+ self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
121
+
122
+ logits = super().forward(
123
+ inputs,
124
+ self.dec_attn_mask if attention_mask is None else attention_mask,
65
125
  cache_position,
66
- **kwargs,
67
126
  )
68
127
 
128
+ return logits
129
+
130
+ def prefill_forward(
131
+ self,
132
+ inputs: torch.Tensor,
133
+ cache_position: torch.Tensor = None,
134
+ attention_mask: Optional[torch.Tensor] = None,
135
+ batch_idx: int = None,
136
+ ) -> torch.FloatTensor:
137
+ """
138
+ Performs chunked prefill for efficient KV-cache updates and memory optimization.
139
+ Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
140
+ and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
141
+ """
142
+
143
+ if batch_idx is None or batch_idx >= self.batch_size:
144
+ raise RuntimeError(
145
+ f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
146
+ )
147
+
148
+ # Handle continuous batching in a compiled graph by extracting valid inputs
149
+ # If an attention mask is provided, select only the valid (non-masked) inputs
150
+ inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
151
+
152
+ query_length = inputs.shape[1]
153
+ if query_length > self.max_seq_len:
154
+ raise ValueError(
155
+ f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
156
+ )
157
+
158
+ # Initialize attention mask for chunked processing
159
+ chunked_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
160
+
161
+ # Buffer for storing output logits
162
+ out_buffers = [
163
+ torch.empty(
164
+ size=self.output_size,
165
+ dtype=torch.float32,
166
+ device="cpu",
167
+ )
168
+ ]
169
+
170
+ # Process input in chunks of size `prefill_chunk_size`
171
+ for step in range(0, query_length, self.prefill_chunk_size):
172
+ # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
173
+ if (step + self.prefill_chunk_size) > query_length:
174
+ padding_size = step + self.prefill_chunk_size - query_length
175
+ # inputs_embeds
176
+ if inputs.dim() == 3:
177
+ inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
178
+ # inputs_ids
179
+ else:
180
+ inputs = torch.nn.functional.pad(inputs, (0, padding_size))
181
+
182
+ cache_position = torch.cat(
183
+ [
184
+ cache_position,
185
+ torch.arange(
186
+ query_length,
187
+ step + self.prefill_chunk_size,
188
+ dtype=torch.int32,
189
+ ).unsqueeze(0),
190
+ ],
191
+ dim=-1,
192
+ )
193
+
194
+ # Extract the current chunk of inputs and cache positions
195
+ input_chunk = inputs[:, step : step + self.prefill_chunk_size]
196
+ cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
197
+
198
+ # Update attention mask to ensure proper causal behavior
199
+ if step >= self.prefill_chunk_size:
200
+ chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
201
+ chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
202
+
203
+ # Define batch position and query position
204
+ batch_position = torch.tensor(batch_idx, dtype=torch.int16)
205
+ query_position = torch.tensor((query_length - 1) % self.prefill_chunk_size, dtype=torch.int16)
206
+
207
+ # Forward pass for the current chunk
208
+ logits = super().forward(
209
+ input_chunk,
210
+ chunked_attention_mask,
211
+ cache_pos_chunk,
212
+ batch_position,
213
+ query_position,
214
+ out=out_buffers,
215
+ )
216
+
217
+ # Update decoder attention mask with processed KV-cache length from prefill phase
218
+ self.dec_attn_mask[batch_idx].fill_(0)
219
+ self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
220
+
221
+ return logits
222
+
69
223
 
70
224
  @dataclass
71
225
  class RBLNDecoderOnlyOutput(ModelOutput):
@@ -103,13 +257,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
103
257
  self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
104
258
  self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
105
259
 
106
- self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
107
- self.causal_mask = 1 - torch.triu(
108
- torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
109
- )
110
- self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.float32)
111
- self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
112
-
113
260
  main_input_name = self.main_input_name
114
261
  if self.rbln_config.model_cfg["use_inputs_embeds"]:
115
262
  main_input_name = "inputs_embeds"
@@ -124,11 +271,25 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
124
271
  else:
125
272
  self.embed_tokens = None
126
273
 
274
+ dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
127
275
  self.prefill_decoder = RBLNRuntimeModel(
128
- runtime=self.model[0], main_input_name=main_input_name, embed_tokens=self.embed_tokens
276
+ runtime=self.model[0],
277
+ main_input_name=main_input_name,
278
+ embed_tokens=self.embed_tokens,
279
+ phase="prefill",
280
+ batch_size=self.batch_size,
281
+ dec_attn_mask=dec_attn_mask,
282
+ vocab_size=self.config.vocab_size,
283
+ max_seq_len=self.max_seq_len,
284
+ prefill_chunk_size=self.prefill_chunk_size,
129
285
  )
130
286
  self.decoder = RBLNRuntimeModel(
131
- runtime=self.model[1], main_input_name=main_input_name, embed_tokens=self.embed_tokens
287
+ runtime=self.model[1],
288
+ main_input_name=main_input_name,
289
+ embed_tokens=self.embed_tokens,
290
+ phase="decode",
291
+ batch_size=self.batch_size,
292
+ dec_attn_mask=dec_attn_mask,
132
293
  )
133
294
 
134
295
  @classmethod
@@ -155,7 +316,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
155
316
  def get_quantized_model(
156
317
  cls,
157
318
  model_id: str,
158
- config: Optional[PretrainedConfig] = None,
319
+ config: Optional["PretrainedConfig"] = None,
159
320
  use_auth_token: Optional[Union[bool, str]] = None,
160
321
  revision: Optional[str] = None,
161
322
  force_download: bool = False,
@@ -496,32 +657,33 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
496
657
  generate_idx: Optional[torch.Tensor] = None,
497
658
  **kwargs,
498
659
  ) -> Tuple[torch.FloatTensor]:
499
- # prefll
660
+ """
661
+ Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
662
+ For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
663
+ A for-loop ensures synchronization with the HuggingFace generate API.
664
+ The decoder stage operates as usual, processing inputs in batch mode.
665
+ """
666
+ # Prefll
500
667
  if cache_position is None:
501
668
  logits = []
502
- input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
503
- batch_size = input_tensors.shape[0]
669
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
670
+ batch_size = inputs.shape[0]
504
671
 
505
672
  for b_idx in range(batch_size):
506
- # Transform inputs as vllm format
507
- if attention_mask is not None:
508
- input_tensor = input_tensors[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
509
- else:
510
- input_tensor = input_tensors[b_idx : b_idx + 1]
511
-
512
673
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
513
-
514
- logit = self._forward_prefill(
515
- input_ids=input_tensor if inputs_embeds is None else None,
516
- inputs_embeds=input_tensor if inputs_embeds is not None else None,
674
+ logit = self.prefill_decoder(
675
+ input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
676
+ inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
677
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
517
678
  cache_position=cache_position,
518
679
  batch_idx=b_idx,
519
680
  )
520
681
  logits.append(logit)
682
+
521
683
  logits = torch.cat(logits, dim=0)
522
- # decoder
684
+ # Decoder
523
685
  else:
524
- logits = self._forward_decoder(
686
+ logits = self.decoder(
525
687
  input_ids=input_ids,
526
688
  inputs_embeds=inputs_embeds,
527
689
  cache_position=cache_position,
@@ -531,119 +693,3 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
531
693
  logits=logits,
532
694
  generate_idx=generate_idx,
533
695
  )
534
-
535
- def _forward_prefill(
536
- self,
537
- input_ids: torch.LongTensor = None,
538
- inputs_embeds: torch.Tensor = None,
539
- cache_position: torch.Tensor = None,
540
- batch_idx: int = None,
541
- ) -> torch.FloatTensor:
542
- if batch_idx is None or batch_idx >= self.batch_size:
543
- raise RuntimeError(
544
- f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
545
- )
546
-
547
- out_buffers = [
548
- torch.empty(
549
- size=[
550
- 1,
551
- 1,
552
- self.config.vocab_size,
553
- ],
554
- dtype=torch.float32,
555
- device="cpu",
556
- )
557
- ]
558
-
559
- input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
560
- query_length = input_tensors.shape[1]
561
- if query_length > self.max_seq_len:
562
- raise ValueError(
563
- f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
564
- )
565
-
566
- _attention_mask = self.prefill_attention_mask.clone()
567
-
568
- for step in range(0, query_length, self.prefill_chunk_size):
569
- # pad input_tensors & cache_position for prefill_chunk
570
- if (step + self.prefill_chunk_size) > query_length:
571
- pad_to_chunk = step + self.prefill_chunk_size - query_length
572
- if inputs_embeds is not None:
573
- input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, pad_to_chunk))
574
- else:
575
- input_tensors = torch.nn.functional.pad(input_tensors, (0, pad_to_chunk))
576
-
577
- cache_position = torch.cat(
578
- [
579
- cache_position,
580
- torch.arange(
581
- query_length,
582
- step + self.prefill_chunk_size,
583
- dtype=torch.int32,
584
- ).unsqueeze(0),
585
- ],
586
- dim=-1,
587
- )
588
-
589
- # slice input_tensor & cache_position with prefill_chunk_size
590
- _input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
591
- _cache_position = cache_position[:, step : step + self.prefill_chunk_size]
592
-
593
- # update attention_mask
594
- if step >= self.prefill_chunk_size:
595
- _attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
596
- _attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
597
-
598
- query_position = (query_length - 1) % self.prefill_chunk_size
599
-
600
- logits = self.prefill_decoder(
601
- input_ids=_input_tensors.contiguous() if inputs_embeds is None else None,
602
- inputs_embeds=_input_tensors.contiguous() if inputs_embeds is not None else None,
603
- attention_mask=_attention_mask.contiguous(),
604
- cache_position=_cache_position.contiguous(),
605
- batch_position=torch.tensor(batch_idx, dtype=torch.int16),
606
- query_position=torch.tensor(query_position, dtype=torch.int16),
607
- out=out_buffers,
608
- )
609
-
610
- # update decoder_attn_mask with preprocessed kv-cache length in prefill phase
611
- self.dec_attn_mask[batch_idx] = self.dec_attn_mask_init.clone()
612
- self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
613
-
614
- return logits
615
-
616
- def _forward_decoder(
617
- self,
618
- input_ids: torch.LongTensor = None,
619
- inputs_embeds: torch.Tensor = None,
620
- cache_position: torch.Tensor = None,
621
- ) -> torch.FloatTensor:
622
- input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
623
- if input_tensors is None:
624
- raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
625
-
626
- batch_size = input_tensors.shape[0]
627
- if batch_size != self.batch_size:
628
- raise RuntimeError(
629
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
630
- )
631
-
632
- if batch_size != cache_position.shape[0]:
633
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
634
-
635
- for b_idx in range(batch_size):
636
- decoding_step = cache_position[b_idx].item()
637
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
638
- raise ValueError(
639
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
640
- )
641
- self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
642
- logits = self.decoder(
643
- input_ids=input_tensors.contiguous() if inputs_embeds is None else None,
644
- inputs_embeds=input_tensors.contiguous() if inputs_embeds is not None else None,
645
- attention_mask=self.dec_attn_mask.contiguous(),
646
- cache_position=cache_position.contiguous(),
647
- )
648
-
649
- return logits
@@ -25,7 +25,6 @@ from transformers import (
25
25
  PreTrainedModel,
26
26
  )
27
27
  from transformers.modeling_outputs import BaseModelOutputWithPooling
28
- from transformers.models.llava_next.modeling_llava_next import LlavaNextCausalLMOutputWithPast
29
28
 
30
29
  from ....modeling import RBLNModel
31
30
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
@@ -337,7 +336,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
337
336
  generate_idx: Optional[torch.Tensor] = None,
338
337
  batch_idx: Optional[int] = None,
339
338
  **kwargs,
340
- ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
339
+ ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
341
340
  vision_feature_layer = (
342
341
  vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
343
342
  )
@@ -378,7 +377,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
378
377
  inputs_embeds = [inputs_embeds[i : i + 1, attention_mask[i].bool()] for i in range(batch_size)]
379
378
  for batch_idx in range(batch_size):
380
379
  generate_idx[batch_idx] = inputs_embeds[batch_idx].shape[-2]
381
- logit = self.language_model._forward_prefill(
380
+ logit = self.language_model.prefill_decoder(
382
381
  inputs_embeds=inputs_embeds[batch_idx],
383
382
  batch_idx=batch_idx,
384
383
  cache_position=torch.arange(
@@ -390,15 +389,13 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
390
389
 
391
390
  logits.append(logit)
392
391
  logits = torch.cat(logits, dim=0)
393
- outputs = RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
394
392
  else:
395
- outputs: RBLNDecoderOnlyOutput = self.language_model(
393
+ logits = self.language_model.decoder(
396
394
  inputs_embeds=inputs_embeds,
397
395
  cache_position=cache_position,
398
- generate_idx=generate_idx,
399
396
  )
400
397
 
401
- return outputs
398
+ return RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
402
399
 
403
400
  # Almost copied from : https://github.com/huggingface/transformers/blob/6b550462139655d488d4c663086a63e98713c6b9/src/transformers/models/llava_next/modeling_llava_next.py
404
401
  def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
19
19
  import rebel
20
20
  import torch
21
21
  from rebel.compile_context import CompileContext
22
- from transformers import AutoModelForSeq2SeqLM, GenerationConfig, PretrainedConfig, PreTrainedModel
22
+ from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
23
23
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
24
24
 
25
25
  from ....modeling import RBLNModel
@@ -31,12 +31,7 @@ from ....utils.runtime_utils import RBLNPytorchRuntime
31
31
  logger = get_logger(__name__)
32
32
 
33
33
  if TYPE_CHECKING:
34
- from transformers import (
35
- AutoFeatureExtractor,
36
- AutoProcessor,
37
- AutoTokenizer,
38
- PretrainedConfig,
39
- )
34
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, GenerationConfig, PretrainedConfig
40
35
 
41
36
 
42
37
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
@@ -50,9 +45,50 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
50
45
  class RBLNRuntimeDecoder(RBLNPytorchRuntime):
51
46
  mandatory_members = ["main_input_name"]
52
47
 
53
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
54
- outputs = super().forward(*args, **kwargs)
55
- return Seq2SeqLMOutput(logits=outputs)
48
+ def __init__(
49
+ self,
50
+ runtime: rebel.Runtime,
51
+ batch_size: int,
52
+ dec_max_seq_len: int,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ super().__init__(runtime, **kwargs)
56
+ self.batch_size = batch_size
57
+ self.dec_max_seq_len = dec_max_seq_len
58
+
59
+ def forward(
60
+ self,
61
+ decoder_input_ids: Optional[torch.LongTensor] = None,
62
+ attention_mask: Optional[torch.FloatTensor] = None,
63
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
64
+ cache_position: Optional[torch.Tensor] = None,
65
+ **kwargs,
66
+ ) -> Tuple[torch.FloatTensor]:
67
+ batch_size = decoder_input_ids.shape[0]
68
+ if batch_size != self.batch_size:
69
+ raise RuntimeError(
70
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
71
+ )
72
+
73
+ if batch_size != cache_position.shape[0]:
74
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
75
+
76
+ for b_idx in range(self.batch_size):
77
+ decoding_step = cache_position[b_idx].item()
78
+ if not (0 <= decoding_step < self.dec_max_seq_len):
79
+ raise ValueError(
80
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
81
+ )
82
+ decoder_attention_mask[b_idx, : decoding_step + 1] = 1
83
+
84
+ lm_logits = super().forward(
85
+ decoder_input_ids,
86
+ decoder_attention_mask,
87
+ attention_mask,
88
+ cache_position,
89
+ )
90
+
91
+ return Seq2SeqLMOutput(logits=lm_logits)
56
92
 
57
93
 
58
94
  class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
@@ -72,8 +108,15 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
72
108
  auto_model_class = AutoModelForSeq2SeqLM
73
109
 
74
110
  def __post_init__(self, **kwargs):
75
- self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_ids")
76
- self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
111
+ batch_size = self.rbln_config.model_cfg["batch_size"]
112
+ dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
113
+ self.encoder = RBLNRuntimeEncoder(
114
+ runtime=self.model[0],
115
+ main_input_name="input_ids",
116
+ )
117
+ self.decoder = RBLNRuntimeDecoder(
118
+ runtime=self.model[1], main_input_name="input_ids", batch_size=batch_size, dec_max_seq_len=dec_max_seq_len
119
+ )
77
120
 
78
121
  @classmethod
79
122
  @torch.inference_mode()
@@ -304,46 +347,24 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
304
347
 
305
348
  def forward(
306
349
  self,
307
- input_ids: torch.LongTensor = None,
350
+ decoder_input_ids: torch.LongTensor = None,
308
351
  cache_position: Union[List[torch.Tensor], torch.Tensor] = None,
309
352
  **kwargs,
310
353
  ) -> Tuple[torch.FloatTensor]:
311
354
  # common decoder
312
355
  cache_position = torch.full((self.rbln_config.model_cfg["batch_size"], 1), cache_position, dtype=torch.int32)
313
- logits = self._forward_decoder(input_ids=input_ids, cache_position=cache_position, **kwargs).logits
356
+ logits = self.decoder(decoder_input_ids=decoder_input_ids, cache_position=cache_position, **kwargs).logits
314
357
 
315
358
  return Seq2SeqLMOutput(
316
359
  logits=logits,
317
360
  )
318
361
 
319
- def _forward_decoder(
320
- self,
321
- attention_mask: Optional[torch.FloatTensor] = None,
322
- decoder_input_ids: Optional[torch.LongTensor] = None,
323
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
324
- cache_position: Optional[torch.Tensor] = None,
325
- **kwargs,
326
- ) -> Tuple[torch.FloatTensor]:
327
- dec_attention_mask = decoder_attention_mask.clone()
328
- for b_idx in range(self.rbln_config.model_cfg["batch_size"]):
329
- dec_attention_mask[b_idx, : cache_position[b_idx] + 1] = 1
330
-
331
- decoder_output = self.decoder(
332
- input_ids=decoder_input_ids,
333
- attention_mask=dec_attention_mask,
334
- encoder_attention_mask=attention_mask,
335
- cache_position=cache_position,
336
- )
337
- lm_logits = decoder_output.logits
338
-
339
- return Seq2SeqLMOutput(logits=lm_logits)
340
-
341
362
  def _prepare_encoder_decoder_kwargs_for_generation(
342
363
  self,
343
364
  inputs_tensor: torch.Tensor,
344
365
  model_kwargs,
345
366
  model_input_name: Optional[str] = None,
346
- generation_config: Optional[GenerationConfig] = None,
367
+ generation_config: Optional["GenerationConfig"] = None,
347
368
  ) -> Dict[str, Any]:
348
369
  # 1. get encoder
349
370
  encoder = self.get_encoder()
@@ -373,6 +394,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
373
394
  )
374
395
 
375
396
  # 3. make sure that encoder returns `ModelOutput`
397
+ model_input_name = model_input_name if model_input_name is not None else self.main_input_name
376
398
  encoder_kwargs["return_dict"] = True
377
399
  encoder_kwargs["output_hidden_states"] = False
378
400
  encoder_kwargs["output_attentions"] = False