transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,9 @@ import os
19
19
  import re
20
20
  from functools import partial, reduce
21
21
 
22
+ from ..distributed import DistributedConfig
23
+ from ..utils import is_torch_greater_or_equal, logging
24
+ from ..utils.generic import GeneralInterface
22
25
  from ..utils.import_utils import is_torch_available
23
26
 
24
27
 
@@ -27,14 +30,6 @@ if is_torch_available():
27
30
  import torch.distributed as dist
28
31
  from torch import nn
29
32
 
30
- from ..distributed import DistributedConfig
31
- from ..utils import is_torch_greater_or_equal, logging
32
- from ..utils.generic import GeneralInterface
33
-
34
-
35
- logger = logging.get_logger(__name__)
36
-
37
- if is_torch_available():
38
33
  # Cache this result has it's a C FFI call which can be pretty time-consuming
39
34
  _torch_distributed_available = torch.distributed.is_available()
40
35
 
@@ -42,6 +37,9 @@ if is_torch_available():
42
37
  from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
43
38
 
44
39
 
40
+ logger = logging.get_logger(__name__)
41
+
42
+
45
43
  def initialize_tensor_parallelism(
46
44
  tp_plan: str | dict[str, str] | None, tp_size: int | None = None, device_mesh=None, device_map=None
47
45
  ):
@@ -470,7 +468,12 @@ class TensorParallelLayer:
470
468
  @staticmethod
471
469
  def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): ...
472
470
 
473
- def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
471
+ def shard_tensor(
472
+ self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
473
+ ) -> torch.Tensor:
474
+ raise NotImplementedError
475
+
476
+ def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
474
477
  raise NotImplementedError
475
478
 
476
479
  def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
@@ -519,19 +522,10 @@ class GatherParallel(TensorParallelLayer):
519
522
  return outputs
520
523
 
521
524
  def shard_tensor(
522
- self,
523
- param,
524
- param_type=None,
525
- param_casting_dtype=None,
526
- to_contiguous=None,
527
- rank=None,
528
- device_mesh=None,
529
- tensor_idx=None,
530
- ):
531
- shard = [Replicate()]
532
- parameter = param[...].to(param_casting_dtype)
533
- self.shard = shard
534
- return parameter, shard
525
+ self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
526
+ ) -> torch.Tensor:
527
+ self.shard = [Replicate()]
528
+ return param[...].to(device=device, dtype=dtype)
535
529
 
536
530
  def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
537
531
  distribute_module(
@@ -562,29 +556,20 @@ class IsolatedParallel(TensorParallelLayer):
562
556
  return outputs
563
557
 
564
558
  def shard_tensor(
565
- self,
566
- param,
567
- param_type=None,
568
- param_casting_dtype=None,
569
- to_contiguous=None,
570
- rank=None,
571
- device_mesh=None,
572
- tensor_idx=None,
573
- ):
574
- mesh = device_mesh or self.device_mesh
575
- parameter = param[...].to(param_casting_dtype)
576
- if mesh is not None:
577
- parameter = parameter / mesh.size()
559
+ self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
560
+ ) -> torch.Tensor:
561
+ parameter = param[...].to(device=device, dtype=dtype)
562
+ if self.device_mesh is not None:
563
+ parameter = parameter / self.device_mesh.size()
578
564
  self.shard = None
579
- return parameter, None
565
+ return parameter
580
566
 
581
- def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
582
- param = param[...].to(param_casting_dtype)
567
+ def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
568
+ parameter = self.shard_tensor(param, dtype=dtype)
583
569
  if to_contiguous:
584
- param = param.contiguous()
585
- param = param / device_mesh.size() # TODO should be optionable
570
+ parameter = parameter.contiguous()
586
571
  # TODO: assumes parent module will allreduce the output afterwards (e.g rowlinear bias is IsolatedParallel and parent module is GatherParallel)
587
- return param
572
+ return parameter
588
573
 
589
574
  def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
590
575
  distribute_module(
@@ -623,31 +608,15 @@ class ReplicateParallel(TensorParallelLayer):
623
608
  return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
624
609
 
625
610
  def shard_tensor(
626
- self,
627
- param,
628
- param_type=None,
629
- param_casting_dtype=None,
630
- to_contiguous=None,
631
- rank=None,
632
- device_mesh=None,
633
- tensor_idx=None,
634
- ):
635
- parameter = param[...].to(param_casting_dtype)
636
- shard = [Replicate()]
637
- self.shard = shard
638
- return parameter, shard
639
-
640
- def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
641
- parameter, shard = self.shard_tensor(
642
- param,
643
- param_type=param_type,
644
- param_casting_dtype=param_casting_dtype,
645
- to_contiguous=to_contiguous,
646
- rank=rank,
647
- device_mesh=device_mesh,
648
- )
611
+ self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
612
+ ) -> torch.Tensor:
613
+ self.shard = [Replicate()]
614
+ return param[...].to(device=device, dtype=dtype)
615
+
616
+ def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
617
+ parameter = self.shard_tensor(param, dtype=dtype)
649
618
  if self.use_dtensor:
650
- parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
619
+ parameter = DTensor.from_local(parameter, self.device_mesh, self.shard, run_check=False)
651
620
  return parameter
652
621
 
653
622
 
@@ -685,38 +654,34 @@ class ColwiseParallel(TensorParallelLayer):
685
654
  return input_tensor
686
655
 
687
656
  def shard_tensor(
688
- self,
689
- param,
690
- param_type=None,
691
- param_casting_dtype=None,
692
- to_contiguous=None,
693
- rank=None,
694
- device_mesh=None,
695
- tensor_idx=None,
696
- ):
697
- device_mesh = self.device_mesh
698
- empty_param = self.empty_param
699
- rank = self.rank
700
- if param_type == "bias":
701
- parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx)
657
+ self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
658
+ ) -> torch.Tensor:
659
+ # If only 1 dim, shard this one (usually it's a `bias`)
660
+ dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
661
+ if dim == 1:
662
+ parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1, tensor_idx)
702
663
  shard = [Shard(-1)]
703
664
  else:
704
665
  shard = [Shard(-2)]
705
- parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2, tensor_idx)
706
- parameter = parameter.to(param_casting_dtype)
666
+ parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2, tensor_idx)
707
667
  self.shard = shard
708
- return parameter, shard
668
+ return parameter.to(device=device, dtype=dtype)
709
669
 
710
- def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
670
+ def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
711
671
  # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
712
672
  # means Colwise as Linear is input * weight^T + bias, where
713
673
  # weight would become Shard(1)
714
- parameter, shard = self.shard_tensor(param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh)
674
+ parameter = self.shard_tensor(param, dtype=dtype)
715
675
  if to_contiguous:
716
676
  parameter = parameter.contiguous()
717
677
  if self.use_dtensor:
718
678
  parameter = DTensor.from_local(
719
- parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
679
+ parameter,
680
+ self.device_mesh,
681
+ self.shard,
682
+ run_check=False,
683
+ shape=self.empty_param.size(),
684
+ stride=self.empty_param.stride(),
720
685
  )
721
686
  return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
722
687
 
@@ -731,33 +696,41 @@ class ColwiseParallel(TensorParallelLayer):
731
696
 
732
697
  class PackedColwiseParallel(ColwiseParallel):
733
698
  def shard_tensor(
734
- self,
735
- param,
736
- param_type=None,
737
- param_casting_dtype=None,
738
- to_contiguous=None,
739
- rank=None,
740
- device_mesh=None,
741
- tensor_idx=None,
742
- ):
743
- device_mesh = device_mesh or self.device_mesh
744
- empty_param = self.empty_param
745
- rank = rank if rank is not None else self.rank
746
- return get_packed_weights(param, empty_param, device_mesh, rank, -2).to(param_casting_dtype), [Shard(-2)]
699
+ self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
700
+ ) -> torch.Tensor:
701
+ parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -2)
702
+ return parameter.to(device=device, dtype=dtype)
747
703
 
748
- def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
704
+ def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
749
705
  # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
750
706
  # means Colwise as Linear is input * weight^T + bias, where
751
707
  # weight would become Shard(1)
752
- parameter = get_packed_weights(param, empty_param, device_mesh, rank, -2)
753
- parameter = parameter.to(param_casting_dtype)
708
+ parameter = self.shard_tensor(param, dtype=dtype)
754
709
  if to_contiguous:
755
710
  parameter = parameter.contiguous()
756
711
  if self.use_dtensor:
757
- parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False)
712
+ parameter = DTensor.from_local(parameter, self.device_mesh, [Shard(-2)], run_check=False)
758
713
  return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
759
714
 
760
715
 
716
+ class LocalColwiseParallel(ColwiseParallel):
717
+ """
718
+ Colwise parallel with use_dtensor=False for local tensor operations.
719
+ """
720
+
721
+ def __init__(self, **kwargs):
722
+ super().__init__(use_dtensor=False, **kwargs)
723
+
724
+
725
+ class ColwiseParallelReplicate(ColwiseParallel):
726
+ """
727
+ Colwise parallel with output layouts replicated.
728
+ """
729
+
730
+ def __init__(self, **kwargs):
731
+ super().__init__(output_layouts=Replicate(), **kwargs)
732
+
733
+
761
734
  class RowwiseParallel(TensorParallelLayer):
762
735
  """
763
736
  Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding.
@@ -782,7 +755,7 @@ class RowwiseParallel(TensorParallelLayer):
782
755
  input_layouts: Placement | None = None,
783
756
  output_layouts: Placement | None = None,
784
757
  use_local_output: bool = True,
785
- use_dtensor=True,
758
+ use_dtensor: bool = True,
786
759
  **kwargs,
787
760
  ):
788
761
  super().__init__(**kwargs)
@@ -792,45 +765,36 @@ class RowwiseParallel(TensorParallelLayer):
792
765
  self.use_dtensor = use_dtensor
793
766
 
794
767
  def shard_tensor(
795
- self,
796
- param,
797
- param_type=None,
798
- param_casting_dtype=None,
799
- to_contiguous=None,
800
- rank=None,
801
- device_mesh=None,
802
- tensor_idx=None,
803
- ):
804
- device_mesh = device_mesh or self.device_mesh
805
- empty_param = self.empty_param
806
- rank = rank if rank is not None else self.rank
807
- if param_type == "bias":
768
+ self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
769
+ ) -> torch.Tensor:
770
+ # If only 1 dim, it should not be sharded (usually it's a `bias`)
771
+ dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
772
+ if dim == 1:
808
773
  shard = [Replicate()]
809
774
  parameter = param[...]
810
775
  else:
811
- parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx=tensor_idx)
776
+ parameter = get_tensor_shard(
777
+ param, self.empty_param, self.device_mesh, self.rank, -1, tensor_idx=tensor_idx
778
+ )
812
779
  shard = [Shard(-1)]
813
- parameter = parameter.to(param_casting_dtype)
814
780
  self.shard = shard
815
- return parameter, shard
781
+ return parameter.to(device=device, dtype=dtype)
816
782
 
817
- def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
783
+ def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
818
784
  # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
819
785
  # means Rowwise as nn.Linear is input * weight^T + bias, where
820
786
  # weight would become Shard(0)
821
- if param_type != "bias":
822
- parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
823
- shard = [Shard(-1)]
824
- else:
825
- shard = [Replicate()]
826
- parameter = param[:]
827
-
828
- parameter = parameter.to(param_casting_dtype)
787
+ parameter = self.shard_tensor(param, dtype=dtype)
829
788
  if to_contiguous:
830
789
  parameter = parameter.contiguous()
831
790
  if self.use_dtensor:
832
791
  parameter = DTensor.from_local(
833
- parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
792
+ parameter,
793
+ self.device_mesh,
794
+ self.shard,
795
+ run_check=False,
796
+ shape=self.empty_param.size(),
797
+ stride=self.empty_param.stride(),
834
798
  )
835
799
  return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
836
800
 
@@ -886,33 +850,50 @@ class RowwiseParallel(TensorParallelLayer):
886
850
 
887
851
  class PackedRowwiseParallel(RowwiseParallel):
888
852
  def shard_tensor(
889
- self,
890
- param,
891
- param_type=None,
892
- param_casting_dtype=None,
893
- to_contiguous=None,
894
- rank=None,
895
- device_mesh=None,
896
- tensor_idx=None,
897
- ):
898
- device_mesh = device_mesh or self.device_mesh
899
- empty_param = self.empty_param
900
- rank = rank if rank is not None else self.rank
901
- return get_packed_weights(param, empty_param, device_mesh, rank, -1), [Shard(-1)]
853
+ self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
854
+ ) -> torch.Tensor:
855
+ parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -1)
856
+ return parameter.to(device=device, dtype=dtype)
902
857
 
903
- def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
858
+ def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
904
859
  # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
905
860
  # means Colwise as Linear is input * weight^T + bias, where
906
861
  # weight would become Shard(1)
907
- parameter = get_packed_weights(param, empty_param, device_mesh, rank, -1)
908
- parameter = parameter.to(param_casting_dtype)
862
+ parameter = self.shard_tensor(param, dtype=dtype)
909
863
  if to_contiguous:
910
864
  parameter = parameter.contiguous()
911
865
  if self.use_dtensor:
912
- parameter = DTensor.from_local(parameter, device_mesh, [Shard(-1)], run_check=False)
866
+ parameter = DTensor.from_local(parameter, self.device_mesh, [Shard(-1)], run_check=False)
913
867
  return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
914
868
 
915
869
 
870
+ class LocalRowwiseParallel(RowwiseParallel):
871
+ """
872
+ Rowwise parallel with use_dtensor=False for local tensor operations.
873
+ """
874
+
875
+ def __init__(self, **kwargs):
876
+ super().__init__(use_dtensor=False, **kwargs)
877
+
878
+
879
+ class LocalPackedRowwiseParallel(PackedRowwiseParallel):
880
+ """
881
+ Packed rowwise parallel with use_dtensor=False for local tensor operations.
882
+ """
883
+
884
+ def __init__(self, **kwargs):
885
+ super().__init__(use_dtensor=False, **kwargs)
886
+
887
+
888
+ class RowwiseParallelReplicate(RowwiseParallel):
889
+ """
890
+ Rowwise parallel with input layouts replicated.
891
+ """
892
+
893
+ def __init__(self, **kwargs):
894
+ super().__init__(input_layouts=Replicate(), **kwargs)
895
+
896
+
916
897
  class SequenceParallel(TensorParallelLayer):
917
898
  """
918
899
  SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with
@@ -970,18 +951,13 @@ class SequenceParallel(TensorParallelLayer):
970
951
 
971
952
  def shard_tensor(
972
953
  self,
973
- param,
974
- param_type=None,
975
- param_casting_dtype=None,
976
- to_contiguous=None,
977
- rank=None,
978
- device_mesh=None,
954
+ param: torch.Tensor,
979
955
  tensor_idx=None,
980
- ):
981
- parameter = param[...].to(param_casting_dtype)
982
- shard = [Replicate()]
983
- self.shard = shard
984
- return parameter, shard
956
+ device=None,
957
+ dtype=None,
958
+ ) -> torch.Tensor:
959
+ self.shard = [Replicate()]
960
+ return param[...].to(device=device, dtype=dtype)
985
961
 
986
962
  @staticmethod
987
963
  def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
@@ -999,16 +975,15 @@ class SequenceParallel(TensorParallelLayer):
999
975
  ) # maybe we have to replicate ? because next layer is not sharded
1000
976
  return outputs.to_local() # if use_local_output else outputs
1001
977
 
1002
- def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
978
+ def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
1003
979
  # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
1004
980
  # means Colwise as Linear is input * weight^T + bias, where
1005
981
  # weight would become Shard(1)
1006
- parameter = param[...]
1007
- parameter = parameter.to(param_casting_dtype)
982
+ parameter = self.shard_tensor(param, dtype=dtype)
1008
983
  if to_contiguous:
1009
984
  parameter = parameter.contiguous()
1010
985
  if self.use_dtensor:
1011
- parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False)
986
+ parameter = DTensor.from_local(parameter, self.device_mesh, [Replicate()], run_check=False)
1012
987
  return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
1013
988
 
1014
989
 
@@ -1022,41 +997,23 @@ class GroupedGemmParallel(TensorParallelLayer):
1022
997
  self.use_dtensor = False
1023
998
 
1024
999
  def shard_tensor(
1025
- self,
1026
- param,
1027
- param_type=None,
1028
- param_casting_dtype=None,
1029
- to_contiguous=None,
1030
- rank=None,
1031
- device_mesh=None,
1032
- tensor_idx=None,
1033
- ):
1034
- empty_param = self.empty_param
1035
- ep_rank = self.rank
1036
- device_mesh = self.device_mesh
1037
-
1038
- global_num_experts = empty_param.shape[0]
1039
- if global_num_experts % device_mesh.size() != 0:
1000
+ self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
1001
+ ) -> torch.Tensor:
1002
+ global_num_experts = self.empty_param.shape[0]
1003
+ if global_num_experts % self.device_mesh.size() != 0:
1040
1004
  raise ValueError(
1041
- f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0"
1005
+ f"Global number of experts must be divisible by number of devices: {global_num_experts} % {self.device_mesh.size()} != 0"
1042
1006
  )
1043
- local_num_experts = global_num_experts // device_mesh.size()
1044
- parameter = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype)
1007
+ local_num_experts = global_num_experts // self.device_mesh.size()
1008
+ parameter = param[self.rank * local_num_experts : (self.rank + 1) * local_num_experts]
1045
1009
  self.shard = None
1046
- return parameter, None
1010
+ return parameter.to(device=device, dtype=dtype)
1047
1011
 
1048
- def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
1049
- ep_rank = rank
1050
- global_num_experts = empty_param.shape[0]
1051
- if global_num_experts % device_mesh.size() != 0:
1052
- raise ValueError(
1053
- f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0"
1054
- )
1055
- local_num_experts = global_num_experts // device_mesh.size()
1056
- param = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype)
1012
+ def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
1013
+ parameter = self.shard_tensor(param, dtype=dtype)
1057
1014
  if to_contiguous:
1058
- param = param.contiguous()
1059
- return param
1015
+ parameter = parameter.contiguous()
1016
+ return parameter
1060
1017
 
1061
1018
 
1062
1019
  class RouterParallel(TensorParallelLayer):
@@ -1064,10 +1021,10 @@ class RouterParallel(TensorParallelLayer):
1064
1021
  Allows to reshape the router scores to support running expert parallel.
1065
1022
  """
1066
1023
 
1067
- def __init__(self, *args, **kwargs):
1024
+ def __init__(self, use_dtensor: bool = False, *args, **kwargs):
1068
1025
  super().__init__(**kwargs)
1069
1026
  self.args = args
1070
- self.use_dtensor = False
1027
+ self.use_dtensor = use_dtensor
1071
1028
 
1072
1029
  @staticmethod
1073
1030
  def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
@@ -1118,7 +1075,7 @@ class RouterParallel(TensorParallelLayer):
1118
1075
  f"The number of experts must be divisible by number of ep_size: {mod.num_experts} % {ep_size} != 0"
1119
1076
  )
1120
1077
  num_local_experts = mod.num_experts // ep_size
1121
- router_scores, router_indices = outputs
1078
+ router_logits, router_scores, router_indices = outputs
1122
1079
  router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts]
1123
1080
  router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, -1)
1124
1081
  # As -1 % 1 is 0, we can only use mask fill when num_local_experts is 1
@@ -1129,28 +1086,20 @@ class RouterParallel(TensorParallelLayer):
1129
1086
  router_indices = router_indices.masked_fill(
1130
1087
  router_indices == -1, num_local_experts
1131
1088
  ) # masking class for one hot
1132
- return router_scores, router_indices
1089
+ return router_logits, router_scores, router_indices
1133
1090
 
1134
1091
  def shard_tensor(
1135
- self,
1136
- param,
1137
- param_type=None,
1138
- param_casting_dtype=None,
1139
- to_contiguous=None,
1140
- rank=None,
1141
- device_mesh=None,
1142
- tensor_idx=None,
1143
- ):
1144
- parameter = param[...].to(param_casting_dtype)
1092
+ self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
1093
+ ) -> torch.Tensor:
1145
1094
  self.shard = None
1146
- return parameter, None
1095
+ return param[...].to(device=device, dtype=dtype)
1147
1096
 
1148
- def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
1097
+ def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
1149
1098
  # TODO: i'd like for this to be the default
1150
- param = param[...].to(param_casting_dtype)
1099
+ parameter = self.shard_tensor(param, dtype=dtype)
1151
1100
  if to_contiguous:
1152
- param = param.contiguous()
1153
- return param
1101
+ parameter = parameter.contiguous()
1102
+ return parameter
1154
1103
 
1155
1104
  def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
1156
1105
  # TODO: need an abstract Parallel class that is different from TensorParallelLayer
@@ -1169,13 +1118,13 @@ class ParallelInterface(GeneralInterface):
1169
1118
  {
1170
1119
  "colwise": ColwiseParallel(),
1171
1120
  "rowwise": RowwiseParallel(),
1172
- "colwise_rep": ColwiseParallel(output_layouts=Replicate()),
1173
- "rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
1174
- "local_colwise": ColwiseParallel(use_dtensor=False),
1175
- "local_rowwise": RowwiseParallel(use_dtensor=False),
1121
+ "colwise_rep": ColwiseParallelReplicate(),
1122
+ "rowwise_rep": RowwiseParallelReplicate(),
1123
+ "local_colwise": LocalColwiseParallel(),
1124
+ "local_rowwise": LocalRowwiseParallel(),
1176
1125
  "local": IsolatedParallel(),
1177
1126
  "gather": GatherParallel(),
1178
- "local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
1127
+ "local_packed_rowwise": LocalPackedRowwiseParallel(),
1179
1128
  "sequence_parallel": SequenceParallel(),
1180
1129
  "replicate": ReplicateParallel(),
1181
1130
  "grouped_gemm": GroupedGemmParallel(),
@@ -1286,13 +1235,10 @@ def shard_and_distribute_module(
1286
1235
 
1287
1236
  if current_shard_plan is not None:
1288
1237
  try:
1289
- tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
1290
- tp_layer.empty_param = empty_param
1291
- tp_layer.device_mesh = device_mesh
1292
- tp_layer.rank = rank
1293
- param = tp_layer.partition_tensor(
1294
- param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
1238
+ tp_layer = ALL_PARALLEL_STYLES[current_shard_plan](
1239
+ empty_param=empty_param, device_mesh=device_mesh, rank=rank
1295
1240
  )
1241
+ param = tp_layer.partition_tensor(param, param_casting_dtype, is_contiguous)
1296
1242
  except NotImplementedError as e:
1297
1243
  print(
1298
1244
  f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
@@ -14,13 +14,11 @@
14
14
  "VPTQ (Vector Post-Training Quantization) integration file"
15
15
 
16
16
  from ..quantizers.quantizers_utils import should_convert_module
17
- from ..utils import is_accelerate_available, is_torch_available, logging
17
+ from ..utils import is_torch_available, logging
18
18
 
19
19
 
20
- if is_accelerate_available():
21
- from accelerate import init_empty_weights
22
-
23
20
  if is_torch_available():
21
+ import torch
24
22
  import torch.nn as nn
25
23
 
26
24
  logger = logging.get_logger(__name__)
@@ -48,7 +46,7 @@ def replace_with_vptq_linear(model, modules_to_not_convert: list[str] | None = N
48
46
  for module_name, module in model.named_modules():
49
47
  if not should_convert_module(module_name, modules_to_not_convert):
50
48
  continue
51
- with init_empty_weights():
49
+ with torch.device("meta"):
52
50
  if isinstance(module, nn.Linear):
53
51
  layer_params = config_for_layers.get(module_name, None) or shared_layer_config.get(
54
52
  module_name.rsplit(".")[1], None