transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1594) hide show
  1. transformers/__init__.py +11 -37
  2. transformers/activations.py +2 -2
  3. transformers/audio_utils.py +32 -32
  4. transformers/backbone_utils.py +326 -0
  5. transformers/cache_utils.py +26 -126
  6. transformers/cli/chat.py +3 -3
  7. transformers/cli/serve.py +13 -10
  8. transformers/cli/transformers.py +2 -1
  9. transformers/configuration_utils.py +22 -92
  10. transformers/conversion_mapping.py +150 -26
  11. transformers/convert_slow_tokenizer.py +9 -12
  12. transformers/core_model_loading.py +217 -129
  13. transformers/data/processors/glue.py +0 -1
  14. transformers/data/processors/utils.py +0 -1
  15. transformers/data/processors/xnli.py +0 -1
  16. transformers/dependency_versions_check.py +0 -1
  17. transformers/dependency_versions_table.py +10 -11
  18. transformers/distributed/configuration_utils.py +1 -2
  19. transformers/dynamic_module_utils.py +23 -23
  20. transformers/feature_extraction_sequence_utils.py +19 -23
  21. transformers/feature_extraction_utils.py +14 -14
  22. transformers/file_utils.py +0 -2
  23. transformers/generation/candidate_generator.py +2 -4
  24. transformers/generation/configuration_utils.py +54 -39
  25. transformers/generation/continuous_batching/__init__.py +0 -1
  26. transformers/generation/continuous_batching/cache.py +74 -44
  27. transformers/generation/continuous_batching/cache_manager.py +28 -28
  28. transformers/generation/continuous_batching/continuous_api.py +133 -414
  29. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  30. transformers/generation/continuous_batching/requests.py +77 -19
  31. transformers/generation/continuous_batching/scheduler.py +154 -104
  32. transformers/generation/logits_process.py +10 -133
  33. transformers/generation/stopping_criteria.py +1 -2
  34. transformers/generation/streamers.py +0 -1
  35. transformers/generation/utils.py +91 -121
  36. transformers/generation/watermarking.py +2 -3
  37. transformers/hf_argparser.py +9 -13
  38. transformers/hyperparameter_search.py +1 -2
  39. transformers/image_processing_base.py +9 -9
  40. transformers/image_processing_utils.py +11 -15
  41. transformers/image_processing_utils_fast.py +70 -71
  42. transformers/image_transforms.py +73 -42
  43. transformers/image_utils.py +30 -37
  44. transformers/initialization.py +57 -0
  45. transformers/integrations/__init__.py +10 -24
  46. transformers/integrations/accelerate.py +47 -11
  47. transformers/integrations/awq.py +1 -3
  48. transformers/integrations/deepspeed.py +146 -4
  49. transformers/integrations/eetq.py +0 -1
  50. transformers/integrations/executorch.py +2 -6
  51. transformers/integrations/fbgemm_fp8.py +1 -2
  52. transformers/integrations/finegrained_fp8.py +149 -13
  53. transformers/integrations/flash_attention.py +3 -8
  54. transformers/integrations/flex_attention.py +1 -1
  55. transformers/integrations/fp_quant.py +4 -6
  56. transformers/integrations/ggml.py +0 -1
  57. transformers/integrations/hub_kernels.py +18 -7
  58. transformers/integrations/integration_utils.py +2 -3
  59. transformers/integrations/moe.py +226 -106
  60. transformers/integrations/mxfp4.py +52 -40
  61. transformers/integrations/peft.py +488 -176
  62. transformers/integrations/quark.py +2 -4
  63. transformers/integrations/tensor_parallel.py +641 -581
  64. transformers/integrations/torchao.py +4 -6
  65. transformers/loss/loss_lw_detr.py +356 -0
  66. transformers/loss/loss_utils.py +2 -0
  67. transformers/masking_utils.py +199 -59
  68. transformers/model_debugging_utils.py +4 -5
  69. transformers/modelcard.py +14 -192
  70. transformers/modeling_attn_mask_utils.py +19 -19
  71. transformers/modeling_flash_attention_utils.py +28 -29
  72. transformers/modeling_gguf_pytorch_utils.py +5 -5
  73. transformers/modeling_layers.py +21 -22
  74. transformers/modeling_outputs.py +242 -253
  75. transformers/modeling_rope_utils.py +32 -32
  76. transformers/modeling_utils.py +416 -438
  77. transformers/models/__init__.py +10 -0
  78. transformers/models/afmoe/configuration_afmoe.py +40 -33
  79. transformers/models/afmoe/modeling_afmoe.py +38 -41
  80. transformers/models/afmoe/modular_afmoe.py +23 -25
  81. transformers/models/aimv2/configuration_aimv2.py +2 -10
  82. transformers/models/aimv2/modeling_aimv2.py +46 -45
  83. transformers/models/aimv2/modular_aimv2.py +13 -19
  84. transformers/models/albert/configuration_albert.py +8 -2
  85. transformers/models/albert/modeling_albert.py +70 -72
  86. transformers/models/albert/tokenization_albert.py +1 -4
  87. transformers/models/align/configuration_align.py +8 -6
  88. transformers/models/align/modeling_align.py +83 -86
  89. transformers/models/align/processing_align.py +2 -30
  90. transformers/models/altclip/configuration_altclip.py +4 -7
  91. transformers/models/altclip/modeling_altclip.py +106 -103
  92. transformers/models/altclip/processing_altclip.py +2 -15
  93. transformers/models/apertus/__init__.py +0 -1
  94. transformers/models/apertus/configuration_apertus.py +23 -28
  95. transformers/models/apertus/modeling_apertus.py +35 -38
  96. transformers/models/apertus/modular_apertus.py +36 -40
  97. transformers/models/arcee/configuration_arcee.py +25 -30
  98. transformers/models/arcee/modeling_arcee.py +35 -38
  99. transformers/models/arcee/modular_arcee.py +20 -23
  100. transformers/models/aria/configuration_aria.py +31 -44
  101. transformers/models/aria/image_processing_aria.py +25 -27
  102. transformers/models/aria/modeling_aria.py +102 -102
  103. transformers/models/aria/modular_aria.py +111 -124
  104. transformers/models/aria/processing_aria.py +28 -35
  105. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
  106. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
  107. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
  108. transformers/models/audioflamingo3/__init__.py +0 -1
  109. transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
  110. transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
  111. transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
  112. transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
  113. transformers/models/auto/auto_factory.py +12 -11
  114. transformers/models/auto/configuration_auto.py +48 -5
  115. transformers/models/auto/feature_extraction_auto.py +5 -7
  116. transformers/models/auto/image_processing_auto.py +30 -39
  117. transformers/models/auto/modeling_auto.py +33 -199
  118. transformers/models/auto/processing_auto.py +11 -19
  119. transformers/models/auto/tokenization_auto.py +38 -37
  120. transformers/models/auto/video_processing_auto.py +7 -8
  121. transformers/models/autoformer/configuration_autoformer.py +4 -7
  122. transformers/models/autoformer/modeling_autoformer.py +100 -101
  123. transformers/models/aya_vision/configuration_aya_vision.py +4 -1
  124. transformers/models/aya_vision/modeling_aya_vision.py +64 -99
  125. transformers/models/aya_vision/modular_aya_vision.py +46 -74
  126. transformers/models/aya_vision/processing_aya_vision.py +25 -53
  127. transformers/models/bamba/configuration_bamba.py +46 -39
  128. transformers/models/bamba/modeling_bamba.py +83 -119
  129. transformers/models/bamba/modular_bamba.py +70 -109
  130. transformers/models/bark/configuration_bark.py +6 -8
  131. transformers/models/bark/generation_configuration_bark.py +3 -5
  132. transformers/models/bark/modeling_bark.py +64 -65
  133. transformers/models/bark/processing_bark.py +19 -41
  134. transformers/models/bart/configuration_bart.py +9 -5
  135. transformers/models/bart/modeling_bart.py +124 -129
  136. transformers/models/barthez/tokenization_barthez.py +1 -4
  137. transformers/models/bartpho/tokenization_bartpho.py +6 -7
  138. transformers/models/beit/configuration_beit.py +2 -15
  139. transformers/models/beit/image_processing_beit.py +53 -56
  140. transformers/models/beit/image_processing_beit_fast.py +11 -12
  141. transformers/models/beit/modeling_beit.py +65 -62
  142. transformers/models/bert/configuration_bert.py +12 -2
  143. transformers/models/bert/modeling_bert.py +117 -152
  144. transformers/models/bert/tokenization_bert.py +2 -4
  145. transformers/models/bert/tokenization_bert_legacy.py +3 -5
  146. transformers/models/bert_generation/configuration_bert_generation.py +17 -2
  147. transformers/models/bert_generation/modeling_bert_generation.py +53 -55
  148. transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
  149. transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
  150. transformers/models/bertweet/tokenization_bertweet.py +1 -3
  151. transformers/models/big_bird/configuration_big_bird.py +12 -9
  152. transformers/models/big_bird/modeling_big_bird.py +107 -124
  153. transformers/models/big_bird/tokenization_big_bird.py +1 -4
  154. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  155. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
  156. transformers/models/biogpt/configuration_biogpt.py +8 -2
  157. transformers/models/biogpt/modeling_biogpt.py +73 -79
  158. transformers/models/biogpt/modular_biogpt.py +60 -66
  159. transformers/models/biogpt/tokenization_biogpt.py +3 -5
  160. transformers/models/bit/configuration_bit.py +2 -5
  161. transformers/models/bit/image_processing_bit.py +21 -24
  162. transformers/models/bit/image_processing_bit_fast.py +0 -1
  163. transformers/models/bit/modeling_bit.py +15 -16
  164. transformers/models/bitnet/configuration_bitnet.py +23 -28
  165. transformers/models/bitnet/modeling_bitnet.py +34 -38
  166. transformers/models/bitnet/modular_bitnet.py +7 -10
  167. transformers/models/blenderbot/configuration_blenderbot.py +8 -5
  168. transformers/models/blenderbot/modeling_blenderbot.py +68 -99
  169. transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
  170. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
  171. transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
  172. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
  173. transformers/models/blip/configuration_blip.py +9 -10
  174. transformers/models/blip/image_processing_blip.py +17 -20
  175. transformers/models/blip/image_processing_blip_fast.py +0 -1
  176. transformers/models/blip/modeling_blip.py +115 -108
  177. transformers/models/blip/modeling_blip_text.py +63 -65
  178. transformers/models/blip/processing_blip.py +5 -36
  179. transformers/models/blip_2/configuration_blip_2.py +2 -2
  180. transformers/models/blip_2/modeling_blip_2.py +145 -121
  181. transformers/models/blip_2/processing_blip_2.py +8 -38
  182. transformers/models/bloom/configuration_bloom.py +5 -2
  183. transformers/models/bloom/modeling_bloom.py +60 -60
  184. transformers/models/blt/configuration_blt.py +94 -86
  185. transformers/models/blt/modeling_blt.py +93 -90
  186. transformers/models/blt/modular_blt.py +127 -69
  187. transformers/models/bridgetower/configuration_bridgetower.py +7 -2
  188. transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
  189. transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
  190. transformers/models/bridgetower/modeling_bridgetower.py +136 -124
  191. transformers/models/bridgetower/processing_bridgetower.py +2 -16
  192. transformers/models/bros/configuration_bros.py +24 -18
  193. transformers/models/bros/modeling_bros.py +78 -80
  194. transformers/models/bros/processing_bros.py +2 -12
  195. transformers/models/byt5/tokenization_byt5.py +4 -6
  196. transformers/models/camembert/configuration_camembert.py +8 -2
  197. transformers/models/camembert/modeling_camembert.py +97 -99
  198. transformers/models/camembert/modular_camembert.py +51 -54
  199. transformers/models/camembert/tokenization_camembert.py +1 -4
  200. transformers/models/canine/configuration_canine.py +4 -2
  201. transformers/models/canine/modeling_canine.py +73 -75
  202. transformers/models/canine/tokenization_canine.py +0 -1
  203. transformers/models/chameleon/configuration_chameleon.py +29 -34
  204. transformers/models/chameleon/image_processing_chameleon.py +21 -24
  205. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
  206. transformers/models/chameleon/modeling_chameleon.py +135 -92
  207. transformers/models/chameleon/processing_chameleon.py +16 -41
  208. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
  209. transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
  210. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
  211. transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
  212. transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
  213. transformers/models/clap/configuration_clap.py +4 -9
  214. transformers/models/clap/feature_extraction_clap.py +9 -10
  215. transformers/models/clap/modeling_clap.py +109 -111
  216. transformers/models/clap/processing_clap.py +2 -15
  217. transformers/models/clip/configuration_clip.py +4 -2
  218. transformers/models/clip/image_processing_clip.py +21 -24
  219. transformers/models/clip/image_processing_clip_fast.py +9 -1
  220. transformers/models/clip/modeling_clip.py +70 -68
  221. transformers/models/clip/processing_clip.py +2 -14
  222. transformers/models/clip/tokenization_clip.py +2 -5
  223. transformers/models/clipseg/configuration_clipseg.py +4 -2
  224. transformers/models/clipseg/modeling_clipseg.py +113 -112
  225. transformers/models/clipseg/processing_clipseg.py +19 -42
  226. transformers/models/clvp/configuration_clvp.py +15 -5
  227. transformers/models/clvp/feature_extraction_clvp.py +7 -10
  228. transformers/models/clvp/modeling_clvp.py +138 -145
  229. transformers/models/clvp/number_normalizer.py +1 -2
  230. transformers/models/clvp/processing_clvp.py +3 -20
  231. transformers/models/clvp/tokenization_clvp.py +0 -1
  232. transformers/models/code_llama/tokenization_code_llama.py +3 -6
  233. transformers/models/codegen/configuration_codegen.py +4 -4
  234. transformers/models/codegen/modeling_codegen.py +50 -49
  235. transformers/models/codegen/tokenization_codegen.py +5 -6
  236. transformers/models/cohere/configuration_cohere.py +25 -30
  237. transformers/models/cohere/modeling_cohere.py +39 -42
  238. transformers/models/cohere/modular_cohere.py +27 -31
  239. transformers/models/cohere/tokenization_cohere.py +5 -6
  240. transformers/models/cohere2/configuration_cohere2.py +27 -32
  241. transformers/models/cohere2/modeling_cohere2.py +38 -41
  242. transformers/models/cohere2/modular_cohere2.py +48 -52
  243. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  244. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
  245. transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
  246. transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
  247. transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
  248. transformers/models/colpali/configuration_colpali.py +0 -1
  249. transformers/models/colpali/modeling_colpali.py +14 -16
  250. transformers/models/colpali/modular_colpali.py +11 -51
  251. transformers/models/colpali/processing_colpali.py +14 -52
  252. transformers/models/colqwen2/modeling_colqwen2.py +27 -28
  253. transformers/models/colqwen2/modular_colqwen2.py +36 -74
  254. transformers/models/colqwen2/processing_colqwen2.py +16 -52
  255. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
  256. transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
  257. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
  258. transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
  259. transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
  260. transformers/models/convbert/configuration_convbert.py +11 -8
  261. transformers/models/convbert/modeling_convbert.py +85 -87
  262. transformers/models/convbert/tokenization_convbert.py +0 -1
  263. transformers/models/convnext/configuration_convnext.py +2 -5
  264. transformers/models/convnext/image_processing_convnext.py +18 -21
  265. transformers/models/convnext/image_processing_convnext_fast.py +7 -8
  266. transformers/models/convnext/modeling_convnext.py +12 -14
  267. transformers/models/convnextv2/configuration_convnextv2.py +2 -5
  268. transformers/models/convnextv2/modeling_convnextv2.py +12 -14
  269. transformers/models/cpm/tokenization_cpm.py +6 -7
  270. transformers/models/cpm/tokenization_cpm_fast.py +3 -5
  271. transformers/models/cpmant/configuration_cpmant.py +4 -1
  272. transformers/models/cpmant/modeling_cpmant.py +38 -40
  273. transformers/models/cpmant/tokenization_cpmant.py +1 -3
  274. transformers/models/csm/configuration_csm.py +58 -66
  275. transformers/models/csm/generation_csm.py +13 -14
  276. transformers/models/csm/modeling_csm.py +81 -84
  277. transformers/models/csm/modular_csm.py +56 -58
  278. transformers/models/csm/processing_csm.py +25 -68
  279. transformers/models/ctrl/configuration_ctrl.py +16 -1
  280. transformers/models/ctrl/modeling_ctrl.py +51 -66
  281. transformers/models/ctrl/tokenization_ctrl.py +0 -1
  282. transformers/models/cvt/configuration_cvt.py +0 -1
  283. transformers/models/cvt/modeling_cvt.py +13 -15
  284. transformers/models/cwm/__init__.py +0 -1
  285. transformers/models/cwm/configuration_cwm.py +8 -12
  286. transformers/models/cwm/modeling_cwm.py +36 -38
  287. transformers/models/cwm/modular_cwm.py +10 -12
  288. transformers/models/d_fine/configuration_d_fine.py +10 -57
  289. transformers/models/d_fine/modeling_d_fine.py +786 -927
  290. transformers/models/d_fine/modular_d_fine.py +339 -417
  291. transformers/models/dab_detr/configuration_dab_detr.py +22 -49
  292. transformers/models/dab_detr/modeling_dab_detr.py +79 -77
  293. transformers/models/dac/configuration_dac.py +0 -1
  294. transformers/models/dac/feature_extraction_dac.py +6 -9
  295. transformers/models/dac/modeling_dac.py +22 -24
  296. transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
  297. transformers/models/data2vec/configuration_data2vec_text.py +11 -3
  298. transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
  299. transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
  300. transformers/models/data2vec/modeling_data2vec_text.py +97 -99
  301. transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
  302. transformers/models/data2vec/modular_data2vec_audio.py +6 -1
  303. transformers/models/data2vec/modular_data2vec_text.py +51 -54
  304. transformers/models/dbrx/configuration_dbrx.py +29 -22
  305. transformers/models/dbrx/modeling_dbrx.py +45 -48
  306. transformers/models/dbrx/modular_dbrx.py +37 -39
  307. transformers/models/deberta/configuration_deberta.py +6 -1
  308. transformers/models/deberta/modeling_deberta.py +57 -60
  309. transformers/models/deberta/tokenization_deberta.py +2 -5
  310. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
  311. transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
  312. transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
  313. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
  314. transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
  315. transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
  316. transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
  317. transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
  318. transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
  319. transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
  320. transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
  321. transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
  322. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
  323. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
  324. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
  325. transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
  326. transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
  327. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
  328. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
  329. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
  330. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
  331. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
  332. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
  333. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
  334. transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
  335. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
  336. transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
  337. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
  338. transformers/models/deit/configuration_deit.py +0 -1
  339. transformers/models/deit/image_processing_deit.py +18 -21
  340. transformers/models/deit/image_processing_deit_fast.py +0 -1
  341. transformers/models/deit/modeling_deit.py +27 -25
  342. transformers/models/depth_anything/configuration_depth_anything.py +12 -43
  343. transformers/models/depth_anything/modeling_depth_anything.py +10 -11
  344. transformers/models/depth_pro/configuration_depth_pro.py +0 -1
  345. transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
  346. transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
  347. transformers/models/depth_pro/modeling_depth_pro.py +29 -27
  348. transformers/models/detr/configuration_detr.py +18 -50
  349. transformers/models/detr/image_processing_detr.py +64 -66
  350. transformers/models/detr/image_processing_detr_fast.py +33 -34
  351. transformers/models/detr/modeling_detr.py +748 -789
  352. transformers/models/dia/configuration_dia.py +9 -15
  353. transformers/models/dia/feature_extraction_dia.py +6 -9
  354. transformers/models/dia/generation_dia.py +48 -53
  355. transformers/models/dia/modeling_dia.py +68 -71
  356. transformers/models/dia/modular_dia.py +56 -58
  357. transformers/models/dia/processing_dia.py +39 -29
  358. transformers/models/dia/tokenization_dia.py +3 -6
  359. transformers/models/diffllama/configuration_diffllama.py +25 -30
  360. transformers/models/diffllama/modeling_diffllama.py +45 -53
  361. transformers/models/diffllama/modular_diffllama.py +18 -25
  362. transformers/models/dinat/configuration_dinat.py +2 -5
  363. transformers/models/dinat/modeling_dinat.py +47 -48
  364. transformers/models/dinov2/configuration_dinov2.py +2 -5
  365. transformers/models/dinov2/modeling_dinov2.py +20 -21
  366. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
  367. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
  368. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
  369. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
  370. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
  371. transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
  372. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
  373. transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
  374. transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
  375. transformers/models/distilbert/configuration_distilbert.py +8 -2
  376. transformers/models/distilbert/modeling_distilbert.py +47 -49
  377. transformers/models/distilbert/tokenization_distilbert.py +0 -1
  378. transformers/models/doge/__init__.py +0 -1
  379. transformers/models/doge/configuration_doge.py +42 -35
  380. transformers/models/doge/modeling_doge.py +46 -49
  381. transformers/models/doge/modular_doge.py +77 -68
  382. transformers/models/donut/configuration_donut_swin.py +0 -1
  383. transformers/models/donut/image_processing_donut.py +26 -29
  384. transformers/models/donut/image_processing_donut_fast.py +9 -14
  385. transformers/models/donut/modeling_donut_swin.py +44 -46
  386. transformers/models/donut/processing_donut.py +5 -26
  387. transformers/models/dots1/configuration_dots1.py +43 -36
  388. transformers/models/dots1/modeling_dots1.py +35 -38
  389. transformers/models/dots1/modular_dots1.py +0 -1
  390. transformers/models/dpr/configuration_dpr.py +19 -2
  391. transformers/models/dpr/modeling_dpr.py +37 -39
  392. transformers/models/dpr/tokenization_dpr.py +7 -9
  393. transformers/models/dpr/tokenization_dpr_fast.py +7 -9
  394. transformers/models/dpt/configuration_dpt.py +23 -66
  395. transformers/models/dpt/image_processing_dpt.py +65 -66
  396. transformers/models/dpt/image_processing_dpt_fast.py +18 -19
  397. transformers/models/dpt/modeling_dpt.py +38 -36
  398. transformers/models/dpt/modular_dpt.py +14 -15
  399. transformers/models/edgetam/configuration_edgetam.py +1 -2
  400. transformers/models/edgetam/modeling_edgetam.py +87 -89
  401. transformers/models/edgetam/modular_edgetam.py +7 -13
  402. transformers/models/edgetam_video/__init__.py +0 -1
  403. transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
  404. transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
  405. transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
  406. transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
  407. transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
  408. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
  409. transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
  410. transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
  411. transformers/models/efficientnet/configuration_efficientnet.py +0 -1
  412. transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
  413. transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
  414. transformers/models/efficientnet/modeling_efficientnet.py +12 -14
  415. transformers/models/electra/configuration_electra.py +13 -3
  416. transformers/models/electra/modeling_electra.py +107 -109
  417. transformers/models/emu3/configuration_emu3.py +17 -17
  418. transformers/models/emu3/image_processing_emu3.py +44 -39
  419. transformers/models/emu3/modeling_emu3.py +143 -109
  420. transformers/models/emu3/modular_emu3.py +109 -73
  421. transformers/models/emu3/processing_emu3.py +18 -43
  422. transformers/models/encodec/configuration_encodec.py +2 -4
  423. transformers/models/encodec/feature_extraction_encodec.py +10 -13
  424. transformers/models/encodec/modeling_encodec.py +25 -29
  425. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
  426. transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
  427. transformers/models/eomt/configuration_eomt.py +12 -14
  428. transformers/models/eomt/image_processing_eomt.py +53 -55
  429. transformers/models/eomt/image_processing_eomt_fast.py +18 -19
  430. transformers/models/eomt/modeling_eomt.py +19 -21
  431. transformers/models/eomt/modular_eomt.py +28 -30
  432. transformers/models/eomt_dinov3/__init__.py +28 -0
  433. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  434. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  435. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  436. transformers/models/ernie/configuration_ernie.py +24 -3
  437. transformers/models/ernie/modeling_ernie.py +127 -162
  438. transformers/models/ernie/modular_ernie.py +91 -103
  439. transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
  440. transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
  441. transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
  442. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
  443. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
  444. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
  445. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
  446. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
  447. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
  448. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
  449. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
  450. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
  451. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
  452. transformers/models/esm/configuration_esm.py +11 -15
  453. transformers/models/esm/modeling_esm.py +35 -37
  454. transformers/models/esm/modeling_esmfold.py +43 -50
  455. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  456. transformers/models/esm/openfold_utils/loss.py +1 -2
  457. transformers/models/esm/openfold_utils/protein.py +15 -16
  458. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  459. transformers/models/esm/tokenization_esm.py +2 -4
  460. transformers/models/evolla/configuration_evolla.py +50 -40
  461. transformers/models/evolla/modeling_evolla.py +69 -68
  462. transformers/models/evolla/modular_evolla.py +50 -48
  463. transformers/models/evolla/processing_evolla.py +23 -35
  464. transformers/models/exaone4/configuration_exaone4.py +27 -27
  465. transformers/models/exaone4/modeling_exaone4.py +36 -39
  466. transformers/models/exaone4/modular_exaone4.py +51 -50
  467. transformers/models/exaone_moe/__init__.py +27 -0
  468. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  469. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  470. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  471. transformers/models/falcon/configuration_falcon.py +31 -26
  472. transformers/models/falcon/modeling_falcon.py +76 -84
  473. transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
  474. transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
  475. transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
  476. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
  477. transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
  478. transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
  479. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
  480. transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
  481. transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
  482. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
  483. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
  484. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
  485. transformers/models/flaubert/configuration_flaubert.py +10 -5
  486. transformers/models/flaubert/modeling_flaubert.py +125 -129
  487. transformers/models/flaubert/tokenization_flaubert.py +3 -5
  488. transformers/models/flava/configuration_flava.py +9 -9
  489. transformers/models/flava/image_processing_flava.py +66 -67
  490. transformers/models/flava/image_processing_flava_fast.py +46 -47
  491. transformers/models/flava/modeling_flava.py +144 -135
  492. transformers/models/flava/processing_flava.py +2 -12
  493. transformers/models/flex_olmo/__init__.py +0 -1
  494. transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
  495. transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
  496. transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
  497. transformers/models/florence2/configuration_florence2.py +4 -1
  498. transformers/models/florence2/modeling_florence2.py +96 -72
  499. transformers/models/florence2/modular_florence2.py +100 -107
  500. transformers/models/florence2/processing_florence2.py +18 -47
  501. transformers/models/fnet/configuration_fnet.py +6 -2
  502. transformers/models/fnet/modeling_fnet.py +69 -80
  503. transformers/models/fnet/tokenization_fnet.py +0 -1
  504. transformers/models/focalnet/configuration_focalnet.py +2 -5
  505. transformers/models/focalnet/modeling_focalnet.py +49 -48
  506. transformers/models/fsmt/configuration_fsmt.py +12 -17
  507. transformers/models/fsmt/modeling_fsmt.py +47 -48
  508. transformers/models/fsmt/tokenization_fsmt.py +3 -5
  509. transformers/models/funnel/configuration_funnel.py +8 -1
  510. transformers/models/funnel/modeling_funnel.py +91 -93
  511. transformers/models/funnel/tokenization_funnel.py +2 -5
  512. transformers/models/fuyu/configuration_fuyu.py +28 -34
  513. transformers/models/fuyu/image_processing_fuyu.py +29 -31
  514. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  515. transformers/models/fuyu/modeling_fuyu.py +50 -52
  516. transformers/models/fuyu/processing_fuyu.py +9 -36
  517. transformers/models/gemma/configuration_gemma.py +25 -30
  518. transformers/models/gemma/modeling_gemma.py +36 -38
  519. transformers/models/gemma/modular_gemma.py +33 -36
  520. transformers/models/gemma/tokenization_gemma.py +3 -6
  521. transformers/models/gemma2/configuration_gemma2.py +30 -35
  522. transformers/models/gemma2/modeling_gemma2.py +38 -41
  523. transformers/models/gemma2/modular_gemma2.py +63 -67
  524. transformers/models/gemma3/configuration_gemma3.py +53 -48
  525. transformers/models/gemma3/image_processing_gemma3.py +29 -31
  526. transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
  527. transformers/models/gemma3/modeling_gemma3.py +123 -122
  528. transformers/models/gemma3/modular_gemma3.py +128 -125
  529. transformers/models/gemma3/processing_gemma3.py +5 -5
  530. transformers/models/gemma3n/configuration_gemma3n.py +42 -30
  531. transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
  532. transformers/models/gemma3n/modeling_gemma3n.py +166 -147
  533. transformers/models/gemma3n/modular_gemma3n.py +176 -148
  534. transformers/models/gemma3n/processing_gemma3n.py +12 -26
  535. transformers/models/git/configuration_git.py +5 -8
  536. transformers/models/git/modeling_git.py +115 -127
  537. transformers/models/git/processing_git.py +2 -14
  538. transformers/models/glm/configuration_glm.py +26 -30
  539. transformers/models/glm/modeling_glm.py +36 -39
  540. transformers/models/glm/modular_glm.py +4 -7
  541. transformers/models/glm4/configuration_glm4.py +26 -30
  542. transformers/models/glm4/modeling_glm4.py +39 -41
  543. transformers/models/glm4/modular_glm4.py +8 -10
  544. transformers/models/glm46v/configuration_glm46v.py +4 -1
  545. transformers/models/glm46v/image_processing_glm46v.py +40 -38
  546. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  547. transformers/models/glm46v/modeling_glm46v.py +138 -93
  548. transformers/models/glm46v/modular_glm46v.py +5 -3
  549. transformers/models/glm46v/processing_glm46v.py +7 -41
  550. transformers/models/glm46v/video_processing_glm46v.py +9 -11
  551. transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
  552. transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
  553. transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
  554. transformers/models/glm4_moe_lite/__init__.py +28 -0
  555. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
  556. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
  557. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
  558. transformers/models/glm4v/configuration_glm4v.py +25 -24
  559. transformers/models/glm4v/image_processing_glm4v.py +39 -38
  560. transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
  561. transformers/models/glm4v/modeling_glm4v.py +249 -210
  562. transformers/models/glm4v/modular_glm4v.py +211 -230
  563. transformers/models/glm4v/processing_glm4v.py +7 -41
  564. transformers/models/glm4v/video_processing_glm4v.py +9 -11
  565. transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
  566. transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
  567. transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
  568. transformers/models/glm_image/__init__.py +31 -0
  569. transformers/models/glm_image/configuration_glm_image.py +358 -0
  570. transformers/models/glm_image/image_processing_glm_image.py +503 -0
  571. transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
  572. transformers/models/glm_image/modeling_glm_image.py +1691 -0
  573. transformers/models/glm_image/modular_glm_image.py +1640 -0
  574. transformers/models/glm_image/processing_glm_image.py +265 -0
  575. transformers/models/glm_ocr/__init__.py +28 -0
  576. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  577. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  578. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  579. transformers/models/glmasr/__init__.py +0 -1
  580. transformers/models/glmasr/configuration_glmasr.py +0 -1
  581. transformers/models/glmasr/modeling_glmasr.py +51 -46
  582. transformers/models/glmasr/modular_glmasr.py +39 -29
  583. transformers/models/glmasr/processing_glmasr.py +7 -8
  584. transformers/models/glpn/configuration_glpn.py +0 -1
  585. transformers/models/glpn/image_processing_glpn.py +11 -12
  586. transformers/models/glpn/image_processing_glpn_fast.py +11 -12
  587. transformers/models/glpn/modeling_glpn.py +14 -14
  588. transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
  589. transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
  590. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
  591. transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
  592. transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
  593. transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
  594. transformers/models/gpt2/configuration_gpt2.py +13 -2
  595. transformers/models/gpt2/modeling_gpt2.py +111 -113
  596. transformers/models/gpt2/tokenization_gpt2.py +6 -9
  597. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
  598. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
  599. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
  600. transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
  601. transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
  602. transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
  603. transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
  604. transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
  605. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
  606. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
  607. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
  608. transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
  609. transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
  610. transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
  611. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  612. transformers/models/gptj/configuration_gptj.py +4 -5
  613. transformers/models/gptj/modeling_gptj.py +85 -88
  614. transformers/models/granite/configuration_granite.py +28 -33
  615. transformers/models/granite/modeling_granite.py +43 -45
  616. transformers/models/granite/modular_granite.py +29 -31
  617. transformers/models/granite_speech/configuration_granite_speech.py +0 -1
  618. transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
  619. transformers/models/granite_speech/modeling_granite_speech.py +84 -60
  620. transformers/models/granite_speech/processing_granite_speech.py +11 -4
  621. transformers/models/granitemoe/configuration_granitemoe.py +31 -36
  622. transformers/models/granitemoe/modeling_granitemoe.py +39 -41
  623. transformers/models/granitemoe/modular_granitemoe.py +21 -23
  624. transformers/models/granitemoehybrid/__init__.py +0 -1
  625. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
  626. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
  627. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
  628. transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
  629. transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
  630. transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
  631. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
  632. transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
  633. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
  634. transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
  635. transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
  636. transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
  637. transformers/models/groupvit/configuration_groupvit.py +4 -2
  638. transformers/models/groupvit/modeling_groupvit.py +98 -92
  639. transformers/models/helium/configuration_helium.py +25 -29
  640. transformers/models/helium/modeling_helium.py +37 -40
  641. transformers/models/helium/modular_helium.py +3 -7
  642. transformers/models/herbert/tokenization_herbert.py +4 -6
  643. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
  644. transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
  645. transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
  646. transformers/models/hiera/configuration_hiera.py +2 -5
  647. transformers/models/hiera/modeling_hiera.py +71 -70
  648. transformers/models/hubert/configuration_hubert.py +4 -2
  649. transformers/models/hubert/modeling_hubert.py +42 -41
  650. transformers/models/hubert/modular_hubert.py +8 -11
  651. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
  652. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
  653. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
  654. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
  655. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
  656. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
  657. transformers/models/ibert/configuration_ibert.py +4 -2
  658. transformers/models/ibert/modeling_ibert.py +60 -62
  659. transformers/models/ibert/quant_modules.py +0 -1
  660. transformers/models/idefics/configuration_idefics.py +5 -8
  661. transformers/models/idefics/image_processing_idefics.py +13 -15
  662. transformers/models/idefics/modeling_idefics.py +63 -65
  663. transformers/models/idefics/perceiver.py +1 -3
  664. transformers/models/idefics/processing_idefics.py +32 -48
  665. transformers/models/idefics/vision.py +27 -28
  666. transformers/models/idefics2/configuration_idefics2.py +1 -3
  667. transformers/models/idefics2/image_processing_idefics2.py +31 -32
  668. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  669. transformers/models/idefics2/modeling_idefics2.py +126 -106
  670. transformers/models/idefics2/processing_idefics2.py +10 -68
  671. transformers/models/idefics3/configuration_idefics3.py +1 -4
  672. transformers/models/idefics3/image_processing_idefics3.py +42 -43
  673. transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
  674. transformers/models/idefics3/modeling_idefics3.py +113 -92
  675. transformers/models/idefics3/processing_idefics3.py +15 -69
  676. transformers/models/ijepa/configuration_ijepa.py +0 -1
  677. transformers/models/ijepa/modeling_ijepa.py +13 -14
  678. transformers/models/ijepa/modular_ijepa.py +5 -7
  679. transformers/models/imagegpt/configuration_imagegpt.py +9 -2
  680. transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
  681. transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
  682. transformers/models/imagegpt/modeling_imagegpt.py +65 -62
  683. transformers/models/informer/configuration_informer.py +6 -9
  684. transformers/models/informer/modeling_informer.py +87 -89
  685. transformers/models/informer/modular_informer.py +13 -16
  686. transformers/models/instructblip/configuration_instructblip.py +2 -2
  687. transformers/models/instructblip/modeling_instructblip.py +104 -79
  688. transformers/models/instructblip/processing_instructblip.py +10 -36
  689. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  690. transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
  691. transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
  692. transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
  693. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
  694. transformers/models/internvl/configuration_internvl.py +5 -1
  695. transformers/models/internvl/modeling_internvl.py +76 -98
  696. transformers/models/internvl/modular_internvl.py +45 -59
  697. transformers/models/internvl/processing_internvl.py +12 -45
  698. transformers/models/internvl/video_processing_internvl.py +10 -11
  699. transformers/models/jais2/configuration_jais2.py +25 -29
  700. transformers/models/jais2/modeling_jais2.py +36 -38
  701. transformers/models/jais2/modular_jais2.py +20 -22
  702. transformers/models/jamba/configuration_jamba.py +5 -8
  703. transformers/models/jamba/modeling_jamba.py +47 -50
  704. transformers/models/jamba/modular_jamba.py +40 -41
  705. transformers/models/janus/configuration_janus.py +0 -1
  706. transformers/models/janus/image_processing_janus.py +37 -39
  707. transformers/models/janus/image_processing_janus_fast.py +20 -21
  708. transformers/models/janus/modeling_janus.py +103 -188
  709. transformers/models/janus/modular_janus.py +122 -83
  710. transformers/models/janus/processing_janus.py +17 -43
  711. transformers/models/jetmoe/configuration_jetmoe.py +26 -27
  712. transformers/models/jetmoe/modeling_jetmoe.py +42 -45
  713. transformers/models/jetmoe/modular_jetmoe.py +33 -36
  714. transformers/models/kosmos2/configuration_kosmos2.py +10 -9
  715. transformers/models/kosmos2/modeling_kosmos2.py +199 -178
  716. transformers/models/kosmos2/processing_kosmos2.py +40 -55
  717. transformers/models/kosmos2_5/__init__.py +0 -1
  718. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
  719. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
  720. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
  721. transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
  722. transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
  723. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
  724. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
  725. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
  726. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
  727. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
  728. transformers/models/lasr/configuration_lasr.py +3 -7
  729. transformers/models/lasr/feature_extraction_lasr.py +10 -12
  730. transformers/models/lasr/modeling_lasr.py +21 -24
  731. transformers/models/lasr/modular_lasr.py +11 -13
  732. transformers/models/lasr/processing_lasr.py +12 -6
  733. transformers/models/lasr/tokenization_lasr.py +2 -4
  734. transformers/models/layoutlm/configuration_layoutlm.py +14 -2
  735. transformers/models/layoutlm/modeling_layoutlm.py +70 -72
  736. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
  737. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
  738. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
  739. transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
  740. transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
  741. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
  742. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
  743. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
  744. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
  745. transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
  746. transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
  747. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
  748. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
  749. transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
  750. transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
  751. transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
  752. transformers/models/led/configuration_led.py +8 -12
  753. transformers/models/led/modeling_led.py +113 -267
  754. transformers/models/levit/configuration_levit.py +0 -1
  755. transformers/models/levit/image_processing_levit.py +19 -21
  756. transformers/models/levit/image_processing_levit_fast.py +4 -5
  757. transformers/models/levit/modeling_levit.py +17 -19
  758. transformers/models/lfm2/configuration_lfm2.py +27 -30
  759. transformers/models/lfm2/modeling_lfm2.py +46 -48
  760. transformers/models/lfm2/modular_lfm2.py +32 -32
  761. transformers/models/lfm2_moe/__init__.py +0 -1
  762. transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
  763. transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
  764. transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
  765. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
  766. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
  767. transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
  768. transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
  769. transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
  770. transformers/models/lightglue/image_processing_lightglue.py +16 -15
  771. transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
  772. transformers/models/lightglue/modeling_lightglue.py +31 -33
  773. transformers/models/lightglue/modular_lightglue.py +31 -31
  774. transformers/models/lighton_ocr/__init__.py +28 -0
  775. transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
  776. transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
  777. transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
  778. transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
  779. transformers/models/lilt/configuration_lilt.py +6 -2
  780. transformers/models/lilt/modeling_lilt.py +53 -55
  781. transformers/models/llama/configuration_llama.py +26 -31
  782. transformers/models/llama/modeling_llama.py +35 -38
  783. transformers/models/llama/tokenization_llama.py +2 -4
  784. transformers/models/llama4/configuration_llama4.py +87 -69
  785. transformers/models/llama4/image_processing_llama4_fast.py +11 -12
  786. transformers/models/llama4/modeling_llama4.py +116 -115
  787. transformers/models/llama4/processing_llama4.py +33 -57
  788. transformers/models/llava/configuration_llava.py +10 -1
  789. transformers/models/llava/image_processing_llava.py +25 -28
  790. transformers/models/llava/image_processing_llava_fast.py +9 -10
  791. transformers/models/llava/modeling_llava.py +73 -102
  792. transformers/models/llava/processing_llava.py +18 -51
  793. transformers/models/llava_next/configuration_llava_next.py +2 -2
  794. transformers/models/llava_next/image_processing_llava_next.py +43 -45
  795. transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
  796. transformers/models/llava_next/modeling_llava_next.py +103 -104
  797. transformers/models/llava_next/processing_llava_next.py +18 -47
  798. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
  799. transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
  800. transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
  801. transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
  802. transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
  803. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
  804. transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
  805. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
  806. transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
  807. transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
  808. transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
  809. transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
  810. transformers/models/longcat_flash/__init__.py +0 -1
  811. transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
  812. transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
  813. transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
  814. transformers/models/longformer/configuration_longformer.py +5 -5
  815. transformers/models/longformer/modeling_longformer.py +99 -101
  816. transformers/models/longt5/configuration_longt5.py +9 -7
  817. transformers/models/longt5/modeling_longt5.py +45 -45
  818. transformers/models/luke/configuration_luke.py +8 -2
  819. transformers/models/luke/modeling_luke.py +179 -181
  820. transformers/models/luke/tokenization_luke.py +99 -105
  821. transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
  822. transformers/models/lw_detr/configuration_lw_detr.py +362 -0
  823. transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
  824. transformers/models/lw_detr/modular_lw_detr.py +1609 -0
  825. transformers/models/lxmert/configuration_lxmert.py +16 -1
  826. transformers/models/lxmert/modeling_lxmert.py +63 -74
  827. transformers/models/m2m_100/configuration_m2m_100.py +7 -9
  828. transformers/models/m2m_100/modeling_m2m_100.py +72 -74
  829. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  830. transformers/models/mamba/configuration_mamba.py +5 -3
  831. transformers/models/mamba/modeling_mamba.py +61 -70
  832. transformers/models/mamba2/configuration_mamba2.py +5 -8
  833. transformers/models/mamba2/modeling_mamba2.py +66 -79
  834. transformers/models/marian/configuration_marian.py +10 -5
  835. transformers/models/marian/modeling_marian.py +88 -90
  836. transformers/models/marian/tokenization_marian.py +6 -6
  837. transformers/models/markuplm/configuration_markuplm.py +4 -7
  838. transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
  839. transformers/models/markuplm/modeling_markuplm.py +63 -65
  840. transformers/models/markuplm/processing_markuplm.py +31 -38
  841. transformers/models/markuplm/tokenization_markuplm.py +67 -77
  842. transformers/models/mask2former/configuration_mask2former.py +14 -52
  843. transformers/models/mask2former/image_processing_mask2former.py +84 -85
  844. transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
  845. transformers/models/mask2former/modeling_mask2former.py +108 -104
  846. transformers/models/mask2former/modular_mask2former.py +6 -8
  847. transformers/models/maskformer/configuration_maskformer.py +17 -51
  848. transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
  849. transformers/models/maskformer/image_processing_maskformer.py +84 -85
  850. transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
  851. transformers/models/maskformer/modeling_maskformer.py +71 -67
  852. transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
  853. transformers/models/mbart/configuration_mbart.py +9 -5
  854. transformers/models/mbart/modeling_mbart.py +120 -119
  855. transformers/models/mbart/tokenization_mbart.py +2 -4
  856. transformers/models/mbart50/tokenization_mbart50.py +3 -5
  857. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
  858. transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
  859. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  860. transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
  861. transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
  862. transformers/models/mgp_str/configuration_mgp_str.py +0 -1
  863. transformers/models/mgp_str/modeling_mgp_str.py +18 -18
  864. transformers/models/mgp_str/processing_mgp_str.py +3 -20
  865. transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
  866. transformers/models/mimi/configuration_mimi.py +42 -40
  867. transformers/models/mimi/modeling_mimi.py +116 -115
  868. transformers/models/minimax/__init__.py +0 -1
  869. transformers/models/minimax/configuration_minimax.py +40 -47
  870. transformers/models/minimax/modeling_minimax.py +46 -49
  871. transformers/models/minimax/modular_minimax.py +59 -65
  872. transformers/models/minimax_m2/__init__.py +28 -0
  873. transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
  874. transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
  875. transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
  876. transformers/models/ministral/configuration_ministral.py +25 -29
  877. transformers/models/ministral/modeling_ministral.py +35 -37
  878. transformers/models/ministral/modular_ministral.py +32 -37
  879. transformers/models/ministral3/configuration_ministral3.py +23 -26
  880. transformers/models/ministral3/modeling_ministral3.py +35 -37
  881. transformers/models/ministral3/modular_ministral3.py +7 -8
  882. transformers/models/mistral/configuration_mistral.py +24 -29
  883. transformers/models/mistral/modeling_mistral.py +35 -37
  884. transformers/models/mistral/modular_mistral.py +14 -15
  885. transformers/models/mistral3/configuration_mistral3.py +4 -1
  886. transformers/models/mistral3/modeling_mistral3.py +79 -82
  887. transformers/models/mistral3/modular_mistral3.py +66 -67
  888. transformers/models/mixtral/configuration_mixtral.py +32 -38
  889. transformers/models/mixtral/modeling_mixtral.py +39 -42
  890. transformers/models/mixtral/modular_mixtral.py +26 -29
  891. transformers/models/mlcd/configuration_mlcd.py +0 -1
  892. transformers/models/mlcd/modeling_mlcd.py +17 -17
  893. transformers/models/mlcd/modular_mlcd.py +16 -16
  894. transformers/models/mllama/configuration_mllama.py +10 -15
  895. transformers/models/mllama/image_processing_mllama.py +23 -25
  896. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  897. transformers/models/mllama/modeling_mllama.py +100 -103
  898. transformers/models/mllama/processing_mllama.py +6 -55
  899. transformers/models/mluke/tokenization_mluke.py +97 -103
  900. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
  901. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
  902. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
  903. transformers/models/mobilebert/configuration_mobilebert.py +4 -2
  904. transformers/models/mobilebert/modeling_mobilebert.py +78 -88
  905. transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
  906. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
  907. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
  908. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
  909. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
  910. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
  911. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
  912. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
  913. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
  914. transformers/models/mobilevit/configuration_mobilevit.py +0 -1
  915. transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
  916. transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
  917. transformers/models/mobilevit/modeling_mobilevit.py +21 -21
  918. transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
  919. transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
  920. transformers/models/modernbert/configuration_modernbert.py +76 -51
  921. transformers/models/modernbert/modeling_modernbert.py +188 -943
  922. transformers/models/modernbert/modular_modernbert.py +255 -978
  923. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
  924. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
  925. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
  926. transformers/models/moonshine/configuration_moonshine.py +34 -31
  927. transformers/models/moonshine/modeling_moonshine.py +70 -72
  928. transformers/models/moonshine/modular_moonshine.py +91 -86
  929. transformers/models/moshi/configuration_moshi.py +46 -23
  930. transformers/models/moshi/modeling_moshi.py +134 -142
  931. transformers/models/mpnet/configuration_mpnet.py +6 -2
  932. transformers/models/mpnet/modeling_mpnet.py +55 -57
  933. transformers/models/mpnet/tokenization_mpnet.py +1 -4
  934. transformers/models/mpt/configuration_mpt.py +17 -9
  935. transformers/models/mpt/modeling_mpt.py +58 -60
  936. transformers/models/mra/configuration_mra.py +8 -2
  937. transformers/models/mra/modeling_mra.py +54 -56
  938. transformers/models/mt5/configuration_mt5.py +9 -6
  939. transformers/models/mt5/modeling_mt5.py +80 -85
  940. transformers/models/musicgen/configuration_musicgen.py +12 -8
  941. transformers/models/musicgen/modeling_musicgen.py +114 -116
  942. transformers/models/musicgen/processing_musicgen.py +3 -21
  943. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
  944. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
  945. transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
  946. transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
  947. transformers/models/mvp/configuration_mvp.py +8 -5
  948. transformers/models/mvp/modeling_mvp.py +121 -123
  949. transformers/models/myt5/tokenization_myt5.py +8 -10
  950. transformers/models/nanochat/configuration_nanochat.py +5 -8
  951. transformers/models/nanochat/modeling_nanochat.py +36 -39
  952. transformers/models/nanochat/modular_nanochat.py +16 -18
  953. transformers/models/nemotron/configuration_nemotron.py +25 -30
  954. transformers/models/nemotron/modeling_nemotron.py +53 -66
  955. transformers/models/nllb/tokenization_nllb.py +14 -14
  956. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
  957. transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
  958. transformers/models/nougat/image_processing_nougat.py +29 -32
  959. transformers/models/nougat/image_processing_nougat_fast.py +12 -13
  960. transformers/models/nougat/processing_nougat.py +37 -39
  961. transformers/models/nougat/tokenization_nougat.py +5 -7
  962. transformers/models/nystromformer/configuration_nystromformer.py +8 -2
  963. transformers/models/nystromformer/modeling_nystromformer.py +61 -63
  964. transformers/models/olmo/configuration_olmo.py +23 -28
  965. transformers/models/olmo/modeling_olmo.py +35 -38
  966. transformers/models/olmo/modular_olmo.py +8 -12
  967. transformers/models/olmo2/configuration_olmo2.py +27 -32
  968. transformers/models/olmo2/modeling_olmo2.py +36 -39
  969. transformers/models/olmo2/modular_olmo2.py +36 -38
  970. transformers/models/olmo3/__init__.py +0 -1
  971. transformers/models/olmo3/configuration_olmo3.py +30 -34
  972. transformers/models/olmo3/modeling_olmo3.py +35 -38
  973. transformers/models/olmo3/modular_olmo3.py +44 -47
  974. transformers/models/olmoe/configuration_olmoe.py +29 -33
  975. transformers/models/olmoe/modeling_olmoe.py +41 -43
  976. transformers/models/olmoe/modular_olmoe.py +15 -16
  977. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
  978. transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
  979. transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
  980. transformers/models/oneformer/configuration_oneformer.py +11 -51
  981. transformers/models/oneformer/image_processing_oneformer.py +83 -84
  982. transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
  983. transformers/models/oneformer/modeling_oneformer.py +137 -133
  984. transformers/models/oneformer/processing_oneformer.py +28 -43
  985. transformers/models/openai/configuration_openai.py +16 -1
  986. transformers/models/openai/modeling_openai.py +50 -51
  987. transformers/models/openai/tokenization_openai.py +2 -5
  988. transformers/models/opt/configuration_opt.py +6 -7
  989. transformers/models/opt/modeling_opt.py +79 -80
  990. transformers/models/ovis2/__init__.py +0 -1
  991. transformers/models/ovis2/configuration_ovis2.py +4 -1
  992. transformers/models/ovis2/image_processing_ovis2.py +22 -24
  993. transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
  994. transformers/models/ovis2/modeling_ovis2.py +99 -142
  995. transformers/models/ovis2/modular_ovis2.py +82 -45
  996. transformers/models/ovis2/processing_ovis2.py +12 -40
  997. transformers/models/owlv2/configuration_owlv2.py +4 -2
  998. transformers/models/owlv2/image_processing_owlv2.py +20 -21
  999. transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
  1000. transformers/models/owlv2/modeling_owlv2.py +122 -114
  1001. transformers/models/owlv2/modular_owlv2.py +11 -12
  1002. transformers/models/owlv2/processing_owlv2.py +20 -49
  1003. transformers/models/owlvit/configuration_owlvit.py +4 -2
  1004. transformers/models/owlvit/image_processing_owlvit.py +21 -22
  1005. transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
  1006. transformers/models/owlvit/modeling_owlvit.py +121 -113
  1007. transformers/models/owlvit/processing_owlvit.py +20 -48
  1008. transformers/models/paddleocr_vl/__init__.py +0 -1
  1009. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
  1010. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
  1011. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
  1012. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
  1013. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
  1014. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
  1015. transformers/models/paligemma/configuration_paligemma.py +4 -1
  1016. transformers/models/paligemma/modeling_paligemma.py +81 -79
  1017. transformers/models/paligemma/processing_paligemma.py +13 -66
  1018. transformers/models/parakeet/configuration_parakeet.py +3 -8
  1019. transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
  1020. transformers/models/parakeet/modeling_parakeet.py +21 -25
  1021. transformers/models/parakeet/modular_parakeet.py +19 -21
  1022. transformers/models/parakeet/processing_parakeet.py +12 -5
  1023. transformers/models/parakeet/tokenization_parakeet.py +2 -4
  1024. transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
  1025. transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
  1026. transformers/models/patchtst/configuration_patchtst.py +6 -9
  1027. transformers/models/patchtst/modeling_patchtst.py +75 -77
  1028. transformers/models/pe_audio/__init__.py +0 -1
  1029. transformers/models/pe_audio/configuration_pe_audio.py +14 -16
  1030. transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
  1031. transformers/models/pe_audio/modeling_pe_audio.py +30 -31
  1032. transformers/models/pe_audio/modular_pe_audio.py +17 -18
  1033. transformers/models/pe_audio/processing_pe_audio.py +0 -1
  1034. transformers/models/pe_audio_video/__init__.py +0 -1
  1035. transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
  1036. transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
  1037. transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
  1038. transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
  1039. transformers/models/pe_video/__init__.py +0 -1
  1040. transformers/models/pe_video/configuration_pe_video.py +14 -16
  1041. transformers/models/pe_video/modeling_pe_video.py +57 -46
  1042. transformers/models/pe_video/modular_pe_video.py +47 -35
  1043. transformers/models/pe_video/video_processing_pe_video.py +2 -4
  1044. transformers/models/pegasus/configuration_pegasus.py +8 -6
  1045. transformers/models/pegasus/modeling_pegasus.py +67 -69
  1046. transformers/models/pegasus/tokenization_pegasus.py +1 -4
  1047. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
  1048. transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
  1049. transformers/models/perceiver/configuration_perceiver.py +0 -1
  1050. transformers/models/perceiver/image_processing_perceiver.py +22 -25
  1051. transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
  1052. transformers/models/perceiver/modeling_perceiver.py +152 -145
  1053. transformers/models/perceiver/tokenization_perceiver.py +3 -6
  1054. transformers/models/perception_lm/configuration_perception_lm.py +0 -1
  1055. transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
  1056. transformers/models/perception_lm/modeling_perception_lm.py +64 -67
  1057. transformers/models/perception_lm/modular_perception_lm.py +58 -58
  1058. transformers/models/perception_lm/processing_perception_lm.py +13 -47
  1059. transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
  1060. transformers/models/persimmon/configuration_persimmon.py +23 -28
  1061. transformers/models/persimmon/modeling_persimmon.py +44 -47
  1062. transformers/models/phi/configuration_phi.py +27 -28
  1063. transformers/models/phi/modeling_phi.py +39 -41
  1064. transformers/models/phi/modular_phi.py +26 -26
  1065. transformers/models/phi3/configuration_phi3.py +32 -37
  1066. transformers/models/phi3/modeling_phi3.py +37 -40
  1067. transformers/models/phi3/modular_phi3.py +16 -20
  1068. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
  1069. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
  1070. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  1071. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
  1072. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
  1073. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
  1074. transformers/models/phimoe/configuration_phimoe.py +31 -36
  1075. transformers/models/phimoe/modeling_phimoe.py +50 -77
  1076. transformers/models/phimoe/modular_phimoe.py +12 -8
  1077. transformers/models/phobert/tokenization_phobert.py +4 -6
  1078. transformers/models/pix2struct/configuration_pix2struct.py +12 -10
  1079. transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
  1080. transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
  1081. transformers/models/pix2struct/modeling_pix2struct.py +56 -52
  1082. transformers/models/pix2struct/processing_pix2struct.py +5 -26
  1083. transformers/models/pixio/__init__.py +0 -1
  1084. transformers/models/pixio/configuration_pixio.py +2 -5
  1085. transformers/models/pixio/modeling_pixio.py +16 -17
  1086. transformers/models/pixio/modular_pixio.py +7 -8
  1087. transformers/models/pixtral/configuration_pixtral.py +11 -14
  1088. transformers/models/pixtral/image_processing_pixtral.py +26 -28
  1089. transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
  1090. transformers/models/pixtral/modeling_pixtral.py +31 -37
  1091. transformers/models/pixtral/processing_pixtral.py +18 -52
  1092. transformers/models/plbart/configuration_plbart.py +8 -6
  1093. transformers/models/plbart/modeling_plbart.py +109 -109
  1094. transformers/models/plbart/modular_plbart.py +31 -33
  1095. transformers/models/plbart/tokenization_plbart.py +4 -5
  1096. transformers/models/poolformer/configuration_poolformer.py +0 -1
  1097. transformers/models/poolformer/image_processing_poolformer.py +21 -24
  1098. transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
  1099. transformers/models/poolformer/modeling_poolformer.py +10 -12
  1100. transformers/models/pop2piano/configuration_pop2piano.py +7 -7
  1101. transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
  1102. transformers/models/pop2piano/modeling_pop2piano.py +24 -24
  1103. transformers/models/pop2piano/processing_pop2piano.py +25 -33
  1104. transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
  1105. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  1106. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  1107. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  1108. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  1109. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  1110. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
  1111. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1112. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
  1113. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
  1114. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
  1115. transformers/models/prophetnet/configuration_prophetnet.py +37 -38
  1116. transformers/models/prophetnet/modeling_prophetnet.py +121 -153
  1117. transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
  1118. transformers/models/pvt/configuration_pvt.py +0 -1
  1119. transformers/models/pvt/image_processing_pvt.py +24 -27
  1120. transformers/models/pvt/image_processing_pvt_fast.py +1 -2
  1121. transformers/models/pvt/modeling_pvt.py +19 -21
  1122. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
  1123. transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
  1124. transformers/models/qwen2/configuration_qwen2.py +32 -25
  1125. transformers/models/qwen2/modeling_qwen2.py +35 -37
  1126. transformers/models/qwen2/modular_qwen2.py +14 -15
  1127. transformers/models/qwen2/tokenization_qwen2.py +2 -9
  1128. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
  1129. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
  1130. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
  1131. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
  1132. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
  1133. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
  1134. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
  1135. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
  1136. transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
  1137. transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
  1138. transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
  1139. transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
  1140. transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
  1141. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
  1142. transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
  1143. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
  1144. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
  1145. transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
  1146. transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
  1147. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
  1148. transformers/models/qwen3/configuration_qwen3.py +34 -27
  1149. transformers/models/qwen3/modeling_qwen3.py +35 -38
  1150. transformers/models/qwen3/modular_qwen3.py +7 -9
  1151. transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
  1152. transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
  1153. transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
  1154. transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
  1155. transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
  1156. transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
  1157. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
  1158. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
  1159. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
  1160. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
  1161. transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
  1162. transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
  1163. transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
  1164. transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
  1165. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
  1166. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
  1167. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
  1168. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
  1169. transformers/models/rag/configuration_rag.py +6 -7
  1170. transformers/models/rag/modeling_rag.py +119 -121
  1171. transformers/models/rag/retrieval_rag.py +3 -5
  1172. transformers/models/rag/tokenization_rag.py +0 -50
  1173. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
  1174. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
  1175. transformers/models/reformer/configuration_reformer.py +7 -8
  1176. transformers/models/reformer/modeling_reformer.py +67 -68
  1177. transformers/models/reformer/tokenization_reformer.py +3 -6
  1178. transformers/models/regnet/configuration_regnet.py +0 -1
  1179. transformers/models/regnet/modeling_regnet.py +7 -9
  1180. transformers/models/rembert/configuration_rembert.py +8 -2
  1181. transformers/models/rembert/modeling_rembert.py +108 -132
  1182. transformers/models/rembert/tokenization_rembert.py +1 -4
  1183. transformers/models/resnet/configuration_resnet.py +2 -5
  1184. transformers/models/resnet/modeling_resnet.py +14 -15
  1185. transformers/models/roberta/configuration_roberta.py +11 -3
  1186. transformers/models/roberta/modeling_roberta.py +97 -99
  1187. transformers/models/roberta/modular_roberta.py +55 -58
  1188. transformers/models/roberta/tokenization_roberta.py +2 -5
  1189. transformers/models/roberta/tokenization_roberta_old.py +2 -4
  1190. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
  1191. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
  1192. transformers/models/roc_bert/configuration_roc_bert.py +8 -2
  1193. transformers/models/roc_bert/modeling_roc_bert.py +125 -162
  1194. transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
  1195. transformers/models/roformer/configuration_roformer.py +13 -3
  1196. transformers/models/roformer/modeling_roformer.py +79 -95
  1197. transformers/models/roformer/tokenization_roformer.py +3 -6
  1198. transformers/models/roformer/tokenization_utils.py +0 -1
  1199. transformers/models/rt_detr/configuration_rt_detr.py +8 -50
  1200. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
  1201. transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
  1202. transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
  1203. transformers/models/rt_detr/modeling_rt_detr.py +643 -804
  1204. transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
  1205. transformers/models/rt_detr/modular_rt_detr.py +1522 -20
  1206. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
  1207. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
  1208. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
  1209. transformers/models/rwkv/configuration_rwkv.py +2 -4
  1210. transformers/models/rwkv/modeling_rwkv.py +29 -54
  1211. transformers/models/sam/configuration_sam.py +2 -1
  1212. transformers/models/sam/image_processing_sam.py +59 -60
  1213. transformers/models/sam/image_processing_sam_fast.py +25 -26
  1214. transformers/models/sam/modeling_sam.py +46 -43
  1215. transformers/models/sam/processing_sam.py +39 -27
  1216. transformers/models/sam2/configuration_sam2.py +1 -2
  1217. transformers/models/sam2/image_processing_sam2_fast.py +14 -15
  1218. transformers/models/sam2/modeling_sam2.py +96 -94
  1219. transformers/models/sam2/modular_sam2.py +85 -94
  1220. transformers/models/sam2/processing_sam2.py +31 -47
  1221. transformers/models/sam2_video/configuration_sam2_video.py +0 -1
  1222. transformers/models/sam2_video/modeling_sam2_video.py +114 -116
  1223. transformers/models/sam2_video/modular_sam2_video.py +72 -89
  1224. transformers/models/sam2_video/processing_sam2_video.py +49 -66
  1225. transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
  1226. transformers/models/sam3/configuration_sam3.py +0 -1
  1227. transformers/models/sam3/image_processing_sam3_fast.py +17 -20
  1228. transformers/models/sam3/modeling_sam3.py +94 -100
  1229. transformers/models/sam3/modular_sam3.py +3 -8
  1230. transformers/models/sam3/processing_sam3.py +37 -52
  1231. transformers/models/sam3_tracker/__init__.py +0 -1
  1232. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
  1233. transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
  1234. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
  1235. transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
  1236. transformers/models/sam3_tracker_video/__init__.py +0 -1
  1237. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
  1238. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
  1239. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
  1240. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
  1241. transformers/models/sam3_video/configuration_sam3_video.py +0 -1
  1242. transformers/models/sam3_video/modeling_sam3_video.py +56 -45
  1243. transformers/models/sam3_video/processing_sam3_video.py +25 -45
  1244. transformers/models/sam_hq/__init__.py +1 -1
  1245. transformers/models/sam_hq/configuration_sam_hq.py +2 -1
  1246. transformers/models/sam_hq/modeling_sam_hq.py +52 -50
  1247. transformers/models/sam_hq/modular_sam_hq.py +23 -25
  1248. transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
  1249. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
  1250. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
  1251. transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
  1252. transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
  1253. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
  1254. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
  1255. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
  1256. transformers/models/seed_oss/configuration_seed_oss.py +30 -34
  1257. transformers/models/seed_oss/modeling_seed_oss.py +34 -36
  1258. transformers/models/seed_oss/modular_seed_oss.py +6 -7
  1259. transformers/models/segformer/configuration_segformer.py +0 -10
  1260. transformers/models/segformer/image_processing_segformer.py +39 -42
  1261. transformers/models/segformer/image_processing_segformer_fast.py +11 -12
  1262. transformers/models/segformer/modeling_segformer.py +28 -28
  1263. transformers/models/segformer/modular_segformer.py +8 -9
  1264. transformers/models/seggpt/configuration_seggpt.py +0 -1
  1265. transformers/models/seggpt/image_processing_seggpt.py +38 -41
  1266. transformers/models/seggpt/modeling_seggpt.py +48 -38
  1267. transformers/models/sew/configuration_sew.py +4 -2
  1268. transformers/models/sew/modeling_sew.py +42 -40
  1269. transformers/models/sew/modular_sew.py +12 -13
  1270. transformers/models/sew_d/configuration_sew_d.py +4 -2
  1271. transformers/models/sew_d/modeling_sew_d.py +32 -31
  1272. transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
  1273. transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
  1274. transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
  1275. transformers/models/siglip/configuration_siglip.py +4 -2
  1276. transformers/models/siglip/image_processing_siglip.py +17 -20
  1277. transformers/models/siglip/image_processing_siglip_fast.py +0 -1
  1278. transformers/models/siglip/modeling_siglip.py +65 -110
  1279. transformers/models/siglip/processing_siglip.py +2 -14
  1280. transformers/models/siglip/tokenization_siglip.py +6 -7
  1281. transformers/models/siglip2/__init__.py +1 -0
  1282. transformers/models/siglip2/configuration_siglip2.py +4 -2
  1283. transformers/models/siglip2/image_processing_siglip2.py +15 -16
  1284. transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
  1285. transformers/models/siglip2/modeling_siglip2.py +89 -130
  1286. transformers/models/siglip2/modular_siglip2.py +95 -48
  1287. transformers/models/siglip2/processing_siglip2.py +2 -14
  1288. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  1289. transformers/models/smollm3/configuration_smollm3.py +29 -32
  1290. transformers/models/smollm3/modeling_smollm3.py +35 -38
  1291. transformers/models/smollm3/modular_smollm3.py +36 -38
  1292. transformers/models/smolvlm/configuration_smolvlm.py +2 -4
  1293. transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
  1294. transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
  1295. transformers/models/smolvlm/modeling_smolvlm.py +124 -96
  1296. transformers/models/smolvlm/modular_smolvlm.py +50 -39
  1297. transformers/models/smolvlm/processing_smolvlm.py +15 -76
  1298. transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
  1299. transformers/models/solar_open/__init__.py +27 -0
  1300. transformers/models/solar_open/configuration_solar_open.py +184 -0
  1301. transformers/models/solar_open/modeling_solar_open.py +642 -0
  1302. transformers/models/solar_open/modular_solar_open.py +224 -0
  1303. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
  1304. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
  1305. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1306. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
  1307. transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
  1308. transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
  1309. transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
  1310. transformers/models/speecht5/configuration_speecht5.py +7 -9
  1311. transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
  1312. transformers/models/speecht5/modeling_speecht5.py +172 -174
  1313. transformers/models/speecht5/number_normalizer.py +0 -1
  1314. transformers/models/speecht5/processing_speecht5.py +3 -37
  1315. transformers/models/speecht5/tokenization_speecht5.py +4 -5
  1316. transformers/models/splinter/configuration_splinter.py +6 -7
  1317. transformers/models/splinter/modeling_splinter.py +62 -59
  1318. transformers/models/splinter/tokenization_splinter.py +2 -4
  1319. transformers/models/squeezebert/configuration_squeezebert.py +14 -2
  1320. transformers/models/squeezebert/modeling_squeezebert.py +60 -62
  1321. transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
  1322. transformers/models/stablelm/configuration_stablelm.py +28 -29
  1323. transformers/models/stablelm/modeling_stablelm.py +44 -47
  1324. transformers/models/starcoder2/configuration_starcoder2.py +30 -27
  1325. transformers/models/starcoder2/modeling_starcoder2.py +38 -41
  1326. transformers/models/starcoder2/modular_starcoder2.py +17 -19
  1327. transformers/models/superglue/configuration_superglue.py +7 -3
  1328. transformers/models/superglue/image_processing_superglue.py +15 -15
  1329. transformers/models/superglue/image_processing_superglue_fast.py +8 -8
  1330. transformers/models/superglue/modeling_superglue.py +41 -37
  1331. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1332. transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
  1333. transformers/models/superpoint/modeling_superpoint.py +17 -16
  1334. transformers/models/swiftformer/configuration_swiftformer.py +0 -1
  1335. transformers/models/swiftformer/modeling_swiftformer.py +12 -14
  1336. transformers/models/swin/configuration_swin.py +2 -5
  1337. transformers/models/swin/modeling_swin.py +69 -78
  1338. transformers/models/swin2sr/configuration_swin2sr.py +0 -1
  1339. transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
  1340. transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
  1341. transformers/models/swin2sr/modeling_swin2sr.py +30 -30
  1342. transformers/models/swinv2/configuration_swinv2.py +2 -5
  1343. transformers/models/swinv2/modeling_swinv2.py +65 -74
  1344. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
  1345. transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
  1346. transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
  1347. transformers/models/t5/configuration_t5.py +9 -9
  1348. transformers/models/t5/modeling_t5.py +80 -85
  1349. transformers/models/t5/tokenization_t5.py +1 -3
  1350. transformers/models/t5gemma/configuration_t5gemma.py +43 -59
  1351. transformers/models/t5gemma/modeling_t5gemma.py +105 -108
  1352. transformers/models/t5gemma/modular_t5gemma.py +128 -142
  1353. transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
  1354. transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
  1355. transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
  1356. transformers/models/table_transformer/configuration_table_transformer.py +18 -50
  1357. transformers/models/table_transformer/modeling_table_transformer.py +73 -101
  1358. transformers/models/tapas/configuration_tapas.py +12 -2
  1359. transformers/models/tapas/modeling_tapas.py +65 -67
  1360. transformers/models/tapas/tokenization_tapas.py +116 -153
  1361. transformers/models/textnet/configuration_textnet.py +4 -7
  1362. transformers/models/textnet/image_processing_textnet.py +22 -25
  1363. transformers/models/textnet/image_processing_textnet_fast.py +8 -9
  1364. transformers/models/textnet/modeling_textnet.py +28 -28
  1365. transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
  1366. transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
  1367. transformers/models/timesfm/configuration_timesfm.py +0 -1
  1368. transformers/models/timesfm/modeling_timesfm.py +22 -25
  1369. transformers/models/timesfm/modular_timesfm.py +21 -24
  1370. transformers/models/timesformer/configuration_timesformer.py +0 -1
  1371. transformers/models/timesformer/modeling_timesformer.py +13 -16
  1372. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
  1373. transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
  1374. transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
  1375. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
  1376. transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
  1377. transformers/models/trocr/configuration_trocr.py +11 -8
  1378. transformers/models/trocr/modeling_trocr.py +42 -42
  1379. transformers/models/trocr/processing_trocr.py +5 -25
  1380. transformers/models/tvp/configuration_tvp.py +10 -36
  1381. transformers/models/tvp/image_processing_tvp.py +50 -52
  1382. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1383. transformers/models/tvp/modeling_tvp.py +26 -28
  1384. transformers/models/tvp/processing_tvp.py +2 -14
  1385. transformers/models/udop/configuration_udop.py +16 -8
  1386. transformers/models/udop/modeling_udop.py +73 -72
  1387. transformers/models/udop/processing_udop.py +7 -26
  1388. transformers/models/udop/tokenization_udop.py +80 -93
  1389. transformers/models/umt5/configuration_umt5.py +8 -7
  1390. transformers/models/umt5/modeling_umt5.py +87 -84
  1391. transformers/models/unispeech/configuration_unispeech.py +4 -2
  1392. transformers/models/unispeech/modeling_unispeech.py +54 -53
  1393. transformers/models/unispeech/modular_unispeech.py +20 -22
  1394. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
  1395. transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
  1396. transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
  1397. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1398. transformers/models/univnet/modeling_univnet.py +7 -8
  1399. transformers/models/upernet/configuration_upernet.py +8 -36
  1400. transformers/models/upernet/modeling_upernet.py +11 -14
  1401. transformers/models/vaultgemma/__init__.py +0 -1
  1402. transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
  1403. transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
  1404. transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
  1405. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  1406. transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
  1407. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
  1408. transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
  1409. transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
  1410. transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
  1411. transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
  1412. transformers/models/video_llava/configuration_video_llava.py +4 -1
  1413. transformers/models/video_llava/image_processing_video_llava.py +35 -38
  1414. transformers/models/video_llava/modeling_video_llava.py +139 -143
  1415. transformers/models/video_llava/processing_video_llava.py +38 -78
  1416. transformers/models/video_llava/video_processing_video_llava.py +0 -1
  1417. transformers/models/videomae/configuration_videomae.py +0 -1
  1418. transformers/models/videomae/image_processing_videomae.py +31 -34
  1419. transformers/models/videomae/modeling_videomae.py +17 -20
  1420. transformers/models/videomae/video_processing_videomae.py +0 -1
  1421. transformers/models/vilt/configuration_vilt.py +4 -2
  1422. transformers/models/vilt/image_processing_vilt.py +29 -30
  1423. transformers/models/vilt/image_processing_vilt_fast.py +15 -16
  1424. transformers/models/vilt/modeling_vilt.py +103 -90
  1425. transformers/models/vilt/processing_vilt.py +2 -14
  1426. transformers/models/vipllava/configuration_vipllava.py +4 -1
  1427. transformers/models/vipllava/modeling_vipllava.py +92 -67
  1428. transformers/models/vipllava/modular_vipllava.py +78 -54
  1429. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
  1430. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
  1431. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
  1432. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
  1433. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
  1434. transformers/models/visual_bert/configuration_visual_bert.py +6 -2
  1435. transformers/models/visual_bert/modeling_visual_bert.py +90 -92
  1436. transformers/models/vit/configuration_vit.py +2 -3
  1437. transformers/models/vit/image_processing_vit.py +19 -22
  1438. transformers/models/vit/image_processing_vit_fast.py +0 -1
  1439. transformers/models/vit/modeling_vit.py +20 -20
  1440. transformers/models/vit_mae/configuration_vit_mae.py +0 -1
  1441. transformers/models/vit_mae/modeling_vit_mae.py +32 -30
  1442. transformers/models/vit_msn/configuration_vit_msn.py +0 -1
  1443. transformers/models/vit_msn/modeling_vit_msn.py +21 -19
  1444. transformers/models/vitdet/configuration_vitdet.py +2 -5
  1445. transformers/models/vitdet/modeling_vitdet.py +14 -17
  1446. transformers/models/vitmatte/configuration_vitmatte.py +7 -39
  1447. transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
  1448. transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
  1449. transformers/models/vitmatte/modeling_vitmatte.py +10 -12
  1450. transformers/models/vitpose/configuration_vitpose.py +7 -47
  1451. transformers/models/vitpose/image_processing_vitpose.py +24 -25
  1452. transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
  1453. transformers/models/vitpose/modeling_vitpose.py +15 -15
  1454. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
  1455. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
  1456. transformers/models/vits/configuration_vits.py +4 -1
  1457. transformers/models/vits/modeling_vits.py +43 -42
  1458. transformers/models/vits/tokenization_vits.py +3 -4
  1459. transformers/models/vivit/configuration_vivit.py +0 -1
  1460. transformers/models/vivit/image_processing_vivit.py +36 -39
  1461. transformers/models/vivit/modeling_vivit.py +9 -11
  1462. transformers/models/vjepa2/__init__.py +0 -1
  1463. transformers/models/vjepa2/configuration_vjepa2.py +0 -1
  1464. transformers/models/vjepa2/modeling_vjepa2.py +39 -41
  1465. transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
  1466. transformers/models/voxtral/__init__.py +0 -1
  1467. transformers/models/voxtral/configuration_voxtral.py +0 -2
  1468. transformers/models/voxtral/modeling_voxtral.py +41 -48
  1469. transformers/models/voxtral/modular_voxtral.py +35 -38
  1470. transformers/models/voxtral/processing_voxtral.py +25 -48
  1471. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
  1472. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
  1473. transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
  1474. transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
  1475. transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
  1476. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
  1477. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
  1478. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
  1479. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
  1480. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
  1481. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
  1482. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
  1483. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
  1484. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
  1485. transformers/models/wavlm/configuration_wavlm.py +4 -2
  1486. transformers/models/wavlm/modeling_wavlm.py +49 -49
  1487. transformers/models/wavlm/modular_wavlm.py +4 -5
  1488. transformers/models/whisper/configuration_whisper.py +6 -5
  1489. transformers/models/whisper/english_normalizer.py +3 -4
  1490. transformers/models/whisper/feature_extraction_whisper.py +9 -24
  1491. transformers/models/whisper/generation_whisper.py +26 -49
  1492. transformers/models/whisper/modeling_whisper.py +71 -73
  1493. transformers/models/whisper/processing_whisper.py +3 -20
  1494. transformers/models/whisper/tokenization_whisper.py +9 -30
  1495. transformers/models/x_clip/configuration_x_clip.py +4 -2
  1496. transformers/models/x_clip/modeling_x_clip.py +94 -96
  1497. transformers/models/x_clip/processing_x_clip.py +2 -14
  1498. transformers/models/xcodec/configuration_xcodec.py +4 -6
  1499. transformers/models/xcodec/modeling_xcodec.py +15 -17
  1500. transformers/models/xglm/configuration_xglm.py +9 -8
  1501. transformers/models/xglm/modeling_xglm.py +49 -55
  1502. transformers/models/xglm/tokenization_xglm.py +1 -4
  1503. transformers/models/xlm/configuration_xlm.py +10 -8
  1504. transformers/models/xlm/modeling_xlm.py +127 -131
  1505. transformers/models/xlm/tokenization_xlm.py +3 -5
  1506. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
  1507. transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
  1508. transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
  1509. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
  1510. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
  1511. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
  1512. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
  1513. transformers/models/xlnet/configuration_xlnet.py +3 -12
  1514. transformers/models/xlnet/modeling_xlnet.py +149 -162
  1515. transformers/models/xlnet/tokenization_xlnet.py +1 -4
  1516. transformers/models/xlstm/configuration_xlstm.py +8 -12
  1517. transformers/models/xlstm/modeling_xlstm.py +61 -96
  1518. transformers/models/xmod/configuration_xmod.py +11 -3
  1519. transformers/models/xmod/modeling_xmod.py +111 -116
  1520. transformers/models/yolos/configuration_yolos.py +0 -1
  1521. transformers/models/yolos/image_processing_yolos.py +60 -62
  1522. transformers/models/yolos/image_processing_yolos_fast.py +42 -45
  1523. transformers/models/yolos/modeling_yolos.py +19 -21
  1524. transformers/models/yolos/modular_yolos.py +17 -19
  1525. transformers/models/yoso/configuration_yoso.py +8 -2
  1526. transformers/models/yoso/modeling_yoso.py +60 -62
  1527. transformers/models/youtu/__init__.py +27 -0
  1528. transformers/models/youtu/configuration_youtu.py +194 -0
  1529. transformers/models/youtu/modeling_youtu.py +619 -0
  1530. transformers/models/youtu/modular_youtu.py +254 -0
  1531. transformers/models/zamba/configuration_zamba.py +5 -8
  1532. transformers/models/zamba/modeling_zamba.py +93 -125
  1533. transformers/models/zamba2/configuration_zamba2.py +44 -50
  1534. transformers/models/zamba2/modeling_zamba2.py +137 -165
  1535. transformers/models/zamba2/modular_zamba2.py +79 -74
  1536. transformers/models/zoedepth/configuration_zoedepth.py +17 -41
  1537. transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
  1538. transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
  1539. transformers/models/zoedepth/modeling_zoedepth.py +19 -19
  1540. transformers/pipelines/__init__.py +47 -106
  1541. transformers/pipelines/any_to_any.py +15 -23
  1542. transformers/pipelines/audio_utils.py +1 -2
  1543. transformers/pipelines/automatic_speech_recognition.py +0 -2
  1544. transformers/pipelines/base.py +13 -17
  1545. transformers/pipelines/image_text_to_text.py +1 -2
  1546. transformers/pipelines/question_answering.py +4 -43
  1547. transformers/pipelines/text_classification.py +1 -14
  1548. transformers/pipelines/text_to_audio.py +5 -1
  1549. transformers/pipelines/token_classification.py +1 -22
  1550. transformers/pipelines/video_classification.py +1 -9
  1551. transformers/pipelines/zero_shot_audio_classification.py +0 -1
  1552. transformers/pipelines/zero_shot_classification.py +0 -6
  1553. transformers/pipelines/zero_shot_image_classification.py +0 -7
  1554. transformers/processing_utils.py +128 -137
  1555. transformers/pytorch_utils.py +2 -26
  1556. transformers/quantizers/base.py +10 -0
  1557. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  1558. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  1559. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  1560. transformers/quantizers/quantizer_mxfp4.py +1 -1
  1561. transformers/quantizers/quantizer_quark.py +0 -1
  1562. transformers/quantizers/quantizer_torchao.py +3 -19
  1563. transformers/safetensors_conversion.py +11 -4
  1564. transformers/testing_utils.py +6 -65
  1565. transformers/tokenization_mistral_common.py +563 -903
  1566. transformers/tokenization_python.py +6 -4
  1567. transformers/tokenization_utils_base.py +228 -341
  1568. transformers/tokenization_utils_sentencepiece.py +5 -6
  1569. transformers/tokenization_utils_tokenizers.py +36 -7
  1570. transformers/trainer.py +30 -41
  1571. transformers/trainer_jit_checkpoint.py +1 -2
  1572. transformers/trainer_seq2seq.py +1 -1
  1573. transformers/training_args.py +414 -420
  1574. transformers/utils/__init__.py +1 -4
  1575. transformers/utils/attention_visualizer.py +1 -1
  1576. transformers/utils/auto_docstring.py +567 -18
  1577. transformers/utils/backbone_utils.py +13 -373
  1578. transformers/utils/doc.py +4 -36
  1579. transformers/utils/dummy_pt_objects.py +0 -42
  1580. transformers/utils/generic.py +70 -34
  1581. transformers/utils/import_utils.py +72 -75
  1582. transformers/utils/loading_report.py +135 -107
  1583. transformers/utils/quantization_config.py +8 -31
  1584. transformers/video_processing_utils.py +24 -25
  1585. transformers/video_utils.py +21 -23
  1586. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
  1587. transformers-5.1.0.dist-info/RECORD +2092 -0
  1588. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1589. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1590. transformers/pipelines/image_to_text.py +0 -229
  1591. transformers-5.0.0rc2.dist-info/RECORD +0 -2042
  1592. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1593. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1594. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,3 @@
1
- # coding=utf-8
2
1
  # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3
2
  # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
3
  #
@@ -25,13 +24,14 @@ import sys
25
24
  import warnings
26
25
  from abc import abstractmethod
27
26
  from collections import defaultdict
28
- from collections.abc import Callable, Iterator, Sequence
27
+ from collections.abc import Callable, Iterator
29
28
  from contextlib import contextmanager
29
+ from dataclasses import dataclass, field
30
30
  from enum import Enum
31
31
  from functools import partial, wraps
32
32
  from itertools import cycle
33
33
  from threading import Thread
34
- from typing import Optional, TypeVar, Union, get_type_hints
34
+ from typing import Optional, TypeVar, get_type_hints
35
35
  from zipfile import is_zipfile
36
36
 
37
37
  import torch
@@ -78,9 +78,8 @@ from .integrations.tensor_parallel import (
78
78
  ALL_PARALLEL_STYLES,
79
79
  _get_parameter_tp_plan,
80
80
  distribute_model,
81
+ gather_state_dict_for_save,
81
82
  initialize_tensor_parallelism,
82
- repack_weights,
83
- replace_state_dict_local_with_dtensor,
84
83
  shard_and_distribute_module,
85
84
  verify_tp_plan,
86
85
  )
@@ -107,25 +106,26 @@ from .utils import (
107
106
  copy_func,
108
107
  has_file,
109
108
  is_accelerate_available,
109
+ is_bitsandbytes_available,
110
+ is_env_variable_true,
110
111
  is_flash_attn_2_available,
111
112
  is_flash_attn_3_available,
112
113
  is_grouped_mm_available,
113
114
  is_kernels_available,
114
115
  is_torch_flex_attn_available,
115
- is_torch_greater_or_equal,
116
116
  is_torch_mlu_available,
117
117
  is_torch_npu_available,
118
118
  is_torch_xpu_available,
119
119
  logging,
120
120
  )
121
- from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
121
+ from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder, is_flash_attention_requested
122
122
  from .utils.hub import DownloadKwargs, create_and_tag_model_card, get_checkpoint_shard_files
123
123
  from .utils.import_utils import (
124
124
  is_huggingface_hub_greater_or_equal,
125
125
  is_sagemaker_mp_enabled,
126
126
  is_tracing,
127
127
  )
128
- from .utils.loading_report import log_state_dict_report
128
+ from .utils.loading_report import LoadStateDictInfo, log_state_dict_report
129
129
  from .utils.quantization_config import QuantizationMethod
130
130
 
131
131
 
@@ -135,9 +135,6 @@ if is_accelerate_available():
135
135
 
136
136
 
137
137
  _torch_distributed_available = torch.distributed.is_available()
138
- _is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
139
- if _is_dtensor_available:
140
- from torch.distributed.tensor import DTensor
141
138
 
142
139
  if is_sagemaker_mp_enabled():
143
140
  import smdistributed.modelparallel.torch as smp
@@ -163,6 +160,33 @@ FLASH_ATTN_KERNEL_FALLBACK = {
163
160
  }
164
161
 
165
162
 
163
+ @dataclass(frozen=True)
164
+ class LoadStateDictConfig:
165
+ """
166
+ Config for loading weights. This allows bundling arguments that are just
167
+ passed around.
168
+ """
169
+
170
+ pretrained_model_name_or_path: str | None = None
171
+ download_kwargs: DownloadKwargs | None = field(default_factory=DownloadKwargs)
172
+ use_safetensors: bool | None = None
173
+ ignore_mismatched_sizes: bool = False
174
+ sharded_metadata: dict | None = None
175
+ device_map: dict | None = None
176
+ disk_offload_folder: str | None = None
177
+ offload_buffers: bool = False
178
+ dtype: torch.dtype | None = None
179
+ dtype_plan: dict = field(default_factory=dict)
180
+ hf_quantizer: HfQuantizer | None = None
181
+ device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None
182
+ weights_only: bool = True
183
+ weight_mapping: list[WeightConverter | WeightRenaming] | None = None
184
+
185
+ @property
186
+ def is_quantized(self) -> bool:
187
+ return self.hf_quantizer is not None
188
+
189
+
166
190
  def is_local_dist_rank_0():
167
191
  return (
168
192
  torch.distributed.is_available()
@@ -224,8 +248,7 @@ def get_torch_context_manager_or_global_device():
224
248
  is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided.
225
249
  """
226
250
  device_in_context = torch.tensor([]).device
227
- # `get_default_device` was only introduced in torch>=2.3 - use cpu otherwise to align the behavior
228
- default_device = torch.get_default_device() if is_torch_greater_or_equal("2.3") else torch.device("cpu")
251
+ default_device = torch.get_default_device()
229
252
  # This case means no context manager was used -> we still check if the default that was potentially set is not cpu
230
253
  if device_in_context == default_device:
231
254
  if default_device != torch.device("cpu"):
@@ -253,25 +276,22 @@ str_to_torch_dtype = {
253
276
  "U8": torch.uint8,
254
277
  "I8": torch.int8,
255
278
  "I16": torch.int16,
279
+ "U16": torch.uint16,
256
280
  "F16": torch.float16,
257
281
  "BF16": torch.bfloat16,
258
282
  "I32": torch.int32,
283
+ "U32": torch.uint32,
259
284
  "F32": torch.float32,
260
285
  "F64": torch.float64,
261
286
  "I64": torch.int64,
287
+ "U64": torch.uint64,
262
288
  "F8_E4M3": torch.float8_e4m3fn,
263
289
  "F8_E5M2": torch.float8_e5m2,
264
290
  }
265
291
 
266
292
 
267
- if is_torch_greater_or_equal("2.3.0"):
268
- str_to_torch_dtype["U16"] = torch.uint16
269
- str_to_torch_dtype["U32"] = torch.uint32
270
- str_to_torch_dtype["U64"] = torch.uint64
271
-
272
-
273
293
  def load_state_dict(
274
- checkpoint_file: Union[str, os.PathLike], map_location: Union[str, torch.device] = "cpu", weights_only: bool = True
294
+ checkpoint_file: str | os.PathLike, map_location: str | torch.device = "cpu", weights_only: bool = True
275
295
  ) -> dict[str, torch.Tensor]:
276
296
  """
277
297
  Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
@@ -461,7 +481,7 @@ def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor
461
481
  setattr(parent, param_type, tensor)
462
482
 
463
483
 
464
- def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
484
+ def _add_variant(weights_name: str, variant: str | None = None) -> str:
465
485
  if variant is not None:
466
486
  path, name = weights_name.rsplit(".", 1)
467
487
  weights_name = f"{path}.{variant}.{name}"
@@ -469,19 +489,20 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
469
489
 
470
490
 
471
491
  def _get_resolved_checkpoint_files(
472
- pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
473
- variant: Optional[str],
474
- gguf_file: Optional[str],
475
- use_safetensors: Optional[bool],
476
- download_kwargs: DownloadKwargs,
477
- user_agent: dict,
492
+ pretrained_model_name_or_path: str | os.PathLike | None,
493
+ variant: str | None,
494
+ gguf_file: str | None,
495
+ use_safetensors: bool | None,
496
+ user_agent: dict | None,
478
497
  is_remote_code: bool, # Because we can't determine this inside this function, we need it to be passed in
479
- transformers_explicit_filename: Optional[str] = None,
480
- ) -> tuple[Optional[list[str]], Optional[dict]]:
498
+ transformers_explicit_filename: str | None = None,
499
+ download_kwargs: DownloadKwargs | None = None,
500
+ ) -> tuple[list[str] | None, dict | None]:
481
501
  """Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
482
502
  checkpoints are sharded.
483
503
  This function will download the data if necessary.
484
504
  """
505
+ download_kwargs = download_kwargs or DownloadKwargs()
485
506
  cache_dir = download_kwargs.get("cache_dir")
486
507
  force_download = download_kwargs.get("force_download", False)
487
508
  proxies = download_kwargs.get("proxies")
@@ -494,17 +515,19 @@ def _get_resolved_checkpoint_files(
494
515
  if not transformers_explicit_filename.endswith(".safetensors") and not transformers_explicit_filename.endswith(
495
516
  ".safetensors.index.json"
496
517
  ):
497
- raise ValueError(
498
- "The transformers file in the config seems to be incorrect: it is neither a safetensors file "
499
- "(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
500
- f"{transformers_explicit_filename}"
501
- )
518
+ if transformers_explicit_filename != "adapter_model.bin":
519
+ raise ValueError(
520
+ "The transformers file in the config seems to be incorrect: it is neither a safetensors file "
521
+ "(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
522
+ f"{transformers_explicit_filename}"
523
+ )
502
524
 
503
525
  is_sharded = False
504
526
 
505
527
  if pretrained_model_name_or_path is not None and gguf_file is None:
506
528
  pretrained_model_name_or_path = str(pretrained_model_name_or_path)
507
529
  is_local = os.path.isdir(pretrained_model_name_or_path)
530
+ # If the file is a local folder (but not in the HF_HOME cache, even if it's technically local)
508
531
  if is_local:
509
532
  if transformers_explicit_filename is not None:
510
533
  # If the filename is explicitly defined, load this by default.
@@ -563,25 +586,38 @@ def _get_resolved_checkpoint_files(
563
586
  else:
564
587
  filename = _add_variant(WEIGHTS_NAME, variant)
565
588
 
589
+ # Prepare set of kwargs for hub functions
590
+ has_file_kwargs = {
591
+ "revision": revision,
592
+ "proxies": proxies,
593
+ "token": token,
594
+ "cache_dir": cache_dir,
595
+ "local_files_only": local_files_only,
596
+ }
597
+ cached_file_kwargs = {
598
+ "force_download": force_download,
599
+ "user_agent": user_agent,
600
+ "subfolder": subfolder,
601
+ "_raise_exceptions_for_gated_repo": False,
602
+ "_raise_exceptions_for_missing_entries": False,
603
+ "_commit_hash": commit_hash,
604
+ **has_file_kwargs,
605
+ }
606
+ can_auto_convert = (
607
+ not is_offline_mode() # for obvious reasons
608
+ # If we are in a CI environment or in a pytest run, we prevent the conversion
609
+ and not is_env_variable_true("DISABLE_SAFETENSORS_CONVERSION")
610
+ and not is_remote_code # converter bot does not work on remote code
611
+ and subfolder == "" # converter bot does not work on subfolders
612
+ )
613
+
566
614
  try:
567
615
  # Load from URL or cache if already cached
568
- cached_file_kwargs = {
569
- "cache_dir": cache_dir,
570
- "force_download": force_download,
571
- "proxies": proxies,
572
- "local_files_only": local_files_only,
573
- "token": token,
574
- "user_agent": user_agent,
575
- "revision": revision,
576
- "subfolder": subfolder,
577
- "_raise_exceptions_for_gated_repo": False,
578
- "_raise_exceptions_for_missing_entries": False,
579
- "_commit_hash": commit_hash,
580
- }
581
- resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
582
-
583
616
  # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
584
617
  # result when internet is up, the repo and revision exist, but the file does not.
618
+ resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
619
+
620
+ # Try safetensors files first if not already found
585
621
  if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
586
622
  # Maybe the checkpoint is sharded, we try to grab the index name in this case.
587
623
  resolved_archive_file = cached_file(
@@ -592,7 +628,7 @@ def _get_resolved_checkpoint_files(
592
628
  if resolved_archive_file is not None:
593
629
  is_sharded = True
594
630
  elif use_safetensors:
595
- if revision == "main" and not is_offline_mode():
631
+ if revision == "main" and can_auto_convert:
596
632
  resolved_archive_file, revision, is_sharded = auto_conversion(
597
633
  pretrained_model_name_or_path, **cached_file_kwargs
598
634
  )
@@ -609,6 +645,8 @@ def _get_resolved_checkpoint_files(
609
645
  resolved_archive_file = cached_file(
610
646
  pretrained_model_name_or_path, filename, **cached_file_kwargs
611
647
  )
648
+
649
+ # Then try `.bin` files
612
650
  if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
613
651
  # Maybe the checkpoint is sharded, we try to grab the index name in this case.
614
652
  resolved_archive_file = cached_file(
@@ -618,67 +656,38 @@ def _get_resolved_checkpoint_files(
618
656
  )
619
657
  if resolved_archive_file is not None:
620
658
  is_sharded = True
621
- if not local_files_only and not is_offline_mode():
622
- if resolved_archive_file is not None:
623
- # In a CI environment (CircleCI / Github Actions workflow runs) or in a pytest run,
624
- # we set `DISABLE_SAFETENSORS_CONVERSION=true` to prevent the conversion.
625
- if (
626
- filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]
627
- and os.getenv("DISABLE_SAFETENSORS_CONVERSION", None) != "true"
628
- ):
629
- # If the PyTorch file was found, check if there is a safetensors file on the repository
630
- # If there is no safetensors file on the repositories, start an auto conversion
631
- safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
632
- has_file_kwargs = {
633
- "revision": revision,
634
- "proxies": proxies,
635
- "token": token,
636
- "cache_dir": cache_dir,
637
- "local_files_only": local_files_only,
638
- }
639
- cached_file_kwargs = {
640
- "cache_dir": cache_dir,
641
- "force_download": force_download,
642
- "local_files_only": local_files_only,
643
- "user_agent": user_agent,
644
- "subfolder": subfolder,
645
- "_raise_exceptions_for_gated_repo": False,
646
- "_raise_exceptions_for_missing_entries": False,
647
- "_commit_hash": commit_hash,
648
- **has_file_kwargs,
649
- }
650
- if (
651
- not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
652
- and not is_remote_code
653
- ):
654
- Thread(
655
- target=auto_conversion,
656
- args=(pretrained_model_name_or_path,),
657
- kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
658
- name="Thread-auto_conversion",
659
- ).start()
659
+
660
+ # If we have a match, but it's `.bin` format, try to launch safetensors conversion for next time
661
+ if resolved_archive_file is not None:
662
+ safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
663
+ if (
664
+ filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]
665
+ and not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
666
+ and can_auto_convert
667
+ ):
668
+ Thread(
669
+ target=auto_conversion,
670
+ args=(pretrained_model_name_or_path,),
671
+ kwargs={"ignore_errors_during_conversion": False, **cached_file_kwargs},
672
+ name="Thread-auto_conversion",
673
+ ).start()
674
+
675
+ # If no match, raise appropriare errors
676
+ else:
677
+ # Otherwise, no PyTorch file was found
678
+ if variant is not None and has_file(
679
+ pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
680
+ ):
681
+ raise OSError(
682
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
683
+ f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
684
+ f" {variant}. Use `variant=None` to load this model from those weights."
685
+ )
660
686
  else:
661
- # Otherwise, no PyTorch file was found
662
- has_file_kwargs = {
663
- "revision": revision,
664
- "proxies": proxies,
665
- "token": token,
666
- "cache_dir": cache_dir,
667
- "local_files_only": local_files_only,
668
- }
669
- if variant is not None and has_file(
670
- pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
671
- ):
672
- raise OSError(
673
- f"{pretrained_model_name_or_path} does not appear to have a file named"
674
- f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
675
- f" {variant}. Use `variant=None` to load this model from those weights."
676
- )
677
- else:
678
- raise OSError(
679
- f"{pretrained_model_name_or_path} does not appear to have a file named"
680
- f" {_add_variant(WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_NAME, variant)}."
681
- )
687
+ raise OSError(
688
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
689
+ f" {_add_variant(WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_NAME, variant)}."
690
+ )
682
691
 
683
692
  except OSError:
684
693
  # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
@@ -745,13 +754,13 @@ def _get_resolved_checkpoint_files(
745
754
 
746
755
 
747
756
  def _get_dtype(
748
- dtype: Optional[Union[str, torch.dtype, dict]],
749
- checkpoint_files: Optional[list[str]],
757
+ dtype: str | torch.dtype | dict | None,
758
+ checkpoint_files: list[str] | None,
750
759
  config: PreTrainedConfig,
751
- sharded_metadata: Optional[dict],
752
- state_dict: Optional[dict],
760
+ sharded_metadata: dict | None,
761
+ state_dict: dict | None,
753
762
  weights_only: bool,
754
- hf_quantizer: Optional[HfQuantizer] = None,
763
+ hf_quantizer: HfQuantizer | None = None,
755
764
  ) -> tuple[PreTrainedConfig, torch.dtype]:
756
765
  """Find the correct `dtype` to use based on provided arguments. Also update the `config` based on the
757
766
  inferred dtype. We do the following:
@@ -760,7 +769,6 @@ def _get_dtype(
760
769
  2. Else, use the dtype provided as a dict or str
761
770
  """
762
771
  is_sharded = sharded_metadata is not None
763
- asked_dtype = dtype
764
772
 
765
773
  if dtype is not None:
766
774
  if isinstance(dtype, str):
@@ -807,6 +815,13 @@ def _get_dtype(
807
815
  if isinstance(dtype, dict):
808
816
  main_dtype = dtype.get("", torch.get_default_dtype())
809
817
  main_dtype = getattr(torch, main_dtype) if isinstance(main_dtype, str) else main_dtype
818
+
819
+ logger.warning_once(
820
+ "Using different dtypes per module is deprecated and will be removed in future versions "
821
+ "Setting different dtypes per backbone model might cause device errors downstream, therefore "
822
+ f"setting the dtype={main_dtype} for all modules."
823
+ )
824
+
810
825
  else:
811
826
  main_dtype = dtype
812
827
 
@@ -814,17 +829,7 @@ def _get_dtype(
814
829
  config.dtype = main_dtype
815
830
  for sub_config_key in config.sub_configs:
816
831
  if (sub_config := getattr(config, sub_config_key)) is not None:
817
- # The dtype was "auto" -> try to read the subconfig dtype value if any
818
- if asked_dtype == "auto":
819
- sub_dtype = getattr(sub_config, "dtype", main_dtype)
820
- sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
821
- # The dtype was provided as a dict, try to see if we match the subconfig name
822
- elif isinstance(dtype, dict):
823
- sub_dtype = dtype.get(sub_config_key, main_dtype)
824
- sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
825
- else:
826
- sub_dtype = main_dtype
827
- sub_config.dtype = sub_dtype
832
+ sub_config.dtype = main_dtype
828
833
 
829
834
  return config, main_dtype
830
835
 
@@ -877,13 +882,8 @@ class ModuleUtilsMixin:
877
882
  return encoder_extended_attention_mask
878
883
 
879
884
  @staticmethod
880
- def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
881
- if device is not None:
882
- warnings.warn(
883
- "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
884
- )
885
- else:
886
- device = attention_mask.device
885
+ def create_extended_attention_mask_for_decoder(input_shape, attention_mask):
886
+ device = attention_mask.device
887
887
  batch_size, seq_length = input_shape
888
888
  seq_ids = torch.arange(seq_length, device=device)
889
889
  causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
@@ -907,8 +907,7 @@ class ModuleUtilsMixin:
907
907
  self,
908
908
  attention_mask: Tensor,
909
909
  input_shape: tuple[int, ...],
910
- device: Optional[torch.device] = None,
911
- dtype: Optional[torch.dtype] = None,
910
+ dtype: torch.dtype | None = None,
912
911
  ) -> Tensor:
913
912
  """
914
913
  Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
@@ -925,12 +924,6 @@ class ModuleUtilsMixin:
925
924
  if dtype is None:
926
925
  dtype = self.dtype
927
926
 
928
- if not (attention_mask.dim() == 2 and self.config.is_decoder):
929
- # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
930
- if device is not None:
931
- warnings.warn(
932
- "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
933
- )
934
927
  # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
935
928
  # ourselves in which case we just need to make it broadcastable to all heads.
936
929
  if attention_mask.dim() == 3:
@@ -939,9 +932,9 @@ class ModuleUtilsMixin:
939
932
  # Provided a padding mask of dimensions [batch_size, seq_length]
940
933
  # - if the model is a decoder, apply a causal mask in addition to the padding mask
941
934
  # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
942
- if self.config.is_decoder:
935
+ if getattr(self.config, "is_decoder", None):
943
936
  extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
944
- input_shape, attention_mask, device
937
+ input_shape, attention_mask
945
938
  )
946
939
  else:
947
940
  extended_attention_mask = attention_mask[:, None, None, :]
@@ -1112,83 +1105,67 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1112
1105
  - **can_record_outputs** (dict):
1113
1106
  """
1114
1107
 
1115
- config_class = None
1116
- base_model_prefix = ""
1117
- main_input_name = "input_ids"
1118
- model_tags = None
1119
-
1120
- _checkpoint_conversion_mapping = {} # used for BC support in VLMs, not meant to be used by new models
1121
-
1108
+ # General model properties
1109
+ config_class: type[PreTrainedConfig] | None = None
1122
1110
  _auto_class = None
1123
- _no_split_modules = None
1124
- _skip_keys_device_placement = None
1125
-
1126
- _keep_in_fp32_modules = None
1127
- # the _keep_in_fp32_modules will avoid casting to anything other than float32, except bfloat16
1128
- # to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
1129
- _keep_in_fp32_modules_strict = None
1130
-
1131
- dtype_plan: Optional[dict[str, torch.dtype]] = None
1132
-
1133
- # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
1134
- # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
1135
- _keys_to_ignore_on_load_missing = None
1136
- # a list of `re` patterns of `state_dict` keys that should be removed from the list of
1137
- # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
1138
- # warnings.
1139
- _keys_to_ignore_on_load_unexpected = None
1140
- # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
1141
- # trained, but which are either deterministic or tied variables)
1142
- _keys_to_ignore_on_save = None
1143
- # a list of `state_dict` keys that are potentially tied to another key in the state_dict.
1144
- _tied_weights_keys = None
1145
-
1146
- supports_gradient_checkpointing = False
1147
- _is_stateful = False
1148
-
1149
- # Flash Attention support
1150
- _supports_flash_attn = False
1151
-
1152
- # SDPA support
1153
- _supports_sdpa = False
1154
-
1155
- # Flex Attention support
1156
- _supports_flex_attn = False
1157
-
1158
- _can_compile_fullgraph = False
1159
-
1160
- # A tensor parallel plan to be applied to the model when TP is enabled. For
1161
- # top-level models, this attribute is currently defined in respective model
1162
- # code. For base models, this attribute comes from
1163
- # `config.base_model_tp_plan` during `__init__`.
1164
- # It should identify the layers exactly: if you want to TP model.language_model.layers.fc1
1165
- # by passing `tp_plan` to the init, it should be {"model.language_model.layers.fc1":"colwise"}
1166
- # for example.
1167
- _tp_plan = None
1168
-
1169
- # tensor parallel degree to which model is sharded to.
1170
- _tp_size = None
1171
-
1172
- # A pipeline parallel plan specifying the layers which may not be present
1173
- # on all ranks when PP is enabled. For top-level models, this attribute is
1174
- # currently defined in respective model code. For base models, this
1175
- # attribute comes from `config.base_model_pp_plan` during `post_init`.
1176
- #
1177
- # The variable names for the inputs and outputs of the specified layers can
1178
- # be indexed using the `PipelineParallel` enum as follows:
1179
- # - `_pp_plan["layers"][PipelineParallel.inputs]`
1180
- # - `_pp_plan["layers"][PipelineParallel.outputs]`
1181
- _pp_plan = None
1111
+ base_model_prefix: str = ""
1112
+ _is_stateful: bool = False
1113
+ model_tags: list[str] | None = None
1182
1114
 
1115
+ # Input-related properties
1116
+ main_input_name: str = "input_ids"
1117
+ # Attributes used mainly in multimodal LLMs, though all models contain a valid field for these
1118
+ # Possible values are: text, image, video, audio and time
1119
+ input_modalities: str | list[str] = "text"
1120
+
1121
+ # Device-map related properties
1122
+ _no_split_modules: set[str] | list[str] | None = None
1123
+ _skip_keys_device_placement: str | list[str] | None = None
1124
+
1125
+ # Specific dtype upcasting
1126
+ # `_keep_in_fp32_modules` will upcast to fp32 only if the requested dtype is fp16
1127
+ # `_keep_in_fp32_modules_strict` will upcast to fp32 independently if the requested dtype is fp16 or bf16
1128
+ _keep_in_fp32_modules: set[str] | list[str] | None = None
1129
+ _keep_in_fp32_modules_strict: set[str] | list[str] | None = None
1130
+
1131
+ # Loading-specific properties
1132
+ # A dictionary `{"target": "source"}` of checkpoint keys that are potentially tied to one another
1133
+ _tied_weights_keys: dict[str, str] = None
1134
+ # Used for BC support in VLMs, not meant to be used by new models
1135
+ _checkpoint_conversion_mapping: dict[str, str] = {}
1136
+ # A list of `re` patterns describing keys to ignore if they are missing from checkpoints to avoid warnings
1137
+ _keys_to_ignore_on_load_missing: list[str] | None = None
1138
+ # A list of `re` patterns describing keys to ignore if they are unexpected in the checkpoints to avoid warnings
1139
+ _keys_to_ignore_on_load_unexpected: list[str] | None = None
1140
+ # A list of keys to ignore when saving the model
1141
+ _keys_to_ignore_on_save: list[str] | None = None
1142
+
1143
+ # Attention interfaces support properties
1144
+ _supports_sdpa: bool = False
1145
+ _supports_flash_attn: bool = False
1146
+ _supports_flex_attn: bool = False
1147
+
1148
+ # Tensor-parallelism-related properties
1149
+ # A tensor parallel plan of the form `{"model.layer.mlp.param": "colwise"}` to be applied to the model when TP is enabled.
1150
+ # For top-level models, this attribute is currently defined in respective model code. For base models, this attribute comes
1151
+ # from `config.base_model_tp_plan` during `post_init`.
1152
+ _tp_plan: dict[str, str] = None
1153
+ # Tensor parallel degree to which model is sharded to
1154
+ _tp_size = None
1155
+ # A pipeline parallel plan specifying the layers which may not be present on all ranks when PP is enabled. For top-level
1156
+ # models, this attribute is currently defined in respective model code. For base models, it comes from
1157
+ # `config.base_model_pp_plan` during `post_init`.
1158
+ _pp_plan: dict[str, PipelineParallel] | None = None
1159
+
1160
+ # Advanced functionalities support
1161
+ supports_gradient_checkpointing: bool = False
1162
+ _can_compile_fullgraph: bool = False
1183
1163
  # This flag signal that the model can be used as an efficient backend in TGI and vLLM
1184
1164
  # In practice, it means that they support attention (mask) interface functions, fully pass the kwargs
1185
1165
  # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
1186
- _supports_attention_backend = False
1187
- _can_record_outputs = None
1188
-
1189
- # Attributes used mainly in multimodal LLMs, though all models contain a valid field for these
1190
- # Possible values are: text, image, video, audio and time
1191
- input_modalities: Union[str, list[str]] = "text" # most models are text
1166
+ _supports_attention_backend: bool = False
1167
+ # A mapping describing what outputs can be captured by `check_model_inputs` decorator during the forward pass
1168
+ _can_record_outputs: dict | None = None
1192
1169
 
1193
1170
  @property
1194
1171
  @torch._dynamo.allow_in_graph
@@ -1273,6 +1250,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1273
1250
  f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
1274
1251
  )
1275
1252
  self.config = config
1253
+ self.name_or_path = config.name_or_path
1276
1254
 
1277
1255
  # Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
1278
1256
  # setting it recursively)
@@ -1298,38 +1276,33 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1298
1276
  loss_type = None
1299
1277
  self.loss_type = loss_type
1300
1278
 
1301
- self.name_or_path = config.name_or_path
1302
- self.warnings_issued = {}
1303
- # Overwrite the class attribute to make it an instance attribute, so models like
1304
- # `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
1305
- # when a different component (e.g. language_model) is used.
1306
- self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
1307
- self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
1308
- self.dtype_plan = {}
1309
-
1310
- if isinstance(self._keep_in_fp32_modules, list):
1311
- self.dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules, torch.float32))
1312
- if isinstance(self._keep_in_fp32_modules_strict, list):
1313
- self.dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules_strict, torch.float32))
1314
-
1315
- self._no_split_modules = self._no_split_modules or []
1316
1279
  _CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only
1317
1280
 
1318
1281
  def post_init(self):
1319
1282
  """
1320
1283
  A method executed at the end of each Transformer model initialization, to execute code that needs the model's
1321
1284
  modules properly initialized (such as weight initialization).
1285
+ It is also used to obtain all correct static properties (parallelism plans, tied_weights_keys, _keep_in_fp32_modules, etc)
1286
+ correctly in the case of composite models (that is, the top level model should know about those properties from its children).
1322
1287
  """
1323
1288
  # Attach the different parallel plans and tied weight keys to the top-most model, so that everything is
1324
1289
  # easily available
1325
1290
  self._tp_plan, self._ep_plan, self._pp_plan = {}, {}, {}
1326
- # Current submodel should register its tied weights
1327
- self.all_tied_weights_keys = self.get_expanded_tied_weights_keys(all_submodels=False)
1328
1291
  # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
1329
1292
  if self.base_model is self:
1330
1293
  self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {}
1331
1294
  self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
1332
1295
  self._ep_plan = self.config.base_model_ep_plan.copy() if self.config.base_model_ep_plan is not None else {}
1296
+ # Current submodel should register its tied weights
1297
+ self.all_tied_weights_keys = self.get_expanded_tied_weights_keys(all_submodels=False)
1298
+ # Current submodel should register its `_keep_in_fp32_modules`
1299
+ self._keep_in_fp32_modules = set(self._keep_in_fp32_modules or [])
1300
+ self._keep_in_fp32_modules_strict = set(self._keep_in_fp32_modules_strict or [])
1301
+ # Current submodel must register its `_no_split_modules` as well
1302
+ self._no_split_modules = set(self._no_split_modules or [])
1303
+
1304
+ # Iterate over children only: as the final model is created, this is enough to gather the properties from all submodels.
1305
+ # This works because the way the `__init__` and `post_init` are called on all submodules is depth-first in the graph
1333
1306
  for name, module in self.named_children():
1334
1307
  # Parallel plans
1335
1308
  if plan := getattr(module, "_ep_plan", None):
@@ -1341,6 +1314,14 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1341
1314
  # Always attach the keys of the children (if the children's config says to NOT tie, then it's empty)
1342
1315
  if tied_keys := getattr(module, "all_tied_weights_keys", None):
1343
1316
  self.all_tied_weights_keys.update({f"{name}.{k}": f"{name}.{v}" for k, v in tied_keys.copy().items()})
1317
+ # Record keep_in_fp_32 modules from the children as well
1318
+ if keep_fp32 := getattr(module, "_keep_in_fp32_modules", None):
1319
+ self._keep_in_fp32_modules.update(keep_fp32)
1320
+ if keep_fp32_strict := getattr(module, "_keep_in_fp32_modules_strict", None):
1321
+ self._keep_in_fp32_modules_strict.update(keep_fp32_strict)
1322
+ # Record `_no_split_modules` from the children
1323
+ if no_split := getattr(module, "_no_split_modules", None):
1324
+ self._no_split_modules.update(no_split)
1344
1325
 
1345
1326
  # Maybe initialize the weights and tie the keys
1346
1327
  self.init_weights()
@@ -1417,7 +1398,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1417
1398
  # Remove the attribute now that is has been consumed, so it's no saved in the config.
1418
1399
  delattr(self.config, "gradient_checkpointing")
1419
1400
 
1420
- def add_model_tags(self, tags: Union[list[str], str]) -> None:
1401
+ def add_model_tags(self, tags: list[str] | str) -> None:
1421
1402
  r"""
1422
1403
  Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
1423
1404
  not overwrite existing tags in the model.
@@ -1784,7 +1765,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1784
1765
  return True
1785
1766
 
1786
1767
  def _check_and_adjust_attn_implementation(
1787
- self, attn_implementation: Optional[str], is_init_check: bool = False
1768
+ self, attn_implementation: str | None, is_init_check: bool = False
1788
1769
  ) -> str:
1789
1770
  """
1790
1771
  Check that the `attn_implementation` exists and is supported by the models, and try to get the kernel from hub if
@@ -1859,12 +1840,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1859
1840
  )
1860
1841
 
1861
1842
  # preload flash attention here to allow compile with fullgraph
1862
- if "flash" in applicable_attn_implementation:
1843
+ if is_flash_attention_requested(requested_attention_implementation=applicable_attn_implementation):
1863
1844
  lazy_import_flash_attention(applicable_attn_implementation)
1864
1845
 
1865
1846
  return applicable_attn_implementation
1866
1847
 
1867
- def _check_and_adjust_experts_implementation(self, experts_implementation: Optional[str]) -> str:
1848
+ def _check_and_adjust_experts_implementation(self, experts_implementation: str | None) -> str:
1868
1849
  """
1869
1850
  Check that the `experts_implementation` exists and is supported by the models.
1870
1851
 
@@ -1877,7 +1858,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1877
1858
  applicable_experts_implementation = self.get_correct_experts_implementation(experts_implementation)
1878
1859
  return applicable_experts_implementation
1879
1860
 
1880
- def get_correct_attn_implementation(self, requested_attention: Optional[str], is_init_check: bool = False) -> str:
1861
+ def get_correct_attn_implementation(self, requested_attention: str | None, is_init_check: bool = False) -> str:
1881
1862
  applicable_attention = "sdpa" if requested_attention is None else requested_attention
1882
1863
  if applicable_attention not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
1883
1864
  message = (
@@ -1911,7 +1892,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1911
1892
 
1912
1893
  return applicable_attention
1913
1894
 
1914
- def get_correct_experts_implementation(self, requested_experts: Optional[str]) -> str:
1895
+ def get_correct_experts_implementation(self, requested_experts: str | None) -> str:
1915
1896
  applicable_experts = "grouped_mm" if requested_experts is None else requested_experts
1916
1897
  if applicable_experts not in ["eager", "grouped_mm", "batched_mm"]:
1917
1898
  message = (
@@ -1936,15 +1917,16 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1936
1917
  """Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
1937
1918
  opening the file, but avoids maintaining yet another property flag.
1938
1919
  """
1939
- class_file = sys.modules[cls.__module__].__file__
1940
- with open(class_file, "r") as f:
1920
+ class_module = sys.modules[cls.__module__]
1921
+ # This can happen for a custom model in a jupyter notebook or repl for example - simply do not allow to set it then
1922
+ if not hasattr(class_module, "__file__"):
1923
+ return False
1924
+ class_file = class_module.__file__
1925
+ with open(class_file, "r", encoding="utf-8") as f:
1941
1926
  code = f.read()
1942
1927
  # heuristic -> if we find those patterns, the model uses the correct interface
1943
1928
  if re.search(r"class \w+Attention\(nn.Module\)", code):
1944
- return (
1945
- "eager_attention_forward" in code
1946
- and "ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]" in code
1947
- )
1929
+ return "eager_attention_forward" in code and "ALL_ATTENTION_FUNCTIONS.get_interface(" in code
1948
1930
  else:
1949
1931
  # If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
1950
1932
  return True
@@ -1954,13 +1936,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1954
1936
  """Detect whether the class supports setting its experts implementation dynamically. It is an ugly check based on
1955
1937
  opening the file, but avoids maintaining yet another property flag.
1956
1938
  """
1957
- class_file = sys.modules[cls.__module__].__file__
1958
- with open(class_file, "r") as f:
1939
+ class_module = sys.modules[cls.__module__]
1940
+ # This can happen for a custom model in a jupyter notebook or repl for example - simply do not allow to set it then
1941
+ if not hasattr(class_module, "__file__"):
1942
+ return False
1943
+ class_file = class_module.__file__
1944
+ with open(class_file, "r", encoding="utf-8") as f:
1959
1945
  code = f.read()
1960
1946
  # heuristic -> if we the use_experts_implementation decorator is used, then we can set it
1961
1947
  return "@use_experts_implementation" in code
1962
1948
 
1963
- def set_attn_implementation(self, attn_implementation: Union[str, dict]):
1949
+ def set_attn_implementation(self, attn_implementation: str | dict):
1964
1950
  """
1965
1951
  Set the requested `attn_implementation` for this model.
1966
1952
 
@@ -2059,7 +2045,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2059
2045
  if hasattr(subconfig, "_attn_was_changed"):
2060
2046
  del subconfig._attn_was_changed
2061
2047
 
2062
- def set_experts_implementation(self, experts_implementation: Union[str, dict]):
2048
+ def set_experts_implementation(self, experts_implementation: str | dict):
2063
2049
  """
2064
2050
  Set the requested `experts_implementation` for this model.
2065
2051
 
@@ -2162,7 +2148,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2162
2148
  if hasattr(self, "_require_grads_hook"):
2163
2149
  del self._require_grads_hook
2164
2150
 
2165
- def get_encoder(self, modality: Optional[str] = None):
2151
+ def get_encoder(self, modality: str | None = None):
2166
2152
  """
2167
2153
  Best-effort lookup of the *encoder* module. If provided with `modality` argument,
2168
2154
  it looks for a modality-specific encoder in multimodal models (e.g. "image_encoder")
@@ -2194,7 +2180,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2194
2180
  # If this is a base transformer model (no encoder/model attributes), return self
2195
2181
  return self
2196
2182
 
2197
- def set_encoder(self, encoder, modality: Optional[str] = None):
2183
+ def set_encoder(self, encoder, modality: str | None = None):
2198
2184
  """
2199
2185
  Symmetric setter. Mirrors the lookup logic used in `get_encoder`.
2200
2186
  """
@@ -2421,7 +2407,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2421
2407
 
2422
2408
  tied_mapping = self._tied_weights_keys
2423
2409
  # If the config does not specify any tying, return empty dict
2424
- if not self.config.tie_word_embeddings:
2410
+ # NOTE: not all modules have `tie_word_embeddings` attr, for example vision-only
2411
+ # modules do not have any word embeddings!
2412
+ tie_word_embeddings = getattr(self.config, "tie_word_embeddings", False)
2413
+ if not tie_word_embeddings:
2425
2414
  return {}
2426
2415
  # If None, return empty dict
2427
2416
  elif tied_mapping is None:
@@ -2467,7 +2456,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2467
2456
 
2468
2457
  return expanded_tied_weights
2469
2458
 
2470
- def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True):
2459
+ def tie_weights(self, missing_keys: set[str] | None = None, recompute_mapping: bool = True):
2471
2460
  """
2472
2461
  Tie the model weights. If `recompute_mapping=False` (default when called internally), it will rely on the
2473
2462
  `model.all_tied_weights_keys` attribute, containing the `{target: source}` mapping for the tied params.
@@ -2559,39 +2548,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2559
2548
  if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
2560
2549
  output_embeddings.out_features = input_embeddings.num_embeddings
2561
2550
 
2562
- def _get_no_split_modules(self, device_map: str):
2563
- """
2564
- Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
2565
- get the underlying `_no_split_modules`.
2566
-
2567
- Args:
2568
- device_map (`str`):
2569
- The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
2570
-
2571
- Returns:
2572
- `list[str]`: List of modules that should not be split
2573
- """
2574
- _no_split_modules = set()
2575
- modules_to_check = [self]
2576
- while len(modules_to_check) > 0:
2577
- module = modules_to_check.pop(-1)
2578
- # if the module does not appear in _no_split_modules, we also check the children
2579
- if module.__class__.__name__ not in _no_split_modules:
2580
- if isinstance(module, PreTrainedModel):
2581
- if module._no_split_modules is None:
2582
- raise ValueError(
2583
- f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
2584
- "class needs to implement the `_no_split_modules` attribute."
2585
- )
2586
- else:
2587
- _no_split_modules = _no_split_modules | set(module._no_split_modules)
2588
- modules_to_check += list(module.children())
2589
- return list(_no_split_modules)
2590
-
2591
2551
  def resize_token_embeddings(
2592
2552
  self,
2593
- new_num_tokens: Optional[int] = None,
2594
- pad_to_multiple_of: Optional[int] = None,
2553
+ new_num_tokens: int | None = None,
2554
+ pad_to_multiple_of: int | None = None,
2595
2555
  mean_resizing: bool = True,
2596
2556
  ) -> nn.Embedding:
2597
2557
  """
@@ -2671,10 +2631,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2671
2631
  new_num_tokens = new_embeddings.weight.shape[0]
2672
2632
 
2673
2633
  # if word embeddings are not tied, make sure that lm head is resized as well
2674
- if (
2675
- self.get_output_embeddings() is not None
2676
- and not self.config.get_text_config(decoder=True).tie_word_embeddings
2677
- ):
2634
+ if self.get_output_embeddings() is not None:
2678
2635
  old_lm_head = self.get_output_embeddings()
2679
2636
  if isinstance(old_lm_head, torch.nn.Embedding):
2680
2637
  new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
@@ -2692,8 +2649,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2692
2649
  def _get_resized_embeddings(
2693
2650
  self,
2694
2651
  old_embeddings: nn.Embedding,
2695
- new_num_tokens: Optional[int] = None,
2696
- pad_to_multiple_of: Optional[int] = None,
2652
+ new_num_tokens: int | None = None,
2653
+ pad_to_multiple_of: int | None = None,
2697
2654
  mean_resizing: bool = True,
2698
2655
  ) -> nn.Embedding:
2699
2656
  """
@@ -2850,7 +2807,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2850
2807
  def _get_resized_lm_head(
2851
2808
  self,
2852
2809
  old_lm_head: nn.Linear,
2853
- new_num_tokens: Optional[int] = None,
2810
+ new_num_tokens: int | None = None,
2854
2811
  transposed: bool = False,
2855
2812
  mean_resizing: bool = True,
2856
2813
  ) -> nn.Linear:
@@ -3047,7 +3004,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3047
3004
  f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
3048
3005
  )
3049
3006
 
3050
- def get_position_embeddings(self) -> Union[nn.Embedding, tuple[nn.Embedding]]:
3007
+ def get_position_embeddings(self) -> nn.Embedding | tuple[nn.Embedding]:
3051
3008
  raise NotImplementedError(
3052
3009
  f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
3053
3010
  f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
@@ -3055,15 +3012,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3055
3012
 
3056
3013
  def init_weights(self):
3057
3014
  """
3058
- Maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
3015
+ Initialize and tie the weights if needed. If using a custom `PreTrainedModel`, you need to implement any
3059
3016
  initialization logic in `_init_weights`.
3060
3017
  """
3061
3018
  # If we are initializing on meta device, there is no point in trying to run inits
3062
3019
  if get_torch_context_manager_or_global_device() != torch.device("meta"):
3063
3020
  # Initialize weights
3064
3021
  self.initialize_weights()
3065
- # Tie weights needs to be called here, but it can use the pre-computed `all_tied_weights_keys`
3066
- self.tie_weights(recompute_mapping=False)
3022
+ # Tie weights needs to be called here, but it can use the pre-computed `all_tied_weights_keys`
3023
+ self.tie_weights(recompute_mapping=False)
3067
3024
 
3068
3025
  def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
3069
3026
  """
@@ -3080,7 +3037,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3080
3037
  raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
3081
3038
 
3082
3039
  if gradient_checkpointing_kwargs is None:
3083
- gradient_checkpointing_kwargs = {"use_reentrant": True}
3040
+ gradient_checkpointing_kwargs = {"use_reentrant": False}
3084
3041
 
3085
3042
  gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
3086
3043
 
@@ -3158,13 +3115,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3158
3115
 
3159
3116
  def save_pretrained(
3160
3117
  self,
3161
- save_directory: Union[str, os.PathLike],
3118
+ save_directory: str | os.PathLike,
3162
3119
  is_main_process: bool = True,
3163
- state_dict: Optional[dict] = None,
3120
+ state_dict: dict | None = None,
3164
3121
  push_to_hub: bool = False,
3165
- max_shard_size: Union[int, str] = "50GB",
3166
- variant: Optional[str] = None,
3167
- token: Optional[Union[str, bool]] = None,
3122
+ max_shard_size: int | str = "50GB",
3123
+ variant: str | None = None,
3124
+ token: str | bool | None = None,
3168
3125
  save_peft_format: bool = True,
3169
3126
  save_original_format: bool = True,
3170
3127
  **kwargs,
@@ -3231,12 +3188,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3231
3188
  " the logger on the traceback to understand the reason why the quantized model is not serializable."
3232
3189
  )
3233
3190
 
3234
- if "save_config" in kwargs:
3235
- warnings.warn(
3236
- "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
3237
- )
3238
- is_main_process = kwargs.pop("save_config")
3239
-
3240
3191
  # we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
3241
3192
  if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
3242
3193
  raise ImportError(
@@ -3339,16 +3290,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3339
3290
  if ignore_key in state_dict:
3340
3291
  del state_dict[ignore_key]
3341
3292
 
3342
- # If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used,
3343
- # therefore we replace them with DTensors that are equivalently sharded
3293
+ # If model was sharded with TP, gather full tensors for saving
3344
3294
  if self._tp_size is not None:
3345
- state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
3295
+ state_dict = gather_state_dict_for_save(state_dict, self._tp_plan, self._device_mesh, self._tp_size)
3346
3296
 
3347
3297
  # Remove tied weights as safetensors do not handle them
3348
3298
  state_dict = remove_tied_weights_from_state_dict(state_dict, model_to_save)
3349
3299
 
3350
3300
  # Revert all renaming and/or weight operations
3351
- if save_original_format:
3301
+ if save_original_format and not _hf_peft_config_loaded:
3352
3302
  state_dict = revert_weight_conversion(model_to_save, state_dict)
3353
3303
 
3354
3304
  # Shard the model if it is too big.
@@ -3400,13 +3350,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3400
3350
  # Get the tensor, and remove it from state_dict to avoid keeping the ref
3401
3351
  tensor = state_dict.pop(tensor_name)
3402
3352
 
3403
- # In case of TP, get the full parameter back
3404
- if _is_dtensor_available and isinstance(tensor, DTensor):
3405
- tensor = tensor.full_tensor()
3406
- # to get the correctly ordered tensor we need to repack if packed
3407
- if _get_parameter_tp_plan(tensor_name, self._tp_plan) == "local_packed_rowwise":
3408
- tensor = repack_weights(tensor, -1, self._tp_size, 2)
3409
-
3410
3353
  # If the param was offloaded, we need to load it back from disk to resave it. It's a strange pattern,
3411
3354
  # but it would otherwise not be contained in the saved shard if we were to simply move the file
3412
3355
  # or something
@@ -3564,10 +3507,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3564
3507
  " desired `dtype` by passing the correct `dtype` argument."
3565
3508
  )
3566
3509
 
3567
- if getattr(self, "is_loaded_in_8bit", False):
3510
+ if getattr(self, "is_loaded_in_8bit", False) and not is_bitsandbytes_available("0.48"):
3568
3511
  raise ValueError(
3569
- "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
3570
- " model has already been set to the correct devices and casted to the correct `dtype`."
3512
+ "You need to install `pip install bitsandbytes>=0.48.0` if you want to move a 8-bit model across devices using to()."
3571
3513
  )
3572
3514
  elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
3573
3515
  if dtype_present_in_args:
@@ -3600,7 +3542,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3600
3542
  @classmethod
3601
3543
  def get_init_context(cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool):
3602
3544
  # Need to instantiate with correct dtype
3603
- init_contexts = [local_torch_dtype(dtype, cls.__name__)]
3545
+ init_contexts = [local_torch_dtype(dtype, cls.__name__), init.no_tie_weights()]
3604
3546
  if is_deepspeed_zero3_enabled():
3605
3547
  import deepspeed
3606
3548
 
@@ -3621,7 +3563,31 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3621
3563
 
3622
3564
  return init_contexts
3623
3565
 
3624
- def set_use_kernels(self, use_kernels, kernel_config):
3566
+ def _get_dtype_plan(self, dtype: torch.dtype) -> dict:
3567
+ """Create the dtype_plan describing modules/parameters that should use the `keep_in_fp32` flag."""
3568
+ dtype_plan = {}
3569
+
3570
+ # The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
3571
+ # in case of force loading a model that should stay in bf16 in fp16
3572
+ # See https://github.com/huggingface/transformers/issues/20287 for details.
3573
+ if self._keep_in_fp32_modules is not None and dtype == torch.float16:
3574
+ dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules, torch.float32))
3575
+
3576
+ # The _keep_in_fp32_modules_strict was introduced to always force upcast to fp32, for both fp16 and bf16
3577
+ if self._keep_in_fp32_modules_strict is not None and dtype in (torch.float16, torch.bfloat16):
3578
+ dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules_strict, torch.float32))
3579
+
3580
+ return dtype_plan
3581
+
3582
+ def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None):
3583
+ """
3584
+ Set whether or not to use the `kernels` library to kernelize some layers of the model.
3585
+ Args:
3586
+ use_kernels (`bool`):
3587
+ Whether or not to use the `kernels` library to kernelize some layers of the model.
3588
+ kernel_config (`KernelConfig`, *optional*):
3589
+ The kernel configuration to use to kernelize the model. If `None`, the default kernel mapping will be used.
3590
+ """
3625
3591
  if use_kernels:
3626
3592
  if not is_kernels_available():
3627
3593
  raise ValueError(
@@ -3655,16 +3621,16 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3655
3621
  @classmethod
3656
3622
  def from_pretrained(
3657
3623
  cls: type[SpecificPreTrainedModelType],
3658
- pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
3624
+ pretrained_model_name_or_path: str | os.PathLike | None,
3659
3625
  *model_args,
3660
- config: Optional[Union[PreTrainedConfig, str, os.PathLike]] = None,
3661
- cache_dir: Optional[Union[str, os.PathLike]] = None,
3626
+ config: PreTrainedConfig | str | os.PathLike | None = None,
3627
+ cache_dir: str | os.PathLike | None = None,
3662
3628
  ignore_mismatched_sizes: bool = False,
3663
3629
  force_download: bool = False,
3664
3630
  local_files_only: bool = False,
3665
- token: Optional[Union[str, bool]] = None,
3631
+ token: str | bool | None = None,
3666
3632
  revision: str = "main",
3667
- use_safetensors: Optional[bool] = True,
3633
+ use_safetensors: bool | None = None,
3668
3634
  weights_only: bool = True,
3669
3635
  **kwargs,
3670
3636
  ) -> SpecificPreTrainedModelType:
@@ -4063,6 +4029,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4063
4029
  use_kernels=use_kernels,
4064
4030
  )
4065
4031
 
4032
+ # Create the dtype_plan to potentially use the `keep_in_fp32` flags (this needs to be called on the already
4033
+ # instantiated model, as the flags can be modified by instances sometimes)
4034
+ dtype_plan = model._get_dtype_plan(dtype)
4035
+
4066
4036
  # Obtain the weight conversion mapping for this model if any are registered
4067
4037
  weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer)
4068
4038
 
@@ -4074,29 +4044,30 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4074
4044
  device_map = _get_device_map(model, device_map, max_memory, hf_quantizer)
4075
4045
 
4076
4046
  # Finalize model weight initialization
4077
- model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs = cls._load_pretrained_model(
4078
- model,
4079
- state_dict,
4080
- checkpoint_files,
4081
- pretrained_model_name_or_path,
4047
+ load_config = LoadStateDictConfig(
4048
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
4082
4049
  ignore_mismatched_sizes=ignore_mismatched_sizes,
4083
4050
  sharded_metadata=sharded_metadata,
4084
4051
  device_map=device_map,
4085
4052
  disk_offload_folder=offload_folder,
4086
4053
  offload_buffers=offload_buffers,
4087
4054
  dtype=dtype,
4055
+ dtype_plan=dtype_plan,
4088
4056
  hf_quantizer=hf_quantizer,
4089
4057
  device_mesh=device_mesh,
4090
4058
  weights_only=weights_only,
4091
4059
  weight_mapping=weight_conversions,
4060
+ use_safetensors=use_safetensors,
4061
+ download_kwargs=download_kwargs,
4092
4062
  )
4093
-
4063
+ loading_info, disk_offload_index = cls._load_pretrained_model(model, state_dict, checkpoint_files, load_config)
4064
+ loading_info = cls._finalize_model_loading(model, load_config, loading_info)
4094
4065
  model.eval() # Set model in evaluation mode to deactivate Dropout modules by default
4095
4066
  model.set_use_kernels(use_kernels, kernel_config)
4096
4067
 
4097
4068
  # If it is a model with generation capabilities, attempt to load generation files (generation config,
4098
4069
  # custom generate function)
4099
- if model.can_generate() and hasattr(model, "adjust_generation_fn"):
4070
+ if model.can_generate() and hasattr(model, "adjust_generation_fn") and not gguf_file:
4100
4071
  model.adjust_generation_fn(
4101
4072
  generation_config,
4102
4073
  from_auto_class,
@@ -4109,7 +4080,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4109
4080
 
4110
4081
  # If the device_map has more than 1 device: dispatch model with hooks on all devices
4111
4082
  if device_map is not None and len(set(device_map.values())) > 1:
4112
- accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers)
4083
+ accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, disk_offload_index, offload_buffers)
4113
4084
 
4114
4085
  if hf_quantizer is not None:
4115
4086
  model.hf_quantizer = hf_quantizer
@@ -4118,44 +4089,29 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4118
4089
  ) # usually a no-op but sometimes needed, e.g to remove the quant config when dequantizing
4119
4090
 
4120
4091
  if _adapter_model_path is not None:
4121
- adapter_kwargs["key_mapping"] = key_mapping
4122
- model.load_adapter(
4092
+ if token is not None:
4093
+ adapter_kwargs["token"] = token
4094
+ loading_info = model.load_adapter(
4123
4095
  _adapter_model_path,
4124
4096
  adapter_name=adapter_name,
4125
- token=token,
4097
+ load_config=load_config,
4126
4098
  adapter_kwargs=adapter_kwargs,
4127
4099
  )
4128
4100
 
4129
4101
  if output_loading_info:
4130
- loading_info = {
4131
- "missing_keys": missing_keys,
4132
- "unexpected_keys": unexpected_keys,
4133
- "mismatched_keys": mismatched_keys,
4134
- "error_msgs": error_msgs,
4135
- }
4136
- return model, loading_info
4102
+ return model, loading_info.to_dict()
4137
4103
  return model
4138
4104
 
4139
- @classmethod
4105
+ @staticmethod
4140
4106
  def _load_pretrained_model(
4141
- cls,
4142
4107
  model: "PreTrainedModel",
4143
- state_dict: Optional[dict],
4144
- checkpoint_files: Optional[list[str]],
4145
- pretrained_model_name_or_path: Optional[str],
4146
- ignore_mismatched_sizes: bool = False,
4147
- sharded_metadata: Optional[dict] = None,
4148
- device_map: Optional[dict] = None,
4149
- disk_offload_folder: Optional[str] = None,
4150
- offload_buffers: bool = False,
4151
- dtype: Optional[torch.dtype] = None,
4152
- hf_quantizer: Optional[HfQuantizer] = None,
4153
- device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
4154
- weights_only: bool = True,
4155
- weight_mapping: Optional[Sequence[WeightConverter | WeightRenaming]] = None,
4156
- ):
4157
- is_quantized = hf_quantizer is not None
4158
- is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
4108
+ state_dict: dict | None,
4109
+ checkpoint_files: list[str] | None,
4110
+ load_config: LoadStateDictConfig,
4111
+ ) -> tuple[LoadStateDictInfo, dict]:
4112
+ """Perform the actual loading of some checkpoints into a `model`, by reading them from disk and dispatching them accordingly."""
4113
+ is_quantized = load_config.is_quantized
4114
+ is_hqq_or_quark = is_quantized and load_config.hf_quantizer.quantization_config.quant_method in {
4159
4115
  QuantizationMethod.HQQ,
4160
4116
  QuantizationMethod.QUARK,
4161
4117
  }
@@ -4169,21 +4125,21 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4169
4125
  # This offload index if for params explicitly on the "disk" in the device_map
4170
4126
  disk_offload_index = None
4171
4127
  # Prepare parameters offloading if needed
4172
- if device_map is not None and "disk" in device_map.values():
4128
+ if load_config.device_map is not None and "disk" in load_config.device_map.values():
4173
4129
  disk_offload_index = accelerate_disk_offload(
4174
4130
  model,
4175
- disk_offload_folder,
4131
+ load_config.disk_offload_folder,
4176
4132
  checkpoint_files,
4177
- device_map,
4178
- sharded_metadata,
4179
- dtype,
4180
- weight_mapping,
4133
+ load_config.device_map,
4134
+ load_config.sharded_metadata,
4135
+ load_config.dtype,
4136
+ load_config.weight_mapping,
4181
4137
  )
4182
4138
 
4183
4139
  # Warmup cuda to load the weights much faster on devices
4184
- if device_map is not None and not is_hqq_or_quark:
4185
- expanded_device_map = expand_device_map(device_map, expected_keys)
4186
- caching_allocator_warmup(model, expanded_device_map, hf_quantizer)
4140
+ if load_config.device_map is not None and not is_hqq_or_quark:
4141
+ expanded_device_map = expand_device_map(load_config.device_map, expected_keys)
4142
+ caching_allocator_warmup(model, expanded_device_map, load_config.hf_quantizer)
4187
4143
 
4188
4144
  error_msgs = []
4189
4145
 
@@ -4191,24 +4147,30 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4191
4147
  if state_dict is None:
4192
4148
  merged_state_dict = {}
4193
4149
  for ckpt_file in checkpoint_files:
4194
- merged_state_dict.update(load_state_dict(ckpt_file, map_location="cpu", weights_only=weights_only))
4150
+ merged_state_dict.update(
4151
+ load_state_dict(ckpt_file, map_location="cpu", weights_only=load_config.weights_only)
4152
+ )
4195
4153
  state_dict = merged_state_dict
4196
- error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict)
4154
+ error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict, load_config)
4197
4155
  # This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
4198
- unexpected_keys, mismatched_keys, conversion_errors = set(), set(), set()
4156
+ loading_info = LoadStateDictInfo(
4157
+ missing_keys=missing_keys,
4158
+ error_msgs=error_msgs,
4159
+ unexpected_keys=set(),
4160
+ mismatched_keys=set(),
4161
+ conversion_errors={},
4162
+ )
4199
4163
  else:
4200
4164
  all_pointer = set()
4201
- # Checkpoints are safetensors
4202
- if checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors"):
4165
+ if state_dict is not None:
4166
+ merged_state_dict = state_dict
4167
+ elif checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") and state_dict is None:
4203
4168
  merged_state_dict = {}
4204
4169
  for file in checkpoint_files:
4205
4170
  file_pointer = safe_open(file, framework="pt", device="cpu")
4206
4171
  all_pointer.add(file_pointer)
4207
4172
  for k in file_pointer.keys():
4208
4173
  merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet
4209
- # User passed an explicit state_dict
4210
- elif state_dict is not None:
4211
- merged_state_dict = state_dict
4212
4174
  # Checkpoints are .bin
4213
4175
  elif checkpoint_files is not None:
4214
4176
  merged_state_dict = {}
@@ -4217,58 +4179,58 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4217
4179
  else:
4218
4180
  raise ValueError("Neither a state dict nor checkpoint files were found.")
4219
4181
 
4220
- missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, conversion_errors = (
4221
- convert_and_load_state_dict_in_model(
4222
- model=model,
4223
- state_dict=merged_state_dict,
4224
- weight_mapping=weight_mapping,
4225
- tp_plan=model._tp_plan,
4226
- hf_quantizer=hf_quantizer,
4227
- dtype=dtype,
4228
- device_map=device_map,
4229
- dtype_plan=model.dtype_plan,
4230
- device_mesh=device_mesh,
4231
- disk_offload_index=disk_offload_index,
4232
- disk_offload_folder=disk_offload_folder,
4233
- offload_buffers=offload_buffers,
4234
- )
4182
+ loading_info, disk_offload_index = convert_and_load_state_dict_in_model(
4183
+ model=model,
4184
+ state_dict=merged_state_dict,
4185
+ load_config=load_config,
4186
+ tp_plan=model._tp_plan,
4187
+ disk_offload_index=disk_offload_index,
4235
4188
  )
4236
4189
 
4237
4190
  # finally close all opened file pointers
4238
4191
  for k in all_pointer:
4239
4192
  k.__exit__(None, None, None)
4240
4193
 
4241
- # Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
4242
- model.mark_tied_weights_as_initialized()
4243
-
4244
- # Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from
4245
- # meta device (because they were not moved when loading the weights as they were not in the loaded state dict)
4246
- missing_and_mismatched = missing_keys | {k[0] for k in mismatched_keys}
4247
- model._move_missing_keys_from_meta_to_device(missing_and_mismatched, device_map, device_mesh, hf_quantizer)
4194
+ return loading_info, disk_offload_index
4248
4195
 
4249
- # Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialized` flag)
4250
- model._initialize_missing_keys(is_quantized)
4251
-
4252
- # Tie the weights
4253
- model.tie_weights(missing_keys=missing_keys, recompute_mapping=False)
4196
+ @staticmethod
4197
+ def _finalize_model_loading(
4198
+ model, load_config: LoadStateDictConfig, loading_info: LoadStateDictInfo
4199
+ ) -> LoadStateDictInfo:
4200
+ """Perform all post processing operations after having loaded some checkpoints into a model, such as moving
4201
+ missing keys from meta device to their expected device, reinitializing missing weights according to proper
4202
+ distributions, tying the weights and logging the loading report."""
4203
+ try:
4204
+ # Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
4205
+ model.mark_tied_weights_as_initialized()
4206
+
4207
+ # Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from
4208
+ # meta device (because they were not moved when loading the weights as they were not in the loaded state dict)
4209
+ model._move_missing_keys_from_meta_to_device(
4210
+ loading_info.missing_and_mismatched(),
4211
+ load_config.device_map,
4212
+ load_config.device_mesh,
4213
+ load_config.hf_quantizer,
4214
+ )
4254
4215
 
4255
- # Adjust missing and unexpected keys
4256
- missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(missing_keys, unexpected_keys)
4216
+ # Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialized` flag)
4217
+ model._initialize_missing_keys(load_config.is_quantized)
4257
4218
 
4258
- log_state_dict_report(
4259
- model=model,
4260
- pretrained_model_name_or_path=pretrained_model_name_or_path,
4261
- logger=logger,
4262
- error_msgs=error_msgs,
4263
- unexpected_keys=unexpected_keys,
4264
- missing_keys=missing_keys,
4265
- mismatched_keys=mismatched_keys,
4266
- mismatched_shapes=mismatched_keys,
4267
- conversion_errors=conversion_errors,
4268
- ignore_mismatched_sizes=ignore_mismatched_sizes,
4269
- )
4219
+ # Tie the weights
4220
+ model.tie_weights(missing_keys=loading_info.missing_keys, recompute_mapping=False)
4221
+
4222
+ # Adjust missing and unexpected keys
4223
+ model._adjust_missing_and_unexpected_keys(loading_info)
4224
+ finally:
4225
+ log_state_dict_report(
4226
+ model=model,
4227
+ pretrained_model_name_or_path=load_config.pretrained_model_name_or_path,
4228
+ ignore_mismatched_sizes=load_config.ignore_mismatched_sizes,
4229
+ loading_info=loading_info,
4230
+ logger=logger,
4231
+ )
4270
4232
 
4271
- return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs
4233
+ return loading_info
4272
4234
 
4273
4235
  def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
4274
4236
  module_keys = {".".join(key.split(".")[:-1]) for key in names}
@@ -4337,15 +4299,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4337
4299
 
4338
4300
  # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an
4339
4301
  # attention_mask or not. In this case, we should still show a warning because this is a rare case.
4302
+ # NOTE: `sep_token_id` is not used in all models and it can be absent in the config
4303
+ sep_token_id = getattr(self.config, "sep_token_id", None)
4340
4304
  if (
4341
4305
  (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id)
4342
4306
  or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id)
4343
- or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id)
4307
+ or (sep_token_id is not None and sep_token_id == self.config.pad_token_id)
4344
4308
  ):
4345
4309
  warn_string += (
4346
4310
  f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical "
4347
4311
  f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), "
4348
- f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded."
4312
+ f"or the `sep_token_id` ({sep_token_id}), and your input is not padded."
4349
4313
  )
4350
4314
 
4351
4315
  logger.warning_once(warn_string)
@@ -4430,7 +4394,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4430
4394
  )
4431
4395
  self._use_kernels = False
4432
4396
 
4433
- def get_compiled_call(self, compile_config: Optional[CompileConfig]) -> Callable:
4397
+ def get_compiled_call(self, compile_config: CompileConfig | None) -> Callable:
4434
4398
  """Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
4435
4399
  non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
4436
4400
  want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
@@ -4522,11 +4486,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4522
4486
  else:
4523
4487
  self.initialize_weights()
4524
4488
 
4525
- def _adjust_missing_and_unexpected_keys(
4526
- self, missing_keys: set[str], unexpected_keys: set[str]
4527
- ) -> tuple[set[str], set[str]]:
4489
+ def _adjust_missing_and_unexpected_keys(self, loading_info: LoadStateDictInfo) -> None:
4528
4490
  """Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid
4529
- raising unneeded warnings/errors.
4491
+ raising unneeded warnings/errors. This is performed in-place.
4530
4492
  """
4531
4493
  # Old checkpoints may have keys for rotary_emb.inv_freq forach layer, however we moved this buffer to the main model
4532
4494
  # (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
@@ -4544,13 +4506,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4544
4506
 
4545
4507
  # Clean-up missing keys
4546
4508
  if ignore_missing_regex is not None:
4547
- missing_keys = {key for key in missing_keys if ignore_missing_regex.search(key) is None}
4509
+ loading_info.missing_keys = {
4510
+ key for key in loading_info.missing_keys if ignore_missing_regex.search(key) is None
4511
+ }
4548
4512
 
4549
4513
  # Clean-up unexpected keys
4550
4514
  if ignore_unexpected_regex is not None:
4551
- unexpected_keys = {key for key in unexpected_keys if ignore_unexpected_regex.search(key) is None}
4552
-
4553
- return missing_keys, unexpected_keys
4515
+ loading_info.unexpected_keys = {
4516
+ key for key in loading_info.unexpected_keys if ignore_unexpected_regex.search(key) is None
4517
+ }
4554
4518
 
4555
4519
  def mark_tied_weights_as_initialized(self):
4556
4520
  """Adds the `_is_hf_initialized` flag on parameters that will be tied, in order to avoid initializing them
@@ -4640,7 +4604,7 @@ def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
4640
4604
  return model
4641
4605
 
4642
4606
 
4643
- def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
4607
+ def is_accelerator_device(device: str | int | torch.device) -> bool:
4644
4608
  """Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
4645
4609
  a proper `torch.device`.
4646
4610
  """
@@ -4651,7 +4615,7 @@ def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
4651
4615
 
4652
4616
 
4653
4617
  def get_total_byte_count(
4654
- model: PreTrainedModel, accelerator_device_map: dict, hf_quantizer: Optional[HfQuantizer] = None
4618
+ model: PreTrainedModel, accelerator_device_map: dict, hf_quantizer: HfQuantizer | None = None
4655
4619
  ):
4656
4620
  """
4657
4621
  This utility function calculates the total bytes count needed to load the model on each device.
@@ -4684,7 +4648,7 @@ def get_total_byte_count(
4684
4648
  return total_byte_count
4685
4649
 
4686
4650
 
4687
- def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: Optional[HfQuantizer]):
4651
+ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: HfQuantizer | None):
4688
4652
  """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
4689
4653
  device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
4690
4654
  the model, which is actually the loading speed bottleneck.
@@ -4732,7 +4696,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
4732
4696
  ) - torch_accelerator_module.memory_allocated(index)
4733
4697
  byte_count = int(max(0, byte_count - unused_memory))
4734
4698
  # We divide by 2 here as we allocate in fp16
4735
- _ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)
4699
+ _ = torch.empty(int(byte_count // 2), dtype=torch.float16, device=device, requires_grad=False)
4736
4700
 
4737
4701
 
4738
4702
  class AttentionInterface(GeneralInterface):
@@ -4755,6 +4719,20 @@ class AttentionInterface(GeneralInterface):
4755
4719
  "paged|eager": eager_paged_attention_forward,
4756
4720
  }
4757
4721
 
4722
+ def get_interface(self, attn_implementation: str, default: Callable) -> Callable:
4723
+ """Return the requested `attn_implementation`. Also strictly check its validity, and raise if invalid."""
4724
+ if attn_implementation is None:
4725
+ logger.warning_once(
4726
+ "You tried to access the `AttentionInterface` with a `config._attn_implementation` set to `None`. This "
4727
+ "is expected if you use an Attention Module as a standalone Module. If this is not the case, something went "
4728
+ "wrong with the dispatch of `config._attn_implementation`"
4729
+ )
4730
+ elif attn_implementation != "eager" and attn_implementation not in self:
4731
+ raise KeyError(
4732
+ f"`{attn_implementation}` is not a valid attention implementation registered in the `AttentionInterface`"
4733
+ )
4734
+ return super().get(attn_implementation, default)
4735
+
4758
4736
 
4759
4737
  # Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
4760
4738
  ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()