optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,823 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from pathlib import Path
17
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
18
+
19
+ import rebel
20
+ import torch
21
+ from rebel.compile_context import CompileContext
22
+ from transformers import AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
23
+ from transformers.modeling_outputs import BaseModelOutputWithPast
24
+ from transformers.modeling_utils import no_init_weights
25
+
26
+ from ....configuration_utils import RBLNCompileConfig
27
+ from ....modeling import RBLNModel
28
+ from ....utils.logging import get_logger
29
+ from ...modeling_attention_utils import (
30
+ RBLNDecoderOnlyFlashAttentionMixin,
31
+ set_default_values,
32
+ validate_attention_method,
33
+ validate_sliding_window,
34
+ )
35
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
36
+ from ...utils.rbln_quantization import get_quantized_model
37
+ from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
38
+ from .decoderonly_architecture import DecoderOnlyWrapper
39
+ from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
40
+ from .generation_decoderonly import RBLNDecoderOnlyGenerationMixin
41
+
42
+
43
+ logger = get_logger()
44
+
45
+ if TYPE_CHECKING:
46
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
47
+
48
+
49
+ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
50
+ """
51
+ A base class for decoder-only transformer models outputting raw hidden-states without any specific head on top.
52
+ This class is used for RBLN-optimized models that are not causal language models.
53
+ This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
54
+
55
+ The class provides core functionality for:
56
+
57
+ 1. Converting pre-trained transformer models to RBLN-optimized format
58
+ 2. Handling the compilation process for RBLN devices
59
+ 3. Managing inference operations for decoder-only architectures
60
+ This class inherits from RBLNModel and implements specific methods required for
61
+ decoder-only architectures.
62
+
63
+ Note:
64
+ - This class is designed to be subclassed by specific model implementations
65
+ (e.g., RBLNLlamaModel, RBLNQwen2Model)
66
+ - Subclasses should implement model-specific conversion logic.
67
+ - The class handles RBLN-specific optimizations automatically during compilation
68
+ """
69
+
70
+ _tp_support = True
71
+
72
+ main_input_name = "input_ids"
73
+ auto_model_class = AutoModel
74
+ _decoder_wrapper_cls = DecoderOnlyWrapper
75
+ _use_rotary_emb = True
76
+ _supports_non_fp32 = True
77
+
78
+ def __post_init__(self, **kwargs):
79
+ if self.rbln_config.use_inputs_embeds:
80
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
81
+ self.embed_tokens = self._create_embedding_layer()
82
+ self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
83
+ else:
84
+ self.embed_tokens = None
85
+
86
+ self.setup_runtime()
87
+
88
+ def setup_runtime(self):
89
+ # Initialize resources to be used across Runtime instances (prefill and decode phases)
90
+ page_table_manager = RBLNPageTableManager(self.rbln_config)
91
+ dec_attn_mask = torch.zeros(self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype)
92
+ out_buffers = [torch.empty(self.prefill_output_size, dtype=self.dtype)]
93
+
94
+ common_kwargs = {
95
+ "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
96
+ "embed_tokens": self.embed_tokens,
97
+ "dec_attn_mask": dec_attn_mask,
98
+ "page_table_manager": page_table_manager,
99
+ "rbln_config": self.rbln_config,
100
+ }
101
+ self.prefill_decoder = RBLNRuntimeModel(
102
+ runtime=self.model[0],
103
+ phase="prefill",
104
+ batch_size=self.rbln_config.batch_size,
105
+ out_buffers=out_buffers,
106
+ **common_kwargs,
107
+ )
108
+ if self.can_generate():
109
+ self.decoders = {}
110
+ for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
111
+ self.decoders[batch_size] = RBLNRuntimeModel(
112
+ runtime=self.model[i + 1],
113
+ phase="decode",
114
+ batch_size=batch_size,
115
+ **common_kwargs,
116
+ )
117
+
118
+ # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
119
+ self.decoder = self.decoders[self.rbln_config.batch_size]
120
+
121
+ @property
122
+ def prefill_output_size(self):
123
+ return (
124
+ 1,
125
+ self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
126
+ self.config.hidden_size,
127
+ )
128
+
129
+ @classmethod
130
+ def get_quantized_model(
131
+ cls,
132
+ model_id: str,
133
+ config: Optional[PretrainedConfig] = None,
134
+ use_auth_token: Optional[Union[bool, str]] = None,
135
+ revision: Optional[str] = None,
136
+ force_download: bool = False,
137
+ cache_dir: Optional[str] = None,
138
+ subfolder: str = "",
139
+ local_files_only: bool = False,
140
+ trust_remote_code: bool = False,
141
+ rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
142
+ **kwargs,
143
+ ):
144
+ kwargs = cls.update_kwargs(kwargs)
145
+
146
+ return get_quantized_model(
147
+ cls.auto_model_class,
148
+ model_id,
149
+ use_auth_token=use_auth_token,
150
+ revision=revision,
151
+ cache_dir=cache_dir,
152
+ force_download=force_download,
153
+ local_files_only=local_files_only,
154
+ rbln_quantization=rbln_config.quantization,
155
+ **kwargs,
156
+ )
157
+
158
+ def __getattr__(self, __name: str) -> Any:
159
+ # Special method to delegate attribute access to the original Huggingface LM class.
160
+ # This method is called when an attribute is not found in the current instance's dictionary.
161
+ # It enables transparent access to the original model's attributes and methods while maintaining
162
+ # proper method binding.
163
+
164
+ # The method implements a delegation pattern that:
165
+
166
+ # 1. For methods: Creates a wrapper that properly binds 'self' to method calls
167
+ # 2. For other attributes: Returns them directly from the original class
168
+
169
+ def redirect(func):
170
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
171
+
172
+ val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
173
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
174
+ return redirect(val)
175
+ return val
176
+
177
+ @classmethod
178
+ def save_torch_artifacts(
179
+ cls,
180
+ model: PreTrainedModel,
181
+ save_dir_path: Path,
182
+ subfolder: str,
183
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
184
+ ):
185
+ # If you are unavoidably running on a CPU rather than an RBLN device,
186
+ # store the torch tensor, weight, etc. in this function.
187
+ if rbln_config.use_inputs_embeds:
188
+ save_dict = {}
189
+ save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
190
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
191
+
192
+ def _create_embedding_layer(self):
193
+ with no_init_weights():
194
+ embed_tokens = torch.nn.Embedding(
195
+ self.config.vocab_size,
196
+ self.config.hidden_size,
197
+ self.config.pad_token_id,
198
+ )
199
+ return embed_tokens
200
+
201
+ def get_decoder(self):
202
+ if not self.can_generate():
203
+ raise ValueError("Decode stage is not supported in this model.")
204
+ return self.decoder
205
+
206
+ def can_generate(self):
207
+ return self.rbln_config.can_generate
208
+
209
+ def get_input_embeddings(self):
210
+ return self.embed_tokens
211
+
212
+ def get_attn_impl(self) -> str:
213
+ return self.rbln_config.attn_impl
214
+
215
+ def get_kvcache_num_blocks(self) -> int:
216
+ return self.rbln_config.kvcache_num_blocks
217
+
218
+ @classmethod
219
+ def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
220
+ return cls._decoder_wrapper_cls(model, rbln_config, cls._use_rotary_emb).eval()
221
+
222
+ @classmethod
223
+ def _compile_model(
224
+ cls,
225
+ wrapped_model,
226
+ compile_config,
227
+ example_inputs,
228
+ compile_context,
229
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
230
+ quantization=None,
231
+ phase: str = "prefill",
232
+ ):
233
+ try:
234
+ wrapped_model.phase = phase
235
+ if quantization:
236
+ quantization.maybe_set_quantization_env()
237
+ original_linear = torch.nn.functional.linear
238
+ torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
239
+ compiled_model = cls.compile(
240
+ wrapped_model,
241
+ compile_config,
242
+ create_runtimes=rbln_config.create_runtimes,
243
+ device=rbln_config.device,
244
+ example_inputs=example_inputs,
245
+ compile_context=compile_context,
246
+ )
247
+ return compiled_model
248
+ finally:
249
+ torch.nn.functional.linear = original_linear
250
+ if quantization:
251
+ quantization.maybe_reset_quantization_env()
252
+
253
+ @classmethod
254
+ def _get_compile_context(
255
+ cls,
256
+ compile_config: RBLNCompileConfig,
257
+ example_inputs: List[torch.Tensor],
258
+ ):
259
+ context = CompileContext(use_weight_sharing=True)
260
+
261
+ # Mark static tensors (self kv states)
262
+ static_tensors = {}
263
+ idx = 0
264
+ for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
265
+ if "past_key_values" in name:
266
+ static_tensors[name] = tensor
267
+ context.mark_static_address(tensor, f"kv_cache_{idx}")
268
+ idx += 1
269
+
270
+ return context, static_tensors
271
+
272
+ @classmethod
273
+ @torch.inference_mode()
274
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
275
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
276
+ prefill_compile_config = rbln_config.compile_cfgs[0]
277
+
278
+ # Here we use meta tensor, for the memory efficiency.
279
+ meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
280
+ prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
281
+ context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
282
+
283
+ compiled_models = {}
284
+ compiled_models["prefill"] = cls._compile_model(
285
+ wrapped_model,
286
+ prefill_compile_config,
287
+ prefill_example_inputs,
288
+ context,
289
+ rbln_config,
290
+ rbln_config.quantization,
291
+ phase="prefill",
292
+ )
293
+
294
+ if rbln_config.can_generate:
295
+ wrapped_model.phase = "decode"
296
+ for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
297
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
298
+ compiled_decoder = cls._compile_model(
299
+ wrapped_model,
300
+ dec_compile_config,
301
+ dec_example_inputs,
302
+ context,
303
+ rbln_config,
304
+ rbln_config.quantization,
305
+ phase="decode",
306
+ )
307
+ compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
308
+
309
+ # check if the memory is enough to have additional blocks
310
+ required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
311
+ if rbln_config.kvcache_num_blocks < required_num_blocks:
312
+ cls.maybe_suggest_kvcache_num_blocks(
313
+ compiled_models=compiled_models,
314
+ model_config=model.config,
315
+ rbln_config=rbln_config,
316
+ )
317
+
318
+ return compiled_models
319
+
320
+ @classmethod
321
+ def get_pytorch_model(
322
+ cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
323
+ ) -> PreTrainedModel:
324
+ if rbln_config and rbln_config.quantization:
325
+ model = cls.get_quantized_model(*args, rbln_config=rbln_config, **kwargs)
326
+ else:
327
+ model = super().get_pytorch_model(*args, **kwargs)
328
+
329
+ return model
330
+
331
+ @classmethod
332
+ def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
333
+ return use_local_attention
334
+
335
+ @classmethod
336
+ def get_input_info(
337
+ cls,
338
+ batch_size: int,
339
+ query_length: int,
340
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
341
+ model_config: PretrainedConfig,
342
+ ):
343
+ num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
344
+ num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
345
+ num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
346
+ hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
347
+ head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
348
+ is_prefill = query_length > 1
349
+
350
+ input_info = []
351
+ if rbln_config.use_inputs_embeds:
352
+ input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.torch_dtype))
353
+ else:
354
+ input_info.append(("input_ids", [batch_size, query_length], "int64"))
355
+
356
+ input_info.append(("cache_position", [batch_size, query_length], "int32"))
357
+
358
+ if rbln_config.use_global_attention:
359
+ max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
360
+ input_info.append(
361
+ ("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")
362
+ )
363
+ if rbln_config.use_local_attention:
364
+ input_info.append(("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16"))
365
+
366
+ if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
367
+ input_info.append(("query_position", [], "int16"))
368
+
369
+ if rbln_config.use_attention_mask:
370
+ if rbln_config.use_position_ids:
371
+ input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.torch_dtype))
372
+ else:
373
+ input_info.append(
374
+ ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.torch_dtype)
375
+ )
376
+
377
+ if rbln_config.use_position_ids:
378
+ input_info.append(("position_ids", [batch_size, query_length], "int32"))
379
+
380
+ if rbln_config.use_lora:
381
+ input_info.append(("lora_int_ids", [batch_size], "int32"))
382
+
383
+ kvcache_dtype = rbln_config.torch_dtype
384
+ if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
385
+ kvcache_dtype = "float8_e4m3fn"
386
+
387
+ global_kvcache_shape = [
388
+ rbln_config.kvcache_num_blocks,
389
+ num_key_value_heads,
390
+ rbln_config.kvcache_block_size,
391
+ head_dim,
392
+ ]
393
+ local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
394
+ input_info.extend(
395
+ [
396
+ (
397
+ f"past_key_values_{i}",
398
+ local_kvcache_shape
399
+ if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
400
+ else global_kvcache_shape,
401
+ kvcache_dtype,
402
+ )
403
+ for i in range(num_hidden_layers * 2)
404
+ ]
405
+ )
406
+
407
+ return input_info
408
+
409
+ @classmethod
410
+ def _update_sliding_window_config(
411
+ cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
412
+ ):
413
+ # Update the sliding window configuration for the RBLN model.
414
+
415
+ # This method must be implemented by subclasses to handle their specific sliding window configurations,
416
+ # as Hugging Face models use different configuration keys to represent sliding window layers.
417
+
418
+ # Args:
419
+ # model_config (PretrainedConfig): The model configuration from Hugging Face.
420
+ # rbln_config (RBLNDecoderOnlyModelForCausalLMConfig): The RBLN model configuration.
421
+
422
+ # Notes:
423
+ # Required configuration settings:
424
+ # - `cache_impl`: Must be one of:
425
+ # - "static": All layers use global attention (no sliding window)
426
+ # - "sliding_window": All layers use sliding window attention
427
+ # - "hybrid": A mix of global and sliding window attention layers
428
+ # - `sliding_window`: Width of the sliding window (required if cache_impl is "sliding_window" or "hybrid")
429
+ # - `sliding_window_layers`: List of layer indices using sliding window attention (required if cache_impl is "hybrid")
430
+
431
+ # Example implementation for a 'sliding_window' model:
432
+ # ```python
433
+ # rbln_config.cache_impl = "sliding_window"
434
+ # rbln_config.sliding_window = model_config.sliding_window
435
+ # rbln_config.sliding_window_layers = [i for i in range(model_config.num_hidden_layers)]
436
+ # return rbln_config
437
+ # ```
438
+
439
+ # Returns:
440
+ # RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
441
+
442
+ raise NotImplementedError(
443
+ "Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
444
+ "See method docstring for required configuration details."
445
+ )
446
+
447
+ @classmethod
448
+ def _update_attention_config(
449
+ cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
450
+ ):
451
+ rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
452
+ attn_impl=rbln_config.attn_impl,
453
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
454
+ kvcache_block_size=rbln_config.kvcache_block_size,
455
+ max_seq_len=rbln_config.max_seq_len,
456
+ )
457
+
458
+ validate_attention_method(
459
+ attn_impl=rbln_config.attn_impl,
460
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
461
+ kvcache_block_size=rbln_config.kvcache_block_size,
462
+ max_seq_len=rbln_config.max_seq_len,
463
+ )
464
+
465
+ num_full_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
466
+
467
+ # Update kvcache_num_blocks based on the attention implementation.
468
+ if rbln_config.attn_impl == "flash_attn":
469
+ estimated_max_num_blocks = cls.get_maximum_num_blocks_by_model(
470
+ model=model, model_config=model_config, rbln_config=rbln_config
471
+ )
472
+
473
+ if rbln_config.kvcache_num_blocks is None:
474
+ if estimated_max_num_blocks < num_full_blocks:
475
+ # lower bound of the number of blocks for flash attention.
476
+ min_blocks_for_flash = min(
477
+ rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1, num_full_blocks
478
+ )
479
+ if min_blocks_for_flash > estimated_max_num_blocks:
480
+ # NOTE: Just try to compile with lower bound of blocks for flash attention.
481
+ # Even if it's larger than the estimated maximum number of blocks.
482
+ rbln_config.kvcache_num_blocks = min_blocks_for_flash
483
+ else:
484
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
485
+ rbln_config.kvcache_num_blocks = estimated_max_num_blocks
486
+
487
+ if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
488
+ raise RuntimeError(
489
+ f"Batch size ({rbln_config.batch_size}) exceeds num_blocks ({rbln_config.kvcache_num_blocks}). "
490
+ "Ensure the number of blocks is at least equal to the batch size."
491
+ )
492
+ else:
493
+ rbln_config.kvcache_num_blocks = num_full_blocks
494
+ elif rbln_config.kvcache_num_blocks > estimated_max_num_blocks:
495
+ logger.warning(
496
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
497
+ f" than the estimated maximum number of blocks ({estimated_max_num_blocks})."
498
+ "This can cause a failure during model compilation."
499
+ )
500
+ else:
501
+ if rbln_config.kvcache_num_blocks is None:
502
+ rbln_config.kvcache_num_blocks = num_full_blocks
503
+ elif rbln_config.kvcache_num_blocks > num_full_blocks:
504
+ logger.warning(
505
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
506
+ f" than the required number of blocks ({num_full_blocks})."
507
+ "This can cause a failure during model compilation."
508
+ )
509
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
510
+
511
+ return rbln_config
512
+
513
+ @classmethod
514
+ def _update_rbln_config(
515
+ cls,
516
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
517
+ model: Optional[PreTrainedModel] = None,
518
+ model_config: Optional[PretrainedConfig] = None,
519
+ rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
520
+ ) -> RBLNDecoderOnlyModelForCausalLMConfig:
521
+ if rbln_config.max_seq_len is None:
522
+ rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
523
+ model_config, "n_positions", None
524
+ )
525
+ if rbln_config.max_seq_len is None:
526
+ raise ValueError("`max_seq_len` should be specified.")
527
+
528
+ if getattr(model_config, "sliding_window", None) is not None and getattr(
529
+ model_config, "use_sliding_window", True
530
+ ):
531
+ rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
532
+ if rbln_config.sliding_window is not None:
533
+ validate_sliding_window(rbln_config)
534
+
535
+ rbln_config = cls._update_attention_config(model, model_config, rbln_config)
536
+
537
+ prefill_input_info = cls.get_input_info(
538
+ batch_size=1,
539
+ query_length=rbln_config.prefill_chunk_size,
540
+ rbln_config=rbln_config,
541
+ model_config=model_config,
542
+ )
543
+
544
+ prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
545
+ compile_cfgs = [prefill_compile_config]
546
+
547
+ if rbln_config.can_generate:
548
+ for batch_size in rbln_config.decoder_batch_sizes:
549
+ dec_input_info = cls.get_input_info(
550
+ batch_size=batch_size,
551
+ query_length=1,
552
+ rbln_config=rbln_config,
553
+ model_config=model_config,
554
+ )
555
+ compile_cfgs.append(
556
+ RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
557
+ )
558
+ rbln_config.set_compile_cfgs(compile_cfgs)
559
+
560
+ return rbln_config
561
+
562
+ @classmethod
563
+ def _create_runtimes(
564
+ cls,
565
+ compiled_models: List[rebel.RBLNCompiledModel],
566
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
567
+ ) -> List[rebel.Runtime]:
568
+ expected_model_names = ["prefill"]
569
+ if rbln_config.can_generate:
570
+ expected_model_names.extend(
571
+ [f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
572
+ )
573
+ if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
574
+ cls._raise_missing_compiled_file_error(expected_model_names)
575
+
576
+ ret_val = [
577
+ rebel.Runtime(
578
+ compiled_models[0],
579
+ tensor_type="pt",
580
+ device=rbln_config.device_map["prefill"],
581
+ activate_profiler=rbln_config.activate_profiler,
582
+ timeout=rbln_config.timeout,
583
+ )
584
+ ]
585
+ if rbln_config.can_generate:
586
+ ret_val.extend(
587
+ [
588
+ rebel.Runtime(
589
+ compiled_models[i + 1],
590
+ tensor_type="pt",
591
+ device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
592
+ activate_profiler=rbln_config.activate_profiler,
593
+ timeout=rbln_config.timeout,
594
+ )
595
+ for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
596
+ ]
597
+ )
598
+ return ret_val
599
+
600
+ def forward(
601
+ self,
602
+ input_ids: Optional[torch.LongTensor] = None,
603
+ inputs_embeds: Optional[torch.Tensor] = None,
604
+ attention_mask: Optional[torch.LongTensor] = None,
605
+ **kwargs,
606
+ ) -> BaseModelOutputWithPast:
607
+ """
608
+ Args:
609
+ input_ids (torch.LongTensor, optional): The input IDs to the model.
610
+ inputs_embeds (torch.Tensor, optional): The input embeddings to the model.
611
+ attention_mask (torch.LongTensor, optional): The attention mask to the model.
612
+ kwargs (dict[str, Any], optional): Additional keyword arguments.
613
+
614
+ Returns:
615
+ Dataclass containing the last hidden states of the model.
616
+ """
617
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
618
+ batch_size = inputs.shape[0]
619
+ position_embed = kwargs.get("position_embed", None)
620
+
621
+ if batch_size != self.rbln_config.batch_size:
622
+ raise ValueError(
623
+ f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
624
+ )
625
+
626
+ all_last_hidden_states = []
627
+ for b_idx in range(self.rbln_config.batch_size):
628
+ query_length = (
629
+ attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
630
+ )
631
+ cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
632
+ last_hidden_states = self.prefill_decoder(
633
+ inputs[b_idx : b_idx + 1],
634
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
635
+ position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
636
+ cache_position=cache_position,
637
+ batch_idx=b_idx,
638
+ ).logits
639
+ all_last_hidden_states.append(last_hidden_states)
640
+
641
+ last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
642
+
643
+ return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
644
+
645
+
646
+ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
647
+ """
648
+ A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
649
+ This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
650
+
651
+ The class provides core functionality for:
652
+
653
+ 1. Converting pre-trained transformer models to RBLN-optimized format
654
+ 2. Handling the compilation process for RBLN devices
655
+ 3. Managing inference operations for causal language modeling
656
+ This class inherits from RBLNModel and implements specific methods required for
657
+ decoder-only architectures and causal language modeling tasks.
658
+
659
+ Note:
660
+ - This class is designed to be subclassed by specific model implementations
661
+ (e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
662
+ - Subclasses should implement model-specific conversion logic.
663
+ - The class handles RBLN-specific optimizations automatically during compilation
664
+ """
665
+
666
+ auto_model_class = AutoModelForCausalLM
667
+
668
+ @property
669
+ def prefill_output_size(self):
670
+ return (
671
+ 1,
672
+ self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
673
+ self.config.vocab_size,
674
+ )
675
+
676
+ @classmethod
677
+ def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
678
+ return is_prefill
679
+
680
+ def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
681
+ if isinstance(lora_int_ids, int):
682
+ lora_int_ids = torch.tensor([lora_int_ids], dtype=torch.int32)
683
+ elif isinstance(lora_int_ids, list):
684
+ lora_int_ids = torch.tensor(lora_int_ids, dtype=torch.int32)
685
+
686
+ self.lora_int_ids = lora_int_ids
687
+
688
+ self.prefill_decoder.lora_int_ids = lora_int_ids
689
+ if self.rbln_config.can_generate:
690
+ for batch_size in self.rbln_config.decoder_batch_sizes:
691
+ self.decoders[batch_size].lora_int_ids = lora_int_ids
692
+
693
+ def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
694
+ """
695
+ Sets the active adapter(s) for the model using adapter name(s).
696
+
697
+ Args:
698
+ adapter_name (Union[str, List[str]]): The name(s) of the adapter(s) to be activated.
699
+ Can be a single adapter name or a list of adapter names.
700
+
701
+ Raises:
702
+ ValueError: If the model is not configured with LoRA or if the adapter name is not found.
703
+ """
704
+ if not hasattr(self.rbln_config, "lora_config") or self.rbln_config.lora_config is None:
705
+ raise ValueError("Model is not configured with LoRA. Cannot set adapter.")
706
+
707
+ # Convert single adapter name to list for uniform processing
708
+ if isinstance(adapter_name, str):
709
+ adapter_names = [adapter_name]
710
+ else:
711
+ adapter_names = adapter_name
712
+
713
+ # Validate that all adapter names exist
714
+ available_adapters = {
715
+ adapter.lora_name: adapter.lora_int_id for adapter in self.rbln_config.lora_config.adapters
716
+ }
717
+ missing_adapters = [name for name in adapter_names if name not in available_adapters]
718
+ if missing_adapters:
719
+ raise ValueError(
720
+ f"Adapter(s) {missing_adapters} not found. Available adapters: {list(available_adapters.keys())}"
721
+ )
722
+
723
+ # Get the adapter IDs and set them
724
+ lora_int_ids = [available_adapters[name] for name in adapter_names]
725
+ self.set_lora_int_ids(torch.tensor(lora_int_ids, dtype=torch.int32))
726
+
727
+ def forward(
728
+ self,
729
+ input_ids: Optional[torch.LongTensor] = None,
730
+ inputs_embeds: Optional[torch.Tensor] = None,
731
+ cache_position: Optional[torch.Tensor] = None,
732
+ attention_mask: Optional[torch.LongTensor] = None,
733
+ generate_idx: Optional[torch.Tensor] = None,
734
+ padded_cache_lengths: Optional[torch.Tensor] = None,
735
+ position_ids: Optional[torch.Tensor] = None,
736
+ token_type_ids: Optional[torch.Tensor] = None,
737
+ lora_int_ids: Optional[torch.Tensor] = None,
738
+ return_dict: Optional[torch.Tensor] = None,
739
+ **kwargs,
740
+ ) -> Tuple[torch.FloatTensor]:
741
+ # Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
742
+ # For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
743
+ # A for-loop ensures synchronization with the HuggingFace generate API.
744
+ # The decoder stage operates as usual, processing inputs in batch mode.
745
+ if self.rbln_config.use_lora and lora_int_ids is None:
746
+ if self.lora_int_ids is None:
747
+ raise ValueError(
748
+ "lora_int_id is required when using LoRA. "
749
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
750
+ )
751
+ lora_int_ids = self.lora_int_ids
752
+
753
+ # for only use forward
754
+ if generate_idx is None:
755
+ generate_idx = (
756
+ attention_mask.sum(dim=-1, keepdim=True).int()
757
+ if attention_mask is not None
758
+ else torch.full((input_ids.shape[0], 1), input_ids.shape[1], dtype=torch.int32)
759
+ )
760
+ padded_cache_lengths = torch.zeros_like(generate_idx)
761
+
762
+ # Prefill
763
+ if cache_position is None:
764
+ logits = []
765
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
766
+ batch_size = inputs.shape[0]
767
+ input_len = inputs.shape[1]
768
+ if batch_size > self.rbln_config.batch_size:
769
+ raise ValueError(
770
+ f"Input's batch({batch_size}) exceeds compiled batch_size({self.rbln_config.batch_size})"
771
+ )
772
+ if input_len > self.rbln_config.max_seq_len:
773
+ raise ValueError(
774
+ f"Input's length({input_len}) exceeds compiled max_seq_len({self.rbln_config.max_seq_len})."
775
+ )
776
+
777
+ for b_idx in range(batch_size):
778
+ cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
779
+ output = self.prefill_decoder(
780
+ input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
781
+ inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
782
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
783
+ cache_position=cache_position,
784
+ batch_idx=b_idx,
785
+ token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
786
+ lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
787
+ )
788
+ padded_cache_lengths[b_idx] += output.padded_cache_lengths
789
+ logits.append(output.logits)
790
+ logits = torch.cat(logits, dim=0)
791
+ # Decoder
792
+ else:
793
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
794
+ batch_size = inputs.shape[0]
795
+ if batch_size not in self.decoders:
796
+ raise ValueError(
797
+ f"No decoder runtime available for batch size {batch_size}. "
798
+ f"Available batch sizes are: {list(self.decoders.keys())}. "
799
+ f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
800
+ )
801
+ if max(cache_position.reshape(-1)) >= self.rbln_config.max_seq_len:
802
+ raise ValueError(
803
+ f"Cache position exceeds the maximum sequence length.\n"
804
+ f" - Current max cache position: {int(torch.max(cache_position).item())}\n"
805
+ f" - Allowed max_seq_len: {self.rbln_config.max_seq_len}\n"
806
+ f"Solution: Reduce the generation length by adjusting `max_new_tokens` "
807
+ f"or `max_length` in the generation config."
808
+ )
809
+
810
+ logits = self.decoders[batch_size](
811
+ input_ids=input_ids,
812
+ inputs_embeds=inputs_embeds,
813
+ cache_position=cache_position,
814
+ position_ids=position_ids if self.rbln_config.use_position_ids else None,
815
+ lora_int_ids=lora_int_ids,
816
+ ).logits
817
+
818
+ if not return_dict:
819
+ return logits, generate_idx, padded_cache_lengths
820
+ else:
821
+ return RBLNDecoderOnlyOutput(
822
+ logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
823
+ )