transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1682 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_paddleocr_vl.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
10
+ # and OPT implementations in this library. It has been modified from its
11
+ # original forms to accommodate minor architectural differences compared
12
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
13
+ #
14
+ # Licensed under the Apache License, Version 2.0 (the "License");
15
+ # you may not use this file except in compliance with the License.
16
+ # You may obtain a copy of the License at
17
+ #
18
+ # http://www.apache.org/licenses/LICENSE-2.0
19
+ #
20
+ # Unless required by applicable law or agreed to in writing, software
21
+ # distributed under the License is distributed on an "AS IS" BASIS,
22
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23
+ # See the License for the specific language governing permissions and
24
+ # limitations under the License.
25
+
26
+ from collections.abc import Callable
27
+ from dataclasses import dataclass
28
+ from typing import Any, Optional, Union
29
+
30
+ import torch
31
+ from torch import nn
32
+
33
+ from ... import initialization as init
34
+ from ...activations import ACT2FN, GELUActivation
35
+ from ...cache_utils import Cache, DynamicCache
36
+ from ...generation import GenerationMixin
37
+ from ...integrations import use_kernel_forward_from_hub
38
+ from ...masking_utils import create_bidirectional_mask, create_causal_mask
39
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
40
+ from ...modeling_layers import GradientCheckpointingLayer
41
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
42
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
43
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
44
+ from ...processing_utils import Unpack
45
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
46
+ from ...utils.generic import check_model_inputs, maybe_autocast
47
+ from .configuration_paddleocr_vl import PaddleOCRTextConfig, PaddleOCRVisionConfig, PaddleOCRVLConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+
53
+ class PaddleOCRProjector(nn.Module):
54
+ def __init__(self, config: PaddleOCRVLConfig):
55
+ super().__init__()
56
+ self.merge_kernel_size = (config.vision_config.spatial_merge_size, config.vision_config.spatial_merge_size)
57
+
58
+ hidden_size = config.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1]
59
+
60
+ self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-05)
61
+ self.linear_1 = nn.Linear(hidden_size, hidden_size, bias=True)
62
+ self.act = GELUActivation()
63
+ self.linear_2 = nn.Linear(hidden_size, config.text_config.hidden_size, bias=True)
64
+
65
+ def forward(self, image_features: torch.Tensor, image_grid_thw: torch.Tensor) -> torch.Tensor:
66
+ image_features_chunks = image_features.split(image_grid_thw.prod(dim=1).tolist(), dim=0)
67
+ m1, m2 = self.merge_kernel_size
68
+
69
+ processed_features = []
70
+ for image_feature, image_grid in zip(image_features_chunks, image_grid_thw):
71
+ image_feature = self.pre_norm(image_feature)
72
+ t, h, w = image_grid
73
+ d = image_feature.shape[-1]
74
+ h_block = h // m1
75
+ w_block = w // m2
76
+
77
+ image_feature = image_feature.reshape(t, h_block, m1, w_block, m2, d)
78
+ image_feature = image_feature.transpose(2, 3)
79
+ image_feature = image_feature.reshape(t * h_block * w_block, m1 * m2 * d)
80
+
81
+ hidden_states = self.linear_1(image_feature)
82
+ hidden_states = self.act(hidden_states)
83
+ hidden_states = self.linear_2(hidden_states)
84
+ processed_features.append(hidden_states)
85
+
86
+ return torch.cat(processed_features, dim=0)
87
+
88
+
89
+ class PaddleOCRVisionRotaryEmbedding(nn.Module):
90
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
91
+
92
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
93
+ super().__init__()
94
+ self.dim = dim
95
+ self.theta = theta
96
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
97
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
98
+
99
+ def forward(self, seqlen: int) -> torch.Tensor:
100
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
101
+ freqs = torch.outer(seq, self.inv_freq)
102
+ return freqs
103
+
104
+
105
+ class PaddleOCRRotaryEmbedding(nn.Module):
106
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
107
+
108
+ def __init__(self, config: PaddleOCRVLConfig, device=None):
109
+ super().__init__()
110
+ self.max_seq_len_cached = config.max_position_embeddings
111
+ self.original_max_seq_len = config.max_position_embeddings
112
+
113
+ self.config = config
114
+
115
+ self.rope_type = self.config.rope_parameters["rope_type"]
116
+ rope_init_fn: Callable = self.compute_default_rope_parameters
117
+ if self.rope_type != "default":
118
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
119
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
120
+
121
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
122
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
123
+
124
+ @staticmethod
125
+ def compute_default_rope_parameters(
126
+ config: Optional[PaddleOCRVLConfig] = None,
127
+ device: Optional["torch.device"] = None,
128
+ seq_len: Optional[int] = None,
129
+ ) -> tuple["torch.Tensor", float]:
130
+ """
131
+ Computes the inverse frequencies according to the original RoPE implementation
132
+ Args:
133
+ config ([`~transformers.PreTrainedConfig`]):
134
+ The model configuration.
135
+ device (`torch.device`):
136
+ The device to use for initialization of the inverse frequencies.
137
+ seq_len (`int`, *optional*):
138
+ The current sequence length. Unused for this type of RoPE.
139
+ Returns:
140
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
141
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
142
+ """
143
+ base = config.rope_parameters["rope_theta"]
144
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
145
+
146
+ attention_factor = 1.0 # Unused in this type of RoPE
147
+
148
+ # Compute the inverse frequencies
149
+ inv_freq = 1.0 / (
150
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
151
+ )
152
+ return inv_freq, attention_factor
153
+
154
+ # Ignore copy
155
+ def forward(self, x, position_ids):
156
+ # In contrast to other models, PaddleOCR has different position ids for the grids
157
+ # So we expand the inv_freq to shape (3, ...)
158
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
159
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
160
+
161
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
162
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
163
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
164
+ emb = torch.cat((freqs, freqs), dim=-1)
165
+ cos = emb.cos() * self.attention_scaling
166
+ sin = emb.sin() * self.attention_scaling
167
+
168
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
169
+
170
+
171
+ class PaddleOCRMLP(nn.Module):
172
+ def __init__(self, config: PaddleOCRTextConfig):
173
+ super().__init__()
174
+ self.config = config
175
+ self.hidden_size = config.hidden_size
176
+ self.intermediate_size = config.intermediate_size
177
+
178
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
179
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
180
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
181
+ self.act_fn = ACT2FN[config.hidden_act]
182
+
183
+ def forward(self, x):
184
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
185
+ return down_proj
186
+
187
+
188
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
189
+ """
190
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
191
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
192
+ """
193
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
194
+ if n_rep == 1:
195
+ return hidden_states
196
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
197
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
198
+
199
+
200
+ def eager_attention_forward(
201
+ module: nn.Module,
202
+ query: torch.Tensor,
203
+ key: torch.Tensor,
204
+ value: torch.Tensor,
205
+ attention_mask: Optional[torch.Tensor],
206
+ scaling: float,
207
+ dropout: float = 0.0,
208
+ **kwargs,
209
+ ):
210
+ key_states = repeat_kv(key, module.num_key_value_groups)
211
+ value_states = repeat_kv(value, module.num_key_value_groups)
212
+
213
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
214
+ if attention_mask is not None:
215
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
216
+ attn_weights = attn_weights + causal_mask
217
+
218
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
219
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
220
+ attn_output = torch.matmul(attn_weights, value_states)
221
+ attn_output = attn_output.transpose(1, 2).contiguous()
222
+
223
+ return attn_output, attn_weights
224
+
225
+
226
+ def rotate_half(x):
227
+ """Rotates half the hidden dims of the input."""
228
+ x1 = x[..., : x.shape[-1] // 2]
229
+ x2 = x[..., x.shape[-1] // 2 :]
230
+ return torch.cat((-x2, x1), dim=-1)
231
+
232
+
233
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
234
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
235
+
236
+ Explanation:
237
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
238
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
239
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
240
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
241
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
242
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
243
+ difference with modern LLMs.
244
+
245
+ Args:
246
+ q (`torch.Tensor`): The query tensor.
247
+ k (`torch.Tensor`): The key tensor.
248
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
249
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
250
+ position_ids (`torch.Tensor`):
251
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
252
+ used to pass offsetted position ids when working with a KV-cache.
253
+ mrope_section(`List(int)`):
254
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
255
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
256
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
257
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
258
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
259
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
260
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
261
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
262
+ Returns:
263
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
264
+ """
265
+ mrope_section = mrope_section * 2
266
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
267
+ unsqueeze_dim
268
+ )
269
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
270
+ unsqueeze_dim
271
+ )
272
+
273
+ q_embed = (q * cos) + (rotate_half(q) * sin)
274
+ k_embed = (k * cos) + (rotate_half(k) * sin)
275
+ return q_embed, k_embed
276
+
277
+
278
+ class PaddleOCRAttention(nn.Module):
279
+ """
280
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
281
+ and "Generating Long Sequences with Sparse Transformers".
282
+ """
283
+
284
+ def __init__(self, config: PaddleOCRVLConfig, layer_idx: Optional[int] = None):
285
+ super().__init__()
286
+ self.config = config
287
+ self.layer_idx = layer_idx
288
+ if layer_idx is None:
289
+ logger.warning_once(
290
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
291
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
292
+ "when creating this class."
293
+ )
294
+
295
+ self.hidden_size = config.hidden_size
296
+ self.num_heads = config.num_attention_heads
297
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
298
+ self.num_key_value_heads = config.num_key_value_heads
299
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
300
+ self.is_causal = True
301
+
302
+ self.attention_dropout = 0.0
303
+ self.rope_parameters = config.rope_parameters
304
+ self.scaling = self.head_dim**-0.5
305
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_bias)
306
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_bias)
307
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_bias)
308
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias)
309
+ self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
310
+ self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
311
+
312
+ def forward(
313
+ self,
314
+ hidden_states: torch.Tensor,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ position_ids: Optional[torch.LongTensor] = None,
317
+ past_key_values: Optional[Cache] = None,
318
+ output_attentions: bool = False,
319
+ use_cache: bool = False,
320
+ cache_position: Optional[torch.LongTensor] = None,
321
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
322
+ **kwargs: Unpack[FlashAttentionKwargs],
323
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
324
+ bsz, q_len, _ = hidden_states.size()
325
+
326
+ query_states = self.q_proj(hidden_states)
327
+ key_states = self.k_proj(hidden_states)
328
+ value_states = self.v_proj(hidden_states)
329
+
330
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
331
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
332
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
333
+
334
+ cos, sin = position_embeddings
335
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
336
+ query_states, key_states, cos, sin, self.config.rope_parameters["mrope_section"]
337
+ )
338
+
339
+ if past_key_values is not None:
340
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
341
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
342
+
343
+ attention_interface: Callable = eager_attention_forward
344
+ if self.config._attn_implementation != "eager":
345
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
346
+
347
+ attn_output, attn_weights = attention_interface(
348
+ self,
349
+ query_states,
350
+ key_states,
351
+ value_states,
352
+ attention_mask,
353
+ dropout=0.0 if not self.training else self.attention_dropout,
354
+ scaling=self.scaling,
355
+ sliding_window=self.sliding_window,
356
+ position_ids=position_ids, # pass positions for FA2
357
+ **kwargs,
358
+ )
359
+
360
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
361
+ attn_output = self.o_proj(attn_output)
362
+ return attn_output, attn_weights
363
+
364
+
365
+ @use_kernel_forward_from_hub("RMSNorm")
366
+ class PaddleOCRRMSNorm(nn.Module):
367
+ def __init__(self, hidden_size, eps=1e-6):
368
+ """
369
+ PaddleOCRRMSNorm is equivalent to T5LayerNorm
370
+ """
371
+ super().__init__()
372
+ self.weight = nn.Parameter(torch.ones(hidden_size))
373
+ self.variance_epsilon = eps
374
+
375
+ def forward(self, hidden_states):
376
+ input_dtype = hidden_states.dtype
377
+ hidden_states = hidden_states.to(torch.float32)
378
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
379
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
380
+ return self.weight * hidden_states.to(input_dtype)
381
+
382
+ def extra_repr(self):
383
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
384
+
385
+
386
+ class PaddleOCRDecoderLayer(GradientCheckpointingLayer):
387
+ def __init__(self, config: PaddleOCRTextConfig, layer_idx: int):
388
+ super().__init__()
389
+ self.hidden_size = config.hidden_size
390
+
391
+ self.self_attn = PaddleOCRAttention(config=config, layer_idx=layer_idx)
392
+
393
+ self.mlp = PaddleOCRMLP(config)
394
+ self.input_layernorm = PaddleOCRRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
395
+ self.post_attention_layernorm = PaddleOCRRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
396
+
397
+ def forward(
398
+ self,
399
+ hidden_states: torch.Tensor,
400
+ attention_mask: Optional[torch.Tensor] = None,
401
+ position_ids: Optional[torch.LongTensor] = None,
402
+ past_key_values: Optional[Cache] = None,
403
+ use_cache: Optional[bool] = False,
404
+ cache_position: Optional[torch.LongTensor] = None,
405
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
406
+ **kwargs: Unpack[TransformersKwargs],
407
+ ) -> torch.Tensor:
408
+ residual = hidden_states
409
+ hidden_states = self.input_layernorm(hidden_states)
410
+ # Self Attention
411
+ hidden_states, _ = self.self_attn(
412
+ hidden_states=hidden_states,
413
+ attention_mask=attention_mask,
414
+ position_ids=position_ids,
415
+ past_key_values=past_key_values,
416
+ use_cache=use_cache,
417
+ cache_position=cache_position,
418
+ position_embeddings=position_embeddings,
419
+ **kwargs,
420
+ )
421
+ hidden_states = residual + hidden_states
422
+
423
+ # Fully Connected
424
+ residual = hidden_states
425
+ hidden_states = self.post_attention_layernorm(hidden_states)
426
+ hidden_states = self.mlp(hidden_states)
427
+ hidden_states = residual + hidden_states
428
+ return hidden_states
429
+
430
+
431
+ @auto_docstring
432
+ class PaddleOCRVLPreTrainedModel(PreTrainedModel):
433
+ config: PaddleOCRVLConfig
434
+ base_model_prefix = "model"
435
+ supports_gradient_checkpointing = True
436
+ _no_split_modules = ["PaddleOCRDecoderLayer"]
437
+ _skip_keys_device_placement = ["past_key_values"]
438
+ _supports_flash_attn = True
439
+ _supports_sdpa = True
440
+ _supports_flex_attn = True
441
+
442
+ _can_compile_fullgraph = True
443
+ _supports_attention_backend = True
444
+
445
+ _can_record_outputs = {
446
+ "hidden_states": PaddleOCRDecoderLayer,
447
+ "attentions": PaddleOCRAttention,
448
+ }
449
+
450
+ def _init_weights(self, module):
451
+ super()._init_weights(module)
452
+ if isinstance(module, PaddleOCRVisionEmbeddings):
453
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
454
+ elif isinstance(module, PaddleOCRVisionRotaryEmbedding):
455
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
456
+ init.copy_(module.inv_freq, inv_freq)
457
+
458
+
459
+ @auto_docstring
460
+ class PaddleOCRTextModel(PaddleOCRVLPreTrainedModel):
461
+ def __init__(self, config: PaddleOCRTextConfig):
462
+ super().__init__(config)
463
+ self.padding_idx = config.pad_token_id
464
+ self.vocab_size = config.vocab_size
465
+
466
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
467
+ self.layers = nn.ModuleList(
468
+ [PaddleOCRDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
469
+ )
470
+ self.norm = PaddleOCRRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
471
+ self.rotary_emb = PaddleOCRRotaryEmbedding(config=config)
472
+ self.gradient_checkpointing = False
473
+
474
+ # Initialize weights and apply final processing
475
+ self.post_init()
476
+
477
+ @check_model_inputs
478
+ @auto_docstring
479
+ def forward(
480
+ self,
481
+ input_ids: Optional[torch.LongTensor] = None,
482
+ attention_mask: Optional[torch.Tensor] = None,
483
+ position_ids: Optional[torch.LongTensor] = None,
484
+ past_key_values: Optional[Cache] = None,
485
+ inputs_embeds: Optional[torch.FloatTensor] = None,
486
+ cache_position: Optional[torch.LongTensor] = None,
487
+ use_cache: Optional[bool] = None,
488
+ **kwargs: Unpack[TransformersKwargs],
489
+ ) -> BaseModelOutputWithPast:
490
+ if (input_ids is None) ^ (inputs_embeds is not None):
491
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
492
+
493
+ if inputs_embeds is None:
494
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
495
+
496
+ if use_cache and past_key_values is None:
497
+ past_key_values = DynamicCache(config=self.config)
498
+
499
+ if cache_position is None:
500
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
501
+ cache_position: torch.Tensor = (
502
+ torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
503
+ )
504
+
505
+ if position_ids is None:
506
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
507
+ elif position_ids.ndim == 2:
508
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
509
+
510
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
511
+ text_position_ids = position_ids[0]
512
+ position_ids = position_ids[1:]
513
+ else:
514
+ text_position_ids = None
515
+
516
+ causal_mask = create_causal_mask(
517
+ config=self.config,
518
+ input_embeds=inputs_embeds,
519
+ attention_mask=attention_mask,
520
+ cache_position=cache_position,
521
+ past_key_values=past_key_values,
522
+ position_ids=text_position_ids,
523
+ )
524
+
525
+ hidden_states = inputs_embeds
526
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
527
+
528
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
529
+ hidden_states = decoder_layer(
530
+ hidden_states,
531
+ attention_mask=causal_mask,
532
+ position_embeddings=position_embeddings,
533
+ position_ids=text_position_ids,
534
+ past_key_values=past_key_values,
535
+ use_cache=use_cache,
536
+ cache_position=cache_position,
537
+ **kwargs,
538
+ )
539
+
540
+ hidden_states = self.norm(hidden_states)
541
+ return BaseModelOutputWithPast(
542
+ last_hidden_state=hidden_states,
543
+ past_key_values=past_key_values,
544
+ )
545
+
546
+
547
+ class PaddleOCRVisionModel(PaddleOCRVLPreTrainedModel):
548
+ config: PaddleOCRVisionConfig
549
+ main_input_name = "pixel_values"
550
+ input_modalities = "image"
551
+
552
+ def __init__(self, config: PaddleOCRVisionConfig):
553
+ super().__init__(config)
554
+
555
+ self.vision_model = PaddleOCRVisionTransformer(config)
556
+
557
+ # Initialize weights and apply final processing
558
+ self.post_init()
559
+
560
+ def forward(
561
+ self,
562
+ pixel_values: torch.FloatTensor,
563
+ cu_seqlens: torch.Tensor,
564
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
565
+ **kwargs,
566
+ ) -> BaseModelOutputWithPooling:
567
+ """
568
+ Args:
569
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`):
570
+ The tensors corresponding to the input images.
571
+ cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`):
572
+ The cumulative sequence lengths of each image or video feature.
573
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
574
+ The temporal, height and width of feature shape of each image in LLM.
575
+ """
576
+ return self.vision_model(
577
+ pixel_values=pixel_values,
578
+ cu_seqlens=cu_seqlens,
579
+ image_grid_thw=image_grid_thw,
580
+ )
581
+
582
+
583
+ class PaddleOCRVisionEmbeddings(nn.Module):
584
+ def __init__(self, config: PaddleOCRVisionConfig):
585
+ super().__init__()
586
+ self.config = config
587
+ self.embed_dim = config.hidden_size
588
+ self.image_size = config.image_size
589
+ self.patch_size = config.patch_size
590
+
591
+ self.patch_embedding = nn.Conv2d(
592
+ in_channels=config.num_channels,
593
+ out_channels=self.embed_dim,
594
+ kernel_size=self.patch_size,
595
+ stride=self.patch_size,
596
+ padding="valid",
597
+ )
598
+
599
+ self.num_patches = (self.image_size // self.patch_size) ** 2
600
+ self.num_positions = self.num_patches
601
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
602
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
603
+
604
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
605
+ """
606
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
607
+ images. This method is also adapted to support torch.jit tracing and no class embeddings.
608
+
609
+ Adapted from:
610
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
611
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
612
+ """
613
+ num_positions = self.position_embedding.weight.shape[0]
614
+
615
+ patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
616
+
617
+ dim = embeddings.shape[-1]
618
+
619
+ sqrt_num_positions = torch_int(num_positions**0.5)
620
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
621
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
622
+
623
+ patch_pos_embed = nn.functional.interpolate(
624
+ patch_pos_embed,
625
+ size=(height, width),
626
+ mode="bilinear",
627
+ align_corners=False,
628
+ )
629
+
630
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
631
+ return patch_pos_embed
632
+
633
+ def forward(
634
+ self,
635
+ pixel_values: torch.FloatTensor,
636
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
637
+ ) -> torch.Tensor:
638
+ """
639
+ Args:
640
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`):
641
+ The tensors corresponding to the input images.
642
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
643
+ The temporal, height and width of feature shape of each image in LLM.
644
+ """
645
+ batch_size, squence_len, channel, height, width = pixel_values.shape
646
+ target_dtype = self.patch_embedding.weight.dtype
647
+ pixel_values = pixel_values.reshape(batch_size * squence_len, channel, height, width)
648
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
649
+ embeddings = patch_embeds.flatten(-2).squeeze(-1)
650
+ embeddings = embeddings.reshape(batch_size, squence_len, -1)
651
+
652
+ start = 0
653
+ embeddings = embeddings.squeeze(0)
654
+ tmp_embeddings = []
655
+ for image_grid in image_grid_thw:
656
+ t, h, w = image_grid
657
+ end = start + t * h * w
658
+ image_embeddings = embeddings[start:end, :]
659
+ position_embedding = self.interpolate_pos_encoding(image_embeddings, h, w).squeeze(0).repeat(t, 1)
660
+ image_embeddings = image_embeddings + position_embedding
661
+ tmp_embeddings.append(image_embeddings)
662
+ start = end
663
+ embeddings = torch.concat(tmp_embeddings, dim=0)
664
+
665
+ return embeddings
666
+
667
+
668
+ def apply_rotary_pos_emb_vision(
669
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
670
+ ) -> tuple[torch.Tensor, torch.Tensor]:
671
+ orig_q_dtype = q.dtype
672
+ orig_k_dtype = k.dtype
673
+ q, k = q.float(), k.float()
674
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
675
+ q_embed = (q * cos) + (rotate_half(q) * sin)
676
+ k_embed = (k * cos) + (rotate_half(k) * sin)
677
+ q_embed = q_embed.to(orig_q_dtype)
678
+ k_embed = k_embed.to(orig_k_dtype)
679
+ return q_embed, k_embed
680
+
681
+
682
+ class PaddleOCRVisionAttention(nn.Module):
683
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
684
+
685
+ def __init__(self, config: PaddleOCRVisionConfig):
686
+ super().__init__()
687
+ self.config = config
688
+ self.embed_dim = config.hidden_size
689
+ self.num_heads = config.num_attention_heads
690
+ self.head_dim = self.embed_dim // self.num_heads
691
+ if self.head_dim * self.num_heads != self.embed_dim:
692
+ raise ValueError(
693
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
694
+ f" {self.num_heads})."
695
+ )
696
+ self.is_causal = False
697
+
698
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
699
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
700
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
701
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
702
+ self.num_key_value_groups = 1
703
+ self.scaling = self.head_dim**-0.5
704
+ self.attention_dropout = config.attention_dropout
705
+
706
+ def forward(
707
+ self,
708
+ hidden_states: torch.Tensor,
709
+ cu_seqlens: torch.Tensor,
710
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
711
+ **kwargs: Unpack[TransformersKwargs],
712
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
713
+ """
714
+ Args:
715
+ hidden_states (`torch.Tensor`):
716
+ Input to the layer of shape `(seq_len, embed_dim)`.
717
+ cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`):
718
+ The cumulative sequence lengths of each image or video feature.
719
+ position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`):
720
+ The cosine and sine position embeddings for vision attention.
721
+ """
722
+ seq_length = hidden_states.shape[0]
723
+ query_states = self.q_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim)
724
+ key_states = self.k_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim)
725
+ value_states = self.v_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim)
726
+
727
+ cos, sin = position_embeddings
728
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
729
+
730
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
731
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
732
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
733
+
734
+ attention_interface: Callable = eager_attention_forward
735
+ if self.config._attn_implementation != "eager":
736
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
737
+
738
+ if self.config._attn_implementation == "flash_attention_2":
739
+ # Flash Attention 2: Use cu_seqlens for variable length attention
740
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
741
+ attn_output, attn_weights = attention_interface(
742
+ self,
743
+ query_states,
744
+ key_states,
745
+ value_states,
746
+ attention_mask=None,
747
+ scaling=self.scaling,
748
+ dropout=0.0 if not self.training else self.attention_dropout,
749
+ cu_seq_lens_q=cu_seqlens,
750
+ cu_seq_lens_k=cu_seqlens,
751
+ max_length_q=max_seqlen,
752
+ max_length_k=max_seqlen,
753
+ is_causal=False,
754
+ **kwargs,
755
+ )
756
+ else:
757
+ # Other implementations: Process each chunk separately
758
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
759
+ splits = [
760
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
761
+ ]
762
+
763
+ attn_outputs, attn_weights = [], []
764
+ for q, k, v in zip(*splits):
765
+ attn_output, attn_weight = attention_interface(
766
+ self,
767
+ q,
768
+ k,
769
+ v,
770
+ attention_mask=None,
771
+ scaling=self.scaling,
772
+ dropout=0.0 if not self.training else self.attention_dropout,
773
+ is_causal=False,
774
+ **kwargs,
775
+ )
776
+ attn_outputs.append(attn_output)
777
+ attn_weights.append(attn_weight)
778
+
779
+ attn_output = torch.cat(attn_outputs, dim=1)
780
+
781
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
782
+ attn_output = self.out_proj(attn_output)
783
+
784
+ return attn_output, attn_weights
785
+
786
+
787
+ class PaddleOCRVisionMLP(nn.Module):
788
+ def __init__(self, config: PaddleOCRVisionConfig):
789
+ super().__init__()
790
+ self.config = config
791
+ self.activation_fn = ACT2FN[config.hidden_act]
792
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
793
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
794
+
795
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
796
+ hidden_states = self.fc1(hidden_states)
797
+ hidden_states = self.activation_fn(hidden_states)
798
+ hidden_states = self.fc2(hidden_states)
799
+ return hidden_states
800
+
801
+
802
+ class PaddleOCRVisionEncoderLayer(GradientCheckpointingLayer):
803
+ def __init__(self, config: PaddleOCRVisionConfig):
804
+ super().__init__()
805
+ self.embed_dim = config.hidden_size
806
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
807
+ self.self_attn = PaddleOCRVisionAttention(config=config)
808
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
809
+ self.mlp = PaddleOCRVisionMLP(config=config)
810
+
811
+ @auto_docstring
812
+ def forward(
813
+ self,
814
+ hidden_states: torch.Tensor,
815
+ cu_seqlens: torch.Tensor,
816
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
817
+ **kwargs: Unpack[TransformersKwargs],
818
+ ) -> torch.Tensor:
819
+ r"""
820
+ cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`):
821
+ The cumulative sequence lengths of each image or video feature.
822
+ position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`):
823
+ The cosine and sine position embeddings for vision attention.
824
+ """
825
+ residual = hidden_states
826
+
827
+ hidden_states = self.layer_norm1(hidden_states)
828
+ hidden_states, _ = self.self_attn(
829
+ hidden_states,
830
+ cu_seqlens=cu_seqlens,
831
+ position_embeddings=position_embeddings,
832
+ **kwargs,
833
+ )
834
+ hidden_states = residual + hidden_states
835
+
836
+ residual = hidden_states
837
+ hidden_states = self.layer_norm2(hidden_states)
838
+ hidden_states = self.mlp(hidden_states)
839
+ hidden_states = residual + hidden_states
840
+
841
+ return hidden_states
842
+
843
+
844
+ class PaddleOCRVisionEncoder(nn.Module):
845
+ """
846
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
847
+ [`PaddleOCRVisionEncoderLayer`].
848
+
849
+ Args:
850
+ config: PaddleOCRVisionConfig
851
+ """
852
+
853
+ def __init__(self, config: PaddleOCRVisionConfig):
854
+ super().__init__()
855
+ self.config = config
856
+ self.layers = nn.ModuleList([PaddleOCRVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
857
+ self.gradient_checkpointing = False
858
+ embed_dim = config.hidden_size
859
+ num_heads = config.num_attention_heads
860
+ head_dim = embed_dim // num_heads
861
+ self.rotary_pos_emb = PaddleOCRVisionRotaryEmbedding(head_dim // 2)
862
+
863
+ # Ignore copy
864
+ @can_return_tuple
865
+ @auto_docstring
866
+ def forward(
867
+ self,
868
+ inputs_embeds: torch.FloatTensor,
869
+ cu_seqlens: torch.Tensor,
870
+ attention_mask: Optional[torch.Tensor] = None,
871
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
872
+ ) -> BaseModelOutput:
873
+ r"""
874
+ inputs_embeds (`torch.FloatTensor` of shape `(sequence_length, hidden_size)`, *optional*):
875
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
876
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
877
+ than the model's internal embedding lookup matrix.
878
+ cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`):
879
+ The cumulative sequence lengths of each image or video feature.
880
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
881
+ The attention_mask used in forward function shape [batch_size X sequence_length] if not None.
882
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
883
+ The temporal, height and width of feature shape of each image in LLM.
884
+ """
885
+ device = inputs_embeds.device
886
+ hidden_states = inputs_embeds
887
+ attention_mask = create_bidirectional_mask(
888
+ config=self.config,
889
+ input_embeds=inputs_embeds,
890
+ attention_mask=attention_mask,
891
+ )
892
+ split_hids = []
893
+ split_wids = []
894
+ for t, h, w in image_grid_thw:
895
+ image_pids = torch.arange(t * h * w, device=device) % (h * w)
896
+ sample_hids = image_pids // w
897
+ sample_wids = image_pids % w
898
+ split_hids.append(sample_hids)
899
+ split_wids.append(sample_wids)
900
+ width_position_ids = torch.concat(split_wids, dim=0)
901
+ height_position_ids = torch.concat(split_hids, dim=0)
902
+
903
+ pids = torch.stack([height_position_ids, width_position_ids], dim=-1)
904
+ max_grid_size = pids.max() + 1
905
+ rotary_embeddings_max_grid = self.rotary_pos_emb(max_grid_size)
906
+ rotary_embeddings = rotary_embeddings_max_grid[pids].flatten(1)
907
+ rotary_embeddings = rotary_embeddings.repeat(1, 2)
908
+ position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin())
909
+
910
+ for encoder_layer in self.layers:
911
+ hidden_states = encoder_layer(
912
+ hidden_states,
913
+ cu_seqlens=cu_seqlens,
914
+ position_embeddings=position_embeddings,
915
+ )
916
+
917
+ return BaseModelOutput(
918
+ last_hidden_state=hidden_states,
919
+ )
920
+
921
+
922
+ class PaddleOCRVisionTransformer(PaddleOCRVLPreTrainedModel):
923
+ def __init__(self, config: PaddleOCRVisionConfig):
924
+ super().__init__(config)
925
+ self.config = config
926
+ embed_dim = config.hidden_size
927
+
928
+ self.embeddings = PaddleOCRVisionEmbeddings(config)
929
+ self.encoder = PaddleOCRVisionEncoder(config)
930
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
931
+
932
+ self.post_init()
933
+
934
+ def forward(
935
+ self,
936
+ pixel_values: torch.FloatTensor,
937
+ cu_seqlens: torch.Tensor,
938
+ attention_mask: Optional[torch.Tensor] = None,
939
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
940
+ **kwargs,
941
+ ) -> BaseModelOutputWithPooling:
942
+ """
943
+ Args:
944
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`):
945
+ The tensors corresponding to the input images.
946
+ cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`):
947
+ The cumulative sequence lengths of each image or video feature.
948
+ attention_mask (`torch.Tensor`, *optional*):
949
+ The attention_mask used in forward function shape [batch_size X sequence_length] if not None.
950
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
951
+ The temporal, height and width of feature shape of each image in LLM.
952
+ """
953
+ hidden_states = self.embeddings(pixel_values, image_grid_thw=image_grid_thw)
954
+
955
+ encoder_outputs: BaseModelOutput = self.encoder(
956
+ inputs_embeds=hidden_states,
957
+ cu_seqlens=cu_seqlens,
958
+ attention_mask=attention_mask,
959
+ image_grid_thw=image_grid_thw,
960
+ )
961
+
962
+ last_hidden_state = encoder_outputs.last_hidden_state
963
+ last_hidden_state = self.post_layernorm(last_hidden_state)
964
+
965
+ return BaseModelOutputWithPooling(
966
+ last_hidden_state=last_hidden_state,
967
+ pooler_output=None,
968
+ hidden_states=encoder_outputs.hidden_states,
969
+ attentions=encoder_outputs.attentions,
970
+ )
971
+
972
+
973
+ @dataclass
974
+ @auto_docstring(
975
+ custom_intro="""
976
+ Base class for Llava outputs, with hidden states and attentions.
977
+ """
978
+ )
979
+ class PaddleOCRVLModelOutputWithPast(ModelOutput):
980
+ r"""
981
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
982
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
983
+
984
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
985
+ `past_key_values` input) to speed up sequential decoding.
986
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
987
+ The rope index difference between sequence length and multimodal rope.
988
+ """
989
+
990
+ last_hidden_state: Optional[torch.FloatTensor] = None
991
+ past_key_values: Optional[Cache] = None
992
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
993
+ attentions: Optional[tuple[torch.FloatTensor]] = None
994
+ rope_deltas: Optional[torch.LongTensor] = None
995
+
996
+
997
+ @dataclass
998
+ @auto_docstring(
999
+ custom_intro="""
1000
+ Base class for PaddleOCRVL causal language model (or autoregressive) outputs.
1001
+ """
1002
+ )
1003
+ class PaddleOCRVLCausalLMOutputWithPast(ModelOutput):
1004
+ r"""
1005
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1006
+ Language modeling loss (for next-token prediction).
1007
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1008
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1009
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1010
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
1011
+
1012
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
1013
+ `past_key_values` input) to speed up sequential decoding.
1014
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1015
+ The rope index difference between sequence length and multimodal rope.
1016
+ """
1017
+
1018
+ loss: Optional[torch.FloatTensor] = None
1019
+ logits: Optional[torch.FloatTensor] = None
1020
+ past_key_values: Optional[Cache] = None
1021
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
1022
+ attentions: Optional[tuple[torch.FloatTensor]] = None
1023
+ rope_deltas: Optional[torch.LongTensor] = None
1024
+
1025
+
1026
+ @auto_docstring
1027
+ class PaddleOCRVLModel(PaddleOCRVLPreTrainedModel):
1028
+ base_model_prefix = "model"
1029
+ _checkpoint_conversion_mapping = {"^model": "language_model"}
1030
+ # Reference: fix gemma3 grad acc #37208
1031
+ accepts_loss_kwargs = False
1032
+ _keys_to_ignore_on_load_unexpected = ["packing_position_embedding", "vision_model.head"]
1033
+
1034
+ def __init__(self, config: PaddleOCRVLConfig):
1035
+ super().__init__(config)
1036
+ self.visual = PaddleOCRVisionModel._from_config(config.vision_config)
1037
+ self.language_model = PaddleOCRTextModel._from_config(config.text_config)
1038
+ self.rope_deltas = None
1039
+ self.projector = PaddleOCRProjector(config)
1040
+
1041
+ # Initialize weights and apply final processing
1042
+ self.post_init()
1043
+
1044
+ def get_input_embeddings(self):
1045
+ return self.language_model.embed_tokens
1046
+
1047
+ def set_input_embeddings(self, value):
1048
+ self.language_model.embed_tokens = value
1049
+
1050
+ def get_rope_index(
1051
+ self,
1052
+ input_ids: Optional[torch.LongTensor] = None,
1053
+ image_grid_thw: Optional[torch.LongTensor] = None,
1054
+ video_grid_thw: Optional[torch.LongTensor] = None,
1055
+ attention_mask: Optional[torch.Tensor] = None,
1056
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1057
+ """
1058
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
1059
+
1060
+ Explanation:
1061
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
1062
+
1063
+ For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
1064
+ Examples:
1065
+ input_ids: [T T T T T], here T is for text.
1066
+ temporal position_ids: [0, 1, 2, 3, 4]
1067
+ height position_ids: [0, 1, 2, 3, 4]
1068
+ width position_ids: [0, 1, 2, 3, 4]
1069
+
1070
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
1071
+ and 1D rotary position embedding for text part.
1072
+ Examples:
1073
+ Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
1074
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
1075
+ vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
1076
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
1077
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
1078
+ text temporal position_ids: [3, 4, 5, 6, 7]
1079
+ text height position_ids: [3, 4, 5, 6, 7]
1080
+ text width position_ids: [3, 4, 5, 6, 7]
1081
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
1082
+
1083
+ Args:
1084
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1085
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1086
+ it.
1087
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1088
+ The temporal, height and width of feature shape of each image in LLM.
1089
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1090
+ The temporal, height and width of feature shape of each video in LLM.
1091
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1092
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1093
+
1094
+ - 1 for tokens that are **not masked**,
1095
+ - 0 for tokens that are **masked**.
1096
+
1097
+ Returns:
1098
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
1099
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
1100
+ """
1101
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
1102
+ image_token_id = self.config.image_token_id
1103
+ video_token_id = self.config.video_token_id
1104
+ vision_start_token_id = self.config.vision_start_token_id
1105
+ mrope_position_deltas = []
1106
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
1107
+ total_input_ids = input_ids
1108
+ if attention_mask is None:
1109
+ attention_mask = torch.ones_like(total_input_ids)
1110
+ position_ids = torch.ones(
1111
+ 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
1112
+ )
1113
+ image_index, video_index = 0, 0
1114
+ for i, input_ids in enumerate(total_input_ids):
1115
+ input_ids = input_ids[attention_mask[i].to(input_ids.device) == 1]
1116
+ image_nums, video_nums = 0, 0
1117
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
1118
+ vision_tokens = input_ids[vision_start_indices + 1]
1119
+ image_nums = (vision_tokens == image_token_id).sum()
1120
+ video_nums = (vision_tokens == video_token_id).sum()
1121
+ input_tokens = input_ids.tolist()
1122
+ llm_pos_ids_list: list = []
1123
+ st = 0
1124
+ remain_images, remain_videos = image_nums, video_nums
1125
+ for _ in range(image_nums + video_nums):
1126
+ if image_token_id in input_tokens and remain_images > 0:
1127
+ ed_image = input_tokens.index(image_token_id, st)
1128
+ else:
1129
+ ed_image = len(input_tokens) + 1
1130
+ if video_token_id in input_tokens and remain_videos > 0:
1131
+ ed_video = input_tokens.index(video_token_id, st)
1132
+ else:
1133
+ ed_video = len(input_tokens) + 1
1134
+ if ed_image < ed_video:
1135
+ t, h, w = (
1136
+ image_grid_thw[image_index][0],
1137
+ image_grid_thw[image_index][1],
1138
+ image_grid_thw[image_index][2],
1139
+ )
1140
+ image_index += 1
1141
+ remain_images -= 1
1142
+ ed = ed_image
1143
+ else:
1144
+ t, h, w = (
1145
+ video_grid_thw[video_index][0],
1146
+ video_grid_thw[video_index][1],
1147
+ video_grid_thw[video_index][2],
1148
+ )
1149
+ video_index += 1
1150
+ remain_videos -= 1
1151
+ ed = ed_video
1152
+ llm_grid_t, llm_grid_h, llm_grid_w = (
1153
+ t.item(),
1154
+ h.item() // spatial_merge_size,
1155
+ w.item() // spatial_merge_size,
1156
+ )
1157
+ text_len = ed - st
1158
+
1159
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1160
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1161
+
1162
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
1163
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
1164
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
1165
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
1166
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1167
+
1168
+ if st < len(input_tokens):
1169
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1170
+ text_len = len(input_tokens) - st
1171
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1172
+
1173
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1174
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
1175
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
1176
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
1177
+ return position_ids, mrope_position_deltas
1178
+ else:
1179
+ if attention_mask is not None:
1180
+ position_ids = attention_mask.long().cumsum(-1) - 1
1181
+ position_ids.masked_fill_(attention_mask == 0, 1)
1182
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
1183
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
1184
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
1185
+ else:
1186
+ position_ids = (
1187
+ torch.arange(input_ids.shape[1], device=input_ids.device)
1188
+ .view(1, 1, -1)
1189
+ .expand(3, input_ids.shape[0], -1)
1190
+ )
1191
+ mrope_position_deltas = torch.zeros(
1192
+ [input_ids.shape[0], 1],
1193
+ device=input_ids.device,
1194
+ dtype=input_ids.dtype,
1195
+ )
1196
+
1197
+ return position_ids, mrope_position_deltas
1198
+
1199
+ def get_video_features(
1200
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
1201
+ ):
1202
+ """
1203
+ Encodes videos into continuous embeddings that can be forwarded to the language model.
1204
+
1205
+ Args:
1206
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1207
+ The tensors corresponding to the input videos.
1208
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1209
+ The temporal, height and width of feature shape of each video in LLM.
1210
+ """
1211
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
1212
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
1213
+ split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1214
+ video_embeds = torch.split(video_embeds, split_sizes)
1215
+ return video_embeds
1216
+
1217
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
1218
+ """
1219
+ Encodes images into continuous embeddings that can be forwarded to the language model.
1220
+
1221
+ Args:
1222
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1223
+ The tensors corresponding to the input images.
1224
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1225
+ The temporal, height and width of feature shape of each image in LLM.
1226
+ """
1227
+ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0)
1228
+ cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum(
1229
+ dim=0,
1230
+ # Select dtype based on the following factors:
1231
+ # - FA2 requires that cu_seqlens_q must have dtype int32
1232
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
1233
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
1234
+ dtype=image_grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
1235
+ )
1236
+ cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)
1237
+ vision_outputs = self.visual(
1238
+ pixel_values=pixel_values,
1239
+ image_grid_thw=image_grid_thw,
1240
+ cu_seqlens=cu_seqlens,
1241
+ )
1242
+ image_embeds = vision_outputs.last_hidden_state
1243
+ image_embeds = self.projector(image_embeds, image_grid_thw)
1244
+ return image_embeds
1245
+
1246
+ def get_placeholder_mask(
1247
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
1248
+ ):
1249
+ """
1250
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
1251
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
1252
+ """
1253
+ if input_ids is None:
1254
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
1255
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
1256
+ )
1257
+ special_image_mask = special_image_mask.all(-1)
1258
+ else:
1259
+ special_image_mask = input_ids == self.config.image_token_id
1260
+
1261
+ n_image_tokens = special_image_mask.sum()
1262
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1263
+ n_image_features = image_features.shape[0] * image_features.shape[1]
1264
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
1265
+ raise ValueError(
1266
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
1267
+ )
1268
+ return special_image_mask
1269
+
1270
+ @can_return_tuple
1271
+ def forward(
1272
+ self,
1273
+ input_ids: torch.LongTensor = None,
1274
+ attention_mask: Optional[torch.Tensor] = None,
1275
+ position_ids: Optional[torch.LongTensor] = None,
1276
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
1277
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1278
+ use_cache: Optional[bool] = None,
1279
+ pixel_values: Optional[torch.Tensor] = None,
1280
+ image_grid_thw: Optional[torch.LongTensor] = None,
1281
+ rope_deltas: Optional[torch.LongTensor] = None,
1282
+ cache_position: Optional[torch.LongTensor] = None,
1283
+ **kwargs,
1284
+ ) -> Union[tuple, PaddleOCRVLModelOutputWithPast]:
1285
+ r"""
1286
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1287
+ The temporal, height and width of feature shape of each image in LLM.
1288
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1289
+ The rope index difference between sequence length and multimodal rope.
1290
+ """
1291
+ if inputs_embeds is None:
1292
+ inputs_embeds = self.language_model.embed_tokens(input_ids)
1293
+
1294
+ if pixel_values is not None:
1295
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw).to(
1296
+ inputs_embeds.device, inputs_embeds.dtype
1297
+ )
1298
+ image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds)
1299
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1300
+
1301
+ if position_ids is None:
1302
+ past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1303
+ if self.rope_deltas is None or past_key_values_length == 0:
1304
+ position_ids, rope_deltas = self.get_rope_index(
1305
+ input_ids=input_ids,
1306
+ image_grid_thw=image_grid_thw,
1307
+ attention_mask=attention_mask,
1308
+ )
1309
+ self.rope_deltas = rope_deltas
1310
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1311
+ else:
1312
+ batch_size, seq_length, _ = inputs_embeds.shape
1313
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1314
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
1315
+ delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
1316
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1317
+ position_ids = position_ids + delta.to(position_ids.device)
1318
+
1319
+ outputs = self.language_model(
1320
+ input_ids=None,
1321
+ position_ids=position_ids,
1322
+ attention_mask=attention_mask,
1323
+ past_key_values=past_key_values,
1324
+ inputs_embeds=inputs_embeds,
1325
+ use_cache=use_cache,
1326
+ cache_position=cache_position,
1327
+ **kwargs,
1328
+ )
1329
+
1330
+ output = PaddleOCRVLModelOutputWithPast(
1331
+ last_hidden_state=outputs.last_hidden_state,
1332
+ past_key_values=outputs.past_key_values,
1333
+ hidden_states=outputs.hidden_states,
1334
+ attentions=outputs.attentions,
1335
+ rope_deltas=self.rope_deltas,
1336
+ )
1337
+
1338
+ return output
1339
+
1340
+
1341
+ class PaddleOCRVLForConditionalGeneration(PaddleOCRVLPreTrainedModel, GenerationMixin):
1342
+ _checkpoint_conversion_mapping = {
1343
+ "^visual": "model.visual",
1344
+ "^mlp_AR": "model.projector",
1345
+ r"^model(?!(\.visual|\.projector|\.language_model))": "model.language_model",
1346
+ }
1347
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
1348
+ _keys_to_ignore_on_load_unexpected = ["packing_position_embedding", "vision_model.head"]
1349
+
1350
+ def __init__(self, config):
1351
+ super().__init__(config)
1352
+ self.model = PaddleOCRVLModel(config)
1353
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1354
+
1355
+ self.post_init()
1356
+
1357
+ def get_input_embeddings(self):
1358
+ return self.model.get_input_embeddings()
1359
+
1360
+ def set_input_embeddings(self, value):
1361
+ self.model.set_input_embeddings(value)
1362
+
1363
+ def get_video_features(
1364
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
1365
+ ):
1366
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw)
1367
+
1368
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
1369
+ return self.model.get_image_features(pixel_values, image_grid_thw)
1370
+
1371
+ @can_return_tuple
1372
+ @auto_docstring
1373
+ def forward(
1374
+ self,
1375
+ input_ids: Optional[torch.LongTensor] = None,
1376
+ attention_mask: Optional[torch.Tensor] = None,
1377
+ position_ids: Optional[torch.LongTensor] = None,
1378
+ past_key_values: Optional[Cache] = None,
1379
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1380
+ labels: Optional[torch.LongTensor] = None,
1381
+ use_cache: Optional[bool] = None,
1382
+ pixel_values: Optional[torch.Tensor] = None,
1383
+ image_grid_thw: Optional[torch.LongTensor] = None,
1384
+ rope_deltas: Optional[torch.LongTensor] = None,
1385
+ cache_position: Optional[torch.LongTensor] = None,
1386
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1387
+ **kwargs: Unpack[TransformersKwargs],
1388
+ ) -> Union[tuple, PaddleOCRVLCausalLMOutputWithPast]:
1389
+ r"""
1390
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1391
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1392
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1393
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1394
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1395
+ The temporal, height and width of feature shape of each image in LLM.
1396
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1397
+ The rope index difference between sequence length and multimodal rope.
1398
+
1399
+ Example:
1400
+
1401
+ ```python
1402
+ >>> from transformers import AutoProcessor, PaddleOCRVLForConditionalGeneration
1403
+
1404
+ >>> model = PaddleOCRVLForConditionalGeneration.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
1405
+ >>> processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
1406
+
1407
+ >>> messages = [
1408
+ {
1409
+ "role": "user",
1410
+ "content": [
1411
+ {
1412
+ "type": "image",
1413
+ "image": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg",
1414
+ },
1415
+ {"type": "text", "text": "OCR:"},
1416
+ ],
1417
+ }
1418
+ ]
1419
+
1420
+ >>> inputs = processor.apply_chat_template(
1421
+ messages,
1422
+ tokenize=True,
1423
+ add_generation_prompt=True,
1424
+ return_dict=True,
1425
+ return_tensors="pt"
1426
+ ).to(model.device)
1427
+
1428
+ >>> # Generate
1429
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=1024)
1430
+ >>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
1431
+ >>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1432
+ >>> print(output_text)
1433
+ ```
1434
+ """
1435
+ outputs: PaddleOCRVLModelOutputWithPast = self.model(
1436
+ input_ids=input_ids,
1437
+ attention_mask=attention_mask,
1438
+ position_ids=position_ids,
1439
+ image_grid_thw=image_grid_thw,
1440
+ past_key_values=past_key_values,
1441
+ inputs_embeds=inputs_embeds,
1442
+ use_cache=use_cache,
1443
+ pixel_values=pixel_values,
1444
+ rope_deltas=rope_deltas,
1445
+ cache_position=cache_position,
1446
+ **kwargs,
1447
+ )
1448
+ hidden_states = outputs.last_hidden_state
1449
+
1450
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1451
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1452
+
1453
+ loss = None
1454
+ if labels is not None:
1455
+ loss = self.loss_function(
1456
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
1457
+ )
1458
+
1459
+ return PaddleOCRVLCausalLMOutputWithPast(
1460
+ loss=loss,
1461
+ logits=logits,
1462
+ past_key_values=outputs.past_key_values,
1463
+ hidden_states=outputs.hidden_states,
1464
+ attentions=outputs.attentions,
1465
+ rope_deltas=outputs.rope_deltas,
1466
+ )
1467
+
1468
+ def prepare_inputs_for_generation(
1469
+ self,
1470
+ input_ids,
1471
+ past_key_values=None,
1472
+ attention_mask=None,
1473
+ inputs_embeds=None,
1474
+ cache_position=None,
1475
+ position_ids=None,
1476
+ use_cache=True,
1477
+ pixel_values=None,
1478
+ pixel_values_videos=None,
1479
+ image_grid_thw=None,
1480
+ video_grid_thw=None,
1481
+ is_first_iteration=False,
1482
+ **kwargs,
1483
+ ):
1484
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
1485
+
1486
+ model_inputs = super().prepare_inputs_for_generation(
1487
+ input_ids,
1488
+ past_key_values=past_key_values,
1489
+ attention_mask=attention_mask,
1490
+ inputs_embeds=inputs_embeds,
1491
+ cache_position=cache_position,
1492
+ position_ids=position_ids,
1493
+ pixel_values=pixel_values,
1494
+ pixel_values_videos=pixel_values_videos,
1495
+ image_grid_thw=image_grid_thw,
1496
+ video_grid_thw=video_grid_thw,
1497
+ use_cache=use_cache,
1498
+ is_first_iteration=is_first_iteration,
1499
+ **kwargs,
1500
+ )
1501
+
1502
+ # Qwen2-VL position_ids are prepareed with rope_deltas in forward
1503
+ if position_ids is None:
1504
+ # Calculate RoPE index once per generation in the pre-fill stage only.
1505
+ # When compiling, we can't check tensor values thus we check only input length
1506
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
1507
+ # models currently cannot do asssisted decoding
1508
+ if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None:
1509
+ vision_positions, rope_deltas = self.model.get_rope_index(
1510
+ model_inputs.get("input_ids", None),
1511
+ image_grid_thw=image_grid_thw,
1512
+ video_grid_thw=video_grid_thw,
1513
+ attention_mask=attention_mask,
1514
+ )
1515
+ self.model.rope_deltas = rope_deltas
1516
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1517
+ elif "position_ids" in model_inputs:
1518
+ batch_size, seq_length = model_inputs["position_ids"].shape
1519
+ device = model_inputs["position_ids"].device
1520
+ position_ids = torch.arange(seq_length, device=device)
1521
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
1522
+ delta = cache_position[0] + self.model.rope_deltas
1523
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1524
+ vision_positions = position_ids + delta.expand_as(position_ids)
1525
+
1526
+ # Concatenate "text + vision" positions into [4, bs, seq-len]
1527
+ text_positions = model_inputs["position_ids"][None, ...]
1528
+ model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
1529
+
1530
+ if not is_first_iteration and use_cache:
1531
+ model_inputs["pixel_values"] = None
1532
+ model_inputs["pixel_values_videos"] = None
1533
+
1534
+ return model_inputs
1535
+
1536
+ def _get_image_nums_and_video_nums(
1537
+ self,
1538
+ input_ids: Optional[torch.LongTensor],
1539
+ inputs_embeds: Optional[torch.Tensor] = None,
1540
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1541
+ """
1542
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
1543
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
1544
+
1545
+ Args:
1546
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1547
+ Indices of input sequence tokens in the vocabulary.
1548
+
1549
+ Returns:
1550
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
1551
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
1552
+ """
1553
+ image_token_id = self.config.image_token_id
1554
+ video_token_id = self.config.video_token_id
1555
+ vision_start_token_id = self.config.vision_start_token_id
1556
+
1557
+ if inputs_embeds is not None:
1558
+ vision_start_mask = (
1559
+ inputs_embeds
1560
+ == self.get_input_embeddings()(
1561
+ torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
1562
+ )
1563
+ )[..., 0]
1564
+ image_mask = (
1565
+ inputs_embeds
1566
+ == self.get_input_embeddings()(
1567
+ torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
1568
+ )
1569
+ )[..., 0]
1570
+ video_mask = (
1571
+ inputs_embeds
1572
+ == self.get_input_embeddings()(
1573
+ torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
1574
+ )
1575
+ )[..., 0]
1576
+ else:
1577
+ vision_start_mask = input_ids == vision_start_token_id
1578
+ image_mask = input_ids == image_token_id
1579
+ video_mask = input_ids == video_token_id
1580
+
1581
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
1582
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
1583
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
1584
+
1585
+ return image_nums, video_nums
1586
+
1587
+ def _expand_inputs_for_generation(
1588
+ self,
1589
+ expand_size: int = 1,
1590
+ is_encoder_decoder: bool = False,
1591
+ input_ids: Optional[torch.LongTensor] = None,
1592
+ **model_kwargs,
1593
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
1594
+ # Overwritten -- Support for expanding tensors without a batch size dimension
1595
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
1596
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
1597
+ # image_grid_thw.shape[0] is sum(num_images for samples)
1598
+
1599
+ if expand_size == 1:
1600
+ return input_ids, model_kwargs
1601
+
1602
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
1603
+
1604
+ def _expand_dict_for_generation_visual(dict_to_expand):
1605
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
1606
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
1607
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
1608
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
1609
+ )
1610
+
1611
+ def _repeat_interleave_samples(x, lengths, repeat_times):
1612
+ samples = torch.split(x, lengths)
1613
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
1614
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
1615
+ return result
1616
+
1617
+ for key in dict_to_expand:
1618
+ if key == "pixel_values":
1619
+ # split images into samples
1620
+ samples = torch.split(image_grid_thw, list(image_nums))
1621
+ # compute the sequence length of images for each sample
1622
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1623
+ dict_to_expand[key] = _repeat_interleave_samples(
1624
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1625
+ )
1626
+ elif key == "image_grid_thw":
1627
+ # get the num of images for each sample
1628
+ lengths = list(image_nums)
1629
+ dict_to_expand[key] = _repeat_interleave_samples(
1630
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1631
+ )
1632
+ elif key == "pixel_values_videos":
1633
+ samples = torch.split(video_grid_thw, list(video_nums))
1634
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1635
+ dict_to_expand[key] = _repeat_interleave_samples(
1636
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1637
+ )
1638
+ elif key == "video_grid_thw":
1639
+ lengths = list(video_nums)
1640
+ dict_to_expand[key] = _repeat_interleave_samples(
1641
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1642
+ )
1643
+ elif key == "second_per_grid_ts":
1644
+ dict_to_expand[key] = _repeat_interleave_samples(
1645
+ dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
1646
+ )
1647
+ return dict_to_expand
1648
+
1649
+ def _expand_dict_for_generation(dict_to_expand):
1650
+ for key in dict_to_expand:
1651
+ if (
1652
+ key != "cache_position"
1653
+ and dict_to_expand[key] is not None
1654
+ and isinstance(dict_to_expand[key], torch.Tensor)
1655
+ and key not in visual_keys
1656
+ ):
1657
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1658
+ return dict_to_expand
1659
+
1660
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
1661
+
1662
+ if input_ids is not None:
1663
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
1664
+
1665
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
1666
+
1667
+ if is_encoder_decoder:
1668
+ if model_kwargs.get("encoder_outputs") is None:
1669
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
1670
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
1671
+
1672
+ return input_ids, model_kwargs
1673
+
1674
+
1675
+ __all__ = [
1676
+ "PaddleOCRVLForConditionalGeneration",
1677
+ "PaddleOCRVLModel",
1678
+ "PaddleOCRVLPreTrainedModel",
1679
+ "PaddleOCRVisionTransformer",
1680
+ "PaddleOCRTextModel",
1681
+ "PaddleOCRVisionModel",
1682
+ ]