transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1021) hide show
  1. transformers/__init__.py +4 -11
  2. transformers/activations.py +2 -2
  3. transformers/backbone_utils.py +326 -0
  4. transformers/cache_utils.py +11 -2
  5. transformers/cli/serve.py +11 -8
  6. transformers/configuration_utils.py +1 -69
  7. transformers/conversion_mapping.py +146 -26
  8. transformers/convert_slow_tokenizer.py +6 -4
  9. transformers/core_model_loading.py +207 -118
  10. transformers/dependency_versions_check.py +0 -1
  11. transformers/dependency_versions_table.py +7 -8
  12. transformers/file_utils.py +0 -2
  13. transformers/generation/candidate_generator.py +1 -2
  14. transformers/generation/continuous_batching/cache.py +40 -38
  15. transformers/generation/continuous_batching/cache_manager.py +3 -16
  16. transformers/generation/continuous_batching/continuous_api.py +94 -406
  17. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  18. transformers/generation/continuous_batching/requests.py +54 -17
  19. transformers/generation/continuous_batching/scheduler.py +77 -95
  20. transformers/generation/logits_process.py +10 -5
  21. transformers/generation/stopping_criteria.py +1 -2
  22. transformers/generation/utils.py +75 -95
  23. transformers/image_processing_utils.py +0 -3
  24. transformers/image_processing_utils_fast.py +17 -18
  25. transformers/image_transforms.py +44 -13
  26. transformers/image_utils.py +0 -5
  27. transformers/initialization.py +57 -0
  28. transformers/integrations/__init__.py +10 -24
  29. transformers/integrations/accelerate.py +47 -11
  30. transformers/integrations/deepspeed.py +145 -3
  31. transformers/integrations/executorch.py +2 -6
  32. transformers/integrations/finegrained_fp8.py +142 -7
  33. transformers/integrations/flash_attention.py +2 -7
  34. transformers/integrations/hub_kernels.py +18 -7
  35. transformers/integrations/moe.py +226 -106
  36. transformers/integrations/mxfp4.py +47 -34
  37. transformers/integrations/peft.py +488 -176
  38. transformers/integrations/tensor_parallel.py +641 -581
  39. transformers/masking_utils.py +153 -9
  40. transformers/modeling_flash_attention_utils.py +1 -2
  41. transformers/modeling_utils.py +359 -358
  42. transformers/models/__init__.py +6 -0
  43. transformers/models/afmoe/configuration_afmoe.py +14 -4
  44. transformers/models/afmoe/modeling_afmoe.py +8 -8
  45. transformers/models/afmoe/modular_afmoe.py +7 -7
  46. transformers/models/aimv2/configuration_aimv2.py +2 -7
  47. transformers/models/aimv2/modeling_aimv2.py +26 -24
  48. transformers/models/aimv2/modular_aimv2.py +8 -12
  49. transformers/models/albert/configuration_albert.py +8 -1
  50. transformers/models/albert/modeling_albert.py +3 -3
  51. transformers/models/align/configuration_align.py +8 -5
  52. transformers/models/align/modeling_align.py +22 -24
  53. transformers/models/altclip/configuration_altclip.py +4 -6
  54. transformers/models/altclip/modeling_altclip.py +30 -26
  55. transformers/models/apertus/configuration_apertus.py +5 -7
  56. transformers/models/apertus/modeling_apertus.py +4 -4
  57. transformers/models/apertus/modular_apertus.py +8 -10
  58. transformers/models/arcee/configuration_arcee.py +5 -7
  59. transformers/models/arcee/modeling_arcee.py +4 -4
  60. transformers/models/aria/configuration_aria.py +11 -21
  61. transformers/models/aria/modeling_aria.py +39 -36
  62. transformers/models/aria/modular_aria.py +33 -39
  63. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
  64. transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
  65. transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
  66. transformers/models/auto/auto_factory.py +8 -6
  67. transformers/models/auto/configuration_auto.py +22 -0
  68. transformers/models/auto/image_processing_auto.py +17 -13
  69. transformers/models/auto/modeling_auto.py +15 -0
  70. transformers/models/auto/processing_auto.py +9 -18
  71. transformers/models/auto/tokenization_auto.py +17 -15
  72. transformers/models/autoformer/modeling_autoformer.py +2 -1
  73. transformers/models/aya_vision/configuration_aya_vision.py +4 -0
  74. transformers/models/aya_vision/modeling_aya_vision.py +29 -62
  75. transformers/models/aya_vision/modular_aya_vision.py +20 -45
  76. transformers/models/bamba/configuration_bamba.py +17 -7
  77. transformers/models/bamba/modeling_bamba.py +23 -55
  78. transformers/models/bamba/modular_bamba.py +19 -54
  79. transformers/models/bark/configuration_bark.py +2 -1
  80. transformers/models/bark/modeling_bark.py +24 -10
  81. transformers/models/bart/configuration_bart.py +9 -4
  82. transformers/models/bart/modeling_bart.py +9 -12
  83. transformers/models/beit/configuration_beit.py +2 -4
  84. transformers/models/beit/image_processing_beit_fast.py +3 -3
  85. transformers/models/beit/modeling_beit.py +14 -9
  86. transformers/models/bert/configuration_bert.py +12 -1
  87. transformers/models/bert/modeling_bert.py +6 -30
  88. transformers/models/bert_generation/configuration_bert_generation.py +17 -1
  89. transformers/models/bert_generation/modeling_bert_generation.py +6 -6
  90. transformers/models/big_bird/configuration_big_bird.py +12 -8
  91. transformers/models/big_bird/modeling_big_bird.py +0 -15
  92. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
  93. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
  94. transformers/models/biogpt/configuration_biogpt.py +8 -1
  95. transformers/models/biogpt/modeling_biogpt.py +4 -8
  96. transformers/models/biogpt/modular_biogpt.py +1 -5
  97. transformers/models/bit/configuration_bit.py +2 -4
  98. transformers/models/bit/modeling_bit.py +6 -5
  99. transformers/models/bitnet/configuration_bitnet.py +5 -7
  100. transformers/models/bitnet/modeling_bitnet.py +3 -4
  101. transformers/models/bitnet/modular_bitnet.py +3 -4
  102. transformers/models/blenderbot/configuration_blenderbot.py +8 -4
  103. transformers/models/blenderbot/modeling_blenderbot.py +4 -4
  104. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
  105. transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
  106. transformers/models/blip/configuration_blip.py +9 -9
  107. transformers/models/blip/modeling_blip.py +55 -37
  108. transformers/models/blip_2/configuration_blip_2.py +2 -1
  109. transformers/models/blip_2/modeling_blip_2.py +81 -56
  110. transformers/models/bloom/configuration_bloom.py +5 -1
  111. transformers/models/bloom/modeling_bloom.py +2 -1
  112. transformers/models/blt/configuration_blt.py +23 -12
  113. transformers/models/blt/modeling_blt.py +20 -14
  114. transformers/models/blt/modular_blt.py +70 -10
  115. transformers/models/bridgetower/configuration_bridgetower.py +7 -1
  116. transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
  117. transformers/models/bridgetower/modeling_bridgetower.py +29 -15
  118. transformers/models/bros/configuration_bros.py +24 -17
  119. transformers/models/camembert/configuration_camembert.py +8 -1
  120. transformers/models/camembert/modeling_camembert.py +6 -6
  121. transformers/models/canine/configuration_canine.py +4 -1
  122. transformers/models/chameleon/configuration_chameleon.py +5 -7
  123. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
  124. transformers/models/chameleon/modeling_chameleon.py +82 -36
  125. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
  126. transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
  127. transformers/models/clap/configuration_clap.py +4 -8
  128. transformers/models/clap/modeling_clap.py +21 -22
  129. transformers/models/clip/configuration_clip.py +4 -1
  130. transformers/models/clip/image_processing_clip_fast.py +9 -0
  131. transformers/models/clip/modeling_clip.py +25 -22
  132. transformers/models/clipseg/configuration_clipseg.py +4 -1
  133. transformers/models/clipseg/modeling_clipseg.py +27 -25
  134. transformers/models/clipseg/processing_clipseg.py +11 -3
  135. transformers/models/clvp/configuration_clvp.py +14 -2
  136. transformers/models/clvp/modeling_clvp.py +19 -30
  137. transformers/models/codegen/configuration_codegen.py +4 -3
  138. transformers/models/codegen/modeling_codegen.py +2 -1
  139. transformers/models/cohere/configuration_cohere.py +5 -7
  140. transformers/models/cohere/modeling_cohere.py +4 -4
  141. transformers/models/cohere/modular_cohere.py +3 -3
  142. transformers/models/cohere2/configuration_cohere2.py +6 -8
  143. transformers/models/cohere2/modeling_cohere2.py +4 -4
  144. transformers/models/cohere2/modular_cohere2.py +9 -11
  145. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  146. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
  147. transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
  148. transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
  149. transformers/models/colqwen2/modeling_colqwen2.py +7 -6
  150. transformers/models/colqwen2/modular_colqwen2.py +7 -6
  151. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
  152. transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
  153. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
  154. transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
  155. transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
  156. transformers/models/convbert/configuration_convbert.py +11 -7
  157. transformers/models/convnext/configuration_convnext.py +2 -4
  158. transformers/models/convnext/image_processing_convnext_fast.py +2 -2
  159. transformers/models/convnext/modeling_convnext.py +7 -6
  160. transformers/models/convnextv2/configuration_convnextv2.py +2 -4
  161. transformers/models/convnextv2/modeling_convnextv2.py +7 -6
  162. transformers/models/cpmant/configuration_cpmant.py +4 -0
  163. transformers/models/csm/configuration_csm.py +9 -15
  164. transformers/models/csm/modeling_csm.py +3 -3
  165. transformers/models/ctrl/configuration_ctrl.py +16 -0
  166. transformers/models/ctrl/modeling_ctrl.py +13 -25
  167. transformers/models/cwm/configuration_cwm.py +5 -7
  168. transformers/models/cwm/modeling_cwm.py +4 -4
  169. transformers/models/d_fine/configuration_d_fine.py +10 -56
  170. transformers/models/d_fine/modeling_d_fine.py +728 -868
  171. transformers/models/d_fine/modular_d_fine.py +335 -412
  172. transformers/models/dab_detr/configuration_dab_detr.py +22 -48
  173. transformers/models/dab_detr/modeling_dab_detr.py +11 -7
  174. transformers/models/dac/modeling_dac.py +1 -1
  175. transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
  176. transformers/models/data2vec/configuration_data2vec_text.py +11 -2
  177. transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
  178. transformers/models/data2vec/modeling_data2vec_text.py +6 -6
  179. transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
  180. transformers/models/dbrx/configuration_dbrx.py +11 -3
  181. transformers/models/dbrx/modeling_dbrx.py +6 -6
  182. transformers/models/dbrx/modular_dbrx.py +6 -6
  183. transformers/models/deberta/configuration_deberta.py +6 -0
  184. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
  185. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
  186. transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
  187. transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
  188. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
  189. transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
  190. transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
  191. transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
  192. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
  193. transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
  194. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
  195. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
  196. transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
  197. transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
  198. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
  199. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
  200. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
  201. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
  202. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
  203. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
  204. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
  205. transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
  206. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
  207. transformers/models/deit/modeling_deit.py +11 -7
  208. transformers/models/depth_anything/configuration_depth_anything.py +12 -42
  209. transformers/models/depth_anything/modeling_depth_anything.py +5 -3
  210. transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
  211. transformers/models/depth_pro/modeling_depth_pro.py +8 -4
  212. transformers/models/detr/configuration_detr.py +18 -49
  213. transformers/models/detr/image_processing_detr_fast.py +11 -11
  214. transformers/models/detr/modeling_detr.py +695 -734
  215. transformers/models/dia/configuration_dia.py +4 -7
  216. transformers/models/dia/generation_dia.py +8 -17
  217. transformers/models/dia/modeling_dia.py +7 -7
  218. transformers/models/dia/modular_dia.py +4 -4
  219. transformers/models/diffllama/configuration_diffllama.py +5 -7
  220. transformers/models/diffllama/modeling_diffllama.py +3 -8
  221. transformers/models/diffllama/modular_diffllama.py +2 -7
  222. transformers/models/dinat/configuration_dinat.py +2 -4
  223. transformers/models/dinat/modeling_dinat.py +7 -6
  224. transformers/models/dinov2/configuration_dinov2.py +2 -4
  225. transformers/models/dinov2/modeling_dinov2.py +9 -8
  226. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
  227. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
  228. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
  229. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
  230. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
  231. transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
  232. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
  233. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
  234. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
  235. transformers/models/distilbert/configuration_distilbert.py +8 -1
  236. transformers/models/distilbert/modeling_distilbert.py +3 -3
  237. transformers/models/doge/configuration_doge.py +17 -7
  238. transformers/models/doge/modeling_doge.py +4 -4
  239. transformers/models/doge/modular_doge.py +20 -10
  240. transformers/models/donut/image_processing_donut_fast.py +4 -4
  241. transformers/models/dots1/configuration_dots1.py +16 -7
  242. transformers/models/dots1/modeling_dots1.py +4 -4
  243. transformers/models/dpr/configuration_dpr.py +19 -1
  244. transformers/models/dpt/configuration_dpt.py +23 -65
  245. transformers/models/dpt/image_processing_dpt_fast.py +5 -5
  246. transformers/models/dpt/modeling_dpt.py +19 -15
  247. transformers/models/dpt/modular_dpt.py +4 -4
  248. transformers/models/edgetam/configuration_edgetam.py +1 -1
  249. transformers/models/edgetam/modeling_edgetam.py +53 -53
  250. transformers/models/edgetam/modular_edgetam.py +5 -7
  251. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
  252. transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
  253. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
  254. transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
  255. transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
  256. transformers/models/electra/configuration_electra.py +13 -2
  257. transformers/models/electra/modeling_electra.py +6 -6
  258. transformers/models/emu3/configuration_emu3.py +12 -10
  259. transformers/models/emu3/modeling_emu3.py +84 -47
  260. transformers/models/emu3/modular_emu3.py +77 -39
  261. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
  262. transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
  263. transformers/models/eomt/configuration_eomt.py +12 -13
  264. transformers/models/eomt/image_processing_eomt_fast.py +3 -3
  265. transformers/models/eomt/modeling_eomt.py +3 -3
  266. transformers/models/eomt/modular_eomt.py +17 -17
  267. transformers/models/eomt_dinov3/__init__.py +28 -0
  268. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  269. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  270. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  271. transformers/models/ernie/configuration_ernie.py +24 -2
  272. transformers/models/ernie/modeling_ernie.py +6 -30
  273. transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
  274. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  275. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
  276. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
  277. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
  278. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
  279. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
  280. transformers/models/esm/configuration_esm.py +9 -11
  281. transformers/models/esm/modeling_esm.py +3 -3
  282. transformers/models/esm/modeling_esmfold.py +1 -6
  283. transformers/models/esm/openfold_utils/protein.py +2 -3
  284. transformers/models/evolla/configuration_evolla.py +21 -8
  285. transformers/models/evolla/modeling_evolla.py +11 -7
  286. transformers/models/evolla/modular_evolla.py +5 -1
  287. transformers/models/exaone4/configuration_exaone4.py +8 -5
  288. transformers/models/exaone4/modeling_exaone4.py +4 -4
  289. transformers/models/exaone4/modular_exaone4.py +11 -8
  290. transformers/models/exaone_moe/__init__.py +27 -0
  291. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  292. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  293. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  294. transformers/models/falcon/configuration_falcon.py +9 -1
  295. transformers/models/falcon/modeling_falcon.py +3 -8
  296. transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
  297. transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
  298. transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
  299. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
  300. transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
  301. transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
  302. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
  303. transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
  304. transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
  305. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
  306. transformers/models/flaubert/configuration_flaubert.py +10 -4
  307. transformers/models/flaubert/modeling_flaubert.py +1 -1
  308. transformers/models/flava/configuration_flava.py +4 -3
  309. transformers/models/flava/image_processing_flava_fast.py +4 -4
  310. transformers/models/flava/modeling_flava.py +36 -28
  311. transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
  312. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
  313. transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
  314. transformers/models/florence2/configuration_florence2.py +4 -0
  315. transformers/models/florence2/modeling_florence2.py +57 -32
  316. transformers/models/florence2/modular_florence2.py +48 -26
  317. transformers/models/fnet/configuration_fnet.py +6 -1
  318. transformers/models/focalnet/configuration_focalnet.py +2 -4
  319. transformers/models/focalnet/modeling_focalnet.py +10 -7
  320. transformers/models/fsmt/configuration_fsmt.py +12 -16
  321. transformers/models/funnel/configuration_funnel.py +8 -0
  322. transformers/models/fuyu/configuration_fuyu.py +5 -8
  323. transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
  324. transformers/models/fuyu/modeling_fuyu.py +24 -23
  325. transformers/models/gemma/configuration_gemma.py +5 -7
  326. transformers/models/gemma/modeling_gemma.py +4 -4
  327. transformers/models/gemma/modular_gemma.py +5 -7
  328. transformers/models/gemma2/configuration_gemma2.py +5 -7
  329. transformers/models/gemma2/modeling_gemma2.py +4 -4
  330. transformers/models/gemma2/modular_gemma2.py +8 -10
  331. transformers/models/gemma3/configuration_gemma3.py +28 -22
  332. transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
  333. transformers/models/gemma3/modeling_gemma3.py +37 -33
  334. transformers/models/gemma3/modular_gemma3.py +46 -42
  335. transformers/models/gemma3n/configuration_gemma3n.py +35 -22
  336. transformers/models/gemma3n/modeling_gemma3n.py +86 -58
  337. transformers/models/gemma3n/modular_gemma3n.py +112 -75
  338. transformers/models/git/configuration_git.py +5 -7
  339. transformers/models/git/modeling_git.py +31 -41
  340. transformers/models/glm/configuration_glm.py +7 -9
  341. transformers/models/glm/modeling_glm.py +4 -4
  342. transformers/models/glm4/configuration_glm4.py +7 -9
  343. transformers/models/glm4/modeling_glm4.py +4 -4
  344. transformers/models/glm46v/configuration_glm46v.py +4 -0
  345. transformers/models/glm46v/image_processing_glm46v.py +5 -2
  346. transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
  347. transformers/models/glm46v/modeling_glm46v.py +91 -46
  348. transformers/models/glm46v/modular_glm46v.py +4 -0
  349. transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
  350. transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
  351. transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
  352. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
  353. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
  354. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
  355. transformers/models/glm4v/configuration_glm4v.py +12 -8
  356. transformers/models/glm4v/image_processing_glm4v.py +5 -2
  357. transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
  358. transformers/models/glm4v/modeling_glm4v.py +120 -63
  359. transformers/models/glm4v/modular_glm4v.py +82 -50
  360. transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
  361. transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
  362. transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
  363. transformers/models/glm_image/configuration_glm_image.py +26 -20
  364. transformers/models/glm_image/image_processing_glm_image.py +1 -1
  365. transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
  366. transformers/models/glm_image/modeling_glm_image.py +337 -236
  367. transformers/models/glm_image/modular_glm_image.py +415 -255
  368. transformers/models/glm_image/processing_glm_image.py +65 -17
  369. transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
  370. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  371. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  372. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  373. transformers/models/glmasr/modeling_glmasr.py +34 -28
  374. transformers/models/glmasr/modular_glmasr.py +23 -11
  375. transformers/models/glpn/image_processing_glpn_fast.py +3 -3
  376. transformers/models/glpn/modeling_glpn.py +4 -2
  377. transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
  378. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
  379. transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
  380. transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
  381. transformers/models/gpt2/configuration_gpt2.py +13 -1
  382. transformers/models/gpt2/modeling_gpt2.py +5 -5
  383. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
  384. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
  385. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
  386. transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
  387. transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
  388. transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
  389. transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
  390. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
  391. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
  392. transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
  393. transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
  394. transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
  395. transformers/models/gptj/configuration_gptj.py +4 -4
  396. transformers/models/gptj/modeling_gptj.py +3 -7
  397. transformers/models/granite/configuration_granite.py +5 -7
  398. transformers/models/granite/modeling_granite.py +4 -4
  399. transformers/models/granite_speech/modeling_granite_speech.py +63 -37
  400. transformers/models/granitemoe/configuration_granitemoe.py +5 -7
  401. transformers/models/granitemoe/modeling_granitemoe.py +4 -4
  402. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
  403. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
  404. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
  405. transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
  406. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
  407. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
  408. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
  409. transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
  410. transformers/models/groupvit/configuration_groupvit.py +4 -1
  411. transformers/models/groupvit/modeling_groupvit.py +29 -22
  412. transformers/models/helium/configuration_helium.py +5 -7
  413. transformers/models/helium/modeling_helium.py +4 -4
  414. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
  415. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
  416. transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
  417. transformers/models/hiera/configuration_hiera.py +2 -4
  418. transformers/models/hiera/modeling_hiera.py +11 -8
  419. transformers/models/hubert/configuration_hubert.py +4 -1
  420. transformers/models/hubert/modeling_hubert.py +7 -4
  421. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
  422. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
  423. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
  424. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
  425. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
  426. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
  427. transformers/models/ibert/configuration_ibert.py +4 -1
  428. transformers/models/idefics/configuration_idefics.py +5 -7
  429. transformers/models/idefics/modeling_idefics.py +3 -4
  430. transformers/models/idefics/vision.py +5 -4
  431. transformers/models/idefics2/configuration_idefics2.py +1 -2
  432. transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
  433. transformers/models/idefics2/modeling_idefics2.py +72 -50
  434. transformers/models/idefics3/configuration_idefics3.py +1 -3
  435. transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
  436. transformers/models/idefics3/modeling_idefics3.py +63 -40
  437. transformers/models/ijepa/modeling_ijepa.py +3 -3
  438. transformers/models/imagegpt/configuration_imagegpt.py +9 -1
  439. transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
  440. transformers/models/imagegpt/modeling_imagegpt.py +8 -4
  441. transformers/models/informer/modeling_informer.py +3 -3
  442. transformers/models/instructblip/configuration_instructblip.py +2 -1
  443. transformers/models/instructblip/modeling_instructblip.py +65 -39
  444. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
  445. transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
  446. transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
  447. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
  448. transformers/models/internvl/configuration_internvl.py +5 -0
  449. transformers/models/internvl/modeling_internvl.py +35 -55
  450. transformers/models/internvl/modular_internvl.py +26 -38
  451. transformers/models/internvl/video_processing_internvl.py +2 -2
  452. transformers/models/jais2/configuration_jais2.py +5 -7
  453. transformers/models/jais2/modeling_jais2.py +4 -4
  454. transformers/models/jamba/configuration_jamba.py +5 -7
  455. transformers/models/jamba/modeling_jamba.py +4 -4
  456. transformers/models/jamba/modular_jamba.py +3 -3
  457. transformers/models/janus/image_processing_janus.py +2 -2
  458. transformers/models/janus/image_processing_janus_fast.py +8 -8
  459. transformers/models/janus/modeling_janus.py +63 -146
  460. transformers/models/janus/modular_janus.py +62 -20
  461. transformers/models/jetmoe/configuration_jetmoe.py +6 -4
  462. transformers/models/jetmoe/modeling_jetmoe.py +3 -3
  463. transformers/models/jetmoe/modular_jetmoe.py +3 -3
  464. transformers/models/kosmos2/configuration_kosmos2.py +10 -8
  465. transformers/models/kosmos2/modeling_kosmos2.py +56 -34
  466. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
  467. transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
  468. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
  469. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
  470. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
  471. transformers/models/lasr/configuration_lasr.py +2 -4
  472. transformers/models/lasr/modeling_lasr.py +3 -3
  473. transformers/models/lasr/modular_lasr.py +3 -3
  474. transformers/models/layoutlm/configuration_layoutlm.py +14 -1
  475. transformers/models/layoutlm/modeling_layoutlm.py +3 -3
  476. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
  477. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
  478. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
  479. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
  480. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
  481. transformers/models/led/configuration_led.py +7 -8
  482. transformers/models/levit/image_processing_levit_fast.py +4 -4
  483. transformers/models/lfm2/configuration_lfm2.py +5 -7
  484. transformers/models/lfm2/modeling_lfm2.py +4 -4
  485. transformers/models/lfm2/modular_lfm2.py +3 -3
  486. transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
  487. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
  488. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  489. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
  490. transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
  491. transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
  492. transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
  493. transformers/models/lightglue/modeling_lightglue.py +3 -3
  494. transformers/models/lightglue/modular_lightglue.py +3 -3
  495. transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
  496. transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
  497. transformers/models/lilt/configuration_lilt.py +6 -1
  498. transformers/models/llama/configuration_llama.py +5 -7
  499. transformers/models/llama/modeling_llama.py +4 -4
  500. transformers/models/llama4/configuration_llama4.py +67 -47
  501. transformers/models/llama4/image_processing_llama4_fast.py +3 -3
  502. transformers/models/llama4/modeling_llama4.py +46 -44
  503. transformers/models/llava/configuration_llava.py +10 -0
  504. transformers/models/llava/image_processing_llava_fast.py +3 -3
  505. transformers/models/llava/modeling_llava.py +38 -65
  506. transformers/models/llava_next/configuration_llava_next.py +2 -1
  507. transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
  508. transformers/models/llava_next/modeling_llava_next.py +61 -60
  509. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
  510. transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
  511. transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
  512. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
  513. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
  514. transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
  515. transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
  516. transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
  517. transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
  518. transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
  519. transformers/models/longformer/configuration_longformer.py +4 -1
  520. transformers/models/longt5/configuration_longt5.py +9 -6
  521. transformers/models/longt5/modeling_longt5.py +2 -1
  522. transformers/models/luke/configuration_luke.py +8 -1
  523. transformers/models/lw_detr/configuration_lw_detr.py +19 -31
  524. transformers/models/lw_detr/modeling_lw_detr.py +43 -44
  525. transformers/models/lw_detr/modular_lw_detr.py +36 -38
  526. transformers/models/lxmert/configuration_lxmert.py +16 -0
  527. transformers/models/m2m_100/configuration_m2m_100.py +7 -8
  528. transformers/models/m2m_100/modeling_m2m_100.py +3 -3
  529. transformers/models/mamba/configuration_mamba.py +5 -2
  530. transformers/models/mamba/modeling_mamba.py +18 -26
  531. transformers/models/mamba2/configuration_mamba2.py +5 -7
  532. transformers/models/mamba2/modeling_mamba2.py +22 -33
  533. transformers/models/marian/configuration_marian.py +10 -4
  534. transformers/models/marian/modeling_marian.py +4 -4
  535. transformers/models/markuplm/configuration_markuplm.py +4 -6
  536. transformers/models/markuplm/modeling_markuplm.py +3 -3
  537. transformers/models/mask2former/configuration_mask2former.py +12 -47
  538. transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
  539. transformers/models/mask2former/modeling_mask2former.py +18 -12
  540. transformers/models/maskformer/configuration_maskformer.py +14 -45
  541. transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
  542. transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
  543. transformers/models/maskformer/modeling_maskformer.py +15 -9
  544. transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
  545. transformers/models/mbart/configuration_mbart.py +9 -4
  546. transformers/models/mbart/modeling_mbart.py +9 -6
  547. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
  548. transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
  549. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  550. transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
  551. transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
  552. transformers/models/mgp_str/modeling_mgp_str.py +4 -2
  553. transformers/models/mimi/configuration_mimi.py +4 -0
  554. transformers/models/mimi/modeling_mimi.py +40 -36
  555. transformers/models/minimax/configuration_minimax.py +8 -11
  556. transformers/models/minimax/modeling_minimax.py +5 -5
  557. transformers/models/minimax/modular_minimax.py +9 -12
  558. transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
  559. transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
  560. transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
  561. transformers/models/ministral/configuration_ministral.py +5 -7
  562. transformers/models/ministral/modeling_ministral.py +4 -4
  563. transformers/models/ministral/modular_ministral.py +5 -8
  564. transformers/models/ministral3/configuration_ministral3.py +4 -4
  565. transformers/models/ministral3/modeling_ministral3.py +4 -4
  566. transformers/models/ministral3/modular_ministral3.py +3 -3
  567. transformers/models/mistral/configuration_mistral.py +5 -7
  568. transformers/models/mistral/modeling_mistral.py +4 -4
  569. transformers/models/mistral/modular_mistral.py +3 -3
  570. transformers/models/mistral3/configuration_mistral3.py +4 -0
  571. transformers/models/mistral3/modeling_mistral3.py +36 -40
  572. transformers/models/mistral3/modular_mistral3.py +31 -32
  573. transformers/models/mixtral/configuration_mixtral.py +8 -11
  574. transformers/models/mixtral/modeling_mixtral.py +4 -4
  575. transformers/models/mlcd/modeling_mlcd.py +7 -5
  576. transformers/models/mlcd/modular_mlcd.py +7 -5
  577. transformers/models/mllama/configuration_mllama.py +5 -7
  578. transformers/models/mllama/image_processing_mllama_fast.py +6 -5
  579. transformers/models/mllama/modeling_mllama.py +19 -19
  580. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
  581. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
  582. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
  583. transformers/models/mobilebert/configuration_mobilebert.py +4 -1
  584. transformers/models/mobilebert/modeling_mobilebert.py +3 -3
  585. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
  586. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
  587. transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
  588. transformers/models/mobilevit/modeling_mobilevit.py +4 -2
  589. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
  590. transformers/models/modernbert/configuration_modernbert.py +46 -21
  591. transformers/models/modernbert/modeling_modernbert.py +146 -899
  592. transformers/models/modernbert/modular_modernbert.py +185 -908
  593. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
  594. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
  595. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
  596. transformers/models/moonshine/configuration_moonshine.py +12 -7
  597. transformers/models/moonshine/modeling_moonshine.py +7 -7
  598. transformers/models/moonshine/modular_moonshine.py +19 -13
  599. transformers/models/moshi/configuration_moshi.py +28 -2
  600. transformers/models/moshi/modeling_moshi.py +4 -9
  601. transformers/models/mpnet/configuration_mpnet.py +6 -1
  602. transformers/models/mpt/configuration_mpt.py +16 -0
  603. transformers/models/mra/configuration_mra.py +8 -1
  604. transformers/models/mt5/configuration_mt5.py +9 -5
  605. transformers/models/mt5/modeling_mt5.py +5 -8
  606. transformers/models/musicgen/configuration_musicgen.py +12 -7
  607. transformers/models/musicgen/modeling_musicgen.py +6 -5
  608. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
  609. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
  610. transformers/models/mvp/configuration_mvp.py +8 -4
  611. transformers/models/mvp/modeling_mvp.py +6 -4
  612. transformers/models/nanochat/configuration_nanochat.py +5 -7
  613. transformers/models/nanochat/modeling_nanochat.py +4 -4
  614. transformers/models/nanochat/modular_nanochat.py +4 -4
  615. transformers/models/nemotron/configuration_nemotron.py +5 -7
  616. transformers/models/nemotron/modeling_nemotron.py +4 -14
  617. transformers/models/nllb/tokenization_nllb.py +7 -5
  618. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
  619. transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
  620. transformers/models/nougat/image_processing_nougat_fast.py +8 -8
  621. transformers/models/nystromformer/configuration_nystromformer.py +8 -1
  622. transformers/models/olmo/configuration_olmo.py +5 -7
  623. transformers/models/olmo/modeling_olmo.py +4 -4
  624. transformers/models/olmo/modular_olmo.py +3 -3
  625. transformers/models/olmo2/configuration_olmo2.py +9 -11
  626. transformers/models/olmo2/modeling_olmo2.py +4 -4
  627. transformers/models/olmo2/modular_olmo2.py +7 -7
  628. transformers/models/olmo3/configuration_olmo3.py +10 -11
  629. transformers/models/olmo3/modeling_olmo3.py +4 -4
  630. transformers/models/olmo3/modular_olmo3.py +13 -14
  631. transformers/models/olmoe/configuration_olmoe.py +5 -7
  632. transformers/models/olmoe/modeling_olmoe.py +4 -4
  633. transformers/models/olmoe/modular_olmoe.py +3 -3
  634. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
  635. transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
  636. transformers/models/oneformer/configuration_oneformer.py +9 -46
  637. transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
  638. transformers/models/oneformer/modeling_oneformer.py +14 -9
  639. transformers/models/openai/configuration_openai.py +16 -0
  640. transformers/models/opt/configuration_opt.py +6 -6
  641. transformers/models/opt/modeling_opt.py +5 -5
  642. transformers/models/ovis2/configuration_ovis2.py +4 -0
  643. transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
  644. transformers/models/ovis2/modeling_ovis2.py +58 -99
  645. transformers/models/ovis2/modular_ovis2.py +52 -13
  646. transformers/models/owlv2/configuration_owlv2.py +4 -1
  647. transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
  648. transformers/models/owlv2/modeling_owlv2.py +40 -27
  649. transformers/models/owlv2/modular_owlv2.py +5 -5
  650. transformers/models/owlvit/configuration_owlvit.py +4 -1
  651. transformers/models/owlvit/modeling_owlvit.py +40 -27
  652. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
  653. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
  654. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
  655. transformers/models/paligemma/configuration_paligemma.py +4 -0
  656. transformers/models/paligemma/modeling_paligemma.py +30 -26
  657. transformers/models/parakeet/configuration_parakeet.py +2 -4
  658. transformers/models/parakeet/modeling_parakeet.py +3 -3
  659. transformers/models/parakeet/modular_parakeet.py +3 -3
  660. transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
  661. transformers/models/patchtst/modeling_patchtst.py +3 -3
  662. transformers/models/pe_audio/modeling_pe_audio.py +4 -4
  663. transformers/models/pe_audio/modular_pe_audio.py +1 -1
  664. transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
  665. transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
  666. transformers/models/pe_video/modeling_pe_video.py +36 -24
  667. transformers/models/pe_video/modular_pe_video.py +36 -23
  668. transformers/models/pegasus/configuration_pegasus.py +8 -5
  669. transformers/models/pegasus/modeling_pegasus.py +4 -4
  670. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
  671. transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
  672. transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
  673. transformers/models/perceiver/modeling_perceiver.py +17 -9
  674. transformers/models/perception_lm/modeling_perception_lm.py +26 -27
  675. transformers/models/perception_lm/modular_perception_lm.py +27 -25
  676. transformers/models/persimmon/configuration_persimmon.py +5 -7
  677. transformers/models/persimmon/modeling_persimmon.py +5 -5
  678. transformers/models/phi/configuration_phi.py +8 -6
  679. transformers/models/phi/modeling_phi.py +4 -4
  680. transformers/models/phi/modular_phi.py +3 -3
  681. transformers/models/phi3/configuration_phi3.py +9 -11
  682. transformers/models/phi3/modeling_phi3.py +4 -4
  683. transformers/models/phi3/modular_phi3.py +3 -3
  684. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
  685. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
  686. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
  687. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
  688. transformers/models/phimoe/configuration_phimoe.py +5 -7
  689. transformers/models/phimoe/modeling_phimoe.py +15 -39
  690. transformers/models/phimoe/modular_phimoe.py +12 -7
  691. transformers/models/pix2struct/configuration_pix2struct.py +12 -9
  692. transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
  693. transformers/models/pix2struct/modeling_pix2struct.py +14 -7
  694. transformers/models/pixio/configuration_pixio.py +2 -4
  695. transformers/models/pixio/modeling_pixio.py +9 -8
  696. transformers/models/pixio/modular_pixio.py +4 -2
  697. transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
  698. transformers/models/pixtral/modeling_pixtral.py +9 -12
  699. transformers/models/plbart/configuration_plbart.py +8 -5
  700. transformers/models/plbart/modeling_plbart.py +9 -7
  701. transformers/models/plbart/modular_plbart.py +1 -1
  702. transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
  703. transformers/models/pop2piano/configuration_pop2piano.py +7 -6
  704. transformers/models/pop2piano/modeling_pop2piano.py +2 -1
  705. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  706. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  707. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  708. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  709. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  710. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
  711. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
  712. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
  713. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
  714. transformers/models/prophetnet/configuration_prophetnet.py +11 -10
  715. transformers/models/prophetnet/modeling_prophetnet.py +12 -23
  716. transformers/models/pvt/image_processing_pvt.py +7 -7
  717. transformers/models/pvt/image_processing_pvt_fast.py +1 -1
  718. transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
  719. transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
  720. transformers/models/qwen2/configuration_qwen2.py +14 -4
  721. transformers/models/qwen2/modeling_qwen2.py +4 -4
  722. transformers/models/qwen2/modular_qwen2.py +3 -3
  723. transformers/models/qwen2/tokenization_qwen2.py +0 -4
  724. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
  725. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
  726. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
  727. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
  728. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
  729. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
  730. transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
  731. transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
  732. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  733. transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
  734. transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
  735. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
  736. transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
  737. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
  738. transformers/models/qwen3/configuration_qwen3.py +15 -5
  739. transformers/models/qwen3/modeling_qwen3.py +4 -4
  740. transformers/models/qwen3/modular_qwen3.py +3 -3
  741. transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
  742. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  743. transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
  744. transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
  745. transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
  746. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
  747. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
  748. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
  749. transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
  750. transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
  751. transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
  752. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
  753. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
  754. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
  755. transformers/models/rag/configuration_rag.py +6 -6
  756. transformers/models/rag/modeling_rag.py +3 -3
  757. transformers/models/rag/retrieval_rag.py +1 -1
  758. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
  759. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
  760. transformers/models/reformer/configuration_reformer.py +7 -7
  761. transformers/models/rembert/configuration_rembert.py +8 -1
  762. transformers/models/rembert/modeling_rembert.py +0 -22
  763. transformers/models/resnet/configuration_resnet.py +2 -4
  764. transformers/models/resnet/modeling_resnet.py +6 -5
  765. transformers/models/roberta/configuration_roberta.py +11 -2
  766. transformers/models/roberta/modeling_roberta.py +6 -6
  767. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
  768. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
  769. transformers/models/roc_bert/configuration_roc_bert.py +8 -1
  770. transformers/models/roc_bert/modeling_roc_bert.py +6 -41
  771. transformers/models/roformer/configuration_roformer.py +13 -2
  772. transformers/models/roformer/modeling_roformer.py +0 -14
  773. transformers/models/rt_detr/configuration_rt_detr.py +8 -49
  774. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
  775. transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
  776. transformers/models/rt_detr/modeling_rt_detr.py +578 -737
  777. transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
  778. transformers/models/rt_detr/modular_rt_detr.py +1508 -6
  779. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
  780. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
  781. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
  782. transformers/models/rwkv/configuration_rwkv.py +2 -3
  783. transformers/models/rwkv/modeling_rwkv.py +0 -23
  784. transformers/models/sam/configuration_sam.py +2 -0
  785. transformers/models/sam/image_processing_sam_fast.py +4 -4
  786. transformers/models/sam/modeling_sam.py +13 -8
  787. transformers/models/sam/processing_sam.py +3 -3
  788. transformers/models/sam2/configuration_sam2.py +1 -1
  789. transformers/models/sam2/modeling_sam2.py +56 -52
  790. transformers/models/sam2/modular_sam2.py +47 -55
  791. transformers/models/sam2_video/modeling_sam2_video.py +50 -51
  792. transformers/models/sam2_video/modular_sam2_video.py +12 -10
  793. transformers/models/sam3/modeling_sam3.py +43 -47
  794. transformers/models/sam3/processing_sam3.py +8 -4
  795. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
  796. transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
  797. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
  798. transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
  799. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
  800. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
  801. transformers/models/sam3_video/modeling_sam3_video.py +27 -14
  802. transformers/models/sam_hq/configuration_sam_hq.py +2 -0
  803. transformers/models/sam_hq/modeling_sam_hq.py +13 -9
  804. transformers/models/sam_hq/modular_sam_hq.py +6 -6
  805. transformers/models/sam_hq/processing_sam_hq.py +7 -6
  806. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
  807. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
  808. transformers/models/seed_oss/configuration_seed_oss.py +7 -9
  809. transformers/models/seed_oss/modeling_seed_oss.py +4 -4
  810. transformers/models/seed_oss/modular_seed_oss.py +3 -3
  811. transformers/models/segformer/image_processing_segformer_fast.py +4 -4
  812. transformers/models/segformer/modeling_segformer.py +4 -2
  813. transformers/models/segformer/modular_segformer.py +3 -3
  814. transformers/models/seggpt/modeling_seggpt.py +20 -8
  815. transformers/models/sew/configuration_sew.py +4 -1
  816. transformers/models/sew/modeling_sew.py +9 -5
  817. transformers/models/sew/modular_sew.py +2 -1
  818. transformers/models/sew_d/configuration_sew_d.py +4 -1
  819. transformers/models/sew_d/modeling_sew_d.py +4 -1
  820. transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
  821. transformers/models/siglip/configuration_siglip.py +4 -1
  822. transformers/models/siglip/modeling_siglip.py +27 -71
  823. transformers/models/siglip2/__init__.py +1 -0
  824. transformers/models/siglip2/configuration_siglip2.py +4 -2
  825. transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
  826. transformers/models/siglip2/modeling_siglip2.py +37 -78
  827. transformers/models/siglip2/modular_siglip2.py +74 -25
  828. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  829. transformers/models/smollm3/configuration_smollm3.py +6 -6
  830. transformers/models/smollm3/modeling_smollm3.py +4 -4
  831. transformers/models/smollm3/modular_smollm3.py +9 -9
  832. transformers/models/smolvlm/configuration_smolvlm.py +1 -3
  833. transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
  834. transformers/models/smolvlm/modeling_smolvlm.py +75 -46
  835. transformers/models/smolvlm/modular_smolvlm.py +36 -23
  836. transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
  837. transformers/models/solar_open/__init__.py +27 -0
  838. transformers/models/solar_open/configuration_solar_open.py +184 -0
  839. transformers/models/solar_open/modeling_solar_open.py +642 -0
  840. transformers/models/solar_open/modular_solar_open.py +224 -0
  841. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
  842. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
  843. transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
  844. transformers/models/speecht5/configuration_speecht5.py +7 -8
  845. transformers/models/splinter/configuration_splinter.py +6 -6
  846. transformers/models/splinter/modeling_splinter.py +8 -3
  847. transformers/models/squeezebert/configuration_squeezebert.py +14 -1
  848. transformers/models/stablelm/configuration_stablelm.py +8 -6
  849. transformers/models/stablelm/modeling_stablelm.py +5 -5
  850. transformers/models/starcoder2/configuration_starcoder2.py +11 -5
  851. transformers/models/starcoder2/modeling_starcoder2.py +5 -5
  852. transformers/models/starcoder2/modular_starcoder2.py +4 -4
  853. transformers/models/superglue/configuration_superglue.py +4 -0
  854. transformers/models/superglue/image_processing_superglue_fast.py +4 -3
  855. transformers/models/superglue/modeling_superglue.py +9 -4
  856. transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
  857. transformers/models/superpoint/modeling_superpoint.py +4 -2
  858. transformers/models/swin/configuration_swin.py +2 -4
  859. transformers/models/swin/modeling_swin.py +11 -8
  860. transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
  861. transformers/models/swin2sr/modeling_swin2sr.py +4 -2
  862. transformers/models/swinv2/configuration_swinv2.py +2 -4
  863. transformers/models/swinv2/modeling_swinv2.py +10 -7
  864. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
  865. transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
  866. transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
  867. transformers/models/t5/configuration_t5.py +9 -8
  868. transformers/models/t5/modeling_t5.py +5 -8
  869. transformers/models/t5gemma/configuration_t5gemma.py +10 -25
  870. transformers/models/t5gemma/modeling_t5gemma.py +9 -9
  871. transformers/models/t5gemma/modular_t5gemma.py +11 -24
  872. transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
  873. transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
  874. transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
  875. transformers/models/table_transformer/configuration_table_transformer.py +18 -49
  876. transformers/models/table_transformer/modeling_table_transformer.py +27 -53
  877. transformers/models/tapas/configuration_tapas.py +12 -1
  878. transformers/models/tapas/modeling_tapas.py +1 -1
  879. transformers/models/tapas/tokenization_tapas.py +1 -0
  880. transformers/models/textnet/configuration_textnet.py +4 -6
  881. transformers/models/textnet/image_processing_textnet_fast.py +3 -3
  882. transformers/models/textnet/modeling_textnet.py +15 -14
  883. transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
  884. transformers/models/timesfm/modeling_timesfm.py +5 -6
  885. transformers/models/timesfm/modular_timesfm.py +5 -6
  886. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
  887. transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
  888. transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
  889. transformers/models/trocr/configuration_trocr.py +11 -7
  890. transformers/models/trocr/modeling_trocr.py +4 -2
  891. transformers/models/tvp/configuration_tvp.py +10 -35
  892. transformers/models/tvp/image_processing_tvp_fast.py +6 -5
  893. transformers/models/tvp/modeling_tvp.py +1 -1
  894. transformers/models/udop/configuration_udop.py +16 -7
  895. transformers/models/udop/modeling_udop.py +10 -6
  896. transformers/models/umt5/configuration_umt5.py +8 -6
  897. transformers/models/umt5/modeling_umt5.py +7 -3
  898. transformers/models/unispeech/configuration_unispeech.py +4 -1
  899. transformers/models/unispeech/modeling_unispeech.py +7 -4
  900. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
  901. transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
  902. transformers/models/upernet/configuration_upernet.py +8 -35
  903. transformers/models/upernet/modeling_upernet.py +1 -1
  904. transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
  905. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  906. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  907. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
  908. transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
  909. transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
  910. transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
  911. transformers/models/video_llava/configuration_video_llava.py +4 -0
  912. transformers/models/video_llava/modeling_video_llava.py +87 -89
  913. transformers/models/videomae/modeling_videomae.py +4 -5
  914. transformers/models/vilt/configuration_vilt.py +4 -1
  915. transformers/models/vilt/image_processing_vilt_fast.py +6 -6
  916. transformers/models/vilt/modeling_vilt.py +27 -12
  917. transformers/models/vipllava/configuration_vipllava.py +4 -0
  918. transformers/models/vipllava/modeling_vipllava.py +57 -31
  919. transformers/models/vipllava/modular_vipllava.py +50 -24
  920. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
  921. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
  922. transformers/models/visual_bert/configuration_visual_bert.py +6 -1
  923. transformers/models/vit/configuration_vit.py +2 -2
  924. transformers/models/vit/modeling_vit.py +7 -5
  925. transformers/models/vit_mae/modeling_vit_mae.py +11 -7
  926. transformers/models/vit_msn/modeling_vit_msn.py +11 -7
  927. transformers/models/vitdet/configuration_vitdet.py +2 -4
  928. transformers/models/vitdet/modeling_vitdet.py +2 -3
  929. transformers/models/vitmatte/configuration_vitmatte.py +6 -35
  930. transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
  931. transformers/models/vitmatte/modeling_vitmatte.py +1 -1
  932. transformers/models/vitpose/configuration_vitpose.py +6 -43
  933. transformers/models/vitpose/modeling_vitpose.py +5 -3
  934. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
  935. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
  936. transformers/models/vits/configuration_vits.py +4 -0
  937. transformers/models/vits/modeling_vits.py +9 -7
  938. transformers/models/vivit/modeling_vivit.py +4 -4
  939. transformers/models/vjepa2/modeling_vjepa2.py +9 -9
  940. transformers/models/voxtral/configuration_voxtral.py +0 -1
  941. transformers/models/voxtral/modeling_voxtral.py +25 -24
  942. transformers/models/voxtral/modular_voxtral.py +26 -20
  943. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
  944. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
  945. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
  946. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
  947. transformers/models/wavlm/configuration_wavlm.py +4 -1
  948. transformers/models/wavlm/modeling_wavlm.py +4 -1
  949. transformers/models/whisper/configuration_whisper.py +6 -4
  950. transformers/models/whisper/generation_whisper.py +0 -1
  951. transformers/models/whisper/modeling_whisper.py +3 -3
  952. transformers/models/x_clip/configuration_x_clip.py +4 -1
  953. transformers/models/x_clip/modeling_x_clip.py +26 -27
  954. transformers/models/xglm/configuration_xglm.py +9 -7
  955. transformers/models/xlm/configuration_xlm.py +10 -7
  956. transformers/models/xlm/modeling_xlm.py +1 -1
  957. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
  958. transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
  959. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
  960. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
  961. transformers/models/xlnet/configuration_xlnet.py +3 -1
  962. transformers/models/xlstm/configuration_xlstm.py +5 -7
  963. transformers/models/xlstm/modeling_xlstm.py +0 -32
  964. transformers/models/xmod/configuration_xmod.py +11 -2
  965. transformers/models/xmod/modeling_xmod.py +13 -16
  966. transformers/models/yolos/image_processing_yolos_fast.py +25 -28
  967. transformers/models/yolos/modeling_yolos.py +7 -7
  968. transformers/models/yolos/modular_yolos.py +16 -16
  969. transformers/models/yoso/configuration_yoso.py +8 -1
  970. transformers/models/youtu/__init__.py +27 -0
  971. transformers/models/youtu/configuration_youtu.py +194 -0
  972. transformers/models/youtu/modeling_youtu.py +619 -0
  973. transformers/models/youtu/modular_youtu.py +254 -0
  974. transformers/models/zamba/configuration_zamba.py +5 -7
  975. transformers/models/zamba/modeling_zamba.py +25 -56
  976. transformers/models/zamba2/configuration_zamba2.py +8 -13
  977. transformers/models/zamba2/modeling_zamba2.py +53 -78
  978. transformers/models/zamba2/modular_zamba2.py +36 -29
  979. transformers/models/zoedepth/configuration_zoedepth.py +17 -40
  980. transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
  981. transformers/models/zoedepth/modeling_zoedepth.py +5 -3
  982. transformers/pipelines/__init__.py +1 -61
  983. transformers/pipelines/any_to_any.py +1 -1
  984. transformers/pipelines/automatic_speech_recognition.py +0 -2
  985. transformers/pipelines/base.py +1 -1
  986. transformers/pipelines/image_text_to_text.py +1 -1
  987. transformers/pipelines/text_to_audio.py +5 -1
  988. transformers/processing_utils.py +35 -44
  989. transformers/pytorch_utils.py +2 -26
  990. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  991. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  992. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  993. transformers/quantizers/quantizer_mxfp4.py +1 -1
  994. transformers/quantizers/quantizer_torchao.py +0 -16
  995. transformers/safetensors_conversion.py +11 -4
  996. transformers/testing_utils.py +3 -28
  997. transformers/tokenization_mistral_common.py +9 -0
  998. transformers/tokenization_python.py +6 -4
  999. transformers/tokenization_utils_base.py +119 -219
  1000. transformers/tokenization_utils_tokenizers.py +31 -2
  1001. transformers/trainer.py +25 -33
  1002. transformers/trainer_seq2seq.py +1 -1
  1003. transformers/training_args.py +411 -417
  1004. transformers/utils/__init__.py +1 -4
  1005. transformers/utils/auto_docstring.py +15 -18
  1006. transformers/utils/backbone_utils.py +13 -373
  1007. transformers/utils/doc.py +4 -36
  1008. transformers/utils/generic.py +69 -33
  1009. transformers/utils/import_utils.py +72 -75
  1010. transformers/utils/loading_report.py +133 -105
  1011. transformers/utils/quantization_config.py +0 -21
  1012. transformers/video_processing_utils.py +5 -5
  1013. transformers/video_utils.py +3 -1
  1014. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
  1015. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
  1016. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1017. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1018. transformers/pipelines/image_to_text.py +0 -189
  1019. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1020. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1021. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -21,18 +21,16 @@
21
21
 
22
22
  import math
23
23
  from collections.abc import Callable
24
- from contextlib import nullcontext
25
24
  from typing import Optional
26
25
 
27
26
  import torch
28
- import torch.nn.functional as F
29
27
  from torch import nn
30
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
29
 
32
30
  from ... import initialization as init
33
31
  from ...activations import ACT2FN
34
- from ...integrations import use_kernel_func_from_hub
35
- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
32
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
33
+ from ...masking_utils import create_bidirectional_mask, create_bidirectional_sliding_window_mask
36
34
  from ...modeling_layers import GradientCheckpointingLayer
37
35
  from ...modeling_outputs import (
38
36
  BaseModelOutput,
@@ -43,158 +41,13 @@ from ...modeling_outputs import (
43
41
  TokenClassifierOutput,
44
42
  )
45
43
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
46
- from ...modeling_utils import PreTrainedModel
47
- from ...utils import auto_docstring, is_flash_attn_2_available, logging
48
- from ...utils.generic import maybe_autocast
49
- from ...utils.import_utils import is_triton_available
44
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
45
+ from ...processing_utils import Unpack
46
+ from ...utils import TransformersKwargs, auto_docstring
47
+ from ...utils.generic import can_return_tuple, check_model_inputs, maybe_autocast
50
48
  from .configuration_modernbert import ModernBertConfig
51
49
 
52
50
 
53
- if is_flash_attn_2_available():
54
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
55
- from flash_attn.layers.rotary import RotaryEmbedding
56
- from flash_attn.ops.triton.rotary import apply_rotary
57
- else:
58
- RotaryEmbedding = object
59
-
60
-
61
- logger = logging.get_logger(__name__)
62
-
63
-
64
- class ApplyRotaryEmbUnpad(torch.autograd.Function):
65
- @staticmethod
66
- def forward(
67
- ctx,
68
- qkv,
69
- cos,
70
- sin,
71
- cu_seqlens: torch.Tensor | None = None,
72
- max_seqlen: int | None = None,
73
- ):
74
- # (total_nnz, 3, nheads, headdim)
75
- qkv = qkv.contiguous()
76
- total_nnz, _three, _nheads, headdim = qkv.shape
77
- # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
78
- # we get the same tensor
79
- # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
80
- qk = qkv[:, :2].view(total_nnz, -1, headdim)
81
- apply_rotary(
82
- qk,
83
- cos,
84
- sin,
85
- seqlen_offsets=0,
86
- cu_seqlens=cu_seqlens,
87
- max_seqlen=max_seqlen,
88
- interleaved=False,
89
- inplace=True,
90
- )
91
-
92
- ctx.save_for_backward(cos, sin, cu_seqlens)
93
- ctx.max_seqlen = max_seqlen
94
- return qkv
95
-
96
- @staticmethod
97
- def backward(ctx, do):
98
- cos, sin, cu_seqlens = ctx.saved_tensors
99
- do = do.contiguous()
100
- total_nnz, _three, _nheads, headdim = do.shape
101
- # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
102
- # we get the same tensor
103
- dqk = do[:, :2].view(total_nnz, -1, headdim)
104
- apply_rotary(
105
- dqk,
106
- cos,
107
- sin,
108
- seqlen_offsets=0,
109
- cu_seqlens=cu_seqlens,
110
- max_seqlen=ctx.max_seqlen,
111
- interleaved=False,
112
- inplace=True,
113
- conjugate=True,
114
- )
115
-
116
- return do, None, None, None, None, None, None
117
-
118
-
119
- def apply_rotary_unpadded(
120
- qkv,
121
- cos,
122
- sin,
123
- cu_seqlens: torch.Tensor | None = None,
124
- max_seqlen: int | None = None,
125
- ):
126
- """
127
- Arguments:
128
- qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
129
- cos, sin: (seqlen_rotary, rotary_dim / 2)
130
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
131
- of 1st half and 2nd half (GPT-NeoX style).
132
- inplace: if True, apply rotary embedding in-place.
133
- seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
134
- Most commonly used in inference when we have KV cache.
135
- cu_seqlens: (batch + 1,) or None
136
- max_seqlen: int
137
- Return:
138
- out: (total_nnz, dim)
139
- rotary_dim must be <= headdim
140
- Apply rotary embedding to the first rotary_dim of x.
141
- """
142
- return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
143
-
144
-
145
- class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
146
- """
147
- The rotary position embeddings applied directly to unpadded sequences.
148
- """
149
-
150
- def __init__(
151
- self,
152
- dim: int,
153
- base: float = 10000.0,
154
- max_seqlen: int | None = None,
155
- device: torch.device | None = None,
156
- dtype: torch.dtype | None = None,
157
- ):
158
- """
159
- max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
160
- up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
161
- the cos_sin_cache will be recomputed during the forward pass.
162
- """
163
- super().__init__(dim=dim, base=base, device=device, interleaved=False)
164
- self.max_seqlen = max_seqlen
165
-
166
- if max_seqlen is not None and device is not None and dtype is not None:
167
- self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
168
-
169
- def forward(
170
- self,
171
- qkv: torch.Tensor,
172
- cu_seqlens: torch.Tensor,
173
- max_seqlen: int | None = None,
174
- ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
175
- """
176
- Apply rotary embedding *inplace* to qkv.
177
- qkv: (total_nnz, 3, nheads, headdim)
178
- cu_seqlens: (batch + 1,) cumulative sequence lengths
179
- max_seqlen: int max seq length in the batch
180
- """
181
- if max_seqlen is not None:
182
- self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
183
-
184
- qkv = apply_rotary_unpadded(
185
- qkv,
186
- self._cos_cached,
187
- self._sin_cached,
188
- cu_seqlens=cu_seqlens,
189
- max_seqlen=max_seqlen,
190
- )
191
-
192
- return qkv
193
-
194
- def extra_repr(self) -> str:
195
- return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
196
-
197
-
198
51
  class ModernBertEmbeddings(nn.Module):
199
52
  """
200
53
  Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
@@ -207,21 +60,13 @@ class ModernBertEmbeddings(nn.Module):
207
60
  self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
208
61
  self.drop = nn.Dropout(config.embedding_dropout)
209
62
 
210
- @torch.compile(dynamic=True)
211
- def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
212
- return self.drop(self.norm(self.tok_embeddings(input_ids)))
213
-
214
63
  def forward(
215
64
  self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.Tensor | None = None
216
65
  ) -> torch.Tensor:
217
66
  if inputs_embeds is not None:
218
67
  hidden_states = self.drop(self.norm(inputs_embeds))
219
68
  else:
220
- hidden_states = (
221
- self.compiled_embeddings(input_ids)
222
- if self.config.reference_compile
223
- else self.drop(self.norm(self.tok_embeddings(input_ids)))
224
- )
69
+ hidden_states = self.drop(self.norm(self.tok_embeddings(input_ids)))
225
70
  return hidden_states
226
71
 
227
72
 
@@ -326,6 +171,29 @@ class ModernBertRotaryEmbedding(nn.Module):
326
171
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
327
172
 
328
173
 
174
+ def eager_attention_forward(
175
+ module: nn.Module,
176
+ query: torch.Tensor,
177
+ key: torch.Tensor,
178
+ value: torch.Tensor,
179
+ attention_mask: torch.Tensor | None,
180
+ scaling: float,
181
+ dropout: float = 0.0,
182
+ **kwargs,
183
+ ):
184
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
185
+ if attention_mask is not None:
186
+ causal_mask = attention_mask[:, :, :, : key.shape[-2]]
187
+ attn_weights = attn_weights + causal_mask
188
+
189
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
190
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
191
+
192
+ attn_output = torch.matmul(attn_weights, value)
193
+ attn_output = attn_output.transpose(1, 2).contiguous()
194
+ return attn_output, attn_weights
195
+
196
+
329
197
  def rotate_half(x):
330
198
  """Rotates half the hidden dims of the input."""
331
199
  x1 = x[..., : x.shape[-1] // 2]
@@ -352,137 +220,15 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
352
220
  Returns:
353
221
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
354
222
  """
223
+ original_dtype = q.dtype
355
224
  cos = cos.unsqueeze(unsqueeze_dim)
356
225
  sin = sin.unsqueeze(unsqueeze_dim)
357
- q_embed = (q * cos) + (rotate_half(q) * sin)
358
- k_embed = (k * cos) + (rotate_half(k) * sin)
359
- return q_embed, k_embed
360
-
361
-
362
- def eager_attention_forward(
363
- module: "ModernBertAttention",
364
- qkv: torch.Tensor,
365
- attention_mask: torch.Tensor,
366
- sliding_window_mask: torch.Tensor,
367
- position_ids: torch.LongTensor | None,
368
- local_attention: tuple[int, int],
369
- bs: int,
370
- dim: int,
371
- position_embeddings: torch.Tensor,
372
- output_attentions: bool | None = False,
373
- **_kwargs,
374
- ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]:
375
- # qkv: [batch_size, seqlen, 3, nheads, headdim]
376
- cos, sin = position_embeddings
377
- query, key, value = qkv.transpose(3, 1).unbind(dim=2)
378
- # query, key, value: [batch_size, heads, seq_len, head_dim]
379
- query, key = apply_rotary_pos_emb(query, key, cos, sin)
380
-
381
- scale = module.head_dim**-0.5
382
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
383
-
384
- if local_attention != (-1, -1):
385
- attention_mask = sliding_window_mask
386
-
387
- attn_weights = attn_weights + attention_mask
388
-
389
- # upcast attention to fp32
390
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
391
- attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
392
- attn_output = torch.matmul(attn_weights, value)
393
- attn_output = attn_output.transpose(1, 2).contiguous()
394
- attn_output = attn_output.view(bs, -1, dim)
395
- if output_attentions:
396
- return (attn_output, attn_weights)
397
- return (attn_output,)
398
-
399
-
400
- def flash_attention_forward(
401
- module: "ModernBertAttention",
402
- qkv: torch.Tensor,
403
- rotary_emb: ModernBertUnpaddedRotaryEmbedding,
404
- cu_seqlens: torch.Tensor,
405
- max_seqlen: int,
406
- local_attention: tuple[int, int],
407
- bs: int,
408
- dim: int,
409
- target_dtype: torch.dtype = torch.bfloat16,
410
- **_kwargs,
411
- ) -> tuple[torch.Tensor]:
412
- # (total_seqlen, 3, nheads, headdim)
413
- qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
414
-
415
- convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
416
- if convert_dtype:
417
- # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
418
- # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
419
- orig_dtype = qkv.dtype
420
- qkv = qkv.to(target_dtype)
421
-
422
- attn = flash_attn_varlen_qkvpacked_func(
423
- qkv,
424
- cu_seqlens=cu_seqlens,
425
- max_seqlen=max_seqlen,
426
- dropout_p=module.attention_dropout if module.training else 0.0,
427
- deterministic=module.deterministic_flash_attn,
428
- window_size=local_attention,
429
- )
430
- attn = attn.to(orig_dtype) # type: ignore
431
- else:
432
- attn = flash_attn_varlen_qkvpacked_func(
433
- qkv,
434
- cu_seqlens=cu_seqlens,
435
- max_seqlen=max_seqlen,
436
- dropout_p=module.attention_dropout if module.training else 0.0,
437
- deterministic=module.deterministic_flash_attn,
438
- window_size=local_attention,
439
- )
440
- return (attn.view(bs, dim),)
441
-
442
-
443
- def sdpa_attention_forward(
444
- module: "ModernBertAttention",
445
- qkv: torch.Tensor,
446
- attention_mask: torch.Tensor,
447
- sliding_window_mask: torch.Tensor,
448
- position_ids: torch.LongTensor | None,
449
- local_attention: tuple[int, int],
450
- bs: int,
451
- dim: int,
452
- position_embeddings: torch.Tensor,
453
- **_kwargs,
454
- ) -> tuple[torch.Tensor]:
455
- # qkv: [batch_size, seqlen, 3, nheads, headdim]
456
- cos, sin = position_embeddings
457
- query, key, value = qkv.transpose(3, 1).unbind(dim=2)
458
- # query, key, value: [batch_size, heads, seq_len, head_dim]
459
- query, key = apply_rotary_pos_emb(query, key, cos, sin)
460
-
461
- if local_attention != (-1, -1):
462
- attention_mask = sliding_window_mask
463
-
464
- attn_output = (
465
- F.scaled_dot_product_attention(
466
- query,
467
- key,
468
- value,
469
- dropout_p=module.attention_dropout if module.training else 0.0,
470
- attn_mask=attention_mask,
471
- )
472
- .transpose(1, 2)
473
- .contiguous()
474
- )
475
- attn_output = attn_output.view(bs, -1, dim)
476
- return (attn_output,)
477
-
478
-
479
- MODERNBERT_ATTENTION_FUNCTION = {
480
- "flash_attention_2": flash_attention_forward,
481
- "eager": eager_attention_forward,
482
- "sdpa": sdpa_attention_forward,
483
- }
226
+ q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
227
+ k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
228
+ return q_embed.to(original_dtype), k_embed.to(original_dtype)
484
229
 
485
230
 
231
+ @use_kernelized_func(apply_rotary_pos_emb)
486
232
  class ModernBertAttention(nn.Module):
487
233
  """Performs multi-headed self attention on a batch of unpadded sequences.
488
234
 
@@ -493,10 +239,10 @@ class ModernBertAttention(nn.Module):
493
239
  See `forward` method for additional details.
494
240
  """
495
241
 
496
- def __init__(self, config: ModernBertConfig, layer_id: int | None = None):
242
+ def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
497
243
  super().__init__()
498
244
  self.config = config
499
- self.layer_id = layer_id
245
+ self.layer_idx = layer_idx
500
246
 
501
247
  if config.hidden_size % config.num_attention_heads != 0:
502
248
  raise ValueError(
@@ -505,29 +251,19 @@ class ModernBertAttention(nn.Module):
505
251
 
506
252
  self.attention_dropout = config.attention_dropout
507
253
  self.deterministic_flash_attn = config.deterministic_flash_attn
508
- self.num_heads = config.num_attention_heads
509
254
  self.head_dim = config.hidden_size // config.num_attention_heads
510
- self.all_head_size = self.head_dim * self.num_heads
511
- self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
512
- layer_type = config.layer_types[layer_id]
255
+ self.Wqkv = nn.Linear(
256
+ config.hidden_size, 3 * self.head_dim * config.num_attention_heads, bias=config.attention_bias
257
+ )
513
258
 
514
- if layer_id % config.global_attn_every_n_layers != 0:
515
- self.local_attention = (config.local_attention // 2, config.local_attention // 2)
516
- max_position_embeddings = config.local_attention
259
+ if config.layer_types[layer_idx] == "sliding_attention":
260
+ # config.sliding_window = local_attention // 2 (half-window size, e.g. 64 for local_attention=128)
261
+ # +1 is needed because flash attention sets inclusive boundaries (see modeling_flash_attention_utils.py)
262
+ self.sliding_window = config.sliding_window + 1
517
263
  else:
518
- self.local_attention = (-1, -1)
519
- max_position_embeddings = config.max_position_embeddings
264
+ self.sliding_window = None
520
265
 
521
- if config._attn_implementation == "flash_attention_2":
522
- rope_parameters_dict = (
523
- self.config.rope_parameters[layer_type] if layer_type is not None else self.config.rope_parameters
524
- )
525
- rope_theta = rope_parameters_dict["rope_theta"]
526
- self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
527
- dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
528
- )
529
- else:
530
- self.rotary_emb = None
266
+ self.is_causal = False
531
267
 
532
268
  self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
533
269
  self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
@@ -535,82 +271,75 @@ class ModernBertAttention(nn.Module):
535
271
  def forward(
536
272
  self,
537
273
  hidden_states: torch.Tensor,
538
- position_embeddings: torch.Tensor | None = None,
539
- output_attentions: bool | None = False,
540
- **kwargs,
541
- ) -> torch.Tensor:
274
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
275
+ attention_mask: torch.Tensor | None = None,
276
+ **kwargs: Unpack[TransformersKwargs],
277
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
278
+ input_shape = hidden_states.shape[:-1]
279
+
542
280
  qkv = self.Wqkv(hidden_states)
281
+ qkv = qkv.view(*input_shape, 3, -1, self.head_dim)
282
+ query_states, key_states, value_states = qkv.unbind(dim=-3)
543
283
 
544
- bs = hidden_states.shape[0]
545
- if self.config._attn_implementation == "flash_attention_2":
546
- qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
547
- else:
548
- qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
284
+ query_states = query_states.transpose(1, 2)
285
+ key_states = key_states.transpose(1, 2)
286
+ value_states = value_states.transpose(1, 2)
287
+
288
+ cos, sin = position_embeddings
289
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1)
549
290
 
550
- attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
291
+ attention_interface = eager_attention_forward
292
+ if self.config._attn_implementation != "eager":
293
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
294
+
295
+ attn_output, attn_weights = attention_interface(
551
296
  self,
552
- qkv=qkv,
553
- rotary_emb=self.rotary_emb,
554
- local_attention=self.local_attention,
555
- bs=bs,
556
- dim=self.all_head_size,
557
- position_embeddings=position_embeddings,
558
- output_attentions=output_attentions,
297
+ query_states,
298
+ key_states,
299
+ value_states,
300
+ attention_mask,
301
+ dropout=self.attention_dropout if self.training else 0.0,
302
+ scaling=self.head_dim**-0.5,
303
+ sliding_window=self.sliding_window,
304
+ deterministic=self.deterministic_flash_attn,
559
305
  **kwargs,
560
306
  )
561
- hidden_states = attn_outputs[0]
562
- hidden_states = self.out_drop(self.Wo(hidden_states))
563
307
 
564
- return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
308
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
309
+ attn_output = self.out_drop(self.Wo(attn_output))
310
+ return attn_output, attn_weights
565
311
 
566
312
 
567
313
  class ModernBertEncoderLayer(GradientCheckpointingLayer):
568
- def __init__(self, config: ModernBertConfig, layer_id: int | None = None):
314
+ def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
569
315
  super().__init__()
570
316
  self.config = config
571
- if layer_id == 0:
317
+ self.layer_idx = layer_idx
318
+ if layer_idx == 0:
572
319
  self.attn_norm = nn.Identity()
573
320
  else:
574
321
  self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
575
- self.attn = ModernBertAttention(config=config, layer_id=layer_id)
322
+ self.attn = ModernBertAttention(config=config, layer_idx=layer_idx)
576
323
  self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
577
324
  self.mlp = ModernBertMLP(config)
578
- self.attention_type = config.layer_types[layer_id]
579
-
580
- @torch.compile(dynamic=True)
581
- def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
582
- return self.mlp(self.mlp_norm(hidden_states))
325
+ self.attention_type = config.layer_types[layer_idx]
583
326
 
584
327
  def forward(
585
328
  self,
586
329
  hidden_states: torch.Tensor,
587
330
  attention_mask: torch.Tensor | None = None,
588
- sliding_window_mask: torch.Tensor | None = None,
589
- position_ids: torch.LongTensor | None = None,
590
- cu_seqlens: torch.Tensor | None = None,
591
- max_seqlen: int | None = None,
592
331
  position_embeddings: torch.Tensor | None = None,
593
- output_attentions: bool | None = False,
332
+ **kwargs: Unpack[TransformersKwargs],
594
333
  ) -> torch.Tensor:
595
- attn_outputs = self.attn(
334
+ attn_output, _ = self.attn(
596
335
  self.attn_norm(hidden_states),
597
- attention_mask=attention_mask,
598
- sliding_window_mask=sliding_window_mask,
599
- position_ids=position_ids,
600
- cu_seqlens=cu_seqlens,
601
- max_seqlen=max_seqlen,
602
336
  position_embeddings=position_embeddings,
603
- output_attentions=output_attentions,
604
- )
605
- hidden_states = hidden_states + attn_outputs[0]
606
- mlp_output = (
607
- self.compiled_mlp(hidden_states)
608
- if self.config.reference_compile
609
- else self.mlp(self.mlp_norm(hidden_states))
337
+ attention_mask=attention_mask,
338
+ **kwargs,
610
339
  )
611
- hidden_states = hidden_states + mlp_output
612
-
613
- return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
340
+ hidden_states = hidden_states + attn_output
341
+ hidden_states = hidden_states + self.mlp(self.mlp_norm(hidden_states))
342
+ return hidden_states
614
343
 
615
344
 
616
345
  @auto_docstring
@@ -621,7 +350,13 @@ class ModernBertPreTrainedModel(PreTrainedModel):
621
350
  _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
622
351
  _supports_flash_attn = True
623
352
  _supports_sdpa = True
624
- _supports_flex_attn = False
353
+ _supports_flex_attn = True
354
+ _supports_attention_backend = True
355
+
356
+ _can_record_outputs = {
357
+ "hidden_states": ModernBertEncoderLayer,
358
+ "attentions": ModernBertAttention,
359
+ }
625
360
 
626
361
  @torch.no_grad()
627
362
  def _init_weights(self, module: nn.Module):
@@ -683,9 +418,6 @@ class ModernBertPreTrainedModel(PreTrainedModel):
683
418
  curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
684
419
  init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
685
420
  init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
686
- elif isinstance(module, ModernBertUnpaddedRotaryEmbedding):
687
- inv_freq = module._compute_inv_freq()
688
- init.copy_(module.inv_freq, inv_freq)
689
421
 
690
422
  def _check_and_adjust_attn_implementation(
691
423
  self, attn_implementation: str | None, is_init_check: bool = False
@@ -693,137 +425,17 @@ class ModernBertPreTrainedModel(PreTrainedModel):
693
425
  """
694
426
  Checks and dispatches to hhe requested attention implementation.
695
427
  """
696
- # If the user didn't specify anything, try to use flash_attention_2 if available.
428
+ # If the user didn't specify anything, try to use flash_attention_2.
697
429
  # Otherwise we fall back to the default SDPA -> Eager from the super() method.
698
- # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
699
- # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
700
-
701
430
  try:
702
- attn_implementation = (
703
- "flash_attention_2"
704
- if attn_implementation is None and self._flash_attn_2_can_dispatch()
705
- else attn_implementation
431
+ requested_attn_implementation = "flash_attention_2" if attn_implementation is None else attn_implementation
432
+ return super()._check_and_adjust_attn_implementation(
433
+ attn_implementation=requested_attn_implementation, is_init_check=is_init_check
706
434
  )
707
435
  except (ValueError, ImportError):
708
- pass
709
- return super()._check_and_adjust_attn_implementation(
710
- attn_implementation=attn_implementation, is_init_check=is_init_check
711
- )
712
-
713
- def _maybe_set_compile(self):
714
- if self.config.reference_compile is False:
715
- return
716
-
717
- if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
718
- if self.config.reference_compile:
719
- logger.warning_once(
720
- "If `accelerate` split the model across devices, `torch.compile` will not work. "
721
- "Falling back to non-compiled mode."
722
- )
723
- self.config.reference_compile = False
724
-
725
- if self.device.type == "mps":
726
- if self.config.reference_compile:
727
- logger.warning_once(
728
- "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
729
- "Falling back to non-compiled mode."
730
- )
731
- self.config.reference_compile = False
732
-
733
- if self.device.type == "cpu":
734
- if self.config.reference_compile:
735
- logger.warning_once(
736
- "Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
737
- "Falling back to non-compiled mode."
738
- )
739
- self.config.reference_compile = False
740
-
741
- if self.config.reference_compile is None:
742
- self.config.reference_compile = is_triton_available()
743
-
744
- def resize_token_embeddings(self, *args, **kwargs):
745
- model_embeds = super().resize_token_embeddings(*args, **kwargs)
746
-
747
- if self.config.reference_compile in {True, None}:
748
- if self.config.reference_compile:
749
- logger.warning_once(
750
- "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
751
- )
752
- self.config.reference_compile = False
753
-
754
- return model_embeds
755
-
756
-
757
- def _unpad_modernbert_input(
758
- inputs: torch.Tensor,
759
- attention_mask: torch.Tensor,
760
- position_ids: torch.Tensor | None = None,
761
- labels: torch.Tensor | None = None,
762
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, torch.Tensor | None, torch.Tensor | None]:
763
- """
764
- Remove padding from input sequences.
765
-
766
- Args:
767
- inputs: (batch, seqlen, ...) or (batch, seqlen)
768
- attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
769
- position_ids: (batch, seqlen), int, position ids
770
- labels: (batch, seqlen), int, labels
771
-
772
- Returns:
773
- unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
774
- indices: (total_nnz)
775
- cu_seqlens: (batch + 1), the cumulative sequence lengths
776
- max_seqlen_in_batch: int
777
- unpadded_position_ids: (total_nnz) or None
778
- unpadded_labels: (total_nnz) or None
779
- """
780
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
781
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
782
- max_seqlen_in_batch = int(seqlens_in_batch.max().item())
783
- cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
784
-
785
- if inputs.dim() == 2:
786
- unpadded_inputs = inputs.flatten()[indices]
787
- else:
788
- batch, seqlen, *rest = inputs.shape
789
- shape = batch * seqlen
790
- unpadded_inputs = inputs.view(shape, *rest)[indices]
791
-
792
- unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
793
- unpadded_labels = labels.flatten()[indices] if labels is not None else None
794
-
795
- return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
796
-
797
-
798
- def _pad_modernbert_output(
799
- inputs: torch.Tensor,
800
- indices: torch.Tensor,
801
- batch: int,
802
- seqlen: int,
803
- ) -> torch.Tensor:
804
- """
805
- Add padding to sequences.
806
-
807
- Args:
808
- inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
809
- indices: (total_nnz)
810
- batch: int, batch size
811
- seqlen: int, max sequence length
812
-
813
- Returns:
814
- padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
815
- """
816
- if inputs.dim() == 1:
817
- output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
818
- output[indices] = inputs
819
- padded_inputs = output.view(batch, seqlen)
820
- else:
821
- _, *rest = inputs.shape
822
- output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
823
- output[indices] = inputs
824
- padded_inputs = output.view(batch, seqlen, *rest)
825
-
826
- return padded_inputs
436
+ return super()._check_and_adjust_attn_implementation(
437
+ attn_implementation=attn_implementation, is_init_check=is_init_check
438
+ )
827
439
 
828
440
 
829
441
  @auto_docstring
@@ -833,7 +445,7 @@ class ModernBertModel(ModernBertPreTrainedModel):
833
445
  self.config = config
834
446
  self.embeddings = ModernBertEmbeddings(config)
835
447
  self.layers = nn.ModuleList(
836
- [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
448
+ [ModernBertEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
837
449
  )
838
450
  self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
839
451
  self.rotary_emb = ModernBertRotaryEmbedding(config=config)
@@ -846,175 +458,53 @@ class ModernBertModel(ModernBertPreTrainedModel):
846
458
  def set_input_embeddings(self, value):
847
459
  self.embeddings.tok_embeddings = value
848
460
 
461
+ @check_model_inputs
849
462
  @auto_docstring
850
463
  def forward(
851
464
  self,
852
465
  input_ids: torch.LongTensor | None = None,
853
466
  attention_mask: torch.Tensor | None = None,
854
- sliding_window_mask: torch.Tensor | None = None,
855
467
  position_ids: torch.LongTensor | None = None,
856
468
  inputs_embeds: torch.Tensor | None = None,
857
- indices: torch.Tensor | None = None,
858
- cu_seqlens: torch.Tensor | None = None,
859
- max_seqlen: int | None = None,
860
- batch_size: int | None = None,
861
- seq_len: int | None = None,
862
- output_attentions: bool | None = None,
863
- output_hidden_states: bool | None = None,
864
- return_dict: bool | None = None,
865
- **kwargs,
866
- ) -> tuple[torch.Tensor, ...] | BaseModelOutput:
867
- r"""
868
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
869
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
870
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
871
- far-away tokens in the local attention layers when not using Flash Attention.
872
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
873
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
874
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
875
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
876
- max_seqlen (`int`, *optional*):
877
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
878
- batch_size (`int`, *optional*):
879
- Batch size of the input sequences. Used to pad the output tensors.
880
- seq_len (`int`, *optional*):
881
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
882
- """
883
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
884
- output_hidden_states = (
885
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
886
- )
887
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
888
-
469
+ **kwargs: Unpack[TransformersKwargs],
470
+ ) -> BaseModelOutput:
889
471
  if (input_ids is None) ^ (inputs_embeds is not None):
890
472
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
891
473
 
892
- all_hidden_states = () if output_hidden_states else None
893
- all_self_attentions = () if output_attentions else None
894
-
895
- self._maybe_set_compile()
896
-
897
- if input_ids is not None:
898
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
899
-
900
- if batch_size is None and seq_len is None:
901
- if inputs_embeds is not None:
902
- batch_size, seq_len = inputs_embeds.shape[:2]
903
- else:
904
- batch_size, seq_len = input_ids.shape[:2]
474
+ seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
905
475
  device = input_ids.device if input_ids is not None else inputs_embeds.device
906
476
 
907
- if attention_mask is None:
908
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
909
-
910
- repad = False
911
- if self.config._attn_implementation == "flash_attention_2":
912
- if indices is None and cu_seqlens is None and max_seqlen is None:
913
- repad = True
914
- if inputs_embeds is None:
915
- with torch.no_grad():
916
- input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
917
- inputs=input_ids, attention_mask=attention_mask
918
- )
919
- else:
920
- inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
921
- inputs=inputs_embeds, attention_mask=attention_mask
922
- )
923
- if position_ids is None:
924
- position_ids = indices.unsqueeze(0)
925
- else:
926
- if position_ids is None:
927
- position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
928
-
929
- attention_mask, sliding_window_mask = self._update_attention_mask(
930
- attention_mask, output_attentions=output_attentions
931
- )
477
+ if position_ids is None:
478
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
932
479
 
933
480
  hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
481
+
482
+ if not isinstance(attention_mask_mapping := attention_mask, dict):
483
+ mask_kwargs = {
484
+ "config": self.config,
485
+ "input_embeds": hidden_states,
486
+ "attention_mask": attention_mask,
487
+ }
488
+ attention_mask_mapping = {
489
+ "full_attention": create_bidirectional_mask(**mask_kwargs),
490
+ "sliding_attention": create_bidirectional_sliding_window_mask(**mask_kwargs),
491
+ }
492
+
934
493
  position_embeddings = {}
935
494
  for layer_type in self.config.layer_types:
936
495
  position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
937
496
 
938
497
  for encoder_layer in self.layers:
939
- if output_hidden_states:
940
- all_hidden_states = all_hidden_states + (hidden_states,)
941
-
942
- layer_outputs = encoder_layer(
498
+ hidden_states = encoder_layer(
943
499
  hidden_states,
944
- attention_mask=attention_mask,
945
- sliding_window_mask=sliding_window_mask,
946
- position_ids=position_ids,
947
- cu_seqlens=cu_seqlens,
948
- max_seqlen=max_seqlen,
500
+ attention_mask=attention_mask_mapping[encoder_layer.attention_type],
949
501
  position_embeddings=position_embeddings[encoder_layer.attention_type],
950
- output_attentions=output_attentions,
502
+ **kwargs,
951
503
  )
952
- hidden_states = layer_outputs[0]
953
- if output_attentions and len(layer_outputs) > 1:
954
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
955
-
956
- if output_hidden_states:
957
- all_hidden_states = all_hidden_states + (hidden_states,)
958
504
 
959
505
  hidden_states = self.final_norm(hidden_states)
960
506
 
961
- if repad:
962
- hidden_states = _pad_modernbert_output(
963
- inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
964
- )
965
- if all_hidden_states is not None:
966
- all_hidden_states = tuple(
967
- _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
968
- for hs in all_hidden_states
969
- )
970
- # If the attention implementation is FA2 and there is no need for repadding, there might still be the batch
971
- # dimension missing
972
- elif (
973
- self.config._attn_implementation == "flash_attention_2"
974
- and all_hidden_states is not None
975
- and all_hidden_states[-1].dim() == 2
976
- ):
977
- hidden_states = hidden_states.unsqueeze(0)
978
- all_hidden_states = tuple(hs.unsqueeze(0) for hs in all_hidden_states)
979
-
980
- if not return_dict:
981
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
982
- return BaseModelOutput(
983
- last_hidden_state=hidden_states,
984
- hidden_states=all_hidden_states,
985
- attentions=all_self_attentions,
986
- )
987
-
988
- def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
989
- if output_attentions:
990
- if self.config._attn_implementation == "sdpa":
991
- logger.warning_once(
992
- "Outputting attentions is only supported with the 'eager' attention implementation, "
993
- 'not with "sdpa". Falling back to `attn_implementation="eager"`.'
994
- )
995
- self.config._attn_implementation = "eager"
996
- elif self.config._attn_implementation != "eager":
997
- logger.warning_once(
998
- "Outputting attentions is only supported with the eager attention implementation, "
999
- f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
1000
- " Setting `output_attentions=False`."
1001
- )
1002
-
1003
- global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
1004
-
1005
- # Create position indices
1006
- rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
1007
- # Calculate distance between positions
1008
- distance = torch.abs(rows - rows.T)
1009
-
1010
- # Create sliding window mask (1 for positions within window, 0 outside)
1011
- window_mask = (
1012
- (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
1013
- )
1014
- # Combine with existing mask
1015
- sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
1016
-
1017
- return global_attention_mask, sliding_window_mask
507
+ return BaseModelOutput(last_hidden_state=hidden_states)
1018
508
 
1019
509
 
1020
510
  class ModernBertPredictionHead(nn.Module):
@@ -1056,84 +546,23 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
1056
546
  def set_output_embeddings(self, new_embeddings: nn.Linear):
1057
547
  self.decoder = new_embeddings
1058
548
 
1059
- @torch.compile(dynamic=True)
1060
- def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
1061
- return self.decoder(self.head(output))
1062
-
549
+ @can_return_tuple
1063
550
  @auto_docstring
1064
551
  def forward(
1065
552
  self,
1066
553
  input_ids: torch.LongTensor | None = None,
1067
554
  attention_mask: torch.Tensor | None = None,
1068
- sliding_window_mask: torch.Tensor | None = None,
1069
555
  position_ids: torch.Tensor | None = None,
1070
556
  inputs_embeds: torch.Tensor | None = None,
1071
557
  labels: torch.Tensor | None = None,
1072
- indices: torch.Tensor | None = None,
1073
- cu_seqlens: torch.Tensor | None = None,
1074
- max_seqlen: int | None = None,
1075
- batch_size: int | None = None,
1076
- seq_len: int | None = None,
1077
- output_attentions: bool | None = None,
1078
- output_hidden_states: bool | None = None,
1079
- return_dict: bool | None = None,
1080
- **kwargs,
558
+ **kwargs: Unpack[TransformersKwargs],
1081
559
  ) -> tuple[torch.Tensor] | MaskedLMOutput:
1082
- r"""
1083
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1084
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1085
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1086
- far-away tokens in the local attention layers when not using Flash Attention.
1087
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1088
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1089
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1090
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1091
- max_seqlen (`int`, *optional*):
1092
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1093
- batch_size (`int`, *optional*):
1094
- Batch size of the input sequences. Used to pad the output tensors.
1095
- seq_len (`int`, *optional*):
1096
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1097
- """
1098
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1099
- self._maybe_set_compile()
1100
-
1101
- if self.config._attn_implementation == "flash_attention_2":
1102
- if indices is None and cu_seqlens is None and max_seqlen is None:
1103
- if batch_size is None and seq_len is None:
1104
- if inputs_embeds is not None:
1105
- batch_size, seq_len = inputs_embeds.shape[:2]
1106
- else:
1107
- batch_size, seq_len = input_ids.shape[:2]
1108
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1109
-
1110
- if attention_mask is None:
1111
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1112
-
1113
- if inputs_embeds is None:
1114
- with torch.no_grad():
1115
- input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1116
- inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1117
- )
1118
- else:
1119
- inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1120
- inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1121
- )
1122
-
1123
560
  outputs = self.model(
1124
561
  input_ids=input_ids,
1125
562
  attention_mask=attention_mask,
1126
- sliding_window_mask=sliding_window_mask,
1127
563
  position_ids=position_ids,
1128
564
  inputs_embeds=inputs_embeds,
1129
- indices=indices,
1130
- cu_seqlens=cu_seqlens,
1131
- max_seqlen=max_seqlen,
1132
- batch_size=batch_size,
1133
- seq_len=seq_len,
1134
- output_attentions=output_attentions,
1135
- output_hidden_states=output_hidden_states,
1136
- return_dict=return_dict,
565
+ **kwargs,
1137
566
  )
1138
567
  last_hidden_state = outputs[0]
1139
568
 
@@ -1147,35 +576,12 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
1147
576
  last_hidden_state = last_hidden_state[mask_tokens]
1148
577
  labels = labels[mask_tokens]
1149
578
 
1150
- logits = (
1151
- self.compiled_head(last_hidden_state)
1152
- if self.config.reference_compile
1153
- else self.decoder(self.head(last_hidden_state))
1154
- )
579
+ logits = self.decoder(self.head(last_hidden_state))
1155
580
 
1156
581
  loss = None
1157
582
  if labels is not None:
1158
583
  loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
1159
584
 
1160
- if self.config._attn_implementation == "flash_attention_2":
1161
- # Logits padding
1162
- with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
1163
- logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
1164
- # Hidden states padding
1165
- if getattr(outputs, "hidden_states", None) is not None:
1166
- padded_hidden_states = []
1167
- for hs in outputs.hidden_states:
1168
- if hs.dim() == 3 and hs.shape[0] == 1:
1169
- hs = hs.squeeze(0)
1170
- padded_hidden_states.append(
1171
- _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
1172
- )
1173
- outputs.hidden_states = tuple(padded_hidden_states)
1174
-
1175
- if not return_dict:
1176
- output = (logits,)
1177
- return ((loss,) + output) if loss is not None else output
1178
-
1179
585
  return MaskedLMOutput(
1180
586
  loss=loss,
1181
587
  logits=logits,
@@ -1203,81 +609,39 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
1203
609
  # Initialize weights and apply final processing
1204
610
  self.post_init()
1205
611
 
612
+ @can_return_tuple
1206
613
  @auto_docstring
1207
614
  def forward(
1208
615
  self,
1209
616
  input_ids: torch.LongTensor | None = None,
1210
617
  attention_mask: torch.Tensor | None = None,
1211
- sliding_window_mask: torch.Tensor | None = None,
1212
618
  position_ids: torch.Tensor | None = None,
1213
619
  inputs_embeds: torch.Tensor | None = None,
1214
620
  labels: torch.Tensor | None = None,
1215
- indices: torch.Tensor | None = None,
1216
- cu_seqlens: torch.Tensor | None = None,
1217
- max_seqlen: int | None = None,
1218
- batch_size: int | None = None,
1219
- seq_len: int | None = None,
1220
- output_attentions: bool | None = None,
1221
- output_hidden_states: bool | None = None,
1222
- return_dict: bool | None = None,
1223
- **kwargs,
621
+ **kwargs: Unpack[TransformersKwargs],
1224
622
  ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
1225
623
  r"""
1226
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1227
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1228
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1229
- far-away tokens in the local attention layers when not using Flash Attention.
1230
624
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1231
625
  Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1232
626
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1233
627
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1234
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1235
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1236
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1237
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1238
- max_seqlen (`int`, *optional*):
1239
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1240
- batch_size (`int`, *optional*):
1241
- Batch size of the input sequences. Used to pad the output tensors.
1242
- seq_len (`int`, *optional*):
1243
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1244
628
  """
1245
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1246
- self._maybe_set_compile()
1247
-
1248
- if input_ids is not None:
1249
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1250
-
1251
- if batch_size is None and seq_len is None:
1252
- if inputs_embeds is not None:
1253
- batch_size, seq_len = inputs_embeds.shape[:2]
1254
- else:
1255
- batch_size, seq_len = input_ids.shape[:2]
1256
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1257
-
1258
- if attention_mask is None:
1259
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1260
-
1261
629
  outputs = self.model(
1262
630
  input_ids=input_ids,
1263
631
  attention_mask=attention_mask,
1264
- sliding_window_mask=sliding_window_mask,
1265
632
  position_ids=position_ids,
1266
633
  inputs_embeds=inputs_embeds,
1267
- indices=indices,
1268
- cu_seqlens=cu_seqlens,
1269
- max_seqlen=max_seqlen,
1270
- batch_size=batch_size,
1271
- seq_len=seq_len,
1272
- output_attentions=output_attentions,
1273
- output_hidden_states=output_hidden_states,
1274
- return_dict=return_dict,
634
+ **kwargs,
1275
635
  )
1276
636
  last_hidden_state = outputs[0]
1277
637
 
1278
638
  if self.config.classifier_pooling == "cls":
1279
639
  last_hidden_state = last_hidden_state[:, 0]
1280
640
  elif self.config.classifier_pooling == "mean":
641
+ if attention_mask is None:
642
+ attention_mask = torch.ones(
643
+ last_hidden_state.shape[:2], device=last_hidden_state.device, dtype=torch.bool
644
+ )
1281
645
  last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
1282
646
  dim=1, keepdim=True
1283
647
  )
@@ -1309,10 +673,6 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
1309
673
  loss_fct = BCEWithLogitsLoss()
1310
674
  loss = loss_fct(logits, labels)
1311
675
 
1312
- if not return_dict:
1313
- output = (logits,)
1314
- return ((loss,) + output) if loss is not None else output
1315
-
1316
676
  return SequenceClassifierOutput(
1317
677
  loss=loss,
1318
678
  logits=logits,
@@ -1339,60 +699,27 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1339
699
  # Initialize weights and apply final processing
1340
700
  self.post_init()
1341
701
 
702
+ @can_return_tuple
1342
703
  @auto_docstring
1343
704
  def forward(
1344
705
  self,
1345
706
  input_ids: torch.LongTensor | None = None,
1346
707
  attention_mask: torch.Tensor | None = None,
1347
- sliding_window_mask: torch.Tensor | None = None,
1348
708
  position_ids: torch.Tensor | None = None,
1349
709
  inputs_embeds: torch.Tensor | None = None,
1350
710
  labels: torch.Tensor | None = None,
1351
- indices: torch.Tensor | None = None,
1352
- cu_seqlens: torch.Tensor | None = None,
1353
- max_seqlen: int | None = None,
1354
- batch_size: int | None = None,
1355
- seq_len: int | None = None,
1356
- output_attentions: bool | None = None,
1357
- output_hidden_states: bool | None = None,
1358
- return_dict: bool | None = None,
1359
- **kwargs,
711
+ **kwargs: Unpack[TransformersKwargs],
1360
712
  ) -> tuple[torch.Tensor] | TokenClassifierOutput:
1361
713
  r"""
1362
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1363
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1364
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1365
- far-away tokens in the local attention layers when not using Flash Attention.
1366
714
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1367
715
  Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1368
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1369
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1370
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1371
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1372
- max_seqlen (`int`, *optional*):
1373
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1374
- batch_size (`int`, *optional*):
1375
- Batch size of the input sequences. Used to pad the output tensors.
1376
- seq_len (`int`, *optional*):
1377
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1378
716
  """
1379
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1380
- self._maybe_set_compile()
1381
-
1382
717
  outputs = self.model(
1383
718
  input_ids=input_ids,
1384
719
  attention_mask=attention_mask,
1385
- sliding_window_mask=sliding_window_mask,
1386
720
  position_ids=position_ids,
1387
721
  inputs_embeds=inputs_embeds,
1388
- indices=indices,
1389
- cu_seqlens=cu_seqlens,
1390
- max_seqlen=max_seqlen,
1391
- batch_size=batch_size,
1392
- seq_len=seq_len,
1393
- output_attentions=output_attentions,
1394
- output_hidden_states=output_hidden_states,
1395
- return_dict=return_dict,
722
+ **kwargs,
1396
723
  )
1397
724
  last_hidden_state = outputs[0]
1398
725
 
@@ -1405,10 +732,6 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1405
732
  loss_fct = CrossEntropyLoss()
1406
733
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1407
734
 
1408
- if not return_dict:
1409
- output = (logits,) + outputs[1:]
1410
- return ((loss,) + output) if loss is not None else output
1411
-
1412
735
  return TokenClassifierOutput(
1413
736
  loss=loss,
1414
737
  logits=logits,
@@ -1430,57 +753,22 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
1430
753
 
1431
754
  self.post_init()
1432
755
 
756
+ @can_return_tuple
1433
757
  @auto_docstring
1434
758
  def forward(
1435
759
  self,
1436
- input_ids: torch.Tensor | None,
760
+ input_ids: torch.Tensor | None = None,
1437
761
  attention_mask: torch.Tensor | None = None,
1438
- sliding_window_mask: torch.Tensor | None = None,
1439
762
  position_ids: torch.Tensor | None = None,
1440
763
  start_positions: torch.Tensor | None = None,
1441
764
  end_positions: torch.Tensor | None = None,
1442
- indices: torch.Tensor | None = None,
1443
- cu_seqlens: torch.Tensor | None = None,
1444
- max_seqlen: int | None = None,
1445
- batch_size: int | None = None,
1446
- seq_len: int | None = None,
1447
- output_attentions: bool | None = None,
1448
- output_hidden_states: bool | None = None,
1449
- return_dict: bool | None = None,
1450
- **kwargs,
765
+ **kwargs: Unpack[TransformersKwargs],
1451
766
  ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
1452
- r"""
1453
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1454
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1455
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1456
- far-away tokens in the local attention layers when not using Flash Attention.
1457
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1458
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1459
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1460
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1461
- max_seqlen (`int`, *optional*):
1462
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1463
- batch_size (`int`, *optional*):
1464
- Batch size of the input sequences. Used to pad the output tensors.
1465
- seq_len (`int`, *optional*):
1466
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1467
- """
1468
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1469
- self._maybe_set_compile()
1470
-
1471
767
  outputs = self.model(
1472
768
  input_ids,
1473
769
  attention_mask=attention_mask,
1474
- sliding_window_mask=sliding_window_mask,
1475
770
  position_ids=position_ids,
1476
- indices=indices,
1477
- cu_seqlens=cu_seqlens,
1478
- max_seqlen=max_seqlen,
1479
- batch_size=batch_size,
1480
- seq_len=seq_len,
1481
- output_attentions=output_attentions,
1482
- output_hidden_states=output_hidden_states,
1483
- return_dict=return_dict,
771
+ **kwargs,
1484
772
  )
1485
773
  last_hidden_state = outputs[0]
1486
774
 
@@ -1496,10 +784,6 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
1496
784
  if start_positions is not None and end_positions is not None:
1497
785
  loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1498
786
 
1499
- if not return_dict:
1500
- output = (start_logits, end_logits) + outputs[1:]
1501
- return ((loss,) + output) if loss is not None else output
1502
-
1503
787
  return QuestionAnsweringModelOutput(
1504
788
  loss=loss,
1505
789
  start_logits=start_logits,
@@ -1527,45 +811,22 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
1527
811
  # Initialize weights and apply final processing
1528
812
  self.post_init()
1529
813
 
814
+ @can_return_tuple
1530
815
  @auto_docstring
1531
816
  def forward(
1532
817
  self,
1533
818
  input_ids: torch.LongTensor | None = None,
1534
819
  attention_mask: torch.Tensor | None = None,
1535
- sliding_window_mask: torch.Tensor | None = None,
1536
820
  position_ids: torch.Tensor | None = None,
1537
821
  inputs_embeds: torch.Tensor | None = None,
1538
822
  labels: torch.Tensor | None = None,
1539
- indices: torch.Tensor | None = None,
1540
- cu_seqlens: torch.Tensor | None = None,
1541
- max_seqlen: int | None = None,
1542
- batch_size: int | None = None,
1543
- seq_len: int | None = None,
1544
- output_attentions: bool | None = None,
1545
- output_hidden_states: bool | None = None,
1546
- return_dict: bool | None = None,
1547
- **kwargs,
823
+ **kwargs: Unpack[TransformersKwargs],
1548
824
  ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
1549
825
  r"""
1550
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1551
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1552
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1553
- far-away tokens in the local attention layers when not using Flash Attention.
1554
826
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1555
827
  Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1556
828
  num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
1557
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1558
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1559
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1560
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1561
- max_seqlen (`int`, *optional*):
1562
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1563
- batch_size (`int`, *optional*):
1564
- Batch size of the input sequences. Used to pad the output tensors.
1565
- seq_len (`int`, *optional*):
1566
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1567
829
  """
1568
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1569
830
  num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1570
831
 
1571
832
  input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@@ -1577,22 +838,12 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
1577
838
  else None
1578
839
  )
1579
840
 
1580
- self._maybe_set_compile()
1581
-
1582
841
  outputs = self.model(
1583
842
  input_ids=input_ids,
1584
843
  attention_mask=attention_mask,
1585
- sliding_window_mask=sliding_window_mask,
1586
844
  position_ids=position_ids,
1587
845
  inputs_embeds=inputs_embeds,
1588
- indices=indices,
1589
- cu_seqlens=cu_seqlens,
1590
- max_seqlen=max_seqlen,
1591
- batch_size=batch_size,
1592
- seq_len=seq_len,
1593
- output_attentions=output_attentions,
1594
- output_hidden_states=output_hidden_states,
1595
- return_dict=return_dict,
846
+ **kwargs,
1596
847
  )
1597
848
  last_hidden_state = outputs[0] # shape (num_choices, seq_len, hidden_size)
1598
849
 
@@ -1624,10 +875,6 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
1624
875
  loss_fct = nn.CrossEntropyLoss()
1625
876
  loss = loss_fct(reshaped_logits, labels)
1626
877
 
1627
- if not return_dict:
1628
- output = (reshaped_logits,) + outputs[1:]
1629
- return ((loss,) + output) if loss is not None else output
1630
-
1631
878
  return MultipleChoiceModelOutput(
1632
879
  loss=loss,
1633
880
  logits=reshaped_logits,