transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -19,14 +19,21 @@ from .base import HfQuantizer
19
19
  if TYPE_CHECKING:
20
20
  from ..modeling_utils import PreTrainedModel
21
21
 
22
- from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
22
+ from ..utils import (
23
+ is_accelerate_available,
24
+ is_fbgemm_gpu_available,
25
+ is_kernels_available,
26
+ is_torch_available,
27
+ is_torch_cuda_available,
28
+ is_torch_xpu_available,
29
+ logging,
30
+ )
23
31
  from .quantizers_utils import get_module_from_name
24
32
 
25
33
 
26
34
  if is_torch_available():
27
35
  import torch
28
36
 
29
-
30
37
  logger = logging.get_logger(__name__)
31
38
 
32
39
 
@@ -35,54 +42,41 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
35
42
  FP8 quantization using fbgemm kernels
36
43
  """
37
44
 
38
- requires_parameters_quantization = True
39
45
  requires_calibration = False
40
46
 
41
- required_packages = ["fbgemm-gpu", "accelerate"]
42
-
43
47
  def __init__(self, quantization_config, **kwargs):
44
48
  super().__init__(quantization_config, **kwargs)
45
- self.quantization_config = quantization_config
46
49
 
47
50
  def validate_environment(self, *args, **kwargs):
48
- if not is_torch_available():
49
- raise ImportError(
50
- "Using fbgemm fp8 quantization requires torch >= 2.1.0"
51
- "Please install the latest version of torch ( pip install --upgrade torch )"
52
- )
53
- if not is_fbgemm_gpu_available():
51
+ if not is_torch_cuda_available() and not is_torch_xpu_available():
52
+ raise ImportError("Using fbgemm fp8 quantization requires a GPU or XPU")
53
+ if is_torch_xpu_available() and not is_kernels_available():
54
+ raise ImportError("Using FP8 fbgemm on XPU requires kernels (`pip install kernels`)")
55
+ if is_torch_cuda_available() and not is_fbgemm_gpu_available():
54
56
  raise ImportError(
55
- "Using fbgemm fp8 quantization requires fbgemm-gpu library"
57
+ "Loading an FP8 fbgemm quantized model on CUDA requires fbgemm-gpu library"
56
58
  "Please install the latest version of fbgemm-gpu library by following : https://pytorch.org/FBGEMM/fbgemm_gpu-development/InstallationInstructions.html#fbgemm-gpu-install-libraries"
57
59
  )
58
-
59
60
  if not is_accelerate_available():
60
61
  raise ImportError(
61
62
  "Loading an FP8 quantized model requires accelerate (`pip install --upgrade accelerate`)"
62
63
  )
63
-
64
- if not torch.cuda.is_available():
65
- raise RuntimeError("Using FP8 quantized models with fbgemm kernels requires a GPU")
66
-
67
- compute_capability = torch.cuda.get_device_capability()
68
- major, minor = compute_capability
69
- if major < 9:
70
- raise ValueError(
71
- "FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)"
72
- )
64
+ if is_torch_cuda_available():
65
+ compute_capability = torch.cuda.get_device_capability()
66
+ major, _ = compute_capability
67
+ if major < 9:
68
+ raise ValueError(
69
+ "FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)"
70
+ )
73
71
 
74
72
  device_map = kwargs.get("device_map")
75
73
  if device_map is None:
76
74
  logger.warning_once(
77
- "You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set "
78
- "your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. "
75
+ "You have loaded an FP8 model on CPU and have a CUDA/XPU device available, make sure to set "
76
+ "your model on a GPU/XPU device in order to run your model. To remove this warning, pass device_map = 'cuda' or 'xpu' or 'auto'. "
79
77
  )
80
- elif device_map is not None:
81
- if (
82
- not self.pre_quantized
83
- and isinstance(device_map, dict)
84
- and ("cpu" in device_map.values() or "disk" in device_map.values())
85
- ):
78
+ elif isinstance(device_map, dict):
79
+ if not self.pre_quantized and ("cpu" in device_map.values() or "disk" in device_map.values()):
86
80
  raise ValueError(
87
81
  "You are attempting to load an FP8 model with a device_map that contains a CPU or disk device."
88
82
  "This is not supported when the model is quantized on the fly. "
@@ -90,19 +84,11 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
90
84
  )
91
85
 
92
86
  def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
93
- if dtype is None:
94
- dtype = torch.bfloat16
95
- logger.info(
96
- "Overriding dtype=%s with `dtype=torch.bloat16` due to "
97
- "requirements of `fbgemm-gpu` to enable model loading in fp8. "
98
- "Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
99
- " dtype=torch.bfloat16 to remove this warning.",
100
- dtype,
101
- )
102
- elif dtype == torch.float16:
103
- raise ValueError(
104
- "You cannot use FP8 with dtype=torch.float16.We recommend you passing dtype=torch.bfloat16"
87
+ if dtype != torch.bfloat16:
88
+ logger.warning_once(
89
+ f"Setting dtype to {dtype}, but only bfloat16 is supported right now. Overwriting torch_dtype to bfloat16."
105
90
  )
91
+ dtype = torch.bfloat16
106
92
  return dtype
107
93
 
108
94
  def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
@@ -122,116 +108,25 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
122
108
  return True
123
109
  return False
124
110
 
125
- def create_quantized_param(
126
- self,
127
- model: "PreTrainedModel",
128
- param_value: "torch.Tensor",
129
- param_name: str,
130
- target_device: "torch.device",
131
- **kwargs,
132
- ):
133
- from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
134
-
135
- module, tensor_name = get_module_from_name(model, param_name)
136
-
137
- # Sanity checks
138
- if isinstance(module, FbgemmFp8Linear):
139
- if self.pre_quantized or tensor_name == "bias":
140
- if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
141
- raise ValueError("Expect quantized weights but got an unquantized weight")
142
- else:
143
- if tensor_name == "weight_scale":
144
- raise ValueError("Expect unquantized weights but got a quantized weight_scale")
145
- if isinstance(module, FbgemmFp8Llama4TextExperts):
146
- if not (self.pre_quantized or tensor_name == "bias"):
147
- if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale":
148
- raise ValueError("Expect unquantized weights but got a quantized weight_scale")
149
-
150
- if isinstance(module, FbgemmFp8Llama4TextExperts):
151
- if tensor_name == "gate_up_proj":
152
- # Process each expert separately
153
- # Transpose the second and third dimension
154
- transposed_param = param_value.transpose(1, 2)
155
-
156
- # Reshape to 2D for quantization
157
- original_shape = transposed_param.shape
158
- flattened_param = transposed_param.reshape(-1, original_shape[-1])
159
-
160
- # Quantize using per row instead of per column
161
- new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
162
-
163
- # Reshape back to original dimensions
164
- new_value = new_value_flat.reshape(original_shape)
165
- new_value = new_value.transpose(1, 2)
166
- weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1])
167
- elif tensor_name == "down_proj":
168
- # Process each expert separately
169
- # Transpose the weights for proper quantization
170
- transposed_param = param_value.transpose(1, 2)
171
-
172
- # Reshape to 2D for quantization
173
- original_shape = transposed_param.shape
174
- flattened_param = transposed_param.reshape(-1, original_shape[-1])
175
-
176
- # Quantize using per column
177
- new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
178
-
179
- # Reshape back to original dimensions
180
- new_value = new_value_flat.reshape(original_shape)
181
- new_value = new_value.transpose(1, 2)
182
- weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1)
183
-
184
- module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(weight_scale.to(target_device))
185
- else:
186
- new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value)
187
- module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(
188
- weight_scale.view(weight_scale.shape[0], 1).to(target_device)
189
- )
190
-
191
- module._parameters[tensor_name] = torch.nn.Parameter(new_value.to(target_device))
192
-
193
- del param_name
194
-
195
111
  def _process_model_before_weight_loading(
196
112
  self,
197
113
  model: "PreTrainedModel",
198
- keep_in_fp32_modules: list[str] | None = None,
199
114
  **kwargs,
200
115
  ):
201
116
  from ..integrations import replace_with_fbgemm_fp8_linear
202
117
 
203
- tp_plan = model._tp_plan
204
118
  self.modules_to_not_convert = self.get_modules_to_not_convert(
205
- model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
119
+ model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
206
120
  )
207
121
 
208
- config = model.config
209
122
  model = replace_with_fbgemm_fp8_linear(
210
123
  model,
211
124
  modules_to_not_convert=self.modules_to_not_convert,
212
125
  quantization_config=self.quantization_config,
213
126
  pre_quantized=self.pre_quantized,
214
- config=config,
215
- tp_plan=tp_plan,
127
+ tp_plan=model._tp_plan,
216
128
  )
217
129
 
218
- model.config.quantization_config = self.quantization_config
219
-
220
- def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
221
- from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
222
-
223
- not_missing_keys = []
224
- for name, module in model.named_modules():
225
- if isinstance(module, (FbgemmFp8Linear, FbgemmFp8Llama4TextExperts)):
226
- for missing in missing_keys:
227
- if (
228
- (name in missing or name in f"{prefix}.{missing}")
229
- and not missing.endswith(".weight")
230
- and not missing.endswith(".bias")
231
- ):
232
- not_missing_keys.append(missing)
233
- return [k for k in missing_keys if k not in not_missing_keys]
234
-
235
130
  def update_tp_plan(self, config):
236
131
  if "Llama4" in config.__class__.__name__:
237
132
  text_plan = {
@@ -279,9 +174,14 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
279
174
 
280
175
  return config
281
176
 
282
- def is_serializable(self, safe_serialization=None):
177
+ def is_serializable(self):
283
178
  return True
284
179
 
285
180
  @property
286
181
  def is_trainable(self) -> bool:
287
182
  return False
183
+
184
+ def get_quantize_ops(self):
185
+ from ..integrations.fbgemm_fp8 import FbgemmFp8Quantize
186
+
187
+ return FbgemmFp8Quantize(self)
@@ -20,25 +20,19 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
20
20
  Supports both e4m3fn formats based on platform.
21
21
  """
22
22
 
23
- requires_parameters_quantization = True
24
23
  requires_calibration = False
25
- required_packages = ["accelerate"]
26
24
 
27
25
  def __init__(self, quantization_config, **kwargs):
28
26
  super().__init__(quantization_config, **kwargs)
29
- self.quantization_config = quantization_config
30
27
 
31
28
  def validate_environment(self, *args, **kwargs):
32
- if not is_torch_available():
33
- raise ImportError(
34
- "Using fp8 quantization requires torch >= 2.1.0"
35
- "Please install the latest version of torch ( pip install --upgrade torch )"
36
- )
37
-
38
29
  if not is_accelerate_available():
39
30
  raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)")
40
31
 
41
- if (not (torch.cuda.is_available() or is_torch_xpu_available())) and not self.quantization_config.dequantize:
32
+ if self.quantization_config.dequantize:
33
+ return
34
+
35
+ if not torch.cuda.is_available() and not is_torch_xpu_available():
42
36
  if self.pre_quantized:
43
37
  logger.warning_once(
44
38
  "Using FP8 quantized models requires a GPU or XPU, we will default to dequantizing the model to bf16 since no GPU or XPU is available"
@@ -52,10 +46,13 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
52
46
  compute_capability = torch.cuda.get_device_capability()
53
47
  major, minor = compute_capability
54
48
  if (major < 8) or (major == 8 and minor < 9):
55
- raise ValueError(
49
+ logger.warning_once(
56
50
  "FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)"
57
- f", actual = `{major}.{minor}`"
51
+ f", actual = `{major}.{minor}`. We will default to dequantizing the model to bf16. Feel free "
52
+ f"to use a different quantization method like bitsandbytes or torchao"
58
53
  )
54
+ self.quantization_config.dequantize = True
55
+ return
59
56
 
60
57
  device_map = kwargs.get("device_map")
61
58
  if device_map is None:
@@ -64,11 +61,12 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
64
61
  "your model on a GPU or XPU device in order to run your model. To remove this warning, "
65
62
  "pass device_map = 'cuda' or 'xpu'. "
66
63
  )
67
- elif device_map is not None:
64
+ elif isinstance(device_map, dict):
68
65
  if (
69
66
  not self.pre_quantized
70
- and isinstance(device_map, dict)
71
- and ("cpu" in device_map.values() or "disk" in device_map.values())
67
+ and len(device_map) > 1
68
+ and "cpu" in device_map.values()
69
+ or "disk" in device_map.values()
72
70
  ):
73
71
  raise ValueError(
74
72
  "You are attempting to load an FP8 model with a device_map that contains a cpu/disk device."
@@ -76,76 +74,6 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
76
74
  "Please use a quantized checkpoint or remove the cpu/disk device from the device_map."
77
75
  )
78
76
 
79
- def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
80
- if dtype is None:
81
- logger.info("Setting dtype to torch.float32 as no dtype was specified in from_pretrained")
82
- dtype = torch.float32
83
- return dtype
84
-
85
- # TODO: make this into a `ConversionType` ops -> potentially requires all weights on all ranks
86
- # depending on the layer type (moe -> no if ep)
87
- def create_quantized_param(
88
- self,
89
- model: "PreTrainedModel",
90
- param_value: "torch.Tensor",
91
- param_name: str,
92
- target_device: "torch.device",
93
- **kwargs,
94
- ):
95
- from ..integrations.finegrained_fp8 import FP8Linear
96
- from ..modeling_utils import _load_parameter_into_model
97
-
98
- # Sanity checks
99
- module, tensor_name = get_module_from_name(model, param_name)
100
- if isinstance(module, FP8Linear):
101
- if self.pre_quantized or tensor_name == "bias":
102
- if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
103
- raise ValueError("Expect quantized weights but got an unquantized weight")
104
- else:
105
- return
106
- # if tensor_name == "weight_scale_inv":
107
- # raise ValueError("Expect unquantized weights but got a quantized weight_scale")
108
-
109
- param_value = param_value.to(target_device)
110
-
111
- # Get FP8 min/max values
112
- fp8_min = torch.finfo(torch.float8_e4m3fn).min
113
- fp8_max = torch.finfo(torch.float8_e4m3fn).max
114
-
115
- block_size_m, block_size_n = self.quantization_config.weight_block_size
116
-
117
- rows, cols = param_value.shape[-2:]
118
-
119
- if rows % block_size_m != 0 or cols % block_size_n != 0:
120
- raise ValueError(
121
- f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})"
122
- )
123
- param_value_orig_shape = param_value.shape
124
-
125
- param_value = param_value.reshape(
126
- -1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n
127
- ).permute(0, 1, 3, 2, 4)
128
-
129
- # Calculate scaling factor for each block
130
- max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
131
- scale = fp8_max / max_abs
132
- scale_orig_shape = scale.shape
133
- scale = scale.unsqueeze(-1).unsqueeze(-1)
134
-
135
- # Quantize the weights
136
- quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
137
-
138
- quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
139
- # Reshape back to matrix shape
140
- quantized_param = quantized_param.reshape(param_value_orig_shape)
141
-
142
- # Reshape scale to match the number of blocks
143
- scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
144
-
145
- # Load into the model
146
- _load_parameter_into_model(model, param_name, quantized_param)
147
- _load_parameter_into_model(model, param_name.rsplit(".", 1)[0] + ".weight_scale_inv", scale)
148
-
149
77
  def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
150
78
  from ..integrations.finegrained_fp8 import FP8Expert, FP8Linear
151
79
 
@@ -157,46 +85,34 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
157
85
  return True
158
86
  return False
159
87
 
88
+ def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
89
+ "Return the element size (in bytes) for `param_name`."
90
+ if self.param_needs_quantization(model, param_name):
91
+ # 8 bit, this is neeed as when `pre_quantized`` is False, we don't set the dtype of the FP8Linear in order to correctly load the weights
92
+ return 1
93
+ return super().param_element_size(model, param_name, param)
94
+
160
95
  def _process_model_before_weight_loading(
161
96
  self,
162
97
  model: "PreTrainedModel",
163
- keep_in_fp32_modules: list[str] | None = None,
164
98
  **kwargs,
165
99
  ):
166
100
  from ..integrations.finegrained_fp8 import replace_with_fp8_linear
167
101
 
168
- # takes 2 fucking seconds
169
102
  self.modules_to_not_convert = self.get_modules_to_not_convert(
170
- model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
103
+ model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
171
104
  )
172
105
 
173
- # while this one is 81ms :)
174
106
  model = replace_with_fp8_linear(
175
107
  model,
176
108
  modules_to_not_convert=self.modules_to_not_convert,
177
109
  quantization_config=self.quantization_config,
110
+ pre_quantized=self.pre_quantized,
178
111
  )
179
112
 
180
- model.config.quantization_config = self.quantization_config
181
-
182
- def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
183
- from ..integrations import FP8Linear
184
-
185
- not_missing_keys = []
186
- for name, module in model.named_modules():
187
- if isinstance(module, FP8Linear):
188
- for missing in missing_keys:
189
- if (
190
- (name in missing or name in f"{prefix}.{missing}")
191
- and not missing.endswith(".weight")
192
- and not missing.endswith(".bias")
193
- ):
194
- not_missing_keys.append(missing)
195
- return [k for k in missing_keys if k not in not_missing_keys]
196
-
197
113
  # NOTE: TP is applied before quantization so this is only to add hooks.
198
114
  # Quantization is incompatible with DTensors, so we have to anyway have
199
- # gathers! But it should be model independant -> figure out where to put
115
+ # gathers! But it should be model independent -> figure out where to put
200
116
  # the gather and that's it.
201
117
  def update_tp_plan(self, config):
202
118
  if "Qwen3" in config.__class__.__name__:
@@ -223,17 +139,13 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
223
139
 
224
140
  return config
225
141
 
226
- def is_serializable(self, safe_serialization=None):
142
+ def is_serializable(self):
227
143
  return True
228
144
 
229
145
  @property
230
146
  def is_trainable(self) -> bool:
231
147
  return False
232
148
 
233
- def get_accelerator_warm_up_factor(self):
234
- # Pre-processing is done cleanly, so we can allocate everything here
235
- return 2
236
-
237
149
  def get_quantize_ops(self):
238
150
  from ..integrations.finegrained_fp8 import Fp8Quantize
239
151
 
@@ -246,8 +158,9 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
246
158
  if self.pre_quantized and self.quantization_config.dequantize:
247
159
  return [
248
160
  # either use the dollar sign, or permute the source patterns to start matching against the scales first
161
+ # We also collect the activation scales, they will not be used
249
162
  WeightConverter(
250
- source_patterns=["weight$", "weight_scale_inv"],
163
+ source_patterns=["weight$", "weight_scale_inv", "activation_scale"],
251
164
  target_patterns="weight",
252
165
  operations=[Fp8Dequantize(self)],
253
166
  )
@@ -36,13 +36,10 @@ class FPQuantHfQuantizer(HfQuantizer):
36
36
  """
37
37
 
38
38
  requires_calibration = False
39
- requires_parameters_quantization = True
40
39
  is_qat_trainable = True
41
- required_packages = ["fp_quant"]
42
40
 
43
41
  def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
44
42
  super().__init__(quantization_config, **kwargs)
45
- self.quantization_config = quantization_config
46
43
 
47
44
  def validate_environment(self, device_map, **kwargs):
48
45
  if not torch.cuda.is_available() and not is_torch_xpu_available():
@@ -68,66 +65,35 @@ class FPQuantHfQuantizer(HfQuantizer):
68
65
  "You are attempting to load a FPQuant model without setting device_map."
69
66
  " Please set device_map comprised of 'cuda' devices."
70
67
  )
71
- elif (
72
- isinstance(device_map, dict)
73
- and ("cpu" in device_map.values() or "disk" in device_map.values())
74
- and not self.quantization_config.pseudoquantization
75
- ):
76
- raise ValueError(
77
- "You are attempting to load a FPQuant model with a device_map that contains a CPU or disk device."
78
- " This is not supported. Please remove the CPU or disk device from the device_map."
79
- )
68
+ elif isinstance(device_map, dict):
69
+ if (
70
+ not self.quantization_config.pseudoquantization
71
+ and len(device_map) > 1
72
+ and "cpu" in device_map.values()
73
+ or "disk" in device_map.values()
74
+ ):
75
+ raise ValueError(
76
+ "You are attempting to load a FPQuant model with a device_map that contains a CPU or disk device."
77
+ " This is not supported. Please remove the CPU or disk device from the device_map."
78
+ )
80
79
 
81
80
  def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
82
- if dtype is None:
83
- logger.info("`dtype` is None. Setting `dtype=torch.bfloat16` for qutlass compatibility.")
81
+ if dtype != torch.bfloat16:
82
+ logger.warning_once(
83
+ f"Setting dtype to {dtype}, but only bfloat16 is supported right now. Overwriting torch_dtype to bfloat16."
84
+ )
84
85
  dtype = torch.bfloat16
85
- elif dtype != torch.bfloat16:
86
- raise ValueError(f"Invalid `dtype` {dtype}. fp_quant quantization only supports `dtype=torch.bfloat16`.")
87
-
88
86
  return dtype
89
87
 
90
- def create_quantized_param(
91
- self,
92
- model: "PreTrainedModel",
93
- param_value: "torch.Tensor",
94
- param_name: str,
95
- target_device: "torch.device",
96
- **kwargs,
97
- ):
98
- module, _ = get_module_from_name(model, param_name)
99
-
100
- if target_device == "cpu" and param_name.endswith("weight"):
101
- # Works agains hard-coded missing key dispatch to CPU
102
- return
103
-
104
- # The module holds either:
105
- # * `weight` when `store_master_weights=True`
106
- # * `qweight` and `scales` when `store_master_weights=False` and `pseudoquantization=False`
107
- # * `dqweight` when `store_master_weights=False` and `pseudoquantization=True`
108
-
109
- if param_name.endswith(".qweight"):
110
- # Loading a real quantized checkpoint without master weights
111
- module.qweight = torch.nn.Parameter(
112
- param_value.to(target_device),
113
- requires_grad=False,
114
- )
115
- module.weight = None
116
- module.dqweight = None
117
- return
118
-
119
- if param_name.endswith(".dqweight"):
120
- # Loading a pseudo-quantized checkpoint without master weights
121
- module.dqweight = torch.nn.Parameter(param_value.to(target_device))
122
- module.weight = None
123
- module.qweight = None
124
- module.scales = None
125
- return
126
-
127
- # Loading master weights or an unquantized checkpoint
128
- module.weight = torch.nn.Parameter(param_value.to(target_device))
129
- # Let pre-forward handle the quantization and set None where necessary
130
- module.pre_forward()
88
+ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
89
+ from fp_quant import FPQuantLinear
90
+
91
+ module, tensor_name = get_module_from_name(model, param_name)
92
+ if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight", "dqweight"]:
93
+ # Only quantize weights of FPQuantLinear modules that are not already quantized
94
+ return True
95
+ else:
96
+ return False
131
97
 
132
98
  def _process_model_before_weight_loading(
133
99
  self,
@@ -142,20 +108,6 @@ class FPQuantHfQuantizer(HfQuantizer):
142
108
  model,
143
109
  fp_quant_linear_config=adapt_fp_quant_config(self.quantization_config),
144
110
  )
145
- model.config.quantization_config = self.quantization_config
146
-
147
- def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
148
- from fp_quant import FPQuantLinear
149
-
150
- fp_quant_names = {name for name, module in model.named_modules() if isinstance(module, FPQuantLinear)}
151
-
152
- def should_exclude(key: str) -> bool:
153
- if key.endswith(".weight") or key.endswith(".bias"):
154
- return False
155
- full_key = f"{prefix}.{key}"
156
- return any(name in key or name in full_key for name in fp_quant_names)
157
-
158
- return [key for key in missing_keys if not should_exclude(key)]
159
111
 
160
112
  @property
161
113
  def is_trainable(self, model: Optional["PreTrainedModel"] = None):
@@ -166,15 +118,33 @@ class FPQuantHfQuantizer(HfQuantizer):
166
118
  )
167
119
  return trainable
168
120
 
169
- def is_serializable(self, safe_serialization=None):
121
+ def is_serializable(self):
170
122
  return True
171
123
 
172
- def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
173
- from fp_quant import FPQuantLinear
174
-
175
- module, tensor_name = get_module_from_name(model, param_name)
176
- if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight", "dqweight"]:
177
- # Only quantize weights of FPQuantLinear modules that are not already quantized
178
- return True
179
- else:
180
- return False
124
+ def get_quantize_ops(self):
125
+ from ..integrations.fp_quant import FpQuantQuantize
126
+
127
+ return FpQuantQuantize(self)
128
+
129
+ def get_weight_conversions(self):
130
+ from ..core_model_loading import WeightConverter
131
+ from ..integrations.fp_quant import FpQuantDeserialize
132
+
133
+ if self.pre_quantized:
134
+ if self.quantization_config.pseudoquantization:
135
+ return [
136
+ WeightConverter(
137
+ source_patterns=[".dqweight"],
138
+ target_patterns=".dqweight",
139
+ operations=[FpQuantDeserialize(self)],
140
+ ),
141
+ ]
142
+ else:
143
+ return [
144
+ WeightConverter(
145
+ source_patterns=[".qweight"],
146
+ target_patterns=".qweight",
147
+ operations=[FpQuantDeserialize(self)],
148
+ ),
149
+ ]
150
+ return []