transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1021) hide show
  1. transformers/__init__.py +4 -11
  2. transformers/activations.py +2 -2
  3. transformers/backbone_utils.py +326 -0
  4. transformers/cache_utils.py +11 -2
  5. transformers/cli/serve.py +11 -8
  6. transformers/configuration_utils.py +1 -69
  7. transformers/conversion_mapping.py +146 -26
  8. transformers/convert_slow_tokenizer.py +6 -4
  9. transformers/core_model_loading.py +207 -118
  10. transformers/dependency_versions_check.py +0 -1
  11. transformers/dependency_versions_table.py +7 -8
  12. transformers/file_utils.py +0 -2
  13. transformers/generation/candidate_generator.py +1 -2
  14. transformers/generation/continuous_batching/cache.py +40 -38
  15. transformers/generation/continuous_batching/cache_manager.py +3 -16
  16. transformers/generation/continuous_batching/continuous_api.py +94 -406
  17. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  18. transformers/generation/continuous_batching/requests.py +54 -17
  19. transformers/generation/continuous_batching/scheduler.py +77 -95
  20. transformers/generation/logits_process.py +10 -5
  21. transformers/generation/stopping_criteria.py +1 -2
  22. transformers/generation/utils.py +75 -95
  23. transformers/image_processing_utils.py +0 -3
  24. transformers/image_processing_utils_fast.py +17 -18
  25. transformers/image_transforms.py +44 -13
  26. transformers/image_utils.py +0 -5
  27. transformers/initialization.py +57 -0
  28. transformers/integrations/__init__.py +10 -24
  29. transformers/integrations/accelerate.py +47 -11
  30. transformers/integrations/deepspeed.py +145 -3
  31. transformers/integrations/executorch.py +2 -6
  32. transformers/integrations/finegrained_fp8.py +142 -7
  33. transformers/integrations/flash_attention.py +2 -7
  34. transformers/integrations/hub_kernels.py +18 -7
  35. transformers/integrations/moe.py +226 -106
  36. transformers/integrations/mxfp4.py +47 -34
  37. transformers/integrations/peft.py +488 -176
  38. transformers/integrations/tensor_parallel.py +641 -581
  39. transformers/masking_utils.py +153 -9
  40. transformers/modeling_flash_attention_utils.py +1 -2
  41. transformers/modeling_utils.py +359 -358
  42. transformers/models/__init__.py +6 -0
  43. transformers/models/afmoe/configuration_afmoe.py +14 -4
  44. transformers/models/afmoe/modeling_afmoe.py +8 -8
  45. transformers/models/afmoe/modular_afmoe.py +7 -7
  46. transformers/models/aimv2/configuration_aimv2.py +2 -7
  47. transformers/models/aimv2/modeling_aimv2.py +26 -24
  48. transformers/models/aimv2/modular_aimv2.py +8 -12
  49. transformers/models/albert/configuration_albert.py +8 -1
  50. transformers/models/albert/modeling_albert.py +3 -3
  51. transformers/models/align/configuration_align.py +8 -5
  52. transformers/models/align/modeling_align.py +22 -24
  53. transformers/models/altclip/configuration_altclip.py +4 -6
  54. transformers/models/altclip/modeling_altclip.py +30 -26
  55. transformers/models/apertus/configuration_apertus.py +5 -7
  56. transformers/models/apertus/modeling_apertus.py +4 -4
  57. transformers/models/apertus/modular_apertus.py +8 -10
  58. transformers/models/arcee/configuration_arcee.py +5 -7
  59. transformers/models/arcee/modeling_arcee.py +4 -4
  60. transformers/models/aria/configuration_aria.py +11 -21
  61. transformers/models/aria/modeling_aria.py +39 -36
  62. transformers/models/aria/modular_aria.py +33 -39
  63. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
  64. transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
  65. transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
  66. transformers/models/auto/auto_factory.py +8 -6
  67. transformers/models/auto/configuration_auto.py +22 -0
  68. transformers/models/auto/image_processing_auto.py +17 -13
  69. transformers/models/auto/modeling_auto.py +15 -0
  70. transformers/models/auto/processing_auto.py +9 -18
  71. transformers/models/auto/tokenization_auto.py +17 -15
  72. transformers/models/autoformer/modeling_autoformer.py +2 -1
  73. transformers/models/aya_vision/configuration_aya_vision.py +4 -0
  74. transformers/models/aya_vision/modeling_aya_vision.py +29 -62
  75. transformers/models/aya_vision/modular_aya_vision.py +20 -45
  76. transformers/models/bamba/configuration_bamba.py +17 -7
  77. transformers/models/bamba/modeling_bamba.py +23 -55
  78. transformers/models/bamba/modular_bamba.py +19 -54
  79. transformers/models/bark/configuration_bark.py +2 -1
  80. transformers/models/bark/modeling_bark.py +24 -10
  81. transformers/models/bart/configuration_bart.py +9 -4
  82. transformers/models/bart/modeling_bart.py +9 -12
  83. transformers/models/beit/configuration_beit.py +2 -4
  84. transformers/models/beit/image_processing_beit_fast.py +3 -3
  85. transformers/models/beit/modeling_beit.py +14 -9
  86. transformers/models/bert/configuration_bert.py +12 -1
  87. transformers/models/bert/modeling_bert.py +6 -30
  88. transformers/models/bert_generation/configuration_bert_generation.py +17 -1
  89. transformers/models/bert_generation/modeling_bert_generation.py +6 -6
  90. transformers/models/big_bird/configuration_big_bird.py +12 -8
  91. transformers/models/big_bird/modeling_big_bird.py +0 -15
  92. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
  93. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
  94. transformers/models/biogpt/configuration_biogpt.py +8 -1
  95. transformers/models/biogpt/modeling_biogpt.py +4 -8
  96. transformers/models/biogpt/modular_biogpt.py +1 -5
  97. transformers/models/bit/configuration_bit.py +2 -4
  98. transformers/models/bit/modeling_bit.py +6 -5
  99. transformers/models/bitnet/configuration_bitnet.py +5 -7
  100. transformers/models/bitnet/modeling_bitnet.py +3 -4
  101. transformers/models/bitnet/modular_bitnet.py +3 -4
  102. transformers/models/blenderbot/configuration_blenderbot.py +8 -4
  103. transformers/models/blenderbot/modeling_blenderbot.py +4 -4
  104. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
  105. transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
  106. transformers/models/blip/configuration_blip.py +9 -9
  107. transformers/models/blip/modeling_blip.py +55 -37
  108. transformers/models/blip_2/configuration_blip_2.py +2 -1
  109. transformers/models/blip_2/modeling_blip_2.py +81 -56
  110. transformers/models/bloom/configuration_bloom.py +5 -1
  111. transformers/models/bloom/modeling_bloom.py +2 -1
  112. transformers/models/blt/configuration_blt.py +23 -12
  113. transformers/models/blt/modeling_blt.py +20 -14
  114. transformers/models/blt/modular_blt.py +70 -10
  115. transformers/models/bridgetower/configuration_bridgetower.py +7 -1
  116. transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
  117. transformers/models/bridgetower/modeling_bridgetower.py +29 -15
  118. transformers/models/bros/configuration_bros.py +24 -17
  119. transformers/models/camembert/configuration_camembert.py +8 -1
  120. transformers/models/camembert/modeling_camembert.py +6 -6
  121. transformers/models/canine/configuration_canine.py +4 -1
  122. transformers/models/chameleon/configuration_chameleon.py +5 -7
  123. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
  124. transformers/models/chameleon/modeling_chameleon.py +82 -36
  125. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
  126. transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
  127. transformers/models/clap/configuration_clap.py +4 -8
  128. transformers/models/clap/modeling_clap.py +21 -22
  129. transformers/models/clip/configuration_clip.py +4 -1
  130. transformers/models/clip/image_processing_clip_fast.py +9 -0
  131. transformers/models/clip/modeling_clip.py +25 -22
  132. transformers/models/clipseg/configuration_clipseg.py +4 -1
  133. transformers/models/clipseg/modeling_clipseg.py +27 -25
  134. transformers/models/clipseg/processing_clipseg.py +11 -3
  135. transformers/models/clvp/configuration_clvp.py +14 -2
  136. transformers/models/clvp/modeling_clvp.py +19 -30
  137. transformers/models/codegen/configuration_codegen.py +4 -3
  138. transformers/models/codegen/modeling_codegen.py +2 -1
  139. transformers/models/cohere/configuration_cohere.py +5 -7
  140. transformers/models/cohere/modeling_cohere.py +4 -4
  141. transformers/models/cohere/modular_cohere.py +3 -3
  142. transformers/models/cohere2/configuration_cohere2.py +6 -8
  143. transformers/models/cohere2/modeling_cohere2.py +4 -4
  144. transformers/models/cohere2/modular_cohere2.py +9 -11
  145. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  146. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
  147. transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
  148. transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
  149. transformers/models/colqwen2/modeling_colqwen2.py +7 -6
  150. transformers/models/colqwen2/modular_colqwen2.py +7 -6
  151. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
  152. transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
  153. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
  154. transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
  155. transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
  156. transformers/models/convbert/configuration_convbert.py +11 -7
  157. transformers/models/convnext/configuration_convnext.py +2 -4
  158. transformers/models/convnext/image_processing_convnext_fast.py +2 -2
  159. transformers/models/convnext/modeling_convnext.py +7 -6
  160. transformers/models/convnextv2/configuration_convnextv2.py +2 -4
  161. transformers/models/convnextv2/modeling_convnextv2.py +7 -6
  162. transformers/models/cpmant/configuration_cpmant.py +4 -0
  163. transformers/models/csm/configuration_csm.py +9 -15
  164. transformers/models/csm/modeling_csm.py +3 -3
  165. transformers/models/ctrl/configuration_ctrl.py +16 -0
  166. transformers/models/ctrl/modeling_ctrl.py +13 -25
  167. transformers/models/cwm/configuration_cwm.py +5 -7
  168. transformers/models/cwm/modeling_cwm.py +4 -4
  169. transformers/models/d_fine/configuration_d_fine.py +10 -56
  170. transformers/models/d_fine/modeling_d_fine.py +728 -868
  171. transformers/models/d_fine/modular_d_fine.py +335 -412
  172. transformers/models/dab_detr/configuration_dab_detr.py +22 -48
  173. transformers/models/dab_detr/modeling_dab_detr.py +11 -7
  174. transformers/models/dac/modeling_dac.py +1 -1
  175. transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
  176. transformers/models/data2vec/configuration_data2vec_text.py +11 -2
  177. transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
  178. transformers/models/data2vec/modeling_data2vec_text.py +6 -6
  179. transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
  180. transformers/models/dbrx/configuration_dbrx.py +11 -3
  181. transformers/models/dbrx/modeling_dbrx.py +6 -6
  182. transformers/models/dbrx/modular_dbrx.py +6 -6
  183. transformers/models/deberta/configuration_deberta.py +6 -0
  184. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
  185. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
  186. transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
  187. transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
  188. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
  189. transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
  190. transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
  191. transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
  192. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
  193. transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
  194. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
  195. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
  196. transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
  197. transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
  198. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
  199. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
  200. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
  201. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
  202. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
  203. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
  204. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
  205. transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
  206. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
  207. transformers/models/deit/modeling_deit.py +11 -7
  208. transformers/models/depth_anything/configuration_depth_anything.py +12 -42
  209. transformers/models/depth_anything/modeling_depth_anything.py +5 -3
  210. transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
  211. transformers/models/depth_pro/modeling_depth_pro.py +8 -4
  212. transformers/models/detr/configuration_detr.py +18 -49
  213. transformers/models/detr/image_processing_detr_fast.py +11 -11
  214. transformers/models/detr/modeling_detr.py +695 -734
  215. transformers/models/dia/configuration_dia.py +4 -7
  216. transformers/models/dia/generation_dia.py +8 -17
  217. transformers/models/dia/modeling_dia.py +7 -7
  218. transformers/models/dia/modular_dia.py +4 -4
  219. transformers/models/diffllama/configuration_diffllama.py +5 -7
  220. transformers/models/diffllama/modeling_diffllama.py +3 -8
  221. transformers/models/diffllama/modular_diffllama.py +2 -7
  222. transformers/models/dinat/configuration_dinat.py +2 -4
  223. transformers/models/dinat/modeling_dinat.py +7 -6
  224. transformers/models/dinov2/configuration_dinov2.py +2 -4
  225. transformers/models/dinov2/modeling_dinov2.py +9 -8
  226. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
  227. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
  228. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
  229. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
  230. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
  231. transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
  232. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
  233. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
  234. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
  235. transformers/models/distilbert/configuration_distilbert.py +8 -1
  236. transformers/models/distilbert/modeling_distilbert.py +3 -3
  237. transformers/models/doge/configuration_doge.py +17 -7
  238. transformers/models/doge/modeling_doge.py +4 -4
  239. transformers/models/doge/modular_doge.py +20 -10
  240. transformers/models/donut/image_processing_donut_fast.py +4 -4
  241. transformers/models/dots1/configuration_dots1.py +16 -7
  242. transformers/models/dots1/modeling_dots1.py +4 -4
  243. transformers/models/dpr/configuration_dpr.py +19 -1
  244. transformers/models/dpt/configuration_dpt.py +23 -65
  245. transformers/models/dpt/image_processing_dpt_fast.py +5 -5
  246. transformers/models/dpt/modeling_dpt.py +19 -15
  247. transformers/models/dpt/modular_dpt.py +4 -4
  248. transformers/models/edgetam/configuration_edgetam.py +1 -1
  249. transformers/models/edgetam/modeling_edgetam.py +53 -53
  250. transformers/models/edgetam/modular_edgetam.py +5 -7
  251. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
  252. transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
  253. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
  254. transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
  255. transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
  256. transformers/models/electra/configuration_electra.py +13 -2
  257. transformers/models/electra/modeling_electra.py +6 -6
  258. transformers/models/emu3/configuration_emu3.py +12 -10
  259. transformers/models/emu3/modeling_emu3.py +84 -47
  260. transformers/models/emu3/modular_emu3.py +77 -39
  261. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
  262. transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
  263. transformers/models/eomt/configuration_eomt.py +12 -13
  264. transformers/models/eomt/image_processing_eomt_fast.py +3 -3
  265. transformers/models/eomt/modeling_eomt.py +3 -3
  266. transformers/models/eomt/modular_eomt.py +17 -17
  267. transformers/models/eomt_dinov3/__init__.py +28 -0
  268. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  269. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  270. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  271. transformers/models/ernie/configuration_ernie.py +24 -2
  272. transformers/models/ernie/modeling_ernie.py +6 -30
  273. transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
  274. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  275. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
  276. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
  277. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
  278. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
  279. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
  280. transformers/models/esm/configuration_esm.py +9 -11
  281. transformers/models/esm/modeling_esm.py +3 -3
  282. transformers/models/esm/modeling_esmfold.py +1 -6
  283. transformers/models/esm/openfold_utils/protein.py +2 -3
  284. transformers/models/evolla/configuration_evolla.py +21 -8
  285. transformers/models/evolla/modeling_evolla.py +11 -7
  286. transformers/models/evolla/modular_evolla.py +5 -1
  287. transformers/models/exaone4/configuration_exaone4.py +8 -5
  288. transformers/models/exaone4/modeling_exaone4.py +4 -4
  289. transformers/models/exaone4/modular_exaone4.py +11 -8
  290. transformers/models/exaone_moe/__init__.py +27 -0
  291. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  292. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  293. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  294. transformers/models/falcon/configuration_falcon.py +9 -1
  295. transformers/models/falcon/modeling_falcon.py +3 -8
  296. transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
  297. transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
  298. transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
  299. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
  300. transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
  301. transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
  302. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
  303. transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
  304. transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
  305. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
  306. transformers/models/flaubert/configuration_flaubert.py +10 -4
  307. transformers/models/flaubert/modeling_flaubert.py +1 -1
  308. transformers/models/flava/configuration_flava.py +4 -3
  309. transformers/models/flava/image_processing_flava_fast.py +4 -4
  310. transformers/models/flava/modeling_flava.py +36 -28
  311. transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
  312. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
  313. transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
  314. transformers/models/florence2/configuration_florence2.py +4 -0
  315. transformers/models/florence2/modeling_florence2.py +57 -32
  316. transformers/models/florence2/modular_florence2.py +48 -26
  317. transformers/models/fnet/configuration_fnet.py +6 -1
  318. transformers/models/focalnet/configuration_focalnet.py +2 -4
  319. transformers/models/focalnet/modeling_focalnet.py +10 -7
  320. transformers/models/fsmt/configuration_fsmt.py +12 -16
  321. transformers/models/funnel/configuration_funnel.py +8 -0
  322. transformers/models/fuyu/configuration_fuyu.py +5 -8
  323. transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
  324. transformers/models/fuyu/modeling_fuyu.py +24 -23
  325. transformers/models/gemma/configuration_gemma.py +5 -7
  326. transformers/models/gemma/modeling_gemma.py +4 -4
  327. transformers/models/gemma/modular_gemma.py +5 -7
  328. transformers/models/gemma2/configuration_gemma2.py +5 -7
  329. transformers/models/gemma2/modeling_gemma2.py +4 -4
  330. transformers/models/gemma2/modular_gemma2.py +8 -10
  331. transformers/models/gemma3/configuration_gemma3.py +28 -22
  332. transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
  333. transformers/models/gemma3/modeling_gemma3.py +37 -33
  334. transformers/models/gemma3/modular_gemma3.py +46 -42
  335. transformers/models/gemma3n/configuration_gemma3n.py +35 -22
  336. transformers/models/gemma3n/modeling_gemma3n.py +86 -58
  337. transformers/models/gemma3n/modular_gemma3n.py +112 -75
  338. transformers/models/git/configuration_git.py +5 -7
  339. transformers/models/git/modeling_git.py +31 -41
  340. transformers/models/glm/configuration_glm.py +7 -9
  341. transformers/models/glm/modeling_glm.py +4 -4
  342. transformers/models/glm4/configuration_glm4.py +7 -9
  343. transformers/models/glm4/modeling_glm4.py +4 -4
  344. transformers/models/glm46v/configuration_glm46v.py +4 -0
  345. transformers/models/glm46v/image_processing_glm46v.py +5 -2
  346. transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
  347. transformers/models/glm46v/modeling_glm46v.py +91 -46
  348. transformers/models/glm46v/modular_glm46v.py +4 -0
  349. transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
  350. transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
  351. transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
  352. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
  353. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
  354. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
  355. transformers/models/glm4v/configuration_glm4v.py +12 -8
  356. transformers/models/glm4v/image_processing_glm4v.py +5 -2
  357. transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
  358. transformers/models/glm4v/modeling_glm4v.py +120 -63
  359. transformers/models/glm4v/modular_glm4v.py +82 -50
  360. transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
  361. transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
  362. transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
  363. transformers/models/glm_image/configuration_glm_image.py +26 -20
  364. transformers/models/glm_image/image_processing_glm_image.py +1 -1
  365. transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
  366. transformers/models/glm_image/modeling_glm_image.py +337 -236
  367. transformers/models/glm_image/modular_glm_image.py +415 -255
  368. transformers/models/glm_image/processing_glm_image.py +65 -17
  369. transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
  370. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  371. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  372. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  373. transformers/models/glmasr/modeling_glmasr.py +34 -28
  374. transformers/models/glmasr/modular_glmasr.py +23 -11
  375. transformers/models/glpn/image_processing_glpn_fast.py +3 -3
  376. transformers/models/glpn/modeling_glpn.py +4 -2
  377. transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
  378. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
  379. transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
  380. transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
  381. transformers/models/gpt2/configuration_gpt2.py +13 -1
  382. transformers/models/gpt2/modeling_gpt2.py +5 -5
  383. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
  384. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
  385. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
  386. transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
  387. transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
  388. transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
  389. transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
  390. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
  391. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
  392. transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
  393. transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
  394. transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
  395. transformers/models/gptj/configuration_gptj.py +4 -4
  396. transformers/models/gptj/modeling_gptj.py +3 -7
  397. transformers/models/granite/configuration_granite.py +5 -7
  398. transformers/models/granite/modeling_granite.py +4 -4
  399. transformers/models/granite_speech/modeling_granite_speech.py +63 -37
  400. transformers/models/granitemoe/configuration_granitemoe.py +5 -7
  401. transformers/models/granitemoe/modeling_granitemoe.py +4 -4
  402. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
  403. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
  404. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
  405. transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
  406. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
  407. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
  408. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
  409. transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
  410. transformers/models/groupvit/configuration_groupvit.py +4 -1
  411. transformers/models/groupvit/modeling_groupvit.py +29 -22
  412. transformers/models/helium/configuration_helium.py +5 -7
  413. transformers/models/helium/modeling_helium.py +4 -4
  414. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
  415. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
  416. transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
  417. transformers/models/hiera/configuration_hiera.py +2 -4
  418. transformers/models/hiera/modeling_hiera.py +11 -8
  419. transformers/models/hubert/configuration_hubert.py +4 -1
  420. transformers/models/hubert/modeling_hubert.py +7 -4
  421. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
  422. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
  423. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
  424. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
  425. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
  426. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
  427. transformers/models/ibert/configuration_ibert.py +4 -1
  428. transformers/models/idefics/configuration_idefics.py +5 -7
  429. transformers/models/idefics/modeling_idefics.py +3 -4
  430. transformers/models/idefics/vision.py +5 -4
  431. transformers/models/idefics2/configuration_idefics2.py +1 -2
  432. transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
  433. transformers/models/idefics2/modeling_idefics2.py +72 -50
  434. transformers/models/idefics3/configuration_idefics3.py +1 -3
  435. transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
  436. transformers/models/idefics3/modeling_idefics3.py +63 -40
  437. transformers/models/ijepa/modeling_ijepa.py +3 -3
  438. transformers/models/imagegpt/configuration_imagegpt.py +9 -1
  439. transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
  440. transformers/models/imagegpt/modeling_imagegpt.py +8 -4
  441. transformers/models/informer/modeling_informer.py +3 -3
  442. transformers/models/instructblip/configuration_instructblip.py +2 -1
  443. transformers/models/instructblip/modeling_instructblip.py +65 -39
  444. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
  445. transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
  446. transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
  447. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
  448. transformers/models/internvl/configuration_internvl.py +5 -0
  449. transformers/models/internvl/modeling_internvl.py +35 -55
  450. transformers/models/internvl/modular_internvl.py +26 -38
  451. transformers/models/internvl/video_processing_internvl.py +2 -2
  452. transformers/models/jais2/configuration_jais2.py +5 -7
  453. transformers/models/jais2/modeling_jais2.py +4 -4
  454. transformers/models/jamba/configuration_jamba.py +5 -7
  455. transformers/models/jamba/modeling_jamba.py +4 -4
  456. transformers/models/jamba/modular_jamba.py +3 -3
  457. transformers/models/janus/image_processing_janus.py +2 -2
  458. transformers/models/janus/image_processing_janus_fast.py +8 -8
  459. transformers/models/janus/modeling_janus.py +63 -146
  460. transformers/models/janus/modular_janus.py +62 -20
  461. transformers/models/jetmoe/configuration_jetmoe.py +6 -4
  462. transformers/models/jetmoe/modeling_jetmoe.py +3 -3
  463. transformers/models/jetmoe/modular_jetmoe.py +3 -3
  464. transformers/models/kosmos2/configuration_kosmos2.py +10 -8
  465. transformers/models/kosmos2/modeling_kosmos2.py +56 -34
  466. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
  467. transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
  468. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
  469. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
  470. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
  471. transformers/models/lasr/configuration_lasr.py +2 -4
  472. transformers/models/lasr/modeling_lasr.py +3 -3
  473. transformers/models/lasr/modular_lasr.py +3 -3
  474. transformers/models/layoutlm/configuration_layoutlm.py +14 -1
  475. transformers/models/layoutlm/modeling_layoutlm.py +3 -3
  476. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
  477. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
  478. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
  479. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
  480. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
  481. transformers/models/led/configuration_led.py +7 -8
  482. transformers/models/levit/image_processing_levit_fast.py +4 -4
  483. transformers/models/lfm2/configuration_lfm2.py +5 -7
  484. transformers/models/lfm2/modeling_lfm2.py +4 -4
  485. transformers/models/lfm2/modular_lfm2.py +3 -3
  486. transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
  487. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
  488. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  489. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
  490. transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
  491. transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
  492. transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
  493. transformers/models/lightglue/modeling_lightglue.py +3 -3
  494. transformers/models/lightglue/modular_lightglue.py +3 -3
  495. transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
  496. transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
  497. transformers/models/lilt/configuration_lilt.py +6 -1
  498. transformers/models/llama/configuration_llama.py +5 -7
  499. transformers/models/llama/modeling_llama.py +4 -4
  500. transformers/models/llama4/configuration_llama4.py +67 -47
  501. transformers/models/llama4/image_processing_llama4_fast.py +3 -3
  502. transformers/models/llama4/modeling_llama4.py +46 -44
  503. transformers/models/llava/configuration_llava.py +10 -0
  504. transformers/models/llava/image_processing_llava_fast.py +3 -3
  505. transformers/models/llava/modeling_llava.py +38 -65
  506. transformers/models/llava_next/configuration_llava_next.py +2 -1
  507. transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
  508. transformers/models/llava_next/modeling_llava_next.py +61 -60
  509. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
  510. transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
  511. transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
  512. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
  513. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
  514. transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
  515. transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
  516. transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
  517. transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
  518. transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
  519. transformers/models/longformer/configuration_longformer.py +4 -1
  520. transformers/models/longt5/configuration_longt5.py +9 -6
  521. transformers/models/longt5/modeling_longt5.py +2 -1
  522. transformers/models/luke/configuration_luke.py +8 -1
  523. transformers/models/lw_detr/configuration_lw_detr.py +19 -31
  524. transformers/models/lw_detr/modeling_lw_detr.py +43 -44
  525. transformers/models/lw_detr/modular_lw_detr.py +36 -38
  526. transformers/models/lxmert/configuration_lxmert.py +16 -0
  527. transformers/models/m2m_100/configuration_m2m_100.py +7 -8
  528. transformers/models/m2m_100/modeling_m2m_100.py +3 -3
  529. transformers/models/mamba/configuration_mamba.py +5 -2
  530. transformers/models/mamba/modeling_mamba.py +18 -26
  531. transformers/models/mamba2/configuration_mamba2.py +5 -7
  532. transformers/models/mamba2/modeling_mamba2.py +22 -33
  533. transformers/models/marian/configuration_marian.py +10 -4
  534. transformers/models/marian/modeling_marian.py +4 -4
  535. transformers/models/markuplm/configuration_markuplm.py +4 -6
  536. transformers/models/markuplm/modeling_markuplm.py +3 -3
  537. transformers/models/mask2former/configuration_mask2former.py +12 -47
  538. transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
  539. transformers/models/mask2former/modeling_mask2former.py +18 -12
  540. transformers/models/maskformer/configuration_maskformer.py +14 -45
  541. transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
  542. transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
  543. transformers/models/maskformer/modeling_maskformer.py +15 -9
  544. transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
  545. transformers/models/mbart/configuration_mbart.py +9 -4
  546. transformers/models/mbart/modeling_mbart.py +9 -6
  547. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
  548. transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
  549. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  550. transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
  551. transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
  552. transformers/models/mgp_str/modeling_mgp_str.py +4 -2
  553. transformers/models/mimi/configuration_mimi.py +4 -0
  554. transformers/models/mimi/modeling_mimi.py +40 -36
  555. transformers/models/minimax/configuration_minimax.py +8 -11
  556. transformers/models/minimax/modeling_minimax.py +5 -5
  557. transformers/models/minimax/modular_minimax.py +9 -12
  558. transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
  559. transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
  560. transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
  561. transformers/models/ministral/configuration_ministral.py +5 -7
  562. transformers/models/ministral/modeling_ministral.py +4 -4
  563. transformers/models/ministral/modular_ministral.py +5 -8
  564. transformers/models/ministral3/configuration_ministral3.py +4 -4
  565. transformers/models/ministral3/modeling_ministral3.py +4 -4
  566. transformers/models/ministral3/modular_ministral3.py +3 -3
  567. transformers/models/mistral/configuration_mistral.py +5 -7
  568. transformers/models/mistral/modeling_mistral.py +4 -4
  569. transformers/models/mistral/modular_mistral.py +3 -3
  570. transformers/models/mistral3/configuration_mistral3.py +4 -0
  571. transformers/models/mistral3/modeling_mistral3.py +36 -40
  572. transformers/models/mistral3/modular_mistral3.py +31 -32
  573. transformers/models/mixtral/configuration_mixtral.py +8 -11
  574. transformers/models/mixtral/modeling_mixtral.py +4 -4
  575. transformers/models/mlcd/modeling_mlcd.py +7 -5
  576. transformers/models/mlcd/modular_mlcd.py +7 -5
  577. transformers/models/mllama/configuration_mllama.py +5 -7
  578. transformers/models/mllama/image_processing_mllama_fast.py +6 -5
  579. transformers/models/mllama/modeling_mllama.py +19 -19
  580. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
  581. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
  582. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
  583. transformers/models/mobilebert/configuration_mobilebert.py +4 -1
  584. transformers/models/mobilebert/modeling_mobilebert.py +3 -3
  585. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
  586. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
  587. transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
  588. transformers/models/mobilevit/modeling_mobilevit.py +4 -2
  589. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
  590. transformers/models/modernbert/configuration_modernbert.py +46 -21
  591. transformers/models/modernbert/modeling_modernbert.py +146 -899
  592. transformers/models/modernbert/modular_modernbert.py +185 -908
  593. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
  594. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
  595. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
  596. transformers/models/moonshine/configuration_moonshine.py +12 -7
  597. transformers/models/moonshine/modeling_moonshine.py +7 -7
  598. transformers/models/moonshine/modular_moonshine.py +19 -13
  599. transformers/models/moshi/configuration_moshi.py +28 -2
  600. transformers/models/moshi/modeling_moshi.py +4 -9
  601. transformers/models/mpnet/configuration_mpnet.py +6 -1
  602. transformers/models/mpt/configuration_mpt.py +16 -0
  603. transformers/models/mra/configuration_mra.py +8 -1
  604. transformers/models/mt5/configuration_mt5.py +9 -5
  605. transformers/models/mt5/modeling_mt5.py +5 -8
  606. transformers/models/musicgen/configuration_musicgen.py +12 -7
  607. transformers/models/musicgen/modeling_musicgen.py +6 -5
  608. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
  609. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
  610. transformers/models/mvp/configuration_mvp.py +8 -4
  611. transformers/models/mvp/modeling_mvp.py +6 -4
  612. transformers/models/nanochat/configuration_nanochat.py +5 -7
  613. transformers/models/nanochat/modeling_nanochat.py +4 -4
  614. transformers/models/nanochat/modular_nanochat.py +4 -4
  615. transformers/models/nemotron/configuration_nemotron.py +5 -7
  616. transformers/models/nemotron/modeling_nemotron.py +4 -14
  617. transformers/models/nllb/tokenization_nllb.py +7 -5
  618. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
  619. transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
  620. transformers/models/nougat/image_processing_nougat_fast.py +8 -8
  621. transformers/models/nystromformer/configuration_nystromformer.py +8 -1
  622. transformers/models/olmo/configuration_olmo.py +5 -7
  623. transformers/models/olmo/modeling_olmo.py +4 -4
  624. transformers/models/olmo/modular_olmo.py +3 -3
  625. transformers/models/olmo2/configuration_olmo2.py +9 -11
  626. transformers/models/olmo2/modeling_olmo2.py +4 -4
  627. transformers/models/olmo2/modular_olmo2.py +7 -7
  628. transformers/models/olmo3/configuration_olmo3.py +10 -11
  629. transformers/models/olmo3/modeling_olmo3.py +4 -4
  630. transformers/models/olmo3/modular_olmo3.py +13 -14
  631. transformers/models/olmoe/configuration_olmoe.py +5 -7
  632. transformers/models/olmoe/modeling_olmoe.py +4 -4
  633. transformers/models/olmoe/modular_olmoe.py +3 -3
  634. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
  635. transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
  636. transformers/models/oneformer/configuration_oneformer.py +9 -46
  637. transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
  638. transformers/models/oneformer/modeling_oneformer.py +14 -9
  639. transformers/models/openai/configuration_openai.py +16 -0
  640. transformers/models/opt/configuration_opt.py +6 -6
  641. transformers/models/opt/modeling_opt.py +5 -5
  642. transformers/models/ovis2/configuration_ovis2.py +4 -0
  643. transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
  644. transformers/models/ovis2/modeling_ovis2.py +58 -99
  645. transformers/models/ovis2/modular_ovis2.py +52 -13
  646. transformers/models/owlv2/configuration_owlv2.py +4 -1
  647. transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
  648. transformers/models/owlv2/modeling_owlv2.py +40 -27
  649. transformers/models/owlv2/modular_owlv2.py +5 -5
  650. transformers/models/owlvit/configuration_owlvit.py +4 -1
  651. transformers/models/owlvit/modeling_owlvit.py +40 -27
  652. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
  653. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
  654. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
  655. transformers/models/paligemma/configuration_paligemma.py +4 -0
  656. transformers/models/paligemma/modeling_paligemma.py +30 -26
  657. transformers/models/parakeet/configuration_parakeet.py +2 -4
  658. transformers/models/parakeet/modeling_parakeet.py +3 -3
  659. transformers/models/parakeet/modular_parakeet.py +3 -3
  660. transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
  661. transformers/models/patchtst/modeling_patchtst.py +3 -3
  662. transformers/models/pe_audio/modeling_pe_audio.py +4 -4
  663. transformers/models/pe_audio/modular_pe_audio.py +1 -1
  664. transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
  665. transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
  666. transformers/models/pe_video/modeling_pe_video.py +36 -24
  667. transformers/models/pe_video/modular_pe_video.py +36 -23
  668. transformers/models/pegasus/configuration_pegasus.py +8 -5
  669. transformers/models/pegasus/modeling_pegasus.py +4 -4
  670. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
  671. transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
  672. transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
  673. transformers/models/perceiver/modeling_perceiver.py +17 -9
  674. transformers/models/perception_lm/modeling_perception_lm.py +26 -27
  675. transformers/models/perception_lm/modular_perception_lm.py +27 -25
  676. transformers/models/persimmon/configuration_persimmon.py +5 -7
  677. transformers/models/persimmon/modeling_persimmon.py +5 -5
  678. transformers/models/phi/configuration_phi.py +8 -6
  679. transformers/models/phi/modeling_phi.py +4 -4
  680. transformers/models/phi/modular_phi.py +3 -3
  681. transformers/models/phi3/configuration_phi3.py +9 -11
  682. transformers/models/phi3/modeling_phi3.py +4 -4
  683. transformers/models/phi3/modular_phi3.py +3 -3
  684. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
  685. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
  686. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
  687. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
  688. transformers/models/phimoe/configuration_phimoe.py +5 -7
  689. transformers/models/phimoe/modeling_phimoe.py +15 -39
  690. transformers/models/phimoe/modular_phimoe.py +12 -7
  691. transformers/models/pix2struct/configuration_pix2struct.py +12 -9
  692. transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
  693. transformers/models/pix2struct/modeling_pix2struct.py +14 -7
  694. transformers/models/pixio/configuration_pixio.py +2 -4
  695. transformers/models/pixio/modeling_pixio.py +9 -8
  696. transformers/models/pixio/modular_pixio.py +4 -2
  697. transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
  698. transformers/models/pixtral/modeling_pixtral.py +9 -12
  699. transformers/models/plbart/configuration_plbart.py +8 -5
  700. transformers/models/plbart/modeling_plbart.py +9 -7
  701. transformers/models/plbart/modular_plbart.py +1 -1
  702. transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
  703. transformers/models/pop2piano/configuration_pop2piano.py +7 -6
  704. transformers/models/pop2piano/modeling_pop2piano.py +2 -1
  705. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  706. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  707. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  708. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  709. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  710. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
  711. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
  712. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
  713. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
  714. transformers/models/prophetnet/configuration_prophetnet.py +11 -10
  715. transformers/models/prophetnet/modeling_prophetnet.py +12 -23
  716. transformers/models/pvt/image_processing_pvt.py +7 -7
  717. transformers/models/pvt/image_processing_pvt_fast.py +1 -1
  718. transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
  719. transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
  720. transformers/models/qwen2/configuration_qwen2.py +14 -4
  721. transformers/models/qwen2/modeling_qwen2.py +4 -4
  722. transformers/models/qwen2/modular_qwen2.py +3 -3
  723. transformers/models/qwen2/tokenization_qwen2.py +0 -4
  724. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
  725. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
  726. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
  727. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
  728. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
  729. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
  730. transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
  731. transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
  732. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  733. transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
  734. transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
  735. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
  736. transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
  737. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
  738. transformers/models/qwen3/configuration_qwen3.py +15 -5
  739. transformers/models/qwen3/modeling_qwen3.py +4 -4
  740. transformers/models/qwen3/modular_qwen3.py +3 -3
  741. transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
  742. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  743. transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
  744. transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
  745. transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
  746. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
  747. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
  748. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
  749. transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
  750. transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
  751. transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
  752. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
  753. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
  754. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
  755. transformers/models/rag/configuration_rag.py +6 -6
  756. transformers/models/rag/modeling_rag.py +3 -3
  757. transformers/models/rag/retrieval_rag.py +1 -1
  758. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
  759. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
  760. transformers/models/reformer/configuration_reformer.py +7 -7
  761. transformers/models/rembert/configuration_rembert.py +8 -1
  762. transformers/models/rembert/modeling_rembert.py +0 -22
  763. transformers/models/resnet/configuration_resnet.py +2 -4
  764. transformers/models/resnet/modeling_resnet.py +6 -5
  765. transformers/models/roberta/configuration_roberta.py +11 -2
  766. transformers/models/roberta/modeling_roberta.py +6 -6
  767. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
  768. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
  769. transformers/models/roc_bert/configuration_roc_bert.py +8 -1
  770. transformers/models/roc_bert/modeling_roc_bert.py +6 -41
  771. transformers/models/roformer/configuration_roformer.py +13 -2
  772. transformers/models/roformer/modeling_roformer.py +0 -14
  773. transformers/models/rt_detr/configuration_rt_detr.py +8 -49
  774. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
  775. transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
  776. transformers/models/rt_detr/modeling_rt_detr.py +578 -737
  777. transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
  778. transformers/models/rt_detr/modular_rt_detr.py +1508 -6
  779. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
  780. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
  781. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
  782. transformers/models/rwkv/configuration_rwkv.py +2 -3
  783. transformers/models/rwkv/modeling_rwkv.py +0 -23
  784. transformers/models/sam/configuration_sam.py +2 -0
  785. transformers/models/sam/image_processing_sam_fast.py +4 -4
  786. transformers/models/sam/modeling_sam.py +13 -8
  787. transformers/models/sam/processing_sam.py +3 -3
  788. transformers/models/sam2/configuration_sam2.py +1 -1
  789. transformers/models/sam2/modeling_sam2.py +56 -52
  790. transformers/models/sam2/modular_sam2.py +47 -55
  791. transformers/models/sam2_video/modeling_sam2_video.py +50 -51
  792. transformers/models/sam2_video/modular_sam2_video.py +12 -10
  793. transformers/models/sam3/modeling_sam3.py +43 -47
  794. transformers/models/sam3/processing_sam3.py +8 -4
  795. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
  796. transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
  797. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
  798. transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
  799. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
  800. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
  801. transformers/models/sam3_video/modeling_sam3_video.py +27 -14
  802. transformers/models/sam_hq/configuration_sam_hq.py +2 -0
  803. transformers/models/sam_hq/modeling_sam_hq.py +13 -9
  804. transformers/models/sam_hq/modular_sam_hq.py +6 -6
  805. transformers/models/sam_hq/processing_sam_hq.py +7 -6
  806. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
  807. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
  808. transformers/models/seed_oss/configuration_seed_oss.py +7 -9
  809. transformers/models/seed_oss/modeling_seed_oss.py +4 -4
  810. transformers/models/seed_oss/modular_seed_oss.py +3 -3
  811. transformers/models/segformer/image_processing_segformer_fast.py +4 -4
  812. transformers/models/segformer/modeling_segformer.py +4 -2
  813. transformers/models/segformer/modular_segformer.py +3 -3
  814. transformers/models/seggpt/modeling_seggpt.py +20 -8
  815. transformers/models/sew/configuration_sew.py +4 -1
  816. transformers/models/sew/modeling_sew.py +9 -5
  817. transformers/models/sew/modular_sew.py +2 -1
  818. transformers/models/sew_d/configuration_sew_d.py +4 -1
  819. transformers/models/sew_d/modeling_sew_d.py +4 -1
  820. transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
  821. transformers/models/siglip/configuration_siglip.py +4 -1
  822. transformers/models/siglip/modeling_siglip.py +27 -71
  823. transformers/models/siglip2/__init__.py +1 -0
  824. transformers/models/siglip2/configuration_siglip2.py +4 -2
  825. transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
  826. transformers/models/siglip2/modeling_siglip2.py +37 -78
  827. transformers/models/siglip2/modular_siglip2.py +74 -25
  828. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  829. transformers/models/smollm3/configuration_smollm3.py +6 -6
  830. transformers/models/smollm3/modeling_smollm3.py +4 -4
  831. transformers/models/smollm3/modular_smollm3.py +9 -9
  832. transformers/models/smolvlm/configuration_smolvlm.py +1 -3
  833. transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
  834. transformers/models/smolvlm/modeling_smolvlm.py +75 -46
  835. transformers/models/smolvlm/modular_smolvlm.py +36 -23
  836. transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
  837. transformers/models/solar_open/__init__.py +27 -0
  838. transformers/models/solar_open/configuration_solar_open.py +184 -0
  839. transformers/models/solar_open/modeling_solar_open.py +642 -0
  840. transformers/models/solar_open/modular_solar_open.py +224 -0
  841. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
  842. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
  843. transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
  844. transformers/models/speecht5/configuration_speecht5.py +7 -8
  845. transformers/models/splinter/configuration_splinter.py +6 -6
  846. transformers/models/splinter/modeling_splinter.py +8 -3
  847. transformers/models/squeezebert/configuration_squeezebert.py +14 -1
  848. transformers/models/stablelm/configuration_stablelm.py +8 -6
  849. transformers/models/stablelm/modeling_stablelm.py +5 -5
  850. transformers/models/starcoder2/configuration_starcoder2.py +11 -5
  851. transformers/models/starcoder2/modeling_starcoder2.py +5 -5
  852. transformers/models/starcoder2/modular_starcoder2.py +4 -4
  853. transformers/models/superglue/configuration_superglue.py +4 -0
  854. transformers/models/superglue/image_processing_superglue_fast.py +4 -3
  855. transformers/models/superglue/modeling_superglue.py +9 -4
  856. transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
  857. transformers/models/superpoint/modeling_superpoint.py +4 -2
  858. transformers/models/swin/configuration_swin.py +2 -4
  859. transformers/models/swin/modeling_swin.py +11 -8
  860. transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
  861. transformers/models/swin2sr/modeling_swin2sr.py +4 -2
  862. transformers/models/swinv2/configuration_swinv2.py +2 -4
  863. transformers/models/swinv2/modeling_swinv2.py +10 -7
  864. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
  865. transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
  866. transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
  867. transformers/models/t5/configuration_t5.py +9 -8
  868. transformers/models/t5/modeling_t5.py +5 -8
  869. transformers/models/t5gemma/configuration_t5gemma.py +10 -25
  870. transformers/models/t5gemma/modeling_t5gemma.py +9 -9
  871. transformers/models/t5gemma/modular_t5gemma.py +11 -24
  872. transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
  873. transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
  874. transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
  875. transformers/models/table_transformer/configuration_table_transformer.py +18 -49
  876. transformers/models/table_transformer/modeling_table_transformer.py +27 -53
  877. transformers/models/tapas/configuration_tapas.py +12 -1
  878. transformers/models/tapas/modeling_tapas.py +1 -1
  879. transformers/models/tapas/tokenization_tapas.py +1 -0
  880. transformers/models/textnet/configuration_textnet.py +4 -6
  881. transformers/models/textnet/image_processing_textnet_fast.py +3 -3
  882. transformers/models/textnet/modeling_textnet.py +15 -14
  883. transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
  884. transformers/models/timesfm/modeling_timesfm.py +5 -6
  885. transformers/models/timesfm/modular_timesfm.py +5 -6
  886. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
  887. transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
  888. transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
  889. transformers/models/trocr/configuration_trocr.py +11 -7
  890. transformers/models/trocr/modeling_trocr.py +4 -2
  891. transformers/models/tvp/configuration_tvp.py +10 -35
  892. transformers/models/tvp/image_processing_tvp_fast.py +6 -5
  893. transformers/models/tvp/modeling_tvp.py +1 -1
  894. transformers/models/udop/configuration_udop.py +16 -7
  895. transformers/models/udop/modeling_udop.py +10 -6
  896. transformers/models/umt5/configuration_umt5.py +8 -6
  897. transformers/models/umt5/modeling_umt5.py +7 -3
  898. transformers/models/unispeech/configuration_unispeech.py +4 -1
  899. transformers/models/unispeech/modeling_unispeech.py +7 -4
  900. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
  901. transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
  902. transformers/models/upernet/configuration_upernet.py +8 -35
  903. transformers/models/upernet/modeling_upernet.py +1 -1
  904. transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
  905. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  906. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  907. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
  908. transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
  909. transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
  910. transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
  911. transformers/models/video_llava/configuration_video_llava.py +4 -0
  912. transformers/models/video_llava/modeling_video_llava.py +87 -89
  913. transformers/models/videomae/modeling_videomae.py +4 -5
  914. transformers/models/vilt/configuration_vilt.py +4 -1
  915. transformers/models/vilt/image_processing_vilt_fast.py +6 -6
  916. transformers/models/vilt/modeling_vilt.py +27 -12
  917. transformers/models/vipllava/configuration_vipllava.py +4 -0
  918. transformers/models/vipllava/modeling_vipllava.py +57 -31
  919. transformers/models/vipllava/modular_vipllava.py +50 -24
  920. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
  921. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
  922. transformers/models/visual_bert/configuration_visual_bert.py +6 -1
  923. transformers/models/vit/configuration_vit.py +2 -2
  924. transformers/models/vit/modeling_vit.py +7 -5
  925. transformers/models/vit_mae/modeling_vit_mae.py +11 -7
  926. transformers/models/vit_msn/modeling_vit_msn.py +11 -7
  927. transformers/models/vitdet/configuration_vitdet.py +2 -4
  928. transformers/models/vitdet/modeling_vitdet.py +2 -3
  929. transformers/models/vitmatte/configuration_vitmatte.py +6 -35
  930. transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
  931. transformers/models/vitmatte/modeling_vitmatte.py +1 -1
  932. transformers/models/vitpose/configuration_vitpose.py +6 -43
  933. transformers/models/vitpose/modeling_vitpose.py +5 -3
  934. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
  935. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
  936. transformers/models/vits/configuration_vits.py +4 -0
  937. transformers/models/vits/modeling_vits.py +9 -7
  938. transformers/models/vivit/modeling_vivit.py +4 -4
  939. transformers/models/vjepa2/modeling_vjepa2.py +9 -9
  940. transformers/models/voxtral/configuration_voxtral.py +0 -1
  941. transformers/models/voxtral/modeling_voxtral.py +25 -24
  942. transformers/models/voxtral/modular_voxtral.py +26 -20
  943. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
  944. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
  945. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
  946. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
  947. transformers/models/wavlm/configuration_wavlm.py +4 -1
  948. transformers/models/wavlm/modeling_wavlm.py +4 -1
  949. transformers/models/whisper/configuration_whisper.py +6 -4
  950. transformers/models/whisper/generation_whisper.py +0 -1
  951. transformers/models/whisper/modeling_whisper.py +3 -3
  952. transformers/models/x_clip/configuration_x_clip.py +4 -1
  953. transformers/models/x_clip/modeling_x_clip.py +26 -27
  954. transformers/models/xglm/configuration_xglm.py +9 -7
  955. transformers/models/xlm/configuration_xlm.py +10 -7
  956. transformers/models/xlm/modeling_xlm.py +1 -1
  957. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
  958. transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
  959. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
  960. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
  961. transformers/models/xlnet/configuration_xlnet.py +3 -1
  962. transformers/models/xlstm/configuration_xlstm.py +5 -7
  963. transformers/models/xlstm/modeling_xlstm.py +0 -32
  964. transformers/models/xmod/configuration_xmod.py +11 -2
  965. transformers/models/xmod/modeling_xmod.py +13 -16
  966. transformers/models/yolos/image_processing_yolos_fast.py +25 -28
  967. transformers/models/yolos/modeling_yolos.py +7 -7
  968. transformers/models/yolos/modular_yolos.py +16 -16
  969. transformers/models/yoso/configuration_yoso.py +8 -1
  970. transformers/models/youtu/__init__.py +27 -0
  971. transformers/models/youtu/configuration_youtu.py +194 -0
  972. transformers/models/youtu/modeling_youtu.py +619 -0
  973. transformers/models/youtu/modular_youtu.py +254 -0
  974. transformers/models/zamba/configuration_zamba.py +5 -7
  975. transformers/models/zamba/modeling_zamba.py +25 -56
  976. transformers/models/zamba2/configuration_zamba2.py +8 -13
  977. transformers/models/zamba2/modeling_zamba2.py +53 -78
  978. transformers/models/zamba2/modular_zamba2.py +36 -29
  979. transformers/models/zoedepth/configuration_zoedepth.py +17 -40
  980. transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
  981. transformers/models/zoedepth/modeling_zoedepth.py +5 -3
  982. transformers/pipelines/__init__.py +1 -61
  983. transformers/pipelines/any_to_any.py +1 -1
  984. transformers/pipelines/automatic_speech_recognition.py +0 -2
  985. transformers/pipelines/base.py +1 -1
  986. transformers/pipelines/image_text_to_text.py +1 -1
  987. transformers/pipelines/text_to_audio.py +5 -1
  988. transformers/processing_utils.py +35 -44
  989. transformers/pytorch_utils.py +2 -26
  990. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  991. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  992. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  993. transformers/quantizers/quantizer_mxfp4.py +1 -1
  994. transformers/quantizers/quantizer_torchao.py +0 -16
  995. transformers/safetensors_conversion.py +11 -4
  996. transformers/testing_utils.py +3 -28
  997. transformers/tokenization_mistral_common.py +9 -0
  998. transformers/tokenization_python.py +6 -4
  999. transformers/tokenization_utils_base.py +119 -219
  1000. transformers/tokenization_utils_tokenizers.py +31 -2
  1001. transformers/trainer.py +25 -33
  1002. transformers/trainer_seq2seq.py +1 -1
  1003. transformers/training_args.py +411 -417
  1004. transformers/utils/__init__.py +1 -4
  1005. transformers/utils/auto_docstring.py +15 -18
  1006. transformers/utils/backbone_utils.py +13 -373
  1007. transformers/utils/doc.py +4 -36
  1008. transformers/utils/generic.py +69 -33
  1009. transformers/utils/import_utils.py +72 -75
  1010. transformers/utils/loading_report.py +133 -105
  1011. transformers/utils/quantization_config.py +0 -21
  1012. transformers/video_processing_utils.py +5 -5
  1013. transformers/video_utils.py +3 -1
  1014. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
  1015. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
  1016. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1017. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1018. transformers/pipelines/image_to_text.py +0 -189
  1019. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1020. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1021. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ import math
17
17
  import operator
18
18
  import os
19
19
  import re
20
- from functools import partial, reduce
20
+ from functools import reduce
21
21
 
22
22
  from ..distributed import DistributedConfig
23
23
  from ..utils import is_torch_greater_or_equal, logging
@@ -33,9 +33,6 @@ if is_torch_available():
33
33
  # Cache this result has it's a C FFI call which can be pretty time-consuming
34
34
  _torch_distributed_available = torch.distributed.is_available()
35
35
 
36
- if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
37
- from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
38
-
39
36
 
40
37
  logger = logging.get_logger(__name__)
41
38
 
@@ -68,10 +65,6 @@ def initialize_tensor_parallelism(
68
65
 
69
66
  backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"}
70
67
  backend = backend_map.get(device_type)
71
- if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", "0")):
72
- backend = "ccl"
73
- if device_type == "xpu" and not is_torch_greater_or_equal("2.8", accept_dev=True):
74
- backend = "ccl"
75
68
 
76
69
  torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
77
70
  current_device = getattr(torch, device_type)
@@ -116,32 +109,6 @@ def initialize_tensor_parallelism(
116
109
  return device_map, device_mesh, tp_size
117
110
 
118
111
 
119
- def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
120
- """
121
- Convert block count or proportions to block sizes.
122
-
123
- This function accepts
124
-
125
- - The number of blocks (int), in which case the block size is
126
- total_size//blocks; or
127
- - A list of block sizes (list[int]).
128
-
129
- In the second case, if sum(blocks) < total_size, the ratios between
130
- the block sizes will be preserved. For instance, if blocks is
131
- [2, 1, 1] and total_size is 1024, the returned block sizes are
132
- [512, 256, 256].
133
- """
134
- if isinstance(blocks, list):
135
- total_blocks = sum(blocks)
136
- assert total_size % total_blocks == 0, f"Cannot split {total_size} in proportional blocks: {blocks}"
137
- part_size = total_size // total_blocks
138
- return [part_size * block for block in blocks]
139
- else:
140
- assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
141
- single_size = total_size // blocks
142
- return [single_size] * blocks
143
-
144
-
145
112
  def replace_layer_number_by_wildcard(name: str) -> str:
146
113
  """
147
114
  Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
@@ -170,6 +137,11 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
170
137
  return None
171
138
 
172
139
 
140
+ # =============================================================================
141
+ # Tensor Sharding Utilities
142
+ # =============================================================================
143
+
144
+
173
145
  if is_torch_available():
174
146
  str_to_dtype = {
175
147
  "BOOL": torch.bool,
@@ -186,6 +158,32 @@ if is_torch_available():
186
158
  }
187
159
 
188
160
 
161
+ def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
162
+ """
163
+ Convert block count or proportions to block sizes.
164
+
165
+ This function accepts
166
+
167
+ - The number of blocks (int), in which case the block size is
168
+ total_size//blocks; or
169
+ - A list of block sizes (list[int]).
170
+
171
+ In the second case, if sum(blocks) < total_size, the ratios between
172
+ the block sizes will be preserved. For instance, if blocks is
173
+ [2, 1, 1] and total_size is 1024, the returned block sizes are
174
+ [512, 256, 256].
175
+ """
176
+ if isinstance(blocks, list):
177
+ total_blocks = sum(blocks)
178
+ assert total_size % total_blocks == 0, f"Cannot split {total_size} in proportional blocks: {blocks}"
179
+ part_size = total_size // total_blocks
180
+ return [part_size * block for block in blocks]
181
+ else:
182
+ assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
183
+ single_size = total_size // blocks
184
+ return [single_size] * blocks
185
+
186
+
189
187
  def get_packed_weights(param, empty_param, device_mesh, rank, dim):
190
188
  """
191
189
  When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
@@ -372,19 +370,20 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: int
372
370
  dim (int): Dimension along which to shard the tensor.
373
371
  """
374
372
  param_dim = empty_param.ndim
375
- # Flatten the mesh to get the total number of devices
376
373
  mesh_shape = device_mesh.shape
377
374
  world_size = reduce(operator.mul, mesh_shape)
375
+ # Get param shape: works for both torch.Tensor and safetensors TensorInfo
376
+ param_shape = list(param.shape) if isinstance(param, torch.Tensor) else param.get_shape()
378
377
  if dim < 0:
379
378
  dim = param_dim + dim
380
- if empty_param.dim() == 3 and dim == 1 and len(param.get_shape()) == 2:
381
- dim = 0
382
- elif empty_param.dim() == 3 and dim == 2 and len(param.get_shape()) == 2:
379
+ if empty_param.dim() == 3 and dim == 1 and len(param_shape) == 2:
383
380
  dim = 0
381
+ elif empty_param.dim() == 3 and dim == 2 and len(param_shape) == 2:
382
+ dim = 1
384
383
 
385
- shard_size = math.ceil(empty_param.size(dim) / world_size)
384
+ shard_size = math.ceil(param_shape[dim] / world_size)
386
385
  start = rank * shard_size
387
- end = min(start + shard_size, empty_param.size(dim))
386
+ end = min(start + shard_size, param_shape[dim])
388
387
 
389
388
  if dim >= param_dim:
390
389
  raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
@@ -401,9 +400,7 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: int
401
400
  # actually we still shard dim=0 does not change
402
401
  # so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
403
402
  # tensor on a certain device (with the input tensor_index)
404
- dimensions = param.get_shape()
405
-
406
- if empty_param.dim() == 3 and dim == 0 and len(param.get_shape()) == 2:
403
+ if tensor_idx is not None and empty_param.dim() == 3 and dim == 0 and len(param_shape) == 2:
407
404
  # special case we don't "shard" just send this entire tensor to the correct rank.
408
405
  if start <= tensor_idx < end:
409
406
  # this tensor does need to be materialized on this device:
@@ -411,17 +408,214 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: int
411
408
  else:
412
409
  return torch.empty([], dtype=torch.int64, device=rank)
413
410
 
414
- slice_indices = [slice(None)] * len(param.get_shape())
411
+ slice_indices = [slice(None)] * len(param_shape)
415
412
 
416
- if start < param.get_shape()[dim]:
413
+ if start < param_shape[dim]:
417
414
  slice_indices[dim] = slice(start, end)
418
415
  param = param[tuple(slice_indices)]
419
416
  if isinstance(param, list): # TODO handle the modulelist case!
420
417
  param = [p[:] for p in param]
421
418
  return param
422
419
 
423
- dimensions[dim] = 0
424
- return torch.empty(tuple(dimensions), dtype=torch.int64) # empty allocates memory....
420
+ param_shape[dim] = 0
421
+ return torch.empty(tuple(param_shape), dtype=torch.int64) # empty allocates memory....
422
+
423
+
424
+ def _split_along_last_dim(x, world_size):
425
+ """Split tensor along last dimension into world_size chunks."""
426
+ return torch.chunk(x, world_size, dim=-1)
427
+
428
+
429
+ # =============================================================================
430
+ # Distributed Communication Primitives
431
+ # =============================================================================
432
+ #
433
+ # Naming convention:
434
+ # - Functions describe their FORWARD behavior
435
+ # - Backward behavior is the "conjugate" operation for gradient flow
436
+ #
437
+ # Available operations:
438
+ # ┌────────────────────┬─────────────────────┬─────────────────────┐
439
+ # │ Function │ Forward │ Backward │
440
+ # ├────────────────────┼─────────────────────┼─────────────────────┤
441
+ # │ all_reduce │ all-reduce (sum) │ identity │
442
+ # │ all_reduce_backward│ identity │ all-reduce (sum) │
443
+ # │ all_gather │ all-gather │ split (local chunk) │
444
+ # │ split │ split (local chunk) │ all-gather │
445
+ # │ reduce_scatter │ reduce-scatter │ all-gather │
446
+ # └────────────────────┴─────────────────────┴─────────────────────┘
447
+ # ===================
448
+
449
+
450
+ class _AllReduceBackward(torch.autograd.Function):
451
+ """Identity forward, all-reduce backward. Used before colwise layers (f in Megatron)."""
452
+
453
+ @staticmethod
454
+ def forward(ctx, x, device_mesh):
455
+ ctx.device_mesh = device_mesh
456
+ return x
457
+
458
+ @staticmethod
459
+ def backward(ctx, grad_output):
460
+ device_mesh = ctx.device_mesh
461
+ if device_mesh.size() == 1:
462
+ return grad_output, None
463
+ dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
464
+ return grad_output, None
465
+
466
+
467
+ class _AllReduceForward(torch.autograd.Function):
468
+ """All-reduce forward, identity backward. Used after rowwise layers (g in Megatron)."""
469
+
470
+ @staticmethod
471
+ def forward(ctx, x, device_mesh):
472
+ if device_mesh.size() == 1:
473
+ return x
474
+ dist.all_reduce(x, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
475
+ return x
476
+
477
+ @staticmethod
478
+ def backward(ctx, grad_output):
479
+ return grad_output, None
480
+
481
+
482
+ class _AllGather(torch.autograd.Function):
483
+ """All-gather forward, split backward. Gathers sharded outputs."""
484
+
485
+ @staticmethod
486
+ def forward(ctx, x, device_mesh):
487
+ ctx.device_mesh = device_mesh
488
+ world_size = device_mesh.size()
489
+
490
+ if world_size == 1:
491
+ return x
492
+
493
+ last_dim = x.dim() - 1
494
+ rank = device_mesh.get_local_rank()
495
+ group = device_mesh.get_group()
496
+
497
+ x = x.contiguous()
498
+ tensor_list = [torch.empty_like(x) for _ in range(world_size)]
499
+ tensor_list[rank] = x
500
+ dist.all_gather(tensor_list, x, group=group)
501
+ return torch.cat(tensor_list, dim=last_dim).contiguous()
502
+
503
+ @staticmethod
504
+ def backward(ctx, grad_output):
505
+ device_mesh = ctx.device_mesh
506
+ world_size = device_mesh.size()
507
+
508
+ if world_size == 1:
509
+ return grad_output, None
510
+
511
+ rank = device_mesh.get_local_rank()
512
+ chunks = _split_along_last_dim(grad_output, world_size)
513
+ return chunks[rank].contiguous(), None
514
+
515
+
516
+ class _Split(torch.autograd.Function):
517
+ """Split forward, all-gather backward. Scatters replicated input."""
518
+
519
+ @staticmethod
520
+ def forward(ctx, x, device_mesh):
521
+ ctx.device_mesh = device_mesh
522
+ world_size = device_mesh.size()
523
+
524
+ if world_size == 1:
525
+ return x
526
+
527
+ rank = device_mesh.get_local_rank()
528
+ chunks = _split_along_last_dim(x, world_size)
529
+ return chunks[rank].contiguous()
530
+
531
+ @staticmethod
532
+ def backward(ctx, grad_output):
533
+ device_mesh = ctx.device_mesh
534
+ world_size = device_mesh.size()
535
+
536
+ if world_size == 1:
537
+ return grad_output, None
538
+
539
+ last_dim = grad_output.dim() - 1
540
+ rank = device_mesh.get_local_rank()
541
+ group = device_mesh.get_group()
542
+
543
+ grad_output = grad_output.contiguous()
544
+ tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
545
+ tensor_list[rank] = grad_output
546
+ dist.all_gather(tensor_list, grad_output, group=group)
547
+ return torch.cat(tensor_list, dim=last_dim).contiguous(), None
548
+
549
+
550
+ class _ReduceScatter(torch.autograd.Function):
551
+ """Reduce-scatter forward, all-gather backward. For sequence parallel."""
552
+
553
+ @staticmethod
554
+ def forward(ctx, x, device_mesh):
555
+ ctx.device_mesh = device_mesh
556
+ world_size = device_mesh.size()
557
+
558
+ if world_size == 1:
559
+ return x
560
+
561
+ last_dim = x.dim() - 1
562
+ group = device_mesh.get_group()
563
+
564
+ input_chunks = list(x.chunk(world_size, dim=last_dim))
565
+ output_shape = list(x.shape)
566
+ output_shape[last_dim] //= world_size
567
+ output = torch.empty(output_shape, dtype=x.dtype, device=x.device)
568
+
569
+ dist.reduce_scatter(output, input_chunks, op=dist.ReduceOp.SUM, group=group)
570
+ return output
571
+
572
+ @staticmethod
573
+ def backward(ctx, grad_output):
574
+ device_mesh = ctx.device_mesh
575
+ world_size = device_mesh.size()
576
+
577
+ if world_size == 1:
578
+ return grad_output, None
579
+
580
+ last_dim = grad_output.dim() - 1
581
+ rank = device_mesh.get_local_rank()
582
+ group = device_mesh.get_group()
583
+
584
+ grad_output = grad_output.contiguous()
585
+ tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
586
+ tensor_list[rank] = grad_output
587
+ dist.all_gather(tensor_list, grad_output, group=group)
588
+ return torch.cat(tensor_list, dim=last_dim).contiguous(), None
589
+
590
+
591
+ # =============================================================================
592
+ # Convenience wrappers
593
+ # =============================================================================
594
+
595
+
596
+ def all_reduce_backward(x, device_mesh):
597
+ """Identity forward, all-reduce backward. Use before colwise layers."""
598
+ return _AllReduceBackward.apply(x, device_mesh)
599
+
600
+
601
+ def all_reduce_forward(x, device_mesh):
602
+ """All-reduce forward, identity backward. Use after rowwise layers."""
603
+ return _AllReduceForward.apply(x, device_mesh)
604
+
605
+
606
+ def all_gather(x, device_mesh):
607
+ """All-gather forward, split backward."""
608
+ return _AllGather.apply(x, device_mesh)
609
+
610
+
611
+ def split(x, device_mesh):
612
+ """Split forward, all-gather backward."""
613
+ return _Split.apply(x, device_mesh)
614
+
615
+
616
+ def reduce_scatter(x, device_mesh):
617
+ """Reduce-scatter forward, all-gather backward."""
618
+ return _ReduceScatter.apply(x, device_mesh)
425
619
 
426
620
 
427
621
  def distribute_module(
@@ -434,224 +628,163 @@ def distribute_module(
434
628
  Copy pasted from torch's function but we remove the communications (partitioning)
435
629
  as well as buffer registering that is similarly not efficient.
436
630
  """
437
- if len(module._forward_pre_hooks) == 0:
438
- if input_fn is not None:
439
- module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
440
- if output_fn is not None:
441
- module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
631
+ if input_fn is not None:
632
+ module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
633
+ if output_fn is not None:
634
+ module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
442
635
  return module
443
636
 
444
637
 
445
638
  class TensorParallelLayer:
446
- """
447
- General tensor parallel layer for transformers.
448
- """
639
+ """General tensor parallel layer for transformers"""
449
640
 
450
- use_dtensor = True
451
641
  device_mesh = None
452
642
  rank = None
453
-
454
- # Used to compare the shape of the original tensor
455
643
  empty_param = None
456
644
 
457
- # Used to init the corresponding DTensor
458
- shard = None
459
-
460
645
  def __init__(self, device_mesh=None, rank=None, empty_param=None):
461
646
  self.rank = rank
462
647
  self.device_mesh = device_mesh
463
648
  self.empty_param = empty_param
464
649
 
465
650
  @staticmethod
466
- def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ...
651
+ def _prepare_input_fn(mod, inputs, device_mesh): ...
467
652
 
468
653
  @staticmethod
469
- def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): ...
654
+ def _prepare_output_fn(mod, outputs, device_mesh): ...
470
655
 
471
656
  def shard_tensor(
472
657
  self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
473
658
  ) -> torch.Tensor:
474
659
  raise NotImplementedError
475
660
 
476
- def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
477
- raise NotImplementedError
478
-
479
- def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
480
- if self.use_dtensor:
481
- distribute_module(
482
- module,
483
- device_mesh,
484
- partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
485
- partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
486
- )
487
-
488
-
489
- # use_dtensor needs to be set to false for nn.Parameter when you want to view, chunk, slice
490
- # you name it. Whatever you want to do that is a bit unconventional, you need local tensors
491
- class GatherParallel(TensorParallelLayer):
492
- """
493
- Simple class used to define the hooks to add to a layer when we just want to gather the outputs
494
- """
495
-
496
- def __init__(
497
- self,
498
- input_layouts: Placement | None = None,
499
- output_layouts: Placement | None = None,
500
- use_local_output: bool = True,
501
- **kwargs,
502
- ):
503
- super().__init__(**kwargs)
504
- self.input_layouts = (input_layouts or Replicate(),)
505
- self.output_layouts = output_layouts
506
- self.desired_input_layouts = (Replicate(),)
507
- self.use_local_output = use_local_output
508
-
509
- @staticmethod
510
- def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
511
- mod.expert_parallel_group = device_mesh.get_group()
512
- if inputs and isinstance(inputs[0], DTensor):
513
- inputs = inputs[0].to_local()
514
- return inputs
515
-
516
- @staticmethod
517
- def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
518
- if isinstance(outputs, torch.Tensor):
519
- dist.all_reduce(outputs, op=dist.ReduceOp.SUM, async_op=False)
520
- else:
521
- dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False)
522
- return outputs
523
-
524
- def shard_tensor(
525
- self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
526
- ) -> torch.Tensor:
527
- self.shard = [Replicate()]
528
- return param[...].to(device=device, dtype=dtype)
529
-
530
661
  def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
531
662
  distribute_module(
532
663
  module,
533
664
  device_mesh,
534
- partial(self._prepare_input_fn, None, None),
535
- partial(self._prepare_output_fn, None, None),
665
+ self._prepare_input_fn,
666
+ self._prepare_output_fn,
536
667
  )
537
668
 
669
+ def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
670
+ """
671
+ Compute the expected shape after TP sharding for a given full shape.
672
+
673
+ Args:
674
+ full_shape: The full (unsharded) parameter shape
538
675
 
539
- class IsolatedParallel(TensorParallelLayer):
676
+ Returns:
677
+ The expected sharded shape for this rank
678
+ """
679
+ # Default: no sharding, return full shape
680
+ return tuple(full_shape)
681
+
682
+
683
+ class ColwiseParallel(TensorParallelLayer):
540
684
  """
541
- This class is used to isolate computation in a TP layer from the rest of the world.
542
- Parameters need to be LOCAL, so not dtensors
685
+ Column-wise parallel: weight is sharded on dim -2 (output features).
686
+ Forward: input replicated -> output sharded on last dim.
687
+ If gather_output=True, output is all-gathered to produce full tensor.
543
688
  """
544
689
 
545
- @staticmethod
546
- def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh=None):
547
- # annotate module input placements/sharding with input_layouts
548
- input_tensor = inputs[0]
549
- if isinstance(input_tensor, DTensor):
550
- input_tensor = input_tensor.to_local()
551
- return input_tensor
690
+ def __init__(self, gather_output: bool = False, **kwargs):
691
+ super().__init__(**kwargs)
692
+ self.gather_output = gather_output
552
693
 
553
- @staticmethod
554
- def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh=None):
555
- # TODO: figure out dynamo support for instance method and switch this to instance method
694
+ def _prepare_input_fn(self, mod, inputs, device_mesh):
695
+ input_tensor = inputs[0] if inputs else inputs
696
+ return all_reduce_backward(input_tensor, device_mesh)
697
+
698
+ def _prepare_output_fn(self, mod, outputs, device_mesh):
699
+ if self.gather_output:
700
+ return all_gather(outputs, device_mesh)
556
701
  return outputs
557
702
 
558
703
  def shard_tensor(
559
704
  self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
560
705
  ) -> torch.Tensor:
561
- parameter = param[...].to(device=device, dtype=dtype)
562
- if self.device_mesh is not None:
563
- parameter = parameter / self.device_mesh.size()
564
- self.shard = None
565
- return parameter
566
-
567
- def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
568
- parameter = self.shard_tensor(param, dtype=dtype)
569
- if to_contiguous:
570
- parameter = parameter.contiguous()
571
- # TODO: assumes parent module will allreduce the output afterwards (e.g rowlinear bias is IsolatedParallel and parent module is GatherParallel)
572
- return parameter
706
+ # If only 1 dim, shard this one (usually it's a `bias`)
707
+ dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
708
+ if dim == 1:
709
+ parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
710
+ else:
711
+ parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
712
+ return parameter.to(device=device, dtype=dtype)
573
713
 
574
- def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
575
- distribute_module(
576
- module,
577
- device_mesh,
578
- partial(self._prepare_input_fn, None, None),
579
- partial(self._prepare_output_fn, None, None),
580
- )
714
+ def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
715
+ world_size = self.device_mesh.size()
716
+ shape = list(full_shape)
717
+ # Colwise shards dim -2, but 1D tensors (bias) shard on dim -1
718
+ dim = -1 if len(shape) == 1 else -2
719
+ dim = len(shape) + dim if dim < 0 else dim
720
+ shard_size = math.ceil(shape[dim] / world_size)
721
+ start = self.rank * shard_size
722
+ end = min(start + shard_size, shape[dim])
723
+ shape[dim] = end - start
724
+ return tuple(shape)
581
725
 
582
726
 
583
- class ReplicateParallel(TensorParallelLayer):
727
+ class RowwiseParallel(TensorParallelLayer):
584
728
  """
585
- This class is used to replicate computation in a TP layer (used in SP regions when we don't use sequence parallelism for example)
729
+ Row-wise parallel: weight is sharded on dim -1 (input features).
730
+ Forward: input (optionally split) -> output partial -> all-reduce to replicate.
731
+
732
+ Args:
733
+ split_input: If True, splits replicated input before matmul. Use when input
734
+ comes from a non-parallelizable operation (chunk/slice).
735
+ Default False (expects pre-sharded input from colwise layer).
586
736
  """
587
737
 
588
- def __init__(self, use_dtensor=True, use_local_output=True, **kwargs):
738
+ def __init__(self, split_input: bool = False, **kwargs):
589
739
  super().__init__(**kwargs)
590
- self.input_layouts = (Replicate(),)
591
- self.output_layouts = (Replicate(),)
592
- self.desired_input_layouts = (Replicate(),)
593
- self.use_local_output = use_local_output
594
- self.use_dtensor = use_dtensor
740
+ self.split_input = split_input
595
741
 
596
- @staticmethod
597
- def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
598
- # TODO: figure out dynamo support for instance method and switch this to instance method
599
- # annotate module input placements/sharding with input_layouts
600
- input_tensor = inputs[0]
601
- if not isinstance(input_tensor, DTensor):
602
- input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
742
+ def _prepare_input_fn(self, mod, inputs, device_mesh):
743
+ if hasattr(mod, "bias") and mod.bias is not None:
744
+ mod._bias = mod.bias
745
+ mod.bias = None
746
+
747
+ input_tensor = inputs[0] if inputs else inputs
603
748
 
749
+ if self.split_input:
750
+ # Input is replicated, split it to match sharded weight
751
+ return split(input_tensor, device_mesh)
604
752
  return input_tensor
605
753
 
606
- @staticmethod
607
- def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
608
- return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
754
+ def _prepare_output_fn(self, mod, outputs, device_mesh):
755
+ outputs = all_reduce_forward(outputs, device_mesh)
756
+ if hasattr(mod, "_bias") and mod._bias is not None:
757
+ outputs = outputs + mod._bias
758
+ return outputs
609
759
 
610
760
  def shard_tensor(
611
761
  self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
612
762
  ) -> torch.Tensor:
613
- self.shard = [Replicate()]
614
- return param[...].to(device=device, dtype=dtype)
615
-
616
- def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
617
- parameter = self.shard_tensor(param, dtype=dtype)
618
- if self.use_dtensor:
619
- parameter = DTensor.from_local(parameter, self.device_mesh, self.shard, run_check=False)
620
- return parameter
621
-
763
+ # If only 1 dim, it should not be sharded (usually it's a `bias`)
764
+ dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
765
+ if dim == 1:
766
+ parameter = param[...]
767
+ else:
768
+ parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
769
+ return parameter.to(device=device, dtype=dtype)
622
770
 
623
- class ColwiseParallel(TensorParallelLayer):
624
- """
625
- General tensor parallel layer for transformers.
626
- """
771
+ def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
772
+ # 1D tensors (bias) are NOT sharded in rowwise
773
+ if len(full_shape) == 1:
774
+ return tuple(full_shape)
775
+ world_size = self.device_mesh.size()
776
+ shape = list(full_shape)
777
+ dim = -1
778
+ dim = len(shape) + dim if dim < 0 else dim
779
+ shard_size = math.ceil(shape[dim] / world_size)
780
+ start = self.rank * shard_size
781
+ end = min(start + shard_size, shape[dim])
782
+ shape[dim] = end - start
783
+ return tuple(shape)
627
784
 
628
- def __init__(
629
- self,
630
- input_layouts: Placement | None = None,
631
- output_layouts: Placement | None = None,
632
- use_local_output: bool = True,
633
- use_dtensor=True,
634
- **kwargs,
635
- ):
636
- super().__init__(**kwargs)
637
- self.input_layouts = (input_layouts or Replicate(),)
638
- self.output_layouts = (output_layouts or Shard(-1),)
639
- self.desired_input_layouts = (Replicate(),)
640
- self.use_local_output = use_local_output
641
- self.use_dtensor = use_dtensor
642
785
 
643
- @staticmethod
644
- def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
645
- # TODO: figure out dynamo support for instance method and switch this to instance method
646
- # annotate module input placements/sharding with input_layouts
647
- input_tensor = inputs[0]
648
- if not isinstance(input_tensor, DTensor):
649
- input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
650
-
651
- # transform the input layouts to the desired layouts of ColwiseParallel
652
- if input_layouts != desired_input_layouts:
653
- input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
654
- return input_tensor
786
+ class PackedColwiseParallel(ColwiseParallel):
787
+ """Packed column-wise parallel for fused weights like gate_up_proj."""
655
788
 
656
789
  def shard_tensor(
657
790
  self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
@@ -659,333 +792,144 @@ class ColwiseParallel(TensorParallelLayer):
659
792
  # If only 1 dim, shard this one (usually it's a `bias`)
660
793
  dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
661
794
  if dim == 1:
662
- parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1, tensor_idx)
663
- shard = [Shard(-1)]
795
+ parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
664
796
  else:
665
- shard = [Shard(-2)]
666
- parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2, tensor_idx)
667
- self.shard = shard
797
+ expected_shape = self.get_expected_sharded_shape(self.empty_param.shape)
798
+ if dim < len(expected_shape):
799
+ # Input is unpacked (e.g., gate_proj that will be concatenated to gate_up_proj)
800
+ # Use regular tensor shard - concatenation will happen after
801
+ parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
802
+ else:
803
+ # Input is already packed, use packed sharding
804
+ parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -2)
668
805
  return parameter.to(device=device, dtype=dtype)
669
806
 
670
- def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
671
- # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
672
- # means Colwise as Linear is input * weight^T + bias, where
673
- # weight would become Shard(1)
674
- parameter = self.shard_tensor(param, dtype=dtype)
675
- if to_contiguous:
676
- parameter = parameter.contiguous()
677
- if self.use_dtensor:
678
- parameter = DTensor.from_local(
679
- parameter,
680
- self.device_mesh,
681
- self.shard,
682
- run_check=False,
683
- shape=self.empty_param.size(),
684
- stride=self.empty_param.stride(),
685
- )
686
- return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
687
-
688
- @staticmethod
689
- def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
690
- # outputs is a shard on last dimension DTensor, i.e. Shard(-1)
691
- if outputs.placements != output_layouts:
692
- outputs = outputs.redistribute(placements=output_layouts, async_op=False)
693
- # back to local tensor
694
- return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
695
807
 
808
+ class PackedRowwiseParallel(RowwiseParallel):
809
+ """Packed row-wise parallel for fused weights like gate_up_proj."""
696
810
 
697
- class PackedColwiseParallel(ColwiseParallel):
698
811
  def shard_tensor(
699
812
  self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
700
813
  ) -> torch.Tensor:
701
- parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -2)
814
+ # If only 1 dim, it should not be sharded (usually it's a `bias`)
815
+ dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
816
+ if dim == 1:
817
+ parameter = param[...]
818
+ else:
819
+ # Check if input tensor is unpacked (shape mismatch with expected packed size)
820
+ # This happens when using MergeModulelist + Concatenate for fused weights like gate_up_proj
821
+ param_shape = param.shape if isinstance(param, torch.Tensor) else param.get_shape()
822
+ expected_packed_dim = self.empty_param.shape[-1] if self.empty_param.dim() >= 1 else 0
823
+ actual_dim = param_shape[-1] if len(param_shape) >= 1 else 0
824
+
825
+ if actual_dim < expected_packed_dim:
826
+ # Input is unpacked, use regular tensor shard
827
+ parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
828
+ else:
829
+ # Input is already packed, use packed sharding
830
+ parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -1)
702
831
  return parameter.to(device=device, dtype=dtype)
703
832
 
704
- def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
705
- # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
706
- # means Colwise as Linear is input * weight^T + bias, where
707
- # weight would become Shard(1)
708
- parameter = self.shard_tensor(param, dtype=dtype)
709
- if to_contiguous:
710
- parameter = parameter.contiguous()
711
- if self.use_dtensor:
712
- parameter = DTensor.from_local(parameter, self.device_mesh, [Shard(-2)], run_check=False)
713
- return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
714
833
 
834
+ class EmbeddingParallel(TensorParallelLayer):
835
+ """EmbeddingParallel: shards embedding table, handles masked lookups for vocab parallelism."""
715
836
 
716
- class LocalColwiseParallel(ColwiseParallel):
717
- """
718
- Colwise parallel with use_dtensor=False for local tensor operations.
719
- """
837
+ def __init__(self, *, embedding_dim_sharding: int = 0, **kwargs):
838
+ super().__init__(**kwargs)
839
+ self.embedding_dim_sharding = embedding_dim_sharding
720
840
 
721
- def __init__(self, **kwargs):
722
- super().__init__(use_dtensor=False, **kwargs)
841
+ def _prepare_input_fn(self, mod, inputs, device_mesh):
842
+ input_tensor = inputs[0] if inputs else inputs
723
843
 
844
+ # For vocab-parallel (dim 0), we need to handle masking and offsetting
845
+ if self.embedding_dim_sharding == 0:
846
+ rank = device_mesh.get_local_rank()
724
847
 
725
- class ColwiseParallelReplicate(ColwiseParallel):
726
- """
727
- Colwise parallel with output layouts replicated.
728
- """
848
+ # Get vocab range for this rank
849
+ # Use weight.shape[0] to get the actual local (sharded) size, not num_embeddings
850
+ # which may not be updated after sharding
851
+ per_partition_size = mod.weight.shape[0]
852
+ vocab_start_index = rank * per_partition_size
853
+ vocab_end_index = vocab_start_index + per_partition_size
729
854
 
730
- def __init__(self, **kwargs):
731
- super().__init__(output_layouts=Replicate(), **kwargs)
855
+ # Build mask for out-of-vocabulary tokens
856
+ input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index)
857
+ mod._input_mask = input_mask
732
858
 
859
+ # Offset input to local indices and mask invalid ones
860
+ masked_input = input_tensor.clone() - vocab_start_index
861
+ masked_input[input_mask] = 0 # Set to valid local index
733
862
 
734
- class RowwiseParallel(TensorParallelLayer):
735
- """
736
- Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding.
737
- Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules.
738
- (i.e. MLP, Attention)
739
-
740
- Keyword Args:
741
- input_layouts (Placement, optional):
742
- The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
743
- become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension.
744
- output_layouts (Placement, optional):
745
- The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
746
- with the user desired layout. If not specified, the output tensor is replicated.
747
- use_local_output (bool, optional):
748
- Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
749
- Returns:
750
- A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module.
751
- """
863
+ return masked_input
752
864
 
753
- def __init__(
754
- self,
755
- input_layouts: Placement | None = None,
756
- output_layouts: Placement | None = None,
757
- use_local_output: bool = True,
758
- use_dtensor: bool = True,
759
- **kwargs,
760
- ):
761
- super().__init__(**kwargs)
762
- self.input_layouts = (input_layouts or Shard(-1),)
763
- self.output_layouts = (output_layouts or Replicate(),)
764
- self.use_local_output = use_local_output
765
- self.use_dtensor = use_dtensor
865
+ return input_tensor
866
+
867
+ def _prepare_output_fn(self, mod, outputs, device_mesh):
868
+ # For vocab-parallel (dim 0), zero out embeddings for out-of-range tokens before all-reduce
869
+ if self.embedding_dim_sharding == 0 and hasattr(mod, "_input_mask"):
870
+ input_mask = mod._input_mask
871
+ # Use multiplication instead of in-place assignment to preserve gradients
872
+ mask_expanded = input_mask.unsqueeze(-1).expand_as(outputs)
873
+ outputs = outputs * (~mask_expanded).float()
874
+ del mod._input_mask
875
+
876
+ return all_reduce_forward(outputs, device_mesh)
766
877
 
767
878
  def shard_tensor(
768
879
  self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
769
880
  ) -> torch.Tensor:
770
- # If only 1 dim, it should not be sharded (usually it's a `bias`)
881
+ # If only 1 dim, shard this one (usually it's a `bias`)
771
882
  dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
772
883
  if dim == 1:
773
- shard = [Replicate()]
774
- parameter = param[...]
884
+ parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
775
885
  else:
776
886
  parameter = get_tensor_shard(
777
- param, self.empty_param, self.device_mesh, self.rank, -1, tensor_idx=tensor_idx
778
- )
779
- shard = [Shard(-1)]
780
- self.shard = shard
781
- return parameter.to(device=device, dtype=dtype)
782
-
783
- def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
784
- # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
785
- # means Rowwise as nn.Linear is input * weight^T + bias, where
786
- # weight would become Shard(0)
787
- parameter = self.shard_tensor(param, dtype=dtype)
788
- if to_contiguous:
789
- parameter = parameter.contiguous()
790
- if self.use_dtensor:
791
- parameter = DTensor.from_local(
792
- parameter,
887
+ param,
888
+ self.empty_param,
793
889
  self.device_mesh,
794
- self.shard,
795
- run_check=False,
796
- shape=self.empty_param.size(),
797
- stride=self.empty_param.stride(),
798
- )
799
- return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
800
-
801
- @staticmethod
802
- def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
803
- if hasattr(mod, "bias") and mod.bias is not None:
804
- mod._bias = mod.bias.to_local()
805
- mod.bias = None
806
-
807
- input_tensor = inputs[0]
808
- if not isinstance(input_tensor, DTensor):
809
- input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
810
-
811
- if input_layouts != desired_input_layouts:
812
- input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
813
- return input_tensor
814
-
815
- @staticmethod
816
- def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
817
- # Rowwise sharding produces partial output, depending on output layouts:
818
- # 1. to replicate -> allreduce
819
- # 2. to shard -> reduce_scatter
820
- if outputs.placements != output_layouts:
821
- outputs = outputs.redistribute(placements=output_layouts, async_op=True)
822
- outputs = outputs.to_local() # otherwise the `+=` op will gather
823
- if hasattr(mod, "_bias"):
824
- outputs = outputs + mod._bias
825
- # back to local tensor if use_local_output is True
826
- return outputs
827
-
828
- def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
829
- module._distribute_module_applied = True
830
- if self.use_dtensor:
831
- if isinstance(module, nn.Linear):
832
- # rowwise linear runtime sharding requires input tensor shard on last dim
833
- self.desired_input_layouts: tuple[Placement, ...] = (Shard(-1),)
834
- elif isinstance(module, nn.Embedding):
835
- # rowwise embedding runtime sharding requires input tensor replicated
836
- self.desired_input_layouts = (Replicate(),)
837
- elif isinstance(module, nn.Parameter):
838
- # rowwise embedding runtime sharding requires input tensor replicated
839
- self.desired_input_layouts = (Shard(-1),)
840
- else:
841
- raise NotImplementedError("RowwiseParallel currently only support nn.Linear and nn.Embedding!")
842
-
843
- distribute_module(
844
- module,
845
- device_mesh,
846
- partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
847
- partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
890
+ self.rank,
891
+ self.embedding_dim_sharding,
848
892
  )
849
-
850
-
851
- class PackedRowwiseParallel(RowwiseParallel):
852
- def shard_tensor(
853
- self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
854
- ) -> torch.Tensor:
855
- parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -1)
856
893
  return parameter.to(device=device, dtype=dtype)
857
894
 
858
- def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
859
- # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
860
- # means Colwise as Linear is input * weight^T + bias, where
861
- # weight would become Shard(1)
862
- parameter = self.shard_tensor(param, dtype=dtype)
863
- if to_contiguous:
864
- parameter = parameter.contiguous()
865
- if self.use_dtensor:
866
- parameter = DTensor.from_local(parameter, self.device_mesh, [Shard(-1)], run_check=False)
867
- return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
868
-
869
-
870
- class LocalRowwiseParallel(RowwiseParallel):
871
- """
872
- Rowwise parallel with use_dtensor=False for local tensor operations.
873
- """
874
-
875
- def __init__(self, **kwargs):
876
- super().__init__(use_dtensor=False, **kwargs)
877
-
878
-
879
- class LocalPackedRowwiseParallel(PackedRowwiseParallel):
880
- """
881
- Packed rowwise parallel with use_dtensor=False for local tensor operations.
882
- """
883
-
884
- def __init__(self, **kwargs):
885
- super().__init__(use_dtensor=False, **kwargs)
886
-
887
-
888
- class RowwiseParallelReplicate(RowwiseParallel):
889
- """
890
- Rowwise parallel with input layouts replicated.
891
- """
892
-
893
- def __init__(self, **kwargs):
894
- super().__init__(input_layouts=Replicate(), **kwargs)
895
+ def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
896
+ world_size = self.device_mesh.size()
897
+ shape = list(full_shape)
898
+ # EmbeddingParallel shards on self.embedding_dim_sharding (default 0)
899
+ # 1D tensors (bias) shard on dim -1
900
+ dim = -1 if len(shape) == 1 else self.embedding_dim_sharding
901
+ dim = len(shape) + dim if dim < 0 else dim
902
+ shard_size = math.ceil(shape[dim] / world_size)
903
+ start = self.rank * shard_size
904
+ end = min(start + shard_size, shape[dim])
905
+ shape[dim] = end - start
906
+ return tuple(shape)
895
907
 
896
908
 
897
909
  class SequenceParallel(TensorParallelLayer):
898
910
  """
899
- SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with
900
- input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the
901
- `RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__
902
-
903
- This style implements the operation that is described in the paper
904
- `Reducing Activation Recomputation in Large Transformer Models <https://huggingface.co/papers/2205.05198>`__
905
-
906
- If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded
907
- on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input
908
- passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would
909
- redistribute the input to be sharded on the sequence dimension.
910
-
911
- The output of the ``nn.Module`` will be sharded on the sequence dimension.
912
-
913
- Keyword Args:
914
- sequence_dim (int, optional):
915
- The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to
916
- become a DTensor that is sharded on the sequence dimension, default: 1.
917
- use_local_output (bool, optional):
918
- Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False.
919
- Returns:
920
- A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``.
921
-
922
- Example::
923
- >>> # xdoctest: +SKIP(failing)
924
- >>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
925
- >>> from torch.distributed.device_mesh import init_device_mesh
926
- >>> ...
927
- >>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
928
- >>> tp_mesh = init_device_mesh("cuda", (8,))
929
- >>>
930
- >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
931
- >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
932
- >>>
933
- >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
934
- >>> ...
935
-
936
- .. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.
937
- ``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom
938
- inits for the weights on those modules, you need to broadcast the weights before/after parallelizing
939
- to ensure that they are replicated.
911
+ Sequence Parallel: input/output sharded on sequence dimension.
912
+ Weights are replicated.
940
913
  """
941
914
 
942
915
  def __init__(self, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False, **kwargs):
943
916
  super().__init__(**kwargs)
944
- self.input_layouts = (Replicate(),)
945
- self.desired_input_layouts = (Shard(1),)
946
- self.output_layouts = (Replicate(),)
947
- self.use_local_output = use_local_output
948
- self.use_dtensor = True
949
- self.sequence_sharding = (Shard(sequence_dim),)
950
- self.use_local_output = use_local_output
917
+ self.sequence_dim = sequence_dim
918
+
919
+ def _prepare_input_fn(self, mod, inputs, device_mesh):
920
+ input_tensor = inputs[0] if inputs else inputs
921
+ # For sequence parallel, input is sharded on sequence dim
922
+ # All-gather for the layer, then reduce-scatter after
923
+ return all_gather(input_tensor, device_mesh)
924
+
925
+ def _prepare_output_fn(self, mod, outputs, device_mesh):
926
+ return reduce_scatter(outputs, device_mesh)
951
927
 
952
928
  def shard_tensor(
953
- self,
954
- param: torch.Tensor,
955
- tensor_idx=None,
956
- device=None,
957
- dtype=None,
929
+ self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
958
930
  ) -> torch.Tensor:
959
- self.shard = [Replicate()]
960
931
  return param[...].to(device=device, dtype=dtype)
961
932
 
962
- @staticmethod
963
- def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
964
- input_tensor = inputs[0]
965
- if not isinstance(input_tensor, DTensor):
966
- input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
967
- if input_layouts != desired_input_layouts:
968
- input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
969
- return input_tensor
970
-
971
- @staticmethod
972
- def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
973
- outputs = outputs.redistribute(
974
- placements=(Replicate(),), async_op=True
975
- ) # maybe we have to replicate ? because next layer is not sharded
976
- return outputs.to_local() # if use_local_output else outputs
977
-
978
- def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
979
- # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
980
- # means Colwise as Linear is input * weight^T + bias, where
981
- # weight would become Shard(1)
982
- parameter = self.shard_tensor(param, dtype=dtype)
983
- if to_contiguous:
984
- parameter = parameter.contiguous()
985
- if self.use_dtensor:
986
- parameter = DTensor.from_local(parameter, self.device_mesh, [Replicate()], run_check=False)
987
- return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
988
-
989
933
 
990
934
  class GroupedGemmParallel(TensorParallelLayer):
991
935
  """
@@ -994,7 +938,6 @@ class GroupedGemmParallel(TensorParallelLayer):
994
938
 
995
939
  def __init__(self, **kwargs):
996
940
  super().__init__(**kwargs)
997
- self.use_dtensor = False
998
941
 
999
942
  def shard_tensor(
1000
943
  self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
@@ -1005,15 +948,30 @@ class GroupedGemmParallel(TensorParallelLayer):
1005
948
  f"Global number of experts must be divisible by number of devices: {global_num_experts} % {self.device_mesh.size()} != 0"
1006
949
  )
1007
950
  local_num_experts = global_num_experts // self.device_mesh.size()
1008
- parameter = param[self.rank * local_num_experts : (self.rank + 1) * local_num_experts]
1009
- self.shard = None
1010
- return parameter.to(device=device, dtype=dtype)
1011
-
1012
- def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
1013
- parameter = self.shard_tensor(param, dtype=dtype)
1014
- if to_contiguous:
1015
- parameter = parameter.contiguous()
1016
- return parameter
951
+ shard_size = local_num_experts
952
+ if isinstance(device, torch.device):
953
+ device = device.index if device.index is not None else 0
954
+ start = device * shard_size
955
+ end = (device + 1) * shard_size
956
+ # special case we don't "shard" just send this entire tensor to the correct rank.
957
+ shape = param.get_shape() if not isinstance(param, torch.Tensor) else param.shape
958
+ if tensor_idx is not None and start <= tensor_idx < end:
959
+ # this tensor does need to be materialized on this device:
960
+ return param[:].to(device=device)
961
+ elif tensor_idx is None: # a bias or a weight, but already merged
962
+ return param[start:end].to(device=device, dtype=dtype)
963
+ elif len(shape) >= 1 and tensor_idx is not None:
964
+ return None
965
+ else: # bias case
966
+ return param[:].to(device=device, dtype=dtype)
967
+
968
+ def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
969
+ # GroupedGemm shards on dim 0 (experts dimension)
970
+ world_size = self.device_mesh.size()
971
+ shape = list(full_shape)
972
+ local_num_experts = shape[0] // world_size
973
+ shape[0] = local_num_experts
974
+ return tuple(shape)
1017
975
 
1018
976
 
1019
977
  class RouterParallel(TensorParallelLayer):
@@ -1021,20 +979,15 @@ class RouterParallel(TensorParallelLayer):
1021
979
  Allows to reshape the router scores to support running expert parallel.
1022
980
  """
1023
981
 
1024
- def __init__(self, use_dtensor: bool = False, *args, **kwargs):
982
+ def __init__(self, **kwargs):
1025
983
  super().__init__(**kwargs)
1026
- self.args = args
1027
- self.use_dtensor = use_dtensor
1028
984
 
1029
985
  @staticmethod
1030
- def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
1031
- input_tensor = inputs[0]
1032
- if isinstance(input_tensor, DTensor):
1033
- raise NotImplementedError("RouterParallel does not support DTensor input for now")
1034
- return input_tensor
986
+ def _prepare_input_fn(mod, inputs, device_mesh):
987
+ return inputs[0] if inputs else inputs
1035
988
 
1036
989
  @staticmethod
1037
- def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
990
+ def _prepare_output_fn(mod, outputs, device_mesh):
1038
991
  """
1039
992
  Imagine if you had 4 tokens, top_k = 4, and 128experts.
1040
993
  With EP = 8. The num_local_expert should be 128/8 = 16
@@ -1076,6 +1029,7 @@ class RouterParallel(TensorParallelLayer):
1076
1029
  )
1077
1030
  num_local_experts = mod.num_experts // ep_size
1078
1031
  router_logits, router_scores, router_indices = outputs
1032
+ router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_scores)
1079
1033
  router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts]
1080
1034
  router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, -1)
1081
1035
  # As -1 % 1 is 0, we can only use mask fill when num_local_experts is 1
@@ -1083,32 +1037,54 @@ class RouterParallel(TensorParallelLayer):
1083
1037
  router_indices = torch.fmod(router_indices, num_local_experts)
1084
1038
  else:
1085
1039
  router_indices = router_indices.masked_fill(router_indices > 0, 0).masked_fill(router_indices < 0, -1)
1086
- router_indices = router_indices.masked_fill(
1087
- router_indices == -1, num_local_experts
1088
- ) # masking class for one hot
1040
+ router_indices = router_indices.masked_fill(router_indices == -1, num_local_experts)
1089
1041
  return router_logits, router_scores, router_indices
1090
1042
 
1091
1043
  def shard_tensor(
1092
1044
  self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
1093
1045
  ) -> torch.Tensor:
1094
- self.shard = None
1095
1046
  return param[...].to(device=device, dtype=dtype)
1096
1047
 
1097
- def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
1098
- # TODO: i'd like for this to be the default
1099
- parameter = self.shard_tensor(param, dtype=dtype)
1100
- if to_contiguous:
1101
- parameter = parameter.contiguous()
1102
- return parameter
1103
1048
 
1104
- def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
1105
- # TODO: need an abstract Parallel class that is different from TensorParallelLayer
1106
- distribute_module(
1107
- module,
1108
- device_mesh,
1109
- partial(self._prepare_input_fn, None, None),
1110
- partial(self._prepare_output_fn, None, None),
1111
- )
1049
+ class MoeTensorParalellExperts(TensorParallelLayer):
1050
+ """
1051
+ Note: For tensor parallel, the MoEExpertsParallel TP layer handles gradient sync:
1052
+ - all_reduce_backward on hidden_states (for colwise gate_up_proj gradient)
1053
+ - all_reduce_backward on top_k_weights (for router gradient)
1054
+ - all_reduce_forward on output (for partial expert outputs)
1055
+ """
1056
+
1057
+ def __init__(self, **kwargs):
1058
+ super().__init__(**kwargs)
1059
+
1060
+ @staticmethod
1061
+ def _prepare_input_fn(mod, inputs, device_mesh):
1062
+ # inputs = (hidden_states, top_k_index, top_k_weights)
1063
+ hidden_states = inputs[0]
1064
+ top_k_index = inputs[1]
1065
+ top_k_weights = inputs[2]
1066
+
1067
+ # all_reduce_backward on hidden_states for correct colwise (gate_up_proj) gradient
1068
+ hidden_states = all_reduce_backward(hidden_states, device_mesh)
1069
+
1070
+ # all_reduce_backward on routing weights for correct router gradient
1071
+ # This is needed because ∂L/∂routing_weights = ∂L/∂output * partial_expert_output
1072
+ # and partial_expert_output is different on each GPU before all-reduce
1073
+ top_k_weights = all_reduce_backward(top_k_weights, device_mesh)
1074
+
1075
+ return (hidden_states, top_k_index, top_k_weights)
1076
+
1077
+ @staticmethod
1078
+ def _prepare_output_fn(mod, outputs, device_mesh):
1079
+ # all_reduce_forward to sum partial expert outputs across GPUs
1080
+ return all_reduce_forward(outputs, device_mesh)
1081
+
1082
+ def shard_tensor(
1083
+ self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
1084
+ ) -> torch.Tensor:
1085
+ # This class doesn't shard tensors - sharding is handled by packed_colwise/rowwise
1086
+ # on the individual weight tensors (gate_up_proj/down_proj)
1087
+ return param[...].to(device=device, dtype=dtype)
1112
1088
 
1113
1089
 
1114
1090
  class ParallelInterface(GeneralInterface):
@@ -1116,69 +1092,152 @@ class ParallelInterface(GeneralInterface):
1116
1092
  # a new instance is created (in order to locally override a given entry)
1117
1093
  _global_mapping = (
1118
1094
  {
1095
+ "embedding_rowwise": EmbeddingParallel(embedding_dim_sharding=0),
1096
+ "colwise_gather_output": ColwiseParallel(gather_output=True),
1119
1097
  "colwise": ColwiseParallel(),
1120
1098
  "rowwise": RowwiseParallel(),
1121
- "colwise_rep": ColwiseParallelReplicate(),
1122
- "rowwise_rep": RowwiseParallelReplicate(),
1123
- "local_colwise": LocalColwiseParallel(),
1124
- "local_rowwise": LocalRowwiseParallel(),
1125
- "local": IsolatedParallel(),
1126
- "gather": GatherParallel(),
1127
- "local_packed_rowwise": LocalPackedRowwiseParallel(),
1099
+ "rowwise_split_input": RowwiseParallel(split_input=True),
1100
+ "packed_colwise": PackedColwiseParallel(),
1101
+ "packed_rowwise": PackedRowwiseParallel(),
1128
1102
  "sequence_parallel": SequenceParallel(),
1129
- "replicate": ReplicateParallel(),
1130
1103
  "grouped_gemm": GroupedGemmParallel(),
1131
1104
  "ep_router": RouterParallel(),
1105
+ "moe_tp_experts": MoeTensorParalellExperts(),
1132
1106
  }
1133
- if is_torch_greater_or_equal("2.5") and _torch_distributed_available
1107
+ if is_torch_available() and _torch_distributed_available
1134
1108
  else {}
1135
1109
  )
1136
1110
 
1111
+ # Map plan names to sharding dimensions for weights
1112
+ # For weights: colwise shards dim -2, rowwise shards dim -1
1113
+ # For embedding: rowwise shards dim 0 (vocab), colwise shards dim -2 (hidden)
1114
+ plan_to_weight_dim: dict[str, int | None] = {
1115
+ "colwise": -2,
1116
+ "colwise_gather_output": -2,
1117
+ "packed_colwise": -2,
1118
+ "rowwise": -1,
1119
+ "rowwise_split_input": -1,
1120
+ "packed_rowwise": -1,
1121
+ "embedding_rowwise": 0,
1122
+ "sequence_parallel": None,
1123
+ }
1124
+
1125
+ # Bias sharding: colwise shards bias, rowwise doesn't (bias is replicated and all-reduced)
1126
+ plan_to_bias_dim: dict[str, int | None] = {
1127
+ "colwise": -1,
1128
+ "colwise_gather_output": -1,
1129
+ "packed_colwise": -1,
1130
+ "rowwise": None,
1131
+ "rowwise_split_input": None,
1132
+ "packed_rowwise": None,
1133
+ "embedding_rowwise": None,
1134
+ "sequence_parallel": None,
1135
+ }
1136
+
1137
1137
 
1138
1138
  ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
1139
1139
 
1140
1140
 
1141
- def convert_local_tensor_to_dtensor(
1142
- parameter: torch.Tensor, parameter_name: str, device_mesh, tp_plan: dict[str, str]
1143
- ) -> DTensor:
1141
+ # =============================================================================
1142
+ # High-Level API Functions
1143
+ # =============================================================================
1144
+
1145
+
1146
+ def gather_full_tensor(local_tensor: torch.Tensor, shard_dim: int, device_mesh) -> torch.Tensor:
1144
1147
  """
1145
- Converts a local variant of weights to a DTensor with corresponding placements. Shouldn't be done ever except of before saving the model.
1148
+ All-gather a sharded tensor along the specified dimension to reconstruct the full tensor.
1149
+
1150
+ Args:
1151
+ local_tensor: The local shard of the tensor on this rank
1152
+ shard_dim: The dimension along which the tensor was sharded
1153
+ device_mesh: The device mesh for distributed communication
1154
+
1155
+ Returns:
1156
+ The full reconstructed tensor (same on all ranks)
1146
1157
  """
1147
- _, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
1148
- tp_style = _get_parameter_tp_plan(parameter_name, tp_plan)
1149
- if not tp_style:
1150
- return parameter
1151
-
1152
- if tp_style not in ["local_packed_rowwise", "local_rowwise", "local_colwise"]:
1153
- return parameter
1154
- # TODO: this logic should be wrapped in a function, this is copied from corresponding tp classes.
1155
- if tp_style == "local_packed_rowwise":
1156
- placements = [Shard(-1)]
1157
- elif tp_style == "local_rowwise":
1158
- if param_type == "bias":
1159
- placements = [Replicate()]
1160
- else:
1161
- placements = [Shard(-1)]
1162
- elif tp_style == "local_colwise":
1163
- if param_type == "bias":
1164
- placements = [Shard(-1)]
1165
- else:
1166
- placements = [Shard(-2)]
1167
- return DTensor.from_local(parameter, device_mesh, placements, run_check=False)
1158
+ world_size = device_mesh.size()
1159
+
1160
+ # Normalize negative dimension
1161
+ if shard_dim < 0:
1162
+ shard_dim = local_tensor.ndim + shard_dim
1168
1163
 
1164
+ # Gather all shards
1165
+ gathered_tensors = [torch.empty_like(local_tensor) for _ in range(world_size)]
1166
+ dist.all_gather(gathered_tensors, local_tensor.contiguous())
1169
1167
 
1170
- def replace_state_dict_local_with_dtensor(
1168
+ # Concatenate along the shard dimension
1169
+ return torch.cat(gathered_tensors, dim=shard_dim)
1170
+
1171
+
1172
+ def gather_state_dict_for_save(
1171
1173
  state_dict: dict[str, torch.Tensor],
1172
1174
  tp_plan: dict[str, str],
1173
1175
  device_mesh,
1176
+ tp_size: int,
1174
1177
  ) -> dict[str, torch.Tensor]:
1175
1178
  """
1176
- Replaces all tensors that were sharded with `local_*` strategy with DTensor to make determining their proper size possible.
1179
+ Gather sharded tensors to reconstruct full tensors for saving.
1180
+
1181
+ This function all-gathers each sharded tensor along its shard dimension
1182
+ to reconstruct the full unsharded tensor for checkpoint saving.
1183
+
1184
+ Args:
1185
+ state_dict: The model state dict with local sharded tensors
1186
+ tp_plan: The tensor parallel plan mapping layer patterns to shard styles
1187
+ device_mesh: The device mesh for distributed communication
1188
+ tp_size: The tensor parallel world size
1189
+
1190
+ Returns:
1191
+ State dict with full (gathered) tensors
1177
1192
  """
1178
- for key, value in state_dict.items():
1179
- if isinstance(value, torch.Tensor) and not isinstance(value, DTensor):
1180
- state_dict[key] = convert_local_tensor_to_dtensor(value, key, device_mesh, tp_plan)
1181
- return state_dict
1193
+ # Use the global mappings from ParallelInterface (can be extended by users)
1194
+ plan_to_weight_dim = ALL_PARALLEL_STYLES.plan_to_weight_dim
1195
+ plan_to_bias_dim = ALL_PARALLEL_STYLES.plan_to_bias_dim
1196
+
1197
+ result = {}
1198
+ for key, tensor in state_dict.items():
1199
+ # Find the matching TP plan for this parameter
1200
+ param_name = key.rsplit(".", 1)[0] if "." in key else key
1201
+ param_type = key.rsplit(".", 1)[1] if "." in key else None
1202
+ generic_param_name = re.sub(r"\d+", "*", param_name)
1203
+ # Also check the full key for nn.Parameter (e.g., MoE experts without .weight suffix)
1204
+ generic_full_key = re.sub(r"\d+", "*", key)
1205
+
1206
+ # Check if this parameter has a TP plan
1207
+ current_plan = None
1208
+ if generic_full_key in tp_plan:
1209
+ # Full key match (e.g., "model.layers.*.mlp.experts.gate_up_proj" for MoE experts)
1210
+ current_plan = tp_plan[generic_full_key]
1211
+ elif generic_param_name in tp_plan:
1212
+ current_plan = tp_plan[generic_param_name]
1213
+ elif "." in generic_param_name:
1214
+ parent_param_name = generic_param_name.rsplit(".", 1)[0]
1215
+ if parent_param_name in tp_plan:
1216
+ current_plan = tp_plan[parent_param_name]
1217
+
1218
+ if current_plan is None or current_plan not in plan_to_weight_dim:
1219
+ # Not sharded, keep as-is
1220
+ result[key] = tensor
1221
+ continue
1222
+
1223
+ # Determine sharding dimension based on param type
1224
+ if param_type == "bias":
1225
+ shard_dim = plan_to_bias_dim.get(current_plan)
1226
+ else:
1227
+ shard_dim = plan_to_weight_dim.get(current_plan)
1228
+
1229
+ if shard_dim is None:
1230
+ # Replicated, keep as-is
1231
+ result[key] = tensor
1232
+ continue
1233
+
1234
+ # Gather full tensor and handle packed weights repacking
1235
+ full_tensor = gather_full_tensor(tensor, shard_dim, device_mesh)
1236
+ if current_plan in ("packed_colwise", "packed_rowwise"):
1237
+ full_tensor = repack_weights(full_tensor, shard_dim, tp_size, 2)
1238
+ result[key] = full_tensor.contiguous()
1239
+
1240
+ return result
1182
1241
 
1183
1242
 
1184
1243
  def add_tensor_parallel_hooks_to_module(
@@ -1207,7 +1266,7 @@ def add_tensor_parallel_hooks_to_module(
1207
1266
 
1208
1267
  def shard_and_distribute_module(
1209
1268
  model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
1210
- ): # TODO: rename to shard_and_distribute_param
1269
+ ):
1211
1270
  r"""
1212
1271
  This function is called in `from_pretrained` when loading a model's checkpoints.
1213
1272
  It receives the pointer to the parameter (or the parameter itself) and takes care of "sharding".
@@ -1223,7 +1282,7 @@ def shard_and_distribute_module(
1223
1282
  """
1224
1283
  param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
1225
1284
  tp_plan = model.tp_plan or {}
1226
- module_to_tp = model.get_submodule(param_name) # TODO: can i loop over modules?
1285
+ module_to_tp = model.get_submodule(param_name)
1227
1286
  rank = int(rank)
1228
1287
  current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
1229
1288
 
@@ -1235,10 +1294,13 @@ def shard_and_distribute_module(
1235
1294
 
1236
1295
  if current_shard_plan is not None:
1237
1296
  try:
1238
- tp_layer = ALL_PARALLEL_STYLES[current_shard_plan](
1239
- empty_param=empty_param, device_mesh=device_mesh, rank=rank
1240
- )
1241
- param = tp_layer.partition_tensor(param, param_casting_dtype, is_contiguous)
1297
+ tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
1298
+ tp_layer.empty_param = empty_param
1299
+ tp_layer.device_mesh = device_mesh
1300
+ tp_layer.rank = rank
1301
+ param = tp_layer.shard_tensor(param, tensor_idx=None, dtype=param_casting_dtype, device=rank)
1302
+ if is_contiguous:
1303
+ param = param.contiguous()
1242
1304
  except NotImplementedError as e:
1243
1305
  print(
1244
1306
  f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
@@ -1251,7 +1313,6 @@ def shard_and_distribute_module(
1251
1313
  if not isinstance(param, torch.nn.Parameter):
1252
1314
  param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point())
1253
1315
  setattr(module_to_tp, param_type, param)
1254
- # module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
1255
1316
  return param
1256
1317
 
1257
1318
 
@@ -1265,20 +1326,18 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
1265
1326
 
1266
1327
  generic_keys = {replace_layer_number_by_wildcard(key) for key in expected_keys}
1267
1328
  unsharded_layers = set(generic_keys)
1268
- unused_rules = tp_plan
1329
+ unused_rules = tp_plan.copy()
1269
1330
 
1270
1331
  for key in generic_keys:
1271
1332
  param_name = key.rsplit(".", 1)[0] if "." in key else key
1272
1333
  generic_param_name = re.sub(r"\d+", "*", param_name)
1273
1334
 
1274
1335
  if generic_param_name in tp_plan:
1275
- unused_rules.pop(generic_param_name)
1336
+ unused_rules.pop(generic_param_name, None)
1276
1337
  unsharded_layers.discard(key)
1277
1338
  elif "." in generic_param_name and (parent_param_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
1278
- unused_rules.pop(parent_param_name)
1339
+ unused_rules.pop(parent_param_name, None)
1279
1340
  unsharded_layers.discard(key)
1280
- else:
1281
- pass # we couldn't find the rule for this parameter, so it's not sharded
1282
1341
 
1283
1342
  if len(unused_rules) > 0:
1284
1343
  logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}")
@@ -1287,6 +1346,7 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
1287
1346
 
1288
1347
 
1289
1348
  def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size):
1349
+ """Distribute a model according to the TP plan."""
1290
1350
  model._tp_size = tp_size
1291
1351
  model._device_mesh = device_mesh
1292
1352
  if distributed_config is not None:
@@ -1297,7 +1357,7 @@ def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size):
1297
1357
  if isinstance(tp_plan, dict):
1298
1358
  model.tp_plan = tp_plan
1299
1359
  model_plan = model.tp_plan
1300
- if model_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
1360
+ if model_plan is not None and _torch_distributed_available:
1301
1361
  for v in model_plan.values():
1302
1362
  if v not in ALL_PARALLEL_STYLES:
1303
1363
  raise ValueError(f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}")