optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,611 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import importlib
15
+ import inspect
16
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import rebel
19
+ import torch
20
+ from rebel.compile_context import CompileContext
21
+ from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
23
+ from transformers.modeling_utils import no_init_weights
24
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding
25
+
26
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
27
+ from ....modeling import RBLNModel
28
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
29
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
30
+ from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
31
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
32
+ from ..decoderonly.modeling_decoderonly import (
33
+ RBLNDecoderOnlyModelForCausalLM,
34
+ )
35
+ from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
36
+ from .gemma3_architecture import Gemma3ForCausalLMWrapper
37
+ from .gemma3_runtime_utils import RBLNGemma3RuntimeModel
38
+
39
+
40
+ if TYPE_CHECKING:
41
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
42
+
43
+
44
+ class LoopVisionTower(LoopProcessor):
45
+ def __init__(self, vision_tower: "RBLNModel"):
46
+ super().__init__(model=vision_tower)
47
+
48
+ def _get_batch_size(self, pixel_values, **kwargs):
49
+ return pixel_values.shape[0]
50
+
51
+ def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
52
+ pixel_values_item = pixel_values[index : index + 1]
53
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
54
+ return ([pixel_values_item], {"out": out_buffer})
55
+
56
+ def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
57
+ output = kwargs["out"]
58
+
59
+ return BaseModelOutputWithPooling(
60
+ last_hidden_state=output[0],
61
+ )
62
+
63
+
64
+ class LoopProjector(LoopProcessor):
65
+ def __init__(self, multi_modal_projector: "RBLNModel"):
66
+ super().__init__(model=multi_modal_projector)
67
+
68
+ def _get_batch_size(self, image_feature, **kwargs):
69
+ return image_feature.shape[0]
70
+
71
+ def _prepare_inputs_for_iteration(self, index, common_inputs, image_feature, **kwargs):
72
+ image_feature_item = image_feature[index : index + 1]
73
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
74
+ return ([image_feature_item], {"out": out_buffer})
75
+
76
+ def _process_outputs(self, outputs: list, **kwargs):
77
+ output = kwargs["out"]
78
+ return output[0]
79
+
80
+
81
+ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
82
+ auto_model_class = AutoModelForImageTextToText
83
+ _rbln_submodules = [
84
+ {"name": "vision_tower"},
85
+ {"name": "language_model"},
86
+ ]
87
+
88
+ def __getattr__(self, __name: str) -> Any:
89
+ def redirect(func):
90
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
91
+
92
+ val = getattr(Gemma3ForConditionalGeneration, __name)
93
+
94
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
95
+ return redirect(val)
96
+ return val
97
+
98
+ def can_generate(self):
99
+ return True
100
+
101
+ @classmethod
102
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
103
+ with no_init_weights():
104
+ model_cls_name = model.model.language_model.__class__.__name__
105
+ causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")
106
+ causal_model_cls = getattr(importlib.import_module("transformers"), causal_model_cls_name)
107
+ new_language_model = causal_model_cls(model.model.language_model.config)
108
+
109
+ new_language_model.lm_head = model.lm_head
110
+ new_language_model.model = model.model.language_model
111
+ model.model.language_model = new_language_model
112
+ model.lm_head = None
113
+ del model.lm_head
114
+ return model
115
+
116
+ def __post_init__(self, **kwargs):
117
+ self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
118
+ self.language_model = self.rbln_submodules[1]
119
+ self.multi_modal_projector = LoopProjector(self.model[0])
120
+ self.vocab_size = self.config.text_config.vocab_size
121
+
122
+ # Copied from the original class
123
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
124
+ return super().__post_init__(**kwargs)
125
+
126
+ def get_attn_impl(self) -> str:
127
+ return self.rbln_config.language_model.attn_impl
128
+
129
+ def get_kvcache_num_blocks(self) -> int:
130
+ return self.rbln_config.language_model.kvcache_num_blocks
131
+
132
+ def get_input_embeddings(self):
133
+ return self.language_model.get_input_embeddings()
134
+
135
+ @classmethod
136
+ def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
137
+ return model.multi_modal_projector
138
+
139
+ @classmethod
140
+ def _update_rbln_config(
141
+ cls,
142
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
143
+ model: Optional["PreTrainedModel"] = None,
144
+ model_config: Optional["PretrainedConfig"] = None,
145
+ rbln_config: Optional[RBLNModelConfig] = None,
146
+ ) -> RBLNModelConfig:
147
+ image_feature_dim = (model_config.vision_config.image_size // model_config.vision_config.patch_size) ** 2
148
+ feature_size = model_config.vision_config.hidden_size
149
+
150
+ input_info = [("image_features", [rbln_config.batch_size, image_feature_dim, feature_size], "float32")]
151
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
152
+ rbln_config.set_compile_cfgs([rbln_compile_config])
153
+ return rbln_config
154
+
155
+ def prepare_inputs_for_generation(
156
+ self,
157
+ input_ids,
158
+ inputs_embeds=None,
159
+ pixel_values=None,
160
+ image_sizes=None,
161
+ attention_mask=None,
162
+ generate_idx=None,
163
+ padded_cache_lengths=None,
164
+ token_type_ids=None,
165
+ **kwargs,
166
+ ):
167
+ # Prepare HF generation
168
+ is_prefill_phase = generate_idx is None
169
+
170
+ model_inputs = self.language_model.prepare_inputs_for_generation(
171
+ input_ids=input_ids,
172
+ inputs_embeds=inputs_embeds,
173
+ generate_idx=generate_idx, # Not affect
174
+ attention_mask=attention_mask,
175
+ padded_cache_lengths=padded_cache_lengths,
176
+ **kwargs,
177
+ )
178
+
179
+ if is_prefill_phase:
180
+ model_inputs.update(
181
+ {
182
+ "pixel_values": pixel_values,
183
+ "image_sizes": image_sizes,
184
+ "token_type_ids": token_type_ids,
185
+ }
186
+ )
187
+
188
+ model_inputs["attention_mask"] = attention_mask
189
+
190
+ return model_inputs
191
+
192
+ def _update_model_kwargs_for_generation(
193
+ self,
194
+ outputs: RBLNDecoderOnlyOutput,
195
+ model_kwargs: Dict[str, Any],
196
+ **kwargs,
197
+ ) -> Dict[str, Any]:
198
+ # update generate_idx
199
+ model_kwargs["generate_idx"] = outputs.generate_idx
200
+ model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
201
+
202
+ return model_kwargs
203
+
204
+ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
205
+ # Projects the last hidden state from the vision model into language model space.
206
+
207
+ # Args:
208
+ # pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
209
+ # The tensors corresponding to the input images.
210
+
211
+ # Returns:
212
+ # Image feature tensor of shape `(num_images, image_length, embed_dim)`.
213
+
214
+ vision_out_buffer = []
215
+ vision_out_size = [
216
+ pixel_values.shape[0],
217
+ (self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2,
218
+ self.config.vision_config.hidden_size,
219
+ ]
220
+ projector_out_size = [
221
+ pixel_values.shape[0],
222
+ self.config.mm_tokens_per_image,
223
+ self.config.text_config.hidden_size,
224
+ ]
225
+ vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
226
+ projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
227
+ vision_outputs = self.vision_tower(pixel_values, out=vision_out_buffer).last_hidden_state
228
+ image_features = self.multi_modal_projector(vision_outputs, out=projector_out_buffer)
229
+ return image_features
230
+
231
+ def _preprocess_prefill(
232
+ self,
233
+ input_ids: Optional[torch.LongTensor] = None,
234
+ inputs_embeds: Optional[torch.FloatTensor] = None,
235
+ pixel_values: Optional[torch.FloatTensor] = None,
236
+ **kwargs,
237
+ ):
238
+ if (input_ids is None) ^ (inputs_embeds is not None):
239
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
240
+
241
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
242
+ if input_ids is not None and self.config.image_token_index >= self.vocab_size:
243
+ special_image_mask = input_ids == self.config.image_token_index
244
+ llm_input_ids = input_ids.clone()
245
+ llm_input_ids[special_image_mask] = 0
246
+ else:
247
+ llm_input_ids = input_ids
248
+
249
+ if inputs_embeds is None:
250
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
251
+
252
+ # Merge text and images
253
+ if pixel_values is not None:
254
+ image_features = self.get_image_features(pixel_values)
255
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
256
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
257
+
258
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
259
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
260
+
261
+ return inputs_embeds
262
+
263
+ def get_padded_cache_position(
264
+ self,
265
+ cache_position: torch.Tensor, # shape: [1, seq_len]
266
+ token_type_ids: torch.Tensor, # shape: [1, seq_len]
267
+ ) -> torch.Tensor:
268
+ seq_len = cache_position[0][-1].item() + 1
269
+
270
+ # Find image start positions
271
+ image_starts = [
272
+ s
273
+ for s in torch.where(token_type_ids == 1)[1]
274
+ if torch.all(token_type_ids[:, s : s + self.rbln_config.image_prefill_chunk_size] == 1)
275
+ ]
276
+
277
+ # Initialize padded tensors
278
+ padded_input_len = seq_len
279
+ for image_start in image_starts:
280
+ pad_needed = (
281
+ self.rbln_config.image_prefill_chunk_size
282
+ - (image_start + padded_input_len - seq_len) % self.rbln_config.image_prefill_chunk_size
283
+ ) % self.rbln_config.image_prefill_chunk_size
284
+ padded_input_len += pad_needed
285
+
286
+ return torch.cat(
287
+ [cache_position, torch.arange(seq_len, padded_input_len, dtype=torch.int32).unsqueeze(0)],
288
+ dim=1,
289
+ )
290
+
291
+ def forward(
292
+ self,
293
+ input_ids: torch.LongTensor = None,
294
+ attention_mask: torch.Tensor = None,
295
+ token_type_ids: torch.Tensor = None,
296
+ pixel_values: torch.FloatTensor = None,
297
+ cache_position: Optional[torch.LongTensor] = None,
298
+ inputs_embeds: Optional[torch.FloatTensor] = None,
299
+ generate_idx: Optional[torch.Tensor] = None,
300
+ padded_cache_lengths: Optional[torch.Tensor] = None,
301
+ position_ids: Optional[torch.Tensor] = None,
302
+ **lm_kwargs: Dict[str, Any],
303
+ ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
304
+ # prefill
305
+ if cache_position is None:
306
+ logits = []
307
+ inputs_embeds = self._preprocess_prefill(input_ids, inputs_embeds, pixel_values)
308
+ batch_size = inputs_embeds.shape[0]
309
+
310
+ for b_idx in range(batch_size):
311
+ cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
312
+ token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
313
+ cache_position = self.get_padded_cache_position(cache_position, token_type_id)
314
+
315
+ output = self.language_model.prefill_decoder(
316
+ inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
317
+ attention_mask=attention_mask[b_idx],
318
+ cache_position=cache_position,
319
+ batch_idx=b_idx,
320
+ token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
321
+ )
322
+ padded_cache_lengths[b_idx] += output.padded_cache_lengths
323
+ logits.append(output.logits)
324
+
325
+ logits = torch.cat(logits, dim=0)
326
+ # decoder
327
+ else:
328
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
329
+ batch_size = inputs.shape[0]
330
+ if batch_size not in self.language_model.decoders:
331
+ raise ValueError(
332
+ f"No decoder runtime available for batch size {batch_size}. "
333
+ f"Available batch sizes are: {list(self.decoders.keys())}. "
334
+ f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
335
+ )
336
+
337
+ logits = self.language_model.decoders[batch_size](
338
+ input_ids=input_ids,
339
+ inputs_embeds=inputs_embeds,
340
+ cache_position=cache_position,
341
+ position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
342
+ ).logits
343
+
344
+ return RBLNDecoderOnlyOutput(
345
+ logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
346
+ )
347
+
348
+
349
+ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
350
+ """
351
+ The Gemma3 Model transformer with a language modeling head (linear layer) on top.
352
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
353
+
354
+ A class to convert and run pre-trained transformers based Gemma3ForCausalLM model on RBLN devices.
355
+ It implements the methods to convert a pre-trained transformers Gemma3ForCausalLM model into a RBLN transformer model by:
356
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
357
+ - compiling the resulting graph using the RBLN compiler.
358
+ """
359
+
360
+ _decoder_wrapper_cls = Gemma3ForCausalLMWrapper
361
+ _supports_non_fp32 = False
362
+
363
+ def setup_runtime(self):
364
+ # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
365
+ dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=torch.float32)
366
+ page_table_manager = RBLNPageTableManager(self.rbln_config)
367
+
368
+ common_kwargs = {
369
+ "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
370
+ "embed_tokens": self.embed_tokens,
371
+ "dec_attn_mask": dec_attn_mask,
372
+ "page_table_manager": page_table_manager,
373
+ "rbln_config": self.rbln_config,
374
+ }
375
+
376
+ self.prefill_decoder = RBLNGemma3RuntimeModel(
377
+ runtime=self.model[0],
378
+ image_prefill=self.model[1] if self.rbln_config.use_image_prefill else None,
379
+ phase="prefill",
380
+ batch_size=self.rbln_config.batch_size,
381
+ **common_kwargs,
382
+ )
383
+
384
+ self.decoders = {}
385
+ for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
386
+ self.decoders[batch_size] = RBLNGemma3RuntimeModel(
387
+ runtime=self.model[i + self.rbln_config.decoder_runtime_idx],
388
+ phase="decode",
389
+ batch_size=batch_size,
390
+ **common_kwargs,
391
+ )
392
+
393
+ # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
394
+ self.decoder = self.decoders[self.rbln_config.batch_size]
395
+
396
+ def _create_embedding_layer(self):
397
+ with no_init_weights():
398
+ embed_tokens = Gemma3TextScaledWordEmbedding(
399
+ self.config.vocab_size,
400
+ self.config.hidden_size,
401
+ self.config.pad_token_id,
402
+ embed_scale=self.config.hidden_size**0.5,
403
+ )
404
+ return embed_tokens
405
+
406
+ @classmethod
407
+ def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
408
+ sliding_window = getattr(model_config, "sliding_window", None)
409
+ sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
410
+ if sliding_window_pattern is None:
411
+ if hasattr(model_config, "layer_types"):
412
+ first_full_attention_index = model_config.layer_types.index("full_attention")
413
+ sliding_window_pattern = first_full_attention_index + 1
414
+ else:
415
+ raise ValueError("Cannot determine sliding_window_pattern from model_config")
416
+
417
+ if sliding_window_pattern <= model_config.num_hidden_layers:
418
+ rbln_config.cache_impl = "hybrid"
419
+ rbln_config.sliding_window = sliding_window
420
+ rbln_config.sliding_window_layers = [
421
+ i for i in range(model_config.num_hidden_layers) if (i + 1) % sliding_window_pattern > 0
422
+ ]
423
+
424
+ return rbln_config
425
+
426
+ @classmethod
427
+ def _update_submodule_config(
428
+ cls,
429
+ model: "PreTrainedModel",
430
+ rbln_config: RBLNModelConfig,
431
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
432
+ ):
433
+ if rbln_config.image_prefill_chunk_size is None:
434
+ rbln_config.image_prefill_chunk_size = model.config.mm_tokens_per_image
435
+
436
+ if rbln_config.image_prefill_chunk_size != model.config.mm_tokens_per_image:
437
+ raise ValueError(
438
+ f"Image prefill chunk size is different from mm_tokens_per_image: {rbln_config.image_prefill_chunk_size} != {model.config.mm_tokens_per_image}"
439
+ )
440
+
441
+ return rbln_config
442
+
443
+ @classmethod
444
+ def _update_rbln_config(
445
+ cls,
446
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
447
+ model: Optional["PreTrainedModel"] = None,
448
+ model_config: Optional["PretrainedConfig"] = None,
449
+ rbln_config: Optional[RBLNGemma3ForCausalLMConfig] = None,
450
+ ) -> RBLNGemma3ForCausalLMConfig:
451
+ # Update rbln_config with super class
452
+ rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
453
+
454
+ if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
455
+ raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
456
+
457
+ if rbln_config.use_image_prefill:
458
+ if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
459
+ raise NotImplementedError(
460
+ "Not implemented for different prefill chunk sizes between text and image prefill."
461
+ )
462
+
463
+ # Update image prefill compile config
464
+ img_prefill_input_info = cls.get_input_info(
465
+ batch_size=1,
466
+ query_length=rbln_config.image_prefill_chunk_size,
467
+ rbln_config=rbln_config,
468
+ model_config=model_config,
469
+ )
470
+ image_prefill_compile_config = RBLNCompileConfig(
471
+ compiled_model_name="image_prefill", input_info=img_prefill_input_info
472
+ )
473
+ # Insert image_prefill compile config at index 1
474
+ compile_cfgs = rbln_config.compile_cfgs
475
+ compile_cfgs.insert(1, image_prefill_compile_config)
476
+ rbln_config.set_compile_cfgs(compile_cfgs)
477
+
478
+ return rbln_config
479
+
480
+ @classmethod
481
+ @torch.inference_mode()
482
+ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNGemma3ForCausalLMConfig):
483
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
484
+
485
+ rbln_compile_configs = rbln_config.compile_cfgs
486
+ prefill_compile_config = rbln_compile_configs[0]
487
+
488
+ context = CompileContext(use_weight_sharing=True)
489
+
490
+ # Here we use meta tensor, for the memory efficiency.
491
+ meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
492
+ prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
493
+
494
+ # Mark static tensors (self kv states)
495
+ static_tensors = {}
496
+ for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
497
+ if "past_key_values" in name:
498
+ static_tensors[name] = tensor
499
+ context.mark_static_address(tensor)
500
+
501
+ def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
502
+ try:
503
+ if quantization:
504
+ quantization.maybe_set_quantization_env()
505
+ original_linear = torch.nn.functional.linear
506
+ torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
507
+ compiled_model = cls.compile(
508
+ wrapped_model,
509
+ compile_config,
510
+ create_runtimes=rbln_config.create_runtimes,
511
+ device=rbln_config.device,
512
+ example_inputs=example_inputs,
513
+ compile_context=compile_context,
514
+ )
515
+ return compiled_model
516
+ finally:
517
+ torch.nn.functional.linear = original_linear
518
+ if quantization:
519
+ quantization.maybe_reset_quantization_env()
520
+
521
+ wrapped_model.phase = "prefill"
522
+ compiled_prefill = compile_model(
523
+ wrapped_model,
524
+ prefill_compile_config,
525
+ prefill_example_inputs,
526
+ context,
527
+ rbln_config.quantization,
528
+ )
529
+ compiled_models = {"prefill": compiled_prefill}
530
+
531
+ if rbln_config.use_image_prefill:
532
+ image_prefill_compile_config = rbln_compile_configs[1]
533
+ image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
534
+ fill=0, static_tensors=static_tensors
535
+ )
536
+ wrapped_model.phase = "image_prefill"
537
+ compiled_image_prefill = compile_model(
538
+ wrapped_model,
539
+ image_prefill_compile_config,
540
+ image_prefill_example_inputs,
541
+ context,
542
+ rbln_config.quantization,
543
+ )
544
+ compiled_models["image_prefill"] = compiled_image_prefill
545
+
546
+ wrapped_model.phase = "decode"
547
+ for batch_size, dec_compile_config in zip(
548
+ rbln_config.decoder_batch_sizes, rbln_compile_configs[rbln_config.decoder_runtime_idx :]
549
+ ):
550
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
551
+ compiled_decoder = compile_model(
552
+ wrapped_model,
553
+ dec_compile_config,
554
+ dec_example_inputs,
555
+ context,
556
+ rbln_config.quantization,
557
+ )
558
+ compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
559
+
560
+ return compiled_models
561
+
562
+ @classmethod
563
+ def _create_runtimes(
564
+ cls,
565
+ compiled_models: List[rebel.RBLNCompiledModel],
566
+ rbln_config: RBLNGemma3ForCausalLMConfig,
567
+ ) -> List[rebel.Runtime]:
568
+ expected_model_names = [
569
+ "prefill",
570
+ *[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
571
+ ]
572
+ if rbln_config.use_image_prefill:
573
+ expected_model_names.insert(1, "image_prefill")
574
+
575
+ if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
576
+ cls._raise_missing_compiled_file_error(expected_model_names)
577
+
578
+ ret_val = [
579
+ rebel.Runtime(
580
+ compiled_models[0],
581
+ tensor_type="pt",
582
+ device=rbln_config.device_map["prefill"],
583
+ activate_profiler=rbln_config.activate_profiler,
584
+ timeout=rbln_config.timeout,
585
+ )
586
+ ]
587
+ if rbln_config.use_image_prefill:
588
+ ret_val.append(
589
+ rebel.Runtime(
590
+ compiled_models[1],
591
+ tensor_type="pt",
592
+ device=rbln_config.device_map["image_prefill"],
593
+ activate_profiler=rbln_config.activate_profiler,
594
+ timeout=rbln_config.timeout,
595
+ ),
596
+ )
597
+
598
+ ret_val.extend(
599
+ [
600
+ rebel.Runtime(
601
+ compiled_models[i + rbln_config.decoder_runtime_idx],
602
+ tensor_type="pt",
603
+ device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
604
+ activate_profiler=rbln_config.activate_profiler,
605
+ timeout=rbln_config.timeout,
606
+ )
607
+ for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
608
+ ]
609
+ )
610
+
611
+ return ret_val
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_gpt2 import RBLNGPT2LMHeadModelConfig, RBLNGPT2ModelConfig
16
+ from .modeling_gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2Model
@@ -0,0 +1,50 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNGPT2LMHeadModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN GPT2 models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+ """
24
+
25
+
26
+ class RBLNGPT2ModelConfig(RBLNDecoderOnlyModelConfig):
27
+ """
28
+ Configuration class for RBLN GPT2 models.
29
+
30
+ This class is an alias of RBLNDecoderOnlyModelConfig.
31
+
32
+ Example usage:
33
+ ```python
34
+ from optimum.rbln import RBLNGPT2Model, RBLNGPT2ModelConfig
35
+
36
+ # Create a configuration object
37
+ config = RBLNGPT2ModelConfig(
38
+ batch_size=1,
39
+ max_seq_len=1024,
40
+ tensor_parallel_size=4
41
+ )
42
+
43
+ # Use the configuration with from_pretrained
44
+ model = RBLNGPT2Model.from_pretrained(
45
+ "openai/gpt2",
46
+ export=True,
47
+ rbln_config=config
48
+ )
49
+ ```
50
+ """