transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1584) hide show
  1. transformers/__init__.py +27 -27
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +32 -33
  4. transformers/cache_utils.py +32 -139
  5. transformers/cli/chat.py +3 -3
  6. transformers/cli/serve.py +2 -2
  7. transformers/cli/transformers.py +2 -1
  8. transformers/configuration_utils.py +143 -101
  9. transformers/conversion_mapping.py +73 -6
  10. transformers/convert_slow_tokenizer.py +3 -8
  11. transformers/core_model_loading.py +215 -50
  12. transformers/data/processors/glue.py +0 -1
  13. transformers/data/processors/utils.py +0 -1
  14. transformers/data/processors/xnli.py +0 -1
  15. transformers/dependency_versions_table.py +5 -5
  16. transformers/distributed/configuration_utils.py +1 -2
  17. transformers/dynamic_module_utils.py +23 -23
  18. transformers/feature_extraction_sequence_utils.py +19 -23
  19. transformers/feature_extraction_utils.py +63 -31
  20. transformers/generation/candidate_generator.py +80 -33
  21. transformers/generation/configuration_utils.py +186 -131
  22. transformers/generation/continuous_batching/__init__.py +0 -1
  23. transformers/generation/continuous_batching/cache.py +81 -24
  24. transformers/generation/continuous_batching/cache_manager.py +155 -45
  25. transformers/generation/continuous_batching/continuous_api.py +152 -84
  26. transformers/generation/continuous_batching/requests.py +51 -3
  27. transformers/generation/continuous_batching/scheduler.py +127 -52
  28. transformers/generation/logits_process.py +0 -128
  29. transformers/generation/stopping_criteria.py +1 -1
  30. transformers/generation/streamers.py +0 -1
  31. transformers/generation/utils.py +107 -119
  32. transformers/generation/watermarking.py +8 -6
  33. transformers/hf_argparser.py +9 -13
  34. transformers/hyperparameter_search.py +1 -2
  35. transformers/image_processing_base.py +11 -21
  36. transformers/image_processing_utils.py +11 -12
  37. transformers/image_processing_utils_fast.py +68 -57
  38. transformers/image_transforms.py +29 -29
  39. transformers/image_utils.py +30 -32
  40. transformers/initialization.py +37 -0
  41. transformers/integrations/__init__.py +12 -0
  42. transformers/integrations/accelerate.py +44 -111
  43. transformers/integrations/aqlm.py +3 -5
  44. transformers/integrations/awq.py +3 -8
  45. transformers/integrations/bitnet.py +5 -8
  46. transformers/integrations/bitsandbytes.py +16 -15
  47. transformers/integrations/deepspeed.py +19 -4
  48. transformers/integrations/eetq.py +3 -6
  49. transformers/integrations/fbgemm_fp8.py +2 -3
  50. transformers/integrations/finegrained_fp8.py +14 -23
  51. transformers/integrations/flash_attention.py +2 -2
  52. transformers/integrations/flex_attention.py +1 -1
  53. transformers/integrations/fp_quant.py +4 -6
  54. transformers/integrations/ggml.py +0 -1
  55. transformers/integrations/higgs.py +2 -5
  56. transformers/integrations/hub_kernels.py +23 -5
  57. transformers/integrations/integration_utils.py +37 -3
  58. transformers/integrations/mistral.py +12 -0
  59. transformers/integrations/moe.py +240 -0
  60. transformers/integrations/mxfp4.py +9 -16
  61. transformers/integrations/peft.py +5 -0
  62. transformers/integrations/quanto.py +5 -2
  63. transformers/integrations/quark.py +2 -4
  64. transformers/integrations/spqr.py +3 -5
  65. transformers/integrations/tensor_parallel.py +167 -221
  66. transformers/integrations/torchao.py +4 -6
  67. transformers/integrations/vptq.py +3 -5
  68. transformers/loss/loss_lw_detr.py +356 -0
  69. transformers/loss/loss_utils.py +2 -0
  70. transformers/masking_utils.py +47 -51
  71. transformers/model_debugging_utils.py +4 -5
  72. transformers/modelcard.py +14 -192
  73. transformers/modeling_attn_mask_utils.py +19 -19
  74. transformers/modeling_flash_attention_utils.py +27 -27
  75. transformers/modeling_gguf_pytorch_utils.py +71 -24
  76. transformers/modeling_layers.py +21 -22
  77. transformers/modeling_outputs.py +242 -253
  78. transformers/modeling_rope_utils.py +110 -113
  79. transformers/modeling_utils.py +633 -576
  80. transformers/models/__init__.py +23 -0
  81. transformers/models/afmoe/configuration_afmoe.py +26 -29
  82. transformers/models/afmoe/modeling_afmoe.py +37 -49
  83. transformers/models/afmoe/modular_afmoe.py +21 -31
  84. transformers/models/aimv2/configuration_aimv2.py +2 -5
  85. transformers/models/aimv2/modeling_aimv2.py +24 -21
  86. transformers/models/aimv2/modular_aimv2.py +11 -9
  87. transformers/models/albert/configuration_albert.py +0 -1
  88. transformers/models/albert/modeling_albert.py +70 -69
  89. transformers/models/albert/tokenization_albert.py +1 -4
  90. transformers/models/align/configuration_align.py +0 -1
  91. transformers/models/align/modeling_align.py +73 -68
  92. transformers/models/align/processing_align.py +2 -30
  93. transformers/models/altclip/configuration_altclip.py +0 -1
  94. transformers/models/altclip/modeling_altclip.py +83 -80
  95. transformers/models/altclip/processing_altclip.py +2 -15
  96. transformers/models/apertus/__init__.py +0 -1
  97. transformers/models/apertus/configuration_apertus.py +18 -21
  98. transformers/models/apertus/modeling_apertus.py +35 -36
  99. transformers/models/apertus/modular_apertus.py +32 -31
  100. transformers/models/arcee/configuration_arcee.py +20 -23
  101. transformers/models/arcee/modeling_arcee.py +32 -35
  102. transformers/models/arcee/modular_arcee.py +20 -23
  103. transformers/models/aria/configuration_aria.py +20 -23
  104. transformers/models/aria/image_processing_aria.py +25 -27
  105. transformers/models/aria/modeling_aria.py +71 -70
  106. transformers/models/aria/modular_aria.py +85 -88
  107. transformers/models/aria/processing_aria.py +28 -35
  108. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
  109. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
  110. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +6 -8
  111. transformers/models/audioflamingo3/__init__.py +0 -1
  112. transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
  113. transformers/models/audioflamingo3/modeling_audioflamingo3.py +22 -23
  114. transformers/models/audioflamingo3/modular_audioflamingo3.py +12 -17
  115. transformers/models/audioflamingo3/processing_audioflamingo3.py +33 -30
  116. transformers/models/auto/auto_factory.py +5 -6
  117. transformers/models/auto/configuration_auto.py +53 -5
  118. transformers/models/auto/feature_extraction_auto.py +12 -10
  119. transformers/models/auto/image_processing_auto.py +17 -28
  120. transformers/models/auto/modeling_auto.py +38 -188
  121. transformers/models/auto/processing_auto.py +6 -1
  122. transformers/models/auto/tokenization_auto.py +147 -169
  123. transformers/models/auto/video_processing_auto.py +12 -10
  124. transformers/models/autoformer/configuration_autoformer.py +4 -7
  125. transformers/models/autoformer/modeling_autoformer.py +98 -100
  126. transformers/models/aya_vision/configuration_aya_vision.py +0 -1
  127. transformers/models/aya_vision/modeling_aya_vision.py +42 -40
  128. transformers/models/aya_vision/modular_aya_vision.py +26 -29
  129. transformers/models/aya_vision/processing_aya_vision.py +25 -53
  130. transformers/models/bamba/configuration_bamba.py +29 -32
  131. transformers/models/bamba/modeling_bamba.py +78 -83
  132. transformers/models/bamba/modular_bamba.py +68 -71
  133. transformers/models/bark/configuration_bark.py +4 -7
  134. transformers/models/bark/generation_configuration_bark.py +3 -5
  135. transformers/models/bark/modeling_bark.py +49 -55
  136. transformers/models/bark/processing_bark.py +19 -41
  137. transformers/models/bart/configuration_bart.py +0 -2
  138. transformers/models/bart/modeling_bart.py +122 -117
  139. transformers/models/barthez/tokenization_barthez.py +1 -4
  140. transformers/models/bartpho/tokenization_bartpho.py +6 -7
  141. transformers/models/beit/configuration_beit.py +0 -11
  142. transformers/models/beit/image_processing_beit.py +53 -56
  143. transformers/models/beit/image_processing_beit_fast.py +8 -10
  144. transformers/models/beit/modeling_beit.py +51 -53
  145. transformers/models/bert/configuration_bert.py +0 -1
  146. transformers/models/bert/modeling_bert.py +114 -122
  147. transformers/models/bert/tokenization_bert.py +2 -4
  148. transformers/models/bert/tokenization_bert_legacy.py +3 -5
  149. transformers/models/bert_generation/configuration_bert_generation.py +0 -1
  150. transformers/models/bert_generation/modeling_bert_generation.py +49 -49
  151. transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
  152. transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
  153. transformers/models/bertweet/tokenization_bertweet.py +1 -3
  154. transformers/models/big_bird/configuration_big_bird.py +0 -1
  155. transformers/models/big_bird/modeling_big_bird.py +110 -109
  156. transformers/models/big_bird/tokenization_big_bird.py +1 -4
  157. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +0 -1
  158. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +116 -111
  159. transformers/models/biogpt/configuration_biogpt.py +0 -1
  160. transformers/models/biogpt/modeling_biogpt.py +69 -71
  161. transformers/models/biogpt/modular_biogpt.py +59 -61
  162. transformers/models/biogpt/tokenization_biogpt.py +3 -5
  163. transformers/models/bit/configuration_bit.py +0 -1
  164. transformers/models/bit/image_processing_bit.py +21 -24
  165. transformers/models/bit/image_processing_bit_fast.py +0 -1
  166. transformers/models/bit/modeling_bit.py +14 -12
  167. transformers/models/bitnet/configuration_bitnet.py +18 -21
  168. transformers/models/bitnet/modeling_bitnet.py +32 -35
  169. transformers/models/bitnet/modular_bitnet.py +4 -6
  170. transformers/models/blenderbot/configuration_blenderbot.py +0 -1
  171. transformers/models/blenderbot/modeling_blenderbot.py +71 -95
  172. transformers/models/blenderbot/tokenization_blenderbot.py +6 -8
  173. transformers/models/blenderbot_small/configuration_blenderbot_small.py +0 -1
  174. transformers/models/blenderbot_small/modeling_blenderbot_small.py +73 -68
  175. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
  176. transformers/models/blip/configuration_blip.py +0 -1
  177. transformers/models/blip/image_processing_blip.py +17 -20
  178. transformers/models/blip/image_processing_blip_fast.py +0 -1
  179. transformers/models/blip/modeling_blip.py +62 -71
  180. transformers/models/blip/modeling_blip_text.py +71 -65
  181. transformers/models/blip/processing_blip.py +5 -36
  182. transformers/models/blip_2/configuration_blip_2.py +0 -1
  183. transformers/models/blip_2/modeling_blip_2.py +72 -71
  184. transformers/models/blip_2/processing_blip_2.py +8 -38
  185. transformers/models/bloom/configuration_bloom.py +0 -1
  186. transformers/models/bloom/modeling_bloom.py +71 -103
  187. transformers/models/blt/configuration_blt.py +71 -74
  188. transformers/models/blt/modeling_blt.py +235 -78
  189. transformers/models/blt/modular_blt.py +225 -62
  190. transformers/models/bridgetower/configuration_bridgetower.py +0 -1
  191. transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
  192. transformers/models/bridgetower/image_processing_bridgetower_fast.py +7 -10
  193. transformers/models/bridgetower/modeling_bridgetower.py +113 -109
  194. transformers/models/bridgetower/processing_bridgetower.py +2 -16
  195. transformers/models/bros/configuration_bros.py +0 -1
  196. transformers/models/bros/modeling_bros.py +86 -80
  197. transformers/models/bros/processing_bros.py +2 -12
  198. transformers/models/byt5/tokenization_byt5.py +4 -6
  199. transformers/models/camembert/configuration_camembert.py +0 -1
  200. transformers/models/camembert/modeling_camembert.py +196 -195
  201. transformers/models/camembert/modular_camembert.py +51 -54
  202. transformers/models/camembert/tokenization_camembert.py +1 -4
  203. transformers/models/canine/configuration_canine.py +0 -1
  204. transformers/models/canine/modeling_canine.py +79 -75
  205. transformers/models/canine/tokenization_canine.py +2 -1
  206. transformers/models/chameleon/configuration_chameleon.py +24 -27
  207. transformers/models/chameleon/image_processing_chameleon.py +21 -24
  208. transformers/models/chameleon/image_processing_chameleon_fast.py +0 -1
  209. transformers/models/chameleon/modeling_chameleon.py +62 -60
  210. transformers/models/chameleon/processing_chameleon.py +16 -41
  211. transformers/models/chinese_clip/configuration_chinese_clip.py +0 -1
  212. transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
  213. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
  214. transformers/models/chinese_clip/modeling_chinese_clip.py +71 -69
  215. transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
  216. transformers/models/clap/configuration_clap.py +0 -1
  217. transformers/models/clap/feature_extraction_clap.py +11 -12
  218. transformers/models/clap/modeling_clap.py +113 -104
  219. transformers/models/clap/processing_clap.py +2 -15
  220. transformers/models/clip/configuration_clip.py +0 -1
  221. transformers/models/clip/image_processing_clip.py +21 -24
  222. transformers/models/clip/image_processing_clip_fast.py +0 -1
  223. transformers/models/clip/modeling_clip.py +47 -46
  224. transformers/models/clip/processing_clip.py +2 -14
  225. transformers/models/clip/tokenization_clip.py +2 -5
  226. transformers/models/clipseg/configuration_clipseg.py +0 -1
  227. transformers/models/clipseg/modeling_clipseg.py +90 -87
  228. transformers/models/clipseg/processing_clipseg.py +8 -39
  229. transformers/models/clvp/configuration_clvp.py +1 -3
  230. transformers/models/clvp/feature_extraction_clvp.py +7 -10
  231. transformers/models/clvp/modeling_clvp.py +133 -118
  232. transformers/models/clvp/number_normalizer.py +1 -2
  233. transformers/models/clvp/processing_clvp.py +3 -20
  234. transformers/models/clvp/tokenization_clvp.py +0 -1
  235. transformers/models/code_llama/tokenization_code_llama.py +4 -7
  236. transformers/models/codegen/configuration_codegen.py +0 -1
  237. transformers/models/codegen/modeling_codegen.py +61 -52
  238. transformers/models/codegen/tokenization_codegen.py +5 -6
  239. transformers/models/cohere/configuration_cohere.py +20 -23
  240. transformers/models/cohere/modeling_cohere.py +36 -39
  241. transformers/models/cohere/modular_cohere.py +24 -28
  242. transformers/models/cohere/tokenization_cohere.py +5 -6
  243. transformers/models/cohere2/configuration_cohere2.py +21 -24
  244. transformers/models/cohere2/modeling_cohere2.py +35 -38
  245. transformers/models/cohere2/modular_cohere2.py +39 -41
  246. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +6 -8
  247. transformers/models/cohere2_vision/modeling_cohere2_vision.py +35 -33
  248. transformers/models/cohere2_vision/modular_cohere2_vision.py +21 -23
  249. transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
  250. transformers/models/colpali/configuration_colpali.py +0 -1
  251. transformers/models/colpali/modeling_colpali.py +14 -16
  252. transformers/models/colpali/modular_colpali.py +11 -51
  253. transformers/models/colpali/processing_colpali.py +14 -52
  254. transformers/models/colqwen2/modeling_colqwen2.py +20 -22
  255. transformers/models/colqwen2/modular_colqwen2.py +29 -68
  256. transformers/models/colqwen2/processing_colqwen2.py +16 -52
  257. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -2
  258. transformers/models/conditional_detr/image_processing_conditional_detr.py +64 -66
  259. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +22 -22
  260. transformers/models/conditional_detr/modeling_conditional_detr.py +82 -81
  261. transformers/models/conditional_detr/modular_conditional_detr.py +1 -3
  262. transformers/models/convbert/configuration_convbert.py +0 -1
  263. transformers/models/convbert/modeling_convbert.py +88 -87
  264. transformers/models/convbert/tokenization_convbert.py +0 -1
  265. transformers/models/convnext/configuration_convnext.py +0 -1
  266. transformers/models/convnext/image_processing_convnext.py +20 -23
  267. transformers/models/convnext/image_processing_convnext_fast.py +14 -19
  268. transformers/models/convnext/modeling_convnext.py +5 -8
  269. transformers/models/convnextv2/configuration_convnextv2.py +0 -1
  270. transformers/models/convnextv2/modeling_convnextv2.py +5 -8
  271. transformers/models/cpm/tokenization_cpm.py +6 -7
  272. transformers/models/cpm/tokenization_cpm_fast.py +3 -5
  273. transformers/models/cpmant/configuration_cpmant.py +0 -1
  274. transformers/models/cpmant/modeling_cpmant.py +38 -40
  275. transformers/models/cpmant/tokenization_cpmant.py +1 -3
  276. transformers/models/csm/configuration_csm.py +49 -51
  277. transformers/models/csm/generation_csm.py +31 -35
  278. transformers/models/csm/modeling_csm.py +81 -82
  279. transformers/models/csm/modular_csm.py +58 -58
  280. transformers/models/csm/processing_csm.py +25 -68
  281. transformers/models/ctrl/configuration_ctrl.py +0 -1
  282. transformers/models/ctrl/modeling_ctrl.py +52 -43
  283. transformers/models/ctrl/tokenization_ctrl.py +0 -1
  284. transformers/models/cvt/configuration_cvt.py +0 -1
  285. transformers/models/cvt/modeling_cvt.py +18 -16
  286. transformers/models/cwm/__init__.py +0 -1
  287. transformers/models/cwm/configuration_cwm.py +3 -5
  288. transformers/models/cwm/modeling_cwm.py +33 -35
  289. transformers/models/cwm/modular_cwm.py +10 -12
  290. transformers/models/d_fine/configuration_d_fine.py +3 -5
  291. transformers/models/d_fine/modeling_d_fine.py +127 -121
  292. transformers/models/d_fine/modular_d_fine.py +23 -13
  293. transformers/models/dab_detr/configuration_dab_detr.py +2 -3
  294. transformers/models/dab_detr/modeling_dab_detr.py +69 -71
  295. transformers/models/dac/configuration_dac.py +0 -1
  296. transformers/models/dac/feature_extraction_dac.py +6 -9
  297. transformers/models/dac/modeling_dac.py +21 -23
  298. transformers/models/data2vec/configuration_data2vec_audio.py +0 -1
  299. transformers/models/data2vec/configuration_data2vec_text.py +0 -1
  300. transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
  301. transformers/models/data2vec/modeling_data2vec_audio.py +52 -56
  302. transformers/models/data2vec/modeling_data2vec_text.py +98 -93
  303. transformers/models/data2vec/modeling_data2vec_vision.py +41 -42
  304. transformers/models/data2vec/modular_data2vec_audio.py +6 -1
  305. transformers/models/data2vec/modular_data2vec_text.py +58 -54
  306. transformers/models/dbrx/configuration_dbrx.py +27 -20
  307. transformers/models/dbrx/modeling_dbrx.py +40 -43
  308. transformers/models/dbrx/modular_dbrx.py +31 -33
  309. transformers/models/deberta/configuration_deberta.py +0 -1
  310. transformers/models/deberta/modeling_deberta.py +59 -60
  311. transformers/models/deberta/tokenization_deberta.py +2 -5
  312. transformers/models/deberta_v2/configuration_deberta_v2.py +0 -1
  313. transformers/models/deberta_v2/modeling_deberta_v2.py +65 -65
  314. transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
  315. transformers/models/decision_transformer/configuration_decision_transformer.py +0 -1
  316. transformers/models/decision_transformer/modeling_decision_transformer.py +56 -55
  317. transformers/models/deepseek_v2/configuration_deepseek_v2.py +34 -37
  318. transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -37
  319. transformers/models/deepseek_v2/modular_deepseek_v2.py +44 -44
  320. transformers/models/deepseek_v3/configuration_deepseek_v3.py +35 -38
  321. transformers/models/deepseek_v3/modeling_deepseek_v3.py +40 -38
  322. transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -7
  323. transformers/models/deepseek_vl/configuration_deepseek_vl.py +2 -3
  324. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +25 -26
  325. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +7 -7
  326. transformers/models/deepseek_vl/modeling_deepseek_vl.py +40 -36
  327. transformers/models/deepseek_vl/modular_deepseek_vl.py +14 -43
  328. transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
  329. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +3 -5
  330. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +35 -35
  331. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +16 -20
  332. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +42 -38
  333. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +80 -99
  334. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
  335. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -3
  336. transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
  337. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +17 -17
  338. transformers/models/deformable_detr/modeling_deformable_detr.py +67 -68
  339. transformers/models/deformable_detr/modular_deformable_detr.py +1 -3
  340. transformers/models/deit/configuration_deit.py +0 -1
  341. transformers/models/deit/image_processing_deit.py +18 -21
  342. transformers/models/deit/image_processing_deit_fast.py +0 -1
  343. transformers/models/deit/modeling_deit.py +16 -18
  344. transformers/models/depth_anything/configuration_depth_anything.py +2 -4
  345. transformers/models/depth_anything/modeling_depth_anything.py +5 -8
  346. transformers/models/depth_pro/configuration_depth_pro.py +0 -1
  347. transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
  348. transformers/models/depth_pro/image_processing_depth_pro_fast.py +6 -8
  349. transformers/models/depth_pro/modeling_depth_pro.py +21 -23
  350. transformers/models/detr/configuration_detr.py +1 -2
  351. transformers/models/detr/image_processing_detr.py +64 -66
  352. transformers/models/detr/image_processing_detr_fast.py +22 -23
  353. transformers/models/detr/modeling_detr.py +78 -73
  354. transformers/models/dia/configuration_dia.py +5 -8
  355. transformers/models/dia/feature_extraction_dia.py +6 -9
  356. transformers/models/dia/generation_dia.py +42 -45
  357. transformers/models/dia/modeling_dia.py +73 -65
  358. transformers/models/dia/modular_dia.py +63 -54
  359. transformers/models/dia/processing_dia.py +39 -29
  360. transformers/models/dia/tokenization_dia.py +3 -6
  361. transformers/models/diffllama/configuration_diffllama.py +20 -23
  362. transformers/models/diffllama/modeling_diffllama.py +44 -47
  363. transformers/models/diffllama/modular_diffllama.py +17 -19
  364. transformers/models/dinat/configuration_dinat.py +0 -1
  365. transformers/models/dinat/modeling_dinat.py +40 -42
  366. transformers/models/dinov2/configuration_dinov2.py +0 -1
  367. transformers/models/dinov2/modeling_dinov2.py +11 -13
  368. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +1 -1
  369. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +12 -13
  370. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +5 -7
  371. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +4 -7
  372. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +3 -6
  373. transformers/models/dinov3_vit/configuration_dinov3_vit.py +5 -8
  374. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +5 -7
  375. transformers/models/dinov3_vit/modeling_dinov3_vit.py +17 -16
  376. transformers/models/dinov3_vit/modular_dinov3_vit.py +14 -13
  377. transformers/models/distilbert/configuration_distilbert.py +0 -1
  378. transformers/models/distilbert/modeling_distilbert.py +55 -55
  379. transformers/models/distilbert/tokenization_distilbert.py +0 -1
  380. transformers/models/doge/__init__.py +0 -1
  381. transformers/models/doge/configuration_doge.py +25 -28
  382. transformers/models/doge/modeling_doge.py +43 -46
  383. transformers/models/doge/modular_doge.py +57 -58
  384. transformers/models/donut/configuration_donut_swin.py +0 -1
  385. transformers/models/donut/image_processing_donut.py +26 -29
  386. transformers/models/donut/image_processing_donut_fast.py +5 -11
  387. transformers/models/donut/modeling_donut_swin.py +60 -58
  388. transformers/models/donut/processing_donut.py +5 -26
  389. transformers/models/dots1/configuration_dots1.py +27 -29
  390. transformers/models/dots1/modeling_dots1.py +45 -39
  391. transformers/models/dots1/modular_dots1.py +0 -1
  392. transformers/models/dpr/configuration_dpr.py +0 -1
  393. transformers/models/dpr/modeling_dpr.py +37 -39
  394. transformers/models/dpr/tokenization_dpr.py +7 -9
  395. transformers/models/dpr/tokenization_dpr_fast.py +7 -9
  396. transformers/models/dpt/configuration_dpt.py +1 -2
  397. transformers/models/dpt/image_processing_dpt.py +65 -66
  398. transformers/models/dpt/image_processing_dpt_fast.py +14 -16
  399. transformers/models/dpt/modeling_dpt.py +19 -21
  400. transformers/models/dpt/modular_dpt.py +11 -13
  401. transformers/models/edgetam/configuration_edgetam.py +1 -2
  402. transformers/models/edgetam/modeling_edgetam.py +44 -43
  403. transformers/models/edgetam/modular_edgetam.py +17 -20
  404. transformers/models/edgetam_video/__init__.py +0 -1
  405. transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
  406. transformers/models/edgetam_video/modeling_edgetam_video.py +131 -120
  407. transformers/models/edgetam_video/modular_edgetam_video.py +29 -37
  408. transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
  409. transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
  410. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +5 -6
  411. transformers/models/efficientloftr/modeling_efficientloftr.py +41 -30
  412. transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
  413. transformers/models/efficientnet/configuration_efficientnet.py +0 -1
  414. transformers/models/efficientnet/image_processing_efficientnet.py +28 -32
  415. transformers/models/efficientnet/image_processing_efficientnet_fast.py +15 -17
  416. transformers/models/efficientnet/modeling_efficientnet.py +17 -15
  417. transformers/models/electra/configuration_electra.py +0 -1
  418. transformers/models/electra/modeling_electra.py +108 -103
  419. transformers/models/emu3/configuration_emu3.py +5 -7
  420. transformers/models/emu3/image_processing_emu3.py +44 -39
  421. transformers/models/emu3/modeling_emu3.py +67 -64
  422. transformers/models/emu3/modular_emu3.py +39 -35
  423. transformers/models/emu3/processing_emu3.py +18 -43
  424. transformers/models/encodec/configuration_encodec.py +2 -4
  425. transformers/models/encodec/feature_extraction_encodec.py +10 -13
  426. transformers/models/encodec/modeling_encodec.py +39 -29
  427. transformers/models/encoder_decoder/configuration_encoder_decoder.py +0 -1
  428. transformers/models/encoder_decoder/modeling_encoder_decoder.py +17 -19
  429. transformers/models/eomt/configuration_eomt.py +0 -1
  430. transformers/models/eomt/image_processing_eomt.py +53 -55
  431. transformers/models/eomt/image_processing_eomt_fast.py +59 -28
  432. transformers/models/eomt/modeling_eomt.py +23 -18
  433. transformers/models/eomt/modular_eomt.py +18 -13
  434. transformers/models/ernie/configuration_ernie.py +0 -1
  435. transformers/models/ernie/modeling_ernie.py +127 -132
  436. transformers/models/ernie/modular_ernie.py +97 -103
  437. transformers/models/ernie4_5/configuration_ernie4_5.py +18 -20
  438. transformers/models/ernie4_5/modeling_ernie4_5.py +32 -34
  439. transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
  440. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +27 -29
  441. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +52 -51
  442. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +16 -44
  443. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  444. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +329 -0
  445. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +455 -0
  446. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +231 -0
  447. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1895 -0
  448. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1901 -0
  449. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +249 -0
  450. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +593 -0
  451. transformers/models/esm/configuration_esm.py +2 -4
  452. transformers/models/esm/modeling_esm.py +38 -34
  453. transformers/models/esm/modeling_esmfold.py +48 -45
  454. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  455. transformers/models/esm/openfold_utils/loss.py +1 -2
  456. transformers/models/esm/openfold_utils/protein.py +13 -13
  457. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  458. transformers/models/esm/tokenization_esm.py +2 -4
  459. transformers/models/evolla/configuration_evolla.py +29 -32
  460. transformers/models/evolla/modeling_evolla.py +67 -62
  461. transformers/models/evolla/modular_evolla.py +53 -47
  462. transformers/models/evolla/processing_evolla.py +23 -35
  463. transformers/models/exaone4/configuration_exaone4.py +19 -22
  464. transformers/models/exaone4/modeling_exaone4.py +33 -36
  465. transformers/models/exaone4/modular_exaone4.py +40 -42
  466. transformers/models/falcon/configuration_falcon.py +22 -25
  467. transformers/models/falcon/modeling_falcon.py +75 -78
  468. transformers/models/falcon_h1/configuration_falcon_h1.py +40 -43
  469. transformers/models/falcon_h1/modeling_falcon_h1.py +80 -78
  470. transformers/models/falcon_h1/modular_falcon_h1.py +54 -50
  471. transformers/models/falcon_mamba/configuration_falcon_mamba.py +0 -1
  472. transformers/models/falcon_mamba/modeling_falcon_mamba.py +50 -47
  473. transformers/models/falcon_mamba/modular_falcon_mamba.py +16 -14
  474. transformers/models/fast_vlm/configuration_fast_vlm.py +1 -0
  475. transformers/models/fast_vlm/modeling_fast_vlm.py +43 -39
  476. transformers/models/fast_vlm/modular_fast_vlm.py +2 -3
  477. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -5
  478. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +68 -57
  479. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +2 -3
  480. transformers/models/flaubert/configuration_flaubert.py +0 -1
  481. transformers/models/flaubert/modeling_flaubert.py +138 -143
  482. transformers/models/flaubert/tokenization_flaubert.py +3 -5
  483. transformers/models/flava/configuration_flava.py +5 -6
  484. transformers/models/flava/image_processing_flava.py +66 -67
  485. transformers/models/flava/image_processing_flava_fast.py +42 -45
  486. transformers/models/flava/modeling_flava.py +111 -107
  487. transformers/models/flava/processing_flava.py +2 -12
  488. transformers/models/flex_olmo/__init__.py +0 -1
  489. transformers/models/flex_olmo/configuration_flex_olmo.py +23 -25
  490. transformers/models/flex_olmo/modeling_flex_olmo.py +44 -43
  491. transformers/models/flex_olmo/modular_flex_olmo.py +35 -37
  492. transformers/models/florence2/configuration_florence2.py +0 -1
  493. transformers/models/florence2/modeling_florence2.py +59 -43
  494. transformers/models/florence2/modular_florence2.py +65 -81
  495. transformers/models/florence2/processing_florence2.py +18 -47
  496. transformers/models/fnet/configuration_fnet.py +0 -1
  497. transformers/models/fnet/modeling_fnet.py +76 -80
  498. transformers/models/fnet/tokenization_fnet.py +0 -1
  499. transformers/models/focalnet/configuration_focalnet.py +0 -1
  500. transformers/models/focalnet/modeling_focalnet.py +39 -41
  501. transformers/models/fsmt/configuration_fsmt.py +0 -1
  502. transformers/models/fsmt/modeling_fsmt.py +47 -48
  503. transformers/models/fsmt/tokenization_fsmt.py +3 -5
  504. transformers/models/funnel/configuration_funnel.py +0 -1
  505. transformers/models/funnel/modeling_funnel.py +91 -93
  506. transformers/models/funnel/tokenization_funnel.py +2 -5
  507. transformers/models/fuyu/configuration_fuyu.py +23 -26
  508. transformers/models/fuyu/image_processing_fuyu.py +29 -31
  509. transformers/models/fuyu/image_processing_fuyu_fast.py +12 -13
  510. transformers/models/fuyu/modeling_fuyu.py +29 -30
  511. transformers/models/fuyu/processing_fuyu.py +23 -34
  512. transformers/models/gemma/configuration_gemma.py +20 -23
  513. transformers/models/gemma/modeling_gemma.py +42 -46
  514. transformers/models/gemma/modular_gemma.py +37 -40
  515. transformers/models/gemma/tokenization_gemma.py +3 -6
  516. transformers/models/gemma2/configuration_gemma2.py +25 -28
  517. transformers/models/gemma2/modeling_gemma2.py +35 -38
  518. transformers/models/gemma2/modular_gemma2.py +56 -58
  519. transformers/models/gemma3/configuration_gemma3.py +28 -29
  520. transformers/models/gemma3/image_processing_gemma3.py +29 -31
  521. transformers/models/gemma3/image_processing_gemma3_fast.py +9 -11
  522. transformers/models/gemma3/modeling_gemma3.py +112 -94
  523. transformers/models/gemma3/modular_gemma3.py +110 -91
  524. transformers/models/gemma3/processing_gemma3.py +5 -5
  525. transformers/models/gemma3n/configuration_gemma3n.py +12 -10
  526. transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
  527. transformers/models/gemma3n/modeling_gemma3n.py +127 -98
  528. transformers/models/gemma3n/modular_gemma3n.py +117 -84
  529. transformers/models/gemma3n/processing_gemma3n.py +12 -26
  530. transformers/models/git/configuration_git.py +0 -1
  531. transformers/models/git/modeling_git.py +250 -197
  532. transformers/models/git/processing_git.py +2 -14
  533. transformers/models/glm/configuration_glm.py +19 -21
  534. transformers/models/glm/modeling_glm.py +33 -36
  535. transformers/models/glm/modular_glm.py +4 -7
  536. transformers/models/glm4/configuration_glm4.py +19 -21
  537. transformers/models/glm4/modeling_glm4.py +36 -38
  538. transformers/models/glm4/modular_glm4.py +8 -10
  539. transformers/models/glm46v/configuration_glm46v.py +0 -1
  540. transformers/models/glm46v/image_processing_glm46v.py +35 -40
  541. transformers/models/glm46v/image_processing_glm46v_fast.py +7 -7
  542. transformers/models/glm46v/modeling_glm46v.py +54 -52
  543. transformers/models/glm46v/modular_glm46v.py +4 -3
  544. transformers/models/glm46v/processing_glm46v.py +7 -41
  545. transformers/models/glm46v/video_processing_glm46v.py +9 -11
  546. transformers/models/glm4_moe/configuration_glm4_moe.py +25 -28
  547. transformers/models/glm4_moe/modeling_glm4_moe.py +41 -40
  548. transformers/models/glm4_moe/modular_glm4_moe.py +27 -30
  549. transformers/models/glm4_moe_lite/__init__.py +28 -0
  550. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +235 -0
  551. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
  552. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +304 -0
  553. transformers/models/glm4v/configuration_glm4v.py +14 -17
  554. transformers/models/glm4v/image_processing_glm4v.py +34 -40
  555. transformers/models/glm4v/image_processing_glm4v_fast.py +6 -7
  556. transformers/models/glm4v/modeling_glm4v.py +148 -156
  557. transformers/models/glm4v/modular_glm4v.py +142 -185
  558. transformers/models/glm4v/processing_glm4v.py +7 -41
  559. transformers/models/glm4v/video_processing_glm4v.py +9 -11
  560. transformers/models/glm4v_moe/configuration_glm4v_moe.py +119 -122
  561. transformers/models/glm4v_moe/modeling_glm4v_moe.py +275 -319
  562. transformers/models/glm4v_moe/modular_glm4v_moe.py +66 -163
  563. transformers/models/glm_image/__init__.py +31 -0
  564. transformers/models/glm_image/configuration_glm_image.py +352 -0
  565. transformers/models/glm_image/image_processing_glm_image.py +503 -0
  566. transformers/models/glm_image/image_processing_glm_image_fast.py +296 -0
  567. transformers/models/glm_image/modeling_glm_image.py +1590 -0
  568. transformers/models/glm_image/modular_glm_image.py +1480 -0
  569. transformers/models/glm_image/processing_glm_image.py +217 -0
  570. transformers/models/glmasr/__init__.py +29 -0
  571. transformers/models/glmasr/configuration_glmasr.py +196 -0
  572. transformers/models/glmasr/modeling_glmasr.py +511 -0
  573. transformers/models/glmasr/modular_glmasr.py +431 -0
  574. transformers/models/glmasr/processing_glmasr.py +331 -0
  575. transformers/models/glpn/configuration_glpn.py +0 -1
  576. transformers/models/glpn/image_processing_glpn.py +11 -12
  577. transformers/models/glpn/image_processing_glpn_fast.py +8 -10
  578. transformers/models/glpn/modeling_glpn.py +10 -12
  579. transformers/models/got_ocr2/configuration_got_ocr2.py +5 -8
  580. transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
  581. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +6 -8
  582. transformers/models/got_ocr2/modeling_got_ocr2.py +48 -45
  583. transformers/models/got_ocr2/modular_got_ocr2.py +31 -34
  584. transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
  585. transformers/models/gpt2/configuration_gpt2.py +0 -1
  586. transformers/models/gpt2/modeling_gpt2.py +114 -113
  587. transformers/models/gpt2/tokenization_gpt2.py +6 -9
  588. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +0 -1
  589. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +76 -88
  590. transformers/models/gpt_neo/configuration_gpt_neo.py +0 -1
  591. transformers/models/gpt_neo/modeling_gpt_neo.py +77 -66
  592. transformers/models/gpt_neox/configuration_gpt_neox.py +19 -22
  593. transformers/models/gpt_neox/modeling_gpt_neox.py +71 -73
  594. transformers/models/gpt_neox/modular_gpt_neox.py +64 -66
  595. transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
  596. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +15 -18
  597. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +42 -45
  598. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
  599. transformers/models/gpt_oss/configuration_gpt_oss.py +38 -24
  600. transformers/models/gpt_oss/modeling_gpt_oss.py +40 -44
  601. transformers/models/gpt_oss/modular_gpt_oss.py +22 -26
  602. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  603. transformers/models/gptj/configuration_gptj.py +0 -1
  604. transformers/models/gptj/modeling_gptj.py +96 -86
  605. transformers/models/granite/configuration_granite.py +23 -26
  606. transformers/models/granite/modeling_granite.py +40 -42
  607. transformers/models/granite/modular_granite.py +29 -31
  608. transformers/models/granite_speech/configuration_granite_speech.py +0 -1
  609. transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
  610. transformers/models/granite_speech/modeling_granite_speech.py +36 -24
  611. transformers/models/granite_speech/processing_granite_speech.py +11 -4
  612. transformers/models/granitemoe/configuration_granitemoe.py +26 -29
  613. transformers/models/granitemoe/modeling_granitemoe.py +37 -40
  614. transformers/models/granitemoe/modular_granitemoe.py +22 -25
  615. transformers/models/granitemoehybrid/__init__.py +0 -1
  616. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +41 -40
  617. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +92 -86
  618. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +29 -21
  619. transformers/models/granitemoeshared/configuration_granitemoeshared.py +27 -30
  620. transformers/models/granitemoeshared/modeling_granitemoeshared.py +50 -55
  621. transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
  622. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -4
  623. transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
  624. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +17 -18
  625. transformers/models/grounding_dino/modeling_grounding_dino.py +95 -97
  626. transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
  627. transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
  628. transformers/models/groupvit/configuration_groupvit.py +0 -1
  629. transformers/models/groupvit/modeling_groupvit.py +75 -71
  630. transformers/models/helium/configuration_helium.py +20 -22
  631. transformers/models/helium/modeling_helium.py +34 -37
  632. transformers/models/helium/modular_helium.py +3 -7
  633. transformers/models/herbert/tokenization_herbert.py +4 -6
  634. transformers/models/hgnet_v2/configuration_hgnet_v2.py +0 -1
  635. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -9
  636. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -9
  637. transformers/models/hiera/configuration_hiera.py +0 -1
  638. transformers/models/hiera/modeling_hiera.py +60 -62
  639. transformers/models/hubert/configuration_hubert.py +0 -1
  640. transformers/models/hubert/modeling_hubert.py +39 -37
  641. transformers/models/hubert/modular_hubert.py +12 -11
  642. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +21 -24
  643. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +31 -34
  644. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +4 -6
  645. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  646. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +25 -28
  647. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +44 -39
  648. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +9 -9
  649. transformers/models/ibert/configuration_ibert.py +0 -1
  650. transformers/models/ibert/modeling_ibert.py +76 -62
  651. transformers/models/ibert/quant_modules.py +0 -1
  652. transformers/models/idefics/configuration_idefics.py +0 -1
  653. transformers/models/idefics/image_processing_idefics.py +13 -15
  654. transformers/models/idefics/modeling_idefics.py +70 -61
  655. transformers/models/idefics/perceiver.py +1 -3
  656. transformers/models/idefics/processing_idefics.py +32 -48
  657. transformers/models/idefics/vision.py +22 -24
  658. transformers/models/idefics2/configuration_idefics2.py +0 -1
  659. transformers/models/idefics2/image_processing_idefics2.py +31 -32
  660. transformers/models/idefics2/image_processing_idefics2_fast.py +7 -8
  661. transformers/models/idefics2/modeling_idefics2.py +63 -59
  662. transformers/models/idefics2/processing_idefics2.py +10 -68
  663. transformers/models/idefics3/configuration_idefics3.py +0 -1
  664. transformers/models/idefics3/image_processing_idefics3.py +42 -43
  665. transformers/models/idefics3/image_processing_idefics3_fast.py +11 -12
  666. transformers/models/idefics3/modeling_idefics3.py +57 -55
  667. transformers/models/idefics3/processing_idefics3.py +15 -69
  668. transformers/models/ijepa/configuration_ijepa.py +0 -1
  669. transformers/models/ijepa/modeling_ijepa.py +10 -11
  670. transformers/models/ijepa/modular_ijepa.py +5 -7
  671. transformers/models/imagegpt/configuration_imagegpt.py +0 -1
  672. transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
  673. transformers/models/imagegpt/image_processing_imagegpt_fast.py +9 -14
  674. transformers/models/imagegpt/modeling_imagegpt.py +66 -60
  675. transformers/models/informer/configuration_informer.py +6 -9
  676. transformers/models/informer/modeling_informer.py +84 -86
  677. transformers/models/informer/modular_informer.py +13 -16
  678. transformers/models/instructblip/configuration_instructblip.py +0 -1
  679. transformers/models/instructblip/modeling_instructblip.py +45 -44
  680. transformers/models/instructblip/processing_instructblip.py +10 -36
  681. transformers/models/instructblipvideo/configuration_instructblipvideo.py +0 -1
  682. transformers/models/instructblipvideo/modeling_instructblipvideo.py +107 -105
  683. transformers/models/instructblipvideo/modular_instructblipvideo.py +34 -36
  684. transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
  685. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +4 -6
  686. transformers/models/internvl/configuration_internvl.py +0 -1
  687. transformers/models/internvl/modeling_internvl.py +52 -51
  688. transformers/models/internvl/modular_internvl.py +24 -30
  689. transformers/models/internvl/processing_internvl.py +12 -45
  690. transformers/models/internvl/video_processing_internvl.py +8 -10
  691. transformers/models/jais2/__init__.py +27 -0
  692. transformers/models/jais2/configuration_jais2.py +150 -0
  693. transformers/models/jais2/modeling_jais2.py +484 -0
  694. transformers/models/jais2/modular_jais2.py +194 -0
  695. transformers/models/jamba/configuration_jamba.py +0 -1
  696. transformers/models/jamba/modeling_jamba.py +67 -65
  697. transformers/models/jamba/modular_jamba.py +54 -55
  698. transformers/models/janus/configuration_janus.py +0 -1
  699. transformers/models/janus/image_processing_janus.py +35 -37
  700. transformers/models/janus/image_processing_janus_fast.py +12 -14
  701. transformers/models/janus/modeling_janus.py +56 -50
  702. transformers/models/janus/modular_janus.py +76 -70
  703. transformers/models/janus/processing_janus.py +17 -43
  704. transformers/models/jetmoe/configuration_jetmoe.py +20 -23
  705. transformers/models/jetmoe/modeling_jetmoe.py +41 -44
  706. transformers/models/jetmoe/modular_jetmoe.py +31 -33
  707. transformers/models/kosmos2/configuration_kosmos2.py +0 -1
  708. transformers/models/kosmos2/modeling_kosmos2.py +159 -148
  709. transformers/models/kosmos2/processing_kosmos2.py +40 -55
  710. transformers/models/kosmos2_5/__init__.py +0 -1
  711. transformers/models/kosmos2_5/configuration_kosmos2_5.py +0 -1
  712. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
  713. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +4 -13
  714. transformers/models/kosmos2_5/modeling_kosmos2_5.py +118 -110
  715. transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
  716. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +23 -25
  717. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
  718. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +67 -68
  719. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +28 -22
  720. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
  721. transformers/models/lasr/configuration_lasr.py +5 -3
  722. transformers/models/lasr/feature_extraction_lasr.py +10 -12
  723. transformers/models/lasr/modeling_lasr.py +21 -23
  724. transformers/models/lasr/modular_lasr.py +16 -11
  725. transformers/models/lasr/processing_lasr.py +12 -8
  726. transformers/models/lasr/tokenization_lasr.py +2 -4
  727. transformers/models/layoutlm/configuration_layoutlm.py +0 -1
  728. transformers/models/layoutlm/modeling_layoutlm.py +72 -72
  729. transformers/models/layoutlmv2/configuration_layoutlmv2.py +0 -1
  730. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
  731. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +5 -7
  732. transformers/models/layoutlmv2/modeling_layoutlmv2.py +60 -50
  733. transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
  734. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +64 -74
  735. transformers/models/layoutlmv3/configuration_layoutlmv3.py +0 -1
  736. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
  737. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +7 -9
  738. transformers/models/layoutlmv3/modeling_layoutlmv3.py +78 -56
  739. transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
  740. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
  741. transformers/models/layoutxlm/configuration_layoutxlm.py +0 -1
  742. transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
  743. transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
  744. transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
  745. transformers/models/led/configuration_led.py +1 -4
  746. transformers/models/led/modeling_led.py +119 -267
  747. transformers/models/levit/configuration_levit.py +0 -1
  748. transformers/models/levit/image_processing_levit.py +19 -21
  749. transformers/models/levit/image_processing_levit_fast.py +0 -1
  750. transformers/models/levit/modeling_levit.py +35 -19
  751. transformers/models/lfm2/configuration_lfm2.py +22 -23
  752. transformers/models/lfm2/modeling_lfm2.py +43 -45
  753. transformers/models/lfm2/modular_lfm2.py +29 -29
  754. transformers/models/lfm2_moe/__init__.py +0 -1
  755. transformers/models/lfm2_moe/configuration_lfm2_moe.py +1 -2
  756. transformers/models/lfm2_moe/modeling_lfm2_moe.py +58 -49
  757. transformers/models/lfm2_moe/modular_lfm2_moe.py +13 -37
  758. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
  759. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +34 -5
  760. transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -38
  761. transformers/models/lfm2_vl/modular_lfm2_vl.py +28 -29
  762. transformers/models/lfm2_vl/processing_lfm2_vl.py +96 -76
  763. transformers/models/lightglue/image_processing_lightglue.py +16 -15
  764. transformers/models/lightglue/image_processing_lightglue_fast.py +5 -6
  765. transformers/models/lightglue/modeling_lightglue.py +28 -30
  766. transformers/models/lightglue/modular_lightglue.py +28 -28
  767. transformers/models/lighton_ocr/__init__.py +28 -0
  768. transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
  769. transformers/models/lighton_ocr/modeling_lighton_ocr.py +460 -0
  770. transformers/models/lighton_ocr/modular_lighton_ocr.py +403 -0
  771. transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
  772. transformers/models/lilt/configuration_lilt.py +0 -1
  773. transformers/models/lilt/modeling_lilt.py +72 -70
  774. transformers/models/llama/configuration_llama.py +21 -24
  775. transformers/models/llama/modeling_llama.py +32 -35
  776. transformers/models/llama/tokenization_llama.py +2 -4
  777. transformers/models/llama4/configuration_llama4.py +20 -22
  778. transformers/models/llama4/image_processing_llama4_fast.py +9 -11
  779. transformers/models/llama4/modeling_llama4.py +78 -75
  780. transformers/models/llama4/processing_llama4.py +33 -57
  781. transformers/models/llava/configuration_llava.py +0 -1
  782. transformers/models/llava/image_processing_llava.py +25 -28
  783. transformers/models/llava/image_processing_llava_fast.py +6 -8
  784. transformers/models/llava/modeling_llava.py +47 -44
  785. transformers/models/llava/processing_llava.py +18 -51
  786. transformers/models/llava_next/configuration_llava_next.py +0 -1
  787. transformers/models/llava_next/image_processing_llava_next.py +43 -45
  788. transformers/models/llava_next/image_processing_llava_next_fast.py +5 -7
  789. transformers/models/llava_next/modeling_llava_next.py +49 -47
  790. transformers/models/llava_next/processing_llava_next.py +18 -47
  791. transformers/models/llava_next_video/configuration_llava_next_video.py +0 -1
  792. transformers/models/llava_next_video/modeling_llava_next_video.py +60 -58
  793. transformers/models/llava_next_video/modular_llava_next_video.py +51 -49
  794. transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
  795. transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
  796. transformers/models/llava_onevision/configuration_llava_onevision.py +0 -1
  797. transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
  798. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +6 -8
  799. transformers/models/llava_onevision/modeling_llava_onevision.py +67 -65
  800. transformers/models/llava_onevision/modular_llava_onevision.py +58 -56
  801. transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
  802. transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
  803. transformers/models/longcat_flash/__init__.py +0 -1
  804. transformers/models/longcat_flash/configuration_longcat_flash.py +32 -35
  805. transformers/models/longcat_flash/modeling_longcat_flash.py +32 -32
  806. transformers/models/longcat_flash/modular_longcat_flash.py +18 -19
  807. transformers/models/longformer/configuration_longformer.py +1 -4
  808. transformers/models/longformer/modeling_longformer.py +99 -101
  809. transformers/models/longt5/configuration_longt5.py +0 -1
  810. transformers/models/longt5/modeling_longt5.py +43 -48
  811. transformers/models/luke/configuration_luke.py +0 -1
  812. transformers/models/luke/modeling_luke.py +179 -181
  813. transformers/models/luke/tokenization_luke.py +99 -105
  814. transformers/models/lw_detr/__init__.py +27 -0
  815. transformers/models/lw_detr/configuration_lw_detr.py +374 -0
  816. transformers/models/lw_detr/modeling_lw_detr.py +1698 -0
  817. transformers/models/lw_detr/modular_lw_detr.py +1611 -0
  818. transformers/models/lxmert/configuration_lxmert.py +0 -1
  819. transformers/models/lxmert/modeling_lxmert.py +63 -74
  820. transformers/models/m2m_100/configuration_m2m_100.py +0 -1
  821. transformers/models/m2m_100/modeling_m2m_100.py +79 -71
  822. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  823. transformers/models/mamba/configuration_mamba.py +0 -1
  824. transformers/models/mamba/modeling_mamba.py +44 -44
  825. transformers/models/mamba2/configuration_mamba2.py +0 -1
  826. transformers/models/mamba2/modeling_mamba2.py +67 -68
  827. transformers/models/marian/configuration_marian.py +1 -2
  828. transformers/models/marian/modeling_marian.py +87 -86
  829. transformers/models/marian/tokenization_marian.py +6 -6
  830. transformers/models/markuplm/configuration_markuplm.py +0 -1
  831. transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
  832. transformers/models/markuplm/modeling_markuplm.py +65 -70
  833. transformers/models/markuplm/processing_markuplm.py +31 -38
  834. transformers/models/markuplm/tokenization_markuplm.py +67 -77
  835. transformers/models/mask2former/configuration_mask2former.py +5 -8
  836. transformers/models/mask2former/image_processing_mask2former.py +84 -85
  837. transformers/models/mask2former/image_processing_mask2former_fast.py +30 -33
  838. transformers/models/mask2former/modeling_mask2former.py +99 -92
  839. transformers/models/mask2former/modular_mask2former.py +6 -8
  840. transformers/models/maskformer/configuration_maskformer.py +6 -9
  841. transformers/models/maskformer/configuration_maskformer_swin.py +0 -1
  842. transformers/models/maskformer/image_processing_maskformer.py +84 -85
  843. transformers/models/maskformer/image_processing_maskformer_fast.py +29 -33
  844. transformers/models/maskformer/modeling_maskformer.py +65 -59
  845. transformers/models/maskformer/modeling_maskformer_swin.py +34 -32
  846. transformers/models/mbart/configuration_mbart.py +1 -1
  847. transformers/models/mbart/modeling_mbart.py +118 -113
  848. transformers/models/mbart/tokenization_mbart.py +2 -4
  849. transformers/models/mbart50/tokenization_mbart50.py +3 -5
  850. transformers/models/megatron_bert/configuration_megatron_bert.py +0 -1
  851. transformers/models/megatron_bert/modeling_megatron_bert.py +141 -150
  852. transformers/models/metaclip_2/modeling_metaclip_2.py +48 -46
  853. transformers/models/metaclip_2/modular_metaclip_2.py +21 -21
  854. transformers/models/mgp_str/configuration_mgp_str.py +0 -1
  855. transformers/models/mgp_str/modeling_mgp_str.py +14 -16
  856. transformers/models/mgp_str/processing_mgp_str.py +3 -20
  857. transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
  858. transformers/models/mimi/configuration_mimi.py +38 -40
  859. transformers/models/mimi/modeling_mimi.py +100 -82
  860. transformers/models/minimax/__init__.py +0 -1
  861. transformers/models/minimax/configuration_minimax.py +32 -36
  862. transformers/models/minimax/modeling_minimax.py +57 -47
  863. transformers/models/minimax/modular_minimax.py +62 -54
  864. transformers/models/minimax_m2/__init__.py +28 -0
  865. transformers/models/minimax_m2/configuration_minimax_m2.py +211 -0
  866. transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
  867. transformers/models/minimax_m2/modular_minimax_m2.py +369 -0
  868. transformers/models/ministral/configuration_ministral.py +20 -22
  869. transformers/models/ministral/modeling_ministral.py +32 -34
  870. transformers/models/ministral/modular_ministral.py +27 -29
  871. transformers/models/ministral3/configuration_ministral3.py +19 -22
  872. transformers/models/ministral3/modeling_ministral3.py +32 -34
  873. transformers/models/ministral3/modular_ministral3.py +4 -5
  874. transformers/models/mistral/configuration_mistral.py +19 -22
  875. transformers/models/mistral/modeling_mistral.py +32 -34
  876. transformers/models/mistral/modular_mistral.py +11 -12
  877. transformers/models/mistral3/configuration_mistral3.py +0 -1
  878. transformers/models/mistral3/modeling_mistral3.py +53 -46
  879. transformers/models/mistral3/modular_mistral3.py +38 -36
  880. transformers/models/mixtral/configuration_mixtral.py +24 -27
  881. transformers/models/mixtral/modeling_mixtral.py +47 -42
  882. transformers/models/mixtral/modular_mixtral.py +32 -31
  883. transformers/models/mlcd/configuration_mlcd.py +0 -1
  884. transformers/models/mlcd/modeling_mlcd.py +16 -12
  885. transformers/models/mlcd/modular_mlcd.py +13 -11
  886. transformers/models/mllama/configuration_mllama.py +5 -8
  887. transformers/models/mllama/image_processing_mllama.py +23 -25
  888. transformers/models/mllama/image_processing_mllama_fast.py +5 -6
  889. transformers/models/mllama/modeling_mllama.py +94 -86
  890. transformers/models/mllama/processing_mllama.py +6 -55
  891. transformers/models/mluke/tokenization_mluke.py +97 -103
  892. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -3
  893. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +95 -97
  894. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -3
  895. transformers/models/mobilebert/configuration_mobilebert.py +0 -1
  896. transformers/models/mobilebert/modeling_mobilebert.py +77 -85
  897. transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
  898. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
  899. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
  900. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
  901. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
  902. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
  903. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
  904. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +10 -12
  905. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +17 -20
  906. transformers/models/mobilevit/configuration_mobilevit.py +0 -1
  907. transformers/models/mobilevit/image_processing_mobilevit.py +46 -49
  908. transformers/models/mobilevit/image_processing_mobilevit_fast.py +9 -11
  909. transformers/models/mobilevit/modeling_mobilevit.py +21 -19
  910. transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
  911. transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -20
  912. transformers/models/modernbert/configuration_modernbert.py +34 -34
  913. transformers/models/modernbert/modeling_modernbert.py +135 -126
  914. transformers/models/modernbert/modular_modernbert.py +167 -156
  915. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +30 -32
  916. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -48
  917. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +78 -71
  918. transformers/models/moonshine/configuration_moonshine.py +22 -24
  919. transformers/models/moonshine/modeling_moonshine.py +64 -66
  920. transformers/models/moonshine/modular_moonshine.py +72 -73
  921. transformers/models/moshi/configuration_moshi.py +18 -21
  922. transformers/models/moshi/modeling_moshi.py +150 -183
  923. transformers/models/mpnet/configuration_mpnet.py +0 -1
  924. transformers/models/mpnet/modeling_mpnet.py +57 -57
  925. transformers/models/mpnet/tokenization_mpnet.py +1 -4
  926. transformers/models/mpt/configuration_mpt.py +1 -9
  927. transformers/models/mpt/modeling_mpt.py +58 -60
  928. transformers/models/mra/configuration_mra.py +0 -1
  929. transformers/models/mra/modeling_mra.py +58 -57
  930. transformers/models/mt5/configuration_mt5.py +2 -4
  931. transformers/models/mt5/modeling_mt5.py +75 -87
  932. transformers/models/musicgen/configuration_musicgen.py +0 -1
  933. transformers/models/musicgen/modeling_musicgen.py +113 -120
  934. transformers/models/musicgen/processing_musicgen.py +3 -21
  935. transformers/models/musicgen_melody/configuration_musicgen_melody.py +0 -1
  936. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
  937. transformers/models/musicgen_melody/modeling_musicgen_melody.py +110 -109
  938. transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
  939. transformers/models/mvp/configuration_mvp.py +0 -1
  940. transformers/models/mvp/modeling_mvp.py +122 -119
  941. transformers/models/myt5/tokenization_myt5.py +8 -10
  942. transformers/models/nanochat/configuration_nanochat.py +0 -1
  943. transformers/models/nanochat/modeling_nanochat.py +33 -36
  944. transformers/models/nanochat/modular_nanochat.py +12 -14
  945. transformers/models/nemotron/configuration_nemotron.py +20 -23
  946. transformers/models/nemotron/modeling_nemotron.py +51 -54
  947. transformers/models/nllb/tokenization_nllb.py +7 -9
  948. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -1
  949. transformers/models/nllb_moe/modeling_nllb_moe.py +77 -69
  950. transformers/models/nougat/image_processing_nougat.py +29 -32
  951. transformers/models/nougat/image_processing_nougat_fast.py +4 -6
  952. transformers/models/nougat/processing_nougat.py +37 -39
  953. transformers/models/nougat/tokenization_nougat.py +16 -23
  954. transformers/models/nystromformer/configuration_nystromformer.py +0 -1
  955. transformers/models/nystromformer/modeling_nystromformer.py +68 -63
  956. transformers/models/olmo/configuration_olmo.py +18 -21
  957. transformers/models/olmo/modeling_olmo.py +32 -35
  958. transformers/models/olmo/modular_olmo.py +5 -9
  959. transformers/models/olmo2/configuration_olmo2.py +18 -21
  960. transformers/models/olmo2/modeling_olmo2.py +33 -36
  961. transformers/models/olmo2/modular_olmo2.py +29 -31
  962. transformers/models/olmo3/__init__.py +0 -1
  963. transformers/models/olmo3/configuration_olmo3.py +20 -23
  964. transformers/models/olmo3/modeling_olmo3.py +32 -35
  965. transformers/models/olmo3/modular_olmo3.py +31 -33
  966. transformers/models/olmoe/configuration_olmoe.py +24 -26
  967. transformers/models/olmoe/modeling_olmoe.py +49 -43
  968. transformers/models/olmoe/modular_olmoe.py +16 -15
  969. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -3
  970. transformers/models/omdet_turbo/modeling_omdet_turbo.py +42 -40
  971. transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
  972. transformers/models/oneformer/configuration_oneformer.py +5 -8
  973. transformers/models/oneformer/image_processing_oneformer.py +83 -84
  974. transformers/models/oneformer/image_processing_oneformer_fast.py +33 -34
  975. transformers/models/oneformer/modeling_oneformer.py +130 -162
  976. transformers/models/oneformer/processing_oneformer.py +28 -43
  977. transformers/models/openai/configuration_openai.py +0 -1
  978. transformers/models/openai/modeling_openai.py +62 -51
  979. transformers/models/openai/tokenization_openai.py +2 -5
  980. transformers/models/opt/configuration_opt.py +0 -1
  981. transformers/models/opt/modeling_opt.py +74 -75
  982. transformers/models/ovis2/__init__.py +0 -1
  983. transformers/models/ovis2/configuration_ovis2.py +0 -1
  984. transformers/models/ovis2/image_processing_ovis2.py +22 -24
  985. transformers/models/ovis2/image_processing_ovis2_fast.py +6 -8
  986. transformers/models/ovis2/modeling_ovis2.py +58 -48
  987. transformers/models/ovis2/modular_ovis2.py +38 -32
  988. transformers/models/ovis2/processing_ovis2.py +12 -40
  989. transformers/models/owlv2/configuration_owlv2.py +0 -1
  990. transformers/models/owlv2/image_processing_owlv2.py +20 -21
  991. transformers/models/owlv2/image_processing_owlv2_fast.py +7 -10
  992. transformers/models/owlv2/modeling_owlv2.py +89 -90
  993. transformers/models/owlv2/modular_owlv2.py +6 -9
  994. transformers/models/owlv2/processing_owlv2.py +20 -49
  995. transformers/models/owlvit/configuration_owlvit.py +0 -1
  996. transformers/models/owlvit/image_processing_owlvit.py +21 -22
  997. transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
  998. transformers/models/owlvit/modeling_owlvit.py +88 -89
  999. transformers/models/owlvit/processing_owlvit.py +20 -48
  1000. transformers/models/paddleocr_vl/__init__.py +0 -1
  1001. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +19 -19
  1002. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +37 -37
  1003. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
  1004. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +104 -90
  1005. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +90 -80
  1006. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
  1007. transformers/models/paligemma/configuration_paligemma.py +0 -1
  1008. transformers/models/paligemma/modeling_paligemma.py +73 -67
  1009. transformers/models/paligemma/processing_paligemma.py +13 -66
  1010. transformers/models/parakeet/configuration_parakeet.py +1 -4
  1011. transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
  1012. transformers/models/parakeet/modeling_parakeet.py +23 -22
  1013. transformers/models/parakeet/modular_parakeet.py +21 -18
  1014. transformers/models/parakeet/processing_parakeet.py +12 -5
  1015. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +5 -7
  1016. transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
  1017. transformers/models/patchtsmixer/modeling_patchtsmixer.py +64 -62
  1018. transformers/models/patchtst/configuration_patchtst.py +6 -9
  1019. transformers/models/patchtst/modeling_patchtst.py +77 -78
  1020. transformers/models/pe_audio/__init__.py +29 -0
  1021. transformers/models/pe_audio/configuration_pe_audio.py +204 -0
  1022. transformers/models/pe_audio/feature_extraction_pe_audio.py +160 -0
  1023. transformers/models/pe_audio/modeling_pe_audio.py +819 -0
  1024. transformers/models/pe_audio/modular_pe_audio.py +298 -0
  1025. transformers/models/pe_audio/processing_pe_audio.py +23 -0
  1026. transformers/models/pe_audio_video/__init__.py +28 -0
  1027. transformers/models/pe_audio_video/configuration_pe_audio_video.py +223 -0
  1028. transformers/models/pe_audio_video/modeling_pe_audio_video.py +971 -0
  1029. transformers/models/pe_audio_video/modular_pe_audio_video.py +763 -0
  1030. transformers/models/pe_audio_video/processing_pe_audio_video.py +24 -0
  1031. transformers/models/pe_video/__init__.py +29 -0
  1032. transformers/models/pe_video/configuration_pe_video.py +209 -0
  1033. transformers/models/pe_video/modeling_pe_video.py +635 -0
  1034. transformers/models/pe_video/modular_pe_video.py +218 -0
  1035. transformers/models/pe_video/processing_pe_video.py +10 -0
  1036. transformers/models/pe_video/video_processing_pe_video.py +64 -0
  1037. transformers/models/pegasus/configuration_pegasus.py +1 -1
  1038. transformers/models/pegasus/modeling_pegasus.py +66 -65
  1039. transformers/models/pegasus/tokenization_pegasus.py +1 -4
  1040. transformers/models/pegasus_x/configuration_pegasus_x.py +0 -1
  1041. transformers/models/pegasus_x/modeling_pegasus_x.py +51 -52
  1042. transformers/models/perceiver/configuration_perceiver.py +0 -1
  1043. transformers/models/perceiver/image_processing_perceiver.py +22 -25
  1044. transformers/models/perceiver/image_processing_perceiver_fast.py +5 -7
  1045. transformers/models/perceiver/modeling_perceiver.py +140 -137
  1046. transformers/models/perceiver/tokenization_perceiver.py +3 -6
  1047. transformers/models/perception_lm/configuration_perception_lm.py +0 -1
  1048. transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -10
  1049. transformers/models/perception_lm/modeling_perception_lm.py +45 -43
  1050. transformers/models/perception_lm/modular_perception_lm.py +38 -36
  1051. transformers/models/perception_lm/processing_perception_lm.py +13 -47
  1052. transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
  1053. transformers/models/persimmon/configuration_persimmon.py +18 -21
  1054. transformers/models/persimmon/modeling_persimmon.py +40 -43
  1055. transformers/models/phi/configuration_phi.py +19 -22
  1056. transformers/models/phi/modeling_phi.py +36 -38
  1057. transformers/models/phi/modular_phi.py +23 -23
  1058. transformers/models/phi3/configuration_phi3.py +23 -26
  1059. transformers/models/phi3/modeling_phi3.py +34 -37
  1060. transformers/models/phi3/modular_phi3.py +13 -17
  1061. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +25 -26
  1062. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
  1063. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +7 -7
  1064. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +58 -57
  1065. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +62 -60
  1066. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -44
  1067. transformers/models/phimoe/configuration_phimoe.py +26 -29
  1068. transformers/models/phimoe/modeling_phimoe.py +47 -42
  1069. transformers/models/phimoe/modular_phimoe.py +1 -2
  1070. transformers/models/phobert/tokenization_phobert.py +4 -6
  1071. transformers/models/pix2struct/configuration_pix2struct.py +0 -1
  1072. transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
  1073. transformers/models/pix2struct/image_processing_pix2struct_fast.py +7 -10
  1074. transformers/models/pix2struct/modeling_pix2struct.py +42 -45
  1075. transformers/models/pix2struct/processing_pix2struct.py +5 -30
  1076. transformers/models/pixio/__init__.py +29 -0
  1077. transformers/models/pixio/configuration_pixio.py +150 -0
  1078. transformers/models/pixio/modeling_pixio.py +505 -0
  1079. transformers/models/pixio/modular_pixio.py +401 -0
  1080. transformers/models/pixtral/configuration_pixtral.py +11 -14
  1081. transformers/models/pixtral/image_processing_pixtral.py +26 -28
  1082. transformers/models/pixtral/image_processing_pixtral_fast.py +5 -6
  1083. transformers/models/pixtral/modeling_pixtral.py +23 -26
  1084. transformers/models/pixtral/processing_pixtral.py +21 -53
  1085. transformers/models/plbart/configuration_plbart.py +1 -1
  1086. transformers/models/plbart/modeling_plbart.py +107 -102
  1087. transformers/models/plbart/modular_plbart.py +36 -32
  1088. transformers/models/plbart/tokenization_plbart.py +4 -5
  1089. transformers/models/poolformer/configuration_poolformer.py +0 -1
  1090. transformers/models/poolformer/image_processing_poolformer.py +21 -24
  1091. transformers/models/poolformer/image_processing_poolformer_fast.py +6 -8
  1092. transformers/models/poolformer/modeling_poolformer.py +21 -13
  1093. transformers/models/pop2piano/configuration_pop2piano.py +0 -2
  1094. transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
  1095. transformers/models/pop2piano/modeling_pop2piano.py +22 -23
  1096. transformers/models/pop2piano/processing_pop2piano.py +25 -33
  1097. transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
  1098. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +3 -3
  1099. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1100. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +14 -15
  1101. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +9 -10
  1102. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +9 -10
  1103. transformers/models/prophetnet/configuration_prophetnet.py +26 -28
  1104. transformers/models/prophetnet/modeling_prophetnet.py +111 -131
  1105. transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
  1106. transformers/models/pvt/configuration_pvt.py +0 -1
  1107. transformers/models/pvt/image_processing_pvt.py +17 -20
  1108. transformers/models/pvt/image_processing_pvt_fast.py +0 -1
  1109. transformers/models/pvt/modeling_pvt.py +19 -21
  1110. transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
  1111. transformers/models/pvt_v2/modeling_pvt_v2.py +21 -23
  1112. transformers/models/qwen2/configuration_qwen2.py +18 -21
  1113. transformers/models/qwen2/modeling_qwen2.py +32 -34
  1114. transformers/models/qwen2/modular_qwen2.py +11 -12
  1115. transformers/models/qwen2/tokenization_qwen2.py +2 -5
  1116. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +20 -23
  1117. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +239 -192
  1118. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +174 -127
  1119. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
  1120. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +22 -25
  1121. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +112 -101
  1122. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +72 -107
  1123. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
  1124. transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
  1125. transformers/models/qwen2_audio/modeling_qwen2_audio.py +29 -31
  1126. transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
  1127. transformers/models/qwen2_moe/configuration_qwen2_moe.py +28 -31
  1128. transformers/models/qwen2_moe/modeling_qwen2_moe.py +48 -43
  1129. transformers/models/qwen2_moe/modular_qwen2_moe.py +7 -10
  1130. transformers/models/qwen2_vl/configuration_qwen2_vl.py +22 -24
  1131. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +41 -42
  1132. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +8 -9
  1133. transformers/models/qwen2_vl/modeling_qwen2_vl.py +108 -96
  1134. transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
  1135. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +35 -13
  1136. transformers/models/qwen3/configuration_qwen3.py +20 -23
  1137. transformers/models/qwen3/modeling_qwen3.py +32 -35
  1138. transformers/models/qwen3/modular_qwen3.py +4 -6
  1139. transformers/models/qwen3_moe/configuration_qwen3_moe.py +25 -28
  1140. transformers/models/qwen3_moe/modeling_qwen3_moe.py +48 -43
  1141. transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
  1142. transformers/models/qwen3_next/configuration_qwen3_next.py +31 -34
  1143. transformers/models/qwen3_next/modeling_qwen3_next.py +43 -48
  1144. transformers/models/qwen3_next/modular_qwen3_next.py +33 -34
  1145. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +89 -88
  1146. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +199 -156
  1147. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +170 -152
  1148. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
  1149. transformers/models/qwen3_vl/configuration_qwen3_vl.py +21 -24
  1150. transformers/models/qwen3_vl/modeling_qwen3_vl.py +91 -81
  1151. transformers/models/qwen3_vl/modular_qwen3_vl.py +86 -112
  1152. transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
  1153. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
  1154. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +21 -25
  1155. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +174 -195
  1156. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +65 -117
  1157. transformers/models/rag/configuration_rag.py +0 -9
  1158. transformers/models/rag/modeling_rag.py +123 -127
  1159. transformers/models/rag/retrieval_rag.py +2 -4
  1160. transformers/models/rag/tokenization_rag.py +0 -50
  1161. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +21 -24
  1162. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +34 -36
  1163. transformers/models/reformer/configuration_reformer.py +0 -1
  1164. transformers/models/reformer/modeling_reformer.py +76 -69
  1165. transformers/models/reformer/tokenization_reformer.py +3 -6
  1166. transformers/models/regnet/configuration_regnet.py +0 -1
  1167. transformers/models/regnet/modeling_regnet.py +11 -9
  1168. transformers/models/rembert/configuration_rembert.py +0 -1
  1169. transformers/models/rembert/modeling_rembert.py +115 -111
  1170. transformers/models/rembert/tokenization_rembert.py +1 -4
  1171. transformers/models/resnet/configuration_resnet.py +0 -1
  1172. transformers/models/resnet/modeling_resnet.py +16 -13
  1173. transformers/models/roberta/configuration_roberta.py +0 -1
  1174. transformers/models/roberta/modeling_roberta.py +94 -93
  1175. transformers/models/roberta/modular_roberta.py +58 -58
  1176. transformers/models/roberta/tokenization_roberta.py +2 -5
  1177. transformers/models/roberta/tokenization_roberta_old.py +2 -4
  1178. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +0 -1
  1179. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +94 -93
  1180. transformers/models/roc_bert/configuration_roc_bert.py +0 -1
  1181. transformers/models/roc_bert/modeling_roc_bert.py +122 -121
  1182. transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
  1183. transformers/models/roformer/configuration_roformer.py +0 -1
  1184. transformers/models/roformer/modeling_roformer.py +79 -81
  1185. transformers/models/roformer/tokenization_roformer.py +3 -6
  1186. transformers/models/roformer/tokenization_utils.py +0 -1
  1187. transformers/models/rt_detr/configuration_rt_detr.py +1 -2
  1188. transformers/models/rt_detr/configuration_rt_detr_resnet.py +0 -1
  1189. transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
  1190. transformers/models/rt_detr/image_processing_rt_detr_fast.py +15 -15
  1191. transformers/models/rt_detr/modeling_rt_detr.py +84 -82
  1192. transformers/models/rt_detr/modeling_rt_detr_resnet.py +10 -7
  1193. transformers/models/rt_detr/modular_rt_detr.py +14 -14
  1194. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -4
  1195. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +86 -81
  1196. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +10 -7
  1197. transformers/models/rwkv/configuration_rwkv.py +0 -1
  1198. transformers/models/rwkv/modeling_rwkv.py +30 -32
  1199. transformers/models/sam/configuration_sam.py +1 -1
  1200. transformers/models/sam/image_processing_sam.py +59 -60
  1201. transformers/models/sam/image_processing_sam_fast.py +21 -23
  1202. transformers/models/sam/modeling_sam.py +37 -36
  1203. transformers/models/sam/processing_sam.py +39 -27
  1204. transformers/models/sam2/configuration_sam2.py +1 -2
  1205. transformers/models/sam2/image_processing_sam2_fast.py +14 -15
  1206. transformers/models/sam2/modeling_sam2.py +50 -48
  1207. transformers/models/sam2/modular_sam2.py +48 -45
  1208. transformers/models/sam2/processing_sam2.py +31 -47
  1209. transformers/models/sam2_video/configuration_sam2_video.py +0 -1
  1210. transformers/models/sam2_video/modeling_sam2_video.py +119 -112
  1211. transformers/models/sam2_video/modular_sam2_video.py +91 -97
  1212. transformers/models/sam2_video/processing_sam2_video.py +49 -66
  1213. transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
  1214. transformers/models/sam3/configuration_sam3.py +21 -2
  1215. transformers/models/sam3/image_processing_sam3_fast.py +17 -20
  1216. transformers/models/sam3/modeling_sam3.py +77 -56
  1217. transformers/models/sam3/modular_sam3.py +3 -8
  1218. transformers/models/sam3/processing_sam3.py +29 -48
  1219. transformers/models/sam3_tracker/__init__.py +0 -1
  1220. transformers/models/sam3_tracker/configuration_sam3_tracker.py +0 -1
  1221. transformers/models/sam3_tracker/modeling_sam3_tracker.py +36 -36
  1222. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -1
  1223. transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -47
  1224. transformers/models/sam3_tracker_video/__init__.py +0 -1
  1225. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -1
  1226. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +96 -85
  1227. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +27 -6
  1228. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
  1229. transformers/models/sam3_video/configuration_sam3_video.py +14 -1
  1230. transformers/models/sam3_video/modeling_sam3_video.py +32 -34
  1231. transformers/models/sam3_video/processing_sam3_video.py +26 -46
  1232. transformers/models/sam_hq/__init__.py +1 -1
  1233. transformers/models/sam_hq/configuration_sam_hq.py +1 -1
  1234. transformers/models/sam_hq/modeling_sam_hq.py +65 -64
  1235. transformers/models/sam_hq/modular_sam_hq.py +17 -19
  1236. transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +39 -28
  1237. transformers/models/seamless_m4t/configuration_seamless_m4t.py +0 -1
  1238. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
  1239. transformers/models/seamless_m4t/modeling_seamless_m4t.py +207 -193
  1240. transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
  1241. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
  1242. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +0 -1
  1243. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +199 -195
  1244. transformers/models/seed_oss/configuration_seed_oss.py +23 -25
  1245. transformers/models/seed_oss/modeling_seed_oss.py +31 -33
  1246. transformers/models/seed_oss/modular_seed_oss.py +3 -4
  1247. transformers/models/segformer/configuration_segformer.py +0 -10
  1248. transformers/models/segformer/image_processing_segformer.py +39 -42
  1249. transformers/models/segformer/image_processing_segformer_fast.py +7 -9
  1250. transformers/models/segformer/modeling_segformer.py +26 -28
  1251. transformers/models/segformer/modular_segformer.py +5 -7
  1252. transformers/models/seggpt/configuration_seggpt.py +0 -1
  1253. transformers/models/seggpt/image_processing_seggpt.py +38 -41
  1254. transformers/models/seggpt/modeling_seggpt.py +28 -30
  1255. transformers/models/sew/configuration_sew.py +0 -1
  1256. transformers/models/sew/modeling_sew.py +33 -35
  1257. transformers/models/sew/modular_sew.py +10 -12
  1258. transformers/models/sew_d/configuration_sew_d.py +0 -1
  1259. transformers/models/sew_d/modeling_sew_d.py +28 -30
  1260. transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
  1261. transformers/models/shieldgemma2/modeling_shieldgemma2.py +16 -17
  1262. transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
  1263. transformers/models/siglip/configuration_siglip.py +0 -1
  1264. transformers/models/siglip/image_processing_siglip.py +17 -20
  1265. transformers/models/siglip/image_processing_siglip_fast.py +0 -1
  1266. transformers/models/siglip/modeling_siglip.py +62 -41
  1267. transformers/models/siglip/processing_siglip.py +2 -14
  1268. transformers/models/siglip/tokenization_siglip.py +6 -7
  1269. transformers/models/siglip2/configuration_siglip2.py +1 -1
  1270. transformers/models/siglip2/image_processing_siglip2.py +15 -16
  1271. transformers/models/siglip2/image_processing_siglip2_fast.py +4 -5
  1272. transformers/models/siglip2/modeling_siglip2.py +114 -92
  1273. transformers/models/siglip2/modular_siglip2.py +23 -25
  1274. transformers/models/siglip2/processing_siglip2.py +2 -14
  1275. transformers/models/smollm3/configuration_smollm3.py +23 -26
  1276. transformers/models/smollm3/modeling_smollm3.py +32 -35
  1277. transformers/models/smollm3/modular_smollm3.py +27 -29
  1278. transformers/models/smolvlm/configuration_smolvlm.py +1 -1
  1279. transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
  1280. transformers/models/smolvlm/image_processing_smolvlm_fast.py +12 -12
  1281. transformers/models/smolvlm/modeling_smolvlm.py +56 -53
  1282. transformers/models/smolvlm/modular_smolvlm.py +15 -17
  1283. transformers/models/smolvlm/processing_smolvlm.py +15 -76
  1284. transformers/models/smolvlm/video_processing_smolvlm.py +7 -9
  1285. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
  1286. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +20 -23
  1287. transformers/models/speech_to_text/configuration_speech_to_text.py +0 -1
  1288. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
  1289. transformers/models/speech_to_text/modeling_speech_to_text.py +62 -54
  1290. transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
  1291. transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
  1292. transformers/models/speecht5/configuration_speecht5.py +0 -1
  1293. transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
  1294. transformers/models/speecht5/modeling_speecht5.py +200 -174
  1295. transformers/models/speecht5/number_normalizer.py +0 -1
  1296. transformers/models/speecht5/processing_speecht5.py +3 -37
  1297. transformers/models/speecht5/tokenization_speecht5.py +4 -5
  1298. transformers/models/splinter/configuration_splinter.py +0 -1
  1299. transformers/models/splinter/modeling_splinter.py +63 -59
  1300. transformers/models/splinter/tokenization_splinter.py +2 -4
  1301. transformers/models/squeezebert/configuration_squeezebert.py +0 -1
  1302. transformers/models/squeezebert/modeling_squeezebert.py +62 -62
  1303. transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
  1304. transformers/models/stablelm/configuration_stablelm.py +20 -23
  1305. transformers/models/stablelm/modeling_stablelm.py +40 -43
  1306. transformers/models/starcoder2/configuration_starcoder2.py +19 -22
  1307. transformers/models/starcoder2/modeling_starcoder2.py +34 -37
  1308. transformers/models/starcoder2/modular_starcoder2.py +13 -15
  1309. transformers/models/superglue/configuration_superglue.py +3 -3
  1310. transformers/models/superglue/image_processing_superglue.py +15 -15
  1311. transformers/models/superglue/image_processing_superglue_fast.py +5 -7
  1312. transformers/models/superglue/modeling_superglue.py +32 -33
  1313. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1314. transformers/models/superpoint/image_processing_superpoint_fast.py +5 -7
  1315. transformers/models/superpoint/modeling_superpoint.py +13 -14
  1316. transformers/models/swiftformer/configuration_swiftformer.py +0 -1
  1317. transformers/models/swiftformer/modeling_swiftformer.py +16 -14
  1318. transformers/models/swin/configuration_swin.py +0 -1
  1319. transformers/models/swin/modeling_swin.py +74 -82
  1320. transformers/models/swin2sr/configuration_swin2sr.py +0 -1
  1321. transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
  1322. transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -6
  1323. transformers/models/swin2sr/modeling_swin2sr.py +75 -61
  1324. transformers/models/swinv2/configuration_swinv2.py +0 -1
  1325. transformers/models/swinv2/modeling_swinv2.py +96 -100
  1326. transformers/models/switch_transformers/configuration_switch_transformers.py +0 -1
  1327. transformers/models/switch_transformers/modeling_switch_transformers.py +34 -41
  1328. transformers/models/switch_transformers/modular_switch_transformers.py +31 -38
  1329. transformers/models/t5/configuration_t5.py +7 -2
  1330. transformers/models/t5/modeling_t5.py +76 -84
  1331. transformers/models/t5/tokenization_t5.py +1 -3
  1332. transformers/models/t5gemma/configuration_t5gemma.py +33 -34
  1333. transformers/models/t5gemma/modeling_t5gemma.py +97 -100
  1334. transformers/models/t5gemma/modular_t5gemma.py +117 -118
  1335. transformers/models/t5gemma2/configuration_t5gemma2.py +59 -96
  1336. transformers/models/t5gemma2/modeling_t5gemma2.py +109 -103
  1337. transformers/models/t5gemma2/modular_t5gemma2.py +375 -91
  1338. transformers/models/table_transformer/configuration_table_transformer.py +1 -2
  1339. transformers/models/table_transformer/modeling_table_transformer.py +47 -49
  1340. transformers/models/tapas/configuration_tapas.py +0 -1
  1341. transformers/models/tapas/modeling_tapas.py +64 -66
  1342. transformers/models/tapas/tokenization_tapas.py +115 -153
  1343. transformers/models/textnet/configuration_textnet.py +0 -1
  1344. transformers/models/textnet/image_processing_textnet.py +22 -25
  1345. transformers/models/textnet/image_processing_textnet_fast.py +5 -7
  1346. transformers/models/textnet/modeling_textnet.py +13 -14
  1347. transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
  1348. transformers/models/time_series_transformer/modeling_time_series_transformer.py +79 -81
  1349. transformers/models/timesfm/configuration_timesfm.py +0 -1
  1350. transformers/models/timesfm/modeling_timesfm.py +29 -19
  1351. transformers/models/timesfm/modular_timesfm.py +28 -18
  1352. transformers/models/timesformer/configuration_timesformer.py +0 -1
  1353. transformers/models/timesformer/modeling_timesformer.py +13 -16
  1354. transformers/models/timm_backbone/configuration_timm_backbone.py +0 -1
  1355. transformers/models/timm_backbone/modeling_timm_backbone.py +17 -15
  1356. transformers/models/timm_wrapper/configuration_timm_wrapper.py +5 -3
  1357. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
  1358. transformers/models/timm_wrapper/modeling_timm_wrapper.py +32 -28
  1359. transformers/models/trocr/configuration_trocr.py +0 -1
  1360. transformers/models/trocr/modeling_trocr.py +39 -42
  1361. transformers/models/trocr/processing_trocr.py +5 -25
  1362. transformers/models/tvp/configuration_tvp.py +5 -2
  1363. transformers/models/tvp/image_processing_tvp.py +50 -52
  1364. transformers/models/tvp/image_processing_tvp_fast.py +9 -10
  1365. transformers/models/tvp/modeling_tvp.py +25 -27
  1366. transformers/models/tvp/processing_tvp.py +2 -14
  1367. transformers/models/udop/configuration_udop.py +1 -1
  1368. transformers/models/udop/modeling_udop.py +63 -70
  1369. transformers/models/udop/processing_udop.py +7 -26
  1370. transformers/models/udop/tokenization_udop.py +80 -93
  1371. transformers/models/umt5/configuration_umt5.py +2 -3
  1372. transformers/models/umt5/modeling_umt5.py +80 -87
  1373. transformers/models/unispeech/configuration_unispeech.py +0 -1
  1374. transformers/models/unispeech/modeling_unispeech.py +47 -49
  1375. transformers/models/unispeech/modular_unispeech.py +20 -22
  1376. transformers/models/unispeech_sat/configuration_unispeech_sat.py +0 -1
  1377. transformers/models/unispeech_sat/modeling_unispeech_sat.py +63 -65
  1378. transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
  1379. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1380. transformers/models/univnet/modeling_univnet.py +7 -8
  1381. transformers/models/upernet/configuration_upernet.py +0 -1
  1382. transformers/models/upernet/modeling_upernet.py +10 -13
  1383. transformers/models/vaultgemma/__init__.py +0 -1
  1384. transformers/models/vaultgemma/configuration_vaultgemma.py +24 -26
  1385. transformers/models/vaultgemma/modeling_vaultgemma.py +35 -37
  1386. transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
  1387. transformers/models/video_llama_3/image_processing_video_llama_3.py +43 -42
  1388. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +8 -8
  1389. transformers/models/video_llama_3/modeling_video_llama_3.py +77 -66
  1390. transformers/models/video_llama_3/modular_video_llama_3.py +110 -112
  1391. transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
  1392. transformers/models/video_llama_3/video_processing_video_llama_3.py +18 -18
  1393. transformers/models/video_llava/configuration_video_llava.py +0 -1
  1394. transformers/models/video_llava/image_processing_video_llava.py +35 -38
  1395. transformers/models/video_llava/modeling_video_llava.py +59 -57
  1396. transformers/models/video_llava/processing_video_llava.py +38 -78
  1397. transformers/models/video_llava/video_processing_video_llava.py +0 -1
  1398. transformers/models/videomae/configuration_videomae.py +0 -1
  1399. transformers/models/videomae/image_processing_videomae.py +31 -34
  1400. transformers/models/videomae/modeling_videomae.py +13 -15
  1401. transformers/models/videomae/video_processing_videomae.py +0 -1
  1402. transformers/models/vilt/configuration_vilt.py +2 -3
  1403. transformers/models/vilt/image_processing_vilt.py +29 -30
  1404. transformers/models/vilt/image_processing_vilt_fast.py +9 -10
  1405. transformers/models/vilt/modeling_vilt.py +83 -78
  1406. transformers/models/vilt/processing_vilt.py +2 -14
  1407. transformers/models/vipllava/configuration_vipllava.py +0 -1
  1408. transformers/models/vipllava/modeling_vipllava.py +45 -42
  1409. transformers/models/vipllava/modular_vipllava.py +30 -32
  1410. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
  1411. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +18 -21
  1412. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
  1413. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +18 -21
  1414. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
  1415. transformers/models/visual_bert/configuration_visual_bert.py +0 -1
  1416. transformers/models/visual_bert/modeling_visual_bert.py +92 -92
  1417. transformers/models/vit/configuration_vit.py +0 -1
  1418. transformers/models/vit/image_processing_vit.py +19 -22
  1419. transformers/models/vit/image_processing_vit_fast.py +0 -1
  1420. transformers/models/vit/modeling_vit.py +13 -15
  1421. transformers/models/vit_mae/configuration_vit_mae.py +0 -1
  1422. transformers/models/vit_mae/modeling_vit_mae.py +21 -23
  1423. transformers/models/vit_msn/configuration_vit_msn.py +0 -1
  1424. transformers/models/vit_msn/modeling_vit_msn.py +10 -12
  1425. transformers/models/vitdet/configuration_vitdet.py +0 -1
  1426. transformers/models/vitdet/modeling_vitdet.py +12 -14
  1427. transformers/models/vitmatte/configuration_vitmatte.py +2 -5
  1428. transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
  1429. transformers/models/vitmatte/image_processing_vitmatte_fast.py +14 -16
  1430. transformers/models/vitmatte/modeling_vitmatte.py +13 -11
  1431. transformers/models/vitpose/configuration_vitpose.py +4 -7
  1432. transformers/models/vitpose/image_processing_vitpose.py +24 -25
  1433. transformers/models/vitpose/image_processing_vitpose_fast.py +9 -11
  1434. transformers/models/vitpose/modeling_vitpose.py +10 -12
  1435. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +0 -1
  1436. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +8 -10
  1437. transformers/models/vits/configuration_vits.py +0 -1
  1438. transformers/models/vits/modeling_vits.py +34 -35
  1439. transformers/models/vits/tokenization_vits.py +3 -4
  1440. transformers/models/vivit/configuration_vivit.py +0 -1
  1441. transformers/models/vivit/image_processing_vivit.py +36 -39
  1442. transformers/models/vivit/modeling_vivit.py +5 -7
  1443. transformers/models/vjepa2/__init__.py +0 -1
  1444. transformers/models/vjepa2/configuration_vjepa2.py +0 -1
  1445. transformers/models/vjepa2/modeling_vjepa2.py +30 -32
  1446. transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
  1447. transformers/models/voxtral/__init__.py +0 -1
  1448. transformers/models/voxtral/configuration_voxtral.py +0 -1
  1449. transformers/models/voxtral/modeling_voxtral.py +19 -27
  1450. transformers/models/voxtral/modular_voxtral.py +12 -21
  1451. transformers/models/voxtral/processing_voxtral.py +25 -48
  1452. transformers/models/wav2vec2/configuration_wav2vec2.py +0 -1
  1453. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
  1454. transformers/models/wav2vec2/modeling_wav2vec2.py +67 -122
  1455. transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
  1456. transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
  1457. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +0 -1
  1458. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +65 -62
  1459. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +52 -48
  1460. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
  1461. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +0 -1
  1462. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +84 -77
  1463. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +37 -30
  1464. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
  1465. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
  1466. transformers/models/wavlm/configuration_wavlm.py +0 -1
  1467. transformers/models/wavlm/modeling_wavlm.py +45 -48
  1468. transformers/models/wavlm/modular_wavlm.py +4 -5
  1469. transformers/models/whisper/configuration_whisper.py +0 -1
  1470. transformers/models/whisper/english_normalizer.py +3 -4
  1471. transformers/models/whisper/feature_extraction_whisper.py +9 -24
  1472. transformers/models/whisper/generation_whisper.py +27 -48
  1473. transformers/models/whisper/modeling_whisper.py +73 -73
  1474. transformers/models/whisper/processing_whisper.py +3 -20
  1475. transformers/models/whisper/tokenization_whisper.py +9 -30
  1476. transformers/models/x_clip/configuration_x_clip.py +0 -1
  1477. transformers/models/x_clip/modeling_x_clip.py +70 -69
  1478. transformers/models/x_clip/processing_x_clip.py +2 -14
  1479. transformers/models/xcodec/configuration_xcodec.py +4 -6
  1480. transformers/models/xcodec/modeling_xcodec.py +20 -17
  1481. transformers/models/xglm/configuration_xglm.py +0 -1
  1482. transformers/models/xglm/modeling_xglm.py +59 -55
  1483. transformers/models/xglm/tokenization_xglm.py +1 -4
  1484. transformers/models/xlm/configuration_xlm.py +0 -1
  1485. transformers/models/xlm/modeling_xlm.py +139 -144
  1486. transformers/models/xlm/tokenization_xlm.py +3 -5
  1487. transformers/models/xlm_roberta/configuration_xlm_roberta.py +0 -1
  1488. transformers/models/xlm_roberta/modeling_xlm_roberta.py +195 -194
  1489. transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
  1490. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
  1491. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +0 -1
  1492. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +94 -93
  1493. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
  1494. transformers/models/xlnet/configuration_xlnet.py +0 -11
  1495. transformers/models/xlnet/modeling_xlnet.py +152 -163
  1496. transformers/models/xlnet/tokenization_xlnet.py +1 -4
  1497. transformers/models/xlstm/configuration_xlstm.py +3 -5
  1498. transformers/models/xlstm/modeling_xlstm.py +62 -65
  1499. transformers/models/xmod/configuration_xmod.py +0 -1
  1500. transformers/models/xmod/modeling_xmod.py +101 -100
  1501. transformers/models/yolos/configuration_yolos.py +0 -1
  1502. transformers/models/yolos/image_processing_yolos.py +60 -62
  1503. transformers/models/yolos/image_processing_yolos_fast.py +18 -18
  1504. transformers/models/yolos/modeling_yolos.py +12 -14
  1505. transformers/models/yolos/modular_yolos.py +2 -4
  1506. transformers/models/yoso/configuration_yoso.py +0 -1
  1507. transformers/models/yoso/modeling_yoso.py +64 -63
  1508. transformers/models/zamba/configuration_zamba.py +0 -1
  1509. transformers/models/zamba/modeling_zamba.py +70 -70
  1510. transformers/models/zamba2/configuration_zamba2.py +36 -37
  1511. transformers/models/zamba2/modeling_zamba2.py +87 -89
  1512. transformers/models/zamba2/modular_zamba2.py +43 -45
  1513. transformers/models/zoedepth/configuration_zoedepth.py +1 -2
  1514. transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
  1515. transformers/models/zoedepth/image_processing_zoedepth_fast.py +12 -15
  1516. transformers/models/zoedepth/modeling_zoedepth.py +21 -16
  1517. transformers/pipelines/__init__.py +59 -55
  1518. transformers/pipelines/any_to_any.py +14 -22
  1519. transformers/pipelines/audio_utils.py +1 -2
  1520. transformers/pipelines/automatic_speech_recognition.py +20 -12
  1521. transformers/pipelines/base.py +13 -17
  1522. transformers/pipelines/deprecated/__init__.py +0 -1
  1523. transformers/pipelines/document_question_answering.py +1 -1
  1524. transformers/pipelines/image_text_to_text.py +0 -1
  1525. transformers/pipelines/image_to_text.py +4 -44
  1526. transformers/pipelines/question_answering.py +5 -44
  1527. transformers/pipelines/text_classification.py +1 -14
  1528. transformers/pipelines/text_to_audio.py +2 -2
  1529. transformers/pipelines/token_classification.py +1 -22
  1530. transformers/pipelines/video_classification.py +1 -9
  1531. transformers/pipelines/zero_shot_audio_classification.py +0 -1
  1532. transformers/pipelines/zero_shot_classification.py +0 -6
  1533. transformers/pipelines/zero_shot_image_classification.py +0 -7
  1534. transformers/processing_utils.py +222 -151
  1535. transformers/quantizers/auto.py +2 -4
  1536. transformers/quantizers/base.py +19 -64
  1537. transformers/quantizers/quantizer_aqlm.py +1 -18
  1538. transformers/quantizers/quantizer_auto_round.py +1 -10
  1539. transformers/quantizers/quantizer_awq.py +3 -8
  1540. transformers/quantizers/quantizer_bitnet.py +1 -6
  1541. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  1542. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  1543. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  1544. transformers/quantizers/quantizer_eetq.py +2 -12
  1545. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  1546. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  1547. transformers/quantizers/quantizer_fp_quant.py +4 -4
  1548. transformers/quantizers/quantizer_gptq.py +1 -4
  1549. transformers/quantizers/quantizer_higgs.py +2 -6
  1550. transformers/quantizers/quantizer_mxfp4.py +2 -28
  1551. transformers/quantizers/quantizer_quanto.py +14 -14
  1552. transformers/quantizers/quantizer_quark.py +0 -1
  1553. transformers/quantizers/quantizer_spqr.py +3 -8
  1554. transformers/quantizers/quantizer_torchao.py +31 -127
  1555. transformers/quantizers/quantizer_vptq.py +1 -10
  1556. transformers/testing_utils.py +31 -49
  1557. transformers/tokenization_mistral_common.py +554 -902
  1558. transformers/tokenization_utils_base.py +112 -124
  1559. transformers/tokenization_utils_sentencepiece.py +5 -6
  1560. transformers/tokenization_utils_tokenizers.py +30 -7
  1561. transformers/trainer.py +30 -11
  1562. transformers/trainer_callback.py +8 -0
  1563. transformers/trainer_jit_checkpoint.py +1 -2
  1564. transformers/trainer_seq2seq.py +4 -0
  1565. transformers/training_args.py +11 -13
  1566. transformers/utils/__init__.py +4 -0
  1567. transformers/utils/attention_visualizer.py +5 -5
  1568. transformers/utils/auto_docstring.py +598 -37
  1569. transformers/utils/doc.py +1 -1
  1570. transformers/utils/dummy_pt_objects.py +0 -42
  1571. transformers/utils/generic.py +21 -1
  1572. transformers/utils/import_utils.py +51 -9
  1573. transformers/utils/kernel_config.py +71 -18
  1574. transformers/utils/loading_report.py +3 -3
  1575. transformers/utils/quantization_config.py +16 -18
  1576. transformers/video_processing_utils.py +35 -32
  1577. transformers/video_utils.py +18 -22
  1578. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc3.dist-info}/METADATA +23 -24
  1579. transformers-5.0.0rc3.dist-info/RECORD +2067 -0
  1580. transformers-5.0.0rc1.dist-info/RECORD +0 -2003
  1581. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc3.dist-info}/WHEEL +0 -0
  1582. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc3.dist-info}/entry_points.txt +0 -0
  1583. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc3.dist-info}/licenses/LICENSE +0 -0
  1584. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,3 @@
1
- # coding=utf-8
2
1
  # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3
2
  # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
3
  #
@@ -16,7 +15,6 @@
16
15
  import collections
17
16
  import copy
18
17
  import functools
19
- import gc
20
18
  import importlib.metadata
21
19
  import inspect
22
20
  import json
@@ -26,13 +24,13 @@ import sys
26
24
  import warnings
27
25
  from abc import abstractmethod
28
26
  from collections import defaultdict
29
- from collections.abc import Callable, Sequence
27
+ from collections.abc import Callable, Iterator, Sequence
30
28
  from contextlib import contextmanager
31
29
  from enum import Enum
32
30
  from functools import partial, wraps
33
31
  from itertools import cycle
34
32
  from threading import Thread
35
- from typing import Optional, TypeVar, Union, get_type_hints
33
+ from typing import Optional, TypeVar, get_type_hints
36
34
  from zipfile import is_zipfile
37
35
 
38
36
  import torch
@@ -63,7 +61,8 @@ from .integrations.accelerate import (
63
61
  accelerate_dispatch,
64
62
  check_and_set_device_map,
65
63
  expand_device_map,
66
- init_empty_weights,
64
+ get_device,
65
+ load_offloaded_parameter,
67
66
  )
68
67
  from .integrations.deepspeed import _load_state_dict_into_zero3_model
69
68
  from .integrations.eager_paged import eager_paged_attention_forward
@@ -86,6 +85,7 @@ from .integrations.tensor_parallel import (
86
85
  )
87
86
  from .loss.loss_utils import LOSS_MAPPING
88
87
  from .modeling_flash_attention_utils import lazy_import_flash_attention, lazy_import_paged_flash_attention
88
+ from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
89
89
  from .pytorch_utils import id_tensor_storage
90
90
  from .quantizers import HfQuantizer
91
91
  from .quantizers.auto import get_hf_quantizer
@@ -108,6 +108,7 @@ from .utils import (
108
108
  is_accelerate_available,
109
109
  is_flash_attn_2_available,
110
110
  is_flash_attn_3_available,
111
+ is_grouped_mm_available,
111
112
  is_kernels_available,
112
113
  is_torch_flex_attn_available,
113
114
  is_torch_greater_or_equal,
@@ -130,7 +131,6 @@ from .utils.quantization_config import QuantizationMethod
130
131
  if is_accelerate_available():
131
132
  from accelerate.hooks import add_hook_to_module
132
133
  from accelerate.utils import extract_model_from_parallel
133
- from accelerate.utils.modeling import get_state_dict_from_offload
134
134
 
135
135
 
136
136
  _torch_distributed_available = torch.distributed.is_available()
@@ -152,10 +152,15 @@ logger = logging.get_logger(__name__)
152
152
  XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
153
153
  XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
154
154
  SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
155
- _init_weights = True
156
155
  _is_quantized = False
157
156
  _is_ds_init_called = False
158
157
 
158
+ # Mapping from flash attention implementations to their kernel fallback repositories
159
+ FLASH_ATTN_KERNEL_FALLBACK = {
160
+ "flash_attention_2": "kernels-community/flash-attn2",
161
+ "flash_attention_3": "kernels-community/vllm-flash-attn3",
162
+ }
163
+
159
164
 
160
165
  def is_local_dist_rank_0():
161
166
  return (
@@ -165,51 +170,6 @@ def is_local_dist_rank_0():
165
170
  )
166
171
 
167
172
 
168
- TORCH_INIT_FUNCTIONS = {
169
- "uniform_": nn.init.uniform_,
170
- "normal_": nn.init.normal_,
171
- "trunc_normal_": nn.init.trunc_normal_,
172
- "constant_": nn.init.constant_,
173
- "xavier_uniform_": nn.init.xavier_uniform_,
174
- "xavier_normal_": nn.init.xavier_normal_,
175
- "kaiming_uniform_": nn.init.kaiming_uniform_,
176
- "kaiming_normal_": nn.init.kaiming_normal_,
177
- "uniform": nn.init.uniform,
178
- "normal": nn.init.normal,
179
- "xavier_uniform": nn.init.xavier_uniform,
180
- "xavier_normal": nn.init.xavier_normal,
181
- "kaiming_uniform": nn.init.kaiming_uniform,
182
- "kaiming_normal": nn.init.kaiming_normal,
183
- "orthogonal_": nn.init.orthogonal_,
184
- }
185
-
186
-
187
- @contextmanager
188
- def no_init_weights():
189
- """
190
- Context manager to globally disable weight initialization to speed up loading large models.
191
- """
192
- global _init_weights
193
- old_init_weights = _init_weights
194
-
195
- _init_weights = False
196
-
197
- def _skip_init(*args, **kwargs):
198
- pass
199
-
200
- # Save the original initialization functions
201
- for name, init_func in TORCH_INIT_FUNCTIONS.items():
202
- setattr(torch.nn.init, name, _skip_init)
203
-
204
- try:
205
- yield
206
- finally:
207
- _init_weights = old_init_weights
208
- # Restore the original initialization functions
209
- for name, init_func in TORCH_INIT_FUNCTIONS.items():
210
- setattr(torch.nn.init, name, init_func)
211
-
212
-
213
173
  @contextmanager
214
174
  def set_quantized_state():
215
175
  global _is_quantized
@@ -233,23 +193,28 @@ def set_zero3_state():
233
193
  _is_ds_init_called = False
234
194
 
235
195
 
236
- def restore_default_dtype(func):
196
+ @contextmanager
197
+ def local_torch_dtype(dtype: torch.dtype, model_class_name: str | None = None):
237
198
  """
238
- Decorator to restore the default torch dtype
239
- at the end of the function. Serves
240
- as a backup in case calling the function raises
241
- an error after the function has changed the default dtype but before it could restore it.
199
+ Locally change the torch default dtype to `dtype`, and restore the old one upon exiting the context.
200
+ If `model_class_name` is provided, it's used to provide a more helpful error message if `dtype` is not valid.
242
201
  """
202
+ # Just a more helping error before we set `torch.set_default_dtype` later on which would crash in this case
203
+ if not dtype.is_floating_point:
204
+ if model_class_name is not None:
205
+ error_message = (
206
+ f"{model_class_name} cannot be instantiated under `dtype={dtype}` as it's not a floating-point dtype"
207
+ )
208
+ else:
209
+ error_message = f"Cannot set `{dtype}` as torch's default as it's not a floating-point dtype"
210
+ raise ValueError(error_message)
243
211
 
244
- @wraps(func)
245
- def _wrapper(*args, **kwargs):
246
- old_dtype = torch.get_default_dtype()
247
- try:
248
- return func(*args, **kwargs)
249
- finally:
250
- torch.set_default_dtype(old_dtype)
251
-
252
- return _wrapper
212
+ original_dtype = torch.get_default_dtype()
213
+ try:
214
+ torch.set_default_dtype(dtype)
215
+ yield
216
+ finally:
217
+ torch.set_default_dtype(original_dtype)
253
218
 
254
219
 
255
220
  def get_torch_context_manager_or_global_device():
@@ -305,7 +270,7 @@ if is_torch_greater_or_equal("2.3.0"):
305
270
 
306
271
 
307
272
  def load_state_dict(
308
- checkpoint_file: Union[str, os.PathLike], map_location: Union[str, torch.device] = "cpu", weights_only: bool = True
273
+ checkpoint_file: str | os.PathLike, map_location: str | torch.device = "cpu", weights_only: bool = True
309
274
  ) -> dict[str, torch.Tensor]:
310
275
  """
311
276
  Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
@@ -405,14 +370,97 @@ def _find_identical(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]
405
370
  return shared_tensors, identical
406
371
 
407
372
 
373
+ def remove_tied_weights_from_state_dict(
374
+ state_dict: dict[str, torch.Tensor], model: "PreTrainedModel"
375
+ ) -> dict[str, torch.Tensor]:
376
+ """
377
+ Remove all tied weights from the given `state_dict`, making sure to keep only the main weight that `model`
378
+ will expect when reloading (even if we know tie weights symmetrically, it's better to keep the intended one).
379
+ This is because `safetensors` does not allow tensor aliasing - so we're going to remove aliases before saving.
380
+ """
381
+ # To avoid any potential mistakes and mismatches between config and actual tied weights, here we check the pointers
382
+ # of the Tensors themselves -> we are guaranteed to find all the actual tied weights
383
+ ptrs = collections.defaultdict(list)
384
+ for name, tensor in state_dict.items():
385
+ if not isinstance(tensor, torch.Tensor):
386
+ # Sometimes in the state_dict we have non-tensor objects.
387
+ # e.g. in bitsandbytes we have some `str` objects in the state_dict
388
+ # In the non-tensor case, fall back to the pointer of the object itself
389
+ ptrs[id(tensor)].append(name)
390
+
391
+ elif tensor.device.type == "meta":
392
+ # In offloaded cases, there may be meta tensors in the state_dict.
393
+ # For these cases, key by the pointer of the original tensor object
394
+ # (state_dict tensors are detached and therefore no longer shared)
395
+ tensor = model.get_parameter(name)
396
+ ptrs[id(tensor)].append(name)
397
+
398
+ else:
399
+ ptrs[id_tensor_storage(tensor)].append(name)
400
+
401
+ shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
402
+
403
+ # Recursively descend to find tied weight keys
404
+ all_potential_tied_weights_keys = set(_get_tied_weight_keys(model))
405
+ error_names = []
406
+ to_delete_names = set()
407
+ # Removing the keys which are declared as known duplicates on load. This allows to make sure the name which is
408
+ # kept is consistent
409
+ if all_potential_tied_weights_keys is not None:
410
+ for names in shared_ptrs.values():
411
+ found = 0
412
+ for name in sorted(names):
413
+ matches_pattern = any(re.search(pat, name) for pat in all_potential_tied_weights_keys)
414
+ if matches_pattern and name in state_dict:
415
+ found += 1
416
+ if found < len(names):
417
+ to_delete_names.add(name)
418
+ # We are entering a place where the weights and the transformers configuration do NOT match.
419
+ shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
420
+ # Those are actually tensor sharing but disjoint from each other, we can safely clone them
421
+ # Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
422
+ for name in disjoint_names:
423
+ state_dict[name] = state_dict[name].clone()
424
+
425
+ # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
426
+ # If the link between tensors was done at runtime then `from_pretrained` will not get
427
+ # the key back leading to random tensor. A proper warning will be shown
428
+ # during reload (if applicable), but since the file is not necessarily compatible with
429
+ # the config, better show a proper warning.
430
+ shared_names, identical_names = _find_identical(shared_names, state_dict)
431
+ # delete tensors that have identical storage
432
+ for inames in identical_names:
433
+ known = inames.intersection(to_delete_names)
434
+ for name in known:
435
+ del state_dict[name]
436
+ unknown = inames.difference(to_delete_names)
437
+ if len(unknown) > 1:
438
+ error_names.append(unknown)
439
+
440
+ if shared_names:
441
+ error_names.extend(shared_names)
442
+
443
+ if len(error_names) > 0:
444
+ raise RuntimeError(
445
+ f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. "
446
+ f"We found all the potential target tied weights keys to be: {all_potential_tied_weights_keys}.\n"
447
+ "This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
448
+ )
449
+
450
+ return state_dict
451
+
452
+
408
453
  def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor):
409
- """Cast a single parameter `param_name` into the `model`, with value `tensor`."""
410
- module, param_type = get_module_from_name(model, param_name)
411
- # This will check potential shape mismatch if skipped before
412
- module.load_state_dict({param_type: tensor}, strict=False, assign=True)
454
+ """Cast a single parameter or buffer `param_name` into the `model`, with value `tensor`."""
455
+ parent, param_type = get_module_from_name(model, param_name)
456
+ if param_type in parent._parameters and not isinstance(tensor, nn.Parameter):
457
+ tensor = nn.Parameter(tensor, requires_grad=tensor.is_floating_point())
458
+ # We need to use setattr here, as we set non-persistent buffers as well with this function (`load_state_dict`
459
+ # does not allow to do it)
460
+ setattr(parent, param_type, tensor)
413
461
 
414
462
 
415
- def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
463
+ def _add_variant(weights_name: str, variant: str | None = None) -> str:
416
464
  if variant is not None:
417
465
  path, name = weights_name.rsplit(".", 1)
418
466
  weights_name = f"{path}.{variant}.{name}"
@@ -420,15 +468,15 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
420
468
 
421
469
 
422
470
  def _get_resolved_checkpoint_files(
423
- pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
424
- variant: Optional[str],
425
- gguf_file: Optional[str],
426
- use_safetensors: Optional[bool],
471
+ pretrained_model_name_or_path: str | os.PathLike | None,
472
+ variant: str | None,
473
+ gguf_file: str | None,
474
+ use_safetensors: bool | None,
427
475
  download_kwargs: DownloadKwargs,
428
476
  user_agent: dict,
429
477
  is_remote_code: bool, # Because we can't determine this inside this function, we need it to be passed in
430
- transformers_explicit_filename: Optional[str] = None,
431
- ) -> tuple[Optional[list[str]], Optional[dict]]:
478
+ transformers_explicit_filename: str | None = None,
479
+ ) -> tuple[list[str] | None, dict | None]:
432
480
  """Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
433
481
  checkpoints are sharded.
434
482
  This function will download the data if necessary.
@@ -696,22 +744,20 @@ def _get_resolved_checkpoint_files(
696
744
 
697
745
 
698
746
  def _get_dtype(
699
- cls,
700
- dtype: Optional[Union[str, torch.dtype, dict]],
701
- checkpoint_files: Optional[list[str]],
747
+ dtype: str | torch.dtype | dict | None,
748
+ checkpoint_files: list[str] | None,
702
749
  config: PreTrainedConfig,
703
- sharded_metadata: Optional[dict],
704
- state_dict: Optional[dict],
750
+ sharded_metadata: dict | None,
751
+ state_dict: dict | None,
705
752
  weights_only: bool,
706
- ) -> tuple[PreTrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
753
+ hf_quantizer: HfQuantizer | None = None,
754
+ ) -> tuple[PreTrainedConfig, torch.dtype]:
707
755
  """Find the correct `dtype` to use based on provided arguments. Also update the `config` based on the
708
756
  inferred dtype. We do the following:
709
- 1. If dtype is not None, we use that dtype
710
- 2. If dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
711
- weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
712
- we also may have config.dtype available, but we won't rely on it till v5
757
+ 1. If dtype is "auto", we try to read the config, else auto-detect dtype from the loaded state_dict, by checking
758
+ its first weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
759
+ 2. Else, use the dtype provided as a dict or str
713
760
  """
714
- dtype_orig = None
715
761
  is_sharded = sharded_metadata is not None
716
762
 
717
763
  if dtype is not None:
@@ -736,43 +782,46 @@ def _get_dtype(
736
782
  )
737
783
  elif hasattr(torch, dtype):
738
784
  dtype = getattr(torch, dtype)
739
- config.dtype = dtype
740
- for sub_config_key in config.sub_configs:
741
- if (sub_config := getattr(config, sub_config_key)) is not None:
742
- sub_config.dtype = dtype
743
- elif isinstance(dtype, torch.dtype):
744
- config.dtype = dtype
745
- for sub_config_key in config.sub_configs:
746
- if (sub_config := getattr(config, sub_config_key)) is not None:
747
- sub_config.dtype = dtype
748
- elif isinstance(dtype, dict):
749
- for key, curr_dtype in dtype.items():
750
- if hasattr(config, key):
751
- value = getattr(config, key)
752
- curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
753
- value.dtype = curr_dtype
754
- # main torch dtype for modules that aren't part of any sub-config
755
- dtype = dtype.get("")
756
- dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
757
- config.dtype = dtype
758
- if dtype is None:
759
- dtype = torch.float32
760
- else:
785
+ else:
786
+ raise ValueError(
787
+ "`dtype` provided as a `str` can only be `'auto'`, or a string representation of a valid `torch.dtype`"
788
+ )
789
+
790
+ # cast it to a proper `torch.dtype` object
791
+ dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
792
+ elif not isinstance(dtype, (dict, torch.dtype)):
761
793
  raise ValueError(
762
794
  f"`dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `dtype` "
763
795
  f"for each sub-config in composite configs, but received {dtype}"
764
796
  )
797
+ else:
798
+ # set torch.get_default_dtype() (usually fp32) as the default dtype if `None` is provided
799
+ dtype = torch.get_default_dtype()
800
+
801
+ if hf_quantizer is not None:
802
+ hf_quantizer.update_dtype(dtype)
803
+
804
+ # Get the main dtype
805
+ if isinstance(dtype, dict):
806
+ main_dtype = dtype.get("", torch.get_default_dtype())
807
+ main_dtype = getattr(torch, main_dtype) if isinstance(main_dtype, str) else main_dtype
808
+
809
+ logger.warning_once(
810
+ "Using different dtypes per module is deprecated and will be removed in future versions "
811
+ "Setting different dtypes per backbone model might cause device errors downstream, therefore "
812
+ f"setting the dtype={main_dtype} for all modules."
813
+ )
765
814
 
766
- dtype_orig = cls._set_default_dtype(dtype)
767
815
  else:
768
- # set fp32 as the default dtype for BC
769
- default_dtype = torch.get_default_dtype()
770
- config.dtype = default_dtype
771
- for key in config.sub_configs:
772
- if (sub_config := getattr(config, key)) is not None:
773
- sub_config.dtype = default_dtype
774
- dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
775
- return config, dtype, dtype_orig
816
+ main_dtype = dtype
817
+
818
+ # Set it on the config and subconfigs
819
+ config.dtype = main_dtype
820
+ for sub_config_key in config.sub_configs:
821
+ if (sub_config := getattr(config, sub_config_key)) is not None:
822
+ sub_config.dtype = main_dtype
823
+
824
+ return config, main_dtype
776
825
 
777
826
 
778
827
  class PipelineParallel(Enum):
@@ -798,11 +847,7 @@ class ModuleUtilsMixin:
798
847
  """
799
848
  `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
800
849
  """
801
- dtype = self._dtype or next(param.dtype for param in self.parameters() if param.is_floating_point())
802
- if isinstance(dtype, str):
803
- if hasattr(torch, dtype):
804
- dtype = getattr(torch, dtype)
805
- return dtype
850
+ return next(param.dtype for param in self.parameters() if param.is_floating_point())
806
851
 
807
852
  def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
808
853
  """
@@ -827,13 +872,8 @@ class ModuleUtilsMixin:
827
872
  return encoder_extended_attention_mask
828
873
 
829
874
  @staticmethod
830
- def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
831
- if device is not None:
832
- warnings.warn(
833
- "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
834
- )
835
- else:
836
- device = attention_mask.device
875
+ def create_extended_attention_mask_for_decoder(input_shape, attention_mask):
876
+ device = attention_mask.device
837
877
  batch_size, seq_length = input_shape
838
878
  seq_ids = torch.arange(seq_length, device=device)
839
879
  causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
@@ -857,8 +897,7 @@ class ModuleUtilsMixin:
857
897
  self,
858
898
  attention_mask: Tensor,
859
899
  input_shape: tuple[int, ...],
860
- device: Optional[torch.device] = None,
861
- dtype: Optional[torch.dtype] = None,
900
+ dtype: torch.dtype | None = None,
862
901
  ) -> Tensor:
863
902
  """
864
903
  Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
@@ -875,12 +914,6 @@ class ModuleUtilsMixin:
875
914
  if dtype is None:
876
915
  dtype = self.dtype
877
916
 
878
- if not (attention_mask.dim() == 2 and self.config.is_decoder):
879
- # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
880
- if device is not None:
881
- warnings.warn(
882
- "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
883
- )
884
917
  # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
885
918
  # ourselves in which case we just need to make it broadcastable to all heads.
886
919
  if attention_mask.dim() == 3:
@@ -891,7 +924,7 @@ class ModuleUtilsMixin:
891
924
  # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
892
925
  if self.config.is_decoder:
893
926
  extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
894
- input_shape, attention_mask, device
927
+ input_shape, attention_mask
895
928
  )
896
929
  else:
897
930
  extended_attention_mask = attention_mask[:, None, None, :]
@@ -972,54 +1005,52 @@ class EmbeddingAccessMixin:
972
1005
  `nn.Module`: A torch module mapping vocabulary to hidden states.
973
1006
  """
974
1007
 
975
- # 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
976
- # for most NLP models), and if so, return it.
977
-
978
1008
  name = getattr(self, "_input_embed_layer", "embed_tokens")
979
1009
 
1010
+ # 1) Direct attribute (most NLP models).
980
1011
  if (default_embedding := getattr(self, name, None)) is not None:
981
1012
  return default_embedding
982
- # 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
1013
+ # 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision/audio models).
1014
+ if hasattr(self, "embeddings") and hasattr(self.embeddings, name):
1015
+ return getattr(self.embeddings, name)
1016
+ # 3) Encoder/decoder wrappers (e.g., `self.model.embed_tokens` or similar overrides).
1017
+ if hasattr(self, "model") and hasattr(self.model, name):
1018
+ return getattr(self.model, name)
983
1019
 
984
- if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
985
- return self.model.embed_tokens
1020
+ if hasattr(self, "base_model"):
1021
+ base_model = self.base_model
1022
+ if base_model is not None and base_model is not self:
1023
+ return base_model.get_input_embeddings()
986
1024
 
987
- # 3) vanilla decoder‑only architectures
988
- elif hasattr(self, "embed_tokens"):
989
- return self.embed_tokens
990
- else:
991
- base_model = getattr(self, "base_model_prefix", None)
992
- if base_model is not None:
993
- base_model = getattr(self, base_model, None)
994
- if base_model is not None and base_model is not self:
995
- return base_model.get_input_embeddings()
996
- raise NotImplementedError(
997
- f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
998
- "please override in the subclass."
999
- )
1025
+ raise NotImplementedError(
1026
+ f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
1027
+ )
1000
1028
 
1001
1029
  def set_input_embeddings(self, value: nn.Module):
1002
1030
  """Fallback setter that handles **~70%** of models in the code-base.
1003
1031
 
1004
1032
  Order of attempts:
1005
- 1. `self.model.embed_tokens`
1006
- 2. `self.embed_tokens`
1007
- 3. delegate to the *base model* if one exists
1008
- 4. otherwise raise `NotImplementedError` so subclasses still can (and
1033
+ 1. `self.<_input_embed_layer>` (direct attribute)
1034
+ 2. `self.embeddings.<_input_embed_layer>` (nested embeddings for vision/audio models)
1035
+ 3. `self.model.<_input_embed_layer>` (encoder/decoder models)
1036
+ 4. delegate to the *base model* if one exists
1037
+ 5. otherwise raise `NotImplementedError` so subclasses still can (and
1009
1038
  should) override for exotic layouts.
1010
1039
  """
1011
1040
 
1012
- # 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
1013
1041
  name = getattr(self, "_input_embed_layer", "embed_tokens")
1014
- if hasattr(self, "model") and hasattr(self.model, name):
1015
- setattr(self.model, name, value)
1016
- # 2) as well as vanilla decoder‑only architectures
1017
- elif hasattr(self, name):
1042
+ # 1) Direct attribute (most NLP models)
1043
+ if hasattr(self, name):
1018
1044
  setattr(self, name, value)
1019
- # 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
1020
- elif getattr(self, self.base_model_prefix, self) is not self:
1021
- base_model = getattr(self, self.base_model_prefix, self)
1022
- base_model.set_input_embeddings(value)
1045
+ # 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision models)
1046
+ elif hasattr(self, "embeddings") and hasattr(self.embeddings, name):
1047
+ setattr(self.embeddings, name, value)
1048
+ # 3) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
1049
+ elif hasattr(self, "model") and hasattr(self.model, name):
1050
+ setattr(self.model, name, value)
1051
+ # 4) recurse once into the registered *base* model (e.g. for encoder/decoder)
1052
+ elif hasattr(self, "base_model") and self.base_model is not self:
1053
+ self.base_model.set_input_embeddings(value)
1023
1054
  else:
1024
1055
  raise NotImplementedError(
1025
1056
  f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
@@ -1080,8 +1111,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1080
1111
  # to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
1081
1112
  _keep_in_fp32_modules_strict = None
1082
1113
 
1083
- dtype_plan: Optional[dict[str, torch.dtype]] = None
1084
- _dtype: Optional[Union[str, torch.dtype]] = torch.get_default_dtype()
1114
+ dtype_plan: dict[str, torch.dtype] | None = None
1085
1115
 
1086
1116
  # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
1087
1117
  # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
@@ -1141,7 +1171,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1141
1171
 
1142
1172
  # Attributes used mainly in multimodal LLMs, though all models contain a valid field for these
1143
1173
  # Possible values are: text, image, video, audio and time
1144
- input_modalities: Union[str, list[str]] = "text" # most models are text
1174
+ input_modalities: str | list[str] = "text" # most models are text
1145
1175
 
1146
1176
  @property
1147
1177
  @torch._dynamo.allow_in_graph
@@ -1226,14 +1256,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1226
1256
  f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
1227
1257
  )
1228
1258
  self.config = config
1229
- default_dtype = torch.get_default_dtype()
1230
- self._dtype = default_dtype
1231
1259
 
1232
1260
  # Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
1233
1261
  # setting it recursively)
1234
1262
  self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
1235
1263
  self.config._attn_implementation, is_init_check=True
1236
1264
  )
1265
+ # Check the experts implementation is supported, or set it if not yet set (on the internal attr, to avoid
1266
+ # setting it recursively)
1267
+ self.config._experts_implementation_internal = self._check_and_adjust_experts_implementation(
1268
+ self.config._experts_implementation
1269
+ )
1237
1270
  if self.can_generate():
1238
1271
  self.generation_config = GenerationConfig.from_model_config(config)
1239
1272
 
@@ -1349,7 +1382,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1349
1382
  def pp_plan(self, plan: dict[str, tuple[str, str]]):
1350
1383
  self._pp_plan = plan
1351
1384
 
1352
- def dequantize(self):
1385
+ def dequantize(self, dtype=None):
1353
1386
  """
1354
1387
  Potentially dequantize the model in case it has been quantized by a quantization method that support
1355
1388
  dequantization.
@@ -1359,7 +1392,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1359
1392
  if hf_quantizer is None:
1360
1393
  raise ValueError("You need to first quantize your model in order to dequantize it")
1361
1394
 
1362
- return hf_quantizer.dequantize(self)
1395
+ return hf_quantizer.dequantize(self, dtype=dtype)
1363
1396
 
1364
1397
  def _backward_compatibility_gradient_checkpointing(self):
1365
1398
  if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
@@ -1367,7 +1400,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1367
1400
  # Remove the attribute now that is has been consumed, so it's no saved in the config.
1368
1401
  delattr(self.config, "gradient_checkpointing")
1369
1402
 
1370
- def add_model_tags(self, tags: Union[list[str], str]) -> None:
1403
+ def add_model_tags(self, tags: list[str] | str) -> None:
1371
1404
  r"""
1372
1405
  Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
1373
1406
  not overwrite existing tags in the model.
@@ -1400,7 +1433,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1400
1433
  self.model_tags.append(tag)
1401
1434
 
1402
1435
  @classmethod
1403
- @restore_default_dtype
1404
1436
  def _from_config(cls, config, **kwargs):
1405
1437
  """
1406
1438
  All context managers that the model should be initialized under go here.
@@ -1409,9 +1441,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1409
1441
  dtype (`torch.dtype`, *optional*):
1410
1442
  Override the default `dtype` and load the model under this dtype.
1411
1443
  """
1412
- # when we init a model from within another model (e.g. VLMs) and dispatch on FA2
1413
- # a warning is raised that dtype should be fp16. Since we never pass dtype from within
1414
- # modeling code, we can try to infer it here same way as done in `from_pretrained`
1415
1444
  # For BC on the old `torch_dtype`
1416
1445
  dtype = kwargs.pop("dtype", config.dtype)
1417
1446
  if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
@@ -1421,67 +1450,32 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1421
1450
  if isinstance(dtype, str):
1422
1451
  dtype = getattr(torch, dtype)
1423
1452
 
1424
- # override default dtype if needed
1425
- dtype_orig = None
1426
- if dtype is not None:
1427
- dtype_orig = cls._set_default_dtype(dtype)
1428
-
1429
1453
  # If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
1430
1454
  if "attn_implementation" in kwargs:
1431
1455
  config._attn_implementation = kwargs.pop("attn_implementation")
1432
1456
 
1457
+ # If passing `experts_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
1458
+ if "experts_implementation" in kwargs:
1459
+ config._experts_implementation = kwargs.pop("experts_implementation")
1460
+
1461
+ init_contexts = []
1462
+ if dtype is not None:
1463
+ init_contexts.append(local_torch_dtype(dtype, cls.__name__))
1464
+
1433
1465
  if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
1434
1466
  logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
1435
1467
  # this immediately partitions the model across all gpus, to avoid the overhead in time
1436
1468
  # and memory copying it on CPU or each GPU first
1437
1469
  import deepspeed
1438
1470
 
1439
- init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
1440
- with ContextManagers(init_contexts):
1441
- model = cls(config, **kwargs)
1471
+ init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
1442
1472
 
1443
- else:
1473
+ # Instantiate the model
1474
+ with ContextManagers(init_contexts):
1444
1475
  model = cls(config, **kwargs)
1445
1476
 
1446
- # restore default dtype if it was modified
1447
- if dtype_orig is not None:
1448
- torch.set_default_dtype(dtype_orig)
1449
-
1450
1477
  return model
1451
1478
 
1452
- @classmethod
1453
- def _set_default_dtype(cls, dtype: torch.dtype) -> torch.dtype:
1454
- """
1455
- Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
1456
- under specific dtype.
1457
-
1458
- Args:
1459
- dtype (`torch.dtype`):
1460
- a floating dtype to set to.
1461
-
1462
- Returns:
1463
- `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
1464
- modified. If it wasn't, returns `None`.
1465
-
1466
- Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
1467
- `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
1468
- """
1469
- if isinstance(dtype, str):
1470
- if hasattr(torch, dtype):
1471
- dtype = getattr(torch, dtype)
1472
- else:
1473
- raise ValueError(f"Received an invalid string dtype: {dtype}")
1474
- if not dtype.is_floating_point:
1475
- raise ValueError(
1476
- f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
1477
- )
1478
-
1479
- logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
1480
- dtype_orig = torch.get_default_dtype()
1481
- torch.set_default_dtype(dtype)
1482
- cls._dtype = dtype
1483
- return dtype_orig
1484
-
1485
1479
  @property
1486
1480
  def base_model(self) -> nn.Module:
1487
1481
  """
@@ -1558,7 +1552,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1558
1552
  return True
1559
1553
 
1560
1554
  if is_torch_xpu_available():
1561
- logger.info("Detect using FlashAttention2 (via kernel `kernels-community/flash-attn2`) on XPU.")
1555
+ logger.info(
1556
+ f"Detect using FlashAttention2 (via kernel `{FLASH_ATTN_KERNEL_FALLBACK['flash_attention_2']}`) on XPU."
1557
+ )
1562
1558
  return True
1563
1559
 
1564
1560
  if importlib.util.find_spec("flash_attn") is None:
@@ -1727,6 +1723,22 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1727
1723
 
1728
1724
  return True
1729
1725
 
1726
+ def _grouped_mm_can_dispatch(self) -> bool:
1727
+ """
1728
+ Check the availability of Grouped MM for a given model.
1729
+ """
1730
+
1731
+ if not self._can_set_experts_implementation():
1732
+ raise ValueError(f"{self.__class__.__name__} does not support setting experts implementation.")
1733
+
1734
+ if not is_grouped_mm_available():
1735
+ raise ImportError(
1736
+ "PyTorch Grouped MM requirements in Transformers are not met. Please install torch>=2.9.0."
1737
+ )
1738
+
1739
+ # If no error raised by this point, we can return `True`
1740
+ return True
1741
+
1730
1742
  def _flex_attn_can_dispatch(self, is_init_check: bool = False) -> bool:
1731
1743
  """
1732
1744
  Check the availability of Flex Attention for a given model.
@@ -1755,7 +1767,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1755
1767
  return True
1756
1768
 
1757
1769
  def _check_and_adjust_attn_implementation(
1758
- self, attn_implementation: Optional[str], is_init_check: bool = False
1770
+ self, attn_implementation: str | None, is_init_check: bool = False
1759
1771
  ) -> str:
1760
1772
  """
1761
1773
  Check that the `attn_implementation` exists and is supported by the models, and try to get the kernel from hub if
@@ -1790,14 +1802,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1790
1802
  and is_kernels_available()
1791
1803
  and not is_torch_npu_available()
1792
1804
  ):
1793
- if attn_implementation.endswith("2"):
1794
- applicable_attn_implementation = "kernels-community/flash-attn2"
1795
- if is_torch_xpu_available():
1796
- # On XPU, kernels library is the native implementation
1797
- # Disabling this flag to avoid giving wrong fallbacks on errors and warnings
1798
- requested_original_flash_attn = False
1799
- else:
1800
- applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
1805
+ applicable_attn_implementation = FLASH_ATTN_KERNEL_FALLBACK[attn_implementation.removeprefix("paged|")]
1806
+
1807
+ if is_torch_xpu_available() and attn_implementation.removeprefix("paged|") == "flash_attention_2":
1808
+ # On XPU, kernels library is the native implementation
1809
+ # Disabling this flag to avoid giving wrong fallbacks on errors and warnings
1810
+ requested_original_flash_attn = False
1801
1811
 
1802
1812
  if is_paged:
1803
1813
  applicable_attn_implementation = f"paged|{applicable_attn_implementation}"
@@ -1837,7 +1847,20 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1837
1847
 
1838
1848
  return applicable_attn_implementation
1839
1849
 
1840
- def get_correct_attn_implementation(self, requested_attention: Optional[str], is_init_check: bool = False) -> str:
1850
+ def _check_and_adjust_experts_implementation(self, experts_implementation: str | None) -> str:
1851
+ """
1852
+ Check that the `experts_implementation` exists and is supported by the models.
1853
+
1854
+ Args:
1855
+ experts_implementation (`str` or `None`):
1856
+ The experts implementation to check for existence/validity.
1857
+ Returns:
1858
+ `str`: The final experts implementation to use.
1859
+ """
1860
+ applicable_experts_implementation = self.get_correct_experts_implementation(experts_implementation)
1861
+ return applicable_experts_implementation
1862
+
1863
+ def get_correct_attn_implementation(self, requested_attention: str | None, is_init_check: bool = False) -> str:
1841
1864
  applicable_attention = "sdpa" if requested_attention is None else requested_attention
1842
1865
  if applicable_attention not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
1843
1866
  message = (
@@ -1871,6 +1894,26 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1871
1894
 
1872
1895
  return applicable_attention
1873
1896
 
1897
+ def get_correct_experts_implementation(self, requested_experts: str | None) -> str:
1898
+ applicable_experts = "grouped_mm" if requested_experts is None else requested_experts
1899
+ if applicable_experts not in ["eager", "grouped_mm", "batched_mm"]:
1900
+ message = (
1901
+ f'Specified `experts_implementation="{applicable_experts}"` is not supported. The only possible arguments are '
1902
+ '`experts_implementation="eager"`, `"experts_implementation=grouped_mm"` and `"experts_implementation=batched_mm"`.'
1903
+ )
1904
+ raise ValueError(message)
1905
+
1906
+ # Perform relevant checks
1907
+ if applicable_experts == "grouped_mm":
1908
+ try:
1909
+ self._grouped_mm_can_dispatch()
1910
+ except (ValueError, ImportError) as e:
1911
+ if requested_experts == "grouped_mm":
1912
+ raise e
1913
+ applicable_experts = "eager"
1914
+
1915
+ return applicable_experts
1916
+
1874
1917
  @classmethod
1875
1918
  def _can_set_attn_implementation(cls) -> bool:
1876
1919
  """Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
@@ -1889,7 +1932,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1889
1932
  # If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
1890
1933
  return True
1891
1934
 
1892
- def set_attn_implementation(self, attn_implementation: Union[str, dict]):
1935
+ @classmethod
1936
+ def _can_set_experts_implementation(cls) -> bool:
1937
+ """Detect whether the class supports setting its experts implementation dynamically. It is an ugly check based on
1938
+ opening the file, but avoids maintaining yet another property flag.
1939
+ """
1940
+ class_file = sys.modules[cls.__module__].__file__
1941
+ with open(class_file, "r") as f:
1942
+ code = f.read()
1943
+ # heuristic -> if we the use_experts_implementation decorator is used, then we can set it
1944
+ return "@use_experts_implementation" in code
1945
+
1946
+ def set_attn_implementation(self, attn_implementation: str | dict):
1893
1947
  """
1894
1948
  Set the requested `attn_implementation` for this model.
1895
1949
 
@@ -1988,6 +2042,50 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1988
2042
  if hasattr(subconfig, "_attn_was_changed"):
1989
2043
  del subconfig._attn_was_changed
1990
2044
 
2045
+ def set_experts_implementation(self, experts_implementation: str | dict):
2046
+ """
2047
+ Set the requested `experts_implementation` for this model.
2048
+
2049
+ Args:
2050
+ experts_implementation (`str` or `dict`):
2051
+ The experts implementation to set for this model. It can be either a `str`, in which case it will be
2052
+ dispatched to all submodels if relevant, or a `dict` where keys are the sub_configs name, in which case each
2053
+ submodel will dispatch the corresponding value.
2054
+ """
2055
+ requested_implementation = (
2056
+ experts_implementation
2057
+ if not isinstance(experts_implementation, dict)
2058
+ else experts_implementation.get("", self.config._experts_implementation)
2059
+ )
2060
+
2061
+ if requested_implementation != self.config._experts_implementation:
2062
+ requested_implementation = self._check_and_adjust_experts_implementation(requested_implementation)
2063
+ # Apply the change (on the internal attr, to avoid setting it recursively)
2064
+ self.config._experts_implementation_internal = requested_implementation
2065
+
2066
+ # Apply it to all submodels as well
2067
+ for submodule in self.modules():
2068
+ # We found a submodel (which is not self) with a different config (otherwise, it may be the same "actual model",
2069
+ # e.g. ForCausalLM has a Model inside, but no need to check it again)
2070
+ if (
2071
+ submodule is not self
2072
+ and isinstance(submodule, PreTrainedModel)
2073
+ and submodule.config.__class__ != self.config.__class__
2074
+ ):
2075
+ # Set the experts on the submodule
2076
+ sub_implementation = requested_implementation
2077
+ if isinstance(experts_implementation, dict):
2078
+ for subconfig_key in self.config.sub_configs:
2079
+ # We need to check for exact object match here, with `is`
2080
+ if getattr(self.config, subconfig_key) is submodule.config:
2081
+ sub_implementation = experts_implementation.get(
2082
+ subconfig_key, submodule.config._experts_implementation
2083
+ )
2084
+ break
2085
+ # Check the module can use correctly, otherwise we raise an error if requested experts can't be set for submodule
2086
+ sub_implementation = submodule.get_correct_experts_implementation(sub_implementation)
2087
+ submodule.config._experts_implementation_internal = sub_implementation
2088
+
1991
2089
  def enable_input_require_grads(self):
1992
2090
  """
1993
2091
  Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
@@ -1999,14 +2097,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1999
2097
 
2000
2098
  hooks = []
2001
2099
  seen_modules = set()
2100
+ found_embeddings = False
2002
2101
 
2003
2102
  for module in self.modules():
2004
2103
  if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")):
2005
2104
  continue
2006
2105
 
2007
- input_embeddings = module.get_input_embeddings()
2106
+ try:
2107
+ input_embeddings = module.get_input_embeddings()
2108
+ except NotImplementedError:
2109
+ continue
2008
2110
 
2009
- if input_embeddings is None:
2111
+ if input_embeddings is None or not hasattr(input_embeddings, "register_forward_hook"):
2010
2112
  continue
2011
2113
 
2012
2114
  embedding_id = id(input_embeddings)
@@ -2015,11 +2117,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2015
2117
 
2016
2118
  seen_modules.add(embedding_id)
2017
2119
  hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads))
2120
+ found_embeddings = True
2018
2121
 
2019
2122
  self._require_grads_hooks = hooks
2020
2123
  if hooks:
2021
2124
  # for BC
2022
2125
  self._require_grads_hook = hooks[0]
2126
+ if not found_embeddings:
2127
+ logger.warning_once(
2128
+ f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token "
2129
+ "embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully "
2130
+ "support those features, or set `_input_embed_layer` to the attribute name that holds the embeddings."
2131
+ )
2023
2132
 
2024
2133
  def disable_input_require_grads(self):
2025
2134
  """
@@ -2036,7 +2145,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2036
2145
  if hasattr(self, "_require_grads_hook"):
2037
2146
  del self._require_grads_hook
2038
2147
 
2039
- def get_encoder(self, modality: Optional[str] = None):
2148
+ def get_encoder(self, modality: str | None = None):
2040
2149
  """
2041
2150
  Best-effort lookup of the *encoder* module. If provided with `modality` argument,
2042
2151
  it looks for a modality-specific encoder in multimodal models (e.g. "image_encoder")
@@ -2068,7 +2177,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2068
2177
  # If this is a base transformer model (no encoder/model attributes), return self
2069
2178
  return self
2070
2179
 
2071
- def set_encoder(self, encoder, modality: Optional[str] = None):
2180
+ def set_encoder(self, encoder, modality: str | None = None):
2072
2181
  """
2073
2182
  Symmetric setter. Mirrors the lookup logic used in `get_encoder`.
2074
2183
  """
@@ -2154,14 +2263,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2154
2263
  if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
2155
2264
  if getattr(module, "weight", None) is not None:
2156
2265
  init.normal_(module.weight, mean=0.0, std=std)
2157
- if getattr(module, "bias", None) is not None:
2266
+ if module.bias is not None:
2158
2267
  init.zeros_(module.bias)
2159
2268
  elif isinstance(module, nn.Embedding):
2160
- if getattr(module, "weight", None) is not None:
2161
- init.normal_(module.weight, mean=0.0, std=std)
2162
- # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
2163
- if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
2164
- init.zeros_(module.weight[module.padding_idx])
2269
+ init.normal_(module.weight, mean=0.0, std=std)
2270
+ # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
2271
+ if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
2272
+ init.zeros_(module.weight[module.padding_idx])
2165
2273
  elif isinstance(module, nn.MultiheadAttention):
2166
2274
  # This uses torch's original init
2167
2275
  module._reset_parameters()
@@ -2173,10 +2281,25 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2173
2281
  or "RMSNorm" in module.__class__.__name__
2174
2282
  ):
2175
2283
  # Norms can exist without weights (in which case they are None from torch primitives)
2176
- if hasattr(module, "weight") and module.weight is not None:
2284
+ if getattr(module, "weight", None) is not None:
2177
2285
  init.ones_(module.weight)
2178
- if hasattr(module, "bias") and module.bias is not None:
2286
+ if getattr(module, "bias", None) is not None:
2179
2287
  init.zeros_(module.bias)
2288
+ # And the potential buffers for the BatchNorms
2289
+ if getattr(module, "running_mean", None) is not None:
2290
+ init.zeros_(module.running_mean)
2291
+ init.ones_(module.running_var)
2292
+ init.zeros_(module.num_batches_tracked)
2293
+ # This matches all the usual RotaryEmbeddings modules
2294
+ elif "RotaryEmbedding" in module.__class__.__name__ and hasattr(module, "original_inv_freq"):
2295
+ rope_fn = (
2296
+ ROPE_INIT_FUNCTIONS[module.rope_type]
2297
+ if module.rope_type != "default"
2298
+ else module.compute_default_rope_parameters
2299
+ )
2300
+ buffer_value, _ = rope_fn(module.config)
2301
+ init.copy_(module.inv_freq, buffer_value)
2302
+ init.copy_(module.original_inv_freq, buffer_value)
2180
2303
 
2181
2304
  def _initialize_weights(self, module):
2182
2305
  """
@@ -2281,7 +2404,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2281
2404
 
2282
2405
  tied_mapping = self._tied_weights_keys
2283
2406
  # If the config does not specify any tying, return empty dict
2284
- if not self.config.tie_word_embeddings and not self.config.tie_encoder_decoder:
2407
+ if not self.config.tie_word_embeddings:
2285
2408
  return {}
2286
2409
  # If None, return empty dict
2287
2410
  elif tied_mapping is None:
@@ -2327,7 +2450,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2327
2450
 
2328
2451
  return expanded_tied_weights
2329
2452
 
2330
- def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True):
2453
+ def tie_weights(self, missing_keys: set[str] | None = None, recompute_mapping: bool = True):
2331
2454
  """
2332
2455
  Tie the model weights. If `recompute_mapping=False` (default when called internally), it will rely on the
2333
2456
  `model.all_tied_weights_keys` attribute, containing the `{target: source}` mapping for the tied params.
@@ -2347,30 +2470,26 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2347
2470
 
2348
2471
  tied_keys = list(tied_keys.items())
2349
2472
  for i, (target_param_name, source_param_name) in enumerate(tied_keys):
2350
- # Usually we tie a single target to a single source, but when both are missing we may later tie
2351
- # both the source and target to a third "backup" parameter that is present in the checkpoint, so we use
2352
- # a list here
2353
- target_param_names = [target_param_name]
2354
-
2355
2473
  # This is `from_pretrained` -> let's check symmetrically in case the source key is not present
2356
2474
  if missing_keys is not None:
2357
2475
  remove_from_missing = True
2358
2476
  source_is_there = source_param_name not in missing_keys
2359
2477
  target_is_there = target_param_name not in missing_keys
2360
2478
  # Both are already present -> it means the config is wrong and do not reflect the actual
2361
- # checkpoint -> let's raise a warning and do nothing
2479
+ # checkpoint -> let's raise a warning and NOT tie them
2362
2480
  if source_is_there and target_is_there:
2363
2481
  logger.warning(
2364
2482
  f"The tied weights mapping and config for this model specifies to tie {source_param_name} to "
2365
2483
  f"{target_param_name}, but both are present in the checkpoints, so we will NOT tie them. "
2366
2484
  "You should update the config with `tie_word_embeddings=False` to silence this warning"
2367
2485
  )
2486
+ # Remove from internal attribute to correctly reflect actual tied weights
2487
+ self.all_tied_weights_keys.pop(target_param_name)
2368
2488
  # Skip to next iteration
2369
2489
  continue
2370
2490
  # We're missing the source but we have the target -> we swap them, tying the parameter that exists
2371
2491
  elif not source_is_there and target_is_there:
2372
2492
  target_param_name, source_param_name = source_param_name, target_param_name
2373
- target_param_names = [target_param_name]
2374
2493
  # Both are missing -> check other keys in case more than 2 keys are tied to the same weight
2375
2494
  elif not source_is_there and not target_is_there:
2376
2495
  for target_backup, source_backup in tied_keys[i + 1 :]:
@@ -2379,10 +2498,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2379
2498
  if source_backup == source_param_name:
2380
2499
  target_backup_is_there = target_backup not in missing_keys
2381
2500
  # If the target is present, we found the correct weight to tie into (we know the source is missing)
2501
+ # Note here that we do not tie the missing source right now as well, as it will be done anyway when
2502
+ # the pair (target_backup, source_backup) becomes the main pair (target_param_name, source_param_name)
2382
2503
  if target_backup_is_there:
2383
2504
  source_param_name = target_backup
2384
- # Append the source as well, since both are missing we'll tie both
2385
- target_param_names.append(source_param_name)
2386
2505
  break
2387
2506
  # If we did not break from the loop, it was impossible to find a source key -> let's raise
2388
2507
  else:
@@ -2398,19 +2517,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2398
2517
 
2399
2518
  # Perform the actual tying
2400
2519
  source_param = self.get_parameter_or_buffer(source_param_name)
2401
- for target_param_name in target_param_names:
2402
- if "." in target_param_name:
2403
- parent_name, name = target_param_name.rsplit(".", 1)
2404
- parent = self.get_submodule(parent_name)
2405
- else:
2406
- name = target_param_name
2407
- parent = self
2408
- # Tie the weights
2409
- setattr(parent, name, source_param)
2410
- self._adjust_bias(parent, source_param)
2411
- # Remove from missing if necesary
2412
- if missing_keys is not None and remove_from_missing:
2413
- missing_keys.discard(target_param_name)
2520
+ if "." in target_param_name:
2521
+ parent_name, name = target_param_name.rsplit(".", 1)
2522
+ parent = self.get_submodule(parent_name)
2523
+ else:
2524
+ name = target_param_name
2525
+ parent = self
2526
+ # Tie the weights
2527
+ setattr(parent, name, source_param)
2528
+ self._adjust_bias(parent, source_param)
2529
+ # Remove from missing if necesary
2530
+ if missing_keys is not None and remove_from_missing:
2531
+ missing_keys.discard(target_param_name)
2414
2532
 
2415
2533
  def _adjust_bias(self, output_embeddings, input_embeddings):
2416
2534
  if getattr(output_embeddings, "bias", None) is not None and hasattr(output_embeddings, "weight"):
@@ -2455,8 +2573,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2455
2573
 
2456
2574
  def resize_token_embeddings(
2457
2575
  self,
2458
- new_num_tokens: Optional[int] = None,
2459
- pad_to_multiple_of: Optional[int] = None,
2576
+ new_num_tokens: int | None = None,
2577
+ pad_to_multiple_of: int | None = None,
2460
2578
  mean_resizing: bool = True,
2461
2579
  ) -> nn.Embedding:
2462
2580
  """
@@ -2557,8 +2675,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2557
2675
  def _get_resized_embeddings(
2558
2676
  self,
2559
2677
  old_embeddings: nn.Embedding,
2560
- new_num_tokens: Optional[int] = None,
2561
- pad_to_multiple_of: Optional[int] = None,
2678
+ new_num_tokens: int | None = None,
2679
+ pad_to_multiple_of: int | None = None,
2562
2680
  mean_resizing: bool = True,
2563
2681
  ) -> nn.Embedding:
2564
2682
  """
@@ -2715,7 +2833,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2715
2833
  def _get_resized_lm_head(
2716
2834
  self,
2717
2835
  old_lm_head: nn.Linear,
2718
- new_num_tokens: Optional[int] = None,
2836
+ new_num_tokens: int | None = None,
2719
2837
  transposed: bool = False,
2720
2838
  mean_resizing: bool = True,
2721
2839
  ) -> nn.Linear:
@@ -2912,7 +3030,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2912
3030
  f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
2913
3031
  )
2914
3032
 
2915
- def get_position_embeddings(self) -> Union[nn.Embedding, tuple[nn.Embedding]]:
3033
+ def get_position_embeddings(self) -> nn.Embedding | tuple[nn.Embedding]:
2916
3034
  raise NotImplementedError(
2917
3035
  f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
2918
3036
  f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
@@ -2923,7 +3041,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2923
3041
  Maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
2924
3042
  initialization logic in `_init_weights`.
2925
3043
  """
2926
- if _init_weights:
3044
+ # If we are initializing on meta device, there is no point in trying to run inits
3045
+ if get_torch_context_manager_or_global_device() != torch.device("meta"):
2927
3046
  # Initialize weights
2928
3047
  self.initialize_weights()
2929
3048
  # Tie weights needs to be called here, but it can use the pre-computed `all_tied_weights_keys`
@@ -2961,7 +3080,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2961
3080
  "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
2962
3081
  )
2963
3082
 
2964
- if getattr(self, "_hf_peft_config_loaded", False):
3083
+ needs_embedding_grads = self.main_input_name == "input_ids"
3084
+ # we use that also to detect whether or not we have to raise if embeddings are missing (the submodel might not have embeddings at all)
3085
+ enable_input_grads = needs_embedding_grads or getattr(self, "_hf_peft_config_loaded", False)
3086
+ if enable_input_grads:
2965
3087
  # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
2966
3088
  # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
2967
3089
  # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
@@ -3019,13 +3141,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3019
3141
 
3020
3142
  def save_pretrained(
3021
3143
  self,
3022
- save_directory: Union[str, os.PathLike],
3144
+ save_directory: str | os.PathLike,
3023
3145
  is_main_process: bool = True,
3024
- state_dict: Optional[dict] = None,
3146
+ state_dict: dict | None = None,
3025
3147
  push_to_hub: bool = False,
3026
- max_shard_size: Union[int, str] = "50GB",
3027
- variant: Optional[str] = None,
3028
- token: Optional[Union[str, bool]] = None,
3148
+ max_shard_size: int | str = "50GB",
3149
+ variant: str | None = None,
3150
+ token: str | bool | None = None,
3029
3151
  save_peft_format: bool = True,
3030
3152
  save_original_format: bool = True,
3031
3153
  **kwargs,
@@ -3092,12 +3214,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3092
3214
  " the logger on the traceback to understand the reason why the quantized model is not serializable."
3093
3215
  )
3094
3216
 
3095
- if "save_config" in kwargs:
3096
- warnings.warn(
3097
- "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
3098
- )
3099
- is_main_process = kwargs.pop("save_config")
3100
-
3101
3217
  # we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
3102
3218
  if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
3103
3219
  raise ImportError(
@@ -3172,29 +3288,23 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3172
3288
  current_peft_config = self.peft_config[active_adapter]
3173
3289
  current_peft_config.save_pretrained(save_directory)
3174
3290
 
3175
- # for offloaded modules
3176
- module_map = {}
3177
-
3178
- # Save the model
3291
+ # Get the model state_dict
3179
3292
  if state_dict is None:
3180
- # if any model parameters are offloaded, make module map
3181
- if (
3182
- hasattr(self, "hf_device_map")
3183
- and len(set(self.hf_device_map.values())) > 1
3184
- and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
3185
- ):
3186
- warnings.warn(
3187
- "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
3188
- )
3189
- for name, module in model_to_save.named_modules():
3190
- if name == "":
3191
- continue
3192
- module_state_dict = module.state_dict()
3193
-
3194
- for key in module_state_dict:
3195
- module_map[name + f".{key}"] = module
3196
3293
  state_dict = model_to_save.state_dict()
3197
3294
 
3295
+ # if any model parameters are offloaded, we need to know it for later
3296
+ is_offloaded = False
3297
+ if (
3298
+ hasattr(self, "hf_device_map")
3299
+ and len(set(self.hf_device_map.values())) > 1
3300
+ and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
3301
+ ):
3302
+ is_offloaded = True
3303
+ warnings.warn(
3304
+ "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory "
3305
+ "exceeds the `shard_size` (50GB default)"
3306
+ )
3307
+
3198
3308
  # Translate state_dict from smp to hf if saving with smp >= 1.10
3199
3309
  if IS_SAGEMAKER_MP_POST_1_10:
3200
3310
  for smp_to_hf, _ in smp.state.module_manager.translate_functions:
@@ -3211,76 +3321,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3211
3321
  if self._tp_size is not None:
3212
3322
  state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
3213
3323
 
3214
- # Safetensors does not allow tensor aliasing - we're going to remove aliases before saving
3215
- ptrs = collections.defaultdict(list)
3216
- for name, tensor in state_dict.items():
3217
- if not isinstance(tensor, torch.Tensor):
3218
- # Sometimes in the state_dict we have non-tensor objects.
3219
- # e.g. in bitsandbytes we have some `str` objects in the state_dict
3220
- # In the non-tensor case, fall back to the pointer of the object itself
3221
- ptrs[id(tensor)].append(name)
3222
-
3223
- elif tensor.device.type == "meta":
3224
- # In offloaded cases, there may be meta tensors in the state_dict.
3225
- # For these cases, key by the pointer of the original tensor object
3226
- # (state_dict tensors are detached and therefore no longer shared)
3227
- tensor = self.get_parameter(name)
3228
- ptrs[id(tensor)].append(name)
3229
-
3230
- else:
3231
- ptrs[id_tensor_storage(tensor)].append(name)
3232
-
3233
- shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
3234
-
3235
- # Recursively descend to find tied weight keys
3236
- _tied_weights_keys = set(_get_tied_weight_keys(self))
3237
- error_names = []
3238
- to_delete_names = set()
3239
- for names in shared_ptrs.values():
3240
- # Removing the keys which are declared as known duplicates on
3241
- # load. This allows to make sure the name which is kept is consistent.
3242
- if _tied_weights_keys is not None:
3243
- found = 0
3244
- for name in sorted(names):
3245
- matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
3246
- if matches_pattern and name in state_dict:
3247
- found += 1
3248
- if found < len(names):
3249
- to_delete_names.add(name)
3250
- # We are entering a place where the weights and the transformers configuration do NOT match.
3251
- shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
3252
- # Those are actually tensor sharing but disjoint from each other, we can safely clone them
3253
- # Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
3254
- for name in disjoint_names:
3255
- state_dict[name] = state_dict[name].clone()
3256
-
3257
- # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
3258
- # If the link between tensors was done at runtime then `from_pretrained` will not get
3259
- # the key back leading to random tensor. A proper warning will be shown
3260
- # during reload (if applicable), but since the file is not necessarily compatible with
3261
- # the config, better show a proper warning.
3262
- shared_names, identical_names = _find_identical(shared_names, state_dict)
3263
- # delete tensors that have identical storage
3264
- for inames in identical_names:
3265
- known = inames.intersection(to_delete_names)
3266
- for name in known:
3267
- del state_dict[name]
3268
- unknown = inames.difference(to_delete_names)
3269
- if len(unknown) > 1:
3270
- error_names.append(unknown)
3271
-
3272
- if shared_names:
3273
- error_names.extend(shared_names)
3274
-
3275
- if len(error_names) > 0:
3276
- raise RuntimeError(
3277
- f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. We found `_tied_weights_keys` to be: {_tied_weights_keys}.\n"
3278
- "This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
3279
- )
3324
+ # Remove tied weights as safetensors do not handle them
3325
+ state_dict = remove_tied_weights_from_state_dict(state_dict, model_to_save)
3280
3326
 
3281
3327
  # Revert all renaming and/or weight operations
3282
3328
  if save_original_format:
3283
- state_dict = revert_weight_conversion(self, state_dict)
3329
+ state_dict = revert_weight_conversion(model_to_save, state_dict)
3284
3330
 
3285
3331
  # Shard the model if it is too big.
3286
3332
  if not _hf_peft_config_loaded:
@@ -3320,47 +3366,39 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3320
3366
  and reg.fullmatch(filename_no_suffix) is not None
3321
3367
  ):
3322
3368
  os.remove(full_filename)
3369
+
3323
3370
  # Save the model
3324
- filename_to_tensors = state_dict_split.filename_to_tensors.items()
3325
- if module_map:
3326
- filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards")
3327
- for shard_file, tensors in filename_to_tensors:
3328
- shard = {}
3329
- for tensor in tensors:
3330
- if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
3331
- full_tensor = state_dict[tensor].full_tensor()
3371
+ for shard_file, tensor_names in logging.tqdm(
3372
+ state_dict_split.filename_to_tensors.items(), desc="Writing model shards"
3373
+ ):
3374
+ filename = os.path.join(save_directory, shard_file)
3375
+ shard_state_dict = {}
3376
+ for tensor_name in tensor_names:
3377
+ # Get the tensor, and remove it from state_dict to avoid keeping the ref
3378
+ tensor = state_dict.pop(tensor_name)
3379
+
3380
+ # In case of TP, get the full parameter back
3381
+ if _is_dtensor_available and isinstance(tensor, DTensor):
3382
+ tensor = tensor.full_tensor()
3332
3383
  # to get the correctly ordered tensor we need to repack if packed
3333
- if _get_parameter_tp_plan(tensor, self._tp_plan) == "local_packed_rowwise":
3334
- full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
3335
- shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
3336
- else:
3337
- shard[tensor] = state_dict[tensor].contiguous()
3338
- # delete reference, see https://github.com/huggingface/transformers/pull/34890
3339
- del state_dict[tensor]
3340
-
3341
- # remake shard with onloaded parameters if necessary
3342
- if module_map:
3343
- # init state_dict for this shard
3344
- shard_state_dict = dict.fromkeys(shard, "")
3345
- for module_name in shard:
3346
- # note that get_state_dict_from_offload can update with meta tensors
3347
- # if both a parent module and its descendant are offloaded
3348
- tensor = shard_state_dict[module_name]
3349
- if tensor == "" or (isinstance(tensor, torch.Tensor) and tensor.device.type == "meta"):
3350
- # update state dict with onloaded parameters
3351
- module = module_map[module_name]
3352
- shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
3353
-
3354
- # assign shard to be the completed state dict
3355
- shard = shard_state_dict
3356
- del shard_state_dict
3357
- gc.collect()
3358
-
3359
- # TODO: we should def parallelize this we are otherwise just waiting
3360
- # too much before scheduling the next write when its in a different file
3361
- safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
3362
-
3363
- del state_dict
3384
+ if _get_parameter_tp_plan(tensor_name, self._tp_plan) == "local_packed_rowwise":
3385
+ tensor = repack_weights(tensor, -1, self._tp_size, 2)
3386
+
3387
+ # If the param was offloaded, we need to load it back from disk to resave it. It's a strange pattern,
3388
+ # but it would otherwise not be contained in the saved shard if we were to simply move the file
3389
+ # or something
3390
+ if is_offloaded and tensor.device.type == "meta":
3391
+ tensor = load_offloaded_parameter(model_to_save, tensor_name)
3392
+
3393
+ # only do contiguous after it's permuted correctly in case of TP
3394
+ shard_state_dict[tensor_name] = tensor.contiguous()
3395
+
3396
+ # TODO: it would be very nice to do the writing concurrently, but safetensors never releases the GIL,
3397
+ # so it's not possible for now....
3398
+ # Write the shard to disk
3399
+ safe_save_file(shard_state_dict, filename, metadata=metadata)
3400
+ # Cleanup the data before next loop (important with offloading, so we don't blowup cpu RAM)
3401
+ del shard_state_dict
3364
3402
 
3365
3403
  if index is None:
3366
3404
  path_to_weights = os.path.join(save_directory, weights_name)
@@ -3537,19 +3575,26 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3537
3575
  return super().float(*args)
3538
3576
 
3539
3577
  @classmethod
3540
- def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
3578
+ def get_init_context(cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool):
3579
+ # Need to instantiate with correct dtype
3580
+ init_contexts = [local_torch_dtype(dtype, cls.__name__)]
3541
3581
  if is_deepspeed_zero3_enabled():
3542
3582
  import deepspeed
3543
3583
 
3544
- init_contexts = [no_init_weights()]
3545
3584
  # We cannot initialize the model on meta device with deepspeed when not quantized
3546
3585
  if not is_quantized and not _is_ds_init_called:
3547
3586
  logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
3548
- init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
3587
+ init_contexts.extend(
3588
+ [
3589
+ init.no_init_weights(),
3590
+ deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
3591
+ set_zero3_state(),
3592
+ ]
3593
+ )
3549
3594
  elif is_quantized:
3550
- init_contexts.extend([init_empty_weights(), set_quantized_state()])
3595
+ init_contexts.extend([torch.device("meta"), set_quantized_state()])
3551
3596
  else:
3552
- init_contexts = [no_init_weights(), init_empty_weights()]
3597
+ init_contexts.append(torch.device("meta"))
3553
3598
 
3554
3599
  return init_contexts
3555
3600
 
@@ -3574,7 +3619,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3574
3619
 
3575
3620
  # This is a context manager to override the default kernel mapping
3576
3621
  # We are calling kernelize inside this context manager using the use_kernels setter
3577
- with use_kernel_mapping(kernel_config.kernel_mapping):
3622
+ # Param inherit_mapping should be False to avoid still loading kernel from remote
3623
+ inherit_mapping = not kernel_config.use_local_kernel
3624
+ with use_kernel_mapping(kernel_config.kernel_mapping, inherit_mapping=inherit_mapping):
3578
3625
  self.use_kernels = True
3579
3626
  # We use the default kernel mapping in .integrations.hub_kernels
3580
3627
  else:
@@ -3583,19 +3630,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3583
3630
  self.use_kernels = False
3584
3631
 
3585
3632
  @classmethod
3586
- @restore_default_dtype
3587
3633
  def from_pretrained(
3588
3634
  cls: type[SpecificPreTrainedModelType],
3589
- pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
3635
+ pretrained_model_name_or_path: str | os.PathLike | None,
3590
3636
  *model_args,
3591
- config: Optional[Union[PreTrainedConfig, str, os.PathLike]] = None,
3592
- cache_dir: Optional[Union[str, os.PathLike]] = None,
3637
+ config: PreTrainedConfig | str | os.PathLike | None = None,
3638
+ cache_dir: str | os.PathLike | None = None,
3593
3639
  ignore_mismatched_sizes: bool = False,
3594
3640
  force_download: bool = False,
3595
3641
  local_files_only: bool = False,
3596
- token: Optional[Union[str, bool]] = None,
3642
+ token: str | bool | None = None,
3597
3643
  revision: str = "main",
3598
- use_safetensors: Optional[bool] = True,
3644
+ use_safetensors: bool | None = True,
3599
3645
  weights_only: bool = True,
3600
3646
  **kwargs,
3601
3647
  ) -> SpecificPreTrainedModelType:
@@ -3692,10 +3738,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3692
3738
  "org/model@main"
3693
3739
  "org/model:custom_kernel"
3694
3740
  "org/model@v1.2.3:custom_kernel"
3741
+ experts_implementation (`str`, *optional*):
3742
+ The experts implementation to use in the model (if relevant). Can be any of:
3743
+
3744
+ - `"eager"` (sequential implementation of the experts matrix multiplications).
3745
+ - `"batched_mm"` (using [`torch.bmm`](https://pytorch.org/docs/stable/generated/torch.bmm.html)).
3746
+ - `"grouped_mm"` (using [`torch._grouped_mm`](https://docs.pytorch.org/docs/main/generated/torch.nn.functional.grouped_mm.html)).
3747
+
3748
+ By default, if available, `grouped_mm` will be used for torch>=2.9.0. The default is otherwise the sequential `"eager"` implementation.
3695
3749
 
3696
3750
  > Parameters for big model inference
3697
3751
 
3698
- dtype (`str` or `torch.dtype`, *optional*):
3752
+ dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"`):
3699
3753
  Override the default `torch_dtype` and load the model under a specific `dtype`. The different options
3700
3754
  are:
3701
3755
 
@@ -3915,8 +3969,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3915
3969
  if "attn_implementation" in kwargs:
3916
3970
  config._attn_implementation = kwargs.pop("attn_implementation")
3917
3971
 
3918
- hf_quantizer, config, dtype, device_map = get_hf_quantizer(
3919
- config, quantization_config, dtype, device_map, weights_only, user_agent
3972
+ if "experts_implementation" in kwargs:
3973
+ config._experts_implementation = kwargs.pop("experts_implementation")
3974
+
3975
+ hf_quantizer, config, device_map = get_hf_quantizer(
3976
+ config, quantization_config, device_map, weights_only, user_agent
3920
3977
  )
3921
3978
 
3922
3979
  if gguf_file:
@@ -3963,33 +4020,29 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3963
4020
  ]
3964
4021
 
3965
4022
  # Find the correct dtype based on current state
3966
- config, dtype, dtype_orig = _get_dtype(
3967
- cls, dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
4023
+ config, dtype = _get_dtype(
4024
+ dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only, hf_quantizer
3968
4025
  )
3969
4026
 
3970
4027
  config.name_or_path = pretrained_model_name_or_path
3971
- model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
4028
+ model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called)
3972
4029
  config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
3973
4030
  with ContextManagers(model_init_context):
3974
4031
  # Let's make sure we don't run the init function of buffer modules
3975
4032
  model = cls(config, *model_args, **model_kwargs)
3976
4033
 
4034
+ if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
4035
+ hf_quantizer.preprocess_model(
4036
+ model=model,
4037
+ dtype=dtype,
4038
+ device_map=device_map,
4039
+ checkpoint_files=checkpoint_files,
4040
+ use_kernels=use_kernels,
4041
+ )
4042
+
3977
4043
  # Obtain the weight conversion mapping for this model if any are registered
3978
4044
  weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer)
3979
4045
 
3980
- # make sure we use the model's config since the __init__ call might have copied it
3981
- config = model.config
3982
-
3983
- if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
3984
- hf_quantizer.preprocess_model(
3985
- model=model,
3986
- device_map=device_map,
3987
- keep_in_fp32_modules=model._keep_in_fp32_modules, # TODO prob no longer needed?
3988
- config=config,
3989
- checkpoint_files=checkpoint_files,
3990
- use_kernels=use_kernels,
3991
- )
3992
-
3993
4046
  if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
3994
4047
  model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)
3995
4048
 
@@ -3997,10 +4050,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3997
4050
  if device_map is not None:
3998
4051
  device_map = _get_device_map(model, device_map, max_memory, hf_quantizer)
3999
4052
 
4000
- # restore default dtype
4001
- if dtype_orig is not None:
4002
- torch.set_default_dtype(dtype_orig)
4003
-
4004
4053
  # Finalize model weight initialization
4005
4054
  model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs = cls._load_pretrained_model(
4006
4055
  model,
@@ -4011,6 +4060,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4011
4060
  sharded_metadata=sharded_metadata,
4012
4061
  device_map=device_map,
4013
4062
  disk_offload_folder=offload_folder,
4063
+ offload_buffers=offload_buffers,
4014
4064
  dtype=dtype,
4015
4065
  hf_quantizer=hf_quantizer,
4016
4066
  device_mesh=device_mesh,
@@ -4018,7 +4068,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4018
4068
  weight_mapping=weight_conversions,
4019
4069
  )
4020
4070
 
4021
- model.eval() # Set model in evaluation mode to deactivate DropOut modules by default
4071
+ model.eval() # Set model in evaluation mode to deactivate Dropout modules by default
4022
4072
  model.set_use_kernels(use_kernels, kernel_config)
4023
4073
 
4024
4074
  # If it is a model with generation capabilities, attempt to load generation files (generation config,
@@ -4034,13 +4084,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4034
4084
  **kwargs,
4035
4085
  )
4036
4086
 
4037
- # for device_map="auto" : dispatch model with hooks on all devices if necessary
4038
- if device_map is not None and device_mesh is None:
4087
+ # If the device_map has more than 1 device: dispatch model with hooks on all devices
4088
+ if device_map is not None and len(set(device_map.values())) > 1:
4039
4089
  accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers)
4040
4090
 
4041
4091
  if hf_quantizer is not None:
4042
4092
  model.hf_quantizer = hf_quantizer
4043
- hf_quantizer.postprocess_model(model, config=config) # usually a no-op but sometimes needed
4093
+ hf_quantizer.postprocess_model(
4094
+ model
4095
+ ) # usually a no-op but sometimes needed, e.g to remove the quant config when dequantizing
4044
4096
 
4045
4097
  if _adapter_model_path is not None:
4046
4098
  adapter_kwargs["key_mapping"] = key_mapping
@@ -4065,18 +4117,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4065
4117
  def _load_pretrained_model(
4066
4118
  cls,
4067
4119
  model: "PreTrainedModel",
4068
- state_dict: Optional[dict],
4069
- checkpoint_files: Optional[list[str]],
4070
- pretrained_model_name_or_path: Optional[str],
4120
+ state_dict: dict | None,
4121
+ checkpoint_files: list[str] | None,
4122
+ pretrained_model_name_or_path: str | None,
4071
4123
  ignore_mismatched_sizes: bool = False,
4072
- sharded_metadata: Optional[dict] = None,
4073
- device_map: Optional[dict] = None,
4074
- disk_offload_folder: Optional[str] = None,
4075
- dtype: Optional[torch.dtype] = None,
4076
- hf_quantizer: Optional[HfQuantizer] = None,
4124
+ sharded_metadata: dict | None = None,
4125
+ device_map: dict | None = None,
4126
+ disk_offload_folder: str | None = None,
4127
+ offload_buffers: bool = False,
4128
+ dtype: torch.dtype | None = None,
4129
+ hf_quantizer: HfQuantizer | None = None,
4077
4130
  device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
4078
4131
  weights_only: bool = True,
4079
- weight_mapping: Optional[Sequence[WeightConverter | WeightRenaming]] = None,
4132
+ weight_mapping: Sequence[WeightConverter | WeightRenaming] | None = None,
4080
4133
  ):
4081
4134
  is_quantized = hf_quantizer is not None
4082
4135
  is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
@@ -4086,6 +4139,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4086
4139
 
4087
4140
  # Model's definition arriving here is final (TP hooks added, quantized layers replaces)
4088
4141
  expected_keys = list(model.state_dict().keys())
4142
+
4089
4143
  if logger.level >= logging.WARNING:
4090
4144
  verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None))
4091
4145
 
@@ -4108,7 +4162,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4108
4162
  expanded_device_map = expand_device_map(device_map, expected_keys)
4109
4163
  caching_allocator_warmup(model, expanded_device_map, hf_quantizer)
4110
4164
 
4111
- tp_plan = getattr(model, "_tp_plan", None)
4112
4165
  error_msgs = []
4113
4166
 
4114
4167
  if is_deepspeed_zero3_enabled() and not is_quantized:
@@ -4117,9 +4170,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4117
4170
  for ckpt_file in checkpoint_files:
4118
4171
  merged_state_dict.update(load_state_dict(ckpt_file, map_location="cpu", weights_only=weights_only))
4119
4172
  state_dict = merged_state_dict
4120
- error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
4173
+ error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict)
4121
4174
  # This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
4122
- missing_keys, unexpected_keys, mismatched_keys, conversion_errors = set(), set(), set(), set()
4175
+ unexpected_keys, mismatched_keys, conversion_errors = set(), set(), set()
4123
4176
  else:
4124
4177
  all_pointer = set()
4125
4178
  # Checkpoints are safetensors
@@ -4143,17 +4196,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4143
4196
 
4144
4197
  missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, conversion_errors = (
4145
4198
  convert_and_load_state_dict_in_model(
4146
- model,
4147
- merged_state_dict,
4148
- weight_mapping,
4149
- tp_plan,
4150
- hf_quantizer,
4151
- dtype,
4152
- device_map,
4153
- model.dtype_plan,
4154
- device_mesh,
4155
- disk_offload_index,
4156
- disk_offload_folder,
4199
+ model=model,
4200
+ state_dict=merged_state_dict,
4201
+ weight_mapping=weight_mapping,
4202
+ tp_plan=model._tp_plan,
4203
+ hf_quantizer=hf_quantizer,
4204
+ dtype=dtype,
4205
+ device_map=device_map,
4206
+ dtype_plan=model.dtype_plan,
4207
+ device_mesh=device_mesh,
4208
+ disk_offload_index=disk_offload_index,
4209
+ disk_offload_folder=disk_offload_folder,
4210
+ offload_buffers=offload_buffers,
4157
4211
  )
4158
4212
  )
4159
4213
 
@@ -4164,12 +4218,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4164
4218
  # Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
4165
4219
  model.mark_tied_weights_as_initialized()
4166
4220
 
4167
- # Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when
4168
- # loading the weights as they are not in the loaded state dict)
4169
- miss_and_mismatched = missing_keys | {k[0] for k in mismatched_keys}
4170
- model._move_missing_keys_from_meta_to_cpu(miss_and_mismatched, dtype, hf_quantizer)
4221
+ # Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from
4222
+ # meta device (because they were not moved when loading the weights as they were not in the loaded state dict)
4223
+ missing_and_mismatched = missing_keys | {k[0] for k in mismatched_keys}
4224
+ model._move_missing_keys_from_meta_to_device(missing_and_mismatched, device_map, device_mesh, hf_quantizer)
4171
4225
 
4172
- # Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialzed` flag)
4226
+ # Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialized` flag)
4173
4227
  model._initialize_missing_keys(is_quantized)
4174
4228
 
4175
4229
  # Tie the weights
@@ -4178,34 +4232,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4178
4232
  # Adjust missing and unexpected keys
4179
4233
  missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(missing_keys, unexpected_keys)
4180
4234
 
4181
- # Post-processing for tensor parallelism
4182
- if device_mesh is not None:
4183
- # When using TP, the device map is a single device for all parameters
4184
- tp_device = list(device_map.values())[0]
4185
- # This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
4186
- # not part of the state_dict (persistent=False)
4187
- for buffer in model.buffers(): # TODO to avoid this buffer could be added to the ckpt
4188
- if buffer.device != tp_device:
4189
- buffer.data = buffer.to(tp_device)
4190
-
4191
- # In this case, the top-most task module weights were not moved to device and parallelized as they
4192
- # were not part of the loaded weights: do it now
4193
- if missing_keys:
4194
- state_dict = model.state_dict()
4195
- for name in missing_keys:
4196
- param = state_dict[name]
4197
- # Shard the param
4198
- shard_and_distribute_module(
4199
- model,
4200
- param.to(tp_device),
4201
- param,
4202
- name,
4203
- None,
4204
- False,
4205
- device_mesh.get_local_rank(),
4206
- device_mesh,
4207
- )
4208
-
4209
4235
  log_state_dict_report(
4210
4236
  model=model,
4211
4237
  pretrained_model_name_or_path=pretrained_model_name_or_path,
@@ -4381,7 +4407,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4381
4407
  )
4382
4408
  self._use_kernels = False
4383
4409
 
4384
- def get_compiled_call(self, compile_config: Optional[CompileConfig]) -> Callable:
4410
+ def get_compiled_call(self, compile_config: CompileConfig | None) -> Callable:
4385
4411
  """Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
4386
4412
  non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
4387
4413
  want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
@@ -4403,33 +4429,54 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4403
4429
  def is_backend_compatible(cls):
4404
4430
  return cls._supports_attention_backend
4405
4431
 
4406
- def _move_missing_keys_from_meta_to_cpu(
4407
- self, missing_keys: list[str], dtype: torch.dtype, hf_quantizer: Optional[HfQuantizer]
4432
+ def _move_missing_keys_from_meta_to_device(
4433
+ self,
4434
+ missing_keys: list[str],
4435
+ device_map: dict | None,
4436
+ device_mesh: "torch.distributed.device_mesh.DeviceMesh | None",
4437
+ hf_quantizer: HfQuantizer | None,
4408
4438
  ) -> None:
4409
- """Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts) back
4410
- from meta device to cpu.
4439
+ """Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts)
4440
+ back from meta device to their device according to the `device_map` if any, else cpu. Takes care of sharding those
4441
+ missing parameters if `device_mesh` is provided, i.e. we are using TP.
4442
+ All non-persistent buffers are also moved back to the correct device (they are not part of the state_dict, but are
4443
+ not missing either).
4411
4444
  """
4412
4445
  is_quantized = hf_quantizer is not None
4446
+ # This is the only case where we do not initialize the model on meta device, so we don't have to do anything here
4447
+ if is_deepspeed_zero3_enabled() and not is_quantized:
4448
+ return
4413
4449
 
4414
4450
  # In this case we need to move everything back
4415
4451
  if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
4416
- # We only do it for the parameters, as the buffers are not initialized on the meta device by default
4417
4452
  for key, param in self.named_parameters():
4418
- value = torch.empty_like(param, dtype=dtype, device="cpu")
4453
+ value = torch.empty_like(param, device="cpu")
4454
+ _load_parameter_into_model(self, key, value)
4455
+ for key, buffer in self.named_buffers():
4456
+ value = torch.empty_like(buffer, device="cpu")
4419
4457
  _load_parameter_into_model(self, key, value)
4420
4458
  return
4421
4459
 
4422
- model_state_dict = self.state_dict()
4423
4460
  # The tied weight keys are in the "missing" usually, but they should not be moved (they will be tied anyway)
4424
4461
  # This is especially important because if they are moved, they will lose the `_is_hf_initialized` flag, and they
4425
4462
  # will be re-initialized for nothing (which can be quite long)
4426
4463
  for key in missing_keys - self.all_tied_weights_keys.keys():
4427
- param = model_state_dict[key]
4428
- # Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them
4429
- if param.device == torch.device("meta"):
4430
- value = torch.empty_like(param, dtype=dtype, device="cpu")
4431
- if not is_quantized or not hf_quantizer.param_needs_quantization(self, key):
4432
- _load_parameter_into_model(self, key, value)
4464
+ param = self.get_parameter_or_buffer(key)
4465
+ param_device = get_device(device_map, key, valid_torch_device=True)
4466
+ value = torch.empty_like(param, device=param_device)
4467
+ # For TP, we may need to shard the param
4468
+ if device_mesh is not None:
4469
+ shard_and_distribute_module(
4470
+ self, value, param, key, None, False, device_mesh.get_local_rank(), device_mesh
4471
+ )
4472
+ # Otherwise, just move it to device
4473
+ else:
4474
+ _load_parameter_into_model(self, key, value)
4475
+ # We need to move back non-persistent buffers as well, as they are not part of loaded weights anyway
4476
+ for key, buffer in self.named_non_persistent_buffers():
4477
+ buffer_device = get_device(device_map, key, valid_torch_device=True)
4478
+ value = torch.empty_like(buffer, device=buffer_device)
4479
+ _load_parameter_into_model(self, key, value)
4433
4480
 
4434
4481
  def _initialize_missing_keys(self, is_quantized: bool) -> None:
4435
4482
  """
@@ -4457,8 +4504,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4457
4504
  ) -> tuple[set[str], set[str]]:
4458
4505
  """Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid
4459
4506
  raising unneeded warnings/errors.
4460
- Also, set the `_is_hf_initialized` on tied weight keys, to avoid initializing them as they are going to
4461
- be tied anyway.
4462
4507
  """
4463
4508
  # Old checkpoints may have keys for rotary_emb.inv_freq forach layer, however we moved this buffer to the main model
4464
4509
  # (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
@@ -4517,6 +4562,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4517
4562
 
4518
4563
  raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.")
4519
4564
 
4565
+ def named_non_persistent_buffers(
4566
+ self, recurse: bool = True, remove_duplicate: bool = True
4567
+ ) -> Iterator[tuple[str, torch.Tensor]]:
4568
+ """Similar to `named_buffers`, but only yield non-persistent ones. It is handy as it's not perfectly straightforward
4569
+ to know if they are persistent or not"""
4570
+ for name, tensor in self.named_buffers(recurse=recurse, remove_duplicate=remove_duplicate):
4571
+ # We have to grab the parent here, as the attribute `_non_persistent_buffers_set` is on the immediate
4572
+ # parent only
4573
+ parent, buf_name = name.rsplit(".", 1) if "." in name else ("", name)
4574
+ parent = self.get_submodule(parent)
4575
+ if buf_name in parent._non_persistent_buffers_set:
4576
+ yield name, tensor
4577
+
4520
4578
  def train(self, mode: bool = True):
4521
4579
  out = super().train(mode)
4522
4580
  if self.use_kernels:
@@ -4559,7 +4617,7 @@ def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
4559
4617
  return model
4560
4618
 
4561
4619
 
4562
- def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
4620
+ def is_accelerator_device(device: str | int | torch.device) -> bool:
4563
4621
  """Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
4564
4622
  a proper `torch.device`.
4565
4623
  """
@@ -4569,7 +4627,41 @@ def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
4569
4627
  return torch.device(device).type not in ["meta", "cpu"]
4570
4628
 
4571
4629
 
4572
- def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: Optional[HfQuantizer]):
4630
+ def get_total_byte_count(
4631
+ model: PreTrainedModel, accelerator_device_map: dict, hf_quantizer: HfQuantizer | None = None
4632
+ ):
4633
+ """
4634
+ This utility function calculates the total bytes count needed to load the model on each device.
4635
+ This is useful for caching_allocator_warmup as we want to know how much cache we need to pre-allocate.
4636
+ """
4637
+
4638
+ total_byte_count = defaultdict(lambda: 0)
4639
+ tied_param_names = model.all_tied_weights_keys.keys()
4640
+ tp_plan = model._tp_plan if torch.distributed.is_available() and torch.distributed.is_initialized() else []
4641
+
4642
+ for param_name, device in accelerator_device_map.items():
4643
+ # Skip if the parameter has already been accounted for (tied weights)
4644
+ if param_name in tied_param_names:
4645
+ continue
4646
+
4647
+ param = model.get_parameter_or_buffer(param_name)
4648
+
4649
+ if hf_quantizer is not None:
4650
+ dtype_size = hf_quantizer.param_element_size(model, param_name, param)
4651
+ else:
4652
+ dtype_size = param.element_size()
4653
+
4654
+ param_byte_count = param.numel() * dtype_size
4655
+
4656
+ if len(tp_plan) > 0:
4657
+ is_part_of_plan = _get_parameter_tp_plan(param_name, tp_plan, is_weight=True) is not None
4658
+ param_byte_count //= torch.distributed.get_world_size() if is_part_of_plan else 1
4659
+
4660
+ total_byte_count[device] += param_byte_count
4661
+ return total_byte_count
4662
+
4663
+
4664
+ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: HfQuantizer | None):
4573
4665
  """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
4574
4666
  device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
4575
4667
  the model, which is actually the loading speed bottleneck.
@@ -4588,8 +4680,6 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
4588
4680
  - Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
4589
4681
  However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
4590
4682
  """
4591
- factor = 2 if hf_quantizer is None else hf_quantizer.get_accelerator_warm_up_factor()
4592
-
4593
4683
  # Remove disk, cpu and meta devices, and cast to proper torch.device
4594
4684
  accelerator_device_map = {
4595
4685
  param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
@@ -4597,40 +4687,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
4597
4687
  if not accelerator_device_map:
4598
4688
  return
4599
4689
 
4600
- tp_plan = getattr(model, "_tp_plan", []) or []
4601
- tp_plan_regex = (
4602
- re.compile("|".join([re.escape(plan) for plan in tp_plan]))
4603
- if _torch_distributed_available and torch.distributed.is_initialized()
4604
- else None
4605
- )
4606
- total_byte_count = defaultdict(lambda: 0)
4607
- tied_param_names = model.all_tied_weights_keys.keys()
4608
- for param_name, device in accelerator_device_map.items():
4609
- # Skip if the parameter has already been accounted for (tied weights)
4610
- if param_name in tied_param_names:
4611
- continue
4612
-
4613
- # For example in the case of MXFP4 quantization, we need to update the param name to the original param name
4614
- # because the checkpoint contains blocks, and scales, but since we are dequantizing, we need to use the original param name
4615
- if hf_quantizer is not None:
4616
- param_name = hf_quantizer.get_param_name(param_name)
4617
-
4618
- try:
4619
- param = model.get_parameter_or_buffer(param_name)
4620
- except AttributeError:
4621
- # TODO: for now let's skip if we can't find the parameters
4622
- if hf_quantizer is not None:
4623
- continue
4624
- raise AttributeError(f"Parameter {param_name} not found in model")
4625
-
4626
- # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
4627
- param_byte_count = param.numel() * param.element_size()
4628
-
4629
- if tp_plan_regex is not None:
4630
- generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
4631
- param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
4632
-
4633
- total_byte_count[device] += param_byte_count
4690
+ total_byte_count = get_total_byte_count(model, accelerator_device_map, hf_quantizer)
4634
4691
 
4635
4692
  # This will kick off the caching allocator to avoid having to Malloc afterwards
4636
4693
  for device, byte_count in total_byte_count.items():
@@ -4650,9 +4707,9 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
4650
4707
  unused_memory = torch_accelerator_module.memory_reserved(
4651
4708
  index
4652
4709
  ) - torch_accelerator_module.memory_allocated(index)
4653
- byte_count = max(0, byte_count - unused_memory)
4654
- # Allocate memory
4655
- _ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
4710
+ byte_count = int(max(0, byte_count - unused_memory))
4711
+ # We divide by 2 here as we allocate in fp16
4712
+ _ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)
4656
4713
 
4657
4714
 
4658
4715
  class AttentionInterface(GeneralInterface):