optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,508 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import deque
16
+ from typing import Any, Optional
17
+
18
+ import rebel
19
+ import torch
20
+ import torch.nn.functional as F
21
+
22
+ from ....utils.runtime_utils import RBLNPytorchRuntime
23
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
24
+ from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
25
+
26
+
27
+ class RBLNPageTableManager:
28
+ EMPTY_BLOCK = -1
29
+ NO_BLOCKS_ERROR = (
30
+ "No memory blocks are available for allocation. "
31
+ "The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln. "
32
+ "This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html). "
33
+ "Using vllm-rbln should fix this issue and enhance inference performance."
34
+ )
35
+
36
+ def __init__(self, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
37
+ self.rbln_config = rbln_config
38
+ self.block_tables = torch.zeros(
39
+ self.rbln_config.batch_size,
40
+ self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
41
+ dtype=torch.int16,
42
+ ).fill_(self.EMPTY_BLOCK)
43
+ self.free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
44
+
45
+ def update_block(self, batch_idx: int, block_idx: int):
46
+ """
47
+ If the block is empty (empty_block), allocates a block from the free_block_pool.
48
+ """
49
+ if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
50
+ raise IndexError(
51
+ f"Invalid index(batch_idx={batch_idx}, block_idx={block_idx}): \n \
52
+ BlockTable Shape(batch_axis, block_axis): {self.block_tables.shape}, BlockSize: {self.rbln_config.kvcache_block_size}"
53
+ )
54
+
55
+ if self.block_tables[batch_idx][block_idx] == self.EMPTY_BLOCK:
56
+ if self.free_block_pool:
57
+ block = self.free_block_pool.popleft()
58
+ self.block_tables[batch_idx][block_idx] = block
59
+ else:
60
+ raise RuntimeError(self.NO_BLOCKS_ERROR)
61
+
62
+ def replace_empty_block(self, block_tables: torch.Tensor):
63
+ """
64
+ Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
65
+ """
66
+ if not torch.any(block_tables == self.EMPTY_BLOCK):
67
+ return block_tables.clone()
68
+ elif self.free_block_pool:
69
+ _free_block = self.free_block_pool[0]
70
+ return torch.where(block_tables == self.EMPTY_BLOCK, _free_block, block_tables)
71
+ else:
72
+ raise RuntimeError(self.NO_BLOCKS_ERROR)
73
+
74
+ def get_block_tables(
75
+ self, cache_position: torch.Tensor, batch_idx: int = None, batch_size: int = None, phase: str = "prefill"
76
+ ) -> torch.Tensor:
77
+ """
78
+ Manages and returns the KV cache block tables.
79
+ Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
80
+
81
+ Args:
82
+ cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
83
+ batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
84
+
85
+ Returns:
86
+ Updated block tables.
87
+ """
88
+
89
+ def get_global_block_tables():
90
+ if not self.rbln_config.use_global_attention:
91
+ return None
92
+
93
+ if phase == "prefill":
94
+ # Track previously used blocks and return them to the free_block_pool and
95
+ # reset the current batch's block table to empty blocks
96
+ prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.EMPTY_BLOCK].tolist()
97
+ self.free_block_pool.extend(prev_blocks)
98
+ self.block_tables[batch_idx].fill_(self.EMPTY_BLOCK)
99
+
100
+ # Get the start (s) and end (e) positions from cache_position and
101
+ # iterate over the cache positions to allocate necessary blocks
102
+ s, e = cache_position[0][0].item(), cache_position[0][-1].item()
103
+ for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
104
+ block_idx = position // self.rbln_config.kvcache_block_size
105
+ self.update_block(batch_idx, block_idx)
106
+
107
+ return self.replace_empty_block(self.block_tables[batch_idx])
108
+ # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
109
+ else:
110
+ for b_idx in range(batch_size):
111
+ position = cache_position[b_idx][0].item()
112
+ block_idx = position // self.rbln_config.kvcache_block_size
113
+ self.update_block(b_idx, block_idx)
114
+
115
+ return self.replace_empty_block(self.block_tables)
116
+
117
+ def get_local_block_tables():
118
+ if not self.rbln_config.use_local_attention:
119
+ return None
120
+ else:
121
+ return (
122
+ torch.tensor([batch_idx], dtype=torch.int16)
123
+ if phase == "prefill"
124
+ else torch.arange(batch_size, dtype=torch.int16).view(batch_size, -1)
125
+ )
126
+
127
+ return get_global_block_tables(), get_local_block_tables()
128
+
129
+ # Whether block_tables and local_block_tables are provided by the user
130
+ def is_external_block_tables(
131
+ self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
132
+ ):
133
+ if self.rbln_config.cache_impl == "static" and block_tables is None:
134
+ return False
135
+ elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
136
+ return False
137
+ elif self.rbln_config.cache_impl == "hybrid":
138
+ if (block_tables is not None) != (local_block_tables is not None):
139
+ raise ValueError(
140
+ "Both block_tables and local_block_tables must be provided or neither of them must be provided."
141
+ )
142
+ elif block_tables is None and local_block_tables is None:
143
+ return False
144
+
145
+ return True
146
+
147
+ def get_block_tables_if_needed(
148
+ self,
149
+ batch_size,
150
+ cache_position: torch.Tensor,
151
+ batch_idx: int = None,
152
+ phase: str = "prefill",
153
+ block_tables: Optional[torch.Tensor] = None,
154
+ local_block_tables: Optional[torch.Tensor] = None,
155
+ ):
156
+ is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
157
+ if not is_external_block_tables:
158
+ block_tables, local_block_tables = self.get_block_tables(
159
+ cache_position, batch_idx=batch_idx, batch_size=batch_size, phase=phase
160
+ )
161
+
162
+ return block_tables, local_block_tables, is_external_block_tables
163
+
164
+
165
+ class RBLNRuntimeModel(RBLNPytorchRuntime):
166
+ mandatory_members = ["main_input_name", "embed_tokens"]
167
+
168
+ def __init__(
169
+ self,
170
+ runtime: rebel.Runtime,
171
+ phase: str,
172
+ batch_size: int,
173
+ dec_attn_mask: torch.Tensor,
174
+ page_table_manager: RBLNPageTableManager,
175
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
176
+ out_buffers: Optional[torch.Tensor] = None,
177
+ **kwargs: Any,
178
+ ) -> None:
179
+ super().__init__(runtime, **kwargs)
180
+ self.phase = phase
181
+ self.batch_size = batch_size
182
+ self.rbln_config = rbln_config
183
+
184
+ # shared resources between prefill and decode phase
185
+ self.dec_attn_mask = dec_attn_mask
186
+ self.page_table_manager = page_table_manager
187
+
188
+ if self.phase == "prefill":
189
+ self.out_buffers = out_buffers
190
+ self.causal_mask = 1 - torch.triu(
191
+ torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
192
+ )
193
+
194
+ self.lora_int_ids = None
195
+
196
+ def inputs_embeddings_if_needed(
197
+ self, input_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None
198
+ ):
199
+ if input_ids is None and inputs_embeds is None:
200
+ raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
201
+
202
+ if self.rbln_config.use_inputs_embeds:
203
+ return self.embed_tokens(input_ids) if inputs_embeds is None else inputs_embeds
204
+ else:
205
+ return input_ids
206
+
207
+ def forward(
208
+ self,
209
+ input_ids: Optional[torch.LongTensor] = None,
210
+ inputs_embeds: Optional[torch.Tensor] = None,
211
+ cache_position: torch.Tensor = None,
212
+ attention_mask: Optional[torch.Tensor] = None,
213
+ batch_idx: Optional[int] = None,
214
+ block_tables: Optional[torch.Tensor] = None,
215
+ position_embed: Optional[torch.Tensor] = None,
216
+ position_ids: Optional[torch.Tensor] = None,
217
+ token_type_ids: Optional[torch.Tensor] = None,
218
+ local_block_tables: Optional[torch.Tensor] = None,
219
+ lora_int_ids: Optional[torch.Tensor] = None,
220
+ ):
221
+ inputs = self.inputs_embeddings_if_needed(input_ids, inputs_embeds)
222
+ block_tables, local_block_tables, is_external_block_tables = (
223
+ self.page_table_manager.get_block_tables_if_needed(
224
+ self.batch_size,
225
+ cache_position,
226
+ batch_idx=batch_idx,
227
+ phase=self.phase,
228
+ block_tables=block_tables,
229
+ local_block_tables=local_block_tables,
230
+ )
231
+ )
232
+
233
+ if self.phase == "decode":
234
+ return self.decode_forward(
235
+ inputs,
236
+ cache_position,
237
+ block_tables,
238
+ is_external_block_tables,
239
+ attention_mask=attention_mask,
240
+ position_embed=position_embed,
241
+ position_ids=position_ids,
242
+ local_block_tables=local_block_tables,
243
+ lora_int_ids=lora_int_ids,
244
+ )
245
+ else:
246
+ return self.prefill_forward(
247
+ inputs,
248
+ cache_position,
249
+ attention_mask,
250
+ batch_idx,
251
+ block_tables,
252
+ is_external_block_tables=is_external_block_tables,
253
+ position_embed=position_embed,
254
+ token_type_ids=token_type_ids,
255
+ local_block_tables=local_block_tables,
256
+ lora_int_ids=lora_int_ids,
257
+ )
258
+
259
+ def decode_forward(
260
+ self,
261
+ inputs: torch.Tensor,
262
+ cache_position: torch.Tensor = None,
263
+ block_tables: torch.Tensor = None,
264
+ is_external_block_tables: bool = None,
265
+ attention_mask: Optional[torch.Tensor] = None,
266
+ position_embed: Optional[torch.Tensor] = None,
267
+ position_ids: Optional[torch.Tensor] = None,
268
+ local_block_tables: Optional[torch.Tensor] = None,
269
+ lora_int_ids: Optional[torch.Tensor] = None,
270
+ ) -> torch.FloatTensor:
271
+ if self.rbln_config.use_lora and lora_int_ids is None:
272
+ if self.lora_int_ids is None:
273
+ raise ValueError(
274
+ "lora_int_id is required when using LoRA. "
275
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
276
+ )
277
+
278
+ lora_int_ids = self.lora_int_ids
279
+
280
+ if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
281
+ raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
282
+
283
+ if self.batch_size != cache_position.shape[0]:
284
+ raise RuntimeError(
285
+ f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.batch_size}."
286
+ )
287
+
288
+ if self.rbln_config.use_attention_mask and attention_mask is None:
289
+ for b_idx in range(self.batch_size):
290
+ decoding_step = cache_position[b_idx].item()
291
+ if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
292
+ raise ValueError(
293
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
294
+ )
295
+
296
+ if is_external_block_tables:
297
+ self.dec_attn_mask[b_idx].fill_(0)
298
+ self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
299
+ else:
300
+ self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
301
+
302
+ attention_mask = self.dec_attn_mask
303
+
304
+ logits = super().forward(
305
+ inputs,
306
+ cache_position,
307
+ block_tables,
308
+ local_block_tables,
309
+ position_embed,
310
+ attention_mask if self.rbln_config.use_attention_mask else None,
311
+ position_ids if self.rbln_config.use_position_ids else None,
312
+ lora_int_ids if self.rbln_config.use_lora else None,
313
+ )
314
+
315
+ return RBLNDecoderOnlyOutput(logits=logits)
316
+
317
+ def _prepare_prefill_inputs(
318
+ self,
319
+ inputs: torch.Tensor,
320
+ cache_position: Optional[torch.Tensor] = None,
321
+ attention_mask: Optional[torch.Tensor] = None,
322
+ position_embed: Optional[torch.Tensor] = None,
323
+ token_type_ids: Optional[torch.Tensor] = None,
324
+ ):
325
+ """
326
+ Prepare inputs for prefill phase.
327
+ """
328
+ # Handle continuous batching in a compiled graph by extracting valid inputs
329
+ # If an attention mask is provided, select only the valid (non-masked) inputs
330
+ if attention_mask is not None:
331
+ inputs = inputs[:, attention_mask.bool()]
332
+ position_embed = None if position_embed is None else position_embed[:, :, :, attention_mask.bool(), :]
333
+ token_type_ids = None if token_type_ids is None else token_type_ids[:, attention_mask.bool()]
334
+
335
+ query_length = inputs.shape[1]
336
+ if query_length > self.rbln_config.max_seq_len:
337
+ raise ValueError(
338
+ f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
339
+ )
340
+
341
+ # Initialize attention mask for chunked processing
342
+ chunked_attention_mask = (
343
+ torch.zeros(
344
+ 1,
345
+ 1,
346
+ self.rbln_config.prefill_chunk_size,
347
+ self.rbln_config.max_seq_len,
348
+ dtype=self.rbln_config.torch_dtype,
349
+ )
350
+ if self.rbln_config.use_attention_mask
351
+ else None
352
+ )
353
+
354
+ cache_position = (
355
+ torch.arange(query_length, dtype=torch.int32).unsqueeze(0) if cache_position is None else cache_position
356
+ )
357
+ # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
358
+ padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
359
+ if padding_size > 0:
360
+ inputs = (
361
+ F.pad(inputs, (0, 0, 0, padding_size))
362
+ if self.rbln_config.use_inputs_embeds
363
+ else F.pad(inputs, (0, padding_size))
364
+ )
365
+ position_embed = F.pad(position_embed, (0, 0, 0, padding_size)) if position_embed is not None else None
366
+ token_type_ids = F.pad(token_type_ids, (0, padding_size), value=-1) if token_type_ids is not None else None
367
+ cache_position = F.pad(cache_position, (0, padding_size))
368
+
369
+ # Overwrite position_ids and padded_cache_lengths
370
+ position_ids = cache_position.clone() if self.rbln_config.use_position_ids else None
371
+ padded_cache_lengths = 0
372
+
373
+ return (
374
+ inputs,
375
+ cache_position,
376
+ chunked_attention_mask,
377
+ position_ids,
378
+ position_embed,
379
+ padded_cache_lengths,
380
+ query_length,
381
+ token_type_ids,
382
+ )
383
+
384
+ def prefill_forward(
385
+ self,
386
+ inputs: torch.Tensor,
387
+ cache_position: Optional[torch.Tensor] = None,
388
+ attention_mask: Optional[torch.Tensor] = None,
389
+ batch_idx: Optional[int] = None,
390
+ block_tables: Optional[torch.Tensor] = None,
391
+ is_external_block_tables: Optional[bool] = None,
392
+ position_embed: Optional[torch.Tensor] = None,
393
+ token_type_ids: Optional[torch.Tensor] = None,
394
+ local_block_tables: Optional[torch.Tensor] = None,
395
+ lora_int_ids: Optional[torch.Tensor] = None,
396
+ ) -> torch.FloatTensor:
397
+ """
398
+ Performs chunked prefill for efficient KV-cache updates and memory optimization.
399
+ Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
400
+ and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
401
+ """
402
+ if self.rbln_config.use_lora and lora_int_ids is None:
403
+ if self.lora_int_ids is None:
404
+ raise ValueError(
405
+ "lora_int_id is required when using LoRA. "
406
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
407
+ )
408
+
409
+ if batch_idx is not None:
410
+ lora_int_ids = self.lora_int_ids[batch_idx : batch_idx + 1].clone()
411
+ else:
412
+ lora_int_ids = self.lora_int_ids.clone()
413
+
414
+ (
415
+ inputs,
416
+ cache_position,
417
+ chunked_attention_mask,
418
+ position_ids,
419
+ position_embed,
420
+ padded_cache_lengths,
421
+ query_length,
422
+ token_type_ids,
423
+ ) = self._prepare_prefill_inputs(
424
+ inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
425
+ )
426
+
427
+ # Assumed that prefix caching was performed externally if cache_position doesn't start from 0.
428
+ prefix_cached_len = cache_position[0][0].item()
429
+ if prefix_cached_len > 0:
430
+ if prefix_cached_len % self.rbln_config.prefill_chunk_size != 0:
431
+ raise NotImplementedError(
432
+ "Prefix Caching is not supported yet for non-multiple of prefill_chunk_size."
433
+ )
434
+ if self.rbln_config.use_attention_mask:
435
+ chunked_attention_mask[:, :, :, :prefix_cached_len] = 1
436
+
437
+ # Process input in chunks of size `prefill_chunk_size`
438
+ output_logits = []
439
+ for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
440
+ s, e = step, step + self.rbln_config.prefill_chunk_size
441
+ # Extract the current chunk of inputs, cache positions, position ids, and position embeddings
442
+ input_chunk = inputs[:, s:e]
443
+ cache_pos_chunk = cache_position[:, s:e]
444
+ position_ids_chunk = position_ids[:, s:e] if self.rbln_config.use_position_ids else None
445
+ position_embed_chunk = position_embed[:, :, :, s:e, :] if position_embed is not None else None
446
+
447
+ # Update attention mask to ensure proper causal behavior
448
+ if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
449
+ if step > 0: # update previous chunk
450
+ chunked_attention_mask[
451
+ :,
452
+ :,
453
+ :,
454
+ s - self.rbln_config.prefill_chunk_size + prefix_cached_len : e
455
+ - self.rbln_config.prefill_chunk_size
456
+ + prefix_cached_len,
457
+ ] = 1
458
+ chunked_attention_mask[:, :, :, s + prefix_cached_len : e + prefix_cached_len] = self.causal_mask
459
+
460
+ # Calculate query position if needed
461
+ if self.rbln_config.use_local_attention or self.rbln_config.logits_to_keep > 0:
462
+ query_position = (
463
+ torch.tensor((query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16)
464
+ if e >= query_length
465
+ else torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
466
+ )
467
+ else:
468
+ query_position = None
469
+
470
+ # Forward pass for the current chunk
471
+ output_logit = super().forward(
472
+ input_chunk,
473
+ cache_pos_chunk,
474
+ block_tables,
475
+ local_block_tables,
476
+ position_embed_chunk,
477
+ query_position,
478
+ chunked_attention_mask if self.rbln_config.use_attention_mask else None,
479
+ position_ids_chunk,
480
+ lora_int_ids if self.rbln_config.use_lora else None,
481
+ out=self.out_buffers,
482
+ )
483
+ output_logits.append(output_logit)
484
+
485
+ # Aggregate output_logits
486
+ output_logits = torch.concat(output_logits, dim=-2)
487
+ if self.rbln_config.logits_to_keep > 0:
488
+ output_logits = output_logits[:, -self.rbln_config.logits_to_keep :, :]
489
+ else:
490
+ output_logits = output_logits[:, :query_length, :]
491
+ # index copy for masked output_logits
492
+ if attention_mask is not None:
493
+ new_output_logits = torch.full(
494
+ (1, attention_mask.shape[-1], output_logits.shape[-1]),
495
+ fill_value=1e-10,
496
+ dtype=output_logits.dtype,
497
+ )
498
+ mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
499
+ new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
500
+
501
+ output_logits = new_output_logits
502
+
503
+ # Update decoder attention mask with processed KV-cache length from prefill phase
504
+ if self.rbln_config.can_generate and not is_external_block_tables and self.rbln_config.use_attention_mask:
505
+ self.dec_attn_mask[batch_idx].fill_(0)
506
+ self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
507
+
508
+ return RBLNDecoderOnlyOutput(logits=output_logits, padded_cache_lengths=padded_cache_lengths)
@@ -0,0 +1,119 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
16
+
17
+ import torch
18
+ from transformers import GenerationConfig
19
+ from transformers.generation.utils import GenerationMixin
20
+ from transformers.modeling_outputs import ModelOutput
21
+
22
+
23
+ if TYPE_CHECKING:
24
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
25
+
26
+
27
+ class RBLNDecoderOnlyGenerationMixin(GenerationMixin):
28
+ _supports_cache_class = False # Needed for GenerationMixin
29
+ _is_stateful = False # Needed for GenerationMixin
30
+
31
+ def _reorder_cache(self, past_key_values, beam_idx):
32
+ raise NotImplementedError
33
+
34
+ def prepare_inputs_for_generation(
35
+ self,
36
+ input_ids: torch.LongTensor,
37
+ generate_idx: Optional[torch.Tensor] = None,
38
+ attention_mask: Optional[torch.LongTensor] = None,
39
+ inputs_embeds: Optional[torch.Tensor] = None,
40
+ padded_cache_lengths: Optional[torch.Tensor] = None,
41
+ **kwargs,
42
+ ):
43
+ model_inputs = {}
44
+ is_prefill_phase = generate_idx is None
45
+
46
+ if is_prefill_phase:
47
+ generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
48
+ padded_cache_lengths = torch.zeros_like(generate_idx)
49
+ cache_position = None
50
+ position_ids = None
51
+ else:
52
+ if inputs_embeds is not None:
53
+ # if `inputs_embeds` are passed, only use them in the 1st generation step for every prompt.
54
+ inputs_embeds = None
55
+
56
+ input_ids = input_ids[:, -1:]
57
+ position_ids = generate_idx
58
+ cache_position = generate_idx + padded_cache_lengths if padded_cache_lengths is not None else generate_idx
59
+ generate_idx = generate_idx + 1
60
+ model_inputs.update({"input_ids": input_ids})
61
+
62
+ if inputs_embeds is not None:
63
+ if self.rbln_config.use_inputs_embeds:
64
+ model_inputs.update({"inputs_embeds": inputs_embeds})
65
+ else:
66
+ raise ValueError(
67
+ "The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
68
+ )
69
+ else:
70
+ model_inputs.update({"input_ids": input_ids})
71
+
72
+ model_inputs.update(
73
+ {
74
+ "attention_mask": attention_mask,
75
+ "cache_position": cache_position,
76
+ "generate_idx": generate_idx,
77
+ "position_ids": position_ids,
78
+ "padded_cache_lengths": padded_cache_lengths,
79
+ }
80
+ )
81
+
82
+ return model_inputs
83
+
84
+ def _update_model_kwargs_for_generation(
85
+ self, outputs: "RBLNDecoderOnlyOutput", model_kwargs: Dict[str, Any], **kwargs
86
+ ) -> Dict[str, Any]:
87
+ # update generate_idx
88
+ model_kwargs["generate_idx"] = outputs.generate_idx
89
+ model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
90
+ return model_kwargs
91
+
92
+ def generate(
93
+ self,
94
+ input_ids: torch.LongTensor,
95
+ attention_mask: Optional[torch.LongTensor] = None,
96
+ generation_config: Optional[GenerationConfig] = None,
97
+ **kwargs,
98
+ ) -> Union[ModelOutput, torch.LongTensor]:
99
+ """
100
+ The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
101
+ Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) for more details.
102
+
103
+ Args:
104
+ input_ids (torch.LongTensor): The input ids to the model.
105
+ attention_mask (torch.LongTensor, optional): The attention mask to the model.
106
+ generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
107
+ If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
108
+ Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
109
+ kwargs (dict[str, Any], optional): Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
110
+
111
+ Returns:
112
+ A ModelOutput (if return_dict_in_generate=True or when config.return_dict_in_generate=True) or a torch.LongTensor.
113
+ """
114
+ if generation_config is not None:
115
+ kwargs["generation_config"] = generation_config
116
+ if attention_mask is not None:
117
+ kwargs["attention_mask"] = attention_mask
118
+
119
+ return super().generate(input_ids, **kwargs)