optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,527 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class Seq2SeqWrapper:
26
+ """A wrapper class for Seq2Seq models to support RBLN-specific optimizations.
27
+
28
+ This wrapper divides the Seq2Seq model into separate encoder and decoder wrappers,
29
+ enabling specific optimizations such as custom cache handling and attention mechanisms.
30
+
31
+ Args:
32
+ model (nn.Module): The Seq2Seq model to wrap.
33
+ enc_max_seq_len (int): Maximum sequence length for the encoder's position embeddings and cache sizes.
34
+ kwargs: Additional arguments to pass to the decoder wrapper.
35
+ """
36
+
37
+ def __init__(self, model: nn.Module, enc_max_seq_len: int, **kwargs):
38
+ self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
39
+ self.decoder = Seq2SeqDecoderWrapper(model, **kwargs)
40
+
41
+
42
+ class Seq2SeqEncoderWrapper(nn.Module):
43
+ """A wrapper for the encoder component of a Seq2Seq model, designed for RBLN optimization.
44
+
45
+ This wrapper modifies the standard encoder-decoder architecture of Seq2Seq models to optimize
46
+ memory usage and attention mechanisms, particularly in cross-attention layers. It supports custom
47
+ cache handling to improve performance during decoding.
48
+
49
+ Args:
50
+ model (nn.Module): The Seq2Seq model containing the encoder.
51
+ enc_max_seq_len (int): Maximum sequence length for encoder embeddings and cache sizes.
52
+ """
53
+
54
+ def __init__(self, model: nn.Module, enc_max_seq_len: int):
55
+ super().__init__()
56
+ self.config = model.config
57
+ self.encoder = model.get_encoder()
58
+ self.encoder_max_length = enc_max_seq_len
59
+ self.__post_init__(model)
60
+
61
+ def __post_init__(self, model: nn.Module):
62
+ """
63
+ Post-initialization to extract and configure encoder-related attributes.
64
+
65
+ It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
66
+ by subclasses to modify or add custom attributes as necessary.
67
+ """
68
+ self.n_layer = getattr(self.config, "decoder_layers", None)
69
+ self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
70
+ self.num_heads = self.config.decoder_attention_heads
71
+ self.d_kv = self.config.d_model // self.num_heads
72
+
73
+ def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
74
+ """
75
+ Extract cross-attention key and value projection layers from the decoder.
76
+ """
77
+ return (
78
+ nn.ModuleList(decoder_layers[i].encoder_attn.k_proj for i in range(self.n_layer)),
79
+ nn.ModuleList(decoder_layers[i].encoder_attn.v_proj for i in range(self.n_layer)),
80
+ )
81
+
82
+ def forward(
83
+ self,
84
+ input_ids: torch.Tensor,
85
+ attention_mask: torch.Tensor,
86
+ b_idx: torch.Tensor,
87
+ *cross_key_values: Tuple[torch.Tensor],
88
+ ) -> Tuple[torch.Tensor]:
89
+ # 1. get encoder last_hidden_states
90
+ encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
91
+ last_hidden_states = encoder_outputs[0]
92
+
93
+ # 2. pre-compute cross_attention's past_key_value which used in decoder phase.
94
+ cross_kv = []
95
+ for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
96
+ past_k = (
97
+ k_proj(last_hidden_states).view(1, self.encoder_max_length, self.num_heads, self.d_kv).transpose(1, 2)
98
+ )
99
+ past_v = (
100
+ v_proj(last_hidden_states).view(1, self.encoder_max_length, self.num_heads, self.d_kv).transpose(1, 2)
101
+ )
102
+
103
+ cross_kv.append(past_k)
104
+ cross_kv.append(past_v)
105
+
106
+ # 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
107
+ batch_axis = torch.tensor(0, dtype=torch.int16)
108
+ cross_key_values = list(cross_key_values)
109
+ for i in range(self.n_layer * 2):
110
+ cross_key_values[i] = torch.ops.rbln_custom_ops.rbln_cache_update(
111
+ cross_key_values[i], cross_kv[i], b_idx[0], batch_axis
112
+ )
113
+
114
+ return cross_key_values
115
+
116
+
117
+ class Seq2SeqDecoderWrapper(nn.Module):
118
+ """
119
+ A wrapper for the decoder component of a Seq2Seq model, designed for RBLN optimization.
120
+
121
+ This wrapper handles tasks such as:
122
+ 1. Converting decoder components to support RBLN-specific conditional generation.
123
+ 2. Customizing attention mechanisms, including self-attention and cross-attention.
124
+ 3. Managing the decoder's key-value caches for both self and cross-attention.
125
+
126
+ Args:
127
+ model (nn.Module): The Seq2Seq model containing the decoder.
128
+ kwargs: Additional arguments for decoder configuration.
129
+ """
130
+
131
+ def __init__(self, model: nn.Module, use_attention_mask: bool = True, **kwargs):
132
+ super().__init__()
133
+ self.config = model.config
134
+ self.use_attention_mask = use_attention_mask
135
+ self.__post_init__(model, **kwargs)
136
+
137
+ def __post_init__(self, model: nn.Module, **kwargs):
138
+ """
139
+ Post-initialization to extract and configure encoder-related attributes.
140
+
141
+ It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
142
+ by subclasses to modify or add custom attributes as necessary.
143
+ """
144
+ self.num_layers = self.config.decoder_layers
145
+ self.conditional_generation = self.convert_to_rbln_conditional_generation(model)
146
+
147
+ def convert_to_rbln_conditional_generation(self, model: nn.Module):
148
+ new_layers = []
149
+ for layer in model.get_decoder().layers:
150
+ self_attn = Seq2SeqSelfAttention(layer.self_attn)
151
+ cross_attn = Seq2SeqCrossAttention(layer.encoder_attn)
152
+ new_layers.append(Seq2SeqDecoderLayer(layer, self_attn, cross_attn))
153
+
154
+ decoder_model = Seq2SeqDecoder(model.get_decoder(), new_layers)
155
+ new_model = Seq2SeqForConditionalGeneration(model, decoder_model)
156
+
157
+ return new_model
158
+
159
+ def forward(
160
+ self,
161
+ *args,
162
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
163
+ if self.use_attention_mask:
164
+ (
165
+ input_ids,
166
+ attention_mask,
167
+ encoder_attention_mask,
168
+ cache_position,
169
+ block_tables,
170
+ *kv_cache,
171
+ ) = args
172
+
173
+ else:
174
+ attention_mask = None
175
+ (input_ids, encoder_attention_mask, cache_position, block_tables, *kv_cache) = args
176
+
177
+ self_past_key_values = ()
178
+ cross_past_key_values = ()
179
+ self_kv_cache = kv_cache[self.num_layers * 2 :]
180
+ cross_kv_cache = kv_cache[: self.num_layers * 2]
181
+ for i in range(0, self.num_layers * 2, 2):
182
+ self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
183
+ cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
184
+
185
+ # decode
186
+ lm_logits = self.conditional_generation(
187
+ input_ids=input_ids,
188
+ attention_mask=attention_mask,
189
+ encoder_attention_mask=encoder_attention_mask,
190
+ self_past_key_values=self_past_key_values,
191
+ cross_past_key_values=cross_past_key_values,
192
+ cache_position=cache_position,
193
+ block_tables=block_tables,
194
+ )
195
+
196
+ return lm_logits
197
+
198
+
199
+ class Seq2SeqForConditionalGeneration(nn.Module):
200
+ """
201
+ A wrapper for Seq2Seq models supporting RBLN-specific optimizations for conditional generation.
202
+
203
+ This class adapts a Seq2Seq model for tasks like machine translation, summarization, or text generation
204
+ by:
205
+ 1. Wrapping and customizing the decoder component to support key RBLN features.
206
+ 2. Managing rescaling and output processing, if enabled.
207
+ 3. Aligning model behavior with RBLN's static and efficient execution requirements.
208
+
209
+ Attributes:
210
+ has_rescaling (bool): Indicates if output rescaling is applied.
211
+ config (PretrainedConfig): Configuration from the original Seq2Seq model.
212
+ lm_head (nn.Linear): The language modeling head for output logits.
213
+ decoder (nn.Module): The wrapped decoder model.
214
+ """
215
+
216
+ has_rescaling = False
217
+
218
+ def __init__(self, model, decoder_model):
219
+ super().__init__()
220
+ self.config = model.config
221
+ self.lm_head = model.lm_head
222
+ self.decoder = decoder_model
223
+ self.__post_init__()
224
+
225
+ def __post_init__(self):
226
+ """
227
+ Abstract method intended to be overridden by subclasses to modify or override
228
+ the attributes of the original model after initialization.
229
+ """
230
+
231
+ def forward(
232
+ self,
233
+ input_ids,
234
+ attention_mask,
235
+ encoder_attention_mask,
236
+ self_past_key_values,
237
+ cross_past_key_values,
238
+ cache_position,
239
+ block_tables: Optional[torch.Tensor] = None,
240
+ ):
241
+ hidden_states = self.decoder(
242
+ input_ids=input_ids,
243
+ attention_mask=attention_mask,
244
+ encoder_attention_mask=encoder_attention_mask,
245
+ self_past_key_values=self_past_key_values,
246
+ cross_past_key_values=cross_past_key_values,
247
+ cache_position=cache_position,
248
+ block_tables=block_tables,
249
+ )
250
+
251
+ if self.has_rescaling and self.config.tie_word_embeddings:
252
+ hidden_states = hidden_states * self.scaling
253
+
254
+ lm_logits = self.lm_head(hidden_states)
255
+
256
+ return lm_logits
257
+
258
+
259
+ class Seq2SeqDecoder(torch.nn.Module):
260
+ """A modified Seq2SeqDecoder implementation optimized for RBLN compilation.
261
+
262
+ Args:
263
+ model: Original Huggingface model to adapt
264
+ layers (List[Seq2SeqDecoderLayer]): Modified transformer layers optimized for RBLN
265
+ """
266
+
267
+ has_pos_emb = True
268
+
269
+ def __init__(self, model, layers, **kwargs):
270
+ super().__init__()
271
+ self._original_mod = model
272
+ self.layers = nn.ModuleList(layers)
273
+ self.embed_tokens = model.embed_tokens
274
+ self.final_layer_norm = getattr(model, "final_layer_norm", None)
275
+ self.__post_init__(**kwargs)
276
+
277
+ def __post_init__(self, **kwargs):
278
+ """
279
+ Abstract method intended to be overridden by subclasses to modify or override
280
+ the attributes of the original model after initialization.
281
+ """
282
+ pass
283
+
284
+ def get_embedding(self):
285
+ return self.embed_tokens
286
+
287
+ def prepare_attn_mask(self, *args, **kwargs):
288
+ raise NotImplementedError(
289
+ "The 'prepare_attn_mask' method is not implemented. Please define this method in a subclass."
290
+ )
291
+
292
+ def apply_position_embedding(self, *args, **kwargs):
293
+ raise NotImplementedError(
294
+ "The 'apply_position_embedding' method is not implemented. Please define this method in a subclass."
295
+ )
296
+
297
+ def forward(
298
+ self,
299
+ input_ids: torch.Tensor,
300
+ attention_mask: torch.Tensor,
301
+ encoder_attention_mask: torch.Tensor,
302
+ self_past_key_values: torch.Tensor,
303
+ cross_past_key_values: torch.Tensor,
304
+ cache_position: torch.Tensor,
305
+ block_tables: Optional[torch.Tensor] = None,
306
+ ):
307
+ # embedding
308
+ hidden_states = self.get_embedding()(input_ids)
309
+ attention_mask, encoder_attention_mask = self.prepare_attn_mask(
310
+ attention_mask, encoder_attention_mask, cache_position=cache_position
311
+ )
312
+
313
+ if self.has_pos_emb:
314
+ hidden_states = self.apply_position_embedding(hidden_states, cache_position)
315
+
316
+ # iterate decoder_layer
317
+ for decoder_layer, self_past_key_value, cross_past_key_value in zip(
318
+ self.layers, self_past_key_values, cross_past_key_values
319
+ ):
320
+ hidden_states = decoder_layer(
321
+ hidden_states,
322
+ attention_mask=attention_mask,
323
+ encoder_attention_mask=encoder_attention_mask,
324
+ self_past_key_value=self_past_key_value,
325
+ cross_past_key_value=cross_past_key_value,
326
+ cache_position=cache_position,
327
+ block_tables=block_tables,
328
+ )
329
+
330
+ if self.final_layer_norm is not None:
331
+ hidden_states = self.final_layer_norm(hidden_states)
332
+
333
+ return hidden_states
334
+
335
+
336
+ class Seq2SeqDecoderLayer(torch.nn.Module):
337
+ """A modified decoder-only model implementation optimized for RBLN compilation.
338
+
339
+ Args:
340
+ model: Original Huggingface model to adapt
341
+ layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
342
+ self_attn (Seq2SeqSelfAttention): Modified self-attention layer optimized for RBLN
343
+ """
344
+
345
+ def __init__(self, decoder_layer, self_attn, cross_attn):
346
+ super().__init__()
347
+ self._original_mod = decoder_layer
348
+ self.self_attn = self_attn
349
+ self.cross_attn = cross_attn
350
+ self.__post_init__()
351
+
352
+ def __post_init__(self, **kwargs):
353
+ """
354
+ Abstract method intended to be overridden by subclasses to modify or override
355
+ the attributes of the original model after initialization.
356
+ """
357
+ pass
358
+
359
+ def pre_self_attn_layer_norm(self, hidden_states):
360
+ raise NotImplementedError(
361
+ "The 'pre_self_attn_layer_norm' method is not implemented. Please define this method in a subclass."
362
+ )
363
+
364
+ def post_self_attn_layer_norm(self, hidden_states):
365
+ raise NotImplementedError(
366
+ "The 'post_self_attn_layer_norm' method is not implemented. Please define this method in a subclass."
367
+ )
368
+
369
+ def pre_cross_attn_layer_norm(self, hidden_states):
370
+ raise NotImplementedError(
371
+ "The 'pre_cross_attn_layer_norm' method is not implemented. Please define this method in a subclass."
372
+ )
373
+
374
+ def post_cross_attn_layer_norm(self, hidden_states):
375
+ raise NotImplementedError(
376
+ "The 'post_cross_attn_layer_norm' method is not implemented. Please define this method in a subclass."
377
+ )
378
+
379
+ def forward(
380
+ self,
381
+ hidden_states: torch.Tensor,
382
+ attention_mask: torch.Tensor,
383
+ encoder_attention_mask: torch.Tensor,
384
+ self_past_key_value: Tuple[torch.Tensor],
385
+ cross_past_key_value: Tuple[torch.Tensor],
386
+ cache_position: torch.Tensor,
387
+ block_tables: Optional[torch.Tensor] = None,
388
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
389
+ dummy_encoder_hidden_states = torch.zeros(1, encoder_attention_mask.shape[-1])
390
+
391
+ # Self Attention Block
392
+ residual = hidden_states
393
+ hidden_states = self.pre_self_attn_layer_norm(hidden_states)
394
+ hidden_states = self.self_attn(
395
+ hidden_states=hidden_states,
396
+ past_key_value=self_past_key_value,
397
+ attention_mask=attention_mask,
398
+ cache_position=cache_position,
399
+ block_tables=block_tables,
400
+ )
401
+ hidden_states = residual + hidden_states
402
+ hidden_states = self.post_self_attn_layer_norm(hidden_states)
403
+
404
+ # Cross-Attention Block
405
+ residual = hidden_states
406
+ hidden_states = self.pre_cross_attn_layer_norm(hidden_states)
407
+
408
+ cross_attn_output = self.cross_attn(
409
+ hidden_states=hidden_states,
410
+ past_key_value=cross_past_key_value,
411
+ attention_mask=encoder_attention_mask,
412
+ key_value_states=dummy_encoder_hidden_states,
413
+ )
414
+ hidden_states = residual + cross_attn_output[0]
415
+ hidden_states = self.post_cross_attn_layer_norm(hidden_states)
416
+
417
+ # Feed-Forward Block
418
+ hidden_states = self.ff_layer(hidden_states)
419
+
420
+ return hidden_states
421
+
422
+
423
+ class Seq2SeqSelfAttention(nn.Module):
424
+ def __init__(self, attn, **kwargs):
425
+ super().__init__()
426
+ self._original_mod = attn
427
+ self.__post_init__(**kwargs)
428
+
429
+ def __post_init__(self, **kwargs):
430
+ """
431
+ Abstract method intended to be overridden by subclasses to modify or override
432
+ the attributes of the original model after initialization.
433
+ """
434
+ pass
435
+
436
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
437
+ return tensor.view(bsz, seq_len, 1, self.num_heads, self.head_dim).transpose(1, 3)
438
+
439
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
440
+ """Projects input hidden states into query, key, and value representations.
441
+
442
+ Args:
443
+ hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
444
+
445
+ Returns:
446
+ Tuple of (query_states, key_states, value_states)
447
+ """
448
+ query_states = self.q_proj(hidden_states)
449
+ key_states = self.k_proj(hidden_states)
450
+ value_states = self.v_proj(hidden_states)
451
+ return query_states, key_states, value_states
452
+
453
+ def forward(
454
+ self,
455
+ hidden_states: torch.Tensor,
456
+ past_key_value: Tuple[torch.Tensor],
457
+ attention_mask: torch.Tensor,
458
+ cache_position: torch.Tensor,
459
+ block_tables: Optional[torch.Tensor] = None,
460
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
461
+ bsz, tgt_len, _ = hidden_states.size()
462
+
463
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
464
+ query_states = self._shape(query_states, tgt_len, bsz)
465
+ key_states = self._shape(key_states, -1, bsz)
466
+ value_states = self._shape(value_states, -1, bsz)
467
+
468
+ block_size = past_key_value[0].shape[-2]
469
+ args = [
470
+ query_states,
471
+ key_states,
472
+ value_states,
473
+ past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
474
+ past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
475
+ cache_position,
476
+ torch.tensor(1.0, dtype=torch.float32), # scale
477
+ block_tables,
478
+ block_size,
479
+ ]
480
+ if attention_mask is not None:
481
+ args.insert(3, attention_mask.unsqueeze(2))
482
+ else:
483
+ args.append(None)
484
+
485
+ attn_output = self.attn_decode(*args)
486
+
487
+ attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
488
+ attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
489
+
490
+ attn_output = self.out_proj(attn_output)
491
+
492
+ return attn_output
493
+
494
+
495
+ class Seq2SeqCrossAttention(nn.Module):
496
+ def __init__(self, attn, **kwargs):
497
+ super().__init__()
498
+ self._original_mod = attn
499
+ self.__post_init__(**kwargs)
500
+
501
+ def forward(
502
+ self,
503
+ hidden_states: torch.Tensor,
504
+ key_value_states: torch.Tensor = None,
505
+ past_key_value: Optional[object] = None,
506
+ attention_mask: Optional[torch.Tensor] = None,
507
+ ):
508
+ bsz, tgt_len, _ = hidden_states.size()
509
+ query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
510
+
511
+ is_cross_attention = key_value_states is not None
512
+ if is_cross_attention:
513
+ key_states = past_key_value[0]
514
+ value_states = past_key_value[1]
515
+
516
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
517
+ query_states,
518
+ key_states,
519
+ value_states,
520
+ attn_mask=attention_mask,
521
+ )
522
+
523
+ attn_output = attn_output.transpose(1, 2).contiguous()
524
+ attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
525
+ attn_output = self.out_proj(attn_output)
526
+
527
+ return attn_output, None, past_key_value
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_siglip import RBLNSiglipVisionModelConfig
16
+ from .modeling_siglip import RBLNSiglipVisionModel
@@ -0,0 +1,76 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNSiglipVisionModelConfig(RBLNModelConfig):
21
+ """
22
+ Configuration class for RBLNSiglipVisionModel.
23
+
24
+ This configuration class stores the configuration parameters specific to
25
+ RBLN-optimized SigLIP vision models for image encoding in multimodal tasks.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ batch_size: Optional[int] = None,
31
+ image_size: Optional[int] = None,
32
+ interpolate_pos_encoding: Optional[bool] = None,
33
+ output_hidden_states: Optional[bool] = None,
34
+ output_attentions: Optional[bool] = None,
35
+ **kwargs,
36
+ ):
37
+ """
38
+ Args:
39
+ batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
40
+ image_size (Optional[int]): The size of input images. Can be an integer for square images,
41
+ a tuple/list (height, width), or a dictionary with 'height' and 'width' keys.
42
+ interpolate_pos_encoding (Optional[bool]): Whether to interpolate the position encoding.
43
+ output_hidden_states: (Optional[bool]): Whether to return hidden states.
44
+ output_attentions: (Optional[bool]): Whether to return attentions.
45
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
46
+
47
+ Raises:
48
+ ValueError: If batch_size is not a positive integer.
49
+ """
50
+ super().__init__(**kwargs)
51
+ self.batch_size = batch_size or 1
52
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
53
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
54
+
55
+ self.image_size = image_size
56
+ self.interpolate_pos_encoding = interpolate_pos_encoding or False
57
+ self.output_hidden_states = output_hidden_states
58
+ self.output_attentions = output_attentions
59
+
60
+ @property
61
+ def image_width(self):
62
+ if isinstance(self.image_size, int):
63
+ return self.image_size
64
+ elif isinstance(self.image_size, (list, tuple)):
65
+ return self.image_size[1]
66
+ else:
67
+ return self.image_size["width"]
68
+
69
+ @property
70
+ def image_height(self):
71
+ if isinstance(self.image_size, int):
72
+ return self.image_size
73
+ elif isinstance(self.image_size, (list, tuple)):
74
+ return self.image_size[0]
75
+ else:
76
+ return self.image_size["height"]