transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -212,6 +212,10 @@ class LayoutLMv3PreTrainedModel(PreTrainedModel):
212
212
  if self.config.visual_embed:
213
213
  init.zeros_(module.cls_token)
214
214
  init.zeros_(module.pos_embed)
215
+ if hasattr(module, "visual_bbox"):
216
+ init.copy_(module.visual_bbox, module.create_visual_bbox(image_size=(module.size, module.size)))
217
+ elif isinstance(module, LayoutLMv3TextEmbeddings):
218
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
215
219
 
216
220
 
217
221
  class LayoutLMv3SelfAttention(nn.Module):
@@ -576,16 +580,18 @@ class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
576
580
  # when the input_size is larger in fine-tuning, we will interpolate the position embeddings in forward
577
581
  self.patch_embed = LayoutLMv3PatchEmbeddings(config)
578
582
 
579
- size = int(config.input_size / config.patch_size)
583
+ self.size = int(config.input_size / config.patch_size)
580
584
  self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
581
- self.pos_embed = nn.Parameter(torch.zeros(1, size * size + 1, config.hidden_size))
585
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.size * self.size + 1, config.hidden_size))
582
586
  self.pos_drop = nn.Dropout(p=0.0)
583
587
 
584
588
  self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
585
589
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
586
590
 
587
591
  if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
588
- self.init_visual_bbox(image_size=(size, size))
592
+ self.register_buffer(
593
+ "visual_bbox", self.create_visual_bbox(image_size=(self.size, self.size)), persistent=False
594
+ )
589
595
 
590
596
  self.norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
591
597
 
@@ -599,7 +605,7 @@ class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
599
605
  def set_input_embeddings(self, value):
600
606
  self.embeddings.word_embeddings = value
601
607
 
602
- def init_visual_bbox(self, image_size=(14, 14), max_len=1000):
608
+ def create_visual_bbox(self, image_size=(14, 14), max_len=1000):
603
609
  """
604
610
  Create the bounding boxes for the visual (patch) tokens.
605
611
  """
@@ -620,7 +626,7 @@ class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
620
626
  ).view(-1, 4)
621
627
 
622
628
  cls_token_box = torch.tensor([[0 + 1, 0 + 1, max_len - 1, max_len - 1]])
623
- self.visual_bbox = torch.cat([cls_token_box, visual_bbox], dim=0)
629
+ return torch.cat([cls_token_box, visual_bbox], dim=0)
624
630
 
625
631
  def calculate_visual_bbox(self, device, dtype, batch_size):
626
632
  visual_bbox = self.visual_bbox.repeat(batch_size, 1, 1)
@@ -884,6 +890,12 @@ class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel):
884
890
 
885
891
  self.post_init()
886
892
 
893
+ def get_input_embeddings(self):
894
+ return self.layoutlmv3.get_input_embeddings()
895
+
896
+ def set_input_embeddings(self, value):
897
+ self.layoutlmv3.set_input_embeddings(value)
898
+
887
899
  @auto_docstring
888
900
  def forward(
889
901
  self,
@@ -984,6 +996,12 @@ class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel):
984
996
 
985
997
  self.post_init()
986
998
 
999
+ def get_input_embeddings(self):
1000
+ return self.layoutlmv3.get_input_embeddings()
1001
+
1002
+ def set_input_embeddings(self, value):
1003
+ self.layoutlmv3.set_input_embeddings(value)
1004
+
987
1005
  @auto_docstring
988
1006
  def forward(
989
1007
  self,
@@ -1104,6 +1122,12 @@ class LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel):
1104
1122
 
1105
1123
  self.post_init()
1106
1124
 
1125
+ def get_input_embeddings(self):
1126
+ return self.layoutlmv3.get_input_embeddings()
1127
+
1128
+ def set_input_embeddings(self, value):
1129
+ self.layoutlmv3.set_input_embeddings(value)
1130
+
1107
1131
  @auto_docstring
1108
1132
  def forward(
1109
1133
  self,
@@ -23,6 +23,7 @@ import torch
23
23
  from torch import nn
24
24
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
25
 
26
+ from ... import initialization as init
26
27
  from ...activations import ACT2FN
27
28
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
28
29
  from ...generation import GenerationMixin
@@ -1077,6 +1078,11 @@ class LEDPreTrainedModel(PreTrainedModel):
1077
1078
  }
1078
1079
  return dummy_inputs
1079
1080
 
1081
+ def _init_weights(self, module):
1082
+ super()._init_weights(module)
1083
+ if isinstance(module, LEDForConditionalGeneration):
1084
+ init.zeros_(module.final_logits_bias)
1085
+
1080
1086
 
1081
1087
  @dataclass
1082
1088
  @auto_docstring(
@@ -21,6 +21,7 @@ from typing import Optional, Union
21
21
  import torch
22
22
  from torch import nn
23
23
 
24
+ from ... import initialization as init
24
25
  from ...modeling_outputs import (
25
26
  BaseModelOutputWithNoAttention,
26
27
  BaseModelOutputWithPoolingAndNoAttention,
@@ -165,6 +166,7 @@ class LevitAttention(nn.Module):
165
166
 
166
167
  points = list(itertools.product(range(resolution), range(resolution)))
167
168
  len_points = len(points)
169
+ self.len_points = len_points
168
170
  attention_offsets, indices = {}, []
169
171
  for p1 in points:
170
172
  for p2 in points:
@@ -172,6 +174,7 @@ class LevitAttention(nn.Module):
172
174
  if offset not in attention_offsets:
173
175
  attention_offsets[offset] = len(attention_offsets)
174
176
  indices.append(attention_offsets[offset])
177
+ self.indices = indices
175
178
 
176
179
  self.attention_bias_cache = {}
177
180
  self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
@@ -243,6 +246,8 @@ class LevitAttentionSubsample(nn.Module):
243
246
  points = list(itertools.product(range(resolution_in), range(resolution_in)))
244
247
  points_ = list(itertools.product(range(resolution_out), range(resolution_out)))
245
248
  len_points, len_points_ = len(points), len(points_)
249
+ self.len_points_ = len_points_
250
+ self.len_points = len_points
246
251
  attention_offsets, indices = {}, []
247
252
  for p1 in points_:
248
253
  for p2 in points:
@@ -251,6 +256,7 @@ class LevitAttentionSubsample(nn.Module):
251
256
  if offset not in attention_offsets:
252
257
  attention_offsets[offset] = len(attention_offsets)
253
258
  indices.append(attention_offsets[offset])
259
+ self.indices = indices
254
260
 
255
261
  self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
256
262
  self.register_buffer(
@@ -472,6 +478,18 @@ class LevitPreTrainedModel(PreTrainedModel):
472
478
  input_modalities = ("image",)
473
479
  _no_split_modules = ["LevitResidualLayer"]
474
480
 
481
+ def _init_weights(self, module):
482
+ super()._init_weights(module)
483
+ if isinstance(module, LevitAttention):
484
+ init.copy_(
485
+ module.attention_bias_idxs, torch.LongTensor(module.indices).view(module.len_points, module.len_points)
486
+ )
487
+ elif isinstance(module, LevitAttentionSubsample):
488
+ init.copy_(
489
+ module.attention_bias_idxs,
490
+ torch.LongTensor(module.indices).view(module.len_points_, module.len_points),
491
+ )
492
+
475
493
 
476
494
  @auto_docstring
477
495
  class LevitModel(LevitPreTrainedModel):
@@ -83,7 +83,7 @@ class Lfm2RotaryEmbedding(nn.Module):
83
83
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
84
84
 
85
85
  self.register_buffer("inv_freq", inv_freq, persistent=False)
86
- self.original_inv_freq = inv_freq
86
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
87
87
 
88
88
  @staticmethod
89
89
  def compute_default_rope_parameters(
@@ -27,7 +27,12 @@ from torch import nn
27
27
  from ... import initialization as init
28
28
  from ...cache_utils import Cache
29
29
  from ...generation import GenerationMixin
30
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
30
+ from ...integrations import (
31
+ use_experts_implementation,
32
+ use_kernel_forward_from_hub,
33
+ use_kernel_func_from_hub,
34
+ use_kernelized_func,
35
+ )
31
36
  from ...masking_utils import create_causal_mask
32
37
  from ...modeling_layers import GradientCheckpointingLayer
33
38
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, MoeModelOutputWithPast
@@ -84,7 +89,7 @@ class Lfm2MoeRotaryEmbedding(nn.Module):
84
89
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
85
90
 
86
91
  self.register_buffer("inv_freq", inv_freq, persistent=False)
87
- self.original_inv_freq = inv_freq
92
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
88
93
 
89
94
  @staticmethod
90
95
  def compute_default_rope_parameters(
@@ -145,6 +150,7 @@ class Lfm2MoeMLP(nn.Module):
145
150
  return self.w2(F.silu(self.w1(x)) * self.w3(x))
146
151
 
147
152
 
153
+ @use_experts_implementation
148
154
  class Lfm2MoeExperts(nn.Module):
149
155
  """Collection of expert weights stored as 3D tensors."""
150
156
 
@@ -155,6 +161,7 @@ class Lfm2MoeExperts(nn.Module):
155
161
  self.intermediate_dim = config.moe_intermediate_size
156
162
  self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
157
163
  self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
164
+ self.act_fn = F.silu
158
165
 
159
166
  def forward(
160
167
  self,
@@ -175,7 +182,7 @@ class Lfm2MoeExperts(nn.Module):
175
182
  top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
176
183
  current_state = hidden_states[token_idx]
177
184
  gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
178
- current_hidden_states = F.silu(gate) * up
185
+ current_hidden_states = self.act_fn(gate) * up
179
186
  current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
180
187
  current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
181
188
  final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
@@ -671,7 +678,7 @@ class Lfm2MoePreTrainedModel(PreTrainedModel):
671
678
  _supports_flash_attn = True
672
679
  _supports_sdpa = True
673
680
  _supports_flex_attn = True
674
- _can_compile_fullgraph = False
681
+ _can_compile_fullgraph = False # uses a non-compilable custom cache class Lfm2MoeHybridConvCache
675
682
  _supports_attention_backend = True
676
683
  _can_record_outputs = {
677
684
  "hidden_states": Lfm2MoeDecoderLayer,
@@ -684,6 +691,9 @@ class Lfm2MoePreTrainedModel(PreTrainedModel):
684
691
  if isinstance(module, Lfm2MoeExperts):
685
692
  init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
686
693
  init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
694
+ elif isinstance(module, Lfm2MoeSparseMoeBlock):
695
+ if module.use_expert_bias:
696
+ init.zeros_(module.expert_bias)
687
697
 
688
698
 
689
699
  @auto_docstring
@@ -72,33 +72,7 @@ class Lfm2MoeMLP(Lfm2MLP):
72
72
  class Lfm2MoeExperts(Qwen2MoeExperts):
73
73
  def __init__(self, config):
74
74
  super().__init__(config)
75
- del self.act_fn
76
-
77
- def forward(
78
- self,
79
- hidden_states: torch.Tensor,
80
- top_k_index: torch.Tensor,
81
- top_k_weights: torch.Tensor,
82
- ) -> torch.Tensor:
83
- final_hidden_states = torch.zeros_like(hidden_states)
84
- with torch.no_grad():
85
- expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
86
- expert_mask = expert_mask.permute(2, 1, 0)
87
- expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
88
-
89
- for expert_idx in expert_hit:
90
- expert_idx = expert_idx[0]
91
- if expert_idx == self.num_experts:
92
- continue
93
- top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
94
- current_state = hidden_states[token_idx]
95
- gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
96
- current_hidden_states = F.silu(gate) * up
97
- current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
98
- current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
99
- final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
100
-
101
- return final_hidden_states
75
+ self.act_fn = F.silu
102
76
 
103
77
 
104
78
  class Lfm2MoeSparseMoeBlock(nn.Module):
@@ -160,7 +134,7 @@ class Lfm2MoeDecoderLayer(Lfm2DecoderLayer):
160
134
 
161
135
 
162
136
  class Lfm2MoePreTrainedModel(LlamaPreTrainedModel):
163
- _can_compile_fullgraph = False
137
+ _can_compile_fullgraph = False # uses a non-compilable custom cache class Lfm2MoeHybridConvCache
164
138
 
165
139
  @torch.no_grad()
166
140
  def _init_weights(self, module):
@@ -168,6 +142,9 @@ class Lfm2MoePreTrainedModel(LlamaPreTrainedModel):
168
142
  if isinstance(module, Lfm2MoeExperts):
169
143
  init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
170
144
  init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
145
+ elif isinstance(module, Lfm2MoeSparseMoeBlock):
146
+ if module.use_expert_bias:
147
+ init.zeros_(module.expert_bias)
171
148
 
172
149
 
173
150
  class Lfm2MoeModel(MixtralModel):
@@ -46,6 +46,8 @@ class Lfm2VlConfig(PreTrainedConfig):
46
46
  The hidden size of the multimodal projector.
47
47
  projector_bias (`bool`, *optional*, defaults to `True`):
48
48
  Whether to use bias in the multimodal projector.
49
+ projector_use_layernorm (`bool`, *optional*, defaults to `True`):
50
+ Whether to use layernorm in the multimodal projector.
49
51
  downsample_factor (`int`, *optional*, defaults to 2):
50
52
  The downsample_factor factor of the vision backbone.
51
53
  """
@@ -61,6 +63,7 @@ class Lfm2VlConfig(PreTrainedConfig):
61
63
  projector_hidden_act="gelu",
62
64
  projector_hidden_size=2560,
63
65
  projector_bias=True,
66
+ projector_use_layernorm=True,
64
67
  downsample_factor=2,
65
68
  **kwargs,
66
69
  ):
@@ -68,6 +71,7 @@ class Lfm2VlConfig(PreTrainedConfig):
68
71
  self.projector_hidden_act = projector_hidden_act
69
72
  self.projector_hidden_size = projector_hidden_size
70
73
  self.projector_bias = projector_bias
74
+ self.projector_use_layernorm = projector_use_layernorm
71
75
  self.downsample_factor = downsample_factor
72
76
 
73
77
  if isinstance(vision_config, dict):
@@ -41,7 +41,8 @@ class Lfm2VlMultiModalProjector(nn.Module):
41
41
  super().__init__()
42
42
  in_channels = config.vision_config.hidden_size * (config.downsample_factor**2)
43
43
  self.factor = config.downsample_factor
44
- self.layer_norm = nn.LayerNorm(in_channels)
44
+ self.use_layer_norm = config.projector_use_layernorm
45
+ self.layer_norm = nn.LayerNorm(in_channels) if config.projector_use_layernorm else None
45
46
  self.linear_1 = nn.Linear(
46
47
  in_channels,
47
48
  config.projector_hidden_size,
@@ -56,7 +57,8 @@ class Lfm2VlMultiModalProjector(nn.Module):
56
57
 
57
58
  def forward(self, image_features: torch.Tensor):
58
59
  image_features = self.pixel_unshuffle(image_features)
59
- image_features = self.layer_norm(image_features)
60
+ if self.use_layer_norm:
61
+ image_features = self.layer_norm(image_features)
60
62
  hidden_states = self.linear_1(image_features)
61
63
  hidden_states = self.act(hidden_states)
62
64
  hidden_states = self.linear_2(hidden_states)
@@ -448,6 +450,7 @@ class Lfm2VlForConditionalGeneration(Lfm2VlPreTrainedModel, GenerationMixin):
448
450
  attention_mask=None,
449
451
  cache_position=None,
450
452
  logits_to_keep=None,
453
+ is_first_iteration=False,
451
454
  **kwargs,
452
455
  ):
453
456
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -459,12 +462,15 @@ class Lfm2VlForConditionalGeneration(Lfm2VlPreTrainedModel, GenerationMixin):
459
462
  attention_mask=attention_mask,
460
463
  cache_position=cache_position,
461
464
  logits_to_keep=logits_to_keep,
465
+ is_first_iteration=is_first_iteration,
462
466
  **kwargs,
463
467
  )
464
468
 
465
- if cache_position[0] == 0:
466
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
467
- # Otherwise we need pixel values to be passed to model
469
+ if is_first_iteration or not kwargs.get("use_cache", True):
470
+ # Pixel values are used only in the first iteration if available
471
+ # In subsquent iterations, they are already merged with text and cached
472
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
473
+ # iteration with a question and cached system prompt (continue generate from cache)
468
474
  model_inputs["pixel_values"] = pixel_values
469
475
 
470
476
  return model_inputs
@@ -41,7 +41,8 @@ class Lfm2VlMultiModalProjector(nn.Module):
41
41
  super().__init__()
42
42
  in_channels = config.vision_config.hidden_size * (config.downsample_factor**2)
43
43
  self.factor = config.downsample_factor
44
- self.layer_norm = nn.LayerNorm(in_channels)
44
+ self.use_layer_norm = config.projector_use_layernorm
45
+ self.layer_norm = nn.LayerNorm(in_channels) if config.projector_use_layernorm else None
45
46
  self.linear_1 = nn.Linear(
46
47
  in_channels,
47
48
  config.projector_hidden_size,
@@ -56,7 +57,8 @@ class Lfm2VlMultiModalProjector(nn.Module):
56
57
 
57
58
  def forward(self, image_features: torch.Tensor):
58
59
  image_features = self.pixel_unshuffle(image_features)
59
- image_features = self.layer_norm(image_features)
60
+ if self.use_layer_norm:
61
+ image_features = self.layer_norm(image_features)
60
62
  hidden_states = self.linear_1(image_features)
61
63
  hidden_states = self.act(hidden_states)
62
64
  hidden_states = self.linear_2(hidden_states)
@@ -165,63 +165,103 @@ class Lfm2VlProcessor(ProcessorMixin):
165
165
  image_sizes: list[list[int]],
166
166
  use_image_special_tokens: bool,
167
167
  **images_kwargs,
168
- ):
169
- prompt_strings = []
168
+ ) -> list[str]:
169
+ use_thumbnail = images_kwargs.get("use_thumbnail", self.image_processor.use_thumbnail)
170
+ image_data = iter(zip(image_rows, image_cols, image_sizes))
170
171
 
171
- image_data = iter(zip(*[image_rows, image_cols, image_sizes]))
172
+ prompt_strings = []
172
173
  for sample_text, sample_images in zip(text, images):
173
- split_sample = sample_text.split(self.image_token)
174
- sample_text_with_image_tokens = ""
175
- for i, image in enumerate(sample_images):
176
- sample_text_with_image_tokens += split_sample[i]
177
- if use_image_special_tokens:
178
- sample_text_with_image_tokens += self.image_start_token
174
+ text_parts = sample_text.split(self.image_token)
175
+ result_parts = []
176
+
177
+ for i, _ in enumerate(sample_images):
178
+ result_parts.append(text_parts[i])
179
179
 
180
180
  rows, cols, image_size = next(image_data)
181
- num_thumbnail_tokens, num_tokens_per_tile = self._get_image_num_tokens(image_size, **images_kwargs)
182
-
183
- if rows > 1 or cols > 1:
184
- for row in range(rows):
185
- for col in range(cols):
186
- if use_image_special_tokens:
187
- sample_text_with_image_tokens += f"<|img_row_{row + 1}_col_{col + 1}|>"
188
- sample_text_with_image_tokens += self.image_token * num_tokens_per_tile
189
-
190
- if num_thumbnail_tokens > 0:
191
- if use_image_special_tokens:
192
- sample_text_with_image_tokens += self.image_thumbnail_token
193
- sample_text_with_image_tokens += self.image_token * num_thumbnail_tokens
194
- else:
195
- sample_text_with_image_tokens += self.image_token * num_thumbnail_tokens
181
+ tokens_per_tile, tokens_for_image = self._get_image_num_tokens(image_size, **images_kwargs)
182
+ image_tokens = self._build_image_tokens(
183
+ rows,
184
+ cols,
185
+ tokens_per_tile,
186
+ tokens_for_image,
187
+ use_thumbnail,
188
+ use_image_special_tokens,
189
+ )
190
+ result_parts.append(image_tokens)
196
191
 
197
- if use_image_special_tokens:
198
- sample_text_with_image_tokens += self.image_end_token
192
+ # Add remaining text after the last image
193
+ if len(sample_images) < len(text_parts):
194
+ result_parts.append(text_parts[-1])
199
195
 
200
- sample_text_with_image_tokens += split_sample[i + 1]
201
- prompt_strings.append(sample_text_with_image_tokens)
196
+ prompt_strings.append("".join(result_parts))
202
197
 
203
198
  return prompt_strings
204
199
 
200
+ def _build_image_tokens(
201
+ self,
202
+ rows: int,
203
+ cols: int,
204
+ tokens_per_tile: int,
205
+ tokens_for_image: int,
206
+ use_thumbnail: bool,
207
+ use_image_special_tokens: bool,
208
+ ) -> str:
209
+ """Build the expanded token string for a single image."""
210
+ parts = []
211
+
212
+ if use_image_special_tokens:
213
+ parts.append(self.image_start_token)
214
+
215
+ is_multi_tile = rows > 1 or cols > 1
216
+ if is_multi_tile:
217
+ for row in range(rows):
218
+ for col in range(cols):
219
+ if use_image_special_tokens:
220
+ parts.append(f"<|img_row_{row + 1}_col_{col + 1}|>")
221
+ parts.append(self.image_token * tokens_per_tile)
222
+
223
+ if use_thumbnail:
224
+ if use_image_special_tokens:
225
+ parts.append(self.image_thumbnail_token)
226
+ parts.append(self.image_token * tokens_for_image)
227
+ else:
228
+ parts.append(self.image_token * tokens_for_image)
229
+
230
+ if use_image_special_tokens:
231
+ parts.append(self.image_end_token)
232
+
233
+ return "".join(parts)
234
+
235
+ def _compute_tokens_per_tile(self, tile_size: int, encoder_patch_size: int, downsample_factor: int) -> int:
236
+ """Compute the number of tokens for a single tile."""
237
+ num_patches = tile_size // encoder_patch_size
238
+ downsampled_patches = math.ceil(num_patches / downsample_factor)
239
+ return downsampled_patches * downsampled_patches
240
+
241
+ def _compute_tokens_for_image(self, image_size: list[int], encoder_patch_size: int, downsample_factor: int) -> int:
242
+ """Compute the number of tokens for a resized image (used for single-tile or thumbnail)."""
243
+ image_height, image_width = image_size
244
+ patches_h = math.ceil((image_height // encoder_patch_size) / downsample_factor)
245
+ patches_w = math.ceil((image_width // encoder_patch_size) / downsample_factor)
246
+ return patches_h * patches_w
247
+
205
248
  def _get_image_num_tokens(self, image_size: list[int], **images_kwargs) -> tuple[int, int]:
249
+ """
250
+ Compute token counts for image processing.
251
+
252
+ Returns:
253
+ tuple[int, int]: (tokens_per_tile, tokens_for_image)
254
+ - tokens_per_tile: tokens for each tile in multi-tile mode
255
+ - tokens_for_image: tokens for the resized image (single-tile) or thumbnail (multi-tile)
256
+ """
206
257
  tile_size = images_kwargs.get("tile_size", self.image_processor.tile_size)
207
258
  downsample_factor = images_kwargs.get("downsample_factor", self.image_processor.downsample_factor)
208
259
  encoder_patch_size = images_kwargs.get("encoder_patch_size", self.image_processor.encoder_patch_size)
209
- use_thumbnail = images_kwargs.get("use_thumbnail", self.image_processor.use_thumbnail)
210
-
211
- thumbnail_tokens = 0
212
- if use_thumbnail:
213
- image_height, image_width = image_size
214
- num_patches_height = image_height // encoder_patch_size
215
- num_patches_width = image_width // encoder_patch_size
216
- dwn_num_patches_height = math.ceil(num_patches_height / downsample_factor)
217
- dwn_num_patches_width = math.ceil(num_patches_width / downsample_factor)
218
- thumbnail_tokens = dwn_num_patches_height * dwn_num_patches_width
219
260
 
220
- num_patches_tile = tile_size // encoder_patch_size
221
- dwn_num_patches_tile = math.ceil(num_patches_tile / downsample_factor)
222
- tile_tokens = dwn_num_patches_tile * dwn_num_patches_tile
261
+ tokens_per_tile = self._compute_tokens_per_tile(tile_size, encoder_patch_size, downsample_factor)
262
+ tokens_for_image = self._compute_tokens_for_image(image_size, encoder_patch_size, downsample_factor)
223
263
 
224
- return thumbnail_tokens, tile_tokens
264
+ return tokens_per_tile, tokens_for_image
225
265
 
226
266
  def batch_decode(self, *args, **kwargs):
227
267
  """
@@ -174,9 +174,8 @@ class LightGlueImageProcessorFast(BaseImageProcessorFast):
174
174
  stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs]
175
175
 
176
176
  # Return in same format as slow processor
177
- image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs
178
177
 
179
- return BatchFeature(data={"pixel_values": image_pairs})
178
+ return BatchFeature(data={"pixel_values": stacked_pairs}, tensor_type=return_tensors)
180
179
 
181
180
  def post_process_keypoint_matching(
182
181
  self,
@@ -21,6 +21,7 @@ import torch
21
21
  from torch import nn
22
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
23
 
24
+ from ... import initialization as init
24
25
  from ...activations import ACT2FN
25
26
  from ...modeling_layers import GradientCheckpointingLayer
26
27
  from ...modeling_outputs import (
@@ -279,11 +280,9 @@ class LiltSelfAttention(nn.Module):
279
280
  new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
280
281
  context_layer = context_layer.view(*new_context_layer_shape)
281
282
 
282
- outputs = (
283
- ((context_layer, layout_context_layer), attention_probs)
284
- if output_attentions
285
- else ((context_layer, layout_context_layer),)
286
- )
283
+ outputs = (context_layer, layout_context_layer)
284
+ if output_attentions:
285
+ outputs = outputs + (attention_probs,)
287
286
 
288
287
  return outputs
289
288
 
@@ -327,9 +326,9 @@ class LiltAttention(nn.Module):
327
326
  attention_mask,
328
327
  output_attentions,
329
328
  )
330
- attention_output = self.output(self_outputs[0][0], hidden_states)
331
- layout_attention_output = self.layout_output(self_outputs[0][1], layout_inputs)
332
- outputs = ((attention_output, layout_attention_output),) + self_outputs[1:] # add attentions if we output them
329
+ attention_output = self.output(self_outputs[0], hidden_states)
330
+ layout_attention_output = self.layout_output(self_outputs[1], layout_inputs)
331
+ outputs = (attention_output, layout_attention_output) + self_outputs[2:] # add attentions if we output them
333
332
  return outputs
334
333
 
335
334
 
@@ -395,10 +394,10 @@ class LiltLayer(GradientCheckpointingLayer):
395
394
  attention_mask,
396
395
  output_attentions=output_attentions,
397
396
  )
398
- attention_output = self_attention_outputs[0][0]
399
- layout_attention_output = self_attention_outputs[0][1]
397
+ attention_output = self_attention_outputs[0]
398
+ layout_attention_output = self_attention_outputs[1]
400
399
 
401
- outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
400
+ outputs = self_attention_outputs[2:] # add self attentions if we output attention weights
402
401
 
403
402
  layer_output = apply_chunking_to_forward(
404
403
  self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
@@ -406,7 +405,7 @@ class LiltLayer(GradientCheckpointingLayer):
406
405
  layout_layer_output = apply_chunking_to_forward(
407
406
  self.layout_feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, layout_attention_output
408
407
  )
409
- outputs = ((layer_output, layout_layer_output),) + outputs
408
+ outputs = (layer_output, layout_layer_output) + outputs
410
409
 
411
410
  return outputs
412
411
 
@@ -451,11 +450,11 @@ class LiltEncoder(nn.Module):
451
450
  output_attentions,
452
451
  )
453
452
 
454
- hidden_states = layer_outputs[0][0]
455
- layout_inputs = layer_outputs[0][1]
453
+ hidden_states = layer_outputs[0]
454
+ layout_inputs = layer_outputs[1]
456
455
 
457
456
  if output_attentions:
458
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
457
+ all_self_attentions = all_self_attentions + (layer_outputs[2],)
459
458
 
460
459
  if output_hidden_states:
461
460
  all_hidden_states = all_hidden_states + (hidden_states,)
@@ -500,6 +499,11 @@ class LiltPreTrainedModel(PreTrainedModel):
500
499
  supports_gradient_checkpointing = True
501
500
  _no_split_modules = []
502
501
 
502
+ def _init_weights(self, module):
503
+ super()._init_weights(module)
504
+ if isinstance(module, LiltTextEmbeddings):
505
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
506
+
503
507
 
504
508
  @auto_docstring
505
509
  class LiltModel(LiltPreTrainedModel):
@@ -87,7 +87,7 @@ class LlamaRotaryEmbedding(nn.Module):
87
87
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
88
88
 
89
89
  self.register_buffer("inv_freq", inv_freq, persistent=False)
90
- self.original_inv_freq = inv_freq
90
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
91
91
 
92
92
  @staticmethod
93
93
  def compute_default_rope_parameters(
@@ -419,10 +419,9 @@ class Llama4ImageProcessorFast(BaseImageProcessorFast):
419
419
  )
420
420
  grouped_processed_images[shape] = torch.cat([processed_images, global_tiles.unsqueeze(1)], dim=1)
421
421
  processed_images = reorder_images(grouped_processed_images, grouped_images_index)
422
- aspect_ratios_list = reorder_images(grouped_aspect_ratios, grouped_images_index)
422
+ aspect_ratios = reorder_images(grouped_aspect_ratios, grouped_images_index)
423
423
 
424
424
  processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
425
- aspect_ratios = torch.stack(aspect_ratios_list, dim=0) if return_tensors else aspect_ratios_list
426
425
  return BatchFeature(
427
426
  data={"pixel_values": processed_images, "aspect_ratios": aspect_ratios}, tensor_type=return_tensors
428
427
  )