transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1668 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_paddleocr_vl.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
10
+ # and OPT implementations in this library. It has been modified from its
11
+ # original forms to accommodate minor architectural differences compared
12
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
13
+ #
14
+ # Licensed under the Apache License, Version 2.0 (the "License");
15
+ # you may not use this file except in compliance with the License.
16
+ # You may obtain a copy of the License at
17
+ #
18
+ # http://www.apache.org/licenses/LICENSE-2.0
19
+ #
20
+ # Unless required by applicable law or agreed to in writing, software
21
+ # distributed under the License is distributed on an "AS IS" BASIS,
22
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23
+ # See the License for the specific language governing permissions and
24
+ # limitations under the License.
25
+
26
+ from collections.abc import Callable
27
+ from dataclasses import dataclass
28
+ from typing import Any, Optional, Union
29
+
30
+ import torch
31
+ from torch import nn
32
+
33
+ from ...activations import ACT2FN, GELUActivation
34
+ from ...cache_utils import Cache, DynamicCache
35
+ from ...generation import GenerationMixin
36
+ from ...integrations import use_kernel_forward_from_hub
37
+ from ...masking_utils import create_bidirectional_mask, create_causal_mask
38
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
39
+ from ...modeling_layers import GradientCheckpointingLayer
40
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
41
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
42
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
43
+ from ...processing_utils import Unpack
44
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
45
+ from ...utils.generic import check_model_inputs, maybe_autocast
46
+ from .configuration_paddleocr_vl import PaddleOCRTextConfig, PaddleOCRVisionConfig, PaddleOCRVLConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class PaddleOCRProjector(nn.Module):
53
+ def __init__(self, config: PaddleOCRVLConfig):
54
+ super().__init__()
55
+ self.merge_kernel_size = (config.vision_config.spatial_merge_size, config.vision_config.spatial_merge_size)
56
+
57
+ hidden_size = config.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1]
58
+
59
+ self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-05)
60
+ self.linear_1 = nn.Linear(hidden_size, hidden_size, bias=True)
61
+ self.act = GELUActivation()
62
+ self.linear_2 = nn.Linear(hidden_size, config.text_config.hidden_size, bias=True)
63
+
64
+ def forward(self, image_features: torch.Tensor, image_grid_thw: torch.Tensor) -> torch.Tensor:
65
+ image_features_chunks = image_features.split(image_grid_thw.prod(dim=1).tolist(), dim=0)
66
+ m1, m2 = self.merge_kernel_size
67
+
68
+ processed_features = []
69
+ for image_feature, image_grid in zip(image_features_chunks, image_grid_thw):
70
+ image_feature = self.pre_norm(image_feature)
71
+ t, h, w = image_grid
72
+ d = image_feature.shape[-1]
73
+ h_block = h // m1
74
+ w_block = w // m2
75
+
76
+ image_feature = image_feature.reshape(t, h_block, m1, w_block, m2, d)
77
+ image_feature = image_feature.transpose(2, 3)
78
+ image_feature = image_feature.reshape(t * h_block * w_block, m1 * m2 * d)
79
+
80
+ hidden_states = self.linear_1(image_feature)
81
+ hidden_states = self.act(hidden_states)
82
+ hidden_states = self.linear_2(hidden_states)
83
+ processed_features.append(hidden_states)
84
+
85
+ return torch.cat(processed_features, dim=0)
86
+
87
+
88
+ class PaddleOCRVisionRotaryEmbedding(nn.Module):
89
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
90
+
91
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
92
+ super().__init__()
93
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
94
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
95
+
96
+ def forward(self, seqlen: int) -> torch.Tensor:
97
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
98
+ freqs = torch.outer(seq, self.inv_freq)
99
+ return freqs
100
+
101
+
102
+ class PaddleOCRRotaryEmbedding(nn.Module):
103
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
104
+
105
+ def __init__(self, config: PaddleOCRVLConfig, device=None):
106
+ super().__init__()
107
+ self.max_seq_len_cached = config.max_position_embeddings
108
+ self.original_max_seq_len = config.max_position_embeddings
109
+
110
+ self.config = config
111
+
112
+ self.rope_type = self.config.rope_parameters["rope_type"]
113
+ rope_init_fn: Callable = self.compute_default_rope_parameters
114
+ if self.rope_type != "default":
115
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
116
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
117
+
118
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
119
+ self.original_inv_freq = inv_freq
120
+
121
+ @staticmethod
122
+ def compute_default_rope_parameters(
123
+ config: Optional[PaddleOCRVLConfig] = None,
124
+ device: Optional["torch.device"] = None,
125
+ seq_len: Optional[int] = None,
126
+ ) -> tuple["torch.Tensor", float]:
127
+ """
128
+ Computes the inverse frequencies according to the original RoPE implementation
129
+ Args:
130
+ config ([`~transformers.PreTrainedConfig`]):
131
+ The model configuration.
132
+ device (`torch.device`):
133
+ The device to use for initialization of the inverse frequencies.
134
+ seq_len (`int`, *optional*):
135
+ The current sequence length. Unused for this type of RoPE.
136
+ Returns:
137
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
138
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
139
+ """
140
+ base = config.rope_parameters["rope_theta"]
141
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
142
+
143
+ attention_factor = 1.0 # Unused in this type of RoPE
144
+
145
+ # Compute the inverse frequencies
146
+ inv_freq = 1.0 / (
147
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
148
+ )
149
+ return inv_freq, attention_factor
150
+
151
+ # Ignore copy
152
+ def forward(self, x, position_ids):
153
+ # In contrast to other models, PaddleOCR has different position ids for the grids
154
+ # So we expand the inv_freq to shape (3, ...)
155
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
156
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
157
+
158
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
159
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
160
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
161
+ emb = torch.cat((freqs, freqs), dim=-1)
162
+ cos = emb.cos() * self.attention_scaling
163
+ sin = emb.sin() * self.attention_scaling
164
+
165
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
166
+
167
+
168
+ class PaddleOCRMLP(nn.Module):
169
+ def __init__(self, config: PaddleOCRTextConfig):
170
+ super().__init__()
171
+ self.config = config
172
+ self.hidden_size = config.hidden_size
173
+ self.intermediate_size = config.intermediate_size
174
+
175
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
176
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
177
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
178
+ self.act_fn = ACT2FN[config.hidden_act]
179
+
180
+ def forward(self, x):
181
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
182
+ return down_proj
183
+
184
+
185
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
186
+ """
187
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
188
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
189
+ """
190
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
191
+ if n_rep == 1:
192
+ return hidden_states
193
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
194
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
195
+
196
+
197
+ def eager_attention_forward(
198
+ module: nn.Module,
199
+ query: torch.Tensor,
200
+ key: torch.Tensor,
201
+ value: torch.Tensor,
202
+ attention_mask: Optional[torch.Tensor],
203
+ scaling: float,
204
+ dropout: float = 0.0,
205
+ **kwargs,
206
+ ):
207
+ key_states = repeat_kv(key, module.num_key_value_groups)
208
+ value_states = repeat_kv(value, module.num_key_value_groups)
209
+
210
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
211
+ if attention_mask is not None:
212
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
213
+ attn_weights = attn_weights + causal_mask
214
+
215
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
216
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
217
+ attn_output = torch.matmul(attn_weights, value_states)
218
+ attn_output = attn_output.transpose(1, 2).contiguous()
219
+
220
+ return attn_output, attn_weights
221
+
222
+
223
+ def rotate_half(x):
224
+ """Rotates half the hidden dims of the input."""
225
+ x1 = x[..., : x.shape[-1] // 2]
226
+ x2 = x[..., x.shape[-1] // 2 :]
227
+ return torch.cat((-x2, x1), dim=-1)
228
+
229
+
230
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
231
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
232
+
233
+ Explanation:
234
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
235
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
236
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
237
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
238
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
239
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
240
+ difference with modern LLMs.
241
+
242
+ Args:
243
+ q (`torch.Tensor`): The query tensor.
244
+ k (`torch.Tensor`): The key tensor.
245
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
246
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
247
+ position_ids (`torch.Tensor`):
248
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
249
+ used to pass offsetted position ids when working with a KV-cache.
250
+ mrope_section(`List(int)`):
251
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
252
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
253
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
254
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
255
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
256
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
257
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
258
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
259
+ Returns:
260
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
261
+ """
262
+ mrope_section = mrope_section * 2
263
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
264
+ unsqueeze_dim
265
+ )
266
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
267
+ unsqueeze_dim
268
+ )
269
+
270
+ q_embed = (q * cos) + (rotate_half(q) * sin)
271
+ k_embed = (k * cos) + (rotate_half(k) * sin)
272
+ return q_embed, k_embed
273
+
274
+
275
+ class PaddleOCRAttention(nn.Module):
276
+ """
277
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
278
+ and "Generating Long Sequences with Sparse Transformers".
279
+ """
280
+
281
+ def __init__(self, config: PaddleOCRVLConfig, layer_idx: Optional[int] = None):
282
+ super().__init__()
283
+ self.config = config
284
+ self.layer_idx = layer_idx
285
+ if layer_idx is None:
286
+ logger.warning_once(
287
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
288
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
289
+ "when creating this class."
290
+ )
291
+
292
+ self.hidden_size = config.hidden_size
293
+ self.num_heads = config.num_attention_heads
294
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
295
+ self.num_key_value_heads = config.num_key_value_heads
296
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
297
+ self.is_causal = True
298
+
299
+ self.attention_dropout = 0.0
300
+ self.rope_parameters = config.rope_parameters
301
+ self.scaling = self.head_dim**-0.5
302
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_bias)
303
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_bias)
304
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_bias)
305
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias)
306
+ self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
307
+ self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
308
+
309
+ def forward(
310
+ self,
311
+ hidden_states: torch.Tensor,
312
+ attention_mask: Optional[torch.Tensor] = None,
313
+ position_ids: Optional[torch.LongTensor] = None,
314
+ past_key_values: Optional[Cache] = None,
315
+ output_attentions: bool = False,
316
+ use_cache: bool = False,
317
+ cache_position: Optional[torch.LongTensor] = None,
318
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
319
+ **kwargs: Unpack[FlashAttentionKwargs],
320
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
321
+ bsz, q_len, _ = hidden_states.size()
322
+
323
+ query_states = self.q_proj(hidden_states)
324
+ key_states = self.k_proj(hidden_states)
325
+ value_states = self.v_proj(hidden_states)
326
+
327
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
328
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
329
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
330
+
331
+ cos, sin = position_embeddings
332
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
333
+ query_states, key_states, cos, sin, self.config.rope_parameters["mrope_section"]
334
+ )
335
+
336
+ if past_key_values is not None:
337
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
338
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
339
+
340
+ attention_interface: Callable = eager_attention_forward
341
+ if self.config._attn_implementation != "eager":
342
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
343
+
344
+ attn_output, attn_weights = attention_interface(
345
+ self,
346
+ query_states,
347
+ key_states,
348
+ value_states,
349
+ attention_mask,
350
+ dropout=0.0 if not self.training else self.attention_dropout,
351
+ scaling=self.scaling,
352
+ sliding_window=self.sliding_window,
353
+ position_ids=position_ids, # pass positions for FA2
354
+ **kwargs,
355
+ )
356
+
357
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
358
+ attn_output = self.o_proj(attn_output)
359
+ return attn_output, attn_weights
360
+
361
+
362
+ @use_kernel_forward_from_hub("RMSNorm")
363
+ class PaddleOCRRMSNorm(nn.Module):
364
+ def __init__(self, hidden_size, eps=1e-6):
365
+ """
366
+ PaddleOCRRMSNorm is equivalent to T5LayerNorm
367
+ """
368
+ super().__init__()
369
+ self.weight = nn.Parameter(torch.ones(hidden_size))
370
+ self.variance_epsilon = eps
371
+
372
+ def forward(self, hidden_states):
373
+ input_dtype = hidden_states.dtype
374
+ hidden_states = hidden_states.to(torch.float32)
375
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
376
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
377
+ return self.weight * hidden_states.to(input_dtype)
378
+
379
+ def extra_repr(self):
380
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
381
+
382
+
383
+ class PaddleOCRDecoderLayer(GradientCheckpointingLayer):
384
+ def __init__(self, config: PaddleOCRTextConfig, layer_idx: int):
385
+ super().__init__()
386
+ self.hidden_size = config.hidden_size
387
+
388
+ self.self_attn = PaddleOCRAttention(config=config, layer_idx=layer_idx)
389
+
390
+ self.mlp = PaddleOCRMLP(config)
391
+ self.input_layernorm = PaddleOCRRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
392
+ self.post_attention_layernorm = PaddleOCRRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
393
+
394
+ def forward(
395
+ self,
396
+ hidden_states: torch.Tensor,
397
+ attention_mask: Optional[torch.Tensor] = None,
398
+ position_ids: Optional[torch.LongTensor] = None,
399
+ past_key_values: Optional[Cache] = None,
400
+ use_cache: Optional[bool] = False,
401
+ cache_position: Optional[torch.LongTensor] = None,
402
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
403
+ **kwargs: Unpack[TransformersKwargs],
404
+ ) -> torch.Tensor:
405
+ residual = hidden_states
406
+ hidden_states = self.input_layernorm(hidden_states)
407
+ # Self Attention
408
+ hidden_states, _ = self.self_attn(
409
+ hidden_states=hidden_states,
410
+ attention_mask=attention_mask,
411
+ position_ids=position_ids,
412
+ past_key_values=past_key_values,
413
+ use_cache=use_cache,
414
+ cache_position=cache_position,
415
+ position_embeddings=position_embeddings,
416
+ **kwargs,
417
+ )
418
+ hidden_states = residual + hidden_states
419
+
420
+ # Fully Connected
421
+ residual = hidden_states
422
+ hidden_states = self.post_attention_layernorm(hidden_states)
423
+ hidden_states = self.mlp(hidden_states)
424
+ hidden_states = residual + hidden_states
425
+ return hidden_states
426
+
427
+
428
+ @auto_docstring
429
+ class PaddleOCRVLPreTrainedModel(PreTrainedModel):
430
+ config: PaddleOCRVLConfig
431
+ base_model_prefix = "model"
432
+ supports_gradient_checkpointing = True
433
+ _no_split_modules = ["PaddleOCRDecoderLayer"]
434
+ _skip_keys_device_placement = ["past_key_values"]
435
+ _supports_flash_attn = True
436
+ _supports_sdpa = True
437
+ _supports_flex_attn = True
438
+
439
+ _can_compile_fullgraph = True
440
+ _supports_attention_backend = True
441
+
442
+ _can_record_outputs = {
443
+ "hidden_states": PaddleOCRDecoderLayer,
444
+ "attentions": PaddleOCRAttention,
445
+ }
446
+
447
+
448
+ @auto_docstring
449
+ class PaddleOCRTextModel(PaddleOCRVLPreTrainedModel):
450
+ def __init__(self, config: PaddleOCRTextConfig):
451
+ super().__init__(config)
452
+ self.padding_idx = config.pad_token_id
453
+ self.vocab_size = config.vocab_size
454
+
455
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
456
+ self.layers = nn.ModuleList(
457
+ [PaddleOCRDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
458
+ )
459
+ self.norm = PaddleOCRRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
460
+ self.rotary_emb = PaddleOCRRotaryEmbedding(config=config)
461
+ self.gradient_checkpointing = False
462
+
463
+ # Initialize weights and apply final processing
464
+ self.post_init()
465
+
466
+ @check_model_inputs
467
+ @auto_docstring
468
+ def forward(
469
+ self,
470
+ input_ids: Optional[torch.LongTensor] = None,
471
+ attention_mask: Optional[torch.Tensor] = None,
472
+ position_ids: Optional[torch.LongTensor] = None,
473
+ past_key_values: Optional[Cache] = None,
474
+ inputs_embeds: Optional[torch.FloatTensor] = None,
475
+ cache_position: Optional[torch.LongTensor] = None,
476
+ use_cache: Optional[bool] = None,
477
+ **kwargs: Unpack[TransformersKwargs],
478
+ ) -> BaseModelOutputWithPast:
479
+ if (input_ids is None) ^ (inputs_embeds is not None):
480
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
481
+
482
+ if inputs_embeds is None:
483
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
484
+
485
+ if use_cache and past_key_values is None:
486
+ past_key_values = DynamicCache(config=self.config)
487
+
488
+ if cache_position is None:
489
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
490
+ cache_position: torch.Tensor = (
491
+ torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
492
+ )
493
+
494
+ if position_ids is None:
495
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
496
+ elif position_ids.ndim == 2:
497
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
498
+
499
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
500
+ text_position_ids = position_ids[0]
501
+ position_ids = position_ids[1:]
502
+ else:
503
+ text_position_ids = None
504
+
505
+ causal_mask = create_causal_mask(
506
+ config=self.config,
507
+ input_embeds=inputs_embeds,
508
+ attention_mask=attention_mask,
509
+ cache_position=cache_position,
510
+ past_key_values=past_key_values,
511
+ position_ids=text_position_ids,
512
+ )
513
+
514
+ hidden_states = inputs_embeds
515
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
516
+
517
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
518
+ hidden_states = decoder_layer(
519
+ hidden_states,
520
+ attention_mask=causal_mask,
521
+ position_embeddings=position_embeddings,
522
+ position_ids=text_position_ids,
523
+ past_key_values=past_key_values,
524
+ use_cache=use_cache,
525
+ cache_position=cache_position,
526
+ **kwargs,
527
+ )
528
+
529
+ hidden_states = self.norm(hidden_states)
530
+ return BaseModelOutputWithPast(
531
+ last_hidden_state=hidden_states,
532
+ past_key_values=past_key_values,
533
+ )
534
+
535
+
536
+ class PaddleOCRVisionModel(PaddleOCRVLPreTrainedModel):
537
+ config: PaddleOCRVisionConfig
538
+ main_input_name = "pixel_values"
539
+ input_modalities = "image"
540
+
541
+ def __init__(self, config: PaddleOCRVisionConfig):
542
+ super().__init__(config)
543
+
544
+ self.vision_model = PaddleOCRVisionTransformer(config)
545
+
546
+ # Initialize weights and apply final processing
547
+ self.post_init()
548
+
549
+ def forward(
550
+ self,
551
+ pixel_values: torch.FloatTensor,
552
+ cu_seqlens: torch.Tensor,
553
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
554
+ **kwargs,
555
+ ) -> BaseModelOutputWithPooling:
556
+ """
557
+ Args:
558
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`):
559
+ The tensors corresponding to the input images.
560
+ cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`):
561
+ The cumulative sequence lengths of each image or video feature.
562
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
563
+ The temporal, height and width of feature shape of each image in LLM.
564
+ """
565
+ return self.vision_model(
566
+ pixel_values=pixel_values,
567
+ cu_seqlens=cu_seqlens,
568
+ image_grid_thw=image_grid_thw,
569
+ )
570
+
571
+
572
+ class PaddleOCRVisionEmbeddings(nn.Module):
573
+ def __init__(self, config: PaddleOCRVisionConfig):
574
+ super().__init__()
575
+ self.config = config
576
+ self.embed_dim = config.hidden_size
577
+ self.image_size = config.image_size
578
+ self.patch_size = config.patch_size
579
+
580
+ self.patch_embedding = nn.Conv2d(
581
+ in_channels=config.num_channels,
582
+ out_channels=self.embed_dim,
583
+ kernel_size=self.patch_size,
584
+ stride=self.patch_size,
585
+ padding="valid",
586
+ )
587
+
588
+ self.num_patches = (self.image_size // self.patch_size) ** 2
589
+ self.num_positions = self.num_patches
590
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
591
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
592
+
593
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
594
+ """
595
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
596
+ images. This method is also adapted to support torch.jit tracing and no class embeddings.
597
+
598
+ Adapted from:
599
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
600
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
601
+ """
602
+ num_positions = self.position_embedding.weight.shape[0]
603
+
604
+ patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
605
+
606
+ dim = embeddings.shape[-1]
607
+
608
+ sqrt_num_positions = torch_int(num_positions**0.5)
609
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
610
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
611
+
612
+ patch_pos_embed = nn.functional.interpolate(
613
+ patch_pos_embed,
614
+ size=(height, width),
615
+ mode="bilinear",
616
+ align_corners=False,
617
+ )
618
+
619
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
620
+ return patch_pos_embed
621
+
622
+ def forward(
623
+ self,
624
+ pixel_values: torch.FloatTensor,
625
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
626
+ ) -> torch.Tensor:
627
+ """
628
+ Args:
629
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`):
630
+ The tensors corresponding to the input images.
631
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
632
+ The temporal, height and width of feature shape of each image in LLM.
633
+ """
634
+ batch_size, squence_len, channel, height, width = pixel_values.shape
635
+ target_dtype = self.patch_embedding.weight.dtype
636
+ pixel_values = pixel_values.reshape(batch_size * squence_len, channel, height, width)
637
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
638
+ embeddings = patch_embeds.flatten(-2).squeeze(-1)
639
+ embeddings = embeddings.reshape(batch_size, squence_len, -1)
640
+
641
+ start = 0
642
+ embeddings = embeddings.squeeze(0)
643
+ tmp_embeddings = []
644
+ for image_grid in image_grid_thw:
645
+ t, h, w = image_grid
646
+ end = start + t * h * w
647
+ image_embeddings = embeddings[start:end, :]
648
+ position_embedding = self.interpolate_pos_encoding(image_embeddings, h, w).squeeze(0).repeat(t, 1)
649
+ image_embeddings = image_embeddings + position_embedding
650
+ tmp_embeddings.append(image_embeddings)
651
+ start = end
652
+ embeddings = torch.concat(tmp_embeddings, dim=0)
653
+
654
+ return embeddings
655
+
656
+
657
+ def apply_rotary_pos_emb_vision(
658
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
659
+ ) -> tuple[torch.Tensor, torch.Tensor]:
660
+ orig_q_dtype = q.dtype
661
+ orig_k_dtype = k.dtype
662
+ q, k = q.float(), k.float()
663
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
664
+ q_embed = (q * cos) + (rotate_half(q) * sin)
665
+ k_embed = (k * cos) + (rotate_half(k) * sin)
666
+ q_embed = q_embed.to(orig_q_dtype)
667
+ k_embed = k_embed.to(orig_k_dtype)
668
+ return q_embed, k_embed
669
+
670
+
671
+ class PaddleOCRVisionAttention(nn.Module):
672
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
673
+
674
+ def __init__(self, config: PaddleOCRVisionConfig):
675
+ super().__init__()
676
+ self.config = config
677
+ self.embed_dim = config.hidden_size
678
+ self.num_heads = config.num_attention_heads
679
+ self.head_dim = self.embed_dim // self.num_heads
680
+ if self.head_dim * self.num_heads != self.embed_dim:
681
+ raise ValueError(
682
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
683
+ f" {self.num_heads})."
684
+ )
685
+ self.is_causal = False
686
+
687
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
688
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
689
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
690
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
691
+ self.num_key_value_groups = 1
692
+ self.scaling = self.head_dim**-0.5
693
+ self.attention_dropout = config.attention_dropout
694
+
695
+ def forward(
696
+ self,
697
+ hidden_states: torch.Tensor,
698
+ cu_seqlens: torch.Tensor,
699
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
700
+ **kwargs: Unpack[TransformersKwargs],
701
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
702
+ """
703
+ Args:
704
+ hidden_states (`torch.Tensor`):
705
+ Input to the layer of shape `(seq_len, embed_dim)`.
706
+ cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`):
707
+ The cumulative sequence lengths of each image or video feature.
708
+ position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`):
709
+ The cosine and sine position embeddings for vision attention.
710
+ """
711
+ seq_length = hidden_states.shape[0]
712
+ query_states = self.q_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim)
713
+ key_states = self.k_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim)
714
+ value_states = self.v_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim)
715
+
716
+ cos, sin = position_embeddings
717
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
718
+
719
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
720
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
721
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
722
+
723
+ attention_interface: Callable = eager_attention_forward
724
+ if self.config._attn_implementation != "eager":
725
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
726
+
727
+ if self.config._attn_implementation == "flash_attention_2":
728
+ # Flash Attention 2: Use cu_seqlens for variable length attention
729
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
730
+ attn_output, attn_weights = attention_interface(
731
+ self,
732
+ query_states,
733
+ key_states,
734
+ value_states,
735
+ attention_mask=None,
736
+ scaling=self.scaling,
737
+ dropout=0.0 if not self.training else self.attention_dropout,
738
+ cu_seq_lens_q=cu_seqlens,
739
+ cu_seq_lens_k=cu_seqlens,
740
+ max_length_q=max_seqlen,
741
+ max_length_k=max_seqlen,
742
+ is_causal=False,
743
+ **kwargs,
744
+ )
745
+ else:
746
+ # Other implementations: Process each chunk separately
747
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
748
+ splits = [
749
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
750
+ ]
751
+
752
+ attn_outputs, attn_weights = [], []
753
+ for q, k, v in zip(*splits):
754
+ attn_output, attn_weight = attention_interface(
755
+ self,
756
+ q,
757
+ k,
758
+ v,
759
+ attention_mask=None,
760
+ scaling=self.scaling,
761
+ dropout=0.0 if not self.training else self.attention_dropout,
762
+ is_causal=False,
763
+ **kwargs,
764
+ )
765
+ attn_outputs.append(attn_output)
766
+ attn_weights.append(attn_weight)
767
+
768
+ attn_output = torch.cat(attn_outputs, dim=1)
769
+
770
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
771
+ attn_output = self.out_proj(attn_output)
772
+
773
+ return attn_output, attn_weights
774
+
775
+
776
+ class PaddleOCRVisionMLP(nn.Module):
777
+ def __init__(self, config: PaddleOCRVisionConfig):
778
+ super().__init__()
779
+ self.config = config
780
+ self.activation_fn = ACT2FN[config.hidden_act]
781
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
782
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
783
+
784
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
785
+ hidden_states = self.fc1(hidden_states)
786
+ hidden_states = self.activation_fn(hidden_states)
787
+ hidden_states = self.fc2(hidden_states)
788
+ return hidden_states
789
+
790
+
791
+ class PaddleOCRVisionEncoderLayer(GradientCheckpointingLayer):
792
+ def __init__(self, config: PaddleOCRVisionConfig):
793
+ super().__init__()
794
+ self.embed_dim = config.hidden_size
795
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
796
+ self.self_attn = PaddleOCRVisionAttention(config=config)
797
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
798
+ self.mlp = PaddleOCRVisionMLP(config=config)
799
+
800
+ @auto_docstring
801
+ def forward(
802
+ self,
803
+ hidden_states: torch.Tensor,
804
+ cu_seqlens: torch.Tensor,
805
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
806
+ **kwargs: Unpack[TransformersKwargs],
807
+ ) -> torch.Tensor:
808
+ r"""
809
+ cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`):
810
+ The cumulative sequence lengths of each image or video feature.
811
+ position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`):
812
+ The cosine and sine position embeddings for vision attention.
813
+ """
814
+ residual = hidden_states
815
+
816
+ hidden_states = self.layer_norm1(hidden_states)
817
+ hidden_states, _ = self.self_attn(
818
+ hidden_states,
819
+ cu_seqlens=cu_seqlens,
820
+ position_embeddings=position_embeddings,
821
+ **kwargs,
822
+ )
823
+ hidden_states = residual + hidden_states
824
+
825
+ residual = hidden_states
826
+ hidden_states = self.layer_norm2(hidden_states)
827
+ hidden_states = self.mlp(hidden_states)
828
+ hidden_states = residual + hidden_states
829
+
830
+ return hidden_states
831
+
832
+
833
+ class PaddleOCRVisionEncoder(nn.Module):
834
+ """
835
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
836
+ [`PaddleOCRVisionEncoderLayer`].
837
+
838
+ Args:
839
+ config: PaddleOCRVisionConfig
840
+ """
841
+
842
+ def __init__(self, config: PaddleOCRVisionConfig):
843
+ super().__init__()
844
+ self.config = config
845
+ self.layers = nn.ModuleList([PaddleOCRVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
846
+ self.gradient_checkpointing = False
847
+ embed_dim = config.hidden_size
848
+ num_heads = config.num_attention_heads
849
+ head_dim = embed_dim // num_heads
850
+ self.rotary_pos_emb = PaddleOCRVisionRotaryEmbedding(head_dim // 2)
851
+
852
+ # Ignore copy
853
+ @can_return_tuple
854
+ @auto_docstring
855
+ def forward(
856
+ self,
857
+ inputs_embeds: torch.FloatTensor,
858
+ cu_seqlens: torch.Tensor,
859
+ attention_mask: Optional[torch.Tensor] = None,
860
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
861
+ ) -> BaseModelOutput:
862
+ """
863
+ Args:
864
+ inputs_embeds (`torch.FloatTensor` of shape `(sequence_length, hidden_size)`, *optional*):
865
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
866
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
867
+ than the model's internal embedding lookup matrix.
868
+ cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`):
869
+ The cumulative sequence lengths of each image or video feature.
870
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
871
+ The attention_mask used in forward function shape [batch_size X sequence_length] if not None.
872
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
873
+ The temporal, height and width of feature shape of each image in LLM.
874
+ """
875
+ device = inputs_embeds.device
876
+ hidden_states = inputs_embeds
877
+ attention_mask = create_bidirectional_mask(
878
+ config=self.config,
879
+ input_embeds=inputs_embeds,
880
+ attention_mask=attention_mask,
881
+ )
882
+ split_hids = []
883
+ split_wids = []
884
+ for t, h, w in image_grid_thw:
885
+ image_pids = torch.arange(t * h * w, device=device) % (h * w)
886
+ sample_hids = image_pids // w
887
+ sample_wids = image_pids % w
888
+ split_hids.append(sample_hids)
889
+ split_wids.append(sample_wids)
890
+ width_position_ids = torch.concat(split_wids, dim=0)
891
+ height_position_ids = torch.concat(split_hids, dim=0)
892
+
893
+ pids = torch.stack([height_position_ids, width_position_ids], dim=-1)
894
+ max_grid_size = pids.max() + 1
895
+ rotary_embeddings_max_grid = self.rotary_pos_emb(max_grid_size)
896
+ rotary_embeddings = rotary_embeddings_max_grid[pids].flatten(1)
897
+ rotary_embeddings = rotary_embeddings.repeat(1, 2)
898
+ position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin())
899
+
900
+ for encoder_layer in self.layers:
901
+ hidden_states = encoder_layer(
902
+ hidden_states,
903
+ cu_seqlens=cu_seqlens,
904
+ position_embeddings=position_embeddings,
905
+ )
906
+
907
+ return BaseModelOutput(
908
+ last_hidden_state=hidden_states,
909
+ )
910
+
911
+
912
+ class PaddleOCRVisionTransformer(PaddleOCRVLPreTrainedModel):
913
+ def __init__(self, config: PaddleOCRVisionConfig):
914
+ super().__init__(config)
915
+ self.config = config
916
+ embed_dim = config.hidden_size
917
+
918
+ self.embeddings = PaddleOCRVisionEmbeddings(config)
919
+ self.encoder = PaddleOCRVisionEncoder(config)
920
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
921
+
922
+ def forward(
923
+ self,
924
+ pixel_values: torch.FloatTensor,
925
+ cu_seqlens: torch.Tensor,
926
+ attention_mask: Optional[torch.Tensor] = None,
927
+ image_grid_thw: Optional[list[Union[tuple[int, int, int], list[tuple[int, int, int]]]]] = None,
928
+ **kwargs,
929
+ ) -> BaseModelOutputWithPooling:
930
+ """
931
+ Args:
932
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`):
933
+ The tensors corresponding to the input images.
934
+ cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`):
935
+ The cumulative sequence lengths of each image or video feature.
936
+ attention_mask (`torch.Tensor`, *optional*):
937
+ The attention_mask used in forward function shape [batch_size X sequence_length] if not None.
938
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
939
+ The temporal, height and width of feature shape of each image in LLM.
940
+ """
941
+ hidden_states = self.embeddings(pixel_values, image_grid_thw=image_grid_thw)
942
+
943
+ encoder_outputs: BaseModelOutput = self.encoder(
944
+ inputs_embeds=hidden_states,
945
+ cu_seqlens=cu_seqlens,
946
+ attention_mask=attention_mask,
947
+ image_grid_thw=image_grid_thw,
948
+ )
949
+
950
+ last_hidden_state = encoder_outputs.last_hidden_state
951
+ last_hidden_state = self.post_layernorm(last_hidden_state)
952
+
953
+ return BaseModelOutputWithPooling(
954
+ last_hidden_state=last_hidden_state,
955
+ pooler_output=None,
956
+ hidden_states=encoder_outputs.hidden_states,
957
+ attentions=encoder_outputs.attentions,
958
+ )
959
+
960
+
961
+ @dataclass
962
+ @auto_docstring(
963
+ custom_intro="""
964
+ Base class for Llava outputs, with hidden states and attentions.
965
+ """
966
+ )
967
+ class PaddleOCRVLModelOutputWithPast(ModelOutput):
968
+ r"""
969
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
970
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
971
+
972
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
973
+ `past_key_values` input) to speed up sequential decoding.
974
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
975
+ The rope index difference between sequence length and multimodal rope.
976
+ """
977
+
978
+ last_hidden_state: Optional[torch.FloatTensor] = None
979
+ past_key_values: Optional[Cache] = None
980
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
981
+ attentions: Optional[tuple[torch.FloatTensor]] = None
982
+ rope_deltas: Optional[torch.LongTensor] = None
983
+
984
+
985
+ @dataclass
986
+ @auto_docstring(
987
+ custom_intro="""
988
+ Base class for PaddleOCRVL causal language model (or autoregressive) outputs.
989
+ """
990
+ )
991
+ class PaddleOCRVLCausalLMOutputWithPast(ModelOutput):
992
+ r"""
993
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
994
+ Language modeling loss (for next-token prediction).
995
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
996
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
997
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
998
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
999
+
1000
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
1001
+ `past_key_values` input) to speed up sequential decoding.
1002
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1003
+ The rope index difference between sequence length and multimodal rope.
1004
+ """
1005
+
1006
+ loss: Optional[torch.FloatTensor] = None
1007
+ logits: Optional[torch.FloatTensor] = None
1008
+ past_key_values: Optional[Cache] = None
1009
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
1010
+ attentions: Optional[tuple[torch.FloatTensor]] = None
1011
+ rope_deltas: Optional[torch.LongTensor] = None
1012
+
1013
+
1014
+ @auto_docstring
1015
+ class PaddleOCRVLModel(PaddleOCRVLPreTrainedModel):
1016
+ base_model_prefix = "model"
1017
+ _checkpoint_conversion_mapping = {"^model": "language_model"}
1018
+ # Reference: fix gemma3 grad acc #37208
1019
+ accepts_loss_kwargs = False
1020
+ _keys_to_ignore_on_load_unexpected = ["packing_position_embedding", "vision_model.head"]
1021
+
1022
+ def __init__(self, config: PaddleOCRVLConfig):
1023
+ super().__init__(config)
1024
+ self.visual = PaddleOCRVisionModel._from_config(config.vision_config)
1025
+ self.language_model = PaddleOCRTextModel._from_config(config.text_config)
1026
+ self.rope_deltas = None
1027
+ self.projector = PaddleOCRProjector(config)
1028
+
1029
+ # Initialize weights and apply final processing
1030
+ self.post_init()
1031
+
1032
+ def get_input_embeddings(self):
1033
+ return self.language_model.embed_tokens
1034
+
1035
+ def set_input_embeddings(self, value):
1036
+ self.language_model.embed_tokens = value
1037
+
1038
+ def get_rope_index(
1039
+ self,
1040
+ input_ids: Optional[torch.LongTensor] = None,
1041
+ image_grid_thw: Optional[torch.LongTensor] = None,
1042
+ video_grid_thw: Optional[torch.LongTensor] = None,
1043
+ attention_mask: Optional[torch.Tensor] = None,
1044
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1045
+ """
1046
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
1047
+
1048
+ Explanation:
1049
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
1050
+
1051
+ For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
1052
+ Examples:
1053
+ input_ids: [T T T T T], here T is for text.
1054
+ temporal position_ids: [0, 1, 2, 3, 4]
1055
+ height position_ids: [0, 1, 2, 3, 4]
1056
+ width position_ids: [0, 1, 2, 3, 4]
1057
+
1058
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
1059
+ and 1D rotary position embedding for text part.
1060
+ Examples:
1061
+ Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
1062
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
1063
+ vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
1064
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
1065
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
1066
+ text temporal position_ids: [3, 4, 5, 6, 7]
1067
+ text height position_ids: [3, 4, 5, 6, 7]
1068
+ text width position_ids: [3, 4, 5, 6, 7]
1069
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
1070
+
1071
+ Args:
1072
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1073
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1074
+ it.
1075
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1076
+ The temporal, height and width of feature shape of each image in LLM.
1077
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1078
+ The temporal, height and width of feature shape of each video in LLM.
1079
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1080
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1081
+
1082
+ - 1 for tokens that are **not masked**,
1083
+ - 0 for tokens that are **masked**.
1084
+
1085
+ Returns:
1086
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
1087
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
1088
+ """
1089
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
1090
+ image_token_id = self.config.image_token_id
1091
+ video_token_id = self.config.video_token_id
1092
+ vision_start_token_id = self.config.vision_start_token_id
1093
+ mrope_position_deltas = []
1094
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
1095
+ total_input_ids = input_ids
1096
+ if attention_mask is None:
1097
+ attention_mask = torch.ones_like(total_input_ids)
1098
+ position_ids = torch.ones(
1099
+ 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
1100
+ )
1101
+ image_index, video_index = 0, 0
1102
+ for i, input_ids in enumerate(total_input_ids):
1103
+ input_ids = input_ids[attention_mask[i].to(input_ids.device) == 1]
1104
+ image_nums, video_nums = 0, 0
1105
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
1106
+ vision_tokens = input_ids[vision_start_indices + 1]
1107
+ image_nums = (vision_tokens == image_token_id).sum()
1108
+ video_nums = (vision_tokens == video_token_id).sum()
1109
+ input_tokens = input_ids.tolist()
1110
+ llm_pos_ids_list: list = []
1111
+ st = 0
1112
+ remain_images, remain_videos = image_nums, video_nums
1113
+ for _ in range(image_nums + video_nums):
1114
+ if image_token_id in input_tokens and remain_images > 0:
1115
+ ed_image = input_tokens.index(image_token_id, st)
1116
+ else:
1117
+ ed_image = len(input_tokens) + 1
1118
+ if video_token_id in input_tokens and remain_videos > 0:
1119
+ ed_video = input_tokens.index(video_token_id, st)
1120
+ else:
1121
+ ed_video = len(input_tokens) + 1
1122
+ if ed_image < ed_video:
1123
+ t, h, w = (
1124
+ image_grid_thw[image_index][0],
1125
+ image_grid_thw[image_index][1],
1126
+ image_grid_thw[image_index][2],
1127
+ )
1128
+ image_index += 1
1129
+ remain_images -= 1
1130
+ ed = ed_image
1131
+ else:
1132
+ t, h, w = (
1133
+ video_grid_thw[video_index][0],
1134
+ video_grid_thw[video_index][1],
1135
+ video_grid_thw[video_index][2],
1136
+ )
1137
+ video_index += 1
1138
+ remain_videos -= 1
1139
+ ed = ed_video
1140
+ llm_grid_t, llm_grid_h, llm_grid_w = (
1141
+ t.item(),
1142
+ h.item() // spatial_merge_size,
1143
+ w.item() // spatial_merge_size,
1144
+ )
1145
+ text_len = ed - st
1146
+
1147
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1148
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1149
+
1150
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
1151
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
1152
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
1153
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
1154
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1155
+
1156
+ if st < len(input_tokens):
1157
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1158
+ text_len = len(input_tokens) - st
1159
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1160
+
1161
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1162
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
1163
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
1164
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
1165
+ return position_ids, mrope_position_deltas
1166
+ else:
1167
+ if attention_mask is not None:
1168
+ position_ids = attention_mask.long().cumsum(-1) - 1
1169
+ position_ids.masked_fill_(attention_mask == 0, 1)
1170
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
1171
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
1172
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
1173
+ else:
1174
+ position_ids = (
1175
+ torch.arange(input_ids.shape[1], device=input_ids.device)
1176
+ .view(1, 1, -1)
1177
+ .expand(3, input_ids.shape[0], -1)
1178
+ )
1179
+ mrope_position_deltas = torch.zeros(
1180
+ [input_ids.shape[0], 1],
1181
+ device=input_ids.device,
1182
+ dtype=input_ids.dtype,
1183
+ )
1184
+
1185
+ return position_ids, mrope_position_deltas
1186
+
1187
+ def get_video_features(
1188
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
1189
+ ):
1190
+ """
1191
+ Encodes videos into continuous embeddings that can be forwarded to the language model.
1192
+
1193
+ Args:
1194
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1195
+ The tensors corresponding to the input videos.
1196
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1197
+ The temporal, height and width of feature shape of each video in LLM.
1198
+ """
1199
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
1200
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
1201
+ split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1202
+ video_embeds = torch.split(video_embeds, split_sizes)
1203
+ return video_embeds
1204
+
1205
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
1206
+ """
1207
+ Encodes images into continuous embeddings that can be forwarded to the language model.
1208
+
1209
+ Args:
1210
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1211
+ The tensors corresponding to the input images.
1212
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1213
+ The temporal, height and width of feature shape of each image in LLM.
1214
+ """
1215
+ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0)
1216
+ cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum(
1217
+ dim=0,
1218
+ # Select dtype based on the following factors:
1219
+ # - FA2 requires that cu_seqlens_q must have dtype int32
1220
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
1221
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
1222
+ dtype=image_grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
1223
+ )
1224
+ cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)
1225
+ vision_outputs = self.visual(
1226
+ pixel_values=pixel_values,
1227
+ image_grid_thw=image_grid_thw,
1228
+ cu_seqlens=cu_seqlens,
1229
+ )
1230
+ image_embeds = vision_outputs.last_hidden_state
1231
+ image_embeds = self.projector(image_embeds, image_grid_thw)
1232
+ return image_embeds
1233
+
1234
+ def get_placeholder_mask(
1235
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
1236
+ ):
1237
+ """
1238
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
1239
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
1240
+ """
1241
+ if input_ids is None:
1242
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
1243
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
1244
+ )
1245
+ special_image_mask = special_image_mask.all(-1)
1246
+ else:
1247
+ special_image_mask = input_ids == self.config.image_token_id
1248
+
1249
+ n_image_tokens = special_image_mask.sum()
1250
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1251
+ n_image_features = image_features.shape[0] * image_features.shape[1]
1252
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
1253
+ raise ValueError(
1254
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
1255
+ )
1256
+ return special_image_mask
1257
+
1258
+ @can_return_tuple
1259
+ def forward(
1260
+ self,
1261
+ input_ids: torch.LongTensor = None,
1262
+ attention_mask: Optional[torch.Tensor] = None,
1263
+ position_ids: Optional[torch.LongTensor] = None,
1264
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
1265
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1266
+ use_cache: Optional[bool] = None,
1267
+ pixel_values: Optional[torch.Tensor] = None,
1268
+ image_grid_thw: Optional[torch.LongTensor] = None,
1269
+ rope_deltas: Optional[torch.LongTensor] = None,
1270
+ cache_position: Optional[torch.LongTensor] = None,
1271
+ **kwargs,
1272
+ ) -> Union[tuple, PaddleOCRVLModelOutputWithPast]:
1273
+ r"""
1274
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1275
+ The temporal, height and width of feature shape of each image in LLM.
1276
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1277
+ The rope index difference between sequence length and multimodal rope.
1278
+ """
1279
+ if inputs_embeds is None:
1280
+ inputs_embeds = self.language_model.embed_tokens(input_ids)
1281
+
1282
+ if pixel_values is not None:
1283
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw).to(
1284
+ inputs_embeds.device, inputs_embeds.dtype
1285
+ )
1286
+ image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds)
1287
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1288
+
1289
+ if position_ids is None:
1290
+ past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1291
+ if self.rope_deltas is None or past_key_values_length == 0:
1292
+ position_ids, rope_deltas = self.get_rope_index(
1293
+ input_ids=input_ids,
1294
+ image_grid_thw=image_grid_thw,
1295
+ attention_mask=attention_mask,
1296
+ )
1297
+ self.rope_deltas = rope_deltas
1298
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1299
+ else:
1300
+ batch_size, seq_length, _ = inputs_embeds.shape
1301
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1302
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
1303
+ delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
1304
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1305
+ position_ids = position_ids + delta.to(position_ids.device)
1306
+
1307
+ outputs = self.language_model(
1308
+ input_ids=None,
1309
+ position_ids=position_ids,
1310
+ attention_mask=attention_mask,
1311
+ past_key_values=past_key_values,
1312
+ inputs_embeds=inputs_embeds,
1313
+ use_cache=use_cache,
1314
+ cache_position=cache_position,
1315
+ **kwargs,
1316
+ )
1317
+
1318
+ output = PaddleOCRVLModelOutputWithPast(
1319
+ last_hidden_state=outputs.last_hidden_state,
1320
+ past_key_values=outputs.past_key_values,
1321
+ hidden_states=outputs.hidden_states,
1322
+ attentions=outputs.attentions,
1323
+ rope_deltas=self.rope_deltas,
1324
+ )
1325
+
1326
+ return output
1327
+
1328
+
1329
+ class PaddleOCRVLForConditionalGeneration(PaddleOCRVLPreTrainedModel, GenerationMixin):
1330
+ _checkpoint_conversion_mapping = {
1331
+ "^visual": "model.visual",
1332
+ "^mlp_AR": "model.projector",
1333
+ r"^model(?!(\.visual|\.projector|\.language_model))": "model.language_model",
1334
+ }
1335
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
1336
+ _keys_to_ignore_on_load_unexpected = ["packing_position_embedding", "vision_model.head"]
1337
+
1338
+ def __init__(self, config):
1339
+ super().__init__(config)
1340
+ self.model = PaddleOCRVLModel(config)
1341
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1342
+
1343
+ self.post_init()
1344
+
1345
+ def get_input_embeddings(self):
1346
+ return self.model.get_input_embeddings()
1347
+
1348
+ def set_input_embeddings(self, value):
1349
+ self.model.set_input_embeddings(value)
1350
+
1351
+ def get_video_features(
1352
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
1353
+ ):
1354
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw)
1355
+
1356
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
1357
+ return self.model.get_image_features(pixel_values, image_grid_thw)
1358
+
1359
+ @can_return_tuple
1360
+ @auto_docstring
1361
+ def forward(
1362
+ self,
1363
+ input_ids: Optional[torch.LongTensor] = None,
1364
+ attention_mask: Optional[torch.Tensor] = None,
1365
+ position_ids: Optional[torch.LongTensor] = None,
1366
+ past_key_values: Optional[Cache] = None,
1367
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1368
+ labels: Optional[torch.LongTensor] = None,
1369
+ use_cache: Optional[bool] = None,
1370
+ pixel_values: Optional[torch.Tensor] = None,
1371
+ image_grid_thw: Optional[torch.LongTensor] = None,
1372
+ rope_deltas: Optional[torch.LongTensor] = None,
1373
+ cache_position: Optional[torch.LongTensor] = None,
1374
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1375
+ **kwargs: Unpack[TransformersKwargs],
1376
+ ) -> Union[tuple, PaddleOCRVLCausalLMOutputWithPast]:
1377
+ r"""
1378
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1379
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1380
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1381
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1382
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1383
+ The temporal, height and width of feature shape of each image in LLM.
1384
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1385
+ The rope index difference between sequence length and multimodal rope.
1386
+
1387
+ Example:
1388
+
1389
+ ```python
1390
+ >>> from transformers import AutoProcessor, PaddleOCRVLForConditionalGeneration
1391
+
1392
+ >>> model = PaddleOCRVLForConditionalGeneration.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
1393
+ >>> processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
1394
+
1395
+ >>> messages = [
1396
+ {
1397
+ "role": "user",
1398
+ "content": [
1399
+ {
1400
+ "type": "image",
1401
+ "image": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg",
1402
+ },
1403
+ {"type": "text", "text": "OCR:"},
1404
+ ],
1405
+ }
1406
+ ]
1407
+
1408
+ >>> inputs = processor.apply_chat_template(
1409
+ messages,
1410
+ tokenize=True,
1411
+ add_generation_prompt=True,
1412
+ return_dict=True,
1413
+ return_tensors="pt"
1414
+ ).to(model.device)
1415
+
1416
+ >>> # Generate
1417
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=1024)
1418
+ >>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
1419
+ >>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1420
+ >>> print(output_text)
1421
+ ```
1422
+ """
1423
+ outputs: PaddleOCRVLModelOutputWithPast = self.model(
1424
+ input_ids=input_ids,
1425
+ attention_mask=attention_mask,
1426
+ position_ids=position_ids,
1427
+ image_grid_thw=image_grid_thw,
1428
+ past_key_values=past_key_values,
1429
+ inputs_embeds=inputs_embeds,
1430
+ use_cache=use_cache,
1431
+ pixel_values=pixel_values,
1432
+ rope_deltas=rope_deltas,
1433
+ cache_position=cache_position,
1434
+ **kwargs,
1435
+ )
1436
+ hidden_states = outputs.last_hidden_state
1437
+
1438
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1439
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1440
+
1441
+ loss = None
1442
+ if labels is not None:
1443
+ loss = self.loss_function(
1444
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
1445
+ )
1446
+
1447
+ return PaddleOCRVLCausalLMOutputWithPast(
1448
+ loss=loss,
1449
+ logits=logits,
1450
+ past_key_values=outputs.past_key_values,
1451
+ hidden_states=outputs.hidden_states,
1452
+ attentions=outputs.attentions,
1453
+ rope_deltas=outputs.rope_deltas,
1454
+ )
1455
+
1456
+ def prepare_inputs_for_generation(
1457
+ self,
1458
+ input_ids,
1459
+ past_key_values=None,
1460
+ attention_mask=None,
1461
+ inputs_embeds=None,
1462
+ cache_position=None,
1463
+ position_ids=None,
1464
+ use_cache=True,
1465
+ pixel_values=None,
1466
+ pixel_values_videos=None,
1467
+ image_grid_thw=None,
1468
+ video_grid_thw=None,
1469
+ **kwargs,
1470
+ ):
1471
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
1472
+
1473
+ model_inputs = super().prepare_inputs_for_generation(
1474
+ input_ids,
1475
+ past_key_values=past_key_values,
1476
+ attention_mask=attention_mask,
1477
+ inputs_embeds=inputs_embeds,
1478
+ cache_position=cache_position,
1479
+ position_ids=position_ids,
1480
+ pixel_values=pixel_values,
1481
+ pixel_values_videos=pixel_values_videos,
1482
+ image_grid_thw=image_grid_thw,
1483
+ video_grid_thw=video_grid_thw,
1484
+ use_cache=use_cache,
1485
+ **kwargs,
1486
+ )
1487
+
1488
+ # Qwen2-VL position_ids are prepareed with rope_deltas in forward
1489
+ if position_ids is None:
1490
+ # Calculate RoPE index once per generation in the pre-fill stage only.
1491
+ # When compiling, we can't check tensor values thus we check only input length
1492
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
1493
+ # models currently cannot do asssisted decoding
1494
+ if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None:
1495
+ vision_positions, rope_deltas = self.model.get_rope_index(
1496
+ model_inputs.get("input_ids", None),
1497
+ image_grid_thw=image_grid_thw,
1498
+ video_grid_thw=video_grid_thw,
1499
+ attention_mask=attention_mask,
1500
+ )
1501
+ self.model.rope_deltas = rope_deltas
1502
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1503
+ elif "position_ids" in model_inputs:
1504
+ batch_size, seq_length = model_inputs["position_ids"].shape
1505
+ device = model_inputs["position_ids"].device
1506
+ position_ids = torch.arange(seq_length, device=device)
1507
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
1508
+ delta = cache_position[0] + self.model.rope_deltas
1509
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1510
+ vision_positions = position_ids + delta.expand_as(position_ids)
1511
+
1512
+ # Concatenate "text + vision" positions into [4, bs, seq-len]
1513
+ text_positions = model_inputs["position_ids"][None, ...]
1514
+ model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
1515
+
1516
+ if model_inputs["cache_position"][0] != 0:
1517
+ model_inputs["pixel_values"] = None
1518
+ model_inputs["pixel_values_videos"] = None
1519
+
1520
+ return model_inputs
1521
+
1522
+ def _get_image_nums_and_video_nums(
1523
+ self,
1524
+ input_ids: Optional[torch.LongTensor],
1525
+ inputs_embeds: Optional[torch.Tensor] = None,
1526
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1527
+ """
1528
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
1529
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
1530
+
1531
+ Args:
1532
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1533
+ Indices of input sequence tokens in the vocabulary.
1534
+
1535
+ Returns:
1536
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
1537
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
1538
+ """
1539
+ image_token_id = self.config.image_token_id
1540
+ video_token_id = self.config.video_token_id
1541
+ vision_start_token_id = self.config.vision_start_token_id
1542
+
1543
+ if inputs_embeds is not None:
1544
+ vision_start_mask = (
1545
+ inputs_embeds
1546
+ == self.get_input_embeddings()(
1547
+ torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
1548
+ )
1549
+ )[..., 0]
1550
+ image_mask = (
1551
+ inputs_embeds
1552
+ == self.get_input_embeddings()(
1553
+ torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
1554
+ )
1555
+ )[..., 0]
1556
+ video_mask = (
1557
+ inputs_embeds
1558
+ == self.get_input_embeddings()(
1559
+ torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
1560
+ )
1561
+ )[..., 0]
1562
+ else:
1563
+ vision_start_mask = input_ids == vision_start_token_id
1564
+ image_mask = input_ids == image_token_id
1565
+ video_mask = input_ids == video_token_id
1566
+
1567
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
1568
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
1569
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
1570
+
1571
+ return image_nums, video_nums
1572
+
1573
+ def _expand_inputs_for_generation(
1574
+ self,
1575
+ expand_size: int = 1,
1576
+ is_encoder_decoder: bool = False,
1577
+ input_ids: Optional[torch.LongTensor] = None,
1578
+ **model_kwargs,
1579
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
1580
+ # Overwritten -- Support for expanding tensors without a batch size dimension
1581
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
1582
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
1583
+ # image_grid_thw.shape[0] is sum(num_images for samples)
1584
+
1585
+ if expand_size == 1:
1586
+ return input_ids, model_kwargs
1587
+
1588
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
1589
+
1590
+ def _expand_dict_for_generation_visual(dict_to_expand):
1591
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
1592
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
1593
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
1594
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
1595
+ )
1596
+
1597
+ def _repeat_interleave_samples(x, lengths, repeat_times):
1598
+ samples = torch.split(x, lengths)
1599
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
1600
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
1601
+ return result
1602
+
1603
+ for key in dict_to_expand:
1604
+ if key == "pixel_values":
1605
+ # split images into samples
1606
+ samples = torch.split(image_grid_thw, list(image_nums))
1607
+ # compute the sequence length of images for each sample
1608
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1609
+ dict_to_expand[key] = _repeat_interleave_samples(
1610
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1611
+ )
1612
+ elif key == "image_grid_thw":
1613
+ # get the num of images for each sample
1614
+ lengths = list(image_nums)
1615
+ dict_to_expand[key] = _repeat_interleave_samples(
1616
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1617
+ )
1618
+ elif key == "pixel_values_videos":
1619
+ samples = torch.split(video_grid_thw, list(video_nums))
1620
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1621
+ dict_to_expand[key] = _repeat_interleave_samples(
1622
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1623
+ )
1624
+ elif key == "video_grid_thw":
1625
+ lengths = list(video_nums)
1626
+ dict_to_expand[key] = _repeat_interleave_samples(
1627
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1628
+ )
1629
+ elif key == "second_per_grid_ts":
1630
+ dict_to_expand[key] = _repeat_interleave_samples(
1631
+ dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
1632
+ )
1633
+ return dict_to_expand
1634
+
1635
+ def _expand_dict_for_generation(dict_to_expand):
1636
+ for key in dict_to_expand:
1637
+ if (
1638
+ key != "cache_position"
1639
+ and dict_to_expand[key] is not None
1640
+ and isinstance(dict_to_expand[key], torch.Tensor)
1641
+ and key not in visual_keys
1642
+ ):
1643
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1644
+ return dict_to_expand
1645
+
1646
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
1647
+
1648
+ if input_ids is not None:
1649
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
1650
+
1651
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
1652
+
1653
+ if is_encoder_decoder:
1654
+ if model_kwargs.get("encoder_outputs") is None:
1655
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
1656
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
1657
+
1658
+ return input_ids, model_kwargs
1659
+
1660
+
1661
+ __all__ = [
1662
+ "PaddleOCRVLForConditionalGeneration",
1663
+ "PaddleOCRVLModel",
1664
+ "PaddleOCRVLPreTrainedModel",
1665
+ "PaddleOCRVisionTransformer",
1666
+ "PaddleOCRTextModel",
1667
+ "PaddleOCRVisionModel",
1668
+ ]