transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,594 @@
1
+ # coding=utf-8
2
+ # Copyright 2025 Baidu and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import os.path
16
+ from functools import partial
17
+ from pathlib import Path
18
+ from shutil import SameFileError, copyfile
19
+ from typing import Any, Optional, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ from huggingface_hub import is_offline_mode
24
+ from huggingface_hub.dataclasses import validate_typed_dict
25
+ from PIL import ImageDraw, ImageFont
26
+ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
27
+
28
+ from ...image_processing_utils import BatchFeature
29
+ from ...image_utils import (
30
+ OPENAI_CLIP_MEAN,
31
+ OPENAI_CLIP_STD,
32
+ ChannelDimension,
33
+ PILImageResampling,
34
+ SizeDict,
35
+ get_image_size,
36
+ validate_kwargs,
37
+ )
38
+ from ...processing_utils import Unpack, VideosKwargs
39
+ from ...utils import (
40
+ IMAGE_PROCESSOR_NAME,
41
+ PROCESSOR_NAME,
42
+ VIDEO_PROCESSOR_NAME,
43
+ TensorType,
44
+ add_start_docstrings,
45
+ logging,
46
+ safe_load_json_file,
47
+ )
48
+ from ...utils.hub import cached_file
49
+ from ...utils.import_utils import is_tracing, requires
50
+ from ...video_processing_utils import BASE_VIDEO_PROCESSOR_DOCSTRING, BaseVideoProcessor
51
+ from ...video_utils import (
52
+ VideoInput,
53
+ VideoMetadata,
54
+ group_videos_by_shape,
55
+ infer_channel_dimension_format,
56
+ reorder_videos,
57
+ )
58
+ from .image_processing_ernie4_5_vl_moe import smart_resize
59
+
60
+
61
+ logger = logging.get_logger(__name__)
62
+
63
+
64
+ class Ernie4_5_VL_MoeVideoProcessorInitKwargs(VideosKwargs, total=False):
65
+ patch_size: int
66
+ temporal_patch_size: int
67
+ merge_size: int
68
+ min_frames: int
69
+ max_frames: int
70
+ draw_on_frames: bool
71
+ font: str
72
+
73
+
74
+ @add_start_docstrings(
75
+ "Constructs a fast Ernie 4.5 VL image processor that dynamically resizes videos based on the original videos.",
76
+ BASE_VIDEO_PROCESSOR_DOCSTRING,
77
+ """
78
+ patch_size (`int`, *optional*, defaults to 14):
79
+ The spacial patch size of the vision encoder.
80
+ temporal_patch_size (`int`, *optional*, defaults to 2):
81
+ The temporal patch size of the vision encoder.
82
+ merge_size (`int`, *optional*, defaults to 2):
83
+ The merge size of the vision encoder to llm encoder.
84
+ min_frames (`int`, *optional*, defaults to 16):
85
+ The minimum number of frames that can be sampled.
86
+ max_frames (`int`, *optional*, defaults to 180):
87
+ The maximum number of frames that can be sampled.
88
+ draw_on_frames (`bool`, *optional*, defaults to `True`):
89
+ Whether to draw timestamps on each frame or not.
90
+ This does not work with `torch.compile` but resembles
91
+ the performance of the original model.
92
+ font (`str`, *optional*, defaults to "Roboto-Regular.ttf"):
93
+ The associated font name for drawing on frames.
94
+ Defaults to "Roboto-Regular.ttf" and is expected to be
95
+ saved along the processor as separate file.
96
+ """,
97
+ )
98
+ @requires(backends=("torchvision",))
99
+ class Ernie4_5_VL_MoeVideoProcessor(BaseVideoProcessor):
100
+ resample = PILImageResampling.BICUBIC
101
+ size = {"shortest_edge": 299 * 28 * 28, "longest_edge": 1196 * 28 * 28}
102
+ image_mean = OPENAI_CLIP_MEAN
103
+ image_std = OPENAI_CLIP_STD
104
+ do_resize = True
105
+ do_rescale = True
106
+ do_normalize = True
107
+ do_convert_rgb = True
108
+ patch_size = 14
109
+ temporal_patch_size = 2
110
+ merge_size = 2
111
+ min_frames = 16
112
+ max_frames = 180
113
+ do_sample_frames = True
114
+ draw_on_frames = True
115
+ font = "Roboto-Regular.ttf"
116
+ valid_kwargs = Ernie4_5_VL_MoeVideoProcessorInitKwargs
117
+ model_input_names = ["pixel_values_videos", "video_grid_thw"]
118
+
119
+ def __init__(self, **kwargs: Unpack[Ernie4_5_VL_MoeVideoProcessorInitKwargs]):
120
+ temporal_patch_size = kwargs.get("temporal_patch_size", 2)
121
+ if temporal_patch_size is None or temporal_patch_size != 2:
122
+ raise ValueError("`Ernie 4.5 VL` only supports a temporal patch size of 2")
123
+
124
+ size = kwargs.pop("size", None)
125
+ size = self.size if size is None else size
126
+ if "shortest_edge" not in size or "longest_edge" not in size:
127
+ raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
128
+
129
+ super().__init__(size=size, **kwargs)
130
+
131
+ @classmethod
132
+ def get_video_processor_dict(
133
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
134
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
135
+ """Overriden to additionally load the font for drawing on frames."""
136
+ cache_dir = kwargs.pop("cache_dir", None)
137
+ force_download = kwargs.pop("force_download", False)
138
+ proxies = kwargs.pop("proxies", None)
139
+ token = kwargs.pop("token", None)
140
+ local_files_only = kwargs.pop("local_files_only", False)
141
+ revision = kwargs.pop("revision", None)
142
+ subfolder = kwargs.pop("subfolder", "")
143
+
144
+ from_pipeline = kwargs.pop("_from_pipeline", None)
145
+ from_auto_class = kwargs.pop("_from_auto", False)
146
+
147
+ user_agent = {"file_type": "video processor", "from_auto_class": from_auto_class}
148
+ if from_pipeline is not None:
149
+ user_agent["using_pipeline"] = from_pipeline
150
+
151
+ if is_offline_mode() and not local_files_only:
152
+ logger.info("Offline mode: forcing local_files_only=True")
153
+ local_files_only = True
154
+
155
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
156
+ is_local = os.path.isdir(pretrained_model_name_or_path)
157
+ if os.path.isfile(pretrained_model_name_or_path):
158
+ resolved_video_processor_file = pretrained_model_name_or_path
159
+ resolved_processor_file = None
160
+ is_local = True
161
+ else:
162
+ video_processor_file = VIDEO_PROCESSOR_NAME
163
+ try:
164
+ # Try to load with a new config name first and if not successful try with the old file name
165
+ # NOTE: we save all processor configs as nested dict in PROCESSOR_NAME from v5, which is the standard
166
+ resolved_processor_file = cached_file(
167
+ pretrained_model_name_or_path,
168
+ filename=PROCESSOR_NAME,
169
+ cache_dir=cache_dir,
170
+ force_download=force_download,
171
+ proxies=proxies,
172
+ local_files_only=local_files_only,
173
+ token=token,
174
+ user_agent=user_agent,
175
+ revision=revision,
176
+ subfolder=subfolder,
177
+ _raise_exceptions_for_missing_entries=False,
178
+ )
179
+ resolved_video_processor_files = [
180
+ resolved_file
181
+ for filename in [video_processor_file, IMAGE_PROCESSOR_NAME]
182
+ if (
183
+ resolved_file := cached_file(
184
+ pretrained_model_name_or_path,
185
+ filename=filename,
186
+ cache_dir=cache_dir,
187
+ force_download=force_download,
188
+ proxies=proxies,
189
+ local_files_only=local_files_only,
190
+ token=token,
191
+ user_agent=user_agent,
192
+ revision=revision,
193
+ subfolder=subfolder,
194
+ _raise_exceptions_for_missing_entries=False,
195
+ )
196
+ )
197
+ is not None
198
+ ]
199
+ resolved_video_processor_file = (
200
+ resolved_video_processor_files[0] if resolved_video_processor_files else None
201
+ )
202
+ except OSError:
203
+ # Raise any OS error raise by `cached_file`. It will have a helpful error message adapted to
204
+ # the original exception.
205
+ raise
206
+ except Exception:
207
+ # For any other exception, we throw a generic error.
208
+ raise OSError(
209
+ f"Can't load video processor for '{pretrained_model_name_or_path}'. If you were trying to load"
210
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
211
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
212
+ f" directory containing a {video_processor_file} file"
213
+ )
214
+
215
+ # Load video_processor dict. Priority goes as (nested config if found -> video processor config -> image processor config)
216
+ # We are downloading both configs because almost all models have a `processor_config.json` but
217
+ # not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style
218
+ video_processor_dict = None
219
+ if resolved_processor_file is not None:
220
+ processor_dict = safe_load_json_file(resolved_processor_file)
221
+ if "video_processor" in processor_dict:
222
+ video_processor_dict = processor_dict["video_processor"]
223
+
224
+ if resolved_video_processor_file is not None and video_processor_dict is None:
225
+ video_processor_dict = safe_load_json_file(resolved_video_processor_file)
226
+
227
+ if video_processor_dict is None:
228
+ raise OSError(
229
+ f"Can't load video processor for '{pretrained_model_name_or_path}'. If you were trying to load"
230
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
231
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
232
+ f" directory containing a {video_processor_file} file"
233
+ )
234
+
235
+ # Specific to Ernie 4.5 VL Moe, we load the font file along the json (if we draw on frames)
236
+ draws_on_frames = video_processor_dict.get("draw_on_frames")
237
+ if (font_name := video_processor_dict.get("font")) is None and draws_on_frames:
238
+ raise AttributeError(
239
+ "Expected a `font` to be saved when using `draw_on_frames` in Ernie 4.5 VL Moe; found nothing."
240
+ )
241
+ if font_name is not None and draws_on_frames:
242
+ video_processor_dict["font"] = cached_file(
243
+ pretrained_model_name_or_path,
244
+ filename=font_name,
245
+ cache_dir=cache_dir,
246
+ force_download=force_download,
247
+ proxies=proxies,
248
+ local_files_only=local_files_only,
249
+ token=token,
250
+ user_agent=user_agent,
251
+ revision=revision,
252
+ subfolder=subfolder,
253
+ _raise_exceptions_for_missing_entries=False,
254
+ )
255
+ try:
256
+ ImageFont.truetype(video_processor_dict["font"])
257
+ except (TypeError, OSError):
258
+ raise OSError(
259
+ f"Could not find an associated font file for {video_processor_dict['font']}. "
260
+ "Make sure to save a font file along for Ernie 4.5 VL Moe."
261
+ )
262
+
263
+ if is_local:
264
+ logger.info(f"loading configuration file {resolved_video_processor_file}")
265
+ else:
266
+ logger.info(
267
+ f"loading configuration file {video_processor_file} from cache at {resolved_video_processor_file}"
268
+ )
269
+
270
+ return video_processor_dict, kwargs
271
+
272
+ def to_dict(self) -> dict[str, Any]:
273
+ """Overriden to strip the prefix of the full path for the font, e.g. `tmp/folder/font.tff` -> `font.tff`"""
274
+ output = super().to_dict()
275
+
276
+ if os.path.isfile(output.get("font")):
277
+ output["font"] = Path(output["font"]).name
278
+ elif output.get("draw_on_frames"):
279
+ raise ValueError(
280
+ f"The video processor dict contains an invalid path to its font: {output['font']}. "
281
+ "Please make sure to contain a valid path or disable `draw_on_frames`."
282
+ )
283
+
284
+ return output
285
+
286
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
287
+ """We additionally save a copy of the font to the `save_directory` (if we found a file there)"""
288
+ os.makedirs(save_directory, exist_ok=True)
289
+
290
+ if os.path.isfile(self.font):
291
+ try:
292
+ copyfile(self.font, Path(save_directory, Path(self.font).name))
293
+ except SameFileError: # already exists which we allow (copy if needed)
294
+ pass
295
+
296
+ return super().save_pretrained(save_directory, push_to_hub, **kwargs)
297
+
298
+ def _further_process_kwargs(
299
+ self,
300
+ size: Optional[SizeDict] = None,
301
+ **kwargs,
302
+ ) -> dict:
303
+ """
304
+ Update kwargs that need further processing before being validated
305
+ Can be overridden by subclasses to customize the processing of kwargs.
306
+ """
307
+ if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
308
+ raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
309
+
310
+ return super()._further_process_kwargs(size=size, **kwargs)
311
+
312
+ def sample_frames(
313
+ self,
314
+ metadata: VideoMetadata,
315
+ min_frames: Optional[int] = None,
316
+ max_frames: Optional[int] = None,
317
+ num_frames: Optional[int] = None,
318
+ fps: Optional[Union[int, float]] = None,
319
+ **kwargs,
320
+ ):
321
+ if fps is not None and num_frames is not None:
322
+ raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
323
+
324
+ num_frames = num_frames if num_frames is not None else self.num_frames
325
+ min_frames = min_frames if min_frames is not None else self.min_frames
326
+ max_frames = max_frames if max_frames is not None else self.max_frames
327
+ total_num_frames = metadata.total_num_frames
328
+
329
+ if num_frames is not None:
330
+ if num_frames < min_frames or num_frames > max_frames:
331
+ raise ValueError(f"`num_frames` must be {min_frames} <= x <= {max_frames}. Got {num_frames} instead.")
332
+ else:
333
+ if fps is not None and (metadata is None or metadata.fps is None):
334
+ raise ValueError(
335
+ "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
336
+ "Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video"
337
+ )
338
+ num_frames = total_num_frames / metadata.fps * fps if fps is not None else total_num_frames
339
+ num_frames = min(max(num_frames, min_frames), max_frames, total_num_frames)
340
+
341
+ if num_frames > total_num_frames:
342
+ raise ValueError(
343
+ f"Video can't be sampled. The inferred `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
344
+ "Decrease `num_frames` or `fps` for sampling."
345
+ )
346
+
347
+ indices = torch.arange(0, total_num_frames, total_num_frames / num_frames).int()
348
+
349
+ return indices
350
+
351
+ def _convert_timestamp(self, time_stamp_in_seconds):
352
+ """Convert to `time: hr:min:sec` format"""
353
+ hours = time_stamp_in_seconds // 3600
354
+ time_stamp_in_seconds = time_stamp_in_seconds % 3600
355
+ mins = time_stamp_in_seconds // 60
356
+ time_stamp_in_seconds = time_stamp_in_seconds % 60
357
+ return f"time: {int(hours):02d}:{int(mins):02d}:{time_stamp_in_seconds:05.02f}"
358
+
359
+ def _render_image_with_timestamp(self, image: torch.Tensor, timestamp: str, size_factor: float = 0.1):
360
+ """Draws a black timestamp with a white border on the corner of the frame"""
361
+ if self.font is None:
362
+ raise AttributeError("To draw on frames with Ernie 4.5 VL, you need an associated font; found nothing")
363
+
364
+ # FIXME: conversion `torch->PIL->torch` is inefficient ~6ms per frame
365
+ # Left for optimization if anyone want to pick it up
366
+ #
367
+ # This can take up to ~1s in preprocessing (if default sampling is used):
368
+ # 180 (frames) x 6ms = 1080ms = ~1,1s
369
+ image = to_pil_image(image)
370
+
371
+ font_size = int(min(*image.size) * size_factor)
372
+ outline_size = int(font_size * size_factor)
373
+ font = ImageFont.truetype(self.font, font_size)
374
+
375
+ # Draw a black text with a white border
376
+ draw = ImageDraw.Draw(image)
377
+ draw.text(
378
+ (0, 0),
379
+ timestamp,
380
+ font=font,
381
+ fill=(0, 0, 0),
382
+ stroke_width=outline_size,
383
+ stroke_fill=(255, 255, 255),
384
+ )
385
+ return pil_to_tensor(image)
386
+
387
+ def _prepare_input_videos(
388
+ self,
389
+ videos: VideoInput,
390
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
391
+ device: Optional[str] = None,
392
+ video_metadata: Optional[list[VideoMetadata]] = None,
393
+ draw_on_frames: bool = True,
394
+ ) -> list["torch.Tensor"]:
395
+ """
396
+ Prepare the input videos for processing.
397
+ """
398
+ processed_videos = []
399
+ for video, metadata in zip(videos, video_metadata):
400
+ # Check for attributes that are necessary to draw timestamps on frames
401
+ if draw_on_frames:
402
+ if metadata is None:
403
+ raise ValueError("Need video metadata to process videos in Ernie 4.5 VL using `draw_on_frames`")
404
+ elif metadata.fps is None:
405
+ metadata.fps = 24
406
+ logger.warning_once(
407
+ "Could not infer the fps of a video due to the metadata not being available, "
408
+ "defaulting to `24`. Please provide `video_metadata` for more accurate results."
409
+ )
410
+
411
+ # `make_batched_videos` always returns a 4D array per video
412
+ if isinstance(video, np.ndarray):
413
+ # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
414
+ video = torch.from_numpy(video).contiguous()
415
+
416
+ # Infer the channel dimension format if not provided
417
+ if input_data_format is None:
418
+ input_data_format = infer_channel_dimension_format(video)
419
+
420
+ if input_data_format == ChannelDimension.LAST:
421
+ video = video.permute(0, 3, 1, 2).contiguous()
422
+
423
+ # specific to ernie, draws timestamps on each frame (if enabled)
424
+ if draw_on_frames:
425
+ if is_tracing(video):
426
+ raise RuntimeError(
427
+ "Using `torch.compile` is not compatible with drawing on frames. "
428
+ "Either don't use `torch.compile` or don't draw on frames via the kwarg `draw_on_frames=False`."
429
+ )
430
+
431
+ for idx, frame in enumerate(video):
432
+ video[idx] = self._render_image_with_timestamp(
433
+ frame, self._convert_timestamp(metadata.timestamps[idx])
434
+ )
435
+
436
+ # last frame is copied if uneven (mitigating issues for temporal patch size)
437
+ if video.shape[0] % 2 != 0:
438
+ video = torch.cat((video, video[-1].detach().clone()[None, ...]), dim=0)
439
+
440
+ if device is not None:
441
+ video = video.to(device)
442
+
443
+ processed_videos.append(video)
444
+ return processed_videos
445
+
446
+ def _preprocess(
447
+ self,
448
+ videos: list[torch.Tensor],
449
+ do_convert_rgb: bool = True,
450
+ do_resize: bool = True,
451
+ size: Optional[SizeDict] = None,
452
+ interpolation: PILImageResampling = PILImageResampling.BICUBIC,
453
+ do_rescale: bool = True,
454
+ rescale_factor: float = 1 / 255.0,
455
+ do_normalize: bool = True,
456
+ image_mean: Optional[Union[float, list[float]]] = None,
457
+ image_std: Optional[Union[float, list[float]]] = None,
458
+ patch_size: Optional[int] = None,
459
+ merge_size: Optional[int] = None,
460
+ return_tensors: Optional[Union[str, TensorType]] = None,
461
+ **kwargs,
462
+ ):
463
+ # Group videos by size for batched resizing
464
+ grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
465
+ resized_videos_grouped = {}
466
+ for shape, stacked_videos in grouped_videos.items():
467
+ if do_convert_rgb:
468
+ stacked_videos = self.convert_to_rgb(stacked_videos)
469
+
470
+ height, width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST)
471
+ resized_height, resized_width = height, width
472
+ if do_resize:
473
+ resized_height, resized_width = smart_resize(
474
+ height,
475
+ width,
476
+ factor=patch_size * merge_size,
477
+ min_pixels=size["shortest_edge"],
478
+ max_pixels=size["longest_edge"],
479
+ )
480
+ stacked_videos = self.resize(
481
+ image=stacked_videos,
482
+ size=SizeDict(height=resized_height, width=resized_width),
483
+ interpolation=interpolation,
484
+ )
485
+ resized_videos_grouped[shape] = stacked_videos
486
+ resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
487
+
488
+ # Group videos by size for further processing
489
+ # Needed in case do_resize is False, or resize returns videos with different sizes
490
+ grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos)
491
+ processed_videos_grouped = {}
492
+ processed_grids = {}
493
+ for shape, stacked_videos in grouped_videos.items():
494
+ resized_height, resized_width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST)
495
+
496
+ # Fused rescale and normalize
497
+ stacked_videos = self.rescale_and_normalize(
498
+ stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std
499
+ )
500
+ patches = stacked_videos
501
+
502
+ batch_size, grid_t, channel = patches.shape[:3]
503
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
504
+
505
+ patches = patches.view(
506
+ batch_size,
507
+ grid_t,
508
+ channel,
509
+ grid_h // merge_size,
510
+ merge_size,
511
+ patch_size,
512
+ grid_w // merge_size,
513
+ merge_size,
514
+ patch_size,
515
+ )
516
+ # Reorder dimensions to group grid and patch information for subsequent flattening.
517
+ # [batch, grid_t, grid_h/merge, grid_w/merge, merge, merge, channel, patch, patch]
518
+ patches = patches.permute(0, 1, 3, 6, 4, 7, 2, 5, 8)
519
+
520
+ flatten_patches = patches.reshape(
521
+ batch_size,
522
+ grid_t * grid_h * grid_w,
523
+ channel * patch_size * patch_size,
524
+ )
525
+
526
+ processed_videos_grouped[shape] = flatten_patches
527
+ processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
528
+
529
+ processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
530
+ processed_grids = reorder_videos(processed_grids, grouped_videos_index)
531
+ pixel_values_videos = torch.cat(processed_videos, dim=0)
532
+ video_grid_thw = torch.tensor(processed_grids)
533
+
534
+ return BatchFeature(
535
+ data={"pixel_values_videos": pixel_values_videos, "video_grid_thw": video_grid_thw},
536
+ tensor_type=return_tensors,
537
+ )
538
+
539
+ @add_start_docstrings(
540
+ BASE_VIDEO_PROCESSOR_DOCSTRING,
541
+ )
542
+ def preprocess(
543
+ self,
544
+ videos: VideoInput,
545
+ **kwargs: Unpack[VideosKwargs],
546
+ ) -> BatchFeature:
547
+ validate_kwargs(
548
+ captured_kwargs=kwargs.keys(),
549
+ valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
550
+ )
551
+
552
+ # Perform type validation on received kwargs
553
+ validate_typed_dict(self.valid_kwargs, kwargs)
554
+
555
+ # Set default kwargs from self. This ensures that if a kwarg is not provided
556
+ # by the user, it gets its default value from the instance, or is set to None.
557
+ for kwarg_name in self.valid_kwargs.__annotations__:
558
+ kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
559
+
560
+ input_data_format = kwargs.pop("input_data_format")
561
+ do_sample_frames = kwargs.pop("do_sample_frames")
562
+ device = kwargs.pop("device")
563
+ video_metadata = kwargs.pop("video_metadata")
564
+ draw_on_frames = kwargs.pop("draw_on_frames")
565
+
566
+ sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
567
+ videos, video_metadata = self._decode_and_sample_videos(
568
+ videos,
569
+ video_metadata=video_metadata,
570
+ do_sample_frames=do_sample_frames,
571
+ sample_indices_fn=sample_indices_fn,
572
+ )
573
+ videos = self._prepare_input_videos(
574
+ videos=videos,
575
+ input_data_format=input_data_format,
576
+ device=device,
577
+ video_metadata=video_metadata,
578
+ draw_on_frames=draw_on_frames,
579
+ )
580
+
581
+ kwargs = self._further_process_kwargs(**kwargs)
582
+ self._validate_preprocess_kwargs(**kwargs)
583
+
584
+ # Pop kwargs that are not needed in _preprocess
585
+ kwargs.pop("data_format")
586
+ return_metadata = kwargs.pop("return_metadata")
587
+
588
+ preprocessed_videos = self._preprocess(videos=videos, **kwargs)
589
+ if return_metadata:
590
+ preprocessed_videos["video_metadata"] = video_metadata
591
+ return preprocessed_videos
592
+
593
+
594
+ __all__ = ["Ernie4_5_VL_MoeVideoProcessor"]
@@ -90,6 +90,7 @@ class RotaryEmbedding(torch.nn.Module):
90
90
 
91
91
  def __init__(self, dim: int):
92
92
  super().__init__()
93
+ self.dim = dim
93
94
  # Generate and save the inverse frequency buffer (non trainable)
94
95
  inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
95
96
  self.register_buffer("inv_freq", inv_freq)
@@ -558,6 +559,11 @@ class EsmPreTrainedModel(PreTrainedModel):
558
559
  super()._init_weights(module)
559
560
  if isinstance(module, EsmLMHead):
560
561
  init.zeros_(module.bias)
562
+ elif isinstance(module, EsmEmbeddings):
563
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
564
+ elif isinstance(module, RotaryEmbedding):
565
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
566
+ init.copy_(module.inv_freq, inv_freq)
561
567
 
562
568
  def get_output_embeddings(self):
563
569
  # NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
@@ -912,7 +912,7 @@ class EsmFoldPreTrainedModel(EsmPreTrainedModel):
912
912
  elif module.init == "gating":
913
913
  init.zeros_(module.weight)
914
914
  if module.bias:
915
- init.ones(module.bias)
915
+ init.ones_(module.bias)
916
916
  elif module.init == "normal":
917
917
  init.kaiming_normal_(module.weight, nonlinearity="linear")
918
918
  elif module.init == "final":
@@ -1979,6 +1979,11 @@ class EsmForProteinFolding(EsmPreTrainedModel):
1979
1979
 
1980
1980
  _can_record_outputs = None
1981
1981
 
1982
+ def _init_weights(self, module):
1983
+ super()._init_weights(module)
1984
+ if isinstance(module, EsmForProteinFolding):
1985
+ init.copy_(module.af2_to_esm, module._af2_to_esm_from_vocab_list(module.config.vocab_list))
1986
+
1982
1987
  def __init__(self, config):
1983
1988
  super().__init__(config)
1984
1989
 
@@ -185,6 +185,7 @@ class EvollaSaProtRotaryEmbedding(nn.Module):
185
185
 
186
186
  def __init__(self, dim: int):
187
187
  super().__init__()
188
+ self.dim = dim
188
189
  # Generate and save the inverse frequency buffer (non trainable)
189
190
  inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
190
191
  self.register_buffer("inv_freq", inv_freq)
@@ -518,12 +519,19 @@ class EvollaSaProtPreTrainedModel(PreTrainedModel):
518
519
  ],
519
520
  }
520
521
 
522
+ def _init_weights(self, module):
523
+ super()._init_weights(module)
524
+ if isinstance(module, EvollaSaProtRotaryEmbedding):
525
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
526
+ init.copy_(module.inv_freq, inv_freq)
527
+
521
528
 
522
529
  class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
523
530
  def __init__(self, config: SaProtConfig):
524
531
  super().__init__(config)
525
532
  self.embeddings = EvollaSaProtEmbeddings(config)
526
533
  self.encoder = EvollaSaProtEncoder(config)
534
+ self.post_init()
527
535
 
528
536
  def get_input_embeddings(self):
529
537
  return self.embeddings.word_embeddings
@@ -980,7 +988,7 @@ class EvollaRotaryEmbedding(nn.Module):
980
988
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
981
989
 
982
990
  self.register_buffer("inv_freq", inv_freq, persistent=False)
983
- self.original_inv_freq = inv_freq
991
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
984
992
 
985
993
  @staticmethod
986
994
  def compute_default_rope_parameters(