transformers 5.0.0__py3-none-any.whl → 5.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1606) hide show
  1. transformers/__init__.py +36 -55
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +33 -32
  4. transformers/cache_utils.py +139 -32
  5. transformers/cli/chat.py +3 -3
  6. transformers/cli/serve.py +19 -49
  7. transformers/cli/transformers.py +1 -2
  8. transformers/configuration_utils.py +155 -129
  9. transformers/conversion_mapping.py +22 -158
  10. transformers/convert_slow_tokenizer.py +17 -227
  11. transformers/core_model_loading.py +185 -528
  12. transformers/data/data_collator.py +4 -12
  13. transformers/data/processors/glue.py +1 -0
  14. transformers/data/processors/utils.py +1 -0
  15. transformers/data/processors/xnli.py +1 -0
  16. transformers/dependency_versions_check.py +1 -0
  17. transformers/dependency_versions_table.py +7 -5
  18. transformers/distributed/configuration_utils.py +2 -1
  19. transformers/dynamic_module_utils.py +25 -24
  20. transformers/feature_extraction_sequence_utils.py +23 -19
  21. transformers/feature_extraction_utils.py +33 -64
  22. transformers/file_utils.py +1 -0
  23. transformers/generation/__init__.py +1 -11
  24. transformers/generation/candidate_generator.py +33 -80
  25. transformers/generation/configuration_utils.py +133 -189
  26. transformers/generation/continuous_batching/__init__.py +1 -4
  27. transformers/generation/continuous_batching/cache.py +25 -83
  28. transformers/generation/continuous_batching/cache_manager.py +45 -155
  29. transformers/generation/continuous_batching/continuous_api.py +147 -270
  30. transformers/generation/continuous_batching/requests.py +3 -51
  31. transformers/generation/continuous_batching/scheduler.py +105 -160
  32. transformers/generation/logits_process.py +128 -0
  33. transformers/generation/stopping_criteria.py +1 -1
  34. transformers/generation/streamers.py +1 -0
  35. transformers/generation/utils.py +123 -122
  36. transformers/generation/watermarking.py +6 -8
  37. transformers/hf_argparser.py +13 -9
  38. transformers/hyperparameter_search.py +2 -1
  39. transformers/image_processing_base.py +23 -12
  40. transformers/image_processing_utils.py +15 -11
  41. transformers/image_processing_utils_fast.py +75 -85
  42. transformers/image_transforms.py +42 -73
  43. transformers/image_utils.py +32 -30
  44. transformers/initialization.py +0 -37
  45. transformers/integrations/__init__.py +2 -16
  46. transformers/integrations/accelerate.py +113 -58
  47. transformers/integrations/aqlm.py +66 -36
  48. transformers/integrations/awq.py +516 -45
  49. transformers/integrations/bitnet.py +105 -47
  50. transformers/integrations/bitsandbytes.py +202 -91
  51. transformers/integrations/deepspeed.py +4 -161
  52. transformers/integrations/eetq.py +82 -84
  53. transformers/integrations/executorch.py +1 -1
  54. transformers/integrations/fbgemm_fp8.py +145 -190
  55. transformers/integrations/finegrained_fp8.py +215 -249
  56. transformers/integrations/flash_attention.py +3 -3
  57. transformers/integrations/flex_attention.py +1 -1
  58. transformers/integrations/fp_quant.py +0 -90
  59. transformers/integrations/ggml.py +2 -11
  60. transformers/integrations/higgs.py +62 -37
  61. transformers/integrations/hub_kernels.py +8 -65
  62. transformers/integrations/integration_utils.py +3 -47
  63. transformers/integrations/mistral.py +0 -12
  64. transformers/integrations/mxfp4.py +80 -33
  65. transformers/integrations/peft.py +191 -483
  66. transformers/integrations/quanto.py +56 -77
  67. transformers/integrations/spqr.py +90 -42
  68. transformers/integrations/tensor_parallel.py +221 -167
  69. transformers/integrations/torchao.py +43 -35
  70. transformers/integrations/vptq.py +59 -40
  71. transformers/kernels/__init__.py +0 -0
  72. transformers/{models/pe_audio_video/processing_pe_audio_video.py → kernels/falcon_mamba/__init__.py} +3 -12
  73. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +529 -0
  74. transformers/loss/loss_utils.py +0 -2
  75. transformers/masking_utils.py +55 -51
  76. transformers/model_debugging_utils.py +5 -4
  77. transformers/modelcard.py +194 -15
  78. transformers/modeling_attn_mask_utils.py +19 -19
  79. transformers/modeling_flash_attention_utils.py +27 -27
  80. transformers/modeling_gguf_pytorch_utils.py +24 -79
  81. transformers/modeling_layers.py +22 -21
  82. transformers/modeling_outputs.py +253 -242
  83. transformers/modeling_rope_utils.py +117 -138
  84. transformers/modeling_utils.py +739 -850
  85. transformers/models/__init__.py +0 -27
  86. transformers/models/afmoe/configuration_afmoe.py +33 -40
  87. transformers/models/afmoe/modeling_afmoe.py +54 -42
  88. transformers/models/afmoe/modular_afmoe.py +33 -23
  89. transformers/models/aimv2/configuration_aimv2.py +10 -2
  90. transformers/models/aimv2/modeling_aimv2.py +42 -47
  91. transformers/models/aimv2/modular_aimv2.py +19 -17
  92. transformers/models/albert/configuration_albert.py +2 -8
  93. transformers/models/albert/modeling_albert.py +69 -70
  94. transformers/models/albert/tokenization_albert.py +14 -5
  95. transformers/models/align/configuration_align.py +6 -8
  96. transformers/models/align/modeling_align.py +89 -94
  97. transformers/models/align/processing_align.py +30 -2
  98. transformers/models/altclip/configuration_altclip.py +7 -4
  99. transformers/models/altclip/modeling_altclip.py +103 -114
  100. transformers/models/altclip/processing_altclip.py +15 -2
  101. transformers/models/apertus/__init__.py +1 -0
  102. transformers/models/apertus/configuration_apertus.py +28 -23
  103. transformers/models/apertus/modeling_apertus.py +40 -39
  104. transformers/models/apertus/modular_apertus.py +38 -37
  105. transformers/models/arcee/configuration_arcee.py +30 -25
  106. transformers/models/arcee/modeling_arcee.py +39 -36
  107. transformers/models/arcee/modular_arcee.py +23 -20
  108. transformers/models/aria/configuration_aria.py +44 -31
  109. transformers/models/aria/image_processing_aria.py +27 -25
  110. transformers/models/aria/modeling_aria.py +106 -110
  111. transformers/models/aria/modular_aria.py +127 -118
  112. transformers/models/aria/processing_aria.py +35 -28
  113. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +1 -0
  114. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +6 -3
  115. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +8 -6
  116. transformers/models/audioflamingo3/__init__.py +1 -0
  117. transformers/models/audioflamingo3/configuration_audioflamingo3.py +1 -0
  118. transformers/models/audioflamingo3/modeling_audioflamingo3.py +49 -58
  119. transformers/models/audioflamingo3/modular_audioflamingo3.py +43 -53
  120. transformers/models/audioflamingo3/processing_audioflamingo3.py +30 -33
  121. transformers/models/auto/auto_factory.py +7 -6
  122. transformers/models/auto/configuration_auto.py +5 -66
  123. transformers/models/auto/feature_extraction_auto.py +10 -14
  124. transformers/models/auto/image_processing_auto.py +41 -32
  125. transformers/models/auto/modeling_auto.py +188 -46
  126. transformers/models/auto/processing_auto.py +11 -24
  127. transformers/models/auto/tokenization_auto.py +588 -171
  128. transformers/models/auto/video_processing_auto.py +10 -12
  129. transformers/models/autoformer/configuration_autoformer.py +7 -4
  130. transformers/models/autoformer/modeling_autoformer.py +101 -104
  131. transformers/models/aya_vision/configuration_aya_vision.py +1 -4
  132. transformers/models/aya_vision/modeling_aya_vision.py +102 -71
  133. transformers/models/aya_vision/modular_aya_vision.py +74 -46
  134. transformers/models/aya_vision/processing_aya_vision.py +53 -25
  135. transformers/models/bamba/configuration_bamba.py +39 -34
  136. transformers/models/bamba/modeling_bamba.py +86 -82
  137. transformers/models/bamba/modular_bamba.py +72 -70
  138. transformers/models/bark/configuration_bark.py +8 -6
  139. transformers/models/bark/generation_configuration_bark.py +5 -3
  140. transformers/models/bark/modeling_bark.py +57 -54
  141. transformers/models/bark/processing_bark.py +41 -19
  142. transformers/models/bart/configuration_bart.py +6 -9
  143. transformers/models/bart/modeling_bart.py +126 -135
  144. transformers/models/barthez/tokenization_barthez.py +11 -3
  145. transformers/models/bartpho/tokenization_bartpho.py +7 -6
  146. transformers/models/beit/configuration_beit.py +11 -0
  147. transformers/models/beit/image_processing_beit.py +56 -53
  148. transformers/models/beit/image_processing_beit_fast.py +12 -10
  149. transformers/models/beit/modeling_beit.py +60 -69
  150. transformers/models/bert/configuration_bert.py +2 -12
  151. transformers/models/bert/modeling_bert.py +122 -114
  152. transformers/models/bert/tokenization_bert.py +23 -8
  153. transformers/models/bert/tokenization_bert_legacy.py +5 -3
  154. transformers/models/bert_generation/configuration_bert_generation.py +2 -17
  155. transformers/models/bert_generation/modeling_bert_generation.py +49 -49
  156. transformers/models/bert_generation/tokenization_bert_generation.py +3 -2
  157. transformers/models/bert_japanese/tokenization_bert_japanese.py +6 -5
  158. transformers/models/bertweet/tokenization_bertweet.py +3 -1
  159. transformers/models/big_bird/configuration_big_bird.py +9 -12
  160. transformers/models/big_bird/modeling_big_bird.py +109 -116
  161. transformers/models/big_bird/tokenization_big_bird.py +43 -16
  162. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  163. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +117 -130
  164. transformers/models/biogpt/configuration_biogpt.py +2 -8
  165. transformers/models/biogpt/modeling_biogpt.py +76 -72
  166. transformers/models/biogpt/modular_biogpt.py +66 -62
  167. transformers/models/biogpt/tokenization_biogpt.py +5 -3
  168. transformers/models/bit/configuration_bit.py +1 -0
  169. transformers/models/bit/image_processing_bit.py +24 -21
  170. transformers/models/bit/image_processing_bit_fast.py +1 -0
  171. transformers/models/bit/modeling_bit.py +12 -25
  172. transformers/models/bitnet/configuration_bitnet.py +28 -23
  173. transformers/models/bitnet/modeling_bitnet.py +39 -36
  174. transformers/models/bitnet/modular_bitnet.py +6 -4
  175. transformers/models/blenderbot/configuration_blenderbot.py +5 -8
  176. transformers/models/blenderbot/modeling_blenderbot.py +96 -77
  177. transformers/models/blenderbot/tokenization_blenderbot.py +24 -18
  178. transformers/models/blenderbot_small/configuration_blenderbot_small.py +5 -8
  179. transformers/models/blenderbot_small/modeling_blenderbot_small.py +69 -79
  180. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +3 -1
  181. transformers/models/blip/configuration_blip.py +10 -9
  182. transformers/models/blip/image_processing_blip.py +20 -17
  183. transformers/models/blip/image_processing_blip_fast.py +1 -0
  184. transformers/models/blip/modeling_blip.py +108 -117
  185. transformers/models/blip/modeling_blip_text.py +65 -73
  186. transformers/models/blip/processing_blip.py +36 -5
  187. transformers/models/blip_2/configuration_blip_2.py +2 -2
  188. transformers/models/blip_2/modeling_blip_2.py +118 -146
  189. transformers/models/blip_2/processing_blip_2.py +38 -8
  190. transformers/models/bloom/configuration_bloom.py +2 -5
  191. transformers/models/bloom/modeling_bloom.py +104 -77
  192. transformers/models/blt/configuration_blt.py +86 -94
  193. transformers/models/blt/modeling_blt.py +81 -238
  194. transformers/models/blt/modular_blt.py +65 -228
  195. transformers/models/bridgetower/configuration_bridgetower.py +2 -7
  196. transformers/models/bridgetower/image_processing_bridgetower.py +35 -34
  197. transformers/models/bridgetower/image_processing_bridgetower_fast.py +16 -13
  198. transformers/models/bridgetower/modeling_bridgetower.py +119 -141
  199. transformers/models/bridgetower/processing_bridgetower.py +16 -2
  200. transformers/models/bros/configuration_bros.py +18 -24
  201. transformers/models/bros/modeling_bros.py +80 -90
  202. transformers/models/bros/processing_bros.py +12 -2
  203. transformers/models/byt5/tokenization_byt5.py +6 -4
  204. transformers/models/camembert/configuration_camembert.py +2 -8
  205. transformers/models/camembert/modeling_camembert.py +195 -196
  206. transformers/models/camembert/modular_camembert.py +54 -51
  207. transformers/models/camembert/tokenization_camembert.py +13 -6
  208. transformers/models/canine/configuration_canine.py +2 -4
  209. transformers/models/canine/modeling_canine.py +75 -84
  210. transformers/models/canine/tokenization_canine.py +1 -2
  211. transformers/models/chameleon/configuration_chameleon.py +34 -29
  212. transformers/models/chameleon/image_processing_chameleon.py +24 -21
  213. transformers/models/chameleon/image_processing_chameleon_fast.py +6 -5
  214. transformers/models/chameleon/modeling_chameleon.py +93 -142
  215. transformers/models/chameleon/processing_chameleon.py +41 -16
  216. transformers/models/chinese_clip/configuration_chinese_clip.py +8 -10
  217. transformers/models/chinese_clip/image_processing_chinese_clip.py +24 -21
  218. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +1 -0
  219. transformers/models/chinese_clip/modeling_chinese_clip.py +92 -96
  220. transformers/models/chinese_clip/processing_chinese_clip.py +15 -2
  221. transformers/models/clap/configuration_clap.py +9 -4
  222. transformers/models/clap/feature_extraction_clap.py +12 -11
  223. transformers/models/clap/modeling_clap.py +123 -136
  224. transformers/models/clap/processing_clap.py +15 -2
  225. transformers/models/clip/configuration_clip.py +2 -4
  226. transformers/models/clip/image_processing_clip.py +24 -21
  227. transformers/models/clip/image_processing_clip_fast.py +1 -9
  228. transformers/models/clip/modeling_clip.py +65 -65
  229. transformers/models/clip/processing_clip.py +14 -2
  230. transformers/models/clip/tokenization_clip.py +46 -21
  231. transformers/models/clipseg/configuration_clipseg.py +2 -4
  232. transformers/models/clipseg/modeling_clipseg.py +109 -119
  233. transformers/models/clipseg/processing_clipseg.py +42 -19
  234. transformers/models/clvp/configuration_clvp.py +5 -15
  235. transformers/models/clvp/feature_extraction_clvp.py +10 -7
  236. transformers/models/clvp/modeling_clvp.py +146 -155
  237. transformers/models/clvp/number_normalizer.py +2 -1
  238. transformers/models/clvp/processing_clvp.py +20 -3
  239. transformers/models/clvp/tokenization_clvp.py +64 -1
  240. transformers/models/code_llama/tokenization_code_llama.py +44 -18
  241. transformers/models/codegen/configuration_codegen.py +4 -4
  242. transformers/models/codegen/modeling_codegen.py +53 -63
  243. transformers/models/codegen/tokenization_codegen.py +47 -17
  244. transformers/models/cohere/configuration_cohere.py +30 -25
  245. transformers/models/cohere/modeling_cohere.py +42 -40
  246. transformers/models/cohere/modular_cohere.py +29 -26
  247. transformers/models/cohere/tokenization_cohere.py +46 -15
  248. transformers/models/cohere2/configuration_cohere2.py +32 -31
  249. transformers/models/cohere2/modeling_cohere2.py +44 -42
  250. transformers/models/cohere2/modular_cohere2.py +54 -54
  251. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +14 -13
  252. transformers/models/cohere2_vision/modeling_cohere2_vision.py +58 -59
  253. transformers/models/cohere2_vision/modular_cohere2_vision.py +46 -45
  254. transformers/models/cohere2_vision/processing_cohere2_vision.py +36 -6
  255. transformers/models/colpali/configuration_colpali.py +1 -0
  256. transformers/models/colpali/modeling_colpali.py +16 -14
  257. transformers/models/colpali/modular_colpali.py +51 -11
  258. transformers/models/colpali/processing_colpali.py +52 -14
  259. transformers/models/colqwen2/modeling_colqwen2.py +28 -28
  260. transformers/models/colqwen2/modular_colqwen2.py +74 -37
  261. transformers/models/colqwen2/processing_colqwen2.py +52 -16
  262. transformers/models/conditional_detr/configuration_conditional_detr.py +2 -1
  263. transformers/models/conditional_detr/image_processing_conditional_detr.py +70 -67
  264. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +36 -36
  265. transformers/models/conditional_detr/modeling_conditional_detr.py +87 -99
  266. transformers/models/conditional_detr/modular_conditional_detr.py +3 -49
  267. transformers/models/convbert/configuration_convbert.py +8 -11
  268. transformers/models/convbert/modeling_convbert.py +87 -94
  269. transformers/models/convbert/tokenization_convbert.py +1 -0
  270. transformers/models/convnext/configuration_convnext.py +1 -0
  271. transformers/models/convnext/image_processing_convnext.py +23 -20
  272. transformers/models/convnext/image_processing_convnext_fast.py +21 -16
  273. transformers/models/convnext/modeling_convnext.py +12 -9
  274. transformers/models/convnextv2/configuration_convnextv2.py +1 -0
  275. transformers/models/convnextv2/modeling_convnextv2.py +12 -9
  276. transformers/models/cpm/tokenization_cpm.py +7 -6
  277. transformers/models/cpm/tokenization_cpm_fast.py +5 -3
  278. transformers/models/cpmant/configuration_cpmant.py +1 -4
  279. transformers/models/cpmant/modeling_cpmant.py +40 -38
  280. transformers/models/cpmant/tokenization_cpmant.py +3 -1
  281. transformers/models/csm/configuration_csm.py +66 -58
  282. transformers/models/csm/generation_csm.py +35 -31
  283. transformers/models/csm/modeling_csm.py +85 -85
  284. transformers/models/csm/modular_csm.py +58 -58
  285. transformers/models/csm/processing_csm.py +68 -25
  286. transformers/models/ctrl/configuration_ctrl.py +1 -16
  287. transformers/models/ctrl/modeling_ctrl.py +44 -54
  288. transformers/models/ctrl/tokenization_ctrl.py +1 -0
  289. transformers/models/cvt/configuration_cvt.py +1 -0
  290. transformers/models/cvt/modeling_cvt.py +16 -20
  291. transformers/models/cwm/__init__.py +1 -0
  292. transformers/models/cwm/configuration_cwm.py +12 -8
  293. transformers/models/cwm/modeling_cwm.py +39 -37
  294. transformers/models/cwm/modular_cwm.py +12 -10
  295. transformers/models/d_fine/configuration_d_fine.py +5 -7
  296. transformers/models/d_fine/modeling_d_fine.py +128 -138
  297. transformers/models/d_fine/modular_d_fine.py +18 -33
  298. transformers/models/dab_detr/configuration_dab_detr.py +3 -6
  299. transformers/models/dab_detr/modeling_dab_detr.py +75 -81
  300. transformers/models/dac/configuration_dac.py +1 -0
  301. transformers/models/dac/feature_extraction_dac.py +9 -6
  302. transformers/models/dac/modeling_dac.py +26 -24
  303. transformers/models/data2vec/configuration_data2vec_audio.py +2 -4
  304. transformers/models/data2vec/configuration_data2vec_text.py +3 -11
  305. transformers/models/data2vec/configuration_data2vec_vision.py +1 -0
  306. transformers/models/data2vec/modeling_data2vec_audio.py +56 -57
  307. transformers/models/data2vec/modeling_data2vec_text.py +93 -98
  308. transformers/models/data2vec/modeling_data2vec_vision.py +45 -49
  309. transformers/models/data2vec/modular_data2vec_audio.py +1 -6
  310. transformers/models/data2vec/modular_data2vec_text.py +54 -58
  311. transformers/models/dbrx/configuration_dbrx.py +22 -36
  312. transformers/models/dbrx/modeling_dbrx.py +45 -42
  313. transformers/models/dbrx/modular_dbrx.py +33 -31
  314. transformers/models/deberta/configuration_deberta.py +1 -6
  315. transformers/models/deberta/modeling_deberta.py +60 -64
  316. transformers/models/deberta/tokenization_deberta.py +21 -9
  317. transformers/models/deberta_v2/configuration_deberta_v2.py +1 -6
  318. transformers/models/deberta_v2/modeling_deberta_v2.py +65 -71
  319. transformers/models/deberta_v2/tokenization_deberta_v2.py +29 -11
  320. transformers/models/decision_transformer/configuration_decision_transformer.py +2 -3
  321. transformers/models/decision_transformer/modeling_decision_transformer.py +56 -60
  322. transformers/models/deepseek_v2/configuration_deepseek_v2.py +44 -39
  323. transformers/models/deepseek_v2/modeling_deepseek_v2.py +43 -43
  324. transformers/models/deepseek_v2/modular_deepseek_v2.py +49 -48
  325. transformers/models/deepseek_v3/configuration_deepseek_v3.py +45 -40
  326. transformers/models/deepseek_v3/modeling_deepseek_v3.py +42 -45
  327. transformers/models/deepseek_v3/modular_deepseek_v3.py +9 -14
  328. transformers/models/deepseek_vl/configuration_deepseek_vl.py +3 -2
  329. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +26 -25
  330. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +10 -10
  331. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -57
  332. transformers/models/deepseek_vl/modular_deepseek_vl.py +43 -14
  333. transformers/models/deepseek_vl/processing_deepseek_vl.py +41 -10
  334. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +5 -3
  335. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +35 -35
  336. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +24 -20
  337. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +61 -109
  338. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +118 -146
  339. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +44 -12
  340. transformers/models/deformable_detr/configuration_deformable_detr.py +3 -2
  341. transformers/models/deformable_detr/image_processing_deformable_detr.py +61 -59
  342. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +28 -28
  343. transformers/models/deformable_detr/modeling_deformable_detr.py +82 -88
  344. transformers/models/deformable_detr/modular_deformable_detr.py +3 -1
  345. transformers/models/deit/configuration_deit.py +1 -0
  346. transformers/models/deit/image_processing_deit.py +21 -18
  347. transformers/models/deit/image_processing_deit_fast.py +1 -0
  348. transformers/models/deit/modeling_deit.py +22 -24
  349. transformers/models/depth_anything/configuration_depth_anything.py +4 -2
  350. transformers/models/depth_anything/modeling_depth_anything.py +10 -10
  351. transformers/models/depth_pro/configuration_depth_pro.py +1 -0
  352. transformers/models/depth_pro/image_processing_depth_pro.py +23 -22
  353. transformers/models/depth_pro/image_processing_depth_pro_fast.py +10 -8
  354. transformers/models/depth_pro/modeling_depth_pro.py +27 -31
  355. transformers/models/detr/configuration_detr.py +2 -1
  356. transformers/models/detr/image_processing_detr.py +66 -64
  357. transformers/models/detr/image_processing_detr_fast.py +34 -33
  358. transformers/models/detr/modeling_detr.py +79 -95
  359. transformers/models/dia/configuration_dia.py +15 -9
  360. transformers/models/dia/feature_extraction_dia.py +9 -6
  361. transformers/models/dia/generation_dia.py +50 -48
  362. transformers/models/dia/modeling_dia.py +69 -78
  363. transformers/models/dia/modular_dia.py +56 -64
  364. transformers/models/dia/processing_dia.py +29 -39
  365. transformers/models/dia/tokenization_dia.py +6 -3
  366. transformers/models/diffllama/configuration_diffllama.py +30 -25
  367. transformers/models/diffllama/modeling_diffllama.py +49 -46
  368. transformers/models/diffllama/modular_diffllama.py +19 -17
  369. transformers/models/dinat/configuration_dinat.py +1 -0
  370. transformers/models/dinat/modeling_dinat.py +44 -47
  371. transformers/models/dinov2/configuration_dinov2.py +1 -0
  372. transformers/models/dinov2/modeling_dinov2.py +15 -15
  373. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +1 -1
  374. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +15 -16
  375. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +9 -9
  376. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +7 -4
  377. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +6 -3
  378. transformers/models/dinov3_vit/configuration_dinov3_vit.py +8 -5
  379. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +9 -7
  380. transformers/models/dinov3_vit/modeling_dinov3_vit.py +18 -19
  381. transformers/models/dinov3_vit/modular_dinov3_vit.py +15 -16
  382. transformers/models/distilbert/configuration_distilbert.py +2 -8
  383. transformers/models/distilbert/modeling_distilbert.py +55 -55
  384. transformers/models/distilbert/tokenization_distilbert.py +1 -13
  385. transformers/models/doge/__init__.py +1 -0
  386. transformers/models/doge/configuration_doge.py +32 -39
  387. transformers/models/doge/modeling_doge.py +49 -45
  388. transformers/models/doge/modular_doge.py +63 -71
  389. transformers/models/donut/configuration_donut_swin.py +1 -0
  390. transformers/models/donut/image_processing_donut.py +29 -26
  391. transformers/models/donut/image_processing_donut_fast.py +15 -9
  392. transformers/models/donut/modeling_donut_swin.py +58 -62
  393. transformers/models/donut/processing_donut.py +26 -5
  394. transformers/models/dots1/configuration_dots1.py +33 -41
  395. transformers/models/dots1/modeling_dots1.py +45 -54
  396. transformers/models/dots1/modular_dots1.py +4 -5
  397. transformers/models/dpr/configuration_dpr.py +2 -19
  398. transformers/models/dpr/modeling_dpr.py +39 -42
  399. transformers/models/dpr/tokenization_dpr.py +9 -19
  400. transformers/models/dpr/tokenization_dpr_fast.py +9 -7
  401. transformers/models/dpt/configuration_dpt.py +2 -1
  402. transformers/models/dpt/image_processing_dpt.py +66 -65
  403. transformers/models/dpt/image_processing_dpt_fast.py +20 -18
  404. transformers/models/dpt/modeling_dpt.py +30 -32
  405. transformers/models/dpt/modular_dpt.py +17 -15
  406. transformers/models/edgetam/configuration_edgetam.py +3 -2
  407. transformers/models/edgetam/modeling_edgetam.py +86 -86
  408. transformers/models/edgetam/modular_edgetam.py +26 -21
  409. transformers/models/edgetam_video/__init__.py +1 -0
  410. transformers/models/edgetam_video/configuration_edgetam_video.py +1 -0
  411. transformers/models/edgetam_video/modeling_edgetam_video.py +158 -169
  412. transformers/models/edgetam_video/modular_edgetam_video.py +37 -30
  413. transformers/models/efficientloftr/configuration_efficientloftr.py +5 -4
  414. transformers/models/efficientloftr/image_processing_efficientloftr.py +16 -14
  415. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +9 -9
  416. transformers/models/efficientloftr/modeling_efficientloftr.py +38 -59
  417. transformers/models/efficientloftr/modular_efficientloftr.py +3 -1
  418. transformers/models/efficientnet/configuration_efficientnet.py +1 -0
  419. transformers/models/efficientnet/image_processing_efficientnet.py +32 -28
  420. transformers/models/efficientnet/image_processing_efficientnet_fast.py +19 -17
  421. transformers/models/efficientnet/modeling_efficientnet.py +15 -19
  422. transformers/models/electra/configuration_electra.py +3 -13
  423. transformers/models/electra/modeling_electra.py +103 -108
  424. transformers/models/emu3/configuration_emu3.py +17 -13
  425. transformers/models/emu3/image_processing_emu3.py +39 -44
  426. transformers/models/emu3/modeling_emu3.py +108 -148
  427. transformers/models/emu3/modular_emu3.py +73 -115
  428. transformers/models/emu3/processing_emu3.py +43 -18
  429. transformers/models/encodec/configuration_encodec.py +4 -2
  430. transformers/models/encodec/feature_extraction_encodec.py +13 -10
  431. transformers/models/encodec/modeling_encodec.py +29 -39
  432. transformers/models/encoder_decoder/configuration_encoder_decoder.py +2 -12
  433. transformers/models/encoder_decoder/modeling_encoder_decoder.py +43 -37
  434. transformers/models/eomt/configuration_eomt.py +1 -0
  435. transformers/models/eomt/image_processing_eomt.py +56 -66
  436. transformers/models/eomt/image_processing_eomt_fast.py +33 -76
  437. transformers/models/eomt/modeling_eomt.py +18 -23
  438. transformers/models/eomt/modular_eomt.py +13 -18
  439. transformers/models/ernie/configuration_ernie.py +3 -24
  440. transformers/models/ernie/modeling_ernie.py +132 -127
  441. transformers/models/ernie/modular_ernie.py +103 -97
  442. transformers/models/ernie4_5/configuration_ernie4_5.py +27 -23
  443. transformers/models/ernie4_5/modeling_ernie4_5.py +38 -36
  444. transformers/models/ernie4_5/modular_ernie4_5.py +4 -3
  445. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +36 -32
  446. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +55 -56
  447. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +46 -18
  448. transformers/models/esm/configuration_esm.py +15 -11
  449. transformers/models/esm/modeling_esm.py +34 -38
  450. transformers/models/esm/modeling_esmfold.py +49 -53
  451. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  452. transformers/models/esm/openfold_utils/loss.py +2 -1
  453. transformers/models/esm/openfold_utils/protein.py +16 -15
  454. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  455. transformers/models/esm/tokenization_esm.py +4 -2
  456. transformers/models/evolla/configuration_evolla.py +40 -50
  457. transformers/models/evolla/modeling_evolla.py +66 -71
  458. transformers/models/evolla/modular_evolla.py +47 -53
  459. transformers/models/evolla/processing_evolla.py +35 -23
  460. transformers/models/exaone4/configuration_exaone4.py +25 -23
  461. transformers/models/exaone4/modeling_exaone4.py +38 -35
  462. transformers/models/exaone4/modular_exaone4.py +46 -44
  463. transformers/models/falcon/configuration_falcon.py +26 -31
  464. transformers/models/falcon/modeling_falcon.py +80 -82
  465. transformers/models/falcon_h1/configuration_falcon_h1.py +51 -45
  466. transformers/models/falcon_h1/modeling_falcon_h1.py +82 -85
  467. transformers/models/falcon_h1/modular_falcon_h1.py +51 -56
  468. transformers/models/falcon_mamba/configuration_falcon_mamba.py +2 -1
  469. transformers/models/falcon_mamba/modeling_falcon_mamba.py +82 -75
  470. transformers/models/falcon_mamba/modular_falcon_mamba.py +45 -28
  471. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +6 -2
  472. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +60 -76
  473. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +3 -2
  474. transformers/models/flaubert/configuration_flaubert.py +5 -10
  475. transformers/models/flaubert/modeling_flaubert.py +143 -145
  476. transformers/models/flaubert/tokenization_flaubert.py +5 -3
  477. transformers/models/flava/configuration_flava.py +6 -5
  478. transformers/models/flava/image_processing_flava.py +67 -66
  479. transformers/models/flava/image_processing_flava_fast.py +49 -46
  480. transformers/models/flava/modeling_flava.py +136 -153
  481. transformers/models/flava/processing_flava.py +12 -2
  482. transformers/models/flex_olmo/__init__.py +1 -0
  483. transformers/models/flex_olmo/configuration_flex_olmo.py +32 -28
  484. transformers/models/flex_olmo/modeling_flex_olmo.py +47 -47
  485. transformers/models/flex_olmo/modular_flex_olmo.py +44 -40
  486. transformers/models/florence2/configuration_florence2.py +1 -0
  487. transformers/models/florence2/modeling_florence2.py +69 -111
  488. transformers/models/florence2/modular_florence2.py +101 -104
  489. transformers/models/florence2/processing_florence2.py +47 -18
  490. transformers/models/fnet/configuration_fnet.py +2 -6
  491. transformers/models/fnet/modeling_fnet.py +80 -83
  492. transformers/models/fnet/tokenization_fnet.py +1 -0
  493. transformers/models/focalnet/configuration_focalnet.py +1 -0
  494. transformers/models/focalnet/modeling_focalnet.py +45 -51
  495. transformers/models/fsmt/configuration_fsmt.py +17 -12
  496. transformers/models/fsmt/modeling_fsmt.py +48 -49
  497. transformers/models/fsmt/tokenization_fsmt.py +5 -3
  498. transformers/models/funnel/configuration_funnel.py +1 -8
  499. transformers/models/funnel/modeling_funnel.py +93 -99
  500. transformers/models/funnel/tokenization_funnel.py +27 -17
  501. transformers/models/fuyu/configuration_fuyu.py +34 -28
  502. transformers/models/fuyu/image_processing_fuyu.py +31 -29
  503. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  504. transformers/models/fuyu/modeling_fuyu.py +53 -53
  505. transformers/models/fuyu/processing_fuyu.py +34 -23
  506. transformers/models/gemma/configuration_gemma.py +30 -25
  507. transformers/models/gemma/modeling_gemma.py +50 -46
  508. transformers/models/gemma/modular_gemma.py +47 -42
  509. transformers/models/gemma/tokenization_gemma.py +30 -10
  510. transformers/models/gemma2/configuration_gemma2.py +35 -30
  511. transformers/models/gemma2/modeling_gemma2.py +42 -39
  512. transformers/models/gemma2/modular_gemma2.py +66 -63
  513. transformers/models/gemma3/configuration_gemma3.py +44 -44
  514. transformers/models/gemma3/image_processing_gemma3.py +31 -29
  515. transformers/models/gemma3/image_processing_gemma3_fast.py +13 -11
  516. transformers/models/gemma3/modeling_gemma3.py +207 -159
  517. transformers/models/gemma3/modular_gemma3.py +204 -153
  518. transformers/models/gemma3/processing_gemma3.py +5 -5
  519. transformers/models/gemma3n/configuration_gemma3n.py +26 -36
  520. transformers/models/gemma3n/feature_extraction_gemma3n.py +11 -9
  521. transformers/models/gemma3n/modeling_gemma3n.py +356 -222
  522. transformers/models/gemma3n/modular_gemma3n.py +207 -230
  523. transformers/models/gemma3n/processing_gemma3n.py +26 -12
  524. transformers/models/git/configuration_git.py +8 -5
  525. transformers/models/git/modeling_git.py +204 -266
  526. transformers/models/git/processing_git.py +14 -2
  527. transformers/models/glm/configuration_glm.py +28 -24
  528. transformers/models/glm/modeling_glm.py +40 -37
  529. transformers/models/glm/modular_glm.py +7 -4
  530. transformers/models/glm4/configuration_glm4.py +28 -24
  531. transformers/models/glm4/modeling_glm4.py +42 -40
  532. transformers/models/glm4/modular_glm4.py +10 -8
  533. transformers/models/glm46v/configuration_glm46v.py +1 -0
  534. transformers/models/glm46v/image_processing_glm46v.py +40 -35
  535. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  536. transformers/models/glm46v/modeling_glm46v.py +90 -137
  537. transformers/models/glm46v/modular_glm46v.py +3 -4
  538. transformers/models/glm46v/processing_glm46v.py +41 -7
  539. transformers/models/glm46v/video_processing_glm46v.py +11 -9
  540. transformers/models/glm4_moe/configuration_glm4_moe.py +32 -40
  541. transformers/models/glm4_moe/modeling_glm4_moe.py +42 -45
  542. transformers/models/glm4_moe/modular_glm4_moe.py +34 -42
  543. transformers/models/glm4v/configuration_glm4v.py +20 -18
  544. transformers/models/glm4v/image_processing_glm4v.py +40 -34
  545. transformers/models/glm4v/image_processing_glm4v_fast.py +9 -8
  546. transformers/models/glm4v/modeling_glm4v.py +205 -254
  547. transformers/models/glm4v/modular_glm4v.py +224 -210
  548. transformers/models/glm4v/processing_glm4v.py +41 -7
  549. transformers/models/glm4v/video_processing_glm4v.py +11 -9
  550. transformers/models/glm4v_moe/configuration_glm4v_moe.py +125 -136
  551. transformers/models/glm4v_moe/modeling_glm4v_moe.py +368 -377
  552. transformers/models/glm4v_moe/modular_glm4v_moe.py +169 -83
  553. transformers/models/glpn/configuration_glpn.py +1 -0
  554. transformers/models/glpn/image_processing_glpn.py +12 -11
  555. transformers/models/glpn/image_processing_glpn_fast.py +13 -11
  556. transformers/models/glpn/modeling_glpn.py +14 -16
  557. transformers/models/got_ocr2/configuration_got_ocr2.py +12 -4
  558. transformers/models/got_ocr2/image_processing_got_ocr2.py +24 -22
  559. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +11 -9
  560. transformers/models/got_ocr2/modeling_got_ocr2.py +80 -77
  561. transformers/models/got_ocr2/modular_got_ocr2.py +51 -54
  562. transformers/models/got_ocr2/processing_got_ocr2.py +63 -42
  563. transformers/models/gpt2/configuration_gpt2.py +2 -13
  564. transformers/models/gpt2/modeling_gpt2.py +115 -120
  565. transformers/models/gpt2/tokenization_gpt2.py +46 -15
  566. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +2 -5
  567. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +89 -79
  568. transformers/models/gpt_neo/configuration_gpt_neo.py +2 -9
  569. transformers/models/gpt_neo/modeling_gpt_neo.py +67 -83
  570. transformers/models/gpt_neox/configuration_gpt_neox.py +25 -25
  571. transformers/models/gpt_neox/modeling_gpt_neox.py +75 -76
  572. transformers/models/gpt_neox/modular_gpt_neox.py +66 -67
  573. transformers/models/gpt_neox/tokenization_gpt_neox.py +51 -9
  574. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +19 -24
  575. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +47 -46
  576. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +3 -1
  577. transformers/models/gpt_oss/configuration_gpt_oss.py +28 -46
  578. transformers/models/gpt_oss/modeling_gpt_oss.py +121 -83
  579. transformers/models/gpt_oss/modular_gpt_oss.py +103 -64
  580. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  581. transformers/models/gptj/configuration_gptj.py +4 -4
  582. transformers/models/gptj/modeling_gptj.py +87 -101
  583. transformers/models/granite/configuration_granite.py +33 -28
  584. transformers/models/granite/modeling_granite.py +46 -44
  585. transformers/models/granite/modular_granite.py +31 -29
  586. transformers/models/granite_speech/configuration_granite_speech.py +1 -0
  587. transformers/models/granite_speech/feature_extraction_granite_speech.py +3 -1
  588. transformers/models/granite_speech/modeling_granite_speech.py +52 -82
  589. transformers/models/granite_speech/processing_granite_speech.py +4 -11
  590. transformers/models/granitemoe/configuration_granitemoe.py +36 -31
  591. transformers/models/granitemoe/modeling_granitemoe.py +46 -41
  592. transformers/models/granitemoe/modular_granitemoe.py +27 -22
  593. transformers/models/granitemoehybrid/__init__.py +1 -0
  594. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +47 -46
  595. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +93 -97
  596. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +21 -54
  597. transformers/models/granitemoeshared/configuration_granitemoeshared.py +37 -33
  598. transformers/models/granitemoeshared/modeling_granitemoeshared.py +61 -54
  599. transformers/models/granitemoeshared/modular_granitemoeshared.py +21 -19
  600. transformers/models/grounding_dino/configuration_grounding_dino.py +4 -6
  601. transformers/models/grounding_dino/image_processing_grounding_dino.py +62 -60
  602. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +29 -28
  603. transformers/models/grounding_dino/modeling_grounding_dino.py +140 -155
  604. transformers/models/grounding_dino/modular_grounding_dino.py +3 -2
  605. transformers/models/grounding_dino/processing_grounding_dino.py +38 -10
  606. transformers/models/groupvit/configuration_groupvit.py +2 -4
  607. transformers/models/groupvit/modeling_groupvit.py +93 -107
  608. transformers/models/helium/configuration_helium.py +29 -25
  609. transformers/models/helium/modeling_helium.py +40 -38
  610. transformers/models/helium/modular_helium.py +7 -3
  611. transformers/models/herbert/tokenization_herbert.py +28 -10
  612. transformers/models/hgnet_v2/configuration_hgnet_v2.py +1 -0
  613. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -24
  614. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -24
  615. transformers/models/hiera/configuration_hiera.py +1 -0
  616. transformers/models/hiera/modeling_hiera.py +66 -72
  617. transformers/models/hubert/configuration_hubert.py +2 -4
  618. transformers/models/hubert/modeling_hubert.py +37 -42
  619. transformers/models/hubert/modular_hubert.py +11 -13
  620. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +31 -26
  621. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +38 -35
  622. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +6 -4
  623. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  624. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +36 -31
  625. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +42 -47
  626. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +9 -9
  627. transformers/models/ibert/configuration_ibert.py +2 -4
  628. transformers/models/ibert/modeling_ibert.py +62 -82
  629. transformers/models/ibert/quant_modules.py +1 -0
  630. transformers/models/idefics/configuration_idefics.py +8 -5
  631. transformers/models/idefics/image_processing_idefics.py +15 -13
  632. transformers/models/idefics/modeling_idefics.py +82 -75
  633. transformers/models/idefics/perceiver.py +3 -1
  634. transformers/models/idefics/processing_idefics.py +48 -32
  635. transformers/models/idefics/vision.py +25 -24
  636. transformers/models/idefics2/configuration_idefics2.py +3 -1
  637. transformers/models/idefics2/image_processing_idefics2.py +32 -31
  638. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  639. transformers/models/idefics2/modeling_idefics2.py +101 -127
  640. transformers/models/idefics2/processing_idefics2.py +68 -10
  641. transformers/models/idefics3/configuration_idefics3.py +4 -1
  642. transformers/models/idefics3/image_processing_idefics3.py +43 -42
  643. transformers/models/idefics3/image_processing_idefics3_fast.py +15 -40
  644. transformers/models/idefics3/modeling_idefics3.py +90 -115
  645. transformers/models/idefics3/processing_idefics3.py +69 -15
  646. transformers/models/ijepa/configuration_ijepa.py +1 -0
  647. transformers/models/ijepa/modeling_ijepa.py +11 -10
  648. transformers/models/ijepa/modular_ijepa.py +7 -5
  649. transformers/models/imagegpt/configuration_imagegpt.py +2 -9
  650. transformers/models/imagegpt/image_processing_imagegpt.py +18 -17
  651. transformers/models/imagegpt/image_processing_imagegpt_fast.py +16 -11
  652. transformers/models/imagegpt/modeling_imagegpt.py +65 -76
  653. transformers/models/informer/configuration_informer.py +9 -6
  654. transformers/models/informer/modeling_informer.py +86 -88
  655. transformers/models/informer/modular_informer.py +16 -14
  656. transformers/models/instructblip/configuration_instructblip.py +2 -2
  657. transformers/models/instructblip/modeling_instructblip.py +63 -103
  658. transformers/models/instructblip/processing_instructblip.py +36 -10
  659. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  660. transformers/models/instructblipvideo/modeling_instructblipvideo.py +139 -157
  661. transformers/models/instructblipvideo/modular_instructblipvideo.py +64 -73
  662. transformers/models/instructblipvideo/processing_instructblipvideo.py +33 -14
  663. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +8 -6
  664. transformers/models/internvl/configuration_internvl.py +1 -0
  665. transformers/models/internvl/modeling_internvl.py +106 -85
  666. transformers/models/internvl/modular_internvl.py +67 -47
  667. transformers/models/internvl/processing_internvl.py +45 -12
  668. transformers/models/internvl/video_processing_internvl.py +12 -10
  669. transformers/models/jamba/configuration_jamba.py +8 -5
  670. transformers/models/jamba/modeling_jamba.py +66 -68
  671. transformers/models/jamba/modular_jamba.py +55 -54
  672. transformers/models/janus/configuration_janus.py +1 -0
  673. transformers/models/janus/image_processing_janus.py +37 -35
  674. transformers/models/janus/image_processing_janus_fast.py +20 -18
  675. transformers/models/janus/modeling_janus.py +191 -115
  676. transformers/models/janus/modular_janus.py +84 -133
  677. transformers/models/janus/processing_janus.py +43 -17
  678. transformers/models/jetmoe/configuration_jetmoe.py +26 -24
  679. transformers/models/jetmoe/modeling_jetmoe.py +46 -43
  680. transformers/models/jetmoe/modular_jetmoe.py +33 -31
  681. transformers/models/kosmos2/configuration_kosmos2.py +9 -10
  682. transformers/models/kosmos2/modeling_kosmos2.py +173 -208
  683. transformers/models/kosmos2/processing_kosmos2.py +55 -40
  684. transformers/models/kosmos2_5/__init__.py +1 -0
  685. transformers/models/kosmos2_5/configuration_kosmos2_5.py +9 -8
  686. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +12 -10
  687. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +13 -4
  688. transformers/models/kosmos2_5/modeling_kosmos2_5.py +118 -132
  689. transformers/models/kosmos2_5/processing_kosmos2_5.py +29 -8
  690. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +28 -31
  691. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +14 -12
  692. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +100 -110
  693. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +22 -28
  694. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +8 -2
  695. transformers/models/layoutlm/configuration_layoutlm.py +2 -14
  696. transformers/models/layoutlm/modeling_layoutlm.py +72 -77
  697. transformers/models/layoutlmv2/configuration_layoutlmv2.py +17 -14
  698. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +21 -18
  699. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +9 -7
  700. transformers/models/layoutlmv2/modeling_layoutlmv2.py +50 -64
  701. transformers/models/layoutlmv2/processing_layoutlmv2.py +44 -14
  702. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +126 -73
  703. transformers/models/layoutlmv3/configuration_layoutlmv3.py +19 -16
  704. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +26 -24
  705. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +11 -9
  706. transformers/models/layoutlmv3/modeling_layoutlmv3.py +56 -82
  707. transformers/models/layoutlmv3/processing_layoutlmv3.py +46 -14
  708. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +134 -74
  709. transformers/models/layoutxlm/configuration_layoutxlm.py +17 -14
  710. transformers/models/layoutxlm/modular_layoutxlm.py +1 -0
  711. transformers/models/layoutxlm/processing_layoutxlm.py +44 -14
  712. transformers/models/layoutxlm/tokenization_layoutxlm.py +113 -77
  713. transformers/models/led/configuration_led.py +12 -8
  714. transformers/models/led/modeling_led.py +266 -124
  715. transformers/models/levit/configuration_levit.py +1 -0
  716. transformers/models/levit/image_processing_levit.py +21 -19
  717. transformers/models/levit/image_processing_levit_fast.py +5 -4
  718. transformers/models/levit/modeling_levit.py +19 -38
  719. transformers/models/lfm2/configuration_lfm2.py +30 -27
  720. transformers/models/lfm2/modeling_lfm2.py +50 -47
  721. transformers/models/lfm2/modular_lfm2.py +30 -29
  722. transformers/models/lfm2_moe/__init__.py +1 -0
  723. transformers/models/lfm2_moe/configuration_lfm2_moe.py +9 -6
  724. transformers/models/lfm2_moe/modeling_lfm2_moe.py +53 -61
  725. transformers/models/lfm2_moe/modular_lfm2_moe.py +37 -13
  726. transformers/models/lfm2_vl/configuration_lfm2_vl.py +1 -4
  727. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +12 -41
  728. transformers/models/lfm2_vl/modeling_lfm2_vl.py +66 -84
  729. transformers/models/lfm2_vl/modular_lfm2_vl.py +56 -70
  730. transformers/models/lfm2_vl/processing_lfm2_vl.py +76 -96
  731. transformers/models/lightglue/image_processing_lightglue.py +15 -16
  732. transformers/models/lightglue/image_processing_lightglue_fast.py +9 -9
  733. transformers/models/lightglue/modeling_lightglue.py +31 -31
  734. transformers/models/lightglue/modular_lightglue.py +28 -29
  735. transformers/models/lilt/configuration_lilt.py +2 -6
  736. transformers/models/lilt/modeling_lilt.py +70 -76
  737. transformers/models/llama/configuration_llama.py +31 -26
  738. transformers/models/llama/modeling_llama.py +39 -36
  739. transformers/models/llama/tokenization_llama.py +44 -14
  740. transformers/models/llama4/configuration_llama4.py +30 -27
  741. transformers/models/llama4/image_processing_llama4_fast.py +14 -12
  742. transformers/models/llama4/modeling_llama4.py +113 -120
  743. transformers/models/llama4/processing_llama4.py +57 -33
  744. transformers/models/llava/configuration_llava.py +1 -10
  745. transformers/models/llava/image_processing_llava.py +28 -25
  746. transformers/models/llava/image_processing_llava_fast.py +11 -9
  747. transformers/models/llava/modeling_llava.py +109 -85
  748. transformers/models/llava/processing_llava.py +51 -18
  749. transformers/models/llava_next/configuration_llava_next.py +2 -2
  750. transformers/models/llava_next/image_processing_llava_next.py +45 -43
  751. transformers/models/llava_next/image_processing_llava_next_fast.py +13 -11
  752. transformers/models/llava_next/modeling_llava_next.py +107 -110
  753. transformers/models/llava_next/processing_llava_next.py +47 -18
  754. transformers/models/llava_next_video/configuration_llava_next_video.py +7 -4
  755. transformers/models/llava_next_video/modeling_llava_next_video.py +158 -175
  756. transformers/models/llava_next_video/modular_llava_next_video.py +150 -155
  757. transformers/models/llava_next_video/processing_llava_next_video.py +63 -21
  758. transformers/models/llava_next_video/video_processing_llava_next_video.py +1 -0
  759. transformers/models/llava_onevision/configuration_llava_onevision.py +7 -4
  760. transformers/models/llava_onevision/image_processing_llava_onevision.py +42 -40
  761. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +15 -14
  762. transformers/models/llava_onevision/modeling_llava_onevision.py +169 -177
  763. transformers/models/llava_onevision/modular_llava_onevision.py +156 -163
  764. transformers/models/llava_onevision/processing_llava_onevision.py +53 -21
  765. transformers/models/llava_onevision/video_processing_llava_onevision.py +1 -0
  766. transformers/models/longcat_flash/__init__.py +1 -0
  767. transformers/models/longcat_flash/configuration_longcat_flash.py +42 -37
  768. transformers/models/longcat_flash/modeling_longcat_flash.py +36 -36
  769. transformers/models/longcat_flash/modular_longcat_flash.py +21 -21
  770. transformers/models/longformer/configuration_longformer.py +5 -5
  771. transformers/models/longformer/modeling_longformer.py +101 -105
  772. transformers/models/longt5/configuration_longt5.py +7 -9
  773. transformers/models/longt5/modeling_longt5.py +49 -49
  774. transformers/models/luke/configuration_luke.py +2 -8
  775. transformers/models/luke/modeling_luke.py +181 -188
  776. transformers/models/luke/tokenization_luke.py +140 -107
  777. transformers/models/lxmert/configuration_lxmert.py +1 -16
  778. transformers/models/lxmert/modeling_lxmert.py +74 -65
  779. transformers/models/m2m_100/configuration_m2m_100.py +9 -7
  780. transformers/models/m2m_100/modeling_m2m_100.py +71 -83
  781. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  782. transformers/models/mamba/configuration_mamba.py +2 -1
  783. transformers/models/mamba/modeling_mamba.py +66 -58
  784. transformers/models/mamba2/configuration_mamba2.py +8 -5
  785. transformers/models/mamba2/modeling_mamba2.py +69 -68
  786. transformers/models/marian/configuration_marian.py +5 -10
  787. transformers/models/marian/modeling_marian.py +87 -93
  788. transformers/models/marian/tokenization_marian.py +6 -6
  789. transformers/models/markuplm/configuration_markuplm.py +7 -4
  790. transformers/models/markuplm/feature_extraction_markuplm.py +2 -1
  791. transformers/models/markuplm/modeling_markuplm.py +70 -69
  792. transformers/models/markuplm/processing_markuplm.py +38 -31
  793. transformers/models/markuplm/tokenization_markuplm.py +136 -93
  794. transformers/models/mask2former/configuration_mask2former.py +8 -5
  795. transformers/models/mask2former/image_processing_mask2former.py +85 -84
  796. transformers/models/mask2former/image_processing_mask2former_fast.py +40 -37
  797. transformers/models/mask2former/modeling_mask2former.py +103 -118
  798. transformers/models/mask2former/modular_mask2former.py +8 -6
  799. transformers/models/maskformer/configuration_maskformer.py +9 -6
  800. transformers/models/maskformer/configuration_maskformer_swin.py +1 -0
  801. transformers/models/maskformer/image_processing_maskformer.py +85 -84
  802. transformers/models/maskformer/image_processing_maskformer_fast.py +40 -36
  803. transformers/models/maskformer/modeling_maskformer.py +65 -79
  804. transformers/models/maskformer/modeling_maskformer_swin.py +32 -36
  805. transformers/models/mbart/configuration_mbart.py +4 -9
  806. transformers/models/mbart/modeling_mbart.py +116 -131
  807. transformers/models/mbart/tokenization_mbart.py +54 -11
  808. transformers/models/mbart50/tokenization_mbart50.py +13 -8
  809. transformers/models/megatron_bert/configuration_megatron_bert.py +3 -13
  810. transformers/models/megatron_bert/modeling_megatron_bert.py +150 -148
  811. transformers/models/metaclip_2/configuration_metaclip_2.py +1 -4
  812. transformers/models/metaclip_2/modeling_metaclip_2.py +84 -91
  813. transformers/models/metaclip_2/modular_metaclip_2.py +45 -61
  814. transformers/models/mgp_str/configuration_mgp_str.py +1 -0
  815. transformers/models/mgp_str/modeling_mgp_str.py +18 -20
  816. transformers/models/mgp_str/processing_mgp_str.py +20 -3
  817. transformers/models/mgp_str/tokenization_mgp_str.py +3 -1
  818. transformers/models/mimi/configuration_mimi.py +40 -42
  819. transformers/models/mimi/modeling_mimi.py +113 -142
  820. transformers/models/minimax/__init__.py +1 -0
  821. transformers/models/minimax/configuration_minimax.py +43 -37
  822. transformers/models/minimax/modeling_minimax.py +51 -61
  823. transformers/models/minimax/modular_minimax.py +62 -68
  824. transformers/models/ministral/configuration_ministral.py +29 -25
  825. transformers/models/ministral/modeling_ministral.py +38 -36
  826. transformers/models/ministral/modular_ministral.py +37 -32
  827. transformers/models/ministral3/configuration_ministral3.py +27 -24
  828. transformers/models/ministral3/modeling_ministral3.py +37 -36
  829. transformers/models/ministral3/modular_ministral3.py +5 -4
  830. transformers/models/mistral/configuration_mistral.py +29 -24
  831. transformers/models/mistral/modeling_mistral.py +37 -36
  832. transformers/models/mistral/modular_mistral.py +12 -11
  833. transformers/models/mistral3/configuration_mistral3.py +1 -4
  834. transformers/models/mistral3/modeling_mistral3.py +86 -89
  835. transformers/models/mistral3/modular_mistral3.py +68 -69
  836. transformers/models/mixtral/configuration_mixtral.py +34 -29
  837. transformers/models/mixtral/modeling_mixtral.py +45 -50
  838. transformers/models/mixtral/modular_mixtral.py +31 -32
  839. transformers/models/mlcd/configuration_mlcd.py +1 -0
  840. transformers/models/mlcd/modeling_mlcd.py +14 -20
  841. transformers/models/mlcd/modular_mlcd.py +13 -17
  842. transformers/models/mllama/configuration_mllama.py +15 -10
  843. transformers/models/mllama/image_processing_mllama.py +25 -23
  844. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  845. transformers/models/mllama/modeling_mllama.py +94 -105
  846. transformers/models/mllama/processing_mllama.py +55 -6
  847. transformers/models/mluke/tokenization_mluke.py +107 -101
  848. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +3 -5
  849. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +140 -155
  850. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +3 -5
  851. transformers/models/mobilebert/configuration_mobilebert.py +2 -4
  852. transformers/models/mobilebert/modeling_mobilebert.py +85 -77
  853. transformers/models/mobilebert/tokenization_mobilebert.py +1 -0
  854. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +1 -0
  855. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +23 -20
  856. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +1 -0
  857. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +16 -15
  858. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +1 -0
  859. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +51 -48
  860. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +15 -13
  861. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +22 -24
  862. transformers/models/mobilevit/configuration_mobilevit.py +1 -0
  863. transformers/models/mobilevit/image_processing_mobilevit.py +49 -46
  864. transformers/models/mobilevit/image_processing_mobilevit_fast.py +14 -12
  865. transformers/models/mobilevit/modeling_mobilevit.py +21 -28
  866. transformers/models/mobilevitv2/configuration_mobilevitv2.py +1 -0
  867. transformers/models/mobilevitv2/modeling_mobilevitv2.py +22 -28
  868. transformers/models/modernbert/configuration_modernbert.py +42 -44
  869. transformers/models/modernbert/modeling_modernbert.py +133 -145
  870. transformers/models/modernbert/modular_modernbert.py +170 -186
  871. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +40 -40
  872. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +57 -62
  873. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +86 -94
  874. transformers/models/moonshine/configuration_moonshine.py +31 -34
  875. transformers/models/moonshine/modeling_moonshine.py +71 -71
  876. transformers/models/moonshine/modular_moonshine.py +83 -88
  877. transformers/models/moshi/configuration_moshi.py +23 -46
  878. transformers/models/moshi/modeling_moshi.py +187 -157
  879. transformers/models/mpnet/configuration_mpnet.py +2 -6
  880. transformers/models/mpnet/modeling_mpnet.py +57 -62
  881. transformers/models/mpnet/tokenization_mpnet.py +15 -4
  882. transformers/models/mpt/configuration_mpt.py +9 -5
  883. transformers/models/mpt/modeling_mpt.py +60 -60
  884. transformers/models/mra/configuration_mra.py +2 -8
  885. transformers/models/mra/modeling_mra.py +57 -64
  886. transformers/models/mt5/configuration_mt5.py +8 -10
  887. transformers/models/mt5/modeling_mt5.py +95 -87
  888. transformers/models/musicgen/configuration_musicgen.py +8 -12
  889. transformers/models/musicgen/modeling_musicgen.py +122 -118
  890. transformers/models/musicgen/processing_musicgen.py +21 -3
  891. transformers/models/musicgen_melody/configuration_musicgen_melody.py +8 -15
  892. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +9 -8
  893. transformers/models/musicgen_melody/modeling_musicgen_melody.py +123 -117
  894. transformers/models/musicgen_melody/processing_musicgen_melody.py +22 -3
  895. transformers/models/mvp/configuration_mvp.py +5 -8
  896. transformers/models/mvp/modeling_mvp.py +123 -135
  897. transformers/models/myt5/tokenization_myt5.py +10 -8
  898. transformers/models/nanochat/configuration_nanochat.py +8 -5
  899. transformers/models/nanochat/modeling_nanochat.py +40 -37
  900. transformers/models/nanochat/modular_nanochat.py +14 -12
  901. transformers/models/nemotron/configuration_nemotron.py +30 -25
  902. transformers/models/nemotron/modeling_nemotron.py +57 -56
  903. transformers/models/nllb/tokenization_nllb.py +28 -12
  904. transformers/models/nllb_moe/configuration_nllb_moe.py +9 -7
  905. transformers/models/nllb_moe/modeling_nllb_moe.py +69 -77
  906. transformers/models/nougat/image_processing_nougat.py +32 -29
  907. transformers/models/nougat/image_processing_nougat_fast.py +14 -12
  908. transformers/models/nougat/processing_nougat.py +39 -37
  909. transformers/models/nougat/tokenization_nougat.py +73 -18
  910. transformers/models/nystromformer/configuration_nystromformer.py +2 -8
  911. transformers/models/nystromformer/modeling_nystromformer.py +63 -74
  912. transformers/models/olmo/configuration_olmo.py +28 -23
  913. transformers/models/olmo/modeling_olmo.py +39 -36
  914. transformers/models/olmo/modular_olmo.py +11 -7
  915. transformers/models/olmo2/configuration_olmo2.py +28 -23
  916. transformers/models/olmo2/modeling_olmo2.py +41 -37
  917. transformers/models/olmo2/modular_olmo2.py +32 -29
  918. transformers/models/olmo3/__init__.py +1 -0
  919. transformers/models/olmo3/configuration_olmo3.py +30 -26
  920. transformers/models/olmo3/modeling_olmo3.py +39 -36
  921. transformers/models/olmo3/modular_olmo3.py +40 -37
  922. transformers/models/olmoe/configuration_olmoe.py +33 -29
  923. transformers/models/olmoe/modeling_olmoe.py +46 -52
  924. transformers/models/olmoe/modular_olmoe.py +15 -16
  925. transformers/models/omdet_turbo/configuration_omdet_turbo.py +4 -2
  926. transformers/models/omdet_turbo/modeling_omdet_turbo.py +47 -53
  927. transformers/models/omdet_turbo/processing_omdet_turbo.py +67 -19
  928. transformers/models/oneformer/configuration_oneformer.py +8 -5
  929. transformers/models/oneformer/image_processing_oneformer.py +84 -83
  930. transformers/models/oneformer/image_processing_oneformer_fast.py +42 -41
  931. transformers/models/oneformer/modeling_oneformer.py +171 -147
  932. transformers/models/oneformer/processing_oneformer.py +43 -28
  933. transformers/models/openai/configuration_openai.py +1 -16
  934. transformers/models/openai/modeling_openai.py +51 -65
  935. transformers/models/openai/tokenization_openai.py +47 -8
  936. transformers/models/opt/configuration_opt.py +7 -6
  937. transformers/models/opt/modeling_opt.py +76 -78
  938. transformers/models/ovis2/__init__.py +1 -0
  939. transformers/models/ovis2/configuration_ovis2.py +1 -0
  940. transformers/models/ovis2/image_processing_ovis2.py +24 -22
  941. transformers/models/ovis2/image_processing_ovis2_fast.py +11 -9
  942. transformers/models/ovis2/modeling_ovis2.py +142 -111
  943. transformers/models/ovis2/modular_ovis2.py +45 -90
  944. transformers/models/ovis2/processing_ovis2.py +40 -12
  945. transformers/models/owlv2/configuration_owlv2.py +2 -4
  946. transformers/models/owlv2/image_processing_owlv2.py +21 -20
  947. transformers/models/owlv2/image_processing_owlv2_fast.py +15 -12
  948. transformers/models/owlv2/modeling_owlv2.py +117 -133
  949. transformers/models/owlv2/modular_owlv2.py +14 -11
  950. transformers/models/owlv2/processing_owlv2.py +49 -20
  951. transformers/models/owlvit/configuration_owlvit.py +2 -4
  952. transformers/models/owlvit/image_processing_owlvit.py +22 -21
  953. transformers/models/owlvit/image_processing_owlvit_fast.py +3 -2
  954. transformers/models/owlvit/modeling_owlvit.py +116 -132
  955. transformers/models/owlvit/processing_owlvit.py +48 -20
  956. transformers/models/paligemma/configuration_paligemma.py +1 -4
  957. transformers/models/paligemma/modeling_paligemma.py +93 -103
  958. transformers/models/paligemma/processing_paligemma.py +66 -13
  959. transformers/models/parakeet/configuration_parakeet.py +14 -7
  960. transformers/models/parakeet/feature_extraction_parakeet.py +12 -10
  961. transformers/models/parakeet/modeling_parakeet.py +28 -32
  962. transformers/models/parakeet/modular_parakeet.py +20 -23
  963. transformers/models/parakeet/processing_parakeet.py +5 -13
  964. transformers/models/parakeet/{tokenization_parakeet.py → tokenization_parakeet_fast.py} +7 -5
  965. transformers/models/patchtsmixer/configuration_patchtsmixer.py +8 -5
  966. transformers/models/patchtsmixer/modeling_patchtsmixer.py +62 -70
  967. transformers/models/patchtst/configuration_patchtst.py +9 -6
  968. transformers/models/patchtst/modeling_patchtst.py +80 -97
  969. transformers/models/pegasus/configuration_pegasus.py +5 -8
  970. transformers/models/pegasus/modeling_pegasus.py +66 -72
  971. transformers/models/pegasus/tokenization_pegasus.py +45 -15
  972. transformers/models/pegasus_x/configuration_pegasus_x.py +4 -5
  973. transformers/models/pegasus_x/modeling_pegasus_x.py +52 -55
  974. transformers/models/perceiver/configuration_perceiver.py +1 -0
  975. transformers/models/perceiver/image_processing_perceiver.py +25 -22
  976. transformers/models/perceiver/image_processing_perceiver_fast.py +9 -7
  977. transformers/models/perceiver/modeling_perceiver.py +146 -165
  978. transformers/models/perceiver/tokenization_perceiver.py +6 -3
  979. transformers/models/perception_lm/configuration_perception_lm.py +1 -0
  980. transformers/models/perception_lm/image_processing_perception_lm_fast.py +10 -8
  981. transformers/models/perception_lm/modeling_perception_lm.py +70 -71
  982. transformers/models/perception_lm/modular_perception_lm.py +61 -65
  983. transformers/models/perception_lm/processing_perception_lm.py +47 -13
  984. transformers/models/perception_lm/video_processing_perception_lm.py +1 -0
  985. transformers/models/persimmon/configuration_persimmon.py +28 -23
  986. transformers/models/persimmon/modeling_persimmon.py +45 -43
  987. transformers/models/phi/configuration_phi.py +28 -23
  988. transformers/models/phi/modeling_phi.py +43 -40
  989. transformers/models/phi/modular_phi.py +24 -23
  990. transformers/models/phi3/configuration_phi3.py +33 -28
  991. transformers/models/phi3/modeling_phi3.py +38 -36
  992. transformers/models/phi3/modular_phi3.py +17 -13
  993. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +33 -30
  994. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +9 -7
  995. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  996. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +78 -95
  997. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +80 -98
  998. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +44 -7
  999. transformers/models/phimoe/configuration_phimoe.py +36 -31
  1000. transformers/models/phimoe/modeling_phimoe.py +45 -50
  1001. transformers/models/phimoe/modular_phimoe.py +4 -3
  1002. transformers/models/phobert/tokenization_phobert.py +6 -4
  1003. transformers/models/pix2struct/configuration_pix2struct.py +10 -12
  1004. transformers/models/pix2struct/image_processing_pix2struct.py +19 -15
  1005. transformers/models/pix2struct/image_processing_pix2struct_fast.py +15 -12
  1006. transformers/models/pix2struct/modeling_pix2struct.py +52 -58
  1007. transformers/models/pix2struct/processing_pix2struct.py +30 -5
  1008. transformers/models/pixtral/configuration_pixtral.py +14 -11
  1009. transformers/models/pixtral/image_processing_pixtral.py +28 -26
  1010. transformers/models/pixtral/image_processing_pixtral_fast.py +11 -10
  1011. transformers/models/pixtral/modeling_pixtral.py +34 -28
  1012. transformers/models/pixtral/processing_pixtral.py +53 -21
  1013. transformers/models/plbart/configuration_plbart.py +5 -8
  1014. transformers/models/plbart/modeling_plbart.py +106 -119
  1015. transformers/models/plbart/modular_plbart.py +33 -39
  1016. transformers/models/plbart/tokenization_plbart.py +7 -4
  1017. transformers/models/poolformer/configuration_poolformer.py +1 -0
  1018. transformers/models/poolformer/image_processing_poolformer.py +24 -21
  1019. transformers/models/poolformer/image_processing_poolformer_fast.py +15 -13
  1020. transformers/models/poolformer/modeling_poolformer.py +13 -23
  1021. transformers/models/pop2piano/configuration_pop2piano.py +8 -7
  1022. transformers/models/pop2piano/feature_extraction_pop2piano.py +9 -6
  1023. transformers/models/pop2piano/modeling_pop2piano.py +24 -26
  1024. transformers/models/pop2piano/processing_pop2piano.py +33 -25
  1025. transformers/models/pop2piano/tokenization_pop2piano.py +23 -15
  1026. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +3 -3
  1027. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1028. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +21 -20
  1029. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +13 -16
  1030. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +13 -16
  1031. transformers/models/prophetnet/configuration_prophetnet.py +38 -37
  1032. transformers/models/prophetnet/modeling_prophetnet.py +131 -114
  1033. transformers/models/prophetnet/tokenization_prophetnet.py +16 -14
  1034. transformers/models/pvt/configuration_pvt.py +1 -0
  1035. transformers/models/pvt/image_processing_pvt.py +27 -24
  1036. transformers/models/pvt/image_processing_pvt_fast.py +2 -1
  1037. transformers/models/pvt/modeling_pvt.py +21 -21
  1038. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -2
  1039. transformers/models/pvt_v2/modeling_pvt_v2.py +25 -28
  1040. transformers/models/qwen2/configuration_qwen2.py +25 -32
  1041. transformers/models/qwen2/modeling_qwen2.py +38 -36
  1042. transformers/models/qwen2/modular_qwen2.py +12 -11
  1043. transformers/models/qwen2/tokenization_qwen2.py +23 -12
  1044. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +26 -32
  1045. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +277 -340
  1046. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +211 -278
  1047. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +49 -41
  1048. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +35 -29
  1049. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +148 -203
  1050. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +118 -93
  1051. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +43 -7
  1052. transformers/models/qwen2_audio/configuration_qwen2_audio.py +1 -0
  1053. transformers/models/qwen2_audio/modeling_qwen2_audio.py +40 -40
  1054. transformers/models/qwen2_audio/processing_qwen2_audio.py +42 -13
  1055. transformers/models/qwen2_moe/configuration_qwen2_moe.py +35 -42
  1056. transformers/models/qwen2_moe/modeling_qwen2_moe.py +46 -51
  1057. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -7
  1058. transformers/models/qwen2_vl/configuration_qwen2_vl.py +34 -29
  1059. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +42 -41
  1060. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +15 -12
  1061. transformers/models/qwen2_vl/modeling_qwen2_vl.py +153 -199
  1062. transformers/models/qwen2_vl/processing_qwen2_vl.py +44 -7
  1063. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +18 -38
  1064. transformers/models/qwen3/configuration_qwen3.py +27 -34
  1065. transformers/models/qwen3/modeling_qwen3.py +39 -36
  1066. transformers/models/qwen3/modular_qwen3.py +6 -4
  1067. transformers/models/qwen3_moe/configuration_qwen3_moe.py +32 -39
  1068. transformers/models/qwen3_moe/modeling_qwen3_moe.py +46 -51
  1069. transformers/models/qwen3_moe/modular_qwen3_moe.py +13 -10
  1070. transformers/models/qwen3_next/configuration_qwen3_next.py +35 -45
  1071. transformers/models/qwen3_next/modeling_qwen3_next.py +51 -47
  1072. transformers/models/qwen3_next/modular_qwen3_next.py +35 -34
  1073. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +101 -135
  1074. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +252 -355
  1075. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +196 -250
  1076. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +48 -40
  1077. transformers/models/qwen3_vl/configuration_qwen3_vl.py +29 -27
  1078. transformers/models/qwen3_vl/modeling_qwen3_vl.py +155 -233
  1079. transformers/models/qwen3_vl/modular_qwen3_vl.py +179 -206
  1080. transformers/models/qwen3_vl/processing_qwen3_vl.py +42 -6
  1081. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +12 -10
  1082. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +30 -23
  1083. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +303 -358
  1084. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +124 -87
  1085. transformers/models/rag/configuration_rag.py +15 -6
  1086. transformers/models/rag/modeling_rag.py +130 -127
  1087. transformers/models/rag/retrieval_rag.py +5 -3
  1088. transformers/models/rag/tokenization_rag.py +50 -0
  1089. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +30 -29
  1090. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +42 -53
  1091. transformers/models/reformer/configuration_reformer.py +8 -7
  1092. transformers/models/reformer/modeling_reformer.py +69 -80
  1093. transformers/models/reformer/tokenization_reformer.py +31 -11
  1094. transformers/models/regnet/configuration_regnet.py +1 -0
  1095. transformers/models/regnet/modeling_regnet.py +8 -15
  1096. transformers/models/rembert/configuration_rembert.py +2 -8
  1097. transformers/models/rembert/modeling_rembert.py +111 -121
  1098. transformers/models/rembert/tokenization_rembert.py +12 -2
  1099. transformers/models/resnet/configuration_resnet.py +1 -0
  1100. transformers/models/resnet/modeling_resnet.py +13 -27
  1101. transformers/models/roberta/configuration_roberta.py +3 -11
  1102. transformers/models/roberta/modeling_roberta.py +93 -94
  1103. transformers/models/roberta/modular_roberta.py +58 -58
  1104. transformers/models/roberta/tokenization_roberta.py +29 -17
  1105. transformers/models/roberta/tokenization_roberta_old.py +4 -2
  1106. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +3 -11
  1107. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +93 -94
  1108. transformers/models/roc_bert/configuration_roc_bert.py +2 -8
  1109. transformers/models/roc_bert/modeling_roc_bert.py +121 -122
  1110. transformers/models/roc_bert/tokenization_roc_bert.py +94 -88
  1111. transformers/models/roformer/configuration_roformer.py +3 -13
  1112. transformers/models/roformer/modeling_roformer.py +81 -85
  1113. transformers/models/roformer/tokenization_roformer.py +412 -74
  1114. transformers/models/roformer/tokenization_roformer_fast.py +160 -0
  1115. transformers/models/roformer/tokenization_utils.py +1 -0
  1116. transformers/models/rt_detr/configuration_rt_detr.py +2 -1
  1117. transformers/models/rt_detr/configuration_rt_detr_resnet.py +1 -0
  1118. transformers/models/rt_detr/image_processing_rt_detr.py +55 -54
  1119. transformers/models/rt_detr/image_processing_rt_detr_fast.py +26 -26
  1120. transformers/models/rt_detr/modeling_rt_detr.py +90 -99
  1121. transformers/models/rt_detr/modeling_rt_detr_resnet.py +6 -13
  1122. transformers/models/rt_detr/modular_rt_detr.py +16 -16
  1123. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +4 -6
  1124. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +90 -101
  1125. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +12 -19
  1126. transformers/models/rwkv/configuration_rwkv.py +4 -2
  1127. transformers/models/rwkv/modeling_rwkv.py +32 -31
  1128. transformers/models/sam/configuration_sam.py +1 -3
  1129. transformers/models/sam/image_processing_sam.py +60 -59
  1130. transformers/models/sam/image_processing_sam_fast.py +27 -25
  1131. transformers/models/sam/modeling_sam.py +41 -47
  1132. transformers/models/sam/processing_sam.py +27 -39
  1133. transformers/models/sam2/configuration_sam2.py +3 -2
  1134. transformers/models/sam2/image_processing_sam2_fast.py +15 -14
  1135. transformers/models/sam2/modeling_sam2.py +90 -96
  1136. transformers/models/sam2/modular_sam2.py +91 -86
  1137. transformers/models/sam2/processing_sam2.py +47 -31
  1138. transformers/models/sam2_video/configuration_sam2_video.py +1 -0
  1139. transformers/models/sam2_video/modeling_sam2_video.py +144 -151
  1140. transformers/models/sam2_video/modular_sam2_video.py +104 -101
  1141. transformers/models/sam2_video/processing_sam2_video.py +66 -49
  1142. transformers/models/sam2_video/video_processing_sam2_video.py +4 -1
  1143. transformers/models/sam3/configuration_sam3.py +2 -21
  1144. transformers/models/sam3/image_processing_sam3_fast.py +20 -17
  1145. transformers/models/sam3/modeling_sam3.py +170 -184
  1146. transformers/models/sam3/modular_sam3.py +8 -3
  1147. transformers/models/sam3/processing_sam3.py +52 -37
  1148. transformers/models/sam3_tracker/__init__.py +1 -0
  1149. transformers/models/sam3_tracker/configuration_sam3_tracker.py +3 -1
  1150. transformers/models/sam3_tracker/modeling_sam3_tracker.py +77 -82
  1151. transformers/models/sam3_tracker/modular_sam3_tracker.py +3 -8
  1152. transformers/models/sam3_tracker/processing_sam3_tracker.py +48 -31
  1153. transformers/models/sam3_tracker_video/__init__.py +1 -0
  1154. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +1 -25
  1155. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +122 -135
  1156. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +26 -35
  1157. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +66 -50
  1158. transformers/models/sam3_video/configuration_sam3_video.py +1 -14
  1159. transformers/models/sam3_video/modeling_sam3_video.py +34 -33
  1160. transformers/models/sam3_video/processing_sam3_video.py +46 -26
  1161. transformers/models/sam_hq/__init__.py +1 -1
  1162. transformers/models/sam_hq/configuration_sam_hq.py +1 -3
  1163. transformers/models/sam_hq/modeling_sam_hq.py +69 -74
  1164. transformers/models/sam_hq/modular_sam_hq.py +25 -23
  1165. transformers/models/sam_hq/{processing_sam_hq.py → processing_samhq.py} +29 -41
  1166. transformers/models/seamless_m4t/configuration_seamless_m4t.py +10 -8
  1167. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +11 -8
  1168. transformers/models/seamless_m4t/modeling_seamless_m4t.py +194 -212
  1169. transformers/models/seamless_m4t/processing_seamless_m4t.py +39 -18
  1170. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +77 -40
  1171. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +10 -8
  1172. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +196 -204
  1173. transformers/models/seed_oss/configuration_seed_oss.py +32 -28
  1174. transformers/models/seed_oss/modeling_seed_oss.py +35 -33
  1175. transformers/models/seed_oss/modular_seed_oss.py +4 -3
  1176. transformers/models/segformer/configuration_segformer.py +10 -0
  1177. transformers/models/segformer/image_processing_segformer.py +42 -39
  1178. transformers/models/segformer/image_processing_segformer_fast.py +12 -10
  1179. transformers/models/segformer/modeling_segformer.py +31 -34
  1180. transformers/models/segformer/modular_segformer.py +10 -8
  1181. transformers/models/seggpt/configuration_seggpt.py +1 -0
  1182. transformers/models/seggpt/image_processing_seggpt.py +41 -38
  1183. transformers/models/seggpt/modeling_seggpt.py +38 -50
  1184. transformers/models/sew/configuration_sew.py +2 -4
  1185. transformers/models/sew/modeling_sew.py +36 -38
  1186. transformers/models/sew/modular_sew.py +13 -13
  1187. transformers/models/sew_d/configuration_sew_d.py +2 -4
  1188. transformers/models/sew_d/modeling_sew_d.py +30 -31
  1189. transformers/models/shieldgemma2/configuration_shieldgemma2.py +1 -0
  1190. transformers/models/shieldgemma2/modeling_shieldgemma2.py +17 -16
  1191. transformers/models/shieldgemma2/processing_shieldgemma2.py +5 -3
  1192. transformers/models/siglip/configuration_siglip.py +2 -4
  1193. transformers/models/siglip/image_processing_siglip.py +20 -17
  1194. transformers/models/siglip/image_processing_siglip_fast.py +1 -0
  1195. transformers/models/siglip/modeling_siglip.py +75 -84
  1196. transformers/models/siglip/processing_siglip.py +14 -2
  1197. transformers/models/siglip/tokenization_siglip.py +7 -6
  1198. transformers/models/siglip2/configuration_siglip2.py +2 -5
  1199. transformers/models/siglip2/image_processing_siglip2.py +16 -15
  1200. transformers/models/siglip2/image_processing_siglip2_fast.py +7 -6
  1201. transformers/models/siglip2/modeling_siglip2.py +129 -143
  1202. transformers/models/siglip2/modular_siglip2.py +46 -47
  1203. transformers/models/siglip2/processing_siglip2.py +14 -2
  1204. transformers/models/smollm3/configuration_smollm3.py +32 -29
  1205. transformers/models/smollm3/modeling_smollm3.py +39 -36
  1206. transformers/models/smollm3/modular_smollm3.py +35 -33
  1207. transformers/models/smolvlm/configuration_smolvlm.py +4 -2
  1208. transformers/models/smolvlm/image_processing_smolvlm.py +43 -42
  1209. transformers/models/smolvlm/image_processing_smolvlm_fast.py +15 -41
  1210. transformers/models/smolvlm/modeling_smolvlm.py +94 -126
  1211. transformers/models/smolvlm/modular_smolvlm.py +39 -50
  1212. transformers/models/smolvlm/processing_smolvlm.py +83 -15
  1213. transformers/models/smolvlm/video_processing_smolvlm.py +18 -16
  1214. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +1 -0
  1215. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +27 -26
  1216. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1217. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +13 -10
  1218. transformers/models/speech_to_text/modeling_speech_to_text.py +54 -66
  1219. transformers/models/speech_to_text/processing_speech_to_text.py +30 -4
  1220. transformers/models/speech_to_text/tokenization_speech_to_text.py +6 -5
  1221. transformers/models/speecht5/configuration_speecht5.py +9 -7
  1222. transformers/models/speecht5/feature_extraction_speecht5.py +37 -16
  1223. transformers/models/speecht5/modeling_speecht5.py +175 -213
  1224. transformers/models/speecht5/number_normalizer.py +1 -0
  1225. transformers/models/speecht5/processing_speecht5.py +37 -3
  1226. transformers/models/speecht5/tokenization_speecht5.py +5 -4
  1227. transformers/models/splinter/configuration_splinter.py +7 -6
  1228. transformers/models/splinter/modeling_splinter.py +59 -71
  1229. transformers/models/splinter/tokenization_splinter.py +30 -9
  1230. transformers/models/squeezebert/configuration_squeezebert.py +2 -14
  1231. transformers/models/squeezebert/modeling_squeezebert.py +62 -68
  1232. transformers/models/squeezebert/tokenization_squeezebert.py +1 -0
  1233. transformers/models/stablelm/configuration_stablelm.py +29 -24
  1234. transformers/models/stablelm/modeling_stablelm.py +45 -44
  1235. transformers/models/starcoder2/configuration_starcoder2.py +27 -30
  1236. transformers/models/starcoder2/modeling_starcoder2.py +41 -39
  1237. transformers/models/starcoder2/modular_starcoder2.py +16 -14
  1238. transformers/models/superglue/configuration_superglue.py +3 -7
  1239. transformers/models/superglue/image_processing_superglue.py +15 -15
  1240. transformers/models/superglue/image_processing_superglue_fast.py +10 -9
  1241. transformers/models/superglue/modeling_superglue.py +37 -42
  1242. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1243. transformers/models/superpoint/image_processing_superpoint_fast.py +11 -8
  1244. transformers/models/superpoint/modeling_superpoint.py +16 -18
  1245. transformers/models/swiftformer/configuration_swiftformer.py +1 -0
  1246. transformers/models/swiftformer/modeling_swiftformer.py +14 -18
  1247. transformers/models/swin/configuration_swin.py +1 -0
  1248. transformers/models/swin/modeling_swin.py +86 -86
  1249. transformers/models/swin2sr/configuration_swin2sr.py +1 -0
  1250. transformers/models/swin2sr/image_processing_swin2sr.py +13 -10
  1251. transformers/models/swin2sr/image_processing_swin2sr_fast.py +8 -4
  1252. transformers/models/swin2sr/modeling_swin2sr.py +63 -81
  1253. transformers/models/swinv2/configuration_swinv2.py +1 -0
  1254. transformers/models/swinv2/modeling_swinv2.py +104 -108
  1255. transformers/models/switch_transformers/configuration_switch_transformers.py +7 -11
  1256. transformers/models/switch_transformers/modeling_switch_transformers.py +44 -37
  1257. transformers/models/switch_transformers/modular_switch_transformers.py +41 -34
  1258. transformers/models/t5/configuration_t5.py +8 -14
  1259. transformers/models/t5/modeling_t5.py +92 -88
  1260. transformers/models/t5/tokenization_t5.py +9 -3
  1261. transformers/models/t5gemma/configuration_t5gemma.py +41 -43
  1262. transformers/models/t5gemma/modeling_t5gemma.py +107 -104
  1263. transformers/models/t5gemma/modular_t5gemma.py +120 -124
  1264. transformers/models/t5gemma2/configuration_t5gemma2.py +120 -80
  1265. transformers/models/t5gemma2/modeling_t5gemma2.py +125 -141
  1266. transformers/models/t5gemma2/modular_t5gemma2.py +104 -393
  1267. transformers/models/table_transformer/configuration_table_transformer.py +2 -1
  1268. transformers/models/table_transformer/modeling_table_transformer.py +49 -51
  1269. transformers/models/tapas/configuration_tapas.py +2 -12
  1270. transformers/models/tapas/modeling_tapas.py +67 -68
  1271. transformers/models/tapas/tokenization_tapas.py +153 -115
  1272. transformers/models/textnet/configuration_textnet.py +1 -0
  1273. transformers/models/textnet/image_processing_textnet.py +25 -22
  1274. transformers/models/textnet/image_processing_textnet_fast.py +10 -8
  1275. transformers/models/textnet/modeling_textnet.py +16 -28
  1276. transformers/models/time_series_transformer/configuration_time_series_transformer.py +8 -5
  1277. transformers/models/time_series_transformer/modeling_time_series_transformer.py +81 -83
  1278. transformers/models/timesfm/configuration_timesfm.py +1 -0
  1279. transformers/models/timesfm/modeling_timesfm.py +22 -33
  1280. transformers/models/timesfm/modular_timesfm.py +21 -32
  1281. transformers/models/timesformer/configuration_timesformer.py +1 -0
  1282. transformers/models/timesformer/modeling_timesformer.py +16 -15
  1283. transformers/models/timm_backbone/configuration_timm_backbone.py +1 -0
  1284. transformers/models/timm_backbone/modeling_timm_backbone.py +15 -17
  1285. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -5
  1286. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +5 -4
  1287. transformers/models/timm_wrapper/modeling_timm_wrapper.py +29 -34
  1288. transformers/models/trocr/configuration_trocr.py +8 -11
  1289. transformers/models/trocr/modeling_trocr.py +44 -45
  1290. transformers/models/trocr/processing_trocr.py +25 -5
  1291. transformers/models/tvp/configuration_tvp.py +2 -5
  1292. transformers/models/tvp/image_processing_tvp.py +52 -50
  1293. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1294. transformers/models/tvp/modeling_tvp.py +27 -27
  1295. transformers/models/tvp/processing_tvp.py +14 -2
  1296. transformers/models/udop/configuration_udop.py +7 -16
  1297. transformers/models/udop/modeling_udop.py +73 -71
  1298. transformers/models/udop/processing_udop.py +26 -7
  1299. transformers/models/udop/tokenization_udop.py +105 -84
  1300. transformers/models/umt5/configuration_umt5.py +7 -8
  1301. transformers/models/umt5/modeling_umt5.py +90 -94
  1302. transformers/models/unispeech/configuration_unispeech.py +2 -4
  1303. transformers/models/unispeech/modeling_unispeech.py +49 -51
  1304. transformers/models/unispeech/modular_unispeech.py +22 -22
  1305. transformers/models/unispeech_sat/configuration_unispeech_sat.py +2 -4
  1306. transformers/models/unispeech_sat/modeling_unispeech_sat.py +65 -69
  1307. transformers/models/unispeech_sat/modular_unispeech_sat.py +23 -23
  1308. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1309. transformers/models/univnet/modeling_univnet.py +8 -8
  1310. transformers/models/upernet/configuration_upernet.py +1 -0
  1311. transformers/models/upernet/modeling_upernet.py +13 -11
  1312. transformers/models/vaultgemma/__init__.py +1 -0
  1313. transformers/models/vaultgemma/configuration_vaultgemma.py +33 -29
  1314. transformers/models/vaultgemma/modeling_vaultgemma.py +41 -39
  1315. transformers/models/vaultgemma/modular_vaultgemma.py +31 -29
  1316. transformers/models/video_llama_3/configuration_video_llama_3.py +0 -4
  1317. transformers/models/video_llama_3/image_processing_video_llama_3.py +42 -43
  1318. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +14 -12
  1319. transformers/models/video_llama_3/modeling_video_llama_3.py +109 -157
  1320. transformers/models/video_llama_3/modular_video_llama_3.py +146 -155
  1321. transformers/models/video_llama_3/processing_video_llama_3.py +39 -5
  1322. transformers/models/video_llama_3/video_processing_video_llama_3.py +23 -42
  1323. transformers/models/video_llava/configuration_video_llava.py +1 -4
  1324. transformers/models/video_llava/image_processing_video_llava.py +38 -35
  1325. transformers/models/video_llava/modeling_video_llava.py +146 -146
  1326. transformers/models/video_llava/processing_video_llava.py +78 -38
  1327. transformers/models/video_llava/video_processing_video_llava.py +1 -0
  1328. transformers/models/videomae/configuration_videomae.py +1 -0
  1329. transformers/models/videomae/image_processing_videomae.py +34 -31
  1330. transformers/models/videomae/modeling_videomae.py +17 -14
  1331. transformers/models/videomae/video_processing_videomae.py +1 -0
  1332. transformers/models/vilt/configuration_vilt.py +4 -6
  1333. transformers/models/vilt/image_processing_vilt.py +30 -29
  1334. transformers/models/vilt/image_processing_vilt_fast.py +16 -15
  1335. transformers/models/vilt/modeling_vilt.py +90 -116
  1336. transformers/models/vilt/processing_vilt.py +14 -2
  1337. transformers/models/vipllava/configuration_vipllava.py +1 -4
  1338. transformers/models/vipllava/modeling_vipllava.py +70 -99
  1339. transformers/models/vipllava/modular_vipllava.py +54 -78
  1340. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +1 -0
  1341. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +27 -28
  1342. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +1 -0
  1343. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +41 -46
  1344. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +16 -2
  1345. transformers/models/visual_bert/configuration_visual_bert.py +2 -6
  1346. transformers/models/visual_bert/modeling_visual_bert.py +92 -98
  1347. transformers/models/vit/configuration_vit.py +1 -0
  1348. transformers/models/vit/image_processing_vit.py +22 -19
  1349. transformers/models/vit/image_processing_vit_fast.py +1 -0
  1350. transformers/models/vit/modeling_vit.py +17 -17
  1351. transformers/models/vit_mae/configuration_vit_mae.py +1 -0
  1352. transformers/models/vit_mae/modeling_vit_mae.py +27 -29
  1353. transformers/models/vit_msn/configuration_vit_msn.py +1 -0
  1354. transformers/models/vit_msn/modeling_vit_msn.py +16 -18
  1355. transformers/models/vitdet/configuration_vitdet.py +1 -0
  1356. transformers/models/vitdet/modeling_vitdet.py +14 -14
  1357. transformers/models/vitmatte/configuration_vitmatte.py +5 -2
  1358. transformers/models/vitmatte/image_processing_vitmatte.py +18 -15
  1359. transformers/models/vitmatte/image_processing_vitmatte_fast.py +18 -16
  1360. transformers/models/vitmatte/modeling_vitmatte.py +11 -14
  1361. transformers/models/vitpose/configuration_vitpose.py +7 -4
  1362. transformers/models/vitpose/image_processing_vitpose.py +25 -24
  1363. transformers/models/vitpose/image_processing_vitpose_fast.py +11 -9
  1364. transformers/models/vitpose/modeling_vitpose.py +14 -14
  1365. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +1 -0
  1366. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +10 -8
  1367. transformers/models/vits/configuration_vits.py +1 -4
  1368. transformers/models/vits/modeling_vits.py +42 -44
  1369. transformers/models/vits/tokenization_vits.py +4 -3
  1370. transformers/models/vivit/configuration_vivit.py +1 -0
  1371. transformers/models/vivit/image_processing_vivit.py +39 -36
  1372. transformers/models/vivit/modeling_vivit.py +8 -6
  1373. transformers/models/vjepa2/__init__.py +1 -0
  1374. transformers/models/vjepa2/configuration_vjepa2.py +1 -0
  1375. transformers/models/vjepa2/modeling_vjepa2.py +32 -31
  1376. transformers/models/vjepa2/video_processing_vjepa2.py +1 -0
  1377. transformers/models/voxtral/__init__.py +1 -0
  1378. transformers/models/voxtral/configuration_voxtral.py +2 -0
  1379. transformers/models/voxtral/modeling_voxtral.py +47 -40
  1380. transformers/models/voxtral/modular_voxtral.py +40 -37
  1381. transformers/models/voxtral/processing_voxtral.py +48 -25
  1382. transformers/models/wav2vec2/configuration_wav2vec2.py +2 -4
  1383. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +10 -7
  1384. transformers/models/wav2vec2/modeling_wav2vec2.py +121 -73
  1385. transformers/models/wav2vec2/processing_wav2vec2.py +35 -6
  1386. transformers/models/wav2vec2/tokenization_wav2vec2.py +332 -20
  1387. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +2 -4
  1388. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +62 -70
  1389. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +48 -57
  1390. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +35 -6
  1391. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +2 -4
  1392. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +77 -90
  1393. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +30 -37
  1394. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +17 -16
  1395. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +55 -36
  1396. transformers/models/wavlm/configuration_wavlm.py +2 -4
  1397. transformers/models/wavlm/modeling_wavlm.py +48 -50
  1398. transformers/models/wavlm/modular_wavlm.py +5 -4
  1399. transformers/models/whisper/configuration_whisper.py +5 -6
  1400. transformers/models/whisper/english_normalizer.py +4 -3
  1401. transformers/models/whisper/feature_extraction_whisper.py +24 -9
  1402. transformers/models/whisper/generation_whisper.py +48 -26
  1403. transformers/models/whisper/modeling_whisper.py +73 -79
  1404. transformers/models/whisper/processing_whisper.py +20 -3
  1405. transformers/models/whisper/tokenization_whisper.py +43 -11
  1406. transformers/models/x_clip/configuration_x_clip.py +2 -4
  1407. transformers/models/x_clip/modeling_x_clip.py +93 -96
  1408. transformers/models/x_clip/processing_x_clip.py +14 -2
  1409. transformers/models/xcodec/configuration_xcodec.py +6 -4
  1410. transformers/models/xcodec/modeling_xcodec.py +17 -20
  1411. transformers/models/xglm/configuration_xglm.py +8 -9
  1412. transformers/models/xglm/modeling_xglm.py +55 -60
  1413. transformers/models/xglm/tokenization_xglm.py +11 -3
  1414. transformers/models/xlm/configuration_xlm.py +8 -10
  1415. transformers/models/xlm/modeling_xlm.py +144 -144
  1416. transformers/models/xlm/tokenization_xlm.py +5 -3
  1417. transformers/models/xlm_roberta/configuration_xlm_roberta.py +3 -11
  1418. transformers/models/xlm_roberta/modeling_xlm_roberta.py +194 -195
  1419. transformers/models/xlm_roberta/modular_xlm_roberta.py +53 -50
  1420. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +18 -8
  1421. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +2 -10
  1422. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +93 -94
  1423. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +70 -67
  1424. transformers/models/xlnet/configuration_xlnet.py +12 -3
  1425. transformers/models/xlnet/modeling_xlnet.py +163 -152
  1426. transformers/models/xlnet/tokenization_xlnet.py +9 -2
  1427. transformers/models/xlstm/configuration_xlstm.py +12 -8
  1428. transformers/models/xlstm/modeling_xlstm.py +65 -62
  1429. transformers/models/xmod/configuration_xmod.py +3 -11
  1430. transformers/models/xmod/modeling_xmod.py +110 -108
  1431. transformers/models/yolos/configuration_yolos.py +1 -0
  1432. transformers/models/yolos/image_processing_yolos.py +62 -60
  1433. transformers/models/yolos/image_processing_yolos_fast.py +45 -42
  1434. transformers/models/yolos/modeling_yolos.py +16 -16
  1435. transformers/models/yolos/modular_yolos.py +19 -17
  1436. transformers/models/yoso/configuration_yoso.py +2 -8
  1437. transformers/models/yoso/modeling_yoso.py +63 -70
  1438. transformers/models/zamba/configuration_zamba.py +8 -5
  1439. transformers/models/zamba/modeling_zamba.py +78 -81
  1440. transformers/models/zamba2/configuration_zamba2.py +50 -44
  1441. transformers/models/zamba2/modeling_zamba2.py +97 -97
  1442. transformers/models/zamba2/modular_zamba2.py +48 -46
  1443. transformers/models/zoedepth/configuration_zoedepth.py +2 -1
  1444. transformers/models/zoedepth/image_processing_zoedepth.py +29 -28
  1445. transformers/models/zoedepth/image_processing_zoedepth_fast.py +24 -21
  1446. transformers/models/zoedepth/modeling_zoedepth.py +18 -26
  1447. transformers/pipelines/__init__.py +114 -57
  1448. transformers/pipelines/any_to_any.py +22 -14
  1449. transformers/pipelines/audio_utils.py +2 -1
  1450. transformers/pipelines/automatic_speech_recognition.py +12 -20
  1451. transformers/pipelines/base.py +27 -15
  1452. transformers/{models/pe_audio/processing_pe_audio.py → pipelines/deprecated/__init__.py} +3 -10
  1453. transformers/pipelines/deprecated/text2text_generation.py +408 -0
  1454. transformers/pipelines/document_question_answering.py +2 -4
  1455. transformers/pipelines/image_text_to_text.py +1 -0
  1456. transformers/pipelines/image_to_text.py +229 -0
  1457. transformers/pipelines/question_answering.py +44 -5
  1458. transformers/pipelines/text_classification.py +14 -1
  1459. transformers/pipelines/text_generation.py +1 -1
  1460. transformers/pipelines/text_to_audio.py +2 -2
  1461. transformers/pipelines/token_classification.py +22 -1
  1462. transformers/pipelines/video_classification.py +9 -1
  1463. transformers/pipelines/zero_shot_audio_classification.py +1 -0
  1464. transformers/pipelines/zero_shot_classification.py +6 -0
  1465. transformers/pipelines/zero_shot_image_classification.py +7 -0
  1466. transformers/processing_utils.py +145 -230
  1467. transformers/quantizers/auto.py +4 -2
  1468. transformers/quantizers/base.py +173 -53
  1469. transformers/quantizers/quantizer_aqlm.py +23 -2
  1470. transformers/quantizers/quantizer_auto_round.py +12 -2
  1471. transformers/quantizers/quantizer_awq.py +89 -20
  1472. transformers/quantizers/quantizer_bitnet.py +14 -4
  1473. transformers/quantizers/quantizer_bnb_4bit.py +155 -18
  1474. transformers/quantizers/quantizer_bnb_8bit.py +110 -24
  1475. transformers/quantizers/quantizer_compressed_tensors.py +9 -2
  1476. transformers/quantizers/quantizer_eetq.py +74 -16
  1477. transformers/quantizers/quantizer_fbgemm_fp8.py +138 -38
  1478. transformers/quantizers/quantizer_finegrained_fp8.py +113 -26
  1479. transformers/quantizers/quantizer_fp_quant.py +82 -52
  1480. transformers/quantizers/quantizer_gptq.py +28 -8
  1481. transformers/quantizers/quantizer_higgs.py +60 -42
  1482. transformers/quantizers/quantizer_hqq.py +153 -144
  1483. transformers/quantizers/quantizer_mxfp4.py +194 -14
  1484. transformers/quantizers/quantizer_quanto.py +79 -35
  1485. transformers/quantizers/quantizer_quark.py +18 -36
  1486. transformers/quantizers/quantizer_spqr.py +12 -4
  1487. transformers/quantizers/quantizer_torchao.py +325 -50
  1488. transformers/quantizers/quantizer_vptq.py +27 -4
  1489. transformers/quantizers/quantizers_utils.py +0 -20
  1490. transformers/safetensors_conversion.py +3 -9
  1491. transformers/testing_utils.py +82 -326
  1492. transformers/tokenization_mistral_common.py +903 -568
  1493. transformers/tokenization_utils_base.py +340 -220
  1494. transformers/tokenization_utils_sentencepiece.py +6 -5
  1495. transformers/tokenization_utils_tokenizers.py +113 -226
  1496. transformers/trainer.py +53 -60
  1497. transformers/trainer_callback.py +0 -8
  1498. transformers/trainer_seq2seq.py +1 -5
  1499. transformers/trainer_utils.py +1 -1
  1500. transformers/training_args.py +41 -77
  1501. transformers/utils/__init__.py +4 -8
  1502. transformers/utils/attention_visualizer.py +5 -5
  1503. transformers/utils/auto_docstring.py +37 -599
  1504. transformers/utils/doc.py +36 -4
  1505. transformers/utils/dummy_pt_objects.py +42 -0
  1506. transformers/utils/generic.py +28 -111
  1507. transformers/utils/hub.py +15 -5
  1508. transformers/utils/import_utils.py +32 -165
  1509. transformers/utils/kernel_config.py +19 -74
  1510. transformers/utils/loading_report.py +15 -25
  1511. transformers/utils/quantization_config.py +241 -72
  1512. transformers/video_processing_utils.py +39 -41
  1513. transformers/video_utils.py +22 -18
  1514. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/METADATA +236 -284
  1515. transformers-5.0.0rc0.dist-info/RECORD +1987 -0
  1516. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/WHEEL +1 -1
  1517. transformers/integrations/moe.py +0 -360
  1518. transformers/integrations/quark.py +0 -53
  1519. transformers/loss/loss_lw_detr.py +0 -356
  1520. transformers/models/ernie4_5_vl_moe/__init__.py +0 -31
  1521. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +0 -340
  1522. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +0 -455
  1523. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +0 -231
  1524. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +0 -1936
  1525. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +0 -1925
  1526. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +0 -249
  1527. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +0 -593
  1528. transformers/models/fast_vlm/__init__.py +0 -27
  1529. transformers/models/fast_vlm/configuration_fast_vlm.py +0 -137
  1530. transformers/models/fast_vlm/modeling_fast_vlm.py +0 -432
  1531. transformers/models/fast_vlm/modular_fast_vlm.py +0 -373
  1532. transformers/models/glm4_moe_lite/__init__.py +0 -28
  1533. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +0 -233
  1534. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +0 -740
  1535. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +0 -302
  1536. transformers/models/glm_image/__init__.py +0 -31
  1537. transformers/models/glm_image/configuration_glm_image.py +0 -351
  1538. transformers/models/glm_image/image_processing_glm_image.py +0 -503
  1539. transformers/models/glm_image/image_processing_glm_image_fast.py +0 -294
  1540. transformers/models/glm_image/modeling_glm_image.py +0 -1642
  1541. transformers/models/glm_image/modular_glm_image.py +0 -1531
  1542. transformers/models/glm_image/processing_glm_image.py +0 -217
  1543. transformers/models/glmasr/__init__.py +0 -29
  1544. transformers/models/glmasr/configuration_glmasr.py +0 -196
  1545. transformers/models/glmasr/modeling_glmasr.py +0 -517
  1546. transformers/models/glmasr/modular_glmasr.py +0 -443
  1547. transformers/models/glmasr/processing_glmasr.py +0 -331
  1548. transformers/models/jais2/__init__.py +0 -27
  1549. transformers/models/jais2/configuration_jais2.py +0 -148
  1550. transformers/models/jais2/modeling_jais2.py +0 -484
  1551. transformers/models/jais2/modular_jais2.py +0 -194
  1552. transformers/models/lasr/__init__.py +0 -29
  1553. transformers/models/lasr/configuration_lasr.py +0 -244
  1554. transformers/models/lasr/feature_extraction_lasr.py +0 -275
  1555. transformers/models/lasr/modeling_lasr.py +0 -727
  1556. transformers/models/lasr/modular_lasr.py +0 -574
  1557. transformers/models/lasr/processing_lasr.py +0 -100
  1558. transformers/models/lasr/tokenization_lasr.py +0 -184
  1559. transformers/models/lighton_ocr/__init__.py +0 -28
  1560. transformers/models/lighton_ocr/configuration_lighton_ocr.py +0 -128
  1561. transformers/models/lighton_ocr/modeling_lighton_ocr.py +0 -463
  1562. transformers/models/lighton_ocr/modular_lighton_ocr.py +0 -404
  1563. transformers/models/lighton_ocr/processing_lighton_ocr.py +0 -229
  1564. transformers/models/lw_detr/__init__.py +0 -27
  1565. transformers/models/lw_detr/configuration_lw_detr.py +0 -374
  1566. transformers/models/lw_detr/modeling_lw_detr.py +0 -1702
  1567. transformers/models/lw_detr/modular_lw_detr.py +0 -1615
  1568. transformers/models/minimax_m2/__init__.py +0 -28
  1569. transformers/models/minimax_m2/configuration_minimax_m2.py +0 -188
  1570. transformers/models/minimax_m2/modeling_minimax_m2.py +0 -704
  1571. transformers/models/minimax_m2/modular_minimax_m2.py +0 -346
  1572. transformers/models/paddleocr_vl/__init__.py +0 -31
  1573. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +0 -335
  1574. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +0 -503
  1575. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +0 -209
  1576. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +0 -1683
  1577. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +0 -1380
  1578. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +0 -133
  1579. transformers/models/pe_audio/__init__.py +0 -29
  1580. transformers/models/pe_audio/configuration_pe_audio.py +0 -204
  1581. transformers/models/pe_audio/feature_extraction_pe_audio.py +0 -160
  1582. transformers/models/pe_audio/modeling_pe_audio.py +0 -819
  1583. transformers/models/pe_audio/modular_pe_audio.py +0 -298
  1584. transformers/models/pe_audio_video/__init__.py +0 -28
  1585. transformers/models/pe_audio_video/configuration_pe_audio_video.py +0 -223
  1586. transformers/models/pe_audio_video/modeling_pe_audio_video.py +0 -971
  1587. transformers/models/pe_audio_video/modular_pe_audio_video.py +0 -763
  1588. transformers/models/pe_video/__init__.py +0 -29
  1589. transformers/models/pe_video/configuration_pe_video.py +0 -209
  1590. transformers/models/pe_video/modeling_pe_video.py +0 -647
  1591. transformers/models/pe_video/modular_pe_video.py +0 -231
  1592. transformers/models/pe_video/processing_pe_video.py +0 -10
  1593. transformers/models/pe_video/video_processing_pe_video.py +0 -64
  1594. transformers/models/pixio/__init__.py +0 -29
  1595. transformers/models/pixio/configuration_pixio.py +0 -150
  1596. transformers/models/pixio/modeling_pixio.py +0 -507
  1597. transformers/models/pixio/modular_pixio.py +0 -403
  1598. transformers/models/solar_open/__init__.py +0 -27
  1599. transformers/models/solar_open/configuration_solar_open.py +0 -184
  1600. transformers/models/solar_open/modeling_solar_open.py +0 -642
  1601. transformers/models/solar_open/modular_solar_open.py +0 -224
  1602. transformers/trainer_jit_checkpoint.py +0 -125
  1603. transformers-5.0.0.dist-info/RECORD +0 -2068
  1604. {transformers-5.0.0.dist-info/licenses → transformers-5.0.0rc0.dist-info}/LICENSE +0 -0
  1605. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/entry_points.txt +0 -0
  1606. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ # coding=utf-8
1
2
  # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
2
3
  # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
4
  #
@@ -15,6 +16,7 @@
15
16
  import collections
16
17
  import copy
17
18
  import functools
19
+ import gc
18
20
  import importlib.metadata
19
21
  import inspect
20
22
  import json
@@ -24,18 +26,17 @@ import sys
24
26
  import warnings
25
27
  from abc import abstractmethod
26
28
  from collections import defaultdict
27
- from collections.abc import Callable, Iterator
29
+ from collections.abc import Callable, Sequence
28
30
  from contextlib import contextmanager
29
- from dataclasses import dataclass, field, replace
30
31
  from enum import Enum
31
32
  from functools import partial, wraps
32
33
  from itertools import cycle
33
34
  from threading import Thread
34
- from typing import Optional, TypeVar, get_type_hints
35
+ from typing import Optional, TypeVar, Union, get_type_hints
35
36
  from zipfile import is_zipfile
36
37
 
37
38
  import torch
38
- from huggingface_hub import create_repo, is_offline_mode, split_torch_state_dict_into_shards
39
+ from huggingface_hub import create_repo, split_torch_state_dict_into_shards
39
40
  from packaging import version
40
41
  from safetensors import safe_open
41
42
  from safetensors.torch import save_file as safe_save_file
@@ -62,8 +63,7 @@ from .integrations.accelerate import (
62
63
  accelerate_dispatch,
63
64
  check_and_set_device_map,
64
65
  expand_device_map,
65
- get_device,
66
- load_offloaded_parameter,
66
+ init_empty_weights,
67
67
  )
68
68
  from .integrations.deepspeed import _load_state_dict_into_zero3_model
69
69
  from .integrations.eager_paged import eager_paged_attention_forward
@@ -85,8 +85,7 @@ from .integrations.tensor_parallel import (
85
85
  verify_tp_plan,
86
86
  )
87
87
  from .loss.loss_utils import LOSS_MAPPING
88
- from .modeling_flash_attention_utils import lazy_import_flash_attention, lazy_import_paged_flash_attention
89
- from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
88
+ from .modeling_flash_attention_utils import lazy_import_flash_attention
90
89
  from .pytorch_utils import id_tensor_storage
91
90
  from .quantizers import HfQuantizer
92
91
  from .quantizers.auto import get_hf_quantizer
@@ -94,6 +93,7 @@ from .quantizers.quantizers_utils import get_module_from_name
94
93
  from .safetensors_conversion import auto_conversion
95
94
  from .utils import (
96
95
  ADAPTER_SAFE_WEIGHTS_NAME,
96
+ ADAPTER_WEIGHTS_NAME,
97
97
  DUMMY_INPUTS,
98
98
  SAFE_WEIGHTS_INDEX_NAME,
99
99
  SAFE_WEIGHTS_NAME,
@@ -107,12 +107,10 @@ from .utils import (
107
107
  copy_func,
108
108
  has_file,
109
109
  is_accelerate_available,
110
- is_bitsandbytes_available,
111
- is_env_variable_true,
112
110
  is_flash_attn_2_available,
113
111
  is_flash_attn_3_available,
114
- is_grouped_mm_available,
115
112
  is_kernels_available,
113
+ is_offline_mode,
116
114
  is_torch_flex_attn_available,
117
115
  is_torch_greater_or_equal,
118
116
  is_torch_mlu_available,
@@ -120,7 +118,7 @@ from .utils import (
120
118
  is_torch_xpu_available,
121
119
  logging,
122
120
  )
123
- from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder, is_flash_attention_requested
121
+ from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
124
122
  from .utils.hub import DownloadKwargs, create_and_tag_model_card, get_checkpoint_shard_files
125
123
  from .utils.import_utils import (
126
124
  is_huggingface_hub_greater_or_equal,
@@ -134,6 +132,7 @@ from .utils.quantization_config import QuantizationMethod
134
132
  if is_accelerate_available():
135
133
  from accelerate.hooks import add_hook_to_module
136
134
  from accelerate.utils import extract_model_from_parallel
135
+ from accelerate.utils.modeling import get_state_dict_from_offload
137
136
 
138
137
 
139
138
  _torch_distributed_available = torch.distributed.is_available()
@@ -155,63 +154,62 @@ logger = logging.get_logger(__name__)
155
154
  XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
156
155
  XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
157
156
  SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
157
+ _init_weights = True
158
158
  _is_quantized = False
159
159
  _is_ds_init_called = False
160
160
 
161
- # Mapping from flash attention implementations to their kernel fallback repositories
162
- FLASH_ATTN_KERNEL_FALLBACK = {
163
- "flash_attention_2": "kernels-community/flash-attn2",
164
- "flash_attention_3": "kernels-community/vllm-flash-attn3",
165
- }
166
-
167
161
 
168
- @dataclass(frozen=True)
169
- class LoadStateDictConfig:
170
- """
171
- Config for loading weights. This allows bundling arguments that are just
172
- passed around.
173
- """
162
+ def is_local_dist_rank_0():
163
+ return (
164
+ torch.distributed.is_available()
165
+ and torch.distributed.is_initialized()
166
+ and int(os.environ.get("LOCAL_RANK", "-1")) == 0
167
+ )
174
168
 
175
- pretrained_model_name_or_path: str | None = None
176
- download_kwargs: DownloadKwargs | None = field(default_factory=DownloadKwargs)
177
- use_safetensors: bool = True
178
- ignore_mismatched_sizes: bool = False
179
- sharded_metadata: dict | None = None
180
- device_map: dict | None = None
181
- disk_offload_folder: str | None = None
182
- offload_buffers: bool = False
183
- dtype: torch.dtype | None = None
184
- hf_quantizer: HfQuantizer | None = None
185
- device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None
186
- weights_only: bool = True
187
- weight_mapping: list[WeightConverter | WeightRenaming] | None = None
188
169
 
189
- @property
190
- def is_quantized(self) -> bool:
191
- return self.hf_quantizer is not None
170
+ TORCH_INIT_FUNCTIONS = {
171
+ "uniform_": nn.init.uniform_,
172
+ "normal_": nn.init.normal_,
173
+ "trunc_normal_": nn.init.trunc_normal_,
174
+ "constant_": nn.init.constant_,
175
+ "xavier_uniform_": nn.init.xavier_uniform_,
176
+ "xavier_normal_": nn.init.xavier_normal_,
177
+ "kaiming_uniform_": nn.init.kaiming_uniform_,
178
+ "kaiming_normal_": nn.init.kaiming_normal_,
179
+ "uniform": nn.init.uniform,
180
+ "normal": nn.init.normal,
181
+ "xavier_uniform": nn.init.xavier_uniform,
182
+ "xavier_normal": nn.init.xavier_normal,
183
+ "kaiming_uniform": nn.init.kaiming_uniform,
184
+ "kaiming_normal": nn.init.kaiming_normal,
185
+ "orthogonal_": nn.init.orthogonal_,
186
+ }
192
187
 
193
188
 
194
- @dataclass
195
- class LoadStateDictInfo:
189
+ @contextmanager
190
+ def no_init_weights():
196
191
  """
197
- Return container for state-dict loading results and diagnostics.
198
- This simplifies the code a bit.
192
+ Context manager to globally disable weight initialization to speed up loading large models.
199
193
  """
194
+ global _init_weights
195
+ old_init_weights = _init_weights
200
196
 
201
- missing_keys: set[str]
202
- unexpected_keys: set[str]
203
- mismatched_keys: set[tuple[str, torch.Size]]
204
- disk_offload_index: dict[str, str] | None
205
- error_msgs: list[str]
206
- conversion_errors: set[str]
197
+ _init_weights = False
207
198
 
199
+ def _skip_init(*args, **kwargs):
200
+ pass
208
201
 
209
- def is_local_dist_rank_0():
210
- return (
211
- torch.distributed.is_available()
212
- and torch.distributed.is_initialized()
213
- and int(os.environ.get("LOCAL_RANK", "-1")) == 0
214
- )
202
+ # Save the original initialization functions
203
+ for name, init_func in TORCH_INIT_FUNCTIONS.items():
204
+ setattr(torch.nn.init, name, _skip_init)
205
+
206
+ try:
207
+ yield
208
+ finally:
209
+ _init_weights = old_init_weights
210
+ # Restore the original initialization functions
211
+ for name, init_func in TORCH_INIT_FUNCTIONS.items():
212
+ setattr(torch.nn.init, name, init_func)
215
213
 
216
214
 
217
215
  @contextmanager
@@ -237,28 +235,23 @@ def set_zero3_state():
237
235
  _is_ds_init_called = False
238
236
 
239
237
 
240
- @contextmanager
241
- def local_torch_dtype(dtype: torch.dtype, model_class_name: str | None = None):
238
+ def restore_default_dtype(func):
242
239
  """
243
- Locally change the torch default dtype to `dtype`, and restore the old one upon exiting the context.
244
- If `model_class_name` is provided, it's used to provide a more helpful error message if `dtype` is not valid.
240
+ Decorator to restore the default torch dtype
241
+ at the end of the function. Serves
242
+ as a backup in case calling the function raises
243
+ an error after the function has changed the default dtype but before it could restore it.
245
244
  """
246
- # Just a more helping error before we set `torch.set_default_dtype` later on which would crash in this case
247
- if not dtype.is_floating_point:
248
- if model_class_name is not None:
249
- error_message = (
250
- f"{model_class_name} cannot be instantiated under `dtype={dtype}` as it's not a floating-point dtype"
251
- )
252
- else:
253
- error_message = f"Cannot set `{dtype}` as torch's default as it's not a floating-point dtype"
254
- raise ValueError(error_message)
255
245
 
256
- original_dtype = torch.get_default_dtype()
257
- try:
258
- torch.set_default_dtype(dtype)
259
- yield
260
- finally:
261
- torch.set_default_dtype(original_dtype)
246
+ @wraps(func)
247
+ def _wrapper(*args, **kwargs):
248
+ old_dtype = torch.get_default_dtype()
249
+ try:
250
+ return func(*args, **kwargs)
251
+ finally:
252
+ torch.set_default_dtype(old_dtype)
253
+
254
+ return _wrapper
262
255
 
263
256
 
264
257
  def get_torch_context_manager_or_global_device():
@@ -286,9 +279,7 @@ def get_state_dict_dtype(state_dict):
286
279
  return t.dtype
287
280
 
288
281
  # if no floating dtype was found return whatever the first dtype is
289
- if len(state_dict) == 0:
290
- return torch.float32
291
- return next(iter(state_dict.values())).dtype
282
+ return next(state_dict.values()).dtype
292
283
 
293
284
 
294
285
  str_to_torch_dtype = {
@@ -314,7 +305,7 @@ if is_torch_greater_or_equal("2.3.0"):
314
305
 
315
306
 
316
307
  def load_state_dict(
317
- checkpoint_file: str | os.PathLike, map_location: str | torch.device = "cpu", weights_only: bool = True
308
+ checkpoint_file: Union[str, os.PathLike], map_location: Union[str, torch.device] = "cpu", weights_only: bool = True
318
309
  ) -> dict[str, torch.Tensor]:
319
310
  """
320
311
  Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
@@ -414,97 +405,14 @@ def _find_identical(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]
414
405
  return shared_tensors, identical
415
406
 
416
407
 
417
- def remove_tied_weights_from_state_dict(
418
- state_dict: dict[str, torch.Tensor], model: "PreTrainedModel"
419
- ) -> dict[str, torch.Tensor]:
420
- """
421
- Remove all tied weights from the given `state_dict`, making sure to keep only the main weight that `model`
422
- will expect when reloading (even if we know tie weights symmetrically, it's better to keep the intended one).
423
- This is because `safetensors` does not allow tensor aliasing - so we're going to remove aliases before saving.
424
- """
425
- # To avoid any potential mistakes and mismatches between config and actual tied weights, here we check the pointers
426
- # of the Tensors themselves -> we are guaranteed to find all the actual tied weights
427
- ptrs = collections.defaultdict(list)
428
- for name, tensor in state_dict.items():
429
- if not isinstance(tensor, torch.Tensor):
430
- # Sometimes in the state_dict we have non-tensor objects.
431
- # e.g. in bitsandbytes we have some `str` objects in the state_dict
432
- # In the non-tensor case, fall back to the pointer of the object itself
433
- ptrs[id(tensor)].append(name)
434
-
435
- elif tensor.device.type == "meta":
436
- # In offloaded cases, there may be meta tensors in the state_dict.
437
- # For these cases, key by the pointer of the original tensor object
438
- # (state_dict tensors are detached and therefore no longer shared)
439
- tensor = model.get_parameter(name)
440
- ptrs[id(tensor)].append(name)
441
-
442
- else:
443
- ptrs[id_tensor_storage(tensor)].append(name)
444
-
445
- shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
446
-
447
- # Recursively descend to find tied weight keys
448
- all_potential_tied_weights_keys = set(_get_tied_weight_keys(model))
449
- error_names = []
450
- to_delete_names = set()
451
- # Removing the keys which are declared as known duplicates on load. This allows to make sure the name which is
452
- # kept is consistent
453
- if all_potential_tied_weights_keys is not None:
454
- for names in shared_ptrs.values():
455
- found = 0
456
- for name in sorted(names):
457
- matches_pattern = any(re.search(pat, name) for pat in all_potential_tied_weights_keys)
458
- if matches_pattern and name in state_dict:
459
- found += 1
460
- if found < len(names):
461
- to_delete_names.add(name)
462
- # We are entering a place where the weights and the transformers configuration do NOT match.
463
- shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
464
- # Those are actually tensor sharing but disjoint from each other, we can safely clone them
465
- # Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
466
- for name in disjoint_names:
467
- state_dict[name] = state_dict[name].clone()
468
-
469
- # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
470
- # If the link between tensors was done at runtime then `from_pretrained` will not get
471
- # the key back leading to random tensor. A proper warning will be shown
472
- # during reload (if applicable), but since the file is not necessarily compatible with
473
- # the config, better show a proper warning.
474
- shared_names, identical_names = _find_identical(shared_names, state_dict)
475
- # delete tensors that have identical storage
476
- for inames in identical_names:
477
- known = inames.intersection(to_delete_names)
478
- for name in known:
479
- del state_dict[name]
480
- unknown = inames.difference(to_delete_names)
481
- if len(unknown) > 1:
482
- error_names.append(unknown)
483
-
484
- if shared_names:
485
- error_names.extend(shared_names)
486
-
487
- if len(error_names) > 0:
488
- raise RuntimeError(
489
- f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. "
490
- f"We found all the potential target tied weights keys to be: {all_potential_tied_weights_keys}.\n"
491
- "This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
492
- )
493
-
494
- return state_dict
495
-
496
-
497
408
  def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor):
498
- """Cast a single parameter or buffer `param_name` into the `model`, with value `tensor`."""
499
- parent, param_type = get_module_from_name(model, param_name)
500
- if param_type in parent._parameters and not isinstance(tensor, nn.Parameter):
501
- tensor = nn.Parameter(tensor, requires_grad=tensor.is_floating_point())
502
- # We need to use setattr here, as we set non-persistent buffers as well with this function (`load_state_dict`
503
- # does not allow to do it)
504
- setattr(parent, param_type, tensor)
409
+ """Cast a single parameter `param_name` into the `model`, with value `tensor`."""
410
+ module, param_type = get_module_from_name(model, param_name)
411
+ # This will check potential shape mismatch if skipped before
412
+ module.load_state_dict({param_type: tensor}, strict=False, assign=True)
505
413
 
506
414
 
507
- def _add_variant(weights_name: str, variant: str | None = None) -> str:
415
+ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
508
416
  if variant is not None:
509
417
  path, name = weights_name.rsplit(".", 1)
510
418
  weights_name = f"{path}.{variant}.{name}"
@@ -512,20 +420,19 @@ def _add_variant(weights_name: str, variant: str | None = None) -> str:
512
420
 
513
421
 
514
422
  def _get_resolved_checkpoint_files(
515
- pretrained_model_name_or_path: str | os.PathLike | None,
516
- variant: str | None,
517
- gguf_file: str | None,
518
- use_safetensors: bool | None,
519
- user_agent: dict | None,
423
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
424
+ variant: Optional[str],
425
+ gguf_file: Optional[str],
426
+ use_safetensors: Optional[bool],
427
+ download_kwargs: DownloadKwargs,
428
+ user_agent: dict,
520
429
  is_remote_code: bool, # Because we can't determine this inside this function, we need it to be passed in
521
- transformers_explicit_filename: str | None = None,
522
- download_kwargs: DownloadKwargs | None = None,
523
- ) -> tuple[list[str] | None, dict | None]:
430
+ transformers_explicit_filename: Optional[str] = None,
431
+ ) -> tuple[Optional[list[str]], Optional[dict]]:
524
432
  """Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
525
433
  checkpoints are sharded.
526
434
  This function will download the data if necessary.
527
435
  """
528
- download_kwargs = download_kwargs or DownloadKwargs()
529
436
  cache_dir = download_kwargs.get("cache_dir")
530
437
  force_download = download_kwargs.get("force_download", False)
531
438
  proxies = download_kwargs.get("proxies")
@@ -538,19 +445,17 @@ def _get_resolved_checkpoint_files(
538
445
  if not transformers_explicit_filename.endswith(".safetensors") and not transformers_explicit_filename.endswith(
539
446
  ".safetensors.index.json"
540
447
  ):
541
- if transformers_explicit_filename != "adapter_model.bin":
542
- raise ValueError(
543
- "The transformers file in the config seems to be incorrect: it is neither a safetensors file "
544
- "(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
545
- f"{transformers_explicit_filename}"
546
- )
448
+ raise ValueError(
449
+ "The transformers file in the config seems to be incorrect: it is neither a safetensors file "
450
+ "(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
451
+ f"{transformers_explicit_filename}"
452
+ )
547
453
 
548
454
  is_sharded = False
549
455
 
550
456
  if pretrained_model_name_or_path is not None and gguf_file is None:
551
457
  pretrained_model_name_or_path = str(pretrained_model_name_or_path)
552
458
  is_local = os.path.isdir(pretrained_model_name_or_path)
553
- # If the file is a local folder (but not in the HF_HOME cache, even if it's technically local)
554
459
  if is_local:
555
460
  if transformers_explicit_filename is not None:
556
461
  # If the filename is explicitly defined, load this by default.
@@ -609,38 +514,25 @@ def _get_resolved_checkpoint_files(
609
514
  else:
610
515
  filename = _add_variant(WEIGHTS_NAME, variant)
611
516
 
612
- # Prepare set of kwargs for hub functions
613
- has_file_kwargs = {
614
- "revision": revision,
615
- "proxies": proxies,
616
- "token": token,
617
- "cache_dir": cache_dir,
618
- "local_files_only": local_files_only,
619
- }
620
- cached_file_kwargs = {
621
- "force_download": force_download,
622
- "user_agent": user_agent,
623
- "subfolder": subfolder,
624
- "_raise_exceptions_for_gated_repo": False,
625
- "_raise_exceptions_for_missing_entries": False,
626
- "_commit_hash": commit_hash,
627
- **has_file_kwargs,
628
- }
629
- can_auto_convert = (
630
- not is_offline_mode() # for obvious reasons
631
- # If we are in a CI environment or in a pytest run, we prevent the conversion
632
- and not is_env_variable_true("DISABLE_SAFETENSORS_CONVERSION")
633
- and not is_remote_code # converter bot does not work on remote code
634
- and subfolder == "" # converter bot does not work on subfolders
635
- )
636
-
637
517
  try:
638
518
  # Load from URL or cache if already cached
639
- # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
640
- # result when internet is up, the repo and revision exist, but the file does not.
519
+ cached_file_kwargs = {
520
+ "cache_dir": cache_dir,
521
+ "force_download": force_download,
522
+ "proxies": proxies,
523
+ "local_files_only": local_files_only,
524
+ "token": token,
525
+ "user_agent": user_agent,
526
+ "revision": revision,
527
+ "subfolder": subfolder,
528
+ "_raise_exceptions_for_gated_repo": False,
529
+ "_raise_exceptions_for_missing_entries": False,
530
+ "_commit_hash": commit_hash,
531
+ }
641
532
  resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
642
533
 
643
- # Try safetensors files first if not already found
534
+ # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
535
+ # result when internet is up, the repo and revision exist, but the file does not.
644
536
  if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
645
537
  # Maybe the checkpoint is sharded, we try to grab the index name in this case.
646
538
  resolved_archive_file = cached_file(
@@ -651,7 +543,7 @@ def _get_resolved_checkpoint_files(
651
543
  if resolved_archive_file is not None:
652
544
  is_sharded = True
653
545
  elif use_safetensors:
654
- if revision == "main" and can_auto_convert:
546
+ if revision == "main" and not is_offline_mode():
655
547
  resolved_archive_file, revision, is_sharded = auto_conversion(
656
548
  pretrained_model_name_or_path, **cached_file_kwargs
657
549
  )
@@ -660,7 +552,8 @@ def _get_resolved_checkpoint_files(
660
552
  raise OSError(
661
553
  f"{pretrained_model_name_or_path} does not appear to have a file named"
662
554
  f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} "
663
- "and thus cannot be loaded with `safetensors`. Please do not set `use_safetensors=True`."
555
+ "and thus cannot be loaded with `safetensors`. Please make sure that the model has "
556
+ "been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
664
557
  )
665
558
  else:
666
559
  # This repo has no safetensors file of any kind, we switch to PyTorch.
@@ -668,8 +561,6 @@ def _get_resolved_checkpoint_files(
668
561
  resolved_archive_file = cached_file(
669
562
  pretrained_model_name_or_path, filename, **cached_file_kwargs
670
563
  )
671
-
672
- # Then try `.bin` files
673
564
  if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
674
565
  # Maybe the checkpoint is sharded, we try to grab the index name in this case.
675
566
  resolved_archive_file = cached_file(
@@ -679,38 +570,67 @@ def _get_resolved_checkpoint_files(
679
570
  )
680
571
  if resolved_archive_file is not None:
681
572
  is_sharded = True
682
-
683
- # If we have a match, but it's `.bin` format, try to launch safetensors conversion for next time
684
- if resolved_archive_file is not None:
685
- safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
686
- if (
687
- filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]
688
- and not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
689
- and can_auto_convert
690
- ):
691
- Thread(
692
- target=auto_conversion,
693
- args=(pretrained_model_name_or_path,),
694
- kwargs={"ignore_errors_during_conversion": False, **cached_file_kwargs},
695
- name="Thread-auto_conversion",
696
- ).start()
697
-
698
- # If no match, raise appropriare errors
699
- else:
700
- # Otherwise, no PyTorch file was found
701
- if variant is not None and has_file(
702
- pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
703
- ):
704
- raise OSError(
705
- f"{pretrained_model_name_or_path} does not appear to have a file named"
706
- f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
707
- f" {variant}. Use `variant=None` to load this model from those weights."
708
- )
573
+ if not local_files_only and not is_offline_mode():
574
+ if resolved_archive_file is not None:
575
+ # In a CI environment (CircleCI / Github Actions workflow runs) or in a pytest run,
576
+ # we set `DISABLE_SAFETENSORS_CONVERSION=true` to prevent the conversion.
577
+ if (
578
+ filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]
579
+ and os.getenv("DISABLE_SAFETENSORS_CONVERSION", None) != "true"
580
+ ):
581
+ # If the PyTorch file was found, check if there is a safetensors file on the repository
582
+ # If there is no safetensors file on the repositories, start an auto conversion
583
+ safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
584
+ has_file_kwargs = {
585
+ "revision": revision,
586
+ "proxies": proxies,
587
+ "token": token,
588
+ "cache_dir": cache_dir,
589
+ "local_files_only": local_files_only,
590
+ }
591
+ cached_file_kwargs = {
592
+ "cache_dir": cache_dir,
593
+ "force_download": force_download,
594
+ "local_files_only": local_files_only,
595
+ "user_agent": user_agent,
596
+ "subfolder": subfolder,
597
+ "_raise_exceptions_for_gated_repo": False,
598
+ "_raise_exceptions_for_missing_entries": False,
599
+ "_commit_hash": commit_hash,
600
+ **has_file_kwargs,
601
+ }
602
+ if (
603
+ not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
604
+ and not is_remote_code
605
+ ):
606
+ Thread(
607
+ target=auto_conversion,
608
+ args=(pretrained_model_name_or_path,),
609
+ kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
610
+ name="Thread-auto_conversion",
611
+ ).start()
709
612
  else:
710
- raise OSError(
711
- f"{pretrained_model_name_or_path} does not appear to have a file named"
712
- f" {_add_variant(WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_NAME, variant)}."
713
- )
613
+ # Otherwise, no PyTorch file was found
614
+ has_file_kwargs = {
615
+ "revision": revision,
616
+ "proxies": proxies,
617
+ "token": token,
618
+ "cache_dir": cache_dir,
619
+ "local_files_only": local_files_only,
620
+ }
621
+ if variant is not None and has_file(
622
+ pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
623
+ ):
624
+ raise OSError(
625
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
626
+ f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
627
+ f" {variant}. Use `variant=None` to load this model from those weights."
628
+ )
629
+ else:
630
+ raise OSError(
631
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
632
+ f" {_add_variant(WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_NAME, variant)}."
633
+ )
714
634
 
715
635
  except OSError:
716
636
  # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
@@ -777,20 +697,22 @@ def _get_resolved_checkpoint_files(
777
697
 
778
698
 
779
699
  def _get_dtype(
780
- dtype: str | torch.dtype | dict | None,
781
- checkpoint_files: list[str] | None,
700
+ cls,
701
+ dtype: Optional[Union[str, torch.dtype, dict]],
702
+ checkpoint_files: Optional[list[str]],
782
703
  config: PreTrainedConfig,
783
- sharded_metadata: dict | None,
784
- state_dict: dict | None,
704
+ sharded_metadata: Optional[dict],
705
+ state_dict: Optional[dict],
785
706
  weights_only: bool,
786
- hf_quantizer: HfQuantizer | None = None,
787
- ) -> tuple[PreTrainedConfig, torch.dtype]:
707
+ ) -> tuple[PreTrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
788
708
  """Find the correct `dtype` to use based on provided arguments. Also update the `config` based on the
789
709
  inferred dtype. We do the following:
790
- 1. If dtype is "auto", we try to read the config, else auto-detect dtype from the loaded state_dict, by checking
791
- its first weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
792
- 2. Else, use the dtype provided as a dict or str
710
+ 1. If dtype is not None, we use that dtype
711
+ 2. If dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
712
+ weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
713
+ we also may have config.dtype available, but we won't rely on it till v5
793
714
  """
715
+ dtype_orig = None
794
716
  is_sharded = sharded_metadata is not None
795
717
 
796
718
  if dtype is not None:
@@ -815,46 +737,43 @@ def _get_dtype(
815
737
  )
816
738
  elif hasattr(torch, dtype):
817
739
  dtype = getattr(torch, dtype)
818
- else:
819
- raise ValueError(
820
- "`dtype` provided as a `str` can only be `'auto'`, or a string representation of a valid `torch.dtype`"
821
- )
822
-
823
- # cast it to a proper `torch.dtype` object
824
- dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
825
- elif not isinstance(dtype, (dict, torch.dtype)):
740
+ config.dtype = dtype
741
+ for sub_config_key in config.sub_configs:
742
+ if (sub_config := getattr(config, sub_config_key)) is not None:
743
+ sub_config.dtype = dtype
744
+ elif isinstance(dtype, torch.dtype):
745
+ config.dtype = dtype
746
+ for sub_config_key in config.sub_configs:
747
+ if (sub_config := getattr(config, sub_config_key)) is not None:
748
+ sub_config.dtype = dtype
749
+ elif isinstance(dtype, dict):
750
+ for key, curr_dtype in dtype.items():
751
+ if hasattr(config, key):
752
+ value = getattr(config, key)
753
+ curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
754
+ value.dtype = curr_dtype
755
+ # main torch dtype for modules that aren't part of any sub-config
756
+ dtype = dtype.get("")
757
+ dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
758
+ config.dtype = dtype
759
+ if dtype is None:
760
+ dtype = torch.float32
761
+ else:
826
762
  raise ValueError(
827
763
  f"`dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `dtype` "
828
764
  f"for each sub-config in composite configs, but received {dtype}"
829
765
  )
830
- else:
831
- # set torch.get_default_dtype() (usually fp32) as the default dtype if `None` is provided
832
- dtype = torch.get_default_dtype()
833
-
834
- if hf_quantizer is not None:
835
- hf_quantizer.update_dtype(dtype)
836
-
837
- # Get the main dtype
838
- if isinstance(dtype, dict):
839
- main_dtype = dtype.get("", torch.get_default_dtype())
840
- main_dtype = getattr(torch, main_dtype) if isinstance(main_dtype, str) else main_dtype
841
-
842
- logger.warning_once(
843
- "Using different dtypes per module is deprecated and will be removed in future versions "
844
- "Setting different dtypes per backbone model might cause device errors downstream, therefore "
845
- f"setting the dtype={main_dtype} for all modules."
846
- )
847
766
 
767
+ dtype_orig = cls._set_default_dtype(dtype)
848
768
  else:
849
- main_dtype = dtype
850
-
851
- # Set it on the config and subconfigs
852
- config.dtype = main_dtype
853
- for sub_config_key in config.sub_configs:
854
- if (sub_config := getattr(config, sub_config_key)) is not None:
855
- sub_config.dtype = main_dtype
769
+ # set fp32 as the default dtype for BC
770
+ default_dtype = torch.get_default_dtype()
771
+ config.dtype = default_dtype
772
+ for key in config.sub_configs:
773
+ if (sub_config := getattr(config, key)) is not None:
774
+ sub_config.dtype = default_dtype
856
775
 
857
- return config, main_dtype
776
+ return config, dtype, dtype_orig
858
777
 
859
778
 
860
779
  class PipelineParallel(Enum):
@@ -905,8 +824,13 @@ class ModuleUtilsMixin:
905
824
  return encoder_extended_attention_mask
906
825
 
907
826
  @staticmethod
908
- def create_extended_attention_mask_for_decoder(input_shape, attention_mask):
909
- device = attention_mask.device
827
+ def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
828
+ if device is not None:
829
+ warnings.warn(
830
+ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
831
+ )
832
+ else:
833
+ device = attention_mask.device
910
834
  batch_size, seq_length = input_shape
911
835
  seq_ids = torch.arange(seq_length, device=device)
912
836
  causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
@@ -930,7 +854,8 @@ class ModuleUtilsMixin:
930
854
  self,
931
855
  attention_mask: Tensor,
932
856
  input_shape: tuple[int, ...],
933
- dtype: torch.dtype | None = None,
857
+ device: Optional[torch.device] = None,
858
+ dtype: Optional[torch.dtype] = None,
934
859
  ) -> Tensor:
935
860
  """
936
861
  Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
@@ -947,6 +872,12 @@ class ModuleUtilsMixin:
947
872
  if dtype is None:
948
873
  dtype = self.dtype
949
874
 
875
+ if not (attention_mask.dim() == 2 and self.config.is_decoder):
876
+ # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
877
+ if device is not None:
878
+ warnings.warn(
879
+ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
880
+ )
950
881
  # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
951
882
  # ourselves in which case we just need to make it broadcastable to all heads.
952
883
  if attention_mask.dim() == 3:
@@ -955,9 +886,9 @@ class ModuleUtilsMixin:
955
886
  # Provided a padding mask of dimensions [batch_size, seq_length]
956
887
  # - if the model is a decoder, apply a causal mask in addition to the padding mask
957
888
  # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
958
- if getattr(self.config, "is_decoder", None):
889
+ if self.config.is_decoder:
959
890
  extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
960
- input_shape, attention_mask
891
+ input_shape, attention_mask, device
961
892
  )
962
893
  else:
963
894
  extended_attention_mask = attention_mask[:, None, None, :]
@@ -1038,52 +969,54 @@ class EmbeddingAccessMixin:
1038
969
  `nn.Module`: A torch module mapping vocabulary to hidden states.
1039
970
  """
1040
971
 
972
+ # 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
973
+ # for most NLP models), and if so, return it.
974
+
1041
975
  name = getattr(self, "_input_embed_layer", "embed_tokens")
1042
976
 
1043
- # 1) Direct attribute (most NLP models).
1044
977
  if (default_embedding := getattr(self, name, None)) is not None:
1045
978
  return default_embedding
1046
- # 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision/audio models).
1047
- if hasattr(self, "embeddings") and hasattr(self.embeddings, name):
1048
- return getattr(self.embeddings, name)
1049
- # 3) Encoder/decoder wrappers (e.g., `self.model.embed_tokens` or similar overrides).
1050
- if hasattr(self, "model") and hasattr(self.model, name):
1051
- return getattr(self.model, name)
979
+ # 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
1052
980
 
1053
- if hasattr(self, "base_model"):
1054
- base_model = self.base_model
1055
- if base_model is not None and base_model is not self:
1056
- return base_model.get_input_embeddings()
981
+ if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
982
+ return self.model.embed_tokens
1057
983
 
1058
- raise NotImplementedError(
1059
- f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
1060
- )
984
+ # 3) vanilla decoder‑only architectures
985
+ elif hasattr(self, "embed_tokens"):
986
+ return self.embed_tokens
987
+ else:
988
+ base_model = getattr(self, "base_model_prefix", None)
989
+ if base_model is not None:
990
+ base_model = getattr(self, base_model, None)
991
+ if base_model is not None and base_model is not self:
992
+ return base_model.get_input_embeddings()
993
+ raise NotImplementedError(
994
+ f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
995
+ "please override in the subclass."
996
+ )
1061
997
 
1062
998
  def set_input_embeddings(self, value: nn.Module):
1063
999
  """Fallback setter that handles **~70%** of models in the code-base.
1064
1000
 
1065
1001
  Order of attempts:
1066
- 1. `self.<_input_embed_layer>` (direct attribute)
1067
- 2. `self.embeddings.<_input_embed_layer>` (nested embeddings for vision/audio models)
1068
- 3. `self.model.<_input_embed_layer>` (encoder/decoder models)
1069
- 4. delegate to the *base model* if one exists
1070
- 5. otherwise raise `NotImplementedError` so subclasses still can (and
1002
+ 1. `self.model.embed_tokens`
1003
+ 2. `self.embed_tokens`
1004
+ 3. delegate to the *base model* if one exists
1005
+ 4. otherwise raise `NotImplementedError` so subclasses still can (and
1071
1006
  should) override for exotic layouts.
1072
1007
  """
1073
1008
 
1009
+ # 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
1074
1010
  name = getattr(self, "_input_embed_layer", "embed_tokens")
1075
- # 1) Direct attribute (most NLP models)
1076
- if hasattr(self, name):
1077
- setattr(self, name, value)
1078
- # 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision models)
1079
- elif hasattr(self, "embeddings") and hasattr(self.embeddings, name):
1080
- setattr(self.embeddings, name, value)
1081
- # 3) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
1082
- elif hasattr(self, "model") and hasattr(self.model, name):
1011
+ if hasattr(self, "model") and hasattr(self.model, name):
1083
1012
  setattr(self.model, name, value)
1084
- # 4) recurse once into the registered *base* model (e.g. for encoder/decoder)
1085
- elif hasattr(self, "base_model") and self.base_model is not self:
1086
- self.base_model.set_input_embeddings(value)
1013
+ # 2) as well as vanilla decoder‑only architectures
1014
+ elif hasattr(self, name):
1015
+ setattr(self, name, value)
1016
+ # 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
1017
+ elif getattr(self, self.base_model_prefix, self) is not self:
1018
+ base_model = getattr(self, self.base_model_prefix, self)
1019
+ base_model.set_input_embeddings(value)
1087
1020
  else:
1088
1021
  raise NotImplementedError(
1089
1022
  f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
@@ -1144,7 +1077,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1144
1077
  # to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
1145
1078
  _keep_in_fp32_modules_strict = None
1146
1079
 
1147
- dtype_plan: dict[str, torch.dtype] | None = None
1080
+ dtype_plan: Optional[dict[str, torch.dtype]] = None
1148
1081
 
1149
1082
  # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
1150
1083
  # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
@@ -1204,7 +1137,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1204
1137
 
1205
1138
  # Attributes used mainly in multimodal LLMs, though all models contain a valid field for these
1206
1139
  # Possible values are: text, image, video, audio and time
1207
- input_modalities: str | list[str] = "text" # most models are text
1140
+ input_modalities: Union[str, list[str]] = "text" # most models are text
1208
1141
 
1209
1142
  @property
1210
1143
  @torch._dynamo.allow_in_graph
@@ -1295,11 +1228,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1295
1228
  self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
1296
1229
  self.config._attn_implementation, is_init_check=True
1297
1230
  )
1298
- # Check the experts implementation is supported, or set it if not yet set (on the internal attr, to avoid
1299
- # setting it recursively)
1300
- self.config._experts_implementation_internal = self._check_and_adjust_experts_implementation(
1301
- self.config._experts_implementation
1302
- )
1303
1231
  if self.can_generate():
1304
1232
  self.generation_config = GenerationConfig.from_model_config(config)
1305
1233
 
@@ -1415,7 +1343,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1415
1343
  def pp_plan(self, plan: dict[str, tuple[str, str]]):
1416
1344
  self._pp_plan = plan
1417
1345
 
1418
- def dequantize(self, dtype=None):
1346
+ def dequantize(self):
1419
1347
  """
1420
1348
  Potentially dequantize the model in case it has been quantized by a quantization method that support
1421
1349
  dequantization.
@@ -1425,7 +1353,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1425
1353
  if hf_quantizer is None:
1426
1354
  raise ValueError("You need to first quantize your model in order to dequantize it")
1427
1355
 
1428
- return hf_quantizer.dequantize(self, dtype=dtype)
1356
+ return hf_quantizer.dequantize(self)
1429
1357
 
1430
1358
  def _backward_compatibility_gradient_checkpointing(self):
1431
1359
  if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
@@ -1433,7 +1361,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1433
1361
  # Remove the attribute now that is has been consumed, so it's no saved in the config.
1434
1362
  delattr(self.config, "gradient_checkpointing")
1435
1363
 
1436
- def add_model_tags(self, tags: list[str] | str) -> None:
1364
+ def add_model_tags(self, tags: Union[list[str], str]) -> None:
1437
1365
  r"""
1438
1366
  Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
1439
1367
  not overwrite existing tags in the model.
@@ -1466,6 +1394,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1466
1394
  self.model_tags.append(tag)
1467
1395
 
1468
1396
  @classmethod
1397
+ @restore_default_dtype
1469
1398
  def _from_config(cls, config, **kwargs):
1470
1399
  """
1471
1400
  All context managers that the model should be initialized under go here.
@@ -1474,6 +1403,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1474
1403
  dtype (`torch.dtype`, *optional*):
1475
1404
  Override the default `dtype` and load the model under this dtype.
1476
1405
  """
1406
+ # when we init a model from within another model (e.g. VLMs) and dispatch on FA2
1407
+ # a warning is raised that dtype should be fp16. Since we never pass dtype from within
1408
+ # modeling code, we can try to infer it here same way as done in `from_pretrained`
1477
1409
  # For BC on the old `torch_dtype`
1478
1410
  dtype = kwargs.pop("dtype", config.dtype)
1479
1411
  if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
@@ -1483,32 +1415,61 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1483
1415
  if isinstance(dtype, str):
1484
1416
  dtype = getattr(torch, dtype)
1485
1417
 
1418
+ # override default dtype if needed
1419
+ dtype_orig = None
1420
+ if dtype is not None:
1421
+ dtype_orig = cls._set_default_dtype(dtype)
1422
+
1486
1423
  # If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
1487
1424
  if "attn_implementation" in kwargs:
1488
1425
  config._attn_implementation = kwargs.pop("attn_implementation")
1489
1426
 
1490
- # If passing `experts_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
1491
- if "experts_implementation" in kwargs:
1492
- config._experts_implementation = kwargs.pop("experts_implementation")
1493
-
1494
- init_contexts = []
1495
- if dtype is not None:
1496
- init_contexts.append(local_torch_dtype(dtype, cls.__name__))
1497
-
1498
1427
  if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
1499
1428
  logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
1500
1429
  # this immediately partitions the model across all gpus, to avoid the overhead in time
1501
1430
  # and memory copying it on CPU or each GPU first
1502
1431
  import deepspeed
1503
1432
 
1504
- init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
1433
+ init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
1434
+ with ContextManagers(init_contexts):
1435
+ model = cls(config, **kwargs)
1505
1436
 
1506
- # Instantiate the model
1507
- with ContextManagers(init_contexts):
1437
+ else:
1508
1438
  model = cls(config, **kwargs)
1509
1439
 
1440
+ # restore default dtype if it was modified
1441
+ if dtype_orig is not None:
1442
+ torch.set_default_dtype(dtype_orig)
1443
+
1510
1444
  return model
1511
1445
 
1446
+ @classmethod
1447
+ def _set_default_dtype(cls, dtype: torch.dtype) -> torch.dtype:
1448
+ """
1449
+ Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
1450
+ under specific dtype.
1451
+
1452
+ Args:
1453
+ dtype (`torch.dtype`):
1454
+ a floating dtype to set to.
1455
+
1456
+ Returns:
1457
+ `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
1458
+ modified. If it wasn't, returns `None`.
1459
+
1460
+ Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
1461
+ `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
1462
+ """
1463
+ if not dtype.is_floating_point:
1464
+ raise ValueError(
1465
+ f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
1466
+ )
1467
+
1468
+ logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
1469
+ dtype_orig = torch.get_default_dtype()
1470
+ torch.set_default_dtype(dtype)
1471
+ return dtype_orig
1472
+
1512
1473
  @property
1513
1474
  def base_model(self) -> nn.Module:
1514
1475
  """
@@ -1585,9 +1546,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1585
1546
  return True
1586
1547
 
1587
1548
  if is_torch_xpu_available():
1588
- logger.info(
1589
- f"Detect using FlashAttention2 (via kernel `{FLASH_ATTN_KERNEL_FALLBACK['flash_attention_2']}`) on XPU."
1590
- )
1549
+ logger.info("Detect using FlashAttention2 (via kernel `kernels-community/flash-attn2`) on XPU.")
1591
1550
  return True
1592
1551
 
1593
1552
  if importlib.util.find_spec("flash_attn") is None:
@@ -1756,22 +1715,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1756
1715
 
1757
1716
  return True
1758
1717
 
1759
- def _grouped_mm_can_dispatch(self) -> bool:
1760
- """
1761
- Check the availability of Grouped MM for a given model.
1762
- """
1763
-
1764
- if not self._can_set_experts_implementation():
1765
- raise ValueError(f"{self.__class__.__name__} does not support setting experts implementation.")
1766
-
1767
- if not is_grouped_mm_available():
1768
- raise ImportError(
1769
- "PyTorch Grouped MM requirements in Transformers are not met. Please install torch>=2.9.0."
1770
- )
1771
-
1772
- # If no error raised by this point, we can return `True`
1773
- return True
1774
-
1775
1718
  def _flex_attn_can_dispatch(self, is_init_check: bool = False) -> bool:
1776
1719
  """
1777
1720
  Check the availability of Flex Attention for a given model.
@@ -1800,7 +1743,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1800
1743
  return True
1801
1744
 
1802
1745
  def _check_and_adjust_attn_implementation(
1803
- self, attn_implementation: str | None, is_init_check: bool = False
1746
+ self, attn_implementation: Optional[str], is_init_check: bool = False
1804
1747
  ) -> str:
1805
1748
  """
1806
1749
  Check that the `attn_implementation` exists and is supported by the models, and try to get the kernel from hub if
@@ -1821,12 +1764,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1821
1764
  """
1822
1765
  applicable_attn_implementation = attn_implementation
1823
1766
 
1824
- is_paged = attn_implementation is not None and attn_implementation.startswith("paged|")
1825
-
1826
1767
  # If FA not installed, do not fail but use kernels instead
1827
1768
  requested_original_flash_attn = attn_implementation is not None and (
1828
- attn_implementation.removeprefix("paged|") == "flash_attention_2"
1829
- or attn_implementation.removeprefix("paged|") == "flash_attention_3"
1769
+ attn_implementation == "flash_attention_2" or attn_implementation == "flash_attention_3"
1830
1770
  )
1831
1771
  if (
1832
1772
  requested_original_flash_attn
@@ -1835,23 +1775,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1835
1775
  and is_kernels_available()
1836
1776
  and not is_torch_npu_available()
1837
1777
  ):
1838
- applicable_attn_implementation = FLASH_ATTN_KERNEL_FALLBACK[attn_implementation.removeprefix("paged|")]
1839
-
1840
- if is_torch_xpu_available() and attn_implementation.removeprefix("paged|") == "flash_attention_2":
1841
- # On XPU, kernels library is the native implementation
1842
- # Disabling this flag to avoid giving wrong fallbacks on errors and warnings
1843
- requested_original_flash_attn = False
1844
-
1845
- if is_paged:
1846
- applicable_attn_implementation = f"paged|{applicable_attn_implementation}"
1778
+ if attn_implementation.endswith("2"):
1779
+ applicable_attn_implementation = "kernels-community/flash-attn2"
1780
+ if is_torch_xpu_available():
1781
+ # On XPU, kernels library is the native implementation
1782
+ # Disabling this flag to avoid giving wrong fallbacks on errors and warnings
1783
+ requested_original_flash_attn = False
1784
+ else:
1785
+ applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
1847
1786
 
1848
1787
  if is_kernel(applicable_attn_implementation):
1849
1788
  try:
1850
1789
  # preload flash attention here to allow compile with fullgraph
1851
- if is_paged:
1852
- lazy_import_paged_flash_attention(applicable_attn_implementation)
1853
- else:
1854
- lazy_import_flash_attention(applicable_attn_implementation)
1790
+ lazy_import_flash_attention(applicable_attn_implementation)
1855
1791
 
1856
1792
  # log that we used kernel fallback if successful
1857
1793
  if requested_original_flash_attn:
@@ -1875,25 +1811,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1875
1811
  )
1876
1812
 
1877
1813
  # preload flash attention here to allow compile with fullgraph
1878
- if is_flash_attention_requested(requested_attention_implementation=applicable_attn_implementation):
1814
+ if "flash" in applicable_attn_implementation:
1879
1815
  lazy_import_flash_attention(applicable_attn_implementation)
1880
1816
 
1881
1817
  return applicable_attn_implementation
1882
1818
 
1883
- def _check_and_adjust_experts_implementation(self, experts_implementation: str | None) -> str:
1884
- """
1885
- Check that the `experts_implementation` exists and is supported by the models.
1886
-
1887
- Args:
1888
- experts_implementation (`str` or `None`):
1889
- The experts implementation to check for existence/validity.
1890
- Returns:
1891
- `str`: The final experts implementation to use.
1892
- """
1893
- applicable_experts_implementation = self.get_correct_experts_implementation(experts_implementation)
1894
- return applicable_experts_implementation
1895
-
1896
- def get_correct_attn_implementation(self, requested_attention: str | None, is_init_check: bool = False) -> str:
1819
+ def get_correct_attn_implementation(self, requested_attention: Optional[str], is_init_check: bool = False) -> str:
1897
1820
  applicable_attention = "sdpa" if requested_attention is None else requested_attention
1898
1821
  if applicable_attention not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
1899
1822
  message = (
@@ -1927,33 +1850,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1927
1850
 
1928
1851
  return applicable_attention
1929
1852
 
1930
- def get_correct_experts_implementation(self, requested_experts: str | None) -> str:
1931
- applicable_experts = "grouped_mm" if requested_experts is None else requested_experts
1932
- if applicable_experts not in ["eager", "grouped_mm", "batched_mm"]:
1933
- message = (
1934
- f'Specified `experts_implementation="{applicable_experts}"` is not supported. The only possible arguments are '
1935
- '`experts_implementation="eager"`, `"experts_implementation=grouped_mm"` and `"experts_implementation=batched_mm"`.'
1936
- )
1937
- raise ValueError(message)
1938
-
1939
- # Perform relevant checks
1940
- if applicable_experts == "grouped_mm":
1941
- try:
1942
- self._grouped_mm_can_dispatch()
1943
- except (ValueError, ImportError) as e:
1944
- if requested_experts == "grouped_mm":
1945
- raise e
1946
- applicable_experts = "eager"
1947
-
1948
- return applicable_experts
1949
-
1950
1853
  @classmethod
1951
1854
  def _can_set_attn_implementation(cls) -> bool:
1952
1855
  """Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
1953
1856
  opening the file, but avoids maintaining yet another property flag.
1954
1857
  """
1955
1858
  class_file = sys.modules[cls.__module__].__file__
1956
- with open(class_file, "r", encoding="utf-8") as f:
1859
+ with open(class_file, "r") as f:
1957
1860
  code = f.read()
1958
1861
  # heuristic -> if we find those patterns, the model uses the correct interface
1959
1862
  if re.search(r"class \w+Attention\(nn.Module\)", code):
@@ -1965,18 +1868,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1965
1868
  # If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
1966
1869
  return True
1967
1870
 
1968
- @classmethod
1969
- def _can_set_experts_implementation(cls) -> bool:
1970
- """Detect whether the class supports setting its experts implementation dynamically. It is an ugly check based on
1971
- opening the file, but avoids maintaining yet another property flag.
1972
- """
1973
- class_file = sys.modules[cls.__module__].__file__
1974
- with open(class_file, "r", encoding="utf-8") as f:
1975
- code = f.read()
1976
- # heuristic -> if we the use_experts_implementation decorator is used, then we can set it
1977
- return "@use_experts_implementation" in code
1978
-
1979
- def set_attn_implementation(self, attn_implementation: str | dict):
1871
+ def set_attn_implementation(self, attn_implementation: Union[str, dict]):
1980
1872
  """
1981
1873
  Set the requested `attn_implementation` for this model.
1982
1874
 
@@ -2075,50 +1967,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2075
1967
  if hasattr(subconfig, "_attn_was_changed"):
2076
1968
  del subconfig._attn_was_changed
2077
1969
 
2078
- def set_experts_implementation(self, experts_implementation: str | dict):
2079
- """
2080
- Set the requested `experts_implementation` for this model.
2081
-
2082
- Args:
2083
- experts_implementation (`str` or `dict`):
2084
- The experts implementation to set for this model. It can be either a `str`, in which case it will be
2085
- dispatched to all submodels if relevant, or a `dict` where keys are the sub_configs name, in which case each
2086
- submodel will dispatch the corresponding value.
2087
- """
2088
- requested_implementation = (
2089
- experts_implementation
2090
- if not isinstance(experts_implementation, dict)
2091
- else experts_implementation.get("", self.config._experts_implementation)
2092
- )
2093
-
2094
- if requested_implementation != self.config._experts_implementation:
2095
- requested_implementation = self._check_and_adjust_experts_implementation(requested_implementation)
2096
- # Apply the change (on the internal attr, to avoid setting it recursively)
2097
- self.config._experts_implementation_internal = requested_implementation
2098
-
2099
- # Apply it to all submodels as well
2100
- for submodule in self.modules():
2101
- # We found a submodel (which is not self) with a different config (otherwise, it may be the same "actual model",
2102
- # e.g. ForCausalLM has a Model inside, but no need to check it again)
2103
- if (
2104
- submodule is not self
2105
- and isinstance(submodule, PreTrainedModel)
2106
- and submodule.config.__class__ != self.config.__class__
2107
- ):
2108
- # Set the experts on the submodule
2109
- sub_implementation = requested_implementation
2110
- if isinstance(experts_implementation, dict):
2111
- for subconfig_key in self.config.sub_configs:
2112
- # We need to check for exact object match here, with `is`
2113
- if getattr(self.config, subconfig_key) is submodule.config:
2114
- sub_implementation = experts_implementation.get(
2115
- subconfig_key, submodule.config._experts_implementation
2116
- )
2117
- break
2118
- # Check the module can use correctly, otherwise we raise an error if requested experts can't be set for submodule
2119
- sub_implementation = submodule.get_correct_experts_implementation(sub_implementation)
2120
- submodule.config._experts_implementation_internal = sub_implementation
2121
-
2122
1970
  def enable_input_require_grads(self):
2123
1971
  """
2124
1972
  Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
@@ -2130,18 +1978,14 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2130
1978
 
2131
1979
  hooks = []
2132
1980
  seen_modules = set()
2133
- found_embeddings = False
2134
1981
 
2135
1982
  for module in self.modules():
2136
1983
  if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")):
2137
1984
  continue
2138
1985
 
2139
- try:
2140
- input_embeddings = module.get_input_embeddings()
2141
- except NotImplementedError:
2142
- continue
1986
+ input_embeddings = module.get_input_embeddings()
2143
1987
 
2144
- if input_embeddings is None or not hasattr(input_embeddings, "register_forward_hook"):
1988
+ if input_embeddings is None:
2145
1989
  continue
2146
1990
 
2147
1991
  embedding_id = id(input_embeddings)
@@ -2150,18 +1994,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2150
1994
 
2151
1995
  seen_modules.add(embedding_id)
2152
1996
  hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads))
2153
- found_embeddings = True
2154
1997
 
2155
1998
  self._require_grads_hooks = hooks
2156
1999
  if hooks:
2157
2000
  # for BC
2158
2001
  self._require_grads_hook = hooks[0]
2159
- if not found_embeddings:
2160
- logger.warning_once(
2161
- f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token "
2162
- "embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully "
2163
- "support those features, or set `_input_embed_layer` to the attribute name that holds the embeddings."
2164
- )
2165
2002
 
2166
2003
  def disable_input_require_grads(self):
2167
2004
  """
@@ -2178,7 +2015,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2178
2015
  if hasattr(self, "_require_grads_hook"):
2179
2016
  del self._require_grads_hook
2180
2017
 
2181
- def get_encoder(self, modality: str | None = None):
2018
+ def get_encoder(self, modality: Optional[str] = None):
2182
2019
  """
2183
2020
  Best-effort lookup of the *encoder* module. If provided with `modality` argument,
2184
2021
  it looks for a modality-specific encoder in multimodal models (e.g. "image_encoder")
@@ -2210,7 +2047,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2210
2047
  # If this is a base transformer model (no encoder/model attributes), return self
2211
2048
  return self
2212
2049
 
2213
- def set_encoder(self, encoder, modality: str | None = None):
2050
+ def set_encoder(self, encoder, modality: Optional[str] = None):
2214
2051
  """
2215
2052
  Symmetric setter. Mirrors the lookup logic used in `get_encoder`.
2216
2053
  """
@@ -2267,6 +2104,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2267
2104
  possible_module_names = ["language_model", "text_model", "decoder"]
2268
2105
  for name in possible_module_names:
2269
2106
  if hasattr(self, name):
2107
+ print(name)
2270
2108
  setattr(self, name, decoder)
2271
2109
  return
2272
2110
 
@@ -2296,13 +2134,14 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2296
2134
  if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
2297
2135
  if getattr(module, "weight", None) is not None:
2298
2136
  init.normal_(module.weight, mean=0.0, std=std)
2299
- if module.bias is not None:
2137
+ if getattr(module, "bias", None) is not None:
2300
2138
  init.zeros_(module.bias)
2301
2139
  elif isinstance(module, nn.Embedding):
2302
- init.normal_(module.weight, mean=0.0, std=std)
2303
- # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
2304
- if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
2305
- init.zeros_(module.weight[module.padding_idx])
2140
+ if getattr(module, "weight", None) is not None:
2141
+ init.normal_(module.weight, mean=0.0, std=std)
2142
+ # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
2143
+ if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
2144
+ init.zeros_(module.weight[module.padding_idx])
2306
2145
  elif isinstance(module, nn.MultiheadAttention):
2307
2146
  # This uses torch's original init
2308
2147
  module._reset_parameters()
@@ -2314,25 +2153,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2314
2153
  or "RMSNorm" in module.__class__.__name__
2315
2154
  ):
2316
2155
  # Norms can exist without weights (in which case they are None from torch primitives)
2317
- if getattr(module, "weight", None) is not None:
2156
+ if hasattr(module, "weight") and module.weight is not None:
2318
2157
  init.ones_(module.weight)
2319
- if getattr(module, "bias", None) is not None:
2158
+ if hasattr(module, "bias") and module.bias is not None:
2320
2159
  init.zeros_(module.bias)
2321
- # And the potential buffers for the BatchNorms
2322
- if getattr(module, "running_mean", None) is not None:
2323
- init.zeros_(module.running_mean)
2324
- init.ones_(module.running_var)
2325
- init.zeros_(module.num_batches_tracked)
2326
- # This matches all the usual RotaryEmbeddings modules
2327
- elif "RotaryEmbedding" in module.__class__.__name__ and hasattr(module, "original_inv_freq"):
2328
- rope_fn = (
2329
- ROPE_INIT_FUNCTIONS[module.rope_type]
2330
- if module.rope_type != "default"
2331
- else module.compute_default_rope_parameters
2332
- )
2333
- buffer_value, _ = rope_fn(module.config)
2334
- init.copy_(module.inv_freq, buffer_value)
2335
- init.copy_(module.original_inv_freq, buffer_value)
2336
2160
 
2337
2161
  def _initialize_weights(self, module):
2338
2162
  """
@@ -2437,10 +2261,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2437
2261
 
2438
2262
  tied_mapping = self._tied_weights_keys
2439
2263
  # If the config does not specify any tying, return empty dict
2440
- # NOTE: not all modules have `tie_word_embeddings` attr, for example vision-only
2441
- # modules do not have any word embeddings!
2442
- tie_word_embeddings = getattr(self.config, "tie_word_embeddings", False)
2443
- if not tie_word_embeddings:
2264
+ if not self.config.tie_word_embeddings and not self.config.tie_encoder_decoder:
2444
2265
  return {}
2445
2266
  # If None, return empty dict
2446
2267
  elif tied_mapping is None:
@@ -2486,7 +2307,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2486
2307
 
2487
2308
  return expanded_tied_weights
2488
2309
 
2489
- def tie_weights(self, missing_keys: set[str] | None = None, recompute_mapping: bool = True):
2310
+ def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True):
2490
2311
  """
2491
2312
  Tie the model weights. If `recompute_mapping=False` (default when called internally), it will rely on the
2492
2313
  `model.all_tied_weights_keys` attribute, containing the `{target: source}` mapping for the tied params.
@@ -2506,26 +2327,30 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2506
2327
 
2507
2328
  tied_keys = list(tied_keys.items())
2508
2329
  for i, (target_param_name, source_param_name) in enumerate(tied_keys):
2330
+ # Usually we tie a single target to a single source, but when both are missing we may later tie
2331
+ # both the source and target to a third "backup" parameter that is present in the checkpoint, so we use
2332
+ # a list here
2333
+ target_param_names = [target_param_name]
2334
+
2509
2335
  # This is `from_pretrained` -> let's check symmetrically in case the source key is not present
2510
2336
  if missing_keys is not None:
2511
2337
  remove_from_missing = True
2512
2338
  source_is_there = source_param_name not in missing_keys
2513
2339
  target_is_there = target_param_name not in missing_keys
2514
2340
  # Both are already present -> it means the config is wrong and do not reflect the actual
2515
- # checkpoint -> let's raise a warning and NOT tie them
2341
+ # checkpoint -> let's raise a warning and do nothing
2516
2342
  if source_is_there and target_is_there:
2517
2343
  logger.warning(
2518
2344
  f"The tied weights mapping and config for this model specifies to tie {source_param_name} to "
2519
2345
  f"{target_param_name}, but both are present in the checkpoints, so we will NOT tie them. "
2520
2346
  "You should update the config with `tie_word_embeddings=False` to silence this warning"
2521
2347
  )
2522
- # Remove from internal attribute to correctly reflect actual tied weights
2523
- self.all_tied_weights_keys.pop(target_param_name)
2524
2348
  # Skip to next iteration
2525
2349
  continue
2526
2350
  # We're missing the source but we have the target -> we swap them, tying the parameter that exists
2527
2351
  elif not source_is_there and target_is_there:
2528
2352
  target_param_name, source_param_name = source_param_name, target_param_name
2353
+ target_param_names = [target_param_name]
2529
2354
  # Both are missing -> check other keys in case more than 2 keys are tied to the same weight
2530
2355
  elif not source_is_there and not target_is_there:
2531
2356
  for target_backup, source_backup in tied_keys[i + 1 :]:
@@ -2534,10 +2359,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2534
2359
  if source_backup == source_param_name:
2535
2360
  target_backup_is_there = target_backup not in missing_keys
2536
2361
  # If the target is present, we found the correct weight to tie into (we know the source is missing)
2537
- # Note here that we do not tie the missing source right now as well, as it will be done anyway when
2538
- # the pair (target_backup, source_backup) becomes the main pair (target_param_name, source_param_name)
2539
2362
  if target_backup_is_there:
2540
2363
  source_param_name = target_backup
2364
+ # Append the source as well, since both are missing we'll tie both
2365
+ target_param_names.append(source_param_name)
2541
2366
  break
2542
2367
  # If we did not break from the loop, it was impossible to find a source key -> let's raise
2543
2368
  else:
@@ -2553,18 +2378,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2553
2378
 
2554
2379
  # Perform the actual tying
2555
2380
  source_param = self.get_parameter_or_buffer(source_param_name)
2556
- if "." in target_param_name:
2557
- parent_name, name = target_param_name.rsplit(".", 1)
2558
- parent = self.get_submodule(parent_name)
2559
- else:
2560
- name = target_param_name
2561
- parent = self
2562
- # Tie the weights
2563
- setattr(parent, name, source_param)
2564
- self._adjust_bias(parent, source_param)
2565
- # Remove from missing if necesary
2566
- if missing_keys is not None and remove_from_missing:
2567
- missing_keys.discard(target_param_name)
2381
+ for target_param_name in target_param_names:
2382
+ if "." in target_param_name:
2383
+ parent_name, name = target_param_name.rsplit(".", 1)
2384
+ parent = self.get_submodule(parent_name)
2385
+ else:
2386
+ name = target_param_name
2387
+ parent = self
2388
+ # Tie the weights
2389
+ setattr(parent, name, source_param)
2390
+ self._adjust_bias(parent, source_param)
2391
+ # Remove from missing if necesary
2392
+ if missing_keys is not None and remove_from_missing:
2393
+ missing_keys.discard(target_param_name)
2568
2394
 
2569
2395
  def _adjust_bias(self, output_embeddings, input_embeddings):
2570
2396
  if getattr(output_embeddings, "bias", None) is not None and hasattr(output_embeddings, "weight"):
@@ -2609,8 +2435,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2609
2435
 
2610
2436
  def resize_token_embeddings(
2611
2437
  self,
2612
- new_num_tokens: int | None = None,
2613
- pad_to_multiple_of: int | None = None,
2438
+ new_num_tokens: Optional[int] = None,
2439
+ pad_to_multiple_of: Optional[int] = None,
2614
2440
  mean_resizing: bool = True,
2615
2441
  ) -> nn.Embedding:
2616
2442
  """
@@ -2690,7 +2516,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2690
2516
  new_num_tokens = new_embeddings.weight.shape[0]
2691
2517
 
2692
2518
  # if word embeddings are not tied, make sure that lm head is resized as well
2693
- if self.get_output_embeddings() is not None:
2519
+ if (
2520
+ self.get_output_embeddings() is not None
2521
+ and not self.config.get_text_config(decoder=True).tie_word_embeddings
2522
+ ):
2694
2523
  old_lm_head = self.get_output_embeddings()
2695
2524
  if isinstance(old_lm_head, torch.nn.Embedding):
2696
2525
  new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
@@ -2708,8 +2537,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2708
2537
  def _get_resized_embeddings(
2709
2538
  self,
2710
2539
  old_embeddings: nn.Embedding,
2711
- new_num_tokens: int | None = None,
2712
- pad_to_multiple_of: int | None = None,
2540
+ new_num_tokens: Optional[int] = None,
2541
+ pad_to_multiple_of: Optional[int] = None,
2713
2542
  mean_resizing: bool = True,
2714
2543
  ) -> nn.Embedding:
2715
2544
  """
@@ -2866,7 +2695,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2866
2695
  def _get_resized_lm_head(
2867
2696
  self,
2868
2697
  old_lm_head: nn.Linear,
2869
- new_num_tokens: int | None = None,
2698
+ new_num_tokens: Optional[int] = None,
2870
2699
  transposed: bool = False,
2871
2700
  mean_resizing: bool = True,
2872
2701
  ) -> nn.Linear:
@@ -3063,7 +2892,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3063
2892
  f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
3064
2893
  )
3065
2894
 
3066
- def get_position_embeddings(self) -> nn.Embedding | tuple[nn.Embedding]:
2895
+ def get_position_embeddings(self) -> Union[nn.Embedding, tuple[nn.Embedding]]:
3067
2896
  raise NotImplementedError(
3068
2897
  f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
3069
2898
  f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
@@ -3074,8 +2903,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3074
2903
  Maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
3075
2904
  initialization logic in `_init_weights`.
3076
2905
  """
3077
- # If we are initializing on meta device, there is no point in trying to run inits
3078
- if get_torch_context_manager_or_global_device() != torch.device("meta"):
2906
+ if _init_weights:
3079
2907
  # Initialize weights
3080
2908
  self.initialize_weights()
3081
2909
  # Tie weights needs to be called here, but it can use the pre-computed `all_tied_weights_keys`
@@ -3096,7 +2924,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3096
2924
  raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
3097
2925
 
3098
2926
  if gradient_checkpointing_kwargs is None:
3099
- gradient_checkpointing_kwargs = {"use_reentrant": False}
2927
+ gradient_checkpointing_kwargs = {"use_reentrant": True}
3100
2928
 
3101
2929
  gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
3102
2930
 
@@ -3113,10 +2941,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3113
2941
  "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
3114
2942
  )
3115
2943
 
3116
- needs_embedding_grads = self.main_input_name == "input_ids"
3117
- # we use that also to detect whether or not we have to raise if embeddings are missing (the submodel might not have embeddings at all)
3118
- enable_input_grads = needs_embedding_grads or getattr(self, "_hf_peft_config_loaded", False)
3119
- if enable_input_grads:
2944
+ if getattr(self, "_hf_peft_config_loaded", False):
3120
2945
  # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
3121
2946
  # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
3122
2947
  # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
@@ -3174,13 +2999,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3174
2999
 
3175
3000
  def save_pretrained(
3176
3001
  self,
3177
- save_directory: str | os.PathLike,
3002
+ save_directory: Union[str, os.PathLike],
3178
3003
  is_main_process: bool = True,
3179
- state_dict: dict | None = None,
3004
+ state_dict: Optional[dict] = None,
3005
+ save_function: Callable = torch.save,
3180
3006
  push_to_hub: bool = False,
3181
- max_shard_size: int | str = "50GB",
3182
- variant: str | None = None,
3183
- token: str | bool | None = None,
3007
+ max_shard_size: Union[int, str] = "5GB",
3008
+ safe_serialization: bool = True,
3009
+ variant: Optional[str] = None,
3010
+ token: Optional[Union[str, bool]] = None,
3184
3011
  save_peft_format: bool = True,
3185
3012
  save_original_format: bool = True,
3186
3013
  **kwargs,
@@ -3200,13 +3027,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3200
3027
  The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
3201
3028
  save parts of the model or if special precautions need to be taken when recovering the state dictionary
3202
3029
  of a model (like when using model parallelism).
3030
+ save_function (`Callable`):
3031
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
3032
+ need to replace `torch.save` by another method.
3203
3033
  push_to_hub (`bool`, *optional*, defaults to `False`):
3204
3034
  Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
3205
3035
  repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
3206
3036
  namespace).
3207
- max_shard_size (`int` or `str`, *optional*, defaults to `"50GB"`):
3037
+ max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`):
3208
3038
  The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
3209
3039
  lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
3040
+ We default it to 5GB in order for models to be able to run easily on free-tier google colab instances
3041
+ without CPU OOM issues.
3210
3042
 
3211
3043
  <Tip warning={true}>
3212
3044
 
@@ -3215,8 +3047,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3215
3047
 
3216
3048
  </Tip>
3217
3049
 
3050
+ safe_serialization (`bool`, *optional*, defaults to `True`):
3051
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
3218
3052
  variant (`str`, *optional*):
3219
- If specified, weights are saved in the format model.<variant>.safetensors.
3053
+ If specified, weights are saved in the format pytorch_model.<variant>.bin.
3220
3054
  token (`str` or `bool`, *optional*):
3221
3055
  The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
3222
3056
  the token generated when running `hf auth login` (stored in `~/.huggingface`).
@@ -3238,7 +3072,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3238
3072
 
3239
3073
  hf_quantizer = getattr(self, "hf_quantizer", None)
3240
3074
  quantization_serializable = (
3241
- hf_quantizer is not None and isinstance(hf_quantizer, HfQuantizer) and hf_quantizer.is_serializable()
3075
+ hf_quantizer is not None
3076
+ and isinstance(hf_quantizer, HfQuantizer)
3077
+ and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
3242
3078
  )
3243
3079
 
3244
3080
  if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
@@ -3247,6 +3083,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3247
3083
  " the logger on the traceback to understand the reason why the quantized model is not serializable."
3248
3084
  )
3249
3085
 
3086
+ if "save_config" in kwargs:
3087
+ warnings.warn(
3088
+ "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
3089
+ )
3090
+ is_main_process = kwargs.pop("save_config")
3091
+
3250
3092
  # we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
3251
3093
  if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
3252
3094
  raise ImportError(
@@ -3268,7 +3110,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3268
3110
 
3269
3111
  metadata = {}
3270
3112
  if hf_quantizer is not None:
3271
- state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self)
3113
+ state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self, safe_serialization)
3272
3114
  metadata["format"] = "pt"
3273
3115
 
3274
3116
  # Only save the model itself if we are using distributed training
@@ -3321,22 +3163,28 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3321
3163
  current_peft_config = self.peft_config[active_adapter]
3322
3164
  current_peft_config.save_pretrained(save_directory)
3323
3165
 
3324
- # Get the model state_dict
3166
+ # for offloaded modules
3167
+ module_map = {}
3168
+
3169
+ # Save the model
3325
3170
  if state_dict is None:
3326
- state_dict = model_to_save.state_dict()
3171
+ # if any model parameters are offloaded, make module map
3172
+ if (
3173
+ hasattr(self, "hf_device_map")
3174
+ and len(set(self.hf_device_map.values())) > 1
3175
+ and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
3176
+ ):
3177
+ warnings.warn(
3178
+ "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
3179
+ )
3180
+ for name, module in model_to_save.named_modules():
3181
+ if name == "":
3182
+ continue
3183
+ module_state_dict = module.state_dict()
3327
3184
 
3328
- # if any model parameters are offloaded, we need to know it for later
3329
- is_offloaded = False
3330
- if (
3331
- hasattr(self, "hf_device_map")
3332
- and len(set(self.hf_device_map.values())) > 1
3333
- and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
3334
- ):
3335
- is_offloaded = True
3336
- warnings.warn(
3337
- "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory "
3338
- "exceeds the `shard_size` (50GB default)"
3339
- )
3185
+ for key in module_state_dict:
3186
+ module_map[name + f".{key}"] = module
3187
+ state_dict = model_to_save.state_dict()
3340
3188
 
3341
3189
  # Translate state_dict from smp to hf if saving with smp >= 1.10
3342
3190
  if IS_SAGEMAKER_MP_POST_1_10:
@@ -3354,19 +3202,86 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3354
3202
  if self._tp_size is not None:
3355
3203
  state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
3356
3204
 
3357
- # Remove tied weights as safetensors do not handle them
3358
- state_dict = remove_tied_weights_from_state_dict(state_dict, model_to_save)
3205
+ if safe_serialization:
3206
+ # TODO: fix safe_serialization for tied weights
3207
+ # Safetensors does not allow tensor aliasing.
3208
+ # We're going to remove aliases before saving
3209
+ ptrs = collections.defaultdict(list)
3210
+ for name, tensor in state_dict.items():
3211
+ if not isinstance(tensor, torch.Tensor):
3212
+ # Sometimes in the state_dict we have non-tensor objects.
3213
+ # e.g. in bitsandbytes we have some `str` objects in the state_dict
3214
+ # In the non-tensor case, fall back to the pointer of the object itself
3215
+ ptrs[id(tensor)].append(name)
3216
+
3217
+ elif tensor.device.type == "meta":
3218
+ # In offloaded cases, there may be meta tensors in the state_dict.
3219
+ # For these cases, key by the pointer of the original tensor object
3220
+ # (state_dict tensors are detached and therefore no longer shared)
3221
+ tensor = self.get_parameter(name)
3222
+ ptrs[id(tensor)].append(name)
3223
+
3224
+ else:
3225
+ ptrs[id_tensor_storage(tensor)].append(name)
3226
+
3227
+ shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
3228
+
3229
+ # Recursively descend to find tied weight keys
3230
+ _tied_weights_keys = set(_get_tied_weight_keys(self))
3231
+ error_names = []
3232
+ to_delete_names = set()
3233
+ for names in shared_ptrs.values():
3234
+ # Removing the keys which are declared as known duplicates on
3235
+ # load. This allows to make sure the name which is kept is consistent.
3236
+ if _tied_weights_keys is not None:
3237
+ found = 0
3238
+ for name in sorted(names):
3239
+ matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
3240
+ if matches_pattern and name in state_dict:
3241
+ found += 1
3242
+ if found < len(names):
3243
+ to_delete_names.add(name)
3244
+ # We are entering a place where the weights and the transformers configuration do NOT match.
3245
+ shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
3246
+ # Those are actually tensor sharing but disjoint from each other, we can safely clone them
3247
+ # Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
3248
+ for name in disjoint_names:
3249
+ state_dict[name] = state_dict[name].clone()
3250
+
3251
+ # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
3252
+ # If the link between tensors was done at runtime then `from_pretrained` will not get
3253
+ # the key back leading to random tensor. A proper warning will be shown
3254
+ # during reload (if applicable), but since the file is not necessarily compatible with
3255
+ # the config, better show a proper warning.
3256
+ shared_names, identical_names = _find_identical(shared_names, state_dict)
3257
+ # delete tensors that have identical storage
3258
+ for inames in identical_names:
3259
+ known = inames.intersection(to_delete_names)
3260
+ for name in known:
3261
+ del state_dict[name]
3262
+ unknown = inames.difference(to_delete_names)
3263
+ if len(unknown) > 1:
3264
+ error_names.append(unknown)
3265
+
3266
+ if shared_names:
3267
+ error_names.extend(shared_names)
3268
+
3269
+ if len(error_names) > 0:
3270
+ raise RuntimeError(
3271
+ f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. We found `_tied_weights_keys` to be: {_tied_weights_keys}.\n"
3272
+ "This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
3273
+ )
3359
3274
 
3360
3275
  # Revert all renaming and/or weight operations
3361
- if save_original_format and not _hf_peft_config_loaded:
3362
- state_dict = revert_weight_conversion(model_to_save, state_dict)
3276
+ if save_original_format:
3277
+ state_dict = revert_weight_conversion(self, state_dict)
3363
3278
 
3364
3279
  # Shard the model if it is too big.
3365
3280
  if not _hf_peft_config_loaded:
3366
- weights_name = SAFE_WEIGHTS_NAME
3281
+ weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
3367
3282
  weights_name = _add_variant(weights_name, variant)
3368
3283
  else:
3369
- weights_name = ADAPTER_SAFE_WEIGHTS_NAME
3284
+ weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME
3370
3285
 
3371
3286
  filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
3372
3287
  state_dict_split = split_torch_state_dict_into_shards(
@@ -3399,45 +3314,57 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3399
3314
  and reg.fullmatch(filename_no_suffix) is not None
3400
3315
  ):
3401
3316
  os.remove(full_filename)
3402
-
3403
3317
  # Save the model
3404
- for shard_file, tensor_names in logging.tqdm(
3405
- state_dict_split.filename_to_tensors.items(), desc="Writing model shards"
3406
- ):
3407
- filename = os.path.join(save_directory, shard_file)
3408
- shard_state_dict = {}
3409
- for tensor_name in tensor_names:
3410
- # Get the tensor, and remove it from state_dict to avoid keeping the ref
3411
- tensor = state_dict.pop(tensor_name)
3412
-
3413
- # In case of TP, get the full parameter back
3414
- if _is_dtensor_available and isinstance(tensor, DTensor):
3415
- tensor = tensor.full_tensor()
3318
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
3319
+ if module_map:
3320
+ filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards")
3321
+ for shard_file, tensors in filename_to_tensors:
3322
+ shard = {}
3323
+ for tensor in tensors:
3324
+ if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
3325
+ full_tensor = state_dict[tensor].full_tensor()
3416
3326
  # to get the correctly ordered tensor we need to repack if packed
3417
- if _get_parameter_tp_plan(tensor_name, self._tp_plan) == "local_packed_rowwise":
3418
- tensor = repack_weights(tensor, -1, self._tp_size, 2)
3419
-
3420
- # If the param was offloaded, we need to load it back from disk to resave it. It's a strange pattern,
3421
- # but it would otherwise not be contained in the saved shard if we were to simply move the file
3422
- # or something
3423
- if is_offloaded and tensor.device.type == "meta":
3424
- tensor = load_offloaded_parameter(model_to_save, tensor_name)
3425
-
3426
- # only do contiguous after it's permuted correctly in case of TP
3427
- shard_state_dict[tensor_name] = tensor.contiguous()
3327
+ if _get_parameter_tp_plan(tensor, self._tp_plan) == "local_packed_rowwise":
3328
+ full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
3329
+ shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
3330
+ else:
3331
+ shard[tensor] = state_dict[tensor].contiguous()
3332
+ # delete reference, see https://github.com/huggingface/transformers/pull/34890
3333
+ del state_dict[tensor]
3334
+
3335
+ # remake shard with onloaded parameters if necessary
3336
+ if module_map:
3337
+ # init state_dict for this shard
3338
+ shard_state_dict = dict.fromkeys(shard, "")
3339
+ for module_name in shard:
3340
+ # note that get_state_dict_from_offload can update with meta tensors
3341
+ # if both a parent module and its descendant are offloaded
3342
+ tensor = shard_state_dict[module_name]
3343
+ if tensor == "" or (isinstance(tensor, torch.Tensor) and tensor.device.type == "meta"):
3344
+ # update state dict with onloaded parameters
3345
+ module = module_map[module_name]
3346
+ shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
3347
+
3348
+ # assign shard to be the completed state dict
3349
+ shard = shard_state_dict
3350
+ del shard_state_dict
3351
+ gc.collect()
3352
+
3353
+ if safe_serialization:
3354
+ # At some point we will need to deal better with save_function (used for TPU and other distributed
3355
+ # joyfulness), but for now this enough. # TODO: we should def parallelize this we are otherwise just waiting
3356
+ # too much before scheduling the next write when its in a different file
3357
+ safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
3358
+ else:
3359
+ save_function(shard, os.path.join(save_directory, shard_file))
3428
3360
 
3429
- # TODO: it would be very nice to do the writing concurrently, but safetensors never releases the GIL,
3430
- # so it's not possible for now....
3431
- # Write the shard to disk
3432
- safe_save_file(shard_state_dict, filename, metadata=metadata)
3433
- # Cleanup the data before next loop (important with offloading, so we don't blowup cpu RAM)
3434
- del shard_state_dict
3361
+ del state_dict
3435
3362
 
3436
3363
  if index is None:
3437
3364
  path_to_weights = os.path.join(save_directory, weights_name)
3438
3365
  logger.info(f"Model weights saved in {path_to_weights}")
3439
3366
  else:
3440
- save_index_file = SAFE_WEIGHTS_INDEX_NAME
3367
+ save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
3441
3368
  save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
3442
3369
  # Save the index as well
3443
3370
  with open(save_index_file, "w", encoding="utf-8") as f:
@@ -3574,9 +3501,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3574
3501
  " desired `dtype` by passing the correct `dtype` argument."
3575
3502
  )
3576
3503
 
3577
- if getattr(self, "is_loaded_in_8bit", False) and not is_bitsandbytes_available("0.48"):
3504
+ if getattr(self, "is_loaded_in_8bit", False):
3578
3505
  raise ValueError(
3579
- "You need to install `pip install bitsandbytes>=0.48.0` if you want to move a 8-bit model across devices using to()."
3506
+ "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
3507
+ " model has already been set to the correct devices and casted to the correct `dtype`."
3580
3508
  )
3581
3509
  elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
3582
3510
  if dtype_present_in_args:
@@ -3607,38 +3535,23 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3607
3535
  return super().float(*args)
3608
3536
 
3609
3537
  @classmethod
3610
- def get_init_context(cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool):
3611
- # Need to instantiate with correct dtype
3612
- init_contexts = [local_torch_dtype(dtype, cls.__name__)]
3538
+ def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
3613
3539
  if is_deepspeed_zero3_enabled():
3614
3540
  import deepspeed
3615
3541
 
3542
+ init_contexts = [no_init_weights()]
3616
3543
  # We cannot initialize the model on meta device with deepspeed when not quantized
3617
3544
  if not is_quantized and not _is_ds_init_called:
3618
3545
  logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
3619
- init_contexts.extend(
3620
- [
3621
- init.no_init_weights(),
3622
- deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
3623
- set_zero3_state(),
3624
- ]
3625
- )
3546
+ init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
3626
3547
  elif is_quantized:
3627
- init_contexts.extend([torch.device("meta"), set_quantized_state()])
3548
+ init_contexts.extend([init_empty_weights(), set_quantized_state()])
3628
3549
  else:
3629
- init_contexts.append(torch.device("meta"))
3550
+ init_contexts = [no_init_weights(), init_empty_weights()]
3630
3551
 
3631
3552
  return init_contexts
3632
3553
 
3633
- def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None):
3634
- """
3635
- Set whether or not to use the `kernels` library to kernelize some layers of the model.
3636
- Args:
3637
- use_kernels (`bool`):
3638
- Whether or not to use the `kernels` library to kernelize some layers of the model.
3639
- kernel_config (`KernelConfig`, *optional*):
3640
- The kernel configuration to use to kernelize the model. If `None`, the default kernel mapping will be used.
3641
- """
3554
+ def set_use_kernels(self, use_kernels, kernel_config):
3642
3555
  if use_kernels:
3643
3556
  if not is_kernels_available():
3644
3557
  raise ValueError(
@@ -3659,9 +3572,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3659
3572
 
3660
3573
  # This is a context manager to override the default kernel mapping
3661
3574
  # We are calling kernelize inside this context manager using the use_kernels setter
3662
- # Param inherit_mapping should be False to avoid still loading kernel from remote
3663
- inherit_mapping = not kernel_config.use_local_kernel
3664
- with use_kernel_mapping(kernel_config.kernel_mapping, inherit_mapping=inherit_mapping):
3575
+ with use_kernel_mapping(kernel_config.kernel_mapping):
3665
3576
  self.use_kernels = True
3666
3577
  # We use the default kernel mapping in .integrations.hub_kernels
3667
3578
  else:
@@ -3670,18 +3581,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3670
3581
  self.use_kernels = False
3671
3582
 
3672
3583
  @classmethod
3584
+ @restore_default_dtype
3673
3585
  def from_pretrained(
3674
3586
  cls: type[SpecificPreTrainedModelType],
3675
- pretrained_model_name_or_path: str | os.PathLike | None,
3587
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
3676
3588
  *model_args,
3677
- config: PreTrainedConfig | str | os.PathLike | None = None,
3678
- cache_dir: str | os.PathLike | None = None,
3589
+ config: Optional[Union[PreTrainedConfig, str, os.PathLike]] = None,
3590
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
3679
3591
  ignore_mismatched_sizes: bool = False,
3680
3592
  force_download: bool = False,
3681
3593
  local_files_only: bool = False,
3682
- token: str | bool | None = None,
3594
+ token: Optional[Union[str, bool]] = None,
3683
3595
  revision: str = "main",
3684
- use_safetensors: bool | None = None,
3596
+ use_safetensors: Optional[bool] = True,
3685
3597
  weights_only: bool = True,
3686
3598
  **kwargs,
3687
3599
  ) -> SpecificPreTrainedModelType:
@@ -3778,18 +3690,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3778
3690
  "org/model@main"
3779
3691
  "org/model:custom_kernel"
3780
3692
  "org/model@v1.2.3:custom_kernel"
3781
- experts_implementation (`str`, *optional*):
3782
- The experts implementation to use in the model (if relevant). Can be any of:
3783
-
3784
- - `"eager"` (sequential implementation of the experts matrix multiplications).
3785
- - `"batched_mm"` (using [`torch.bmm`](https://pytorch.org/docs/stable/generated/torch.bmm.html)).
3786
- - `"grouped_mm"` (using [`torch._grouped_mm`](https://docs.pytorch.org/docs/main/generated/torch.nn.functional.grouped_mm.html)).
3787
-
3788
- By default, if available, `grouped_mm` will be used for torch>=2.9.0. The default is otherwise the sequential `"eager"` implementation.
3789
3693
 
3790
3694
  > Parameters for big model inference
3791
3695
 
3792
- dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"`):
3696
+ dtype (`str` or `torch.dtype`, *optional*):
3793
3697
  Override the default `torch_dtype` and load the model under a specific `dtype`. The different options
3794
3698
  are:
3795
3699
 
@@ -3931,8 +3835,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3931
3835
  # For BC on torch_dtype argument
3932
3836
  if torch_dtype is not None:
3933
3837
  dtype = dtype if dtype is not None else torch_dtype
3934
- if dtype is None:
3935
- dtype = "auto"
3936
3838
 
3937
3839
  if is_offline_mode() and not local_files_only:
3938
3840
  local_files_only = True
@@ -4009,11 +3911,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4009
3911
  if "attn_implementation" in kwargs:
4010
3912
  config._attn_implementation = kwargs.pop("attn_implementation")
4011
3913
 
4012
- if "experts_implementation" in kwargs:
4013
- config._experts_implementation = kwargs.pop("experts_implementation")
4014
-
4015
- hf_quantizer, config, device_map = get_hf_quantizer(
4016
- config, quantization_config, device_map, weights_only, user_agent
3914
+ hf_quantizer, config, dtype, device_map = get_hf_quantizer(
3915
+ config, quantization_config, dtype, device_map, weights_only, user_agent
4017
3916
  )
4018
3917
 
4019
3918
  if gguf_file:
@@ -4060,29 +3959,33 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4060
3959
  ]
4061
3960
 
4062
3961
  # Find the correct dtype based on current state
4063
- config, dtype = _get_dtype(
4064
- dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only, hf_quantizer
3962
+ config, dtype, dtype_orig = _get_dtype(
3963
+ cls, dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
4065
3964
  )
4066
3965
 
4067
3966
  config.name_or_path = pretrained_model_name_or_path
4068
- model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called)
3967
+ model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
4069
3968
  config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
4070
3969
  with ContextManagers(model_init_context):
4071
3970
  # Let's make sure we don't run the init function of buffer modules
4072
3971
  model = cls(config, *model_args, **model_kwargs)
4073
3972
 
4074
- if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
4075
- hf_quantizer.preprocess_model(
4076
- model=model,
4077
- dtype=dtype,
4078
- device_map=device_map,
4079
- checkpoint_files=checkpoint_files,
4080
- use_kernels=use_kernels,
4081
- )
4082
-
4083
3973
  # Obtain the weight conversion mapping for this model if any are registered
4084
3974
  weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer)
4085
3975
 
3976
+ # make sure we use the model's config since the __init__ call might have copied it
3977
+ config = model.config
3978
+
3979
+ if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
3980
+ hf_quantizer.preprocess_model(
3981
+ model=model,
3982
+ device_map=device_map,
3983
+ keep_in_fp32_modules=model._keep_in_fp32_modules, # TODO prob no longer needed?
3984
+ config=config,
3985
+ checkpoint_files=checkpoint_files,
3986
+ use_kernels=use_kernels,
3987
+ )
3988
+
4086
3989
  if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
4087
3990
  model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)
4088
3991
 
@@ -4090,30 +3993,33 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4090
3993
  if device_map is not None:
4091
3994
  device_map = _get_device_map(model, device_map, max_memory, hf_quantizer)
4092
3995
 
3996
+ # restore default dtype
3997
+ if dtype_orig is not None:
3998
+ torch.set_default_dtype(dtype_orig)
3999
+
4093
4000
  # Finalize model weight initialization
4094
- load_config = LoadStateDictConfig(
4095
- pretrained_model_name_or_path=pretrained_model_name_or_path,
4001
+ model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs = cls._load_pretrained_model(
4002
+ model,
4003
+ state_dict,
4004
+ checkpoint_files,
4005
+ pretrained_model_name_or_path,
4096
4006
  ignore_mismatched_sizes=ignore_mismatched_sizes,
4097
4007
  sharded_metadata=sharded_metadata,
4098
4008
  device_map=device_map,
4099
4009
  disk_offload_folder=offload_folder,
4100
- offload_buffers=offload_buffers,
4101
4010
  dtype=dtype,
4102
4011
  hf_quantizer=hf_quantizer,
4103
4012
  device_mesh=device_mesh,
4104
4013
  weights_only=weights_only,
4105
4014
  weight_mapping=weight_conversions,
4106
- use_safetensors=use_safetensors,
4107
- download_kwargs=download_kwargs,
4108
4015
  )
4109
- load_info = cls._load_pretrained_model(model, state_dict, checkpoint_files, load_config)
4110
- load_info = cls._finalize_load_state_dict(model, load_config, load_info)
4111
- model.eval() # Set model in evaluation mode to deactivate Dropout modules by default
4016
+
4017
+ model.eval() # Set model in evaluation mode to deactivate DropOut modules by default
4112
4018
  model.set_use_kernels(use_kernels, kernel_config)
4113
4019
 
4114
4020
  # If it is a model with generation capabilities, attempt to load generation files (generation config,
4115
4021
  # custom generate function)
4116
- if model.can_generate() and hasattr(model, "adjust_generation_fn") and not gguf_file:
4022
+ if model.can_generate() and hasattr(model, "adjust_generation_fn"):
4117
4023
  model.adjust_generation_fn(
4118
4024
  generation_config,
4119
4025
  from_auto_class,
@@ -4124,34 +4030,29 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4124
4030
  **kwargs,
4125
4031
  )
4126
4032
 
4127
- # If the device_map has more than 1 device: dispatch model with hooks on all devices
4128
- if device_map is not None and len(set(device_map.values())) > 1:
4129
- accelerate_dispatch(
4130
- model, hf_quantizer, device_map, offload_folder, load_info.disk_offload_index, offload_buffers
4131
- )
4033
+ # for device_map="auto" : dispatch model with hooks on all devices if necessary
4034
+ if device_map is not None and device_mesh is None:
4035
+ accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers)
4132
4036
 
4133
4037
  if hf_quantizer is not None:
4134
4038
  model.hf_quantizer = hf_quantizer
4135
- hf_quantizer.postprocess_model(
4136
- model
4137
- ) # usually a no-op but sometimes needed, e.g to remove the quant config when dequantizing
4039
+ hf_quantizer.postprocess_model(model, config=config) # usually a no-op but sometimes needed
4138
4040
 
4139
4041
  if _adapter_model_path is not None:
4140
- if token is not None:
4141
- adapter_kwargs["token"] = token
4142
- load_info = model.load_adapter(
4042
+ adapter_kwargs["key_mapping"] = weight_conversions # TODO: Dynamic weight loader for adapters
4043
+ model.load_adapter(
4143
4044
  _adapter_model_path,
4144
4045
  adapter_name=adapter_name,
4145
- load_config=load_config,
4046
+ token=token,
4146
4047
  adapter_kwargs=adapter_kwargs,
4147
4048
  )
4148
4049
 
4149
4050
  if output_loading_info:
4150
4051
  loading_info = {
4151
- "missing_keys": load_info.missing_keys,
4152
- "unexpected_keys": load_info.unexpected_keys,
4153
- "mismatched_keys": load_info.mismatched_keys,
4154
- "error_msgs": load_info.error_msgs,
4052
+ "missing_keys": missing_keys,
4053
+ "unexpected_keys": unexpected_keys,
4054
+ "mismatched_keys": mismatched_keys,
4055
+ "error_msgs": error_msgs,
4155
4056
  }
4156
4057
  return model, loading_info
4157
4058
  return model
@@ -4160,65 +4061,74 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4160
4061
  def _load_pretrained_model(
4161
4062
  cls,
4162
4063
  model: "PreTrainedModel",
4163
- state_dict: dict | None,
4164
- checkpoint_files: list[str] | None,
4165
- load_config: LoadStateDictConfig,
4166
- ) -> LoadStateDictInfo:
4167
- is_quantized = load_config.is_quantized
4168
- is_hqq_or_quark = is_quantized and load_config.hf_quantizer.quantization_config.quant_method in {
4064
+ state_dict: Optional[dict],
4065
+ checkpoint_files: Optional[list[str]],
4066
+ pretrained_model_name_or_path: Optional[str],
4067
+ ignore_mismatched_sizes: bool = False,
4068
+ sharded_metadata: Optional[dict] = None,
4069
+ device_map: Optional[dict] = None,
4070
+ disk_offload_folder: Optional[str] = None,
4071
+ dtype: Optional[torch.dtype] = None,
4072
+ hf_quantizer: Optional[HfQuantizer] = None,
4073
+ device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
4074
+ weights_only: bool = True,
4075
+ weight_mapping: Optional[Sequence[WeightConverter | WeightRenaming]] = None,
4076
+ ):
4077
+ is_quantized = hf_quantizer is not None
4078
+ is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
4169
4079
  QuantizationMethod.HQQ,
4170
4080
  QuantizationMethod.QUARK,
4171
4081
  }
4172
4082
 
4173
4083
  # Model's definition arriving here is final (TP hooks added, quantized layers replaces)
4174
4084
  expected_keys = list(model.state_dict().keys())
4175
-
4176
4085
  if logger.level >= logging.WARNING:
4177
4086
  verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None))
4178
4087
 
4179
4088
  # This offload index if for params explicitly on the "disk" in the device_map
4180
4089
  disk_offload_index = None
4181
4090
  # Prepare parameters offloading if needed
4182
- if load_config.device_map is not None and "disk" in load_config.device_map.values():
4091
+ if device_map is not None and "disk" in device_map.values():
4183
4092
  disk_offload_index = accelerate_disk_offload(
4184
- model,
4185
- load_config.disk_offload_folder,
4093
+ disk_offload_folder,
4186
4094
  checkpoint_files,
4187
- load_config.device_map,
4188
- load_config.sharded_metadata,
4189
- load_config.dtype,
4190
- load_config.weight_mapping,
4095
+ device_map,
4096
+ expected_keys,
4097
+ sharded_metadata,
4098
+ dtype,
4099
+ weight_mapping,
4191
4100
  )
4192
4101
 
4193
4102
  # Warmup cuda to load the weights much faster on devices
4194
- if load_config.device_map is not None and not is_hqq_or_quark:
4195
- expanded_device_map = expand_device_map(load_config.device_map, expected_keys)
4196
- caching_allocator_warmup(model, expanded_device_map, load_config.hf_quantizer)
4103
+ if device_map is not None and not is_hqq_or_quark:
4104
+ expanded_device_map = expand_device_map(device_map, expected_keys)
4105
+ caching_allocator_warmup(model, expanded_device_map, hf_quantizer)
4197
4106
 
4107
+ tp_plan = getattr(model, "_tp_plan", None)
4198
4108
  error_msgs = []
4199
4109
 
4200
4110
  if is_deepspeed_zero3_enabled() and not is_quantized:
4201
4111
  if state_dict is None:
4202
4112
  merged_state_dict = {}
4203
4113
  for ckpt_file in checkpoint_files:
4204
- merged_state_dict.update(
4205
- load_state_dict(ckpt_file, map_location="cpu", weights_only=load_config.weights_only)
4206
- )
4114
+ merged_state_dict.update(load_state_dict(ckpt_file, map_location="cpu", weights_only=weights_only))
4207
4115
  state_dict = merged_state_dict
4208
- error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict, load_config)
4116
+ error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
4209
4117
  # This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
4210
- unexpected_keys, mismatched_keys, conversion_errors = set(), set(), set()
4118
+ missing_keys, unexpected_keys, mismatched_keys, misc = set(), set(), set(), set()
4211
4119
  else:
4212
4120
  all_pointer = set()
4213
- if state_dict is not None:
4214
- merged_state_dict = state_dict
4215
- elif checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") and state_dict is None:
4121
+ # Checkpoints are safetensors
4122
+ if checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors"):
4216
4123
  merged_state_dict = {}
4217
4124
  for file in checkpoint_files:
4218
4125
  file_pointer = safe_open(file, framework="pt", device="cpu")
4219
4126
  all_pointer.add(file_pointer)
4220
4127
  for k in file_pointer.keys():
4221
4128
  merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet
4129
+ # User passed an explicit state_dict
4130
+ elif state_dict is not None:
4131
+ merged_state_dict = state_dict
4222
4132
  # Checkpoints are .bin
4223
4133
  elif checkpoint_files is not None:
4224
4134
  merged_state_dict = {}
@@ -4227,14 +4137,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4227
4137
  else:
4228
4138
  raise ValueError("Neither a state dict nor checkpoint files were found.")
4229
4139
 
4230
- missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, conversion_errors = (
4140
+ missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, misc = (
4231
4141
  convert_and_load_state_dict_in_model(
4232
- model=model,
4233
- state_dict=merged_state_dict,
4234
- load_config=load_config,
4235
- tp_plan=model._tp_plan,
4236
- dtype_plan=model.dtype_plan,
4237
- disk_offload_index=disk_offload_index,
4142
+ model,
4143
+ merged_state_dict,
4144
+ weight_mapping,
4145
+ tp_plan,
4146
+ hf_quantizer,
4147
+ dtype,
4148
+ device_map,
4149
+ model.dtype_plan,
4150
+ device_mesh,
4151
+ disk_offload_index,
4152
+ disk_offload_folder,
4238
4153
  )
4239
4154
  )
4240
4155
 
@@ -4242,58 +4157,65 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4242
4157
  for k in all_pointer:
4243
4158
  k.__exit__(None, None, None)
4244
4159
 
4245
- return LoadStateDictInfo(
4246
- missing_keys=missing_keys,
4247
- unexpected_keys=unexpected_keys,
4248
- mismatched_keys=mismatched_keys,
4249
- disk_offload_index=disk_offload_index,
4250
- error_msgs=error_msgs,
4251
- conversion_errors=conversion_errors,
4252
- )
4253
-
4254
- @staticmethod
4255
- def _finalize_load_state_dict(
4256
- model,
4257
- load_config: LoadStateDictConfig,
4258
- load_info: LoadStateDictInfo,
4259
- ) -> LoadStateDictInfo:
4260
- # TODO @ArthurZucker this will be in a separate function to allows people not to run this
4261
- # for more granularity
4262
-
4263
4160
  # Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
4264
4161
  model.mark_tied_weights_as_initialized()
4265
4162
 
4266
- # Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from
4267
- # meta device (because they were not moved when loading the weights as they were not in the loaded state dict)
4268
- missing_and_mismatched = load_info.missing_keys | {k[0] for k in load_info.mismatched_keys}
4269
- model._move_missing_keys_from_meta_to_device(
4270
- missing_and_mismatched, load_config.device_map, load_config.device_mesh, load_config.hf_quantizer
4271
- )
4163
+ # Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when
4164
+ # loading the weights as they are not in the loaded state dict)
4165
+ miss_and_mismatched = missing_keys | {k[0] for k in mismatched_keys}
4166
+ model._move_missing_keys_from_meta_to_cpu(miss_and_mismatched, dtype, hf_quantizer)
4272
4167
 
4273
- # Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialized` flag)
4274
- model._initialize_missing_keys(load_config.is_quantized)
4168
+ # Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialzed` flag)
4169
+ model._initialize_missing_keys(is_quantized)
4275
4170
 
4276
4171
  # Tie the weights
4277
- model.tie_weights(missing_keys=load_info.missing_keys, recompute_mapping=False)
4172
+ model.tie_weights(missing_keys=missing_keys, recompute_mapping=False)
4278
4173
 
4279
4174
  # Adjust missing and unexpected keys
4280
- missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(
4281
- load_info.missing_keys, load_info.unexpected_keys
4282
- )
4175
+ missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(missing_keys, unexpected_keys)
4176
+
4177
+ # Post-processing for tensor parallelism
4178
+ if device_mesh is not None:
4179
+ # When using TP, the device map is a single device for all parameters
4180
+ tp_device = list(device_map.values())[0]
4181
+ # This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
4182
+ # not part of the state_dict (persistent=False)
4183
+ for buffer in model.buffers(): # TODO to avaoid this buffer could be added to the ckpt
4184
+ if buffer.device != tp_device:
4185
+ buffer.data = buffer.to(tp_device)
4186
+
4187
+ # In this case, the top-most task module weights were not moved to device and parallelized as they
4188
+ # were not part of the loaded weights: do it now
4189
+ if missing_keys:
4190
+ state_dict = model.state_dict()
4191
+ for name in missing_keys:
4192
+ param = state_dict[name]
4193
+ # Shard the param
4194
+ shard_and_distribute_module(
4195
+ model,
4196
+ param.to(tp_device),
4197
+ param,
4198
+ name,
4199
+ None,
4200
+ False,
4201
+ device_mesh.get_local_rank(),
4202
+ device_mesh,
4203
+ )
4283
4204
 
4284
4205
  log_state_dict_report(
4285
4206
  model=model,
4286
- load_config=load_config,
4207
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
4287
4208
  logger=logger,
4288
- error_msgs=load_info.error_msgs,
4209
+ error_msgs=error_msgs,
4289
4210
  unexpected_keys=unexpected_keys,
4290
4211
  missing_keys=missing_keys,
4291
- mismatched_keys=load_info.mismatched_keys,
4292
- mismatched_shapes=load_info.mismatched_keys,
4293
- conversion_errors=load_info.conversion_errors,
4212
+ mismatched_keys=mismatched_keys,
4213
+ mismatched_shapes=mismatched_keys,
4214
+ misc=misc,
4215
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
4294
4216
  )
4295
4217
 
4296
- return replace(load_info, missing_keys=missing_keys, unexpected_keys=unexpected_keys)
4218
+ return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs
4297
4219
 
4298
4220
  def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
4299
4221
  module_keys = {".".join(key.split(".")[:-1]) for key in names}
@@ -4362,17 +4284,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4362
4284
 
4363
4285
  # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an
4364
4286
  # attention_mask or not. In this case, we should still show a warning because this is a rare case.
4365
- # NOTE: `sep_token_id` is not used in all models and it can be absent in the config
4366
- sep_token_id = getattr(self.config, "sep_token_id", None)
4367
4287
  if (
4368
4288
  (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id)
4369
4289
  or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id)
4370
- or (sep_token_id is not None and sep_token_id == self.config.pad_token_id)
4290
+ or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id)
4371
4291
  ):
4372
4292
  warn_string += (
4373
4293
  f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical "
4374
4294
  f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), "
4375
- f"or the `sep_token_id` ({sep_token_id}), and your input is not padded."
4295
+ f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded."
4376
4296
  )
4377
4297
 
4378
4298
  logger.warning_once(warn_string)
@@ -4457,7 +4377,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4457
4377
  )
4458
4378
  self._use_kernels = False
4459
4379
 
4460
- def get_compiled_call(self, compile_config: CompileConfig | None) -> Callable:
4380
+ def get_compiled_call(self, compile_config: Optional[CompileConfig]) -> Callable:
4461
4381
  """Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
4462
4382
  non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
4463
4383
  want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
@@ -4479,54 +4399,33 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4479
4399
  def is_backend_compatible(cls):
4480
4400
  return cls._supports_attention_backend
4481
4401
 
4482
- def _move_missing_keys_from_meta_to_device(
4483
- self,
4484
- missing_keys: list[str],
4485
- device_map: dict | None,
4486
- device_mesh: "torch.distributed.device_mesh.DeviceMesh | None",
4487
- hf_quantizer: HfQuantizer | None,
4402
+ def _move_missing_keys_from_meta_to_cpu(
4403
+ self, missing_keys: list[str], dtype: torch.dtype, hf_quantizer: Optional[HfQuantizer]
4488
4404
  ) -> None:
4489
- """Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts)
4490
- back from meta device to their device according to the `device_map` if any, else cpu. Takes care of sharding those
4491
- missing parameters if `device_mesh` is provided, i.e. we are using TP.
4492
- All non-persistent buffers are also moved back to the correct device (they are not part of the state_dict, but are
4493
- not missing either).
4405
+ """Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts) back
4406
+ from meta device to cpu.
4494
4407
  """
4495
4408
  is_quantized = hf_quantizer is not None
4496
- # This is the only case where we do not initialize the model on meta device, so we don't have to do anything here
4497
- if is_deepspeed_zero3_enabled() and not is_quantized:
4498
- return
4499
4409
 
4500
4410
  # In this case we need to move everything back
4501
4411
  if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
4412
+ # We only do it for the parameters, as the buffers are not initialized on the meta device by default
4502
4413
  for key, param in self.named_parameters():
4503
- value = torch.empty_like(param, device="cpu")
4504
- _load_parameter_into_model(self, key, value)
4505
- for key, buffer in self.named_buffers():
4506
- value = torch.empty_like(buffer, device="cpu")
4414
+ value = torch.empty_like(param, dtype=dtype, device="cpu")
4507
4415
  _load_parameter_into_model(self, key, value)
4508
4416
  return
4509
4417
 
4418
+ model_state_dict = self.state_dict()
4510
4419
  # The tied weight keys are in the "missing" usually, but they should not be moved (they will be tied anyway)
4511
4420
  # This is especially important because if they are moved, they will lose the `_is_hf_initialized` flag, and they
4512
4421
  # will be re-initialized for nothing (which can be quite long)
4513
4422
  for key in missing_keys - self.all_tied_weights_keys.keys():
4514
- param = self.get_parameter_or_buffer(key)
4515
- param_device = get_device(device_map, key, valid_torch_device=True)
4516
- value = torch.empty_like(param, device=param_device)
4517
- # For TP, we may need to shard the param
4518
- if device_mesh is not None:
4519
- shard_and_distribute_module(
4520
- self, value, param, key, None, False, device_mesh.get_local_rank(), device_mesh
4521
- )
4522
- # Otherwise, just move it to device
4523
- else:
4524
- _load_parameter_into_model(self, key, value)
4525
- # We need to move back non-persistent buffers as well, as they are not part of loaded weights anyway
4526
- for key, buffer in self.named_non_persistent_buffers():
4527
- buffer_device = get_device(device_map, key, valid_torch_device=True)
4528
- value = torch.empty_like(buffer, device=buffer_device)
4529
- _load_parameter_into_model(self, key, value)
4423
+ param = model_state_dict[key]
4424
+ # Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them
4425
+ if param.device == torch.device("meta"):
4426
+ value = torch.empty_like(param, dtype=dtype, device="cpu")
4427
+ if not is_quantized or not hf_quantizer.param_needs_quantization(self, key):
4428
+ _load_parameter_into_model(self, key, value)
4530
4429
 
4531
4430
  def _initialize_missing_keys(self, is_quantized: bool) -> None:
4532
4431
  """
@@ -4554,6 +4453,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4554
4453
  ) -> tuple[set[str], set[str]]:
4555
4454
  """Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid
4556
4455
  raising unneeded warnings/errors.
4456
+ Also, set the `_is_hf_initialized` on tied weight keys, to avoid initializing them as they are going to
4457
+ be tied anyway.
4557
4458
  """
4558
4459
  # Old checkpoints may have keys for rotary_emb.inv_freq forach layer, however we moved this buffer to the main model
4559
4460
  # (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
@@ -4612,19 +4513,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4612
4513
 
4613
4514
  raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.")
4614
4515
 
4615
- def named_non_persistent_buffers(
4616
- self, recurse: bool = True, remove_duplicate: bool = True
4617
- ) -> Iterator[tuple[str, torch.Tensor]]:
4618
- """Similar to `named_buffers`, but only yield non-persistent ones. It is handy as it's not perfectly straightforward
4619
- to know if they are persistent or not"""
4620
- for name, tensor in self.named_buffers(recurse=recurse, remove_duplicate=remove_duplicate):
4621
- # We have to grab the parent here, as the attribute `_non_persistent_buffers_set` is on the immediate
4622
- # parent only
4623
- parent, buf_name = name.rsplit(".", 1) if "." in name else ("", name)
4624
- parent = self.get_submodule(parent)
4625
- if buf_name in parent._non_persistent_buffers_set:
4626
- yield name, tensor
4627
-
4628
4516
  def train(self, mode: bool = True):
4629
4517
  out = super().train(mode)
4630
4518
  if self.use_kernels:
@@ -4667,7 +4555,7 @@ def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
4667
4555
  return model
4668
4556
 
4669
4557
 
4670
- def is_accelerator_device(device: str | int | torch.device) -> bool:
4558
+ def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
4671
4559
  """Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
4672
4560
  a proper `torch.device`.
4673
4561
  """
@@ -4677,41 +4565,7 @@ def is_accelerator_device(device: str | int | torch.device) -> bool:
4677
4565
  return torch.device(device).type not in ["meta", "cpu"]
4678
4566
 
4679
4567
 
4680
- def get_total_byte_count(
4681
- model: PreTrainedModel, accelerator_device_map: dict, hf_quantizer: HfQuantizer | None = None
4682
- ):
4683
- """
4684
- This utility function calculates the total bytes count needed to load the model on each device.
4685
- This is useful for caching_allocator_warmup as we want to know how much cache we need to pre-allocate.
4686
- """
4687
-
4688
- total_byte_count = defaultdict(lambda: 0)
4689
- tied_param_names = model.all_tied_weights_keys.keys()
4690
- tp_plan = model._tp_plan if torch.distributed.is_available() and torch.distributed.is_initialized() else []
4691
-
4692
- for param_name, device in accelerator_device_map.items():
4693
- # Skip if the parameter has already been accounted for (tied weights)
4694
- if param_name in tied_param_names:
4695
- continue
4696
-
4697
- param = model.get_parameter_or_buffer(param_name)
4698
-
4699
- if hf_quantizer is not None:
4700
- dtype_size = hf_quantizer.param_element_size(model, param_name, param)
4701
- else:
4702
- dtype_size = param.element_size()
4703
-
4704
- param_byte_count = param.numel() * dtype_size
4705
-
4706
- if len(tp_plan) > 0:
4707
- is_part_of_plan = _get_parameter_tp_plan(param_name, tp_plan, is_weight=True) is not None
4708
- param_byte_count //= torch.distributed.get_world_size() if is_part_of_plan else 1
4709
-
4710
- total_byte_count[device] += param_byte_count
4711
- return total_byte_count
4712
-
4713
-
4714
- def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: HfQuantizer | None):
4568
+ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: Optional[HfQuantizer]):
4715
4569
  """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
4716
4570
  device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
4717
4571
  the model, which is actually the loading speed bottleneck.
@@ -4730,6 +4584,8 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
4730
4584
  - Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
4731
4585
  However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
4732
4586
  """
4587
+ factor = 2 if hf_quantizer is None else hf_quantizer.get_accelerator_warm_up_factor()
4588
+
4733
4589
  # Remove disk, cpu and meta devices, and cast to proper torch.device
4734
4590
  accelerator_device_map = {
4735
4591
  param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
@@ -4737,7 +4593,40 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
4737
4593
  if not accelerator_device_map:
4738
4594
  return
4739
4595
 
4740
- total_byte_count = get_total_byte_count(model, accelerator_device_map, hf_quantizer)
4596
+ tp_plan = getattr(model, "_tp_plan", []) or []
4597
+ tp_plan_regex = (
4598
+ re.compile("|".join([re.escape(plan) for plan in tp_plan]))
4599
+ if _torch_distributed_available and torch.distributed.is_initialized()
4600
+ else None
4601
+ )
4602
+ total_byte_count = defaultdict(lambda: 0)
4603
+ tied_param_names = model.all_tied_weights_keys.keys()
4604
+ for param_name, device in accelerator_device_map.items():
4605
+ # Skip if the parameter has already been accounted for (tied weights)
4606
+ if param_name in tied_param_names:
4607
+ continue
4608
+
4609
+ # For example in the case of MXFP4 quantization, we need to update the param name to the original param name
4610
+ # because the checkpoint contains blocks, and scales, but since we are dequantizing, we need to use the original param name
4611
+ if hf_quantizer is not None:
4612
+ param_name = hf_quantizer.get_param_name(param_name)
4613
+
4614
+ try:
4615
+ param = model.get_parameter_or_buffer(param_name)
4616
+ except AttributeError:
4617
+ # TODO: for now let's skip if we can't find the parameters
4618
+ if hf_quantizer is not None:
4619
+ continue
4620
+ raise AttributeError(f"Parameter {param_name} not found in model")
4621
+
4622
+ # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
4623
+ param_byte_count = param.numel() * param.element_size()
4624
+
4625
+ if tp_plan_regex is not None:
4626
+ generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
4627
+ param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
4628
+
4629
+ total_byte_count[device] += param_byte_count
4741
4630
 
4742
4631
  # This will kick off the caching allocator to avoid having to Malloc afterwards
4743
4632
  for device, byte_count in total_byte_count.items():
@@ -4757,9 +4646,9 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
4757
4646
  unused_memory = torch_accelerator_module.memory_reserved(
4758
4647
  index
4759
4648
  ) - torch_accelerator_module.memory_allocated(index)
4760
- byte_count = int(max(0, byte_count - unused_memory))
4761
- # We divide by 2 here as we allocate in fp16
4762
- _ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)
4649
+ byte_count = max(0, byte_count - unused_memory)
4650
+ # Allocate memory
4651
+ _ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
4763
4652
 
4764
4653
 
4765
4654
  class AttentionInterface(GeneralInterface):