transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1021) hide show
  1. transformers/__init__.py +4 -11
  2. transformers/activations.py +2 -2
  3. transformers/backbone_utils.py +326 -0
  4. transformers/cache_utils.py +11 -2
  5. transformers/cli/serve.py +11 -8
  6. transformers/configuration_utils.py +1 -69
  7. transformers/conversion_mapping.py +146 -26
  8. transformers/convert_slow_tokenizer.py +6 -4
  9. transformers/core_model_loading.py +207 -118
  10. transformers/dependency_versions_check.py +0 -1
  11. transformers/dependency_versions_table.py +7 -8
  12. transformers/file_utils.py +0 -2
  13. transformers/generation/candidate_generator.py +1 -2
  14. transformers/generation/continuous_batching/cache.py +40 -38
  15. transformers/generation/continuous_batching/cache_manager.py +3 -16
  16. transformers/generation/continuous_batching/continuous_api.py +94 -406
  17. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  18. transformers/generation/continuous_batching/requests.py +54 -17
  19. transformers/generation/continuous_batching/scheduler.py +77 -95
  20. transformers/generation/logits_process.py +10 -5
  21. transformers/generation/stopping_criteria.py +1 -2
  22. transformers/generation/utils.py +75 -95
  23. transformers/image_processing_utils.py +0 -3
  24. transformers/image_processing_utils_fast.py +17 -18
  25. transformers/image_transforms.py +44 -13
  26. transformers/image_utils.py +0 -5
  27. transformers/initialization.py +57 -0
  28. transformers/integrations/__init__.py +10 -24
  29. transformers/integrations/accelerate.py +47 -11
  30. transformers/integrations/deepspeed.py +145 -3
  31. transformers/integrations/executorch.py +2 -6
  32. transformers/integrations/finegrained_fp8.py +142 -7
  33. transformers/integrations/flash_attention.py +2 -7
  34. transformers/integrations/hub_kernels.py +18 -7
  35. transformers/integrations/moe.py +226 -106
  36. transformers/integrations/mxfp4.py +47 -34
  37. transformers/integrations/peft.py +488 -176
  38. transformers/integrations/tensor_parallel.py +641 -581
  39. transformers/masking_utils.py +153 -9
  40. transformers/modeling_flash_attention_utils.py +1 -2
  41. transformers/modeling_utils.py +359 -358
  42. transformers/models/__init__.py +6 -0
  43. transformers/models/afmoe/configuration_afmoe.py +14 -4
  44. transformers/models/afmoe/modeling_afmoe.py +8 -8
  45. transformers/models/afmoe/modular_afmoe.py +7 -7
  46. transformers/models/aimv2/configuration_aimv2.py +2 -7
  47. transformers/models/aimv2/modeling_aimv2.py +26 -24
  48. transformers/models/aimv2/modular_aimv2.py +8 -12
  49. transformers/models/albert/configuration_albert.py +8 -1
  50. transformers/models/albert/modeling_albert.py +3 -3
  51. transformers/models/align/configuration_align.py +8 -5
  52. transformers/models/align/modeling_align.py +22 -24
  53. transformers/models/altclip/configuration_altclip.py +4 -6
  54. transformers/models/altclip/modeling_altclip.py +30 -26
  55. transformers/models/apertus/configuration_apertus.py +5 -7
  56. transformers/models/apertus/modeling_apertus.py +4 -4
  57. transformers/models/apertus/modular_apertus.py +8 -10
  58. transformers/models/arcee/configuration_arcee.py +5 -7
  59. transformers/models/arcee/modeling_arcee.py +4 -4
  60. transformers/models/aria/configuration_aria.py +11 -21
  61. transformers/models/aria/modeling_aria.py +39 -36
  62. transformers/models/aria/modular_aria.py +33 -39
  63. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
  64. transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
  65. transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
  66. transformers/models/auto/auto_factory.py +8 -6
  67. transformers/models/auto/configuration_auto.py +22 -0
  68. transformers/models/auto/image_processing_auto.py +17 -13
  69. transformers/models/auto/modeling_auto.py +15 -0
  70. transformers/models/auto/processing_auto.py +9 -18
  71. transformers/models/auto/tokenization_auto.py +17 -15
  72. transformers/models/autoformer/modeling_autoformer.py +2 -1
  73. transformers/models/aya_vision/configuration_aya_vision.py +4 -0
  74. transformers/models/aya_vision/modeling_aya_vision.py +29 -62
  75. transformers/models/aya_vision/modular_aya_vision.py +20 -45
  76. transformers/models/bamba/configuration_bamba.py +17 -7
  77. transformers/models/bamba/modeling_bamba.py +23 -55
  78. transformers/models/bamba/modular_bamba.py +19 -54
  79. transformers/models/bark/configuration_bark.py +2 -1
  80. transformers/models/bark/modeling_bark.py +24 -10
  81. transformers/models/bart/configuration_bart.py +9 -4
  82. transformers/models/bart/modeling_bart.py +9 -12
  83. transformers/models/beit/configuration_beit.py +2 -4
  84. transformers/models/beit/image_processing_beit_fast.py +3 -3
  85. transformers/models/beit/modeling_beit.py +14 -9
  86. transformers/models/bert/configuration_bert.py +12 -1
  87. transformers/models/bert/modeling_bert.py +6 -30
  88. transformers/models/bert_generation/configuration_bert_generation.py +17 -1
  89. transformers/models/bert_generation/modeling_bert_generation.py +6 -6
  90. transformers/models/big_bird/configuration_big_bird.py +12 -8
  91. transformers/models/big_bird/modeling_big_bird.py +0 -15
  92. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
  93. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
  94. transformers/models/biogpt/configuration_biogpt.py +8 -1
  95. transformers/models/biogpt/modeling_biogpt.py +4 -8
  96. transformers/models/biogpt/modular_biogpt.py +1 -5
  97. transformers/models/bit/configuration_bit.py +2 -4
  98. transformers/models/bit/modeling_bit.py +6 -5
  99. transformers/models/bitnet/configuration_bitnet.py +5 -7
  100. transformers/models/bitnet/modeling_bitnet.py +3 -4
  101. transformers/models/bitnet/modular_bitnet.py +3 -4
  102. transformers/models/blenderbot/configuration_blenderbot.py +8 -4
  103. transformers/models/blenderbot/modeling_blenderbot.py +4 -4
  104. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
  105. transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
  106. transformers/models/blip/configuration_blip.py +9 -9
  107. transformers/models/blip/modeling_blip.py +55 -37
  108. transformers/models/blip_2/configuration_blip_2.py +2 -1
  109. transformers/models/blip_2/modeling_blip_2.py +81 -56
  110. transformers/models/bloom/configuration_bloom.py +5 -1
  111. transformers/models/bloom/modeling_bloom.py +2 -1
  112. transformers/models/blt/configuration_blt.py +23 -12
  113. transformers/models/blt/modeling_blt.py +20 -14
  114. transformers/models/blt/modular_blt.py +70 -10
  115. transformers/models/bridgetower/configuration_bridgetower.py +7 -1
  116. transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
  117. transformers/models/bridgetower/modeling_bridgetower.py +29 -15
  118. transformers/models/bros/configuration_bros.py +24 -17
  119. transformers/models/camembert/configuration_camembert.py +8 -1
  120. transformers/models/camembert/modeling_camembert.py +6 -6
  121. transformers/models/canine/configuration_canine.py +4 -1
  122. transformers/models/chameleon/configuration_chameleon.py +5 -7
  123. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
  124. transformers/models/chameleon/modeling_chameleon.py +82 -36
  125. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
  126. transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
  127. transformers/models/clap/configuration_clap.py +4 -8
  128. transformers/models/clap/modeling_clap.py +21 -22
  129. transformers/models/clip/configuration_clip.py +4 -1
  130. transformers/models/clip/image_processing_clip_fast.py +9 -0
  131. transformers/models/clip/modeling_clip.py +25 -22
  132. transformers/models/clipseg/configuration_clipseg.py +4 -1
  133. transformers/models/clipseg/modeling_clipseg.py +27 -25
  134. transformers/models/clipseg/processing_clipseg.py +11 -3
  135. transformers/models/clvp/configuration_clvp.py +14 -2
  136. transformers/models/clvp/modeling_clvp.py +19 -30
  137. transformers/models/codegen/configuration_codegen.py +4 -3
  138. transformers/models/codegen/modeling_codegen.py +2 -1
  139. transformers/models/cohere/configuration_cohere.py +5 -7
  140. transformers/models/cohere/modeling_cohere.py +4 -4
  141. transformers/models/cohere/modular_cohere.py +3 -3
  142. transformers/models/cohere2/configuration_cohere2.py +6 -8
  143. transformers/models/cohere2/modeling_cohere2.py +4 -4
  144. transformers/models/cohere2/modular_cohere2.py +9 -11
  145. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  146. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
  147. transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
  148. transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
  149. transformers/models/colqwen2/modeling_colqwen2.py +7 -6
  150. transformers/models/colqwen2/modular_colqwen2.py +7 -6
  151. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
  152. transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
  153. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
  154. transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
  155. transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
  156. transformers/models/convbert/configuration_convbert.py +11 -7
  157. transformers/models/convnext/configuration_convnext.py +2 -4
  158. transformers/models/convnext/image_processing_convnext_fast.py +2 -2
  159. transformers/models/convnext/modeling_convnext.py +7 -6
  160. transformers/models/convnextv2/configuration_convnextv2.py +2 -4
  161. transformers/models/convnextv2/modeling_convnextv2.py +7 -6
  162. transformers/models/cpmant/configuration_cpmant.py +4 -0
  163. transformers/models/csm/configuration_csm.py +9 -15
  164. transformers/models/csm/modeling_csm.py +3 -3
  165. transformers/models/ctrl/configuration_ctrl.py +16 -0
  166. transformers/models/ctrl/modeling_ctrl.py +13 -25
  167. transformers/models/cwm/configuration_cwm.py +5 -7
  168. transformers/models/cwm/modeling_cwm.py +4 -4
  169. transformers/models/d_fine/configuration_d_fine.py +10 -56
  170. transformers/models/d_fine/modeling_d_fine.py +728 -868
  171. transformers/models/d_fine/modular_d_fine.py +335 -412
  172. transformers/models/dab_detr/configuration_dab_detr.py +22 -48
  173. transformers/models/dab_detr/modeling_dab_detr.py +11 -7
  174. transformers/models/dac/modeling_dac.py +1 -1
  175. transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
  176. transformers/models/data2vec/configuration_data2vec_text.py +11 -2
  177. transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
  178. transformers/models/data2vec/modeling_data2vec_text.py +6 -6
  179. transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
  180. transformers/models/dbrx/configuration_dbrx.py +11 -3
  181. transformers/models/dbrx/modeling_dbrx.py +6 -6
  182. transformers/models/dbrx/modular_dbrx.py +6 -6
  183. transformers/models/deberta/configuration_deberta.py +6 -0
  184. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
  185. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
  186. transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
  187. transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
  188. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
  189. transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
  190. transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
  191. transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
  192. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
  193. transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
  194. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
  195. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
  196. transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
  197. transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
  198. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
  199. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
  200. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
  201. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
  202. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
  203. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
  204. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
  205. transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
  206. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
  207. transformers/models/deit/modeling_deit.py +11 -7
  208. transformers/models/depth_anything/configuration_depth_anything.py +12 -42
  209. transformers/models/depth_anything/modeling_depth_anything.py +5 -3
  210. transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
  211. transformers/models/depth_pro/modeling_depth_pro.py +8 -4
  212. transformers/models/detr/configuration_detr.py +18 -49
  213. transformers/models/detr/image_processing_detr_fast.py +11 -11
  214. transformers/models/detr/modeling_detr.py +695 -734
  215. transformers/models/dia/configuration_dia.py +4 -7
  216. transformers/models/dia/generation_dia.py +8 -17
  217. transformers/models/dia/modeling_dia.py +7 -7
  218. transformers/models/dia/modular_dia.py +4 -4
  219. transformers/models/diffllama/configuration_diffllama.py +5 -7
  220. transformers/models/diffllama/modeling_diffllama.py +3 -8
  221. transformers/models/diffllama/modular_diffllama.py +2 -7
  222. transformers/models/dinat/configuration_dinat.py +2 -4
  223. transformers/models/dinat/modeling_dinat.py +7 -6
  224. transformers/models/dinov2/configuration_dinov2.py +2 -4
  225. transformers/models/dinov2/modeling_dinov2.py +9 -8
  226. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
  227. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
  228. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
  229. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
  230. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
  231. transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
  232. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
  233. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
  234. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
  235. transformers/models/distilbert/configuration_distilbert.py +8 -1
  236. transformers/models/distilbert/modeling_distilbert.py +3 -3
  237. transformers/models/doge/configuration_doge.py +17 -7
  238. transformers/models/doge/modeling_doge.py +4 -4
  239. transformers/models/doge/modular_doge.py +20 -10
  240. transformers/models/donut/image_processing_donut_fast.py +4 -4
  241. transformers/models/dots1/configuration_dots1.py +16 -7
  242. transformers/models/dots1/modeling_dots1.py +4 -4
  243. transformers/models/dpr/configuration_dpr.py +19 -1
  244. transformers/models/dpt/configuration_dpt.py +23 -65
  245. transformers/models/dpt/image_processing_dpt_fast.py +5 -5
  246. transformers/models/dpt/modeling_dpt.py +19 -15
  247. transformers/models/dpt/modular_dpt.py +4 -4
  248. transformers/models/edgetam/configuration_edgetam.py +1 -1
  249. transformers/models/edgetam/modeling_edgetam.py +53 -53
  250. transformers/models/edgetam/modular_edgetam.py +5 -7
  251. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
  252. transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
  253. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
  254. transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
  255. transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
  256. transformers/models/electra/configuration_electra.py +13 -2
  257. transformers/models/electra/modeling_electra.py +6 -6
  258. transformers/models/emu3/configuration_emu3.py +12 -10
  259. transformers/models/emu3/modeling_emu3.py +84 -47
  260. transformers/models/emu3/modular_emu3.py +77 -39
  261. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
  262. transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
  263. transformers/models/eomt/configuration_eomt.py +12 -13
  264. transformers/models/eomt/image_processing_eomt_fast.py +3 -3
  265. transformers/models/eomt/modeling_eomt.py +3 -3
  266. transformers/models/eomt/modular_eomt.py +17 -17
  267. transformers/models/eomt_dinov3/__init__.py +28 -0
  268. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  269. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  270. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  271. transformers/models/ernie/configuration_ernie.py +24 -2
  272. transformers/models/ernie/modeling_ernie.py +6 -30
  273. transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
  274. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  275. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
  276. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
  277. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
  278. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
  279. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
  280. transformers/models/esm/configuration_esm.py +9 -11
  281. transformers/models/esm/modeling_esm.py +3 -3
  282. transformers/models/esm/modeling_esmfold.py +1 -6
  283. transformers/models/esm/openfold_utils/protein.py +2 -3
  284. transformers/models/evolla/configuration_evolla.py +21 -8
  285. transformers/models/evolla/modeling_evolla.py +11 -7
  286. transformers/models/evolla/modular_evolla.py +5 -1
  287. transformers/models/exaone4/configuration_exaone4.py +8 -5
  288. transformers/models/exaone4/modeling_exaone4.py +4 -4
  289. transformers/models/exaone4/modular_exaone4.py +11 -8
  290. transformers/models/exaone_moe/__init__.py +27 -0
  291. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  292. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  293. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  294. transformers/models/falcon/configuration_falcon.py +9 -1
  295. transformers/models/falcon/modeling_falcon.py +3 -8
  296. transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
  297. transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
  298. transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
  299. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
  300. transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
  301. transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
  302. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
  303. transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
  304. transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
  305. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
  306. transformers/models/flaubert/configuration_flaubert.py +10 -4
  307. transformers/models/flaubert/modeling_flaubert.py +1 -1
  308. transformers/models/flava/configuration_flava.py +4 -3
  309. transformers/models/flava/image_processing_flava_fast.py +4 -4
  310. transformers/models/flava/modeling_flava.py +36 -28
  311. transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
  312. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
  313. transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
  314. transformers/models/florence2/configuration_florence2.py +4 -0
  315. transformers/models/florence2/modeling_florence2.py +57 -32
  316. transformers/models/florence2/modular_florence2.py +48 -26
  317. transformers/models/fnet/configuration_fnet.py +6 -1
  318. transformers/models/focalnet/configuration_focalnet.py +2 -4
  319. transformers/models/focalnet/modeling_focalnet.py +10 -7
  320. transformers/models/fsmt/configuration_fsmt.py +12 -16
  321. transformers/models/funnel/configuration_funnel.py +8 -0
  322. transformers/models/fuyu/configuration_fuyu.py +5 -8
  323. transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
  324. transformers/models/fuyu/modeling_fuyu.py +24 -23
  325. transformers/models/gemma/configuration_gemma.py +5 -7
  326. transformers/models/gemma/modeling_gemma.py +4 -4
  327. transformers/models/gemma/modular_gemma.py +5 -7
  328. transformers/models/gemma2/configuration_gemma2.py +5 -7
  329. transformers/models/gemma2/modeling_gemma2.py +4 -4
  330. transformers/models/gemma2/modular_gemma2.py +8 -10
  331. transformers/models/gemma3/configuration_gemma3.py +28 -22
  332. transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
  333. transformers/models/gemma3/modeling_gemma3.py +37 -33
  334. transformers/models/gemma3/modular_gemma3.py +46 -42
  335. transformers/models/gemma3n/configuration_gemma3n.py +35 -22
  336. transformers/models/gemma3n/modeling_gemma3n.py +86 -58
  337. transformers/models/gemma3n/modular_gemma3n.py +112 -75
  338. transformers/models/git/configuration_git.py +5 -7
  339. transformers/models/git/modeling_git.py +31 -41
  340. transformers/models/glm/configuration_glm.py +7 -9
  341. transformers/models/glm/modeling_glm.py +4 -4
  342. transformers/models/glm4/configuration_glm4.py +7 -9
  343. transformers/models/glm4/modeling_glm4.py +4 -4
  344. transformers/models/glm46v/configuration_glm46v.py +4 -0
  345. transformers/models/glm46v/image_processing_glm46v.py +5 -2
  346. transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
  347. transformers/models/glm46v/modeling_glm46v.py +91 -46
  348. transformers/models/glm46v/modular_glm46v.py +4 -0
  349. transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
  350. transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
  351. transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
  352. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
  353. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
  354. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
  355. transformers/models/glm4v/configuration_glm4v.py +12 -8
  356. transformers/models/glm4v/image_processing_glm4v.py +5 -2
  357. transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
  358. transformers/models/glm4v/modeling_glm4v.py +120 -63
  359. transformers/models/glm4v/modular_glm4v.py +82 -50
  360. transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
  361. transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
  362. transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
  363. transformers/models/glm_image/configuration_glm_image.py +26 -20
  364. transformers/models/glm_image/image_processing_glm_image.py +1 -1
  365. transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
  366. transformers/models/glm_image/modeling_glm_image.py +337 -236
  367. transformers/models/glm_image/modular_glm_image.py +415 -255
  368. transformers/models/glm_image/processing_glm_image.py +65 -17
  369. transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
  370. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  371. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  372. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  373. transformers/models/glmasr/modeling_glmasr.py +34 -28
  374. transformers/models/glmasr/modular_glmasr.py +23 -11
  375. transformers/models/glpn/image_processing_glpn_fast.py +3 -3
  376. transformers/models/glpn/modeling_glpn.py +4 -2
  377. transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
  378. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
  379. transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
  380. transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
  381. transformers/models/gpt2/configuration_gpt2.py +13 -1
  382. transformers/models/gpt2/modeling_gpt2.py +5 -5
  383. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
  384. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
  385. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
  386. transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
  387. transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
  388. transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
  389. transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
  390. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
  391. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
  392. transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
  393. transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
  394. transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
  395. transformers/models/gptj/configuration_gptj.py +4 -4
  396. transformers/models/gptj/modeling_gptj.py +3 -7
  397. transformers/models/granite/configuration_granite.py +5 -7
  398. transformers/models/granite/modeling_granite.py +4 -4
  399. transformers/models/granite_speech/modeling_granite_speech.py +63 -37
  400. transformers/models/granitemoe/configuration_granitemoe.py +5 -7
  401. transformers/models/granitemoe/modeling_granitemoe.py +4 -4
  402. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
  403. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
  404. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
  405. transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
  406. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
  407. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
  408. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
  409. transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
  410. transformers/models/groupvit/configuration_groupvit.py +4 -1
  411. transformers/models/groupvit/modeling_groupvit.py +29 -22
  412. transformers/models/helium/configuration_helium.py +5 -7
  413. transformers/models/helium/modeling_helium.py +4 -4
  414. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
  415. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
  416. transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
  417. transformers/models/hiera/configuration_hiera.py +2 -4
  418. transformers/models/hiera/modeling_hiera.py +11 -8
  419. transformers/models/hubert/configuration_hubert.py +4 -1
  420. transformers/models/hubert/modeling_hubert.py +7 -4
  421. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
  422. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
  423. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
  424. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
  425. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
  426. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
  427. transformers/models/ibert/configuration_ibert.py +4 -1
  428. transformers/models/idefics/configuration_idefics.py +5 -7
  429. transformers/models/idefics/modeling_idefics.py +3 -4
  430. transformers/models/idefics/vision.py +5 -4
  431. transformers/models/idefics2/configuration_idefics2.py +1 -2
  432. transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
  433. transformers/models/idefics2/modeling_idefics2.py +72 -50
  434. transformers/models/idefics3/configuration_idefics3.py +1 -3
  435. transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
  436. transformers/models/idefics3/modeling_idefics3.py +63 -40
  437. transformers/models/ijepa/modeling_ijepa.py +3 -3
  438. transformers/models/imagegpt/configuration_imagegpt.py +9 -1
  439. transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
  440. transformers/models/imagegpt/modeling_imagegpt.py +8 -4
  441. transformers/models/informer/modeling_informer.py +3 -3
  442. transformers/models/instructblip/configuration_instructblip.py +2 -1
  443. transformers/models/instructblip/modeling_instructblip.py +65 -39
  444. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
  445. transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
  446. transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
  447. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
  448. transformers/models/internvl/configuration_internvl.py +5 -0
  449. transformers/models/internvl/modeling_internvl.py +35 -55
  450. transformers/models/internvl/modular_internvl.py +26 -38
  451. transformers/models/internvl/video_processing_internvl.py +2 -2
  452. transformers/models/jais2/configuration_jais2.py +5 -7
  453. transformers/models/jais2/modeling_jais2.py +4 -4
  454. transformers/models/jamba/configuration_jamba.py +5 -7
  455. transformers/models/jamba/modeling_jamba.py +4 -4
  456. transformers/models/jamba/modular_jamba.py +3 -3
  457. transformers/models/janus/image_processing_janus.py +2 -2
  458. transformers/models/janus/image_processing_janus_fast.py +8 -8
  459. transformers/models/janus/modeling_janus.py +63 -146
  460. transformers/models/janus/modular_janus.py +62 -20
  461. transformers/models/jetmoe/configuration_jetmoe.py +6 -4
  462. transformers/models/jetmoe/modeling_jetmoe.py +3 -3
  463. transformers/models/jetmoe/modular_jetmoe.py +3 -3
  464. transformers/models/kosmos2/configuration_kosmos2.py +10 -8
  465. transformers/models/kosmos2/modeling_kosmos2.py +56 -34
  466. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
  467. transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
  468. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
  469. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
  470. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
  471. transformers/models/lasr/configuration_lasr.py +2 -4
  472. transformers/models/lasr/modeling_lasr.py +3 -3
  473. transformers/models/lasr/modular_lasr.py +3 -3
  474. transformers/models/layoutlm/configuration_layoutlm.py +14 -1
  475. transformers/models/layoutlm/modeling_layoutlm.py +3 -3
  476. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
  477. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
  478. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
  479. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
  480. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
  481. transformers/models/led/configuration_led.py +7 -8
  482. transformers/models/levit/image_processing_levit_fast.py +4 -4
  483. transformers/models/lfm2/configuration_lfm2.py +5 -7
  484. transformers/models/lfm2/modeling_lfm2.py +4 -4
  485. transformers/models/lfm2/modular_lfm2.py +3 -3
  486. transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
  487. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
  488. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  489. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
  490. transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
  491. transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
  492. transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
  493. transformers/models/lightglue/modeling_lightglue.py +3 -3
  494. transformers/models/lightglue/modular_lightglue.py +3 -3
  495. transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
  496. transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
  497. transformers/models/lilt/configuration_lilt.py +6 -1
  498. transformers/models/llama/configuration_llama.py +5 -7
  499. transformers/models/llama/modeling_llama.py +4 -4
  500. transformers/models/llama4/configuration_llama4.py +67 -47
  501. transformers/models/llama4/image_processing_llama4_fast.py +3 -3
  502. transformers/models/llama4/modeling_llama4.py +46 -44
  503. transformers/models/llava/configuration_llava.py +10 -0
  504. transformers/models/llava/image_processing_llava_fast.py +3 -3
  505. transformers/models/llava/modeling_llava.py +38 -65
  506. transformers/models/llava_next/configuration_llava_next.py +2 -1
  507. transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
  508. transformers/models/llava_next/modeling_llava_next.py +61 -60
  509. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
  510. transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
  511. transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
  512. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
  513. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
  514. transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
  515. transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
  516. transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
  517. transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
  518. transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
  519. transformers/models/longformer/configuration_longformer.py +4 -1
  520. transformers/models/longt5/configuration_longt5.py +9 -6
  521. transformers/models/longt5/modeling_longt5.py +2 -1
  522. transformers/models/luke/configuration_luke.py +8 -1
  523. transformers/models/lw_detr/configuration_lw_detr.py +19 -31
  524. transformers/models/lw_detr/modeling_lw_detr.py +43 -44
  525. transformers/models/lw_detr/modular_lw_detr.py +36 -38
  526. transformers/models/lxmert/configuration_lxmert.py +16 -0
  527. transformers/models/m2m_100/configuration_m2m_100.py +7 -8
  528. transformers/models/m2m_100/modeling_m2m_100.py +3 -3
  529. transformers/models/mamba/configuration_mamba.py +5 -2
  530. transformers/models/mamba/modeling_mamba.py +18 -26
  531. transformers/models/mamba2/configuration_mamba2.py +5 -7
  532. transformers/models/mamba2/modeling_mamba2.py +22 -33
  533. transformers/models/marian/configuration_marian.py +10 -4
  534. transformers/models/marian/modeling_marian.py +4 -4
  535. transformers/models/markuplm/configuration_markuplm.py +4 -6
  536. transformers/models/markuplm/modeling_markuplm.py +3 -3
  537. transformers/models/mask2former/configuration_mask2former.py +12 -47
  538. transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
  539. transformers/models/mask2former/modeling_mask2former.py +18 -12
  540. transformers/models/maskformer/configuration_maskformer.py +14 -45
  541. transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
  542. transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
  543. transformers/models/maskformer/modeling_maskformer.py +15 -9
  544. transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
  545. transformers/models/mbart/configuration_mbart.py +9 -4
  546. transformers/models/mbart/modeling_mbart.py +9 -6
  547. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
  548. transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
  549. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  550. transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
  551. transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
  552. transformers/models/mgp_str/modeling_mgp_str.py +4 -2
  553. transformers/models/mimi/configuration_mimi.py +4 -0
  554. transformers/models/mimi/modeling_mimi.py +40 -36
  555. transformers/models/minimax/configuration_minimax.py +8 -11
  556. transformers/models/minimax/modeling_minimax.py +5 -5
  557. transformers/models/minimax/modular_minimax.py +9 -12
  558. transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
  559. transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
  560. transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
  561. transformers/models/ministral/configuration_ministral.py +5 -7
  562. transformers/models/ministral/modeling_ministral.py +4 -4
  563. transformers/models/ministral/modular_ministral.py +5 -8
  564. transformers/models/ministral3/configuration_ministral3.py +4 -4
  565. transformers/models/ministral3/modeling_ministral3.py +4 -4
  566. transformers/models/ministral3/modular_ministral3.py +3 -3
  567. transformers/models/mistral/configuration_mistral.py +5 -7
  568. transformers/models/mistral/modeling_mistral.py +4 -4
  569. transformers/models/mistral/modular_mistral.py +3 -3
  570. transformers/models/mistral3/configuration_mistral3.py +4 -0
  571. transformers/models/mistral3/modeling_mistral3.py +36 -40
  572. transformers/models/mistral3/modular_mistral3.py +31 -32
  573. transformers/models/mixtral/configuration_mixtral.py +8 -11
  574. transformers/models/mixtral/modeling_mixtral.py +4 -4
  575. transformers/models/mlcd/modeling_mlcd.py +7 -5
  576. transformers/models/mlcd/modular_mlcd.py +7 -5
  577. transformers/models/mllama/configuration_mllama.py +5 -7
  578. transformers/models/mllama/image_processing_mllama_fast.py +6 -5
  579. transformers/models/mllama/modeling_mllama.py +19 -19
  580. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
  581. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
  582. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
  583. transformers/models/mobilebert/configuration_mobilebert.py +4 -1
  584. transformers/models/mobilebert/modeling_mobilebert.py +3 -3
  585. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
  586. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
  587. transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
  588. transformers/models/mobilevit/modeling_mobilevit.py +4 -2
  589. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
  590. transformers/models/modernbert/configuration_modernbert.py +46 -21
  591. transformers/models/modernbert/modeling_modernbert.py +146 -899
  592. transformers/models/modernbert/modular_modernbert.py +185 -908
  593. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
  594. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
  595. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
  596. transformers/models/moonshine/configuration_moonshine.py +12 -7
  597. transformers/models/moonshine/modeling_moonshine.py +7 -7
  598. transformers/models/moonshine/modular_moonshine.py +19 -13
  599. transformers/models/moshi/configuration_moshi.py +28 -2
  600. transformers/models/moshi/modeling_moshi.py +4 -9
  601. transformers/models/mpnet/configuration_mpnet.py +6 -1
  602. transformers/models/mpt/configuration_mpt.py +16 -0
  603. transformers/models/mra/configuration_mra.py +8 -1
  604. transformers/models/mt5/configuration_mt5.py +9 -5
  605. transformers/models/mt5/modeling_mt5.py +5 -8
  606. transformers/models/musicgen/configuration_musicgen.py +12 -7
  607. transformers/models/musicgen/modeling_musicgen.py +6 -5
  608. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
  609. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
  610. transformers/models/mvp/configuration_mvp.py +8 -4
  611. transformers/models/mvp/modeling_mvp.py +6 -4
  612. transformers/models/nanochat/configuration_nanochat.py +5 -7
  613. transformers/models/nanochat/modeling_nanochat.py +4 -4
  614. transformers/models/nanochat/modular_nanochat.py +4 -4
  615. transformers/models/nemotron/configuration_nemotron.py +5 -7
  616. transformers/models/nemotron/modeling_nemotron.py +4 -14
  617. transformers/models/nllb/tokenization_nllb.py +7 -5
  618. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
  619. transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
  620. transformers/models/nougat/image_processing_nougat_fast.py +8 -8
  621. transformers/models/nystromformer/configuration_nystromformer.py +8 -1
  622. transformers/models/olmo/configuration_olmo.py +5 -7
  623. transformers/models/olmo/modeling_olmo.py +4 -4
  624. transformers/models/olmo/modular_olmo.py +3 -3
  625. transformers/models/olmo2/configuration_olmo2.py +9 -11
  626. transformers/models/olmo2/modeling_olmo2.py +4 -4
  627. transformers/models/olmo2/modular_olmo2.py +7 -7
  628. transformers/models/olmo3/configuration_olmo3.py +10 -11
  629. transformers/models/olmo3/modeling_olmo3.py +4 -4
  630. transformers/models/olmo3/modular_olmo3.py +13 -14
  631. transformers/models/olmoe/configuration_olmoe.py +5 -7
  632. transformers/models/olmoe/modeling_olmoe.py +4 -4
  633. transformers/models/olmoe/modular_olmoe.py +3 -3
  634. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
  635. transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
  636. transformers/models/oneformer/configuration_oneformer.py +9 -46
  637. transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
  638. transformers/models/oneformer/modeling_oneformer.py +14 -9
  639. transformers/models/openai/configuration_openai.py +16 -0
  640. transformers/models/opt/configuration_opt.py +6 -6
  641. transformers/models/opt/modeling_opt.py +5 -5
  642. transformers/models/ovis2/configuration_ovis2.py +4 -0
  643. transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
  644. transformers/models/ovis2/modeling_ovis2.py +58 -99
  645. transformers/models/ovis2/modular_ovis2.py +52 -13
  646. transformers/models/owlv2/configuration_owlv2.py +4 -1
  647. transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
  648. transformers/models/owlv2/modeling_owlv2.py +40 -27
  649. transformers/models/owlv2/modular_owlv2.py +5 -5
  650. transformers/models/owlvit/configuration_owlvit.py +4 -1
  651. transformers/models/owlvit/modeling_owlvit.py +40 -27
  652. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
  653. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
  654. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
  655. transformers/models/paligemma/configuration_paligemma.py +4 -0
  656. transformers/models/paligemma/modeling_paligemma.py +30 -26
  657. transformers/models/parakeet/configuration_parakeet.py +2 -4
  658. transformers/models/parakeet/modeling_parakeet.py +3 -3
  659. transformers/models/parakeet/modular_parakeet.py +3 -3
  660. transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
  661. transformers/models/patchtst/modeling_patchtst.py +3 -3
  662. transformers/models/pe_audio/modeling_pe_audio.py +4 -4
  663. transformers/models/pe_audio/modular_pe_audio.py +1 -1
  664. transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
  665. transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
  666. transformers/models/pe_video/modeling_pe_video.py +36 -24
  667. transformers/models/pe_video/modular_pe_video.py +36 -23
  668. transformers/models/pegasus/configuration_pegasus.py +8 -5
  669. transformers/models/pegasus/modeling_pegasus.py +4 -4
  670. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
  671. transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
  672. transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
  673. transformers/models/perceiver/modeling_perceiver.py +17 -9
  674. transformers/models/perception_lm/modeling_perception_lm.py +26 -27
  675. transformers/models/perception_lm/modular_perception_lm.py +27 -25
  676. transformers/models/persimmon/configuration_persimmon.py +5 -7
  677. transformers/models/persimmon/modeling_persimmon.py +5 -5
  678. transformers/models/phi/configuration_phi.py +8 -6
  679. transformers/models/phi/modeling_phi.py +4 -4
  680. transformers/models/phi/modular_phi.py +3 -3
  681. transformers/models/phi3/configuration_phi3.py +9 -11
  682. transformers/models/phi3/modeling_phi3.py +4 -4
  683. transformers/models/phi3/modular_phi3.py +3 -3
  684. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
  685. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
  686. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
  687. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
  688. transformers/models/phimoe/configuration_phimoe.py +5 -7
  689. transformers/models/phimoe/modeling_phimoe.py +15 -39
  690. transformers/models/phimoe/modular_phimoe.py +12 -7
  691. transformers/models/pix2struct/configuration_pix2struct.py +12 -9
  692. transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
  693. transformers/models/pix2struct/modeling_pix2struct.py +14 -7
  694. transformers/models/pixio/configuration_pixio.py +2 -4
  695. transformers/models/pixio/modeling_pixio.py +9 -8
  696. transformers/models/pixio/modular_pixio.py +4 -2
  697. transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
  698. transformers/models/pixtral/modeling_pixtral.py +9 -12
  699. transformers/models/plbart/configuration_plbart.py +8 -5
  700. transformers/models/plbart/modeling_plbart.py +9 -7
  701. transformers/models/plbart/modular_plbart.py +1 -1
  702. transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
  703. transformers/models/pop2piano/configuration_pop2piano.py +7 -6
  704. transformers/models/pop2piano/modeling_pop2piano.py +2 -1
  705. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  706. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  707. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  708. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  709. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  710. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
  711. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
  712. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
  713. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
  714. transformers/models/prophetnet/configuration_prophetnet.py +11 -10
  715. transformers/models/prophetnet/modeling_prophetnet.py +12 -23
  716. transformers/models/pvt/image_processing_pvt.py +7 -7
  717. transformers/models/pvt/image_processing_pvt_fast.py +1 -1
  718. transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
  719. transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
  720. transformers/models/qwen2/configuration_qwen2.py +14 -4
  721. transformers/models/qwen2/modeling_qwen2.py +4 -4
  722. transformers/models/qwen2/modular_qwen2.py +3 -3
  723. transformers/models/qwen2/tokenization_qwen2.py +0 -4
  724. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
  725. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
  726. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
  727. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
  728. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
  729. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
  730. transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
  731. transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
  732. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  733. transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
  734. transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
  735. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
  736. transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
  737. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
  738. transformers/models/qwen3/configuration_qwen3.py +15 -5
  739. transformers/models/qwen3/modeling_qwen3.py +4 -4
  740. transformers/models/qwen3/modular_qwen3.py +3 -3
  741. transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
  742. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  743. transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
  744. transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
  745. transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
  746. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
  747. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
  748. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
  749. transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
  750. transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
  751. transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
  752. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
  753. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
  754. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
  755. transformers/models/rag/configuration_rag.py +6 -6
  756. transformers/models/rag/modeling_rag.py +3 -3
  757. transformers/models/rag/retrieval_rag.py +1 -1
  758. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
  759. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
  760. transformers/models/reformer/configuration_reformer.py +7 -7
  761. transformers/models/rembert/configuration_rembert.py +8 -1
  762. transformers/models/rembert/modeling_rembert.py +0 -22
  763. transformers/models/resnet/configuration_resnet.py +2 -4
  764. transformers/models/resnet/modeling_resnet.py +6 -5
  765. transformers/models/roberta/configuration_roberta.py +11 -2
  766. transformers/models/roberta/modeling_roberta.py +6 -6
  767. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
  768. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
  769. transformers/models/roc_bert/configuration_roc_bert.py +8 -1
  770. transformers/models/roc_bert/modeling_roc_bert.py +6 -41
  771. transformers/models/roformer/configuration_roformer.py +13 -2
  772. transformers/models/roformer/modeling_roformer.py +0 -14
  773. transformers/models/rt_detr/configuration_rt_detr.py +8 -49
  774. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
  775. transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
  776. transformers/models/rt_detr/modeling_rt_detr.py +578 -737
  777. transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
  778. transformers/models/rt_detr/modular_rt_detr.py +1508 -6
  779. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
  780. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
  781. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
  782. transformers/models/rwkv/configuration_rwkv.py +2 -3
  783. transformers/models/rwkv/modeling_rwkv.py +0 -23
  784. transformers/models/sam/configuration_sam.py +2 -0
  785. transformers/models/sam/image_processing_sam_fast.py +4 -4
  786. transformers/models/sam/modeling_sam.py +13 -8
  787. transformers/models/sam/processing_sam.py +3 -3
  788. transformers/models/sam2/configuration_sam2.py +1 -1
  789. transformers/models/sam2/modeling_sam2.py +56 -52
  790. transformers/models/sam2/modular_sam2.py +47 -55
  791. transformers/models/sam2_video/modeling_sam2_video.py +50 -51
  792. transformers/models/sam2_video/modular_sam2_video.py +12 -10
  793. transformers/models/sam3/modeling_sam3.py +43 -47
  794. transformers/models/sam3/processing_sam3.py +8 -4
  795. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
  796. transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
  797. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
  798. transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
  799. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
  800. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
  801. transformers/models/sam3_video/modeling_sam3_video.py +27 -14
  802. transformers/models/sam_hq/configuration_sam_hq.py +2 -0
  803. transformers/models/sam_hq/modeling_sam_hq.py +13 -9
  804. transformers/models/sam_hq/modular_sam_hq.py +6 -6
  805. transformers/models/sam_hq/processing_sam_hq.py +7 -6
  806. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
  807. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
  808. transformers/models/seed_oss/configuration_seed_oss.py +7 -9
  809. transformers/models/seed_oss/modeling_seed_oss.py +4 -4
  810. transformers/models/seed_oss/modular_seed_oss.py +3 -3
  811. transformers/models/segformer/image_processing_segformer_fast.py +4 -4
  812. transformers/models/segformer/modeling_segformer.py +4 -2
  813. transformers/models/segformer/modular_segformer.py +3 -3
  814. transformers/models/seggpt/modeling_seggpt.py +20 -8
  815. transformers/models/sew/configuration_sew.py +4 -1
  816. transformers/models/sew/modeling_sew.py +9 -5
  817. transformers/models/sew/modular_sew.py +2 -1
  818. transformers/models/sew_d/configuration_sew_d.py +4 -1
  819. transformers/models/sew_d/modeling_sew_d.py +4 -1
  820. transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
  821. transformers/models/siglip/configuration_siglip.py +4 -1
  822. transformers/models/siglip/modeling_siglip.py +27 -71
  823. transformers/models/siglip2/__init__.py +1 -0
  824. transformers/models/siglip2/configuration_siglip2.py +4 -2
  825. transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
  826. transformers/models/siglip2/modeling_siglip2.py +37 -78
  827. transformers/models/siglip2/modular_siglip2.py +74 -25
  828. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  829. transformers/models/smollm3/configuration_smollm3.py +6 -6
  830. transformers/models/smollm3/modeling_smollm3.py +4 -4
  831. transformers/models/smollm3/modular_smollm3.py +9 -9
  832. transformers/models/smolvlm/configuration_smolvlm.py +1 -3
  833. transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
  834. transformers/models/smolvlm/modeling_smolvlm.py +75 -46
  835. transformers/models/smolvlm/modular_smolvlm.py +36 -23
  836. transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
  837. transformers/models/solar_open/__init__.py +27 -0
  838. transformers/models/solar_open/configuration_solar_open.py +184 -0
  839. transformers/models/solar_open/modeling_solar_open.py +642 -0
  840. transformers/models/solar_open/modular_solar_open.py +224 -0
  841. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
  842. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
  843. transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
  844. transformers/models/speecht5/configuration_speecht5.py +7 -8
  845. transformers/models/splinter/configuration_splinter.py +6 -6
  846. transformers/models/splinter/modeling_splinter.py +8 -3
  847. transformers/models/squeezebert/configuration_squeezebert.py +14 -1
  848. transformers/models/stablelm/configuration_stablelm.py +8 -6
  849. transformers/models/stablelm/modeling_stablelm.py +5 -5
  850. transformers/models/starcoder2/configuration_starcoder2.py +11 -5
  851. transformers/models/starcoder2/modeling_starcoder2.py +5 -5
  852. transformers/models/starcoder2/modular_starcoder2.py +4 -4
  853. transformers/models/superglue/configuration_superglue.py +4 -0
  854. transformers/models/superglue/image_processing_superglue_fast.py +4 -3
  855. transformers/models/superglue/modeling_superglue.py +9 -4
  856. transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
  857. transformers/models/superpoint/modeling_superpoint.py +4 -2
  858. transformers/models/swin/configuration_swin.py +2 -4
  859. transformers/models/swin/modeling_swin.py +11 -8
  860. transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
  861. transformers/models/swin2sr/modeling_swin2sr.py +4 -2
  862. transformers/models/swinv2/configuration_swinv2.py +2 -4
  863. transformers/models/swinv2/modeling_swinv2.py +10 -7
  864. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
  865. transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
  866. transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
  867. transformers/models/t5/configuration_t5.py +9 -8
  868. transformers/models/t5/modeling_t5.py +5 -8
  869. transformers/models/t5gemma/configuration_t5gemma.py +10 -25
  870. transformers/models/t5gemma/modeling_t5gemma.py +9 -9
  871. transformers/models/t5gemma/modular_t5gemma.py +11 -24
  872. transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
  873. transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
  874. transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
  875. transformers/models/table_transformer/configuration_table_transformer.py +18 -49
  876. transformers/models/table_transformer/modeling_table_transformer.py +27 -53
  877. transformers/models/tapas/configuration_tapas.py +12 -1
  878. transformers/models/tapas/modeling_tapas.py +1 -1
  879. transformers/models/tapas/tokenization_tapas.py +1 -0
  880. transformers/models/textnet/configuration_textnet.py +4 -6
  881. transformers/models/textnet/image_processing_textnet_fast.py +3 -3
  882. transformers/models/textnet/modeling_textnet.py +15 -14
  883. transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
  884. transformers/models/timesfm/modeling_timesfm.py +5 -6
  885. transformers/models/timesfm/modular_timesfm.py +5 -6
  886. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
  887. transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
  888. transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
  889. transformers/models/trocr/configuration_trocr.py +11 -7
  890. transformers/models/trocr/modeling_trocr.py +4 -2
  891. transformers/models/tvp/configuration_tvp.py +10 -35
  892. transformers/models/tvp/image_processing_tvp_fast.py +6 -5
  893. transformers/models/tvp/modeling_tvp.py +1 -1
  894. transformers/models/udop/configuration_udop.py +16 -7
  895. transformers/models/udop/modeling_udop.py +10 -6
  896. transformers/models/umt5/configuration_umt5.py +8 -6
  897. transformers/models/umt5/modeling_umt5.py +7 -3
  898. transformers/models/unispeech/configuration_unispeech.py +4 -1
  899. transformers/models/unispeech/modeling_unispeech.py +7 -4
  900. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
  901. transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
  902. transformers/models/upernet/configuration_upernet.py +8 -35
  903. transformers/models/upernet/modeling_upernet.py +1 -1
  904. transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
  905. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  906. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  907. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
  908. transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
  909. transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
  910. transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
  911. transformers/models/video_llava/configuration_video_llava.py +4 -0
  912. transformers/models/video_llava/modeling_video_llava.py +87 -89
  913. transformers/models/videomae/modeling_videomae.py +4 -5
  914. transformers/models/vilt/configuration_vilt.py +4 -1
  915. transformers/models/vilt/image_processing_vilt_fast.py +6 -6
  916. transformers/models/vilt/modeling_vilt.py +27 -12
  917. transformers/models/vipllava/configuration_vipllava.py +4 -0
  918. transformers/models/vipllava/modeling_vipllava.py +57 -31
  919. transformers/models/vipllava/modular_vipllava.py +50 -24
  920. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
  921. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
  922. transformers/models/visual_bert/configuration_visual_bert.py +6 -1
  923. transformers/models/vit/configuration_vit.py +2 -2
  924. transformers/models/vit/modeling_vit.py +7 -5
  925. transformers/models/vit_mae/modeling_vit_mae.py +11 -7
  926. transformers/models/vit_msn/modeling_vit_msn.py +11 -7
  927. transformers/models/vitdet/configuration_vitdet.py +2 -4
  928. transformers/models/vitdet/modeling_vitdet.py +2 -3
  929. transformers/models/vitmatte/configuration_vitmatte.py +6 -35
  930. transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
  931. transformers/models/vitmatte/modeling_vitmatte.py +1 -1
  932. transformers/models/vitpose/configuration_vitpose.py +6 -43
  933. transformers/models/vitpose/modeling_vitpose.py +5 -3
  934. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
  935. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
  936. transformers/models/vits/configuration_vits.py +4 -0
  937. transformers/models/vits/modeling_vits.py +9 -7
  938. transformers/models/vivit/modeling_vivit.py +4 -4
  939. transformers/models/vjepa2/modeling_vjepa2.py +9 -9
  940. transformers/models/voxtral/configuration_voxtral.py +0 -1
  941. transformers/models/voxtral/modeling_voxtral.py +25 -24
  942. transformers/models/voxtral/modular_voxtral.py +26 -20
  943. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
  944. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
  945. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
  946. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
  947. transformers/models/wavlm/configuration_wavlm.py +4 -1
  948. transformers/models/wavlm/modeling_wavlm.py +4 -1
  949. transformers/models/whisper/configuration_whisper.py +6 -4
  950. transformers/models/whisper/generation_whisper.py +0 -1
  951. transformers/models/whisper/modeling_whisper.py +3 -3
  952. transformers/models/x_clip/configuration_x_clip.py +4 -1
  953. transformers/models/x_clip/modeling_x_clip.py +26 -27
  954. transformers/models/xglm/configuration_xglm.py +9 -7
  955. transformers/models/xlm/configuration_xlm.py +10 -7
  956. transformers/models/xlm/modeling_xlm.py +1 -1
  957. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
  958. transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
  959. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
  960. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
  961. transformers/models/xlnet/configuration_xlnet.py +3 -1
  962. transformers/models/xlstm/configuration_xlstm.py +5 -7
  963. transformers/models/xlstm/modeling_xlstm.py +0 -32
  964. transformers/models/xmod/configuration_xmod.py +11 -2
  965. transformers/models/xmod/modeling_xmod.py +13 -16
  966. transformers/models/yolos/image_processing_yolos_fast.py +25 -28
  967. transformers/models/yolos/modeling_yolos.py +7 -7
  968. transformers/models/yolos/modular_yolos.py +16 -16
  969. transformers/models/yoso/configuration_yoso.py +8 -1
  970. transformers/models/youtu/__init__.py +27 -0
  971. transformers/models/youtu/configuration_youtu.py +194 -0
  972. transformers/models/youtu/modeling_youtu.py +619 -0
  973. transformers/models/youtu/modular_youtu.py +254 -0
  974. transformers/models/zamba/configuration_zamba.py +5 -7
  975. transformers/models/zamba/modeling_zamba.py +25 -56
  976. transformers/models/zamba2/configuration_zamba2.py +8 -13
  977. transformers/models/zamba2/modeling_zamba2.py +53 -78
  978. transformers/models/zamba2/modular_zamba2.py +36 -29
  979. transformers/models/zoedepth/configuration_zoedepth.py +17 -40
  980. transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
  981. transformers/models/zoedepth/modeling_zoedepth.py +5 -3
  982. transformers/pipelines/__init__.py +1 -61
  983. transformers/pipelines/any_to_any.py +1 -1
  984. transformers/pipelines/automatic_speech_recognition.py +0 -2
  985. transformers/pipelines/base.py +1 -1
  986. transformers/pipelines/image_text_to_text.py +1 -1
  987. transformers/pipelines/text_to_audio.py +5 -1
  988. transformers/processing_utils.py +35 -44
  989. transformers/pytorch_utils.py +2 -26
  990. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  991. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  992. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  993. transformers/quantizers/quantizer_mxfp4.py +1 -1
  994. transformers/quantizers/quantizer_torchao.py +0 -16
  995. transformers/safetensors_conversion.py +11 -4
  996. transformers/testing_utils.py +3 -28
  997. transformers/tokenization_mistral_common.py +9 -0
  998. transformers/tokenization_python.py +6 -4
  999. transformers/tokenization_utils_base.py +119 -219
  1000. transformers/tokenization_utils_tokenizers.py +31 -2
  1001. transformers/trainer.py +25 -33
  1002. transformers/trainer_seq2seq.py +1 -1
  1003. transformers/training_args.py +411 -417
  1004. transformers/utils/__init__.py +1 -4
  1005. transformers/utils/auto_docstring.py +15 -18
  1006. transformers/utils/backbone_utils.py +13 -373
  1007. transformers/utils/doc.py +4 -36
  1008. transformers/utils/generic.py +69 -33
  1009. transformers/utils/import_utils.py +72 -75
  1010. transformers/utils/loading_report.py +133 -105
  1011. transformers/utils/quantization_config.py +0 -21
  1012. transformers/video_processing_utils.py +5 -5
  1013. transformers/video_utils.py +3 -1
  1014. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
  1015. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
  1016. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1017. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1018. transformers/pipelines/image_to_text.py +0 -189
  1019. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1020. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1021. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -24,8 +24,9 @@ import sys
24
24
  import warnings
25
25
  from abc import abstractmethod
26
26
  from collections import defaultdict
27
- from collections.abc import Callable, Iterator, Sequence
27
+ from collections.abc import Callable, Iterator
28
28
  from contextlib import contextmanager
29
+ from dataclasses import dataclass, field
29
30
  from enum import Enum
30
31
  from functools import partial, wraps
31
32
  from itertools import cycle
@@ -77,9 +78,8 @@ from .integrations.tensor_parallel import (
77
78
  ALL_PARALLEL_STYLES,
78
79
  _get_parameter_tp_plan,
79
80
  distribute_model,
81
+ gather_state_dict_for_save,
80
82
  initialize_tensor_parallelism,
81
- repack_weights,
82
- replace_state_dict_local_with_dtensor,
83
83
  shard_and_distribute_module,
84
84
  verify_tp_plan,
85
85
  )
@@ -106,25 +106,26 @@ from .utils import (
106
106
  copy_func,
107
107
  has_file,
108
108
  is_accelerate_available,
109
+ is_bitsandbytes_available,
110
+ is_env_variable_true,
109
111
  is_flash_attn_2_available,
110
112
  is_flash_attn_3_available,
111
113
  is_grouped_mm_available,
112
114
  is_kernels_available,
113
115
  is_torch_flex_attn_available,
114
- is_torch_greater_or_equal,
115
116
  is_torch_mlu_available,
116
117
  is_torch_npu_available,
117
118
  is_torch_xpu_available,
118
119
  logging,
119
120
  )
120
- from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
121
+ from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder, is_flash_attention_requested
121
122
  from .utils.hub import DownloadKwargs, create_and_tag_model_card, get_checkpoint_shard_files
122
123
  from .utils.import_utils import (
123
124
  is_huggingface_hub_greater_or_equal,
124
125
  is_sagemaker_mp_enabled,
125
126
  is_tracing,
126
127
  )
127
- from .utils.loading_report import log_state_dict_report
128
+ from .utils.loading_report import LoadStateDictInfo, log_state_dict_report
128
129
  from .utils.quantization_config import QuantizationMethod
129
130
 
130
131
 
@@ -134,9 +135,6 @@ if is_accelerate_available():
134
135
 
135
136
 
136
137
  _torch_distributed_available = torch.distributed.is_available()
137
- _is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
138
- if _is_dtensor_available:
139
- from torch.distributed.tensor import DTensor
140
138
 
141
139
  if is_sagemaker_mp_enabled():
142
140
  import smdistributed.modelparallel.torch as smp
@@ -162,6 +160,33 @@ FLASH_ATTN_KERNEL_FALLBACK = {
162
160
  }
163
161
 
164
162
 
163
+ @dataclass(frozen=True)
164
+ class LoadStateDictConfig:
165
+ """
166
+ Config for loading weights. This allows bundling arguments that are just
167
+ passed around.
168
+ """
169
+
170
+ pretrained_model_name_or_path: str | None = None
171
+ download_kwargs: DownloadKwargs | None = field(default_factory=DownloadKwargs)
172
+ use_safetensors: bool | None = None
173
+ ignore_mismatched_sizes: bool = False
174
+ sharded_metadata: dict | None = None
175
+ device_map: dict | None = None
176
+ disk_offload_folder: str | None = None
177
+ offload_buffers: bool = False
178
+ dtype: torch.dtype | None = None
179
+ dtype_plan: dict = field(default_factory=dict)
180
+ hf_quantizer: HfQuantizer | None = None
181
+ device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None
182
+ weights_only: bool = True
183
+ weight_mapping: list[WeightConverter | WeightRenaming] | None = None
184
+
185
+ @property
186
+ def is_quantized(self) -> bool:
187
+ return self.hf_quantizer is not None
188
+
189
+
165
190
  def is_local_dist_rank_0():
166
191
  return (
167
192
  torch.distributed.is_available()
@@ -223,8 +248,7 @@ def get_torch_context_manager_or_global_device():
223
248
  is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided.
224
249
  """
225
250
  device_in_context = torch.tensor([]).device
226
- # `get_default_device` was only introduced in torch>=2.3 - use cpu otherwise to align the behavior
227
- default_device = torch.get_default_device() if is_torch_greater_or_equal("2.3") else torch.device("cpu")
251
+ default_device = torch.get_default_device()
228
252
  # This case means no context manager was used -> we still check if the default that was potentially set is not cpu
229
253
  if device_in_context == default_device:
230
254
  if default_device != torch.device("cpu"):
@@ -252,23 +276,20 @@ str_to_torch_dtype = {
252
276
  "U8": torch.uint8,
253
277
  "I8": torch.int8,
254
278
  "I16": torch.int16,
279
+ "U16": torch.uint16,
255
280
  "F16": torch.float16,
256
281
  "BF16": torch.bfloat16,
257
282
  "I32": torch.int32,
283
+ "U32": torch.uint32,
258
284
  "F32": torch.float32,
259
285
  "F64": torch.float64,
260
286
  "I64": torch.int64,
287
+ "U64": torch.uint64,
261
288
  "F8_E4M3": torch.float8_e4m3fn,
262
289
  "F8_E5M2": torch.float8_e5m2,
263
290
  }
264
291
 
265
292
 
266
- if is_torch_greater_or_equal("2.3.0"):
267
- str_to_torch_dtype["U16"] = torch.uint16
268
- str_to_torch_dtype["U32"] = torch.uint32
269
- str_to_torch_dtype["U64"] = torch.uint64
270
-
271
-
272
293
  def load_state_dict(
273
294
  checkpoint_file: str | os.PathLike, map_location: str | torch.device = "cpu", weights_only: bool = True
274
295
  ) -> dict[str, torch.Tensor]:
@@ -472,15 +493,16 @@ def _get_resolved_checkpoint_files(
472
493
  variant: str | None,
473
494
  gguf_file: str | None,
474
495
  use_safetensors: bool | None,
475
- download_kwargs: DownloadKwargs,
476
- user_agent: dict,
496
+ user_agent: dict | None,
477
497
  is_remote_code: bool, # Because we can't determine this inside this function, we need it to be passed in
478
498
  transformers_explicit_filename: str | None = None,
499
+ download_kwargs: DownloadKwargs | None = None,
479
500
  ) -> tuple[list[str] | None, dict | None]:
480
501
  """Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
481
502
  checkpoints are sharded.
482
503
  This function will download the data if necessary.
483
504
  """
505
+ download_kwargs = download_kwargs or DownloadKwargs()
484
506
  cache_dir = download_kwargs.get("cache_dir")
485
507
  force_download = download_kwargs.get("force_download", False)
486
508
  proxies = download_kwargs.get("proxies")
@@ -493,17 +515,19 @@ def _get_resolved_checkpoint_files(
493
515
  if not transformers_explicit_filename.endswith(".safetensors") and not transformers_explicit_filename.endswith(
494
516
  ".safetensors.index.json"
495
517
  ):
496
- raise ValueError(
497
- "The transformers file in the config seems to be incorrect: it is neither a safetensors file "
498
- "(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
499
- f"{transformers_explicit_filename}"
500
- )
518
+ if transformers_explicit_filename != "adapter_model.bin":
519
+ raise ValueError(
520
+ "The transformers file in the config seems to be incorrect: it is neither a safetensors file "
521
+ "(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
522
+ f"{transformers_explicit_filename}"
523
+ )
501
524
 
502
525
  is_sharded = False
503
526
 
504
527
  if pretrained_model_name_or_path is not None and gguf_file is None:
505
528
  pretrained_model_name_or_path = str(pretrained_model_name_or_path)
506
529
  is_local = os.path.isdir(pretrained_model_name_or_path)
530
+ # If the file is a local folder (but not in the HF_HOME cache, even if it's technically local)
507
531
  if is_local:
508
532
  if transformers_explicit_filename is not None:
509
533
  # If the filename is explicitly defined, load this by default.
@@ -562,25 +586,38 @@ def _get_resolved_checkpoint_files(
562
586
  else:
563
587
  filename = _add_variant(WEIGHTS_NAME, variant)
564
588
 
589
+ # Prepare set of kwargs for hub functions
590
+ has_file_kwargs = {
591
+ "revision": revision,
592
+ "proxies": proxies,
593
+ "token": token,
594
+ "cache_dir": cache_dir,
595
+ "local_files_only": local_files_only,
596
+ }
597
+ cached_file_kwargs = {
598
+ "force_download": force_download,
599
+ "user_agent": user_agent,
600
+ "subfolder": subfolder,
601
+ "_raise_exceptions_for_gated_repo": False,
602
+ "_raise_exceptions_for_missing_entries": False,
603
+ "_commit_hash": commit_hash,
604
+ **has_file_kwargs,
605
+ }
606
+ can_auto_convert = (
607
+ not is_offline_mode() # for obvious reasons
608
+ # If we are in a CI environment or in a pytest run, we prevent the conversion
609
+ and not is_env_variable_true("DISABLE_SAFETENSORS_CONVERSION")
610
+ and not is_remote_code # converter bot does not work on remote code
611
+ and subfolder == "" # converter bot does not work on subfolders
612
+ )
613
+
565
614
  try:
566
615
  # Load from URL or cache if already cached
567
- cached_file_kwargs = {
568
- "cache_dir": cache_dir,
569
- "force_download": force_download,
570
- "proxies": proxies,
571
- "local_files_only": local_files_only,
572
- "token": token,
573
- "user_agent": user_agent,
574
- "revision": revision,
575
- "subfolder": subfolder,
576
- "_raise_exceptions_for_gated_repo": False,
577
- "_raise_exceptions_for_missing_entries": False,
578
- "_commit_hash": commit_hash,
579
- }
580
- resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
581
-
582
616
  # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
583
617
  # result when internet is up, the repo and revision exist, but the file does not.
618
+ resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
619
+
620
+ # Try safetensors files first if not already found
584
621
  if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
585
622
  # Maybe the checkpoint is sharded, we try to grab the index name in this case.
586
623
  resolved_archive_file = cached_file(
@@ -591,7 +628,7 @@ def _get_resolved_checkpoint_files(
591
628
  if resolved_archive_file is not None:
592
629
  is_sharded = True
593
630
  elif use_safetensors:
594
- if revision == "main" and not is_offline_mode():
631
+ if revision == "main" and can_auto_convert:
595
632
  resolved_archive_file, revision, is_sharded = auto_conversion(
596
633
  pretrained_model_name_or_path, **cached_file_kwargs
597
634
  )
@@ -608,6 +645,8 @@ def _get_resolved_checkpoint_files(
608
645
  resolved_archive_file = cached_file(
609
646
  pretrained_model_name_or_path, filename, **cached_file_kwargs
610
647
  )
648
+
649
+ # Then try `.bin` files
611
650
  if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
612
651
  # Maybe the checkpoint is sharded, we try to grab the index name in this case.
613
652
  resolved_archive_file = cached_file(
@@ -617,67 +656,38 @@ def _get_resolved_checkpoint_files(
617
656
  )
618
657
  if resolved_archive_file is not None:
619
658
  is_sharded = True
620
- if not local_files_only and not is_offline_mode():
621
- if resolved_archive_file is not None:
622
- # In a CI environment (CircleCI / Github Actions workflow runs) or in a pytest run,
623
- # we set `DISABLE_SAFETENSORS_CONVERSION=true` to prevent the conversion.
624
- if (
625
- filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]
626
- and os.getenv("DISABLE_SAFETENSORS_CONVERSION", None) != "true"
627
- ):
628
- # If the PyTorch file was found, check if there is a safetensors file on the repository
629
- # If there is no safetensors file on the repositories, start an auto conversion
630
- safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
631
- has_file_kwargs = {
632
- "revision": revision,
633
- "proxies": proxies,
634
- "token": token,
635
- "cache_dir": cache_dir,
636
- "local_files_only": local_files_only,
637
- }
638
- cached_file_kwargs = {
639
- "cache_dir": cache_dir,
640
- "force_download": force_download,
641
- "local_files_only": local_files_only,
642
- "user_agent": user_agent,
643
- "subfolder": subfolder,
644
- "_raise_exceptions_for_gated_repo": False,
645
- "_raise_exceptions_for_missing_entries": False,
646
- "_commit_hash": commit_hash,
647
- **has_file_kwargs,
648
- }
649
- if (
650
- not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
651
- and not is_remote_code
652
- ):
653
- Thread(
654
- target=auto_conversion,
655
- args=(pretrained_model_name_or_path,),
656
- kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
657
- name="Thread-auto_conversion",
658
- ).start()
659
+
660
+ # If we have a match, but it's `.bin` format, try to launch safetensors conversion for next time
661
+ if resolved_archive_file is not None:
662
+ safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
663
+ if (
664
+ filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]
665
+ and not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
666
+ and can_auto_convert
667
+ ):
668
+ Thread(
669
+ target=auto_conversion,
670
+ args=(pretrained_model_name_or_path,),
671
+ kwargs={"ignore_errors_during_conversion": False, **cached_file_kwargs},
672
+ name="Thread-auto_conversion",
673
+ ).start()
674
+
675
+ # If no match, raise appropriare errors
676
+ else:
677
+ # Otherwise, no PyTorch file was found
678
+ if variant is not None and has_file(
679
+ pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
680
+ ):
681
+ raise OSError(
682
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
683
+ f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
684
+ f" {variant}. Use `variant=None` to load this model from those weights."
685
+ )
659
686
  else:
660
- # Otherwise, no PyTorch file was found
661
- has_file_kwargs = {
662
- "revision": revision,
663
- "proxies": proxies,
664
- "token": token,
665
- "cache_dir": cache_dir,
666
- "local_files_only": local_files_only,
667
- }
668
- if variant is not None and has_file(
669
- pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
670
- ):
671
- raise OSError(
672
- f"{pretrained_model_name_or_path} does not appear to have a file named"
673
- f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
674
- f" {variant}. Use `variant=None` to load this model from those weights."
675
- )
676
- else:
677
- raise OSError(
678
- f"{pretrained_model_name_or_path} does not appear to have a file named"
679
- f" {_add_variant(WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_NAME, variant)}."
680
- )
687
+ raise OSError(
688
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
689
+ f" {_add_variant(WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_NAME, variant)}."
690
+ )
681
691
 
682
692
  except OSError:
683
693
  # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
@@ -922,7 +932,7 @@ class ModuleUtilsMixin:
922
932
  # Provided a padding mask of dimensions [batch_size, seq_length]
923
933
  # - if the model is a decoder, apply a causal mask in addition to the padding mask
924
934
  # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
925
- if self.config.is_decoder:
935
+ if getattr(self.config, "is_decoder", None):
926
936
  extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
927
937
  input_shape, attention_mask
928
938
  )
@@ -1095,83 +1105,67 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1095
1105
  - **can_record_outputs** (dict):
1096
1106
  """
1097
1107
 
1098
- config_class = None
1099
- base_model_prefix = ""
1100
- main_input_name = "input_ids"
1101
- model_tags = None
1102
-
1103
- _checkpoint_conversion_mapping = {} # used for BC support in VLMs, not meant to be used by new models
1104
-
1108
+ # General model properties
1109
+ config_class: type[PreTrainedConfig] | None = None
1105
1110
  _auto_class = None
1106
- _no_split_modules = None
1107
- _skip_keys_device_placement = None
1108
-
1109
- _keep_in_fp32_modules = None
1110
- # the _keep_in_fp32_modules will avoid casting to anything other than float32, except bfloat16
1111
- # to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
1112
- _keep_in_fp32_modules_strict = None
1113
-
1114
- dtype_plan: dict[str, torch.dtype] | None = None
1115
-
1116
- # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
1117
- # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
1118
- _keys_to_ignore_on_load_missing = None
1119
- # a list of `re` patterns of `state_dict` keys that should be removed from the list of
1120
- # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
1121
- # warnings.
1122
- _keys_to_ignore_on_load_unexpected = None
1123
- # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
1124
- # trained, but which are either deterministic or tied variables)
1125
- _keys_to_ignore_on_save = None
1126
- # a list of `state_dict` keys that are potentially tied to another key in the state_dict.
1127
- _tied_weights_keys = None
1128
-
1129
- supports_gradient_checkpointing = False
1130
- _is_stateful = False
1131
-
1132
- # Flash Attention support
1133
- _supports_flash_attn = False
1134
-
1135
- # SDPA support
1136
- _supports_sdpa = False
1137
-
1138
- # Flex Attention support
1139
- _supports_flex_attn = False
1140
-
1141
- _can_compile_fullgraph = False
1142
-
1143
- # A tensor parallel plan to be applied to the model when TP is enabled. For
1144
- # top-level models, this attribute is currently defined in respective model
1145
- # code. For base models, this attribute comes from
1146
- # `config.base_model_tp_plan` during `__init__`.
1147
- # It should identify the layers exactly: if you want to TP model.language_model.layers.fc1
1148
- # by passing `tp_plan` to the init, it should be {"model.language_model.layers.fc1":"colwise"}
1149
- # for example.
1150
- _tp_plan = None
1151
-
1152
- # tensor parallel degree to which model is sharded to.
1153
- _tp_size = None
1154
-
1155
- # A pipeline parallel plan specifying the layers which may not be present
1156
- # on all ranks when PP is enabled. For top-level models, this attribute is
1157
- # currently defined in respective model code. For base models, this
1158
- # attribute comes from `config.base_model_pp_plan` during `post_init`.
1159
- #
1160
- # The variable names for the inputs and outputs of the specified layers can
1161
- # be indexed using the `PipelineParallel` enum as follows:
1162
- # - `_pp_plan["layers"][PipelineParallel.inputs]`
1163
- # - `_pp_plan["layers"][PipelineParallel.outputs]`
1164
- _pp_plan = None
1111
+ base_model_prefix: str = ""
1112
+ _is_stateful: bool = False
1113
+ model_tags: list[str] | None = None
1165
1114
 
1115
+ # Input-related properties
1116
+ main_input_name: str = "input_ids"
1117
+ # Attributes used mainly in multimodal LLMs, though all models contain a valid field for these
1118
+ # Possible values are: text, image, video, audio and time
1119
+ input_modalities: str | list[str] = "text"
1120
+
1121
+ # Device-map related properties
1122
+ _no_split_modules: set[str] | list[str] | None = None
1123
+ _skip_keys_device_placement: str | list[str] | None = None
1124
+
1125
+ # Specific dtype upcasting
1126
+ # `_keep_in_fp32_modules` will upcast to fp32 only if the requested dtype is fp16
1127
+ # `_keep_in_fp32_modules_strict` will upcast to fp32 independently if the requested dtype is fp16 or bf16
1128
+ _keep_in_fp32_modules: set[str] | list[str] | None = None
1129
+ _keep_in_fp32_modules_strict: set[str] | list[str] | None = None
1130
+
1131
+ # Loading-specific properties
1132
+ # A dictionary `{"target": "source"}` of checkpoint keys that are potentially tied to one another
1133
+ _tied_weights_keys: dict[str, str] = None
1134
+ # Used for BC support in VLMs, not meant to be used by new models
1135
+ _checkpoint_conversion_mapping: dict[str, str] = {}
1136
+ # A list of `re` patterns describing keys to ignore if they are missing from checkpoints to avoid warnings
1137
+ _keys_to_ignore_on_load_missing: list[str] | None = None
1138
+ # A list of `re` patterns describing keys to ignore if they are unexpected in the checkpoints to avoid warnings
1139
+ _keys_to_ignore_on_load_unexpected: list[str] | None = None
1140
+ # A list of keys to ignore when saving the model
1141
+ _keys_to_ignore_on_save: list[str] | None = None
1142
+
1143
+ # Attention interfaces support properties
1144
+ _supports_sdpa: bool = False
1145
+ _supports_flash_attn: bool = False
1146
+ _supports_flex_attn: bool = False
1147
+
1148
+ # Tensor-parallelism-related properties
1149
+ # A tensor parallel plan of the form `{"model.layer.mlp.param": "colwise"}` to be applied to the model when TP is enabled.
1150
+ # For top-level models, this attribute is currently defined in respective model code. For base models, this attribute comes
1151
+ # from `config.base_model_tp_plan` during `post_init`.
1152
+ _tp_plan: dict[str, str] = None
1153
+ # Tensor parallel degree to which model is sharded to
1154
+ _tp_size = None
1155
+ # A pipeline parallel plan specifying the layers which may not be present on all ranks when PP is enabled. For top-level
1156
+ # models, this attribute is currently defined in respective model code. For base models, it comes from
1157
+ # `config.base_model_pp_plan` during `post_init`.
1158
+ _pp_plan: dict[str, PipelineParallel] | None = None
1159
+
1160
+ # Advanced functionalities support
1161
+ supports_gradient_checkpointing: bool = False
1162
+ _can_compile_fullgraph: bool = False
1166
1163
  # This flag signal that the model can be used as an efficient backend in TGI and vLLM
1167
1164
  # In practice, it means that they support attention (mask) interface functions, fully pass the kwargs
1168
1165
  # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
1169
- _supports_attention_backend = False
1170
- _can_record_outputs = None
1171
-
1172
- # Attributes used mainly in multimodal LLMs, though all models contain a valid field for these
1173
- # Possible values are: text, image, video, audio and time
1174
- input_modalities: str | list[str] = "text" # most models are text
1166
+ _supports_attention_backend: bool = False
1167
+ # A mapping describing what outputs can be captured by `check_model_inputs` decorator during the forward pass
1168
+ _can_record_outputs: dict | None = None
1175
1169
 
1176
1170
  @property
1177
1171
  @torch._dynamo.allow_in_graph
@@ -1256,6 +1250,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1256
1250
  f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
1257
1251
  )
1258
1252
  self.config = config
1253
+ self.name_or_path = config.name_or_path
1259
1254
 
1260
1255
  # Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
1261
1256
  # setting it recursively)
@@ -1281,38 +1276,33 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1281
1276
  loss_type = None
1282
1277
  self.loss_type = loss_type
1283
1278
 
1284
- self.name_or_path = config.name_or_path
1285
- self.warnings_issued = {}
1286
- # Overwrite the class attribute to make it an instance attribute, so models like
1287
- # `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
1288
- # when a different component (e.g. language_model) is used.
1289
- self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
1290
- self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
1291
- self.dtype_plan = {}
1292
-
1293
- if isinstance(self._keep_in_fp32_modules, list):
1294
- self.dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules, torch.float32))
1295
- if isinstance(self._keep_in_fp32_modules_strict, list):
1296
- self.dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules_strict, torch.float32))
1297
-
1298
- self._no_split_modules = self._no_split_modules or []
1299
1279
  _CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only
1300
1280
 
1301
1281
  def post_init(self):
1302
1282
  """
1303
1283
  A method executed at the end of each Transformer model initialization, to execute code that needs the model's
1304
1284
  modules properly initialized (such as weight initialization).
1285
+ It is also used to obtain all correct static properties (parallelism plans, tied_weights_keys, _keep_in_fp32_modules, etc)
1286
+ correctly in the case of composite models (that is, the top level model should know about those properties from its children).
1305
1287
  """
1306
1288
  # Attach the different parallel plans and tied weight keys to the top-most model, so that everything is
1307
1289
  # easily available
1308
1290
  self._tp_plan, self._ep_plan, self._pp_plan = {}, {}, {}
1309
- # Current submodel should register its tied weights
1310
- self.all_tied_weights_keys = self.get_expanded_tied_weights_keys(all_submodels=False)
1311
1291
  # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
1312
1292
  if self.base_model is self:
1313
1293
  self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {}
1314
1294
  self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
1315
1295
  self._ep_plan = self.config.base_model_ep_plan.copy() if self.config.base_model_ep_plan is not None else {}
1296
+ # Current submodel should register its tied weights
1297
+ self.all_tied_weights_keys = self.get_expanded_tied_weights_keys(all_submodels=False)
1298
+ # Current submodel should register its `_keep_in_fp32_modules`
1299
+ self._keep_in_fp32_modules = set(self._keep_in_fp32_modules or [])
1300
+ self._keep_in_fp32_modules_strict = set(self._keep_in_fp32_modules_strict or [])
1301
+ # Current submodel must register its `_no_split_modules` as well
1302
+ self._no_split_modules = set(self._no_split_modules or [])
1303
+
1304
+ # Iterate over children only: as the final model is created, this is enough to gather the properties from all submodels.
1305
+ # This works because the way the `__init__` and `post_init` are called on all submodules is depth-first in the graph
1316
1306
  for name, module in self.named_children():
1317
1307
  # Parallel plans
1318
1308
  if plan := getattr(module, "_ep_plan", None):
@@ -1324,6 +1314,14 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1324
1314
  # Always attach the keys of the children (if the children's config says to NOT tie, then it's empty)
1325
1315
  if tied_keys := getattr(module, "all_tied_weights_keys", None):
1326
1316
  self.all_tied_weights_keys.update({f"{name}.{k}": f"{name}.{v}" for k, v in tied_keys.copy().items()})
1317
+ # Record keep_in_fp_32 modules from the children as well
1318
+ if keep_fp32 := getattr(module, "_keep_in_fp32_modules", None):
1319
+ self._keep_in_fp32_modules.update(keep_fp32)
1320
+ if keep_fp32_strict := getattr(module, "_keep_in_fp32_modules_strict", None):
1321
+ self._keep_in_fp32_modules_strict.update(keep_fp32_strict)
1322
+ # Record `_no_split_modules` from the children
1323
+ if no_split := getattr(module, "_no_split_modules", None):
1324
+ self._no_split_modules.update(no_split)
1327
1325
 
1328
1326
  # Maybe initialize the weights and tie the keys
1329
1327
  self.init_weights()
@@ -1842,7 +1840,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1842
1840
  )
1843
1841
 
1844
1842
  # preload flash attention here to allow compile with fullgraph
1845
- if "flash" in applicable_attn_implementation:
1843
+ if is_flash_attention_requested(requested_attention_implementation=applicable_attn_implementation):
1846
1844
  lazy_import_flash_attention(applicable_attn_implementation)
1847
1845
 
1848
1846
  return applicable_attn_implementation
@@ -1919,15 +1917,16 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1919
1917
  """Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
1920
1918
  opening the file, but avoids maintaining yet another property flag.
1921
1919
  """
1922
- class_file = sys.modules[cls.__module__].__file__
1923
- with open(class_file, "r") as f:
1920
+ class_module = sys.modules[cls.__module__]
1921
+ # This can happen for a custom model in a jupyter notebook or repl for example - simply do not allow to set it then
1922
+ if not hasattr(class_module, "__file__"):
1923
+ return False
1924
+ class_file = class_module.__file__
1925
+ with open(class_file, "r", encoding="utf-8") as f:
1924
1926
  code = f.read()
1925
1927
  # heuristic -> if we find those patterns, the model uses the correct interface
1926
1928
  if re.search(r"class \w+Attention\(nn.Module\)", code):
1927
- return (
1928
- "eager_attention_forward" in code
1929
- and "ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]" in code
1930
- )
1929
+ return "eager_attention_forward" in code and "ALL_ATTENTION_FUNCTIONS.get_interface(" in code
1931
1930
  else:
1932
1931
  # If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
1933
1932
  return True
@@ -1937,8 +1936,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1937
1936
  """Detect whether the class supports setting its experts implementation dynamically. It is an ugly check based on
1938
1937
  opening the file, but avoids maintaining yet another property flag.
1939
1938
  """
1940
- class_file = sys.modules[cls.__module__].__file__
1941
- with open(class_file, "r") as f:
1939
+ class_module = sys.modules[cls.__module__]
1940
+ # This can happen for a custom model in a jupyter notebook or repl for example - simply do not allow to set it then
1941
+ if not hasattr(class_module, "__file__"):
1942
+ return False
1943
+ class_file = class_module.__file__
1944
+ with open(class_file, "r", encoding="utf-8") as f:
1942
1945
  code = f.read()
1943
1946
  # heuristic -> if we the use_experts_implementation decorator is used, then we can set it
1944
1947
  return "@use_experts_implementation" in code
@@ -2404,7 +2407,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2404
2407
 
2405
2408
  tied_mapping = self._tied_weights_keys
2406
2409
  # If the config does not specify any tying, return empty dict
2407
- if not self.config.tie_word_embeddings:
2410
+ # NOTE: not all modules have `tie_word_embeddings` attr, for example vision-only
2411
+ # modules do not have any word embeddings!
2412
+ tie_word_embeddings = getattr(self.config, "tie_word_embeddings", False)
2413
+ if not tie_word_embeddings:
2408
2414
  return {}
2409
2415
  # If None, return empty dict
2410
2416
  elif tied_mapping is None:
@@ -2542,35 +2548,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2542
2548
  if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
2543
2549
  output_embeddings.out_features = input_embeddings.num_embeddings
2544
2550
 
2545
- def _get_no_split_modules(self, device_map: str):
2546
- """
2547
- Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
2548
- get the underlying `_no_split_modules`.
2549
-
2550
- Args:
2551
- device_map (`str`):
2552
- The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
2553
-
2554
- Returns:
2555
- `list[str]`: List of modules that should not be split
2556
- """
2557
- _no_split_modules = set()
2558
- modules_to_check = [self]
2559
- while len(modules_to_check) > 0:
2560
- module = modules_to_check.pop(-1)
2561
- # if the module does not appear in _no_split_modules, we also check the children
2562
- if module.__class__.__name__ not in _no_split_modules:
2563
- if isinstance(module, PreTrainedModel):
2564
- if module._no_split_modules is None:
2565
- raise ValueError(
2566
- f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
2567
- "class needs to implement the `_no_split_modules` attribute."
2568
- )
2569
- else:
2570
- _no_split_modules = _no_split_modules | set(module._no_split_modules)
2571
- modules_to_check += list(module.children())
2572
- return list(_no_split_modules)
2573
-
2574
2551
  def resize_token_embeddings(
2575
2552
  self,
2576
2553
  new_num_tokens: int | None = None,
@@ -2654,10 +2631,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2654
2631
  new_num_tokens = new_embeddings.weight.shape[0]
2655
2632
 
2656
2633
  # if word embeddings are not tied, make sure that lm head is resized as well
2657
- if (
2658
- self.get_output_embeddings() is not None
2659
- and not self.config.get_text_config(decoder=True).tie_word_embeddings
2660
- ):
2634
+ if self.get_output_embeddings() is not None:
2661
2635
  old_lm_head = self.get_output_embeddings()
2662
2636
  if isinstance(old_lm_head, torch.nn.Embedding):
2663
2637
  new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
@@ -3038,15 +3012,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3038
3012
 
3039
3013
  def init_weights(self):
3040
3014
  """
3041
- Maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
3015
+ Initialize and tie the weights if needed. If using a custom `PreTrainedModel`, you need to implement any
3042
3016
  initialization logic in `_init_weights`.
3043
3017
  """
3044
3018
  # If we are initializing on meta device, there is no point in trying to run inits
3045
3019
  if get_torch_context_manager_or_global_device() != torch.device("meta"):
3046
3020
  # Initialize weights
3047
3021
  self.initialize_weights()
3048
- # Tie weights needs to be called here, but it can use the pre-computed `all_tied_weights_keys`
3049
- self.tie_weights(recompute_mapping=False)
3022
+ # Tie weights needs to be called here, but it can use the pre-computed `all_tied_weights_keys`
3023
+ self.tie_weights(recompute_mapping=False)
3050
3024
 
3051
3025
  def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
3052
3026
  """
@@ -3063,7 +3037,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3063
3037
  raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
3064
3038
 
3065
3039
  if gradient_checkpointing_kwargs is None:
3066
- gradient_checkpointing_kwargs = {"use_reentrant": True}
3040
+ gradient_checkpointing_kwargs = {"use_reentrant": False}
3067
3041
 
3068
3042
  gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
3069
3043
 
@@ -3316,16 +3290,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3316
3290
  if ignore_key in state_dict:
3317
3291
  del state_dict[ignore_key]
3318
3292
 
3319
- # If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used,
3320
- # therefore we replace them with DTensors that are equivalently sharded
3293
+ # If model was sharded with TP, gather full tensors for saving
3321
3294
  if self._tp_size is not None:
3322
- state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
3295
+ state_dict = gather_state_dict_for_save(state_dict, self._tp_plan, self._device_mesh, self._tp_size)
3323
3296
 
3324
3297
  # Remove tied weights as safetensors do not handle them
3325
3298
  state_dict = remove_tied_weights_from_state_dict(state_dict, model_to_save)
3326
3299
 
3327
3300
  # Revert all renaming and/or weight operations
3328
- if save_original_format:
3301
+ if save_original_format and not _hf_peft_config_loaded:
3329
3302
  state_dict = revert_weight_conversion(model_to_save, state_dict)
3330
3303
 
3331
3304
  # Shard the model if it is too big.
@@ -3377,13 +3350,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3377
3350
  # Get the tensor, and remove it from state_dict to avoid keeping the ref
3378
3351
  tensor = state_dict.pop(tensor_name)
3379
3352
 
3380
- # In case of TP, get the full parameter back
3381
- if _is_dtensor_available and isinstance(tensor, DTensor):
3382
- tensor = tensor.full_tensor()
3383
- # to get the correctly ordered tensor we need to repack if packed
3384
- if _get_parameter_tp_plan(tensor_name, self._tp_plan) == "local_packed_rowwise":
3385
- tensor = repack_weights(tensor, -1, self._tp_size, 2)
3386
-
3387
3353
  # If the param was offloaded, we need to load it back from disk to resave it. It's a strange pattern,
3388
3354
  # but it would otherwise not be contained in the saved shard if we were to simply move the file
3389
3355
  # or something
@@ -3541,10 +3507,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3541
3507
  " desired `dtype` by passing the correct `dtype` argument."
3542
3508
  )
3543
3509
 
3544
- if getattr(self, "is_loaded_in_8bit", False):
3510
+ if getattr(self, "is_loaded_in_8bit", False) and not is_bitsandbytes_available("0.48"):
3545
3511
  raise ValueError(
3546
- "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
3547
- " model has already been set to the correct devices and casted to the correct `dtype`."
3512
+ "You need to install `pip install bitsandbytes>=0.48.0` if you want to move a 8-bit model across devices using to()."
3548
3513
  )
3549
3514
  elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
3550
3515
  if dtype_present_in_args:
@@ -3577,7 +3542,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3577
3542
  @classmethod
3578
3543
  def get_init_context(cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool):
3579
3544
  # Need to instantiate with correct dtype
3580
- init_contexts = [local_torch_dtype(dtype, cls.__name__)]
3545
+ init_contexts = [local_torch_dtype(dtype, cls.__name__), init.no_tie_weights()]
3581
3546
  if is_deepspeed_zero3_enabled():
3582
3547
  import deepspeed
3583
3548
 
@@ -3598,7 +3563,31 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3598
3563
 
3599
3564
  return init_contexts
3600
3565
 
3601
- def set_use_kernels(self, use_kernels, kernel_config):
3566
+ def _get_dtype_plan(self, dtype: torch.dtype) -> dict:
3567
+ """Create the dtype_plan describing modules/parameters that should use the `keep_in_fp32` flag."""
3568
+ dtype_plan = {}
3569
+
3570
+ # The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
3571
+ # in case of force loading a model that should stay in bf16 in fp16
3572
+ # See https://github.com/huggingface/transformers/issues/20287 for details.
3573
+ if self._keep_in_fp32_modules is not None and dtype == torch.float16:
3574
+ dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules, torch.float32))
3575
+
3576
+ # The _keep_in_fp32_modules_strict was introduced to always force upcast to fp32, for both fp16 and bf16
3577
+ if self._keep_in_fp32_modules_strict is not None and dtype in (torch.float16, torch.bfloat16):
3578
+ dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules_strict, torch.float32))
3579
+
3580
+ return dtype_plan
3581
+
3582
+ def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None):
3583
+ """
3584
+ Set whether or not to use the `kernels` library to kernelize some layers of the model.
3585
+ Args:
3586
+ use_kernels (`bool`):
3587
+ Whether or not to use the `kernels` library to kernelize some layers of the model.
3588
+ kernel_config (`KernelConfig`, *optional*):
3589
+ The kernel configuration to use to kernelize the model. If `None`, the default kernel mapping will be used.
3590
+ """
3602
3591
  if use_kernels:
3603
3592
  if not is_kernels_available():
3604
3593
  raise ValueError(
@@ -3641,7 +3630,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3641
3630
  local_files_only: bool = False,
3642
3631
  token: str | bool | None = None,
3643
3632
  revision: str = "main",
3644
- use_safetensors: bool | None = True,
3633
+ use_safetensors: bool | None = None,
3645
3634
  weights_only: bool = True,
3646
3635
  **kwargs,
3647
3636
  ) -> SpecificPreTrainedModelType:
@@ -4040,6 +4029,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4040
4029
  use_kernels=use_kernels,
4041
4030
  )
4042
4031
 
4032
+ # Create the dtype_plan to potentially use the `keep_in_fp32` flags (this needs to be called on the already
4033
+ # instantiated model, as the flags can be modified by instances sometimes)
4034
+ dtype_plan = model._get_dtype_plan(dtype)
4035
+
4043
4036
  # Obtain the weight conversion mapping for this model if any are registered
4044
4037
  weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer)
4045
4038
 
@@ -4051,29 +4044,30 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4051
4044
  device_map = _get_device_map(model, device_map, max_memory, hf_quantizer)
4052
4045
 
4053
4046
  # Finalize model weight initialization
4054
- model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs = cls._load_pretrained_model(
4055
- model,
4056
- state_dict,
4057
- checkpoint_files,
4058
- pretrained_model_name_or_path,
4047
+ load_config = LoadStateDictConfig(
4048
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
4059
4049
  ignore_mismatched_sizes=ignore_mismatched_sizes,
4060
4050
  sharded_metadata=sharded_metadata,
4061
4051
  device_map=device_map,
4062
4052
  disk_offload_folder=offload_folder,
4063
4053
  offload_buffers=offload_buffers,
4064
4054
  dtype=dtype,
4055
+ dtype_plan=dtype_plan,
4065
4056
  hf_quantizer=hf_quantizer,
4066
4057
  device_mesh=device_mesh,
4067
4058
  weights_only=weights_only,
4068
4059
  weight_mapping=weight_conversions,
4060
+ use_safetensors=use_safetensors,
4061
+ download_kwargs=download_kwargs,
4069
4062
  )
4070
-
4063
+ loading_info, disk_offload_index = cls._load_pretrained_model(model, state_dict, checkpoint_files, load_config)
4064
+ loading_info = cls._finalize_model_loading(model, load_config, loading_info)
4071
4065
  model.eval() # Set model in evaluation mode to deactivate Dropout modules by default
4072
4066
  model.set_use_kernels(use_kernels, kernel_config)
4073
4067
 
4074
4068
  # If it is a model with generation capabilities, attempt to load generation files (generation config,
4075
4069
  # custom generate function)
4076
- if model.can_generate() and hasattr(model, "adjust_generation_fn"):
4070
+ if model.can_generate() and hasattr(model, "adjust_generation_fn") and not gguf_file:
4077
4071
  model.adjust_generation_fn(
4078
4072
  generation_config,
4079
4073
  from_auto_class,
@@ -4086,7 +4080,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4086
4080
 
4087
4081
  # If the device_map has more than 1 device: dispatch model with hooks on all devices
4088
4082
  if device_map is not None and len(set(device_map.values())) > 1:
4089
- accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers)
4083
+ accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, disk_offload_index, offload_buffers)
4090
4084
 
4091
4085
  if hf_quantizer is not None:
4092
4086
  model.hf_quantizer = hf_quantizer
@@ -4095,44 +4089,29 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4095
4089
  ) # usually a no-op but sometimes needed, e.g to remove the quant config when dequantizing
4096
4090
 
4097
4091
  if _adapter_model_path is not None:
4098
- adapter_kwargs["key_mapping"] = key_mapping
4099
- model.load_adapter(
4092
+ if token is not None:
4093
+ adapter_kwargs["token"] = token
4094
+ loading_info = model.load_adapter(
4100
4095
  _adapter_model_path,
4101
4096
  adapter_name=adapter_name,
4102
- token=token,
4097
+ load_config=load_config,
4103
4098
  adapter_kwargs=adapter_kwargs,
4104
4099
  )
4105
4100
 
4106
4101
  if output_loading_info:
4107
- loading_info = {
4108
- "missing_keys": missing_keys,
4109
- "unexpected_keys": unexpected_keys,
4110
- "mismatched_keys": mismatched_keys,
4111
- "error_msgs": error_msgs,
4112
- }
4113
- return model, loading_info
4102
+ return model, loading_info.to_dict()
4114
4103
  return model
4115
4104
 
4116
- @classmethod
4105
+ @staticmethod
4117
4106
  def _load_pretrained_model(
4118
- cls,
4119
4107
  model: "PreTrainedModel",
4120
4108
  state_dict: dict | None,
4121
4109
  checkpoint_files: list[str] | None,
4122
- pretrained_model_name_or_path: str | None,
4123
- ignore_mismatched_sizes: bool = False,
4124
- sharded_metadata: dict | None = None,
4125
- device_map: dict | None = None,
4126
- disk_offload_folder: str | None = None,
4127
- offload_buffers: bool = False,
4128
- dtype: torch.dtype | None = None,
4129
- hf_quantizer: HfQuantizer | None = None,
4130
- device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
4131
- weights_only: bool = True,
4132
- weight_mapping: Sequence[WeightConverter | WeightRenaming] | None = None,
4133
- ):
4134
- is_quantized = hf_quantizer is not None
4135
- is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
4110
+ load_config: LoadStateDictConfig,
4111
+ ) -> tuple[LoadStateDictInfo, dict]:
4112
+ """Perform the actual loading of some checkpoints into a `model`, by reading them from disk and dispatching them accordingly."""
4113
+ is_quantized = load_config.is_quantized
4114
+ is_hqq_or_quark = is_quantized and load_config.hf_quantizer.quantization_config.quant_method in {
4136
4115
  QuantizationMethod.HQQ,
4137
4116
  QuantizationMethod.QUARK,
4138
4117
  }
@@ -4146,21 +4125,21 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4146
4125
  # This offload index if for params explicitly on the "disk" in the device_map
4147
4126
  disk_offload_index = None
4148
4127
  # Prepare parameters offloading if needed
4149
- if device_map is not None and "disk" in device_map.values():
4128
+ if load_config.device_map is not None and "disk" in load_config.device_map.values():
4150
4129
  disk_offload_index = accelerate_disk_offload(
4151
4130
  model,
4152
- disk_offload_folder,
4131
+ load_config.disk_offload_folder,
4153
4132
  checkpoint_files,
4154
- device_map,
4155
- sharded_metadata,
4156
- dtype,
4157
- weight_mapping,
4133
+ load_config.device_map,
4134
+ load_config.sharded_metadata,
4135
+ load_config.dtype,
4136
+ load_config.weight_mapping,
4158
4137
  )
4159
4138
 
4160
4139
  # Warmup cuda to load the weights much faster on devices
4161
- if device_map is not None and not is_hqq_or_quark:
4162
- expanded_device_map = expand_device_map(device_map, expected_keys)
4163
- caching_allocator_warmup(model, expanded_device_map, hf_quantizer)
4140
+ if load_config.device_map is not None and not is_hqq_or_quark:
4141
+ expanded_device_map = expand_device_map(load_config.device_map, expected_keys)
4142
+ caching_allocator_warmup(model, expanded_device_map, load_config.hf_quantizer)
4164
4143
 
4165
4144
  error_msgs = []
4166
4145
 
@@ -4168,24 +4147,30 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4168
4147
  if state_dict is None:
4169
4148
  merged_state_dict = {}
4170
4149
  for ckpt_file in checkpoint_files:
4171
- merged_state_dict.update(load_state_dict(ckpt_file, map_location="cpu", weights_only=weights_only))
4150
+ merged_state_dict.update(
4151
+ load_state_dict(ckpt_file, map_location="cpu", weights_only=load_config.weights_only)
4152
+ )
4172
4153
  state_dict = merged_state_dict
4173
- error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict)
4154
+ error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict, load_config)
4174
4155
  # This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
4175
- unexpected_keys, mismatched_keys, conversion_errors = set(), set(), set()
4156
+ loading_info = LoadStateDictInfo(
4157
+ missing_keys=missing_keys,
4158
+ error_msgs=error_msgs,
4159
+ unexpected_keys=set(),
4160
+ mismatched_keys=set(),
4161
+ conversion_errors={},
4162
+ )
4176
4163
  else:
4177
4164
  all_pointer = set()
4178
- # Checkpoints are safetensors
4179
- if checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors"):
4165
+ if state_dict is not None:
4166
+ merged_state_dict = state_dict
4167
+ elif checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") and state_dict is None:
4180
4168
  merged_state_dict = {}
4181
4169
  for file in checkpoint_files:
4182
4170
  file_pointer = safe_open(file, framework="pt", device="cpu")
4183
4171
  all_pointer.add(file_pointer)
4184
4172
  for k in file_pointer.keys():
4185
4173
  merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet
4186
- # User passed an explicit state_dict
4187
- elif state_dict is not None:
4188
- merged_state_dict = state_dict
4189
4174
  # Checkpoints are .bin
4190
4175
  elif checkpoint_files is not None:
4191
4176
  merged_state_dict = {}
@@ -4194,58 +4179,58 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4194
4179
  else:
4195
4180
  raise ValueError("Neither a state dict nor checkpoint files were found.")
4196
4181
 
4197
- missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, conversion_errors = (
4198
- convert_and_load_state_dict_in_model(
4199
- model=model,
4200
- state_dict=merged_state_dict,
4201
- weight_mapping=weight_mapping,
4202
- tp_plan=model._tp_plan,
4203
- hf_quantizer=hf_quantizer,
4204
- dtype=dtype,
4205
- device_map=device_map,
4206
- dtype_plan=model.dtype_plan,
4207
- device_mesh=device_mesh,
4208
- disk_offload_index=disk_offload_index,
4209
- disk_offload_folder=disk_offload_folder,
4210
- offload_buffers=offload_buffers,
4211
- )
4182
+ loading_info, disk_offload_index = convert_and_load_state_dict_in_model(
4183
+ model=model,
4184
+ state_dict=merged_state_dict,
4185
+ load_config=load_config,
4186
+ tp_plan=model._tp_plan,
4187
+ disk_offload_index=disk_offload_index,
4212
4188
  )
4213
4189
 
4214
4190
  # finally close all opened file pointers
4215
4191
  for k in all_pointer:
4216
4192
  k.__exit__(None, None, None)
4217
4193
 
4218
- # Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
4219
- model.mark_tied_weights_as_initialized()
4220
-
4221
- # Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from
4222
- # meta device (because they were not moved when loading the weights as they were not in the loaded state dict)
4223
- missing_and_mismatched = missing_keys | {k[0] for k in mismatched_keys}
4224
- model._move_missing_keys_from_meta_to_device(missing_and_mismatched, device_map, device_mesh, hf_quantizer)
4225
-
4226
- # Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialized` flag)
4227
- model._initialize_missing_keys(is_quantized)
4194
+ return loading_info, disk_offload_index
4228
4195
 
4229
- # Tie the weights
4230
- model.tie_weights(missing_keys=missing_keys, recompute_mapping=False)
4196
+ @staticmethod
4197
+ def _finalize_model_loading(
4198
+ model, load_config: LoadStateDictConfig, loading_info: LoadStateDictInfo
4199
+ ) -> LoadStateDictInfo:
4200
+ """Perform all post processing operations after having loaded some checkpoints into a model, such as moving
4201
+ missing keys from meta device to their expected device, reinitializing missing weights according to proper
4202
+ distributions, tying the weights and logging the loading report."""
4203
+ try:
4204
+ # Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
4205
+ model.mark_tied_weights_as_initialized()
4206
+
4207
+ # Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from
4208
+ # meta device (because they were not moved when loading the weights as they were not in the loaded state dict)
4209
+ model._move_missing_keys_from_meta_to_device(
4210
+ loading_info.missing_and_mismatched(),
4211
+ load_config.device_map,
4212
+ load_config.device_mesh,
4213
+ load_config.hf_quantizer,
4214
+ )
4231
4215
 
4232
- # Adjust missing and unexpected keys
4233
- missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(missing_keys, unexpected_keys)
4216
+ # Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialized` flag)
4217
+ model._initialize_missing_keys(load_config.is_quantized)
4234
4218
 
4235
- log_state_dict_report(
4236
- model=model,
4237
- pretrained_model_name_or_path=pretrained_model_name_or_path,
4238
- logger=logger,
4239
- error_msgs=error_msgs,
4240
- unexpected_keys=unexpected_keys,
4241
- missing_keys=missing_keys,
4242
- mismatched_keys=mismatched_keys,
4243
- mismatched_shapes=mismatched_keys,
4244
- conversion_errors=conversion_errors,
4245
- ignore_mismatched_sizes=ignore_mismatched_sizes,
4246
- )
4219
+ # Tie the weights
4220
+ model.tie_weights(missing_keys=loading_info.missing_keys, recompute_mapping=False)
4221
+
4222
+ # Adjust missing and unexpected keys
4223
+ model._adjust_missing_and_unexpected_keys(loading_info)
4224
+ finally:
4225
+ log_state_dict_report(
4226
+ model=model,
4227
+ pretrained_model_name_or_path=load_config.pretrained_model_name_or_path,
4228
+ ignore_mismatched_sizes=load_config.ignore_mismatched_sizes,
4229
+ loading_info=loading_info,
4230
+ logger=logger,
4231
+ )
4247
4232
 
4248
- return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs
4233
+ return loading_info
4249
4234
 
4250
4235
  def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
4251
4236
  module_keys = {".".join(key.split(".")[:-1]) for key in names}
@@ -4314,15 +4299,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4314
4299
 
4315
4300
  # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an
4316
4301
  # attention_mask or not. In this case, we should still show a warning because this is a rare case.
4302
+ # NOTE: `sep_token_id` is not used in all models and it can be absent in the config
4303
+ sep_token_id = getattr(self.config, "sep_token_id", None)
4317
4304
  if (
4318
4305
  (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id)
4319
4306
  or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id)
4320
- or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id)
4307
+ or (sep_token_id is not None and sep_token_id == self.config.pad_token_id)
4321
4308
  ):
4322
4309
  warn_string += (
4323
4310
  f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical "
4324
4311
  f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), "
4325
- f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded."
4312
+ f"or the `sep_token_id` ({sep_token_id}), and your input is not padded."
4326
4313
  )
4327
4314
 
4328
4315
  logger.warning_once(warn_string)
@@ -4499,11 +4486,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4499
4486
  else:
4500
4487
  self.initialize_weights()
4501
4488
 
4502
- def _adjust_missing_and_unexpected_keys(
4503
- self, missing_keys: set[str], unexpected_keys: set[str]
4504
- ) -> tuple[set[str], set[str]]:
4489
+ def _adjust_missing_and_unexpected_keys(self, loading_info: LoadStateDictInfo) -> None:
4505
4490
  """Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid
4506
- raising unneeded warnings/errors.
4491
+ raising unneeded warnings/errors. This is performed in-place.
4507
4492
  """
4508
4493
  # Old checkpoints may have keys for rotary_emb.inv_freq forach layer, however we moved this buffer to the main model
4509
4494
  # (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
@@ -4521,13 +4506,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4521
4506
 
4522
4507
  # Clean-up missing keys
4523
4508
  if ignore_missing_regex is not None:
4524
- missing_keys = {key for key in missing_keys if ignore_missing_regex.search(key) is None}
4509
+ loading_info.missing_keys = {
4510
+ key for key in loading_info.missing_keys if ignore_missing_regex.search(key) is None
4511
+ }
4525
4512
 
4526
4513
  # Clean-up unexpected keys
4527
4514
  if ignore_unexpected_regex is not None:
4528
- unexpected_keys = {key for key in unexpected_keys if ignore_unexpected_regex.search(key) is None}
4529
-
4530
- return missing_keys, unexpected_keys
4515
+ loading_info.unexpected_keys = {
4516
+ key for key in loading_info.unexpected_keys if ignore_unexpected_regex.search(key) is None
4517
+ }
4531
4518
 
4532
4519
  def mark_tied_weights_as_initialized(self):
4533
4520
  """Adds the `_is_hf_initialized` flag on parameters that will be tied, in order to avoid initializing them
@@ -4709,7 +4696,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
4709
4696
  ) - torch_accelerator_module.memory_allocated(index)
4710
4697
  byte_count = int(max(0, byte_count - unused_memory))
4711
4698
  # We divide by 2 here as we allocate in fp16
4712
- _ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)
4699
+ _ = torch.empty(int(byte_count // 2), dtype=torch.float16, device=device, requires_grad=False)
4713
4700
 
4714
4701
 
4715
4702
  class AttentionInterface(GeneralInterface):
@@ -4732,6 +4719,20 @@ class AttentionInterface(GeneralInterface):
4732
4719
  "paged|eager": eager_paged_attention_forward,
4733
4720
  }
4734
4721
 
4722
+ def get_interface(self, attn_implementation: str, default: Callable) -> Callable:
4723
+ """Return the requested `attn_implementation`. Also strictly check its validity, and raise if invalid."""
4724
+ if attn_implementation is None:
4725
+ logger.warning_once(
4726
+ "You tried to access the `AttentionInterface` with a `config._attn_implementation` set to `None`. This "
4727
+ "is expected if you use an Attention Module as a standalone Module. If this is not the case, something went "
4728
+ "wrong with the dispatch of `config._attn_implementation`"
4729
+ )
4730
+ elif attn_implementation != "eager" and attn_implementation not in self:
4731
+ raise KeyError(
4732
+ f"`{attn_implementation}` is not a valid attention implementation registered in the `AttentionInterface`"
4733
+ )
4734
+ return super().get(attn_implementation, default)
4735
+
4735
4736
 
4736
4737
  # Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
4737
4738
  ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()