transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1021) hide show
  1. transformers/__init__.py +4 -11
  2. transformers/activations.py +2 -2
  3. transformers/backbone_utils.py +326 -0
  4. transformers/cache_utils.py +11 -2
  5. transformers/cli/serve.py +11 -8
  6. transformers/configuration_utils.py +1 -69
  7. transformers/conversion_mapping.py +146 -26
  8. transformers/convert_slow_tokenizer.py +6 -4
  9. transformers/core_model_loading.py +207 -118
  10. transformers/dependency_versions_check.py +0 -1
  11. transformers/dependency_versions_table.py +7 -8
  12. transformers/file_utils.py +0 -2
  13. transformers/generation/candidate_generator.py +1 -2
  14. transformers/generation/continuous_batching/cache.py +40 -38
  15. transformers/generation/continuous_batching/cache_manager.py +3 -16
  16. transformers/generation/continuous_batching/continuous_api.py +94 -406
  17. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  18. transformers/generation/continuous_batching/requests.py +54 -17
  19. transformers/generation/continuous_batching/scheduler.py +77 -95
  20. transformers/generation/logits_process.py +10 -5
  21. transformers/generation/stopping_criteria.py +1 -2
  22. transformers/generation/utils.py +75 -95
  23. transformers/image_processing_utils.py +0 -3
  24. transformers/image_processing_utils_fast.py +17 -18
  25. transformers/image_transforms.py +44 -13
  26. transformers/image_utils.py +0 -5
  27. transformers/initialization.py +57 -0
  28. transformers/integrations/__init__.py +10 -24
  29. transformers/integrations/accelerate.py +47 -11
  30. transformers/integrations/deepspeed.py +145 -3
  31. transformers/integrations/executorch.py +2 -6
  32. transformers/integrations/finegrained_fp8.py +142 -7
  33. transformers/integrations/flash_attention.py +2 -7
  34. transformers/integrations/hub_kernels.py +18 -7
  35. transformers/integrations/moe.py +226 -106
  36. transformers/integrations/mxfp4.py +47 -34
  37. transformers/integrations/peft.py +488 -176
  38. transformers/integrations/tensor_parallel.py +641 -581
  39. transformers/masking_utils.py +153 -9
  40. transformers/modeling_flash_attention_utils.py +1 -2
  41. transformers/modeling_utils.py +359 -358
  42. transformers/models/__init__.py +6 -0
  43. transformers/models/afmoe/configuration_afmoe.py +14 -4
  44. transformers/models/afmoe/modeling_afmoe.py +8 -8
  45. transformers/models/afmoe/modular_afmoe.py +7 -7
  46. transformers/models/aimv2/configuration_aimv2.py +2 -7
  47. transformers/models/aimv2/modeling_aimv2.py +26 -24
  48. transformers/models/aimv2/modular_aimv2.py +8 -12
  49. transformers/models/albert/configuration_albert.py +8 -1
  50. transformers/models/albert/modeling_albert.py +3 -3
  51. transformers/models/align/configuration_align.py +8 -5
  52. transformers/models/align/modeling_align.py +22 -24
  53. transformers/models/altclip/configuration_altclip.py +4 -6
  54. transformers/models/altclip/modeling_altclip.py +30 -26
  55. transformers/models/apertus/configuration_apertus.py +5 -7
  56. transformers/models/apertus/modeling_apertus.py +4 -4
  57. transformers/models/apertus/modular_apertus.py +8 -10
  58. transformers/models/arcee/configuration_arcee.py +5 -7
  59. transformers/models/arcee/modeling_arcee.py +4 -4
  60. transformers/models/aria/configuration_aria.py +11 -21
  61. transformers/models/aria/modeling_aria.py +39 -36
  62. transformers/models/aria/modular_aria.py +33 -39
  63. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
  64. transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
  65. transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
  66. transformers/models/auto/auto_factory.py +8 -6
  67. transformers/models/auto/configuration_auto.py +22 -0
  68. transformers/models/auto/image_processing_auto.py +17 -13
  69. transformers/models/auto/modeling_auto.py +15 -0
  70. transformers/models/auto/processing_auto.py +9 -18
  71. transformers/models/auto/tokenization_auto.py +17 -15
  72. transformers/models/autoformer/modeling_autoformer.py +2 -1
  73. transformers/models/aya_vision/configuration_aya_vision.py +4 -0
  74. transformers/models/aya_vision/modeling_aya_vision.py +29 -62
  75. transformers/models/aya_vision/modular_aya_vision.py +20 -45
  76. transformers/models/bamba/configuration_bamba.py +17 -7
  77. transformers/models/bamba/modeling_bamba.py +23 -55
  78. transformers/models/bamba/modular_bamba.py +19 -54
  79. transformers/models/bark/configuration_bark.py +2 -1
  80. transformers/models/bark/modeling_bark.py +24 -10
  81. transformers/models/bart/configuration_bart.py +9 -4
  82. transformers/models/bart/modeling_bart.py +9 -12
  83. transformers/models/beit/configuration_beit.py +2 -4
  84. transformers/models/beit/image_processing_beit_fast.py +3 -3
  85. transformers/models/beit/modeling_beit.py +14 -9
  86. transformers/models/bert/configuration_bert.py +12 -1
  87. transformers/models/bert/modeling_bert.py +6 -30
  88. transformers/models/bert_generation/configuration_bert_generation.py +17 -1
  89. transformers/models/bert_generation/modeling_bert_generation.py +6 -6
  90. transformers/models/big_bird/configuration_big_bird.py +12 -8
  91. transformers/models/big_bird/modeling_big_bird.py +0 -15
  92. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
  93. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
  94. transformers/models/biogpt/configuration_biogpt.py +8 -1
  95. transformers/models/biogpt/modeling_biogpt.py +4 -8
  96. transformers/models/biogpt/modular_biogpt.py +1 -5
  97. transformers/models/bit/configuration_bit.py +2 -4
  98. transformers/models/bit/modeling_bit.py +6 -5
  99. transformers/models/bitnet/configuration_bitnet.py +5 -7
  100. transformers/models/bitnet/modeling_bitnet.py +3 -4
  101. transformers/models/bitnet/modular_bitnet.py +3 -4
  102. transformers/models/blenderbot/configuration_blenderbot.py +8 -4
  103. transformers/models/blenderbot/modeling_blenderbot.py +4 -4
  104. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
  105. transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
  106. transformers/models/blip/configuration_blip.py +9 -9
  107. transformers/models/blip/modeling_blip.py +55 -37
  108. transformers/models/blip_2/configuration_blip_2.py +2 -1
  109. transformers/models/blip_2/modeling_blip_2.py +81 -56
  110. transformers/models/bloom/configuration_bloom.py +5 -1
  111. transformers/models/bloom/modeling_bloom.py +2 -1
  112. transformers/models/blt/configuration_blt.py +23 -12
  113. transformers/models/blt/modeling_blt.py +20 -14
  114. transformers/models/blt/modular_blt.py +70 -10
  115. transformers/models/bridgetower/configuration_bridgetower.py +7 -1
  116. transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
  117. transformers/models/bridgetower/modeling_bridgetower.py +29 -15
  118. transformers/models/bros/configuration_bros.py +24 -17
  119. transformers/models/camembert/configuration_camembert.py +8 -1
  120. transformers/models/camembert/modeling_camembert.py +6 -6
  121. transformers/models/canine/configuration_canine.py +4 -1
  122. transformers/models/chameleon/configuration_chameleon.py +5 -7
  123. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
  124. transformers/models/chameleon/modeling_chameleon.py +82 -36
  125. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
  126. transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
  127. transformers/models/clap/configuration_clap.py +4 -8
  128. transformers/models/clap/modeling_clap.py +21 -22
  129. transformers/models/clip/configuration_clip.py +4 -1
  130. transformers/models/clip/image_processing_clip_fast.py +9 -0
  131. transformers/models/clip/modeling_clip.py +25 -22
  132. transformers/models/clipseg/configuration_clipseg.py +4 -1
  133. transformers/models/clipseg/modeling_clipseg.py +27 -25
  134. transformers/models/clipseg/processing_clipseg.py +11 -3
  135. transformers/models/clvp/configuration_clvp.py +14 -2
  136. transformers/models/clvp/modeling_clvp.py +19 -30
  137. transformers/models/codegen/configuration_codegen.py +4 -3
  138. transformers/models/codegen/modeling_codegen.py +2 -1
  139. transformers/models/cohere/configuration_cohere.py +5 -7
  140. transformers/models/cohere/modeling_cohere.py +4 -4
  141. transformers/models/cohere/modular_cohere.py +3 -3
  142. transformers/models/cohere2/configuration_cohere2.py +6 -8
  143. transformers/models/cohere2/modeling_cohere2.py +4 -4
  144. transformers/models/cohere2/modular_cohere2.py +9 -11
  145. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  146. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
  147. transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
  148. transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
  149. transformers/models/colqwen2/modeling_colqwen2.py +7 -6
  150. transformers/models/colqwen2/modular_colqwen2.py +7 -6
  151. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
  152. transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
  153. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
  154. transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
  155. transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
  156. transformers/models/convbert/configuration_convbert.py +11 -7
  157. transformers/models/convnext/configuration_convnext.py +2 -4
  158. transformers/models/convnext/image_processing_convnext_fast.py +2 -2
  159. transformers/models/convnext/modeling_convnext.py +7 -6
  160. transformers/models/convnextv2/configuration_convnextv2.py +2 -4
  161. transformers/models/convnextv2/modeling_convnextv2.py +7 -6
  162. transformers/models/cpmant/configuration_cpmant.py +4 -0
  163. transformers/models/csm/configuration_csm.py +9 -15
  164. transformers/models/csm/modeling_csm.py +3 -3
  165. transformers/models/ctrl/configuration_ctrl.py +16 -0
  166. transformers/models/ctrl/modeling_ctrl.py +13 -25
  167. transformers/models/cwm/configuration_cwm.py +5 -7
  168. transformers/models/cwm/modeling_cwm.py +4 -4
  169. transformers/models/d_fine/configuration_d_fine.py +10 -56
  170. transformers/models/d_fine/modeling_d_fine.py +728 -868
  171. transformers/models/d_fine/modular_d_fine.py +335 -412
  172. transformers/models/dab_detr/configuration_dab_detr.py +22 -48
  173. transformers/models/dab_detr/modeling_dab_detr.py +11 -7
  174. transformers/models/dac/modeling_dac.py +1 -1
  175. transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
  176. transformers/models/data2vec/configuration_data2vec_text.py +11 -2
  177. transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
  178. transformers/models/data2vec/modeling_data2vec_text.py +6 -6
  179. transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
  180. transformers/models/dbrx/configuration_dbrx.py +11 -3
  181. transformers/models/dbrx/modeling_dbrx.py +6 -6
  182. transformers/models/dbrx/modular_dbrx.py +6 -6
  183. transformers/models/deberta/configuration_deberta.py +6 -0
  184. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
  185. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
  186. transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
  187. transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
  188. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
  189. transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
  190. transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
  191. transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
  192. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
  193. transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
  194. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
  195. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
  196. transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
  197. transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
  198. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
  199. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
  200. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
  201. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
  202. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
  203. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
  204. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
  205. transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
  206. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
  207. transformers/models/deit/modeling_deit.py +11 -7
  208. transformers/models/depth_anything/configuration_depth_anything.py +12 -42
  209. transformers/models/depth_anything/modeling_depth_anything.py +5 -3
  210. transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
  211. transformers/models/depth_pro/modeling_depth_pro.py +8 -4
  212. transformers/models/detr/configuration_detr.py +18 -49
  213. transformers/models/detr/image_processing_detr_fast.py +11 -11
  214. transformers/models/detr/modeling_detr.py +695 -734
  215. transformers/models/dia/configuration_dia.py +4 -7
  216. transformers/models/dia/generation_dia.py +8 -17
  217. transformers/models/dia/modeling_dia.py +7 -7
  218. transformers/models/dia/modular_dia.py +4 -4
  219. transformers/models/diffllama/configuration_diffllama.py +5 -7
  220. transformers/models/diffllama/modeling_diffllama.py +3 -8
  221. transformers/models/diffllama/modular_diffllama.py +2 -7
  222. transformers/models/dinat/configuration_dinat.py +2 -4
  223. transformers/models/dinat/modeling_dinat.py +7 -6
  224. transformers/models/dinov2/configuration_dinov2.py +2 -4
  225. transformers/models/dinov2/modeling_dinov2.py +9 -8
  226. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
  227. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
  228. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
  229. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
  230. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
  231. transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
  232. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
  233. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
  234. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
  235. transformers/models/distilbert/configuration_distilbert.py +8 -1
  236. transformers/models/distilbert/modeling_distilbert.py +3 -3
  237. transformers/models/doge/configuration_doge.py +17 -7
  238. transformers/models/doge/modeling_doge.py +4 -4
  239. transformers/models/doge/modular_doge.py +20 -10
  240. transformers/models/donut/image_processing_donut_fast.py +4 -4
  241. transformers/models/dots1/configuration_dots1.py +16 -7
  242. transformers/models/dots1/modeling_dots1.py +4 -4
  243. transformers/models/dpr/configuration_dpr.py +19 -1
  244. transformers/models/dpt/configuration_dpt.py +23 -65
  245. transformers/models/dpt/image_processing_dpt_fast.py +5 -5
  246. transformers/models/dpt/modeling_dpt.py +19 -15
  247. transformers/models/dpt/modular_dpt.py +4 -4
  248. transformers/models/edgetam/configuration_edgetam.py +1 -1
  249. transformers/models/edgetam/modeling_edgetam.py +53 -53
  250. transformers/models/edgetam/modular_edgetam.py +5 -7
  251. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
  252. transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
  253. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
  254. transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
  255. transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
  256. transformers/models/electra/configuration_electra.py +13 -2
  257. transformers/models/electra/modeling_electra.py +6 -6
  258. transformers/models/emu3/configuration_emu3.py +12 -10
  259. transformers/models/emu3/modeling_emu3.py +84 -47
  260. transformers/models/emu3/modular_emu3.py +77 -39
  261. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
  262. transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
  263. transformers/models/eomt/configuration_eomt.py +12 -13
  264. transformers/models/eomt/image_processing_eomt_fast.py +3 -3
  265. transformers/models/eomt/modeling_eomt.py +3 -3
  266. transformers/models/eomt/modular_eomt.py +17 -17
  267. transformers/models/eomt_dinov3/__init__.py +28 -0
  268. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  269. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  270. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  271. transformers/models/ernie/configuration_ernie.py +24 -2
  272. transformers/models/ernie/modeling_ernie.py +6 -30
  273. transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
  274. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  275. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
  276. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
  277. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
  278. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
  279. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
  280. transformers/models/esm/configuration_esm.py +9 -11
  281. transformers/models/esm/modeling_esm.py +3 -3
  282. transformers/models/esm/modeling_esmfold.py +1 -6
  283. transformers/models/esm/openfold_utils/protein.py +2 -3
  284. transformers/models/evolla/configuration_evolla.py +21 -8
  285. transformers/models/evolla/modeling_evolla.py +11 -7
  286. transformers/models/evolla/modular_evolla.py +5 -1
  287. transformers/models/exaone4/configuration_exaone4.py +8 -5
  288. transformers/models/exaone4/modeling_exaone4.py +4 -4
  289. transformers/models/exaone4/modular_exaone4.py +11 -8
  290. transformers/models/exaone_moe/__init__.py +27 -0
  291. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  292. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  293. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  294. transformers/models/falcon/configuration_falcon.py +9 -1
  295. transformers/models/falcon/modeling_falcon.py +3 -8
  296. transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
  297. transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
  298. transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
  299. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
  300. transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
  301. transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
  302. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
  303. transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
  304. transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
  305. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
  306. transformers/models/flaubert/configuration_flaubert.py +10 -4
  307. transformers/models/flaubert/modeling_flaubert.py +1 -1
  308. transformers/models/flava/configuration_flava.py +4 -3
  309. transformers/models/flava/image_processing_flava_fast.py +4 -4
  310. transformers/models/flava/modeling_flava.py +36 -28
  311. transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
  312. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
  313. transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
  314. transformers/models/florence2/configuration_florence2.py +4 -0
  315. transformers/models/florence2/modeling_florence2.py +57 -32
  316. transformers/models/florence2/modular_florence2.py +48 -26
  317. transformers/models/fnet/configuration_fnet.py +6 -1
  318. transformers/models/focalnet/configuration_focalnet.py +2 -4
  319. transformers/models/focalnet/modeling_focalnet.py +10 -7
  320. transformers/models/fsmt/configuration_fsmt.py +12 -16
  321. transformers/models/funnel/configuration_funnel.py +8 -0
  322. transformers/models/fuyu/configuration_fuyu.py +5 -8
  323. transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
  324. transformers/models/fuyu/modeling_fuyu.py +24 -23
  325. transformers/models/gemma/configuration_gemma.py +5 -7
  326. transformers/models/gemma/modeling_gemma.py +4 -4
  327. transformers/models/gemma/modular_gemma.py +5 -7
  328. transformers/models/gemma2/configuration_gemma2.py +5 -7
  329. transformers/models/gemma2/modeling_gemma2.py +4 -4
  330. transformers/models/gemma2/modular_gemma2.py +8 -10
  331. transformers/models/gemma3/configuration_gemma3.py +28 -22
  332. transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
  333. transformers/models/gemma3/modeling_gemma3.py +37 -33
  334. transformers/models/gemma3/modular_gemma3.py +46 -42
  335. transformers/models/gemma3n/configuration_gemma3n.py +35 -22
  336. transformers/models/gemma3n/modeling_gemma3n.py +86 -58
  337. transformers/models/gemma3n/modular_gemma3n.py +112 -75
  338. transformers/models/git/configuration_git.py +5 -7
  339. transformers/models/git/modeling_git.py +31 -41
  340. transformers/models/glm/configuration_glm.py +7 -9
  341. transformers/models/glm/modeling_glm.py +4 -4
  342. transformers/models/glm4/configuration_glm4.py +7 -9
  343. transformers/models/glm4/modeling_glm4.py +4 -4
  344. transformers/models/glm46v/configuration_glm46v.py +4 -0
  345. transformers/models/glm46v/image_processing_glm46v.py +5 -2
  346. transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
  347. transformers/models/glm46v/modeling_glm46v.py +91 -46
  348. transformers/models/glm46v/modular_glm46v.py +4 -0
  349. transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
  350. transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
  351. transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
  352. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
  353. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
  354. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
  355. transformers/models/glm4v/configuration_glm4v.py +12 -8
  356. transformers/models/glm4v/image_processing_glm4v.py +5 -2
  357. transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
  358. transformers/models/glm4v/modeling_glm4v.py +120 -63
  359. transformers/models/glm4v/modular_glm4v.py +82 -50
  360. transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
  361. transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
  362. transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
  363. transformers/models/glm_image/configuration_glm_image.py +26 -20
  364. transformers/models/glm_image/image_processing_glm_image.py +1 -1
  365. transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
  366. transformers/models/glm_image/modeling_glm_image.py +337 -236
  367. transformers/models/glm_image/modular_glm_image.py +415 -255
  368. transformers/models/glm_image/processing_glm_image.py +65 -17
  369. transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
  370. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  371. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  372. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  373. transformers/models/glmasr/modeling_glmasr.py +34 -28
  374. transformers/models/glmasr/modular_glmasr.py +23 -11
  375. transformers/models/glpn/image_processing_glpn_fast.py +3 -3
  376. transformers/models/glpn/modeling_glpn.py +4 -2
  377. transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
  378. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
  379. transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
  380. transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
  381. transformers/models/gpt2/configuration_gpt2.py +13 -1
  382. transformers/models/gpt2/modeling_gpt2.py +5 -5
  383. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
  384. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
  385. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
  386. transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
  387. transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
  388. transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
  389. transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
  390. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
  391. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
  392. transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
  393. transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
  394. transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
  395. transformers/models/gptj/configuration_gptj.py +4 -4
  396. transformers/models/gptj/modeling_gptj.py +3 -7
  397. transformers/models/granite/configuration_granite.py +5 -7
  398. transformers/models/granite/modeling_granite.py +4 -4
  399. transformers/models/granite_speech/modeling_granite_speech.py +63 -37
  400. transformers/models/granitemoe/configuration_granitemoe.py +5 -7
  401. transformers/models/granitemoe/modeling_granitemoe.py +4 -4
  402. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
  403. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
  404. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
  405. transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
  406. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
  407. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
  408. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
  409. transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
  410. transformers/models/groupvit/configuration_groupvit.py +4 -1
  411. transformers/models/groupvit/modeling_groupvit.py +29 -22
  412. transformers/models/helium/configuration_helium.py +5 -7
  413. transformers/models/helium/modeling_helium.py +4 -4
  414. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
  415. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
  416. transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
  417. transformers/models/hiera/configuration_hiera.py +2 -4
  418. transformers/models/hiera/modeling_hiera.py +11 -8
  419. transformers/models/hubert/configuration_hubert.py +4 -1
  420. transformers/models/hubert/modeling_hubert.py +7 -4
  421. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
  422. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
  423. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
  424. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
  425. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
  426. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
  427. transformers/models/ibert/configuration_ibert.py +4 -1
  428. transformers/models/idefics/configuration_idefics.py +5 -7
  429. transformers/models/idefics/modeling_idefics.py +3 -4
  430. transformers/models/idefics/vision.py +5 -4
  431. transformers/models/idefics2/configuration_idefics2.py +1 -2
  432. transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
  433. transformers/models/idefics2/modeling_idefics2.py +72 -50
  434. transformers/models/idefics3/configuration_idefics3.py +1 -3
  435. transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
  436. transformers/models/idefics3/modeling_idefics3.py +63 -40
  437. transformers/models/ijepa/modeling_ijepa.py +3 -3
  438. transformers/models/imagegpt/configuration_imagegpt.py +9 -1
  439. transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
  440. transformers/models/imagegpt/modeling_imagegpt.py +8 -4
  441. transformers/models/informer/modeling_informer.py +3 -3
  442. transformers/models/instructblip/configuration_instructblip.py +2 -1
  443. transformers/models/instructblip/modeling_instructblip.py +65 -39
  444. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
  445. transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
  446. transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
  447. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
  448. transformers/models/internvl/configuration_internvl.py +5 -0
  449. transformers/models/internvl/modeling_internvl.py +35 -55
  450. transformers/models/internvl/modular_internvl.py +26 -38
  451. transformers/models/internvl/video_processing_internvl.py +2 -2
  452. transformers/models/jais2/configuration_jais2.py +5 -7
  453. transformers/models/jais2/modeling_jais2.py +4 -4
  454. transformers/models/jamba/configuration_jamba.py +5 -7
  455. transformers/models/jamba/modeling_jamba.py +4 -4
  456. transformers/models/jamba/modular_jamba.py +3 -3
  457. transformers/models/janus/image_processing_janus.py +2 -2
  458. transformers/models/janus/image_processing_janus_fast.py +8 -8
  459. transformers/models/janus/modeling_janus.py +63 -146
  460. transformers/models/janus/modular_janus.py +62 -20
  461. transformers/models/jetmoe/configuration_jetmoe.py +6 -4
  462. transformers/models/jetmoe/modeling_jetmoe.py +3 -3
  463. transformers/models/jetmoe/modular_jetmoe.py +3 -3
  464. transformers/models/kosmos2/configuration_kosmos2.py +10 -8
  465. transformers/models/kosmos2/modeling_kosmos2.py +56 -34
  466. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
  467. transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
  468. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
  469. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
  470. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
  471. transformers/models/lasr/configuration_lasr.py +2 -4
  472. transformers/models/lasr/modeling_lasr.py +3 -3
  473. transformers/models/lasr/modular_lasr.py +3 -3
  474. transformers/models/layoutlm/configuration_layoutlm.py +14 -1
  475. transformers/models/layoutlm/modeling_layoutlm.py +3 -3
  476. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
  477. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
  478. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
  479. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
  480. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
  481. transformers/models/led/configuration_led.py +7 -8
  482. transformers/models/levit/image_processing_levit_fast.py +4 -4
  483. transformers/models/lfm2/configuration_lfm2.py +5 -7
  484. transformers/models/lfm2/modeling_lfm2.py +4 -4
  485. transformers/models/lfm2/modular_lfm2.py +3 -3
  486. transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
  487. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
  488. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  489. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
  490. transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
  491. transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
  492. transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
  493. transformers/models/lightglue/modeling_lightglue.py +3 -3
  494. transformers/models/lightglue/modular_lightglue.py +3 -3
  495. transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
  496. transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
  497. transformers/models/lilt/configuration_lilt.py +6 -1
  498. transformers/models/llama/configuration_llama.py +5 -7
  499. transformers/models/llama/modeling_llama.py +4 -4
  500. transformers/models/llama4/configuration_llama4.py +67 -47
  501. transformers/models/llama4/image_processing_llama4_fast.py +3 -3
  502. transformers/models/llama4/modeling_llama4.py +46 -44
  503. transformers/models/llava/configuration_llava.py +10 -0
  504. transformers/models/llava/image_processing_llava_fast.py +3 -3
  505. transformers/models/llava/modeling_llava.py +38 -65
  506. transformers/models/llava_next/configuration_llava_next.py +2 -1
  507. transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
  508. transformers/models/llava_next/modeling_llava_next.py +61 -60
  509. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
  510. transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
  511. transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
  512. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
  513. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
  514. transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
  515. transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
  516. transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
  517. transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
  518. transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
  519. transformers/models/longformer/configuration_longformer.py +4 -1
  520. transformers/models/longt5/configuration_longt5.py +9 -6
  521. transformers/models/longt5/modeling_longt5.py +2 -1
  522. transformers/models/luke/configuration_luke.py +8 -1
  523. transformers/models/lw_detr/configuration_lw_detr.py +19 -31
  524. transformers/models/lw_detr/modeling_lw_detr.py +43 -44
  525. transformers/models/lw_detr/modular_lw_detr.py +36 -38
  526. transformers/models/lxmert/configuration_lxmert.py +16 -0
  527. transformers/models/m2m_100/configuration_m2m_100.py +7 -8
  528. transformers/models/m2m_100/modeling_m2m_100.py +3 -3
  529. transformers/models/mamba/configuration_mamba.py +5 -2
  530. transformers/models/mamba/modeling_mamba.py +18 -26
  531. transformers/models/mamba2/configuration_mamba2.py +5 -7
  532. transformers/models/mamba2/modeling_mamba2.py +22 -33
  533. transformers/models/marian/configuration_marian.py +10 -4
  534. transformers/models/marian/modeling_marian.py +4 -4
  535. transformers/models/markuplm/configuration_markuplm.py +4 -6
  536. transformers/models/markuplm/modeling_markuplm.py +3 -3
  537. transformers/models/mask2former/configuration_mask2former.py +12 -47
  538. transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
  539. transformers/models/mask2former/modeling_mask2former.py +18 -12
  540. transformers/models/maskformer/configuration_maskformer.py +14 -45
  541. transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
  542. transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
  543. transformers/models/maskformer/modeling_maskformer.py +15 -9
  544. transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
  545. transformers/models/mbart/configuration_mbart.py +9 -4
  546. transformers/models/mbart/modeling_mbart.py +9 -6
  547. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
  548. transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
  549. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  550. transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
  551. transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
  552. transformers/models/mgp_str/modeling_mgp_str.py +4 -2
  553. transformers/models/mimi/configuration_mimi.py +4 -0
  554. transformers/models/mimi/modeling_mimi.py +40 -36
  555. transformers/models/minimax/configuration_minimax.py +8 -11
  556. transformers/models/minimax/modeling_minimax.py +5 -5
  557. transformers/models/minimax/modular_minimax.py +9 -12
  558. transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
  559. transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
  560. transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
  561. transformers/models/ministral/configuration_ministral.py +5 -7
  562. transformers/models/ministral/modeling_ministral.py +4 -4
  563. transformers/models/ministral/modular_ministral.py +5 -8
  564. transformers/models/ministral3/configuration_ministral3.py +4 -4
  565. transformers/models/ministral3/modeling_ministral3.py +4 -4
  566. transformers/models/ministral3/modular_ministral3.py +3 -3
  567. transformers/models/mistral/configuration_mistral.py +5 -7
  568. transformers/models/mistral/modeling_mistral.py +4 -4
  569. transformers/models/mistral/modular_mistral.py +3 -3
  570. transformers/models/mistral3/configuration_mistral3.py +4 -0
  571. transformers/models/mistral3/modeling_mistral3.py +36 -40
  572. transformers/models/mistral3/modular_mistral3.py +31 -32
  573. transformers/models/mixtral/configuration_mixtral.py +8 -11
  574. transformers/models/mixtral/modeling_mixtral.py +4 -4
  575. transformers/models/mlcd/modeling_mlcd.py +7 -5
  576. transformers/models/mlcd/modular_mlcd.py +7 -5
  577. transformers/models/mllama/configuration_mllama.py +5 -7
  578. transformers/models/mllama/image_processing_mllama_fast.py +6 -5
  579. transformers/models/mllama/modeling_mllama.py +19 -19
  580. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
  581. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
  582. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
  583. transformers/models/mobilebert/configuration_mobilebert.py +4 -1
  584. transformers/models/mobilebert/modeling_mobilebert.py +3 -3
  585. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
  586. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
  587. transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
  588. transformers/models/mobilevit/modeling_mobilevit.py +4 -2
  589. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
  590. transformers/models/modernbert/configuration_modernbert.py +46 -21
  591. transformers/models/modernbert/modeling_modernbert.py +146 -899
  592. transformers/models/modernbert/modular_modernbert.py +185 -908
  593. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
  594. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
  595. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
  596. transformers/models/moonshine/configuration_moonshine.py +12 -7
  597. transformers/models/moonshine/modeling_moonshine.py +7 -7
  598. transformers/models/moonshine/modular_moonshine.py +19 -13
  599. transformers/models/moshi/configuration_moshi.py +28 -2
  600. transformers/models/moshi/modeling_moshi.py +4 -9
  601. transformers/models/mpnet/configuration_mpnet.py +6 -1
  602. transformers/models/mpt/configuration_mpt.py +16 -0
  603. transformers/models/mra/configuration_mra.py +8 -1
  604. transformers/models/mt5/configuration_mt5.py +9 -5
  605. transformers/models/mt5/modeling_mt5.py +5 -8
  606. transformers/models/musicgen/configuration_musicgen.py +12 -7
  607. transformers/models/musicgen/modeling_musicgen.py +6 -5
  608. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
  609. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
  610. transformers/models/mvp/configuration_mvp.py +8 -4
  611. transformers/models/mvp/modeling_mvp.py +6 -4
  612. transformers/models/nanochat/configuration_nanochat.py +5 -7
  613. transformers/models/nanochat/modeling_nanochat.py +4 -4
  614. transformers/models/nanochat/modular_nanochat.py +4 -4
  615. transformers/models/nemotron/configuration_nemotron.py +5 -7
  616. transformers/models/nemotron/modeling_nemotron.py +4 -14
  617. transformers/models/nllb/tokenization_nllb.py +7 -5
  618. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
  619. transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
  620. transformers/models/nougat/image_processing_nougat_fast.py +8 -8
  621. transformers/models/nystromformer/configuration_nystromformer.py +8 -1
  622. transformers/models/olmo/configuration_olmo.py +5 -7
  623. transformers/models/olmo/modeling_olmo.py +4 -4
  624. transformers/models/olmo/modular_olmo.py +3 -3
  625. transformers/models/olmo2/configuration_olmo2.py +9 -11
  626. transformers/models/olmo2/modeling_olmo2.py +4 -4
  627. transformers/models/olmo2/modular_olmo2.py +7 -7
  628. transformers/models/olmo3/configuration_olmo3.py +10 -11
  629. transformers/models/olmo3/modeling_olmo3.py +4 -4
  630. transformers/models/olmo3/modular_olmo3.py +13 -14
  631. transformers/models/olmoe/configuration_olmoe.py +5 -7
  632. transformers/models/olmoe/modeling_olmoe.py +4 -4
  633. transformers/models/olmoe/modular_olmoe.py +3 -3
  634. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
  635. transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
  636. transformers/models/oneformer/configuration_oneformer.py +9 -46
  637. transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
  638. transformers/models/oneformer/modeling_oneformer.py +14 -9
  639. transformers/models/openai/configuration_openai.py +16 -0
  640. transformers/models/opt/configuration_opt.py +6 -6
  641. transformers/models/opt/modeling_opt.py +5 -5
  642. transformers/models/ovis2/configuration_ovis2.py +4 -0
  643. transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
  644. transformers/models/ovis2/modeling_ovis2.py +58 -99
  645. transformers/models/ovis2/modular_ovis2.py +52 -13
  646. transformers/models/owlv2/configuration_owlv2.py +4 -1
  647. transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
  648. transformers/models/owlv2/modeling_owlv2.py +40 -27
  649. transformers/models/owlv2/modular_owlv2.py +5 -5
  650. transformers/models/owlvit/configuration_owlvit.py +4 -1
  651. transformers/models/owlvit/modeling_owlvit.py +40 -27
  652. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
  653. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
  654. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
  655. transformers/models/paligemma/configuration_paligemma.py +4 -0
  656. transformers/models/paligemma/modeling_paligemma.py +30 -26
  657. transformers/models/parakeet/configuration_parakeet.py +2 -4
  658. transformers/models/parakeet/modeling_parakeet.py +3 -3
  659. transformers/models/parakeet/modular_parakeet.py +3 -3
  660. transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
  661. transformers/models/patchtst/modeling_patchtst.py +3 -3
  662. transformers/models/pe_audio/modeling_pe_audio.py +4 -4
  663. transformers/models/pe_audio/modular_pe_audio.py +1 -1
  664. transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
  665. transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
  666. transformers/models/pe_video/modeling_pe_video.py +36 -24
  667. transformers/models/pe_video/modular_pe_video.py +36 -23
  668. transformers/models/pegasus/configuration_pegasus.py +8 -5
  669. transformers/models/pegasus/modeling_pegasus.py +4 -4
  670. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
  671. transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
  672. transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
  673. transformers/models/perceiver/modeling_perceiver.py +17 -9
  674. transformers/models/perception_lm/modeling_perception_lm.py +26 -27
  675. transformers/models/perception_lm/modular_perception_lm.py +27 -25
  676. transformers/models/persimmon/configuration_persimmon.py +5 -7
  677. transformers/models/persimmon/modeling_persimmon.py +5 -5
  678. transformers/models/phi/configuration_phi.py +8 -6
  679. transformers/models/phi/modeling_phi.py +4 -4
  680. transformers/models/phi/modular_phi.py +3 -3
  681. transformers/models/phi3/configuration_phi3.py +9 -11
  682. transformers/models/phi3/modeling_phi3.py +4 -4
  683. transformers/models/phi3/modular_phi3.py +3 -3
  684. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
  685. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
  686. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
  687. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
  688. transformers/models/phimoe/configuration_phimoe.py +5 -7
  689. transformers/models/phimoe/modeling_phimoe.py +15 -39
  690. transformers/models/phimoe/modular_phimoe.py +12 -7
  691. transformers/models/pix2struct/configuration_pix2struct.py +12 -9
  692. transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
  693. transformers/models/pix2struct/modeling_pix2struct.py +14 -7
  694. transformers/models/pixio/configuration_pixio.py +2 -4
  695. transformers/models/pixio/modeling_pixio.py +9 -8
  696. transformers/models/pixio/modular_pixio.py +4 -2
  697. transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
  698. transformers/models/pixtral/modeling_pixtral.py +9 -12
  699. transformers/models/plbart/configuration_plbart.py +8 -5
  700. transformers/models/plbart/modeling_plbart.py +9 -7
  701. transformers/models/plbart/modular_plbart.py +1 -1
  702. transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
  703. transformers/models/pop2piano/configuration_pop2piano.py +7 -6
  704. transformers/models/pop2piano/modeling_pop2piano.py +2 -1
  705. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  706. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  707. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  708. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  709. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  710. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
  711. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
  712. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
  713. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
  714. transformers/models/prophetnet/configuration_prophetnet.py +11 -10
  715. transformers/models/prophetnet/modeling_prophetnet.py +12 -23
  716. transformers/models/pvt/image_processing_pvt.py +7 -7
  717. transformers/models/pvt/image_processing_pvt_fast.py +1 -1
  718. transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
  719. transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
  720. transformers/models/qwen2/configuration_qwen2.py +14 -4
  721. transformers/models/qwen2/modeling_qwen2.py +4 -4
  722. transformers/models/qwen2/modular_qwen2.py +3 -3
  723. transformers/models/qwen2/tokenization_qwen2.py +0 -4
  724. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
  725. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
  726. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
  727. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
  728. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
  729. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
  730. transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
  731. transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
  732. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  733. transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
  734. transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
  735. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
  736. transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
  737. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
  738. transformers/models/qwen3/configuration_qwen3.py +15 -5
  739. transformers/models/qwen3/modeling_qwen3.py +4 -4
  740. transformers/models/qwen3/modular_qwen3.py +3 -3
  741. transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
  742. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  743. transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
  744. transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
  745. transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
  746. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
  747. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
  748. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
  749. transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
  750. transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
  751. transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
  752. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
  753. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
  754. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
  755. transformers/models/rag/configuration_rag.py +6 -6
  756. transformers/models/rag/modeling_rag.py +3 -3
  757. transformers/models/rag/retrieval_rag.py +1 -1
  758. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
  759. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
  760. transformers/models/reformer/configuration_reformer.py +7 -7
  761. transformers/models/rembert/configuration_rembert.py +8 -1
  762. transformers/models/rembert/modeling_rembert.py +0 -22
  763. transformers/models/resnet/configuration_resnet.py +2 -4
  764. transformers/models/resnet/modeling_resnet.py +6 -5
  765. transformers/models/roberta/configuration_roberta.py +11 -2
  766. transformers/models/roberta/modeling_roberta.py +6 -6
  767. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
  768. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
  769. transformers/models/roc_bert/configuration_roc_bert.py +8 -1
  770. transformers/models/roc_bert/modeling_roc_bert.py +6 -41
  771. transformers/models/roformer/configuration_roformer.py +13 -2
  772. transformers/models/roformer/modeling_roformer.py +0 -14
  773. transformers/models/rt_detr/configuration_rt_detr.py +8 -49
  774. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
  775. transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
  776. transformers/models/rt_detr/modeling_rt_detr.py +578 -737
  777. transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
  778. transformers/models/rt_detr/modular_rt_detr.py +1508 -6
  779. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
  780. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
  781. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
  782. transformers/models/rwkv/configuration_rwkv.py +2 -3
  783. transformers/models/rwkv/modeling_rwkv.py +0 -23
  784. transformers/models/sam/configuration_sam.py +2 -0
  785. transformers/models/sam/image_processing_sam_fast.py +4 -4
  786. transformers/models/sam/modeling_sam.py +13 -8
  787. transformers/models/sam/processing_sam.py +3 -3
  788. transformers/models/sam2/configuration_sam2.py +1 -1
  789. transformers/models/sam2/modeling_sam2.py +56 -52
  790. transformers/models/sam2/modular_sam2.py +47 -55
  791. transformers/models/sam2_video/modeling_sam2_video.py +50 -51
  792. transformers/models/sam2_video/modular_sam2_video.py +12 -10
  793. transformers/models/sam3/modeling_sam3.py +43 -47
  794. transformers/models/sam3/processing_sam3.py +8 -4
  795. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
  796. transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
  797. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
  798. transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
  799. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
  800. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
  801. transformers/models/sam3_video/modeling_sam3_video.py +27 -14
  802. transformers/models/sam_hq/configuration_sam_hq.py +2 -0
  803. transformers/models/sam_hq/modeling_sam_hq.py +13 -9
  804. transformers/models/sam_hq/modular_sam_hq.py +6 -6
  805. transformers/models/sam_hq/processing_sam_hq.py +7 -6
  806. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
  807. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
  808. transformers/models/seed_oss/configuration_seed_oss.py +7 -9
  809. transformers/models/seed_oss/modeling_seed_oss.py +4 -4
  810. transformers/models/seed_oss/modular_seed_oss.py +3 -3
  811. transformers/models/segformer/image_processing_segformer_fast.py +4 -4
  812. transformers/models/segformer/modeling_segformer.py +4 -2
  813. transformers/models/segformer/modular_segformer.py +3 -3
  814. transformers/models/seggpt/modeling_seggpt.py +20 -8
  815. transformers/models/sew/configuration_sew.py +4 -1
  816. transformers/models/sew/modeling_sew.py +9 -5
  817. transformers/models/sew/modular_sew.py +2 -1
  818. transformers/models/sew_d/configuration_sew_d.py +4 -1
  819. transformers/models/sew_d/modeling_sew_d.py +4 -1
  820. transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
  821. transformers/models/siglip/configuration_siglip.py +4 -1
  822. transformers/models/siglip/modeling_siglip.py +27 -71
  823. transformers/models/siglip2/__init__.py +1 -0
  824. transformers/models/siglip2/configuration_siglip2.py +4 -2
  825. transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
  826. transformers/models/siglip2/modeling_siglip2.py +37 -78
  827. transformers/models/siglip2/modular_siglip2.py +74 -25
  828. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  829. transformers/models/smollm3/configuration_smollm3.py +6 -6
  830. transformers/models/smollm3/modeling_smollm3.py +4 -4
  831. transformers/models/smollm3/modular_smollm3.py +9 -9
  832. transformers/models/smolvlm/configuration_smolvlm.py +1 -3
  833. transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
  834. transformers/models/smolvlm/modeling_smolvlm.py +75 -46
  835. transformers/models/smolvlm/modular_smolvlm.py +36 -23
  836. transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
  837. transformers/models/solar_open/__init__.py +27 -0
  838. transformers/models/solar_open/configuration_solar_open.py +184 -0
  839. transformers/models/solar_open/modeling_solar_open.py +642 -0
  840. transformers/models/solar_open/modular_solar_open.py +224 -0
  841. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
  842. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
  843. transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
  844. transformers/models/speecht5/configuration_speecht5.py +7 -8
  845. transformers/models/splinter/configuration_splinter.py +6 -6
  846. transformers/models/splinter/modeling_splinter.py +8 -3
  847. transformers/models/squeezebert/configuration_squeezebert.py +14 -1
  848. transformers/models/stablelm/configuration_stablelm.py +8 -6
  849. transformers/models/stablelm/modeling_stablelm.py +5 -5
  850. transformers/models/starcoder2/configuration_starcoder2.py +11 -5
  851. transformers/models/starcoder2/modeling_starcoder2.py +5 -5
  852. transformers/models/starcoder2/modular_starcoder2.py +4 -4
  853. transformers/models/superglue/configuration_superglue.py +4 -0
  854. transformers/models/superglue/image_processing_superglue_fast.py +4 -3
  855. transformers/models/superglue/modeling_superglue.py +9 -4
  856. transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
  857. transformers/models/superpoint/modeling_superpoint.py +4 -2
  858. transformers/models/swin/configuration_swin.py +2 -4
  859. transformers/models/swin/modeling_swin.py +11 -8
  860. transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
  861. transformers/models/swin2sr/modeling_swin2sr.py +4 -2
  862. transformers/models/swinv2/configuration_swinv2.py +2 -4
  863. transformers/models/swinv2/modeling_swinv2.py +10 -7
  864. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
  865. transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
  866. transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
  867. transformers/models/t5/configuration_t5.py +9 -8
  868. transformers/models/t5/modeling_t5.py +5 -8
  869. transformers/models/t5gemma/configuration_t5gemma.py +10 -25
  870. transformers/models/t5gemma/modeling_t5gemma.py +9 -9
  871. transformers/models/t5gemma/modular_t5gemma.py +11 -24
  872. transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
  873. transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
  874. transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
  875. transformers/models/table_transformer/configuration_table_transformer.py +18 -49
  876. transformers/models/table_transformer/modeling_table_transformer.py +27 -53
  877. transformers/models/tapas/configuration_tapas.py +12 -1
  878. transformers/models/tapas/modeling_tapas.py +1 -1
  879. transformers/models/tapas/tokenization_tapas.py +1 -0
  880. transformers/models/textnet/configuration_textnet.py +4 -6
  881. transformers/models/textnet/image_processing_textnet_fast.py +3 -3
  882. transformers/models/textnet/modeling_textnet.py +15 -14
  883. transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
  884. transformers/models/timesfm/modeling_timesfm.py +5 -6
  885. transformers/models/timesfm/modular_timesfm.py +5 -6
  886. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
  887. transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
  888. transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
  889. transformers/models/trocr/configuration_trocr.py +11 -7
  890. transformers/models/trocr/modeling_trocr.py +4 -2
  891. transformers/models/tvp/configuration_tvp.py +10 -35
  892. transformers/models/tvp/image_processing_tvp_fast.py +6 -5
  893. transformers/models/tvp/modeling_tvp.py +1 -1
  894. transformers/models/udop/configuration_udop.py +16 -7
  895. transformers/models/udop/modeling_udop.py +10 -6
  896. transformers/models/umt5/configuration_umt5.py +8 -6
  897. transformers/models/umt5/modeling_umt5.py +7 -3
  898. transformers/models/unispeech/configuration_unispeech.py +4 -1
  899. transformers/models/unispeech/modeling_unispeech.py +7 -4
  900. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
  901. transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
  902. transformers/models/upernet/configuration_upernet.py +8 -35
  903. transformers/models/upernet/modeling_upernet.py +1 -1
  904. transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
  905. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  906. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  907. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
  908. transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
  909. transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
  910. transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
  911. transformers/models/video_llava/configuration_video_llava.py +4 -0
  912. transformers/models/video_llava/modeling_video_llava.py +87 -89
  913. transformers/models/videomae/modeling_videomae.py +4 -5
  914. transformers/models/vilt/configuration_vilt.py +4 -1
  915. transformers/models/vilt/image_processing_vilt_fast.py +6 -6
  916. transformers/models/vilt/modeling_vilt.py +27 -12
  917. transformers/models/vipllava/configuration_vipllava.py +4 -0
  918. transformers/models/vipllava/modeling_vipllava.py +57 -31
  919. transformers/models/vipllava/modular_vipllava.py +50 -24
  920. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
  921. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
  922. transformers/models/visual_bert/configuration_visual_bert.py +6 -1
  923. transformers/models/vit/configuration_vit.py +2 -2
  924. transformers/models/vit/modeling_vit.py +7 -5
  925. transformers/models/vit_mae/modeling_vit_mae.py +11 -7
  926. transformers/models/vit_msn/modeling_vit_msn.py +11 -7
  927. transformers/models/vitdet/configuration_vitdet.py +2 -4
  928. transformers/models/vitdet/modeling_vitdet.py +2 -3
  929. transformers/models/vitmatte/configuration_vitmatte.py +6 -35
  930. transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
  931. transformers/models/vitmatte/modeling_vitmatte.py +1 -1
  932. transformers/models/vitpose/configuration_vitpose.py +6 -43
  933. transformers/models/vitpose/modeling_vitpose.py +5 -3
  934. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
  935. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
  936. transformers/models/vits/configuration_vits.py +4 -0
  937. transformers/models/vits/modeling_vits.py +9 -7
  938. transformers/models/vivit/modeling_vivit.py +4 -4
  939. transformers/models/vjepa2/modeling_vjepa2.py +9 -9
  940. transformers/models/voxtral/configuration_voxtral.py +0 -1
  941. transformers/models/voxtral/modeling_voxtral.py +25 -24
  942. transformers/models/voxtral/modular_voxtral.py +26 -20
  943. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
  944. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
  945. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
  946. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
  947. transformers/models/wavlm/configuration_wavlm.py +4 -1
  948. transformers/models/wavlm/modeling_wavlm.py +4 -1
  949. transformers/models/whisper/configuration_whisper.py +6 -4
  950. transformers/models/whisper/generation_whisper.py +0 -1
  951. transformers/models/whisper/modeling_whisper.py +3 -3
  952. transformers/models/x_clip/configuration_x_clip.py +4 -1
  953. transformers/models/x_clip/modeling_x_clip.py +26 -27
  954. transformers/models/xglm/configuration_xglm.py +9 -7
  955. transformers/models/xlm/configuration_xlm.py +10 -7
  956. transformers/models/xlm/modeling_xlm.py +1 -1
  957. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
  958. transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
  959. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
  960. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
  961. transformers/models/xlnet/configuration_xlnet.py +3 -1
  962. transformers/models/xlstm/configuration_xlstm.py +5 -7
  963. transformers/models/xlstm/modeling_xlstm.py +0 -32
  964. transformers/models/xmod/configuration_xmod.py +11 -2
  965. transformers/models/xmod/modeling_xmod.py +13 -16
  966. transformers/models/yolos/image_processing_yolos_fast.py +25 -28
  967. transformers/models/yolos/modeling_yolos.py +7 -7
  968. transformers/models/yolos/modular_yolos.py +16 -16
  969. transformers/models/yoso/configuration_yoso.py +8 -1
  970. transformers/models/youtu/__init__.py +27 -0
  971. transformers/models/youtu/configuration_youtu.py +194 -0
  972. transformers/models/youtu/modeling_youtu.py +619 -0
  973. transformers/models/youtu/modular_youtu.py +254 -0
  974. transformers/models/zamba/configuration_zamba.py +5 -7
  975. transformers/models/zamba/modeling_zamba.py +25 -56
  976. transformers/models/zamba2/configuration_zamba2.py +8 -13
  977. transformers/models/zamba2/modeling_zamba2.py +53 -78
  978. transformers/models/zamba2/modular_zamba2.py +36 -29
  979. transformers/models/zoedepth/configuration_zoedepth.py +17 -40
  980. transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
  981. transformers/models/zoedepth/modeling_zoedepth.py +5 -3
  982. transformers/pipelines/__init__.py +1 -61
  983. transformers/pipelines/any_to_any.py +1 -1
  984. transformers/pipelines/automatic_speech_recognition.py +0 -2
  985. transformers/pipelines/base.py +1 -1
  986. transformers/pipelines/image_text_to_text.py +1 -1
  987. transformers/pipelines/text_to_audio.py +5 -1
  988. transformers/processing_utils.py +35 -44
  989. transformers/pytorch_utils.py +2 -26
  990. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  991. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  992. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  993. transformers/quantizers/quantizer_mxfp4.py +1 -1
  994. transformers/quantizers/quantizer_torchao.py +0 -16
  995. transformers/safetensors_conversion.py +11 -4
  996. transformers/testing_utils.py +3 -28
  997. transformers/tokenization_mistral_common.py +9 -0
  998. transformers/tokenization_python.py +6 -4
  999. transformers/tokenization_utils_base.py +119 -219
  1000. transformers/tokenization_utils_tokenizers.py +31 -2
  1001. transformers/trainer.py +25 -33
  1002. transformers/trainer_seq2seq.py +1 -1
  1003. transformers/training_args.py +411 -417
  1004. transformers/utils/__init__.py +1 -4
  1005. transformers/utils/auto_docstring.py +15 -18
  1006. transformers/utils/backbone_utils.py +13 -373
  1007. transformers/utils/doc.py +4 -36
  1008. transformers/utils/generic.py +69 -33
  1009. transformers/utils/import_utils.py +72 -75
  1010. transformers/utils/loading_report.py +133 -105
  1011. transformers/utils/quantization_config.py +0 -21
  1012. transformers/video_processing_utils.py +5 -5
  1013. transformers/video_utils.py +3 -1
  1014. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
  1015. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
  1016. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1017. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1018. transformers/pipelines/image_to_text.py +0 -189
  1019. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1020. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1021. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -14,18 +14,17 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import math
17
- from contextlib import nullcontext
18
17
  from typing import Literal, Optional
19
18
 
20
19
  import torch
21
- import torch.nn.functional as F
22
20
  from torch import nn
23
21
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
22
 
25
23
  from ... import initialization as init
26
24
  from ...activations import ACT2FN
27
25
  from ...configuration_utils import PreTrainedConfig, layer_type_validation
28
- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
26
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
27
+ from ...masking_utils import create_bidirectional_mask, create_bidirectional_sliding_window_mask
29
28
  from ...modeling_layers import GradientCheckpointingLayer
30
29
  from ...modeling_outputs import (
31
30
  BaseModelOutput,
@@ -36,18 +35,12 @@ from ...modeling_outputs import (
36
35
  TokenClassifierOutput,
37
36
  )
38
37
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters
39
- from ...modeling_utils import PreTrainedModel
40
- from ...utils import auto_docstring, is_flash_attn_2_available, logging
41
- from ...utils.import_utils import is_triton_available
42
- from ..gemma3.modeling_gemma3 import Gemma3RotaryEmbedding, apply_rotary_pos_emb
43
-
44
-
45
- if is_flash_attn_2_available():
46
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
47
- from flash_attn.layers.rotary import RotaryEmbedding
48
- from flash_attn.ops.triton.rotary import apply_rotary
49
- else:
50
- RotaryEmbedding = object
38
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
+ from ...processing_utils import Unpack
40
+ from ...utils import TransformersKwargs, auto_docstring, logging
41
+ from ...utils.generic import can_return_tuple, check_model_inputs
42
+ from ..align.modeling_align import eager_attention_forward
43
+ from ..gemma3.modeling_gemma3 import Gemma3RotaryEmbedding, rotate_half
51
44
 
52
45
 
53
46
  logger = logging.get_logger(__name__)
@@ -104,10 +97,9 @@ class ModernBertConfig(PreTrainedConfig):
104
97
  The dropout ratio for the attention probabilities.
105
98
  layer_types (`list`, *optional*):
106
99
  Attention pattern for each layer.
107
- rope_parameters (`RopeParameters`, *optional*):
108
- Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
109
- a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
110
- with longer `max_position_embeddings`.
100
+ rope_parameters (`dict`, *optional*):
101
+ Dictionary mapping attention patterns (`"full_attention"`, `"sliding_attention"`) to `RopeParameters`.
102
+ Each value should be a dictionary containing `rope_type` and optional scaling parameters.
111
103
  local_attention (`int`, *optional*, defaults to 128):
112
104
  The window size for local attention.
113
105
  embedding_dropout (`float`, *optional*, defaults to 0.0):
@@ -137,10 +129,9 @@ class ModernBertConfig(PreTrainedConfig):
137
129
  Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
138
130
  the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
139
131
  shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
140
- be faster in some scenarios.
141
- repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
142
- When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
143
- applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
132
+ be faster in some scenarios. This argument is deprecated and will be removed in a future version.
133
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
134
+ Whether to tie weight embeddings
144
135
 
145
136
  Examples:
146
137
 
@@ -161,6 +152,15 @@ class ModernBertConfig(PreTrainedConfig):
161
152
  keys_to_ignore_at_inference = ["past_key_values"]
162
153
  default_theta = {"global": 160_000.0, "local": 10_000.0}
163
154
 
155
+ def __setattr__(self, name, value):
156
+ if name == "reference_compile" and value is not None:
157
+ logger.warning_once(
158
+ "The `reference_compile` argument is deprecated and will be removed in `transformers v5.2.0`"
159
+ "Use `torch.compile()` directly on the model instead."
160
+ )
161
+ value = None
162
+ super().__setattr__(name, value)
163
+
164
164
  def __init__(
165
165
  self,
166
166
  vocab_size: int | None = 50368,
@@ -172,7 +172,7 @@ class ModernBertConfig(PreTrainedConfig):
172
172
  max_position_embeddings: int | None = 8192,
173
173
  initializer_range: float | None = 0.02,
174
174
  initializer_cutoff_factor: float | None = 2.0,
175
- norm_eps: int | None = 1e-5,
175
+ norm_eps: float | None = 1e-5,
176
176
  norm_bias: bool | None = False,
177
177
  pad_token_id: int | None = 50283,
178
178
  eos_token_id: int | None = 50282,
@@ -182,7 +182,7 @@ class ModernBertConfig(PreTrainedConfig):
182
182
  attention_bias: bool | None = False,
183
183
  attention_dropout: float | None = 0.0,
184
184
  layer_types: list[str] | None = None,
185
- rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
185
+ rope_parameters: dict[Literal["full_attention", "sliding_attention"], RopeParameters] | None = None,
186
186
  local_attention: int | None = 128,
187
187
  embedding_dropout: float | None = 0.0,
188
188
  mlp_bias: bool | None = False,
@@ -195,10 +195,16 @@ class ModernBertConfig(PreTrainedConfig):
195
195
  deterministic_flash_attn: bool | None = False,
196
196
  sparse_prediction: bool | None = False,
197
197
  sparse_pred_ignore_index: int | None = -100,
198
- reference_compile: bool | None = None,
199
- repad_logits_with_grad: bool | None = False,
198
+ reference_compile: bool | None = None, # Deprecated
199
+ tie_word_embeddings: bool | None = True,
200
200
  **kwargs,
201
201
  ):
202
+ self.pad_token_id = pad_token_id
203
+ self.bos_token_id = bos_token_id
204
+ self.eos_token_id = eos_token_id
205
+ self.cls_token_id = cls_token_id
206
+ self.sep_token_id = sep_token_id
207
+ self.tie_word_embeddings = tie_word_embeddings
202
208
  self.vocab_size = vocab_size
203
209
  self.max_position_embeddings = max_position_embeddings
204
210
  self.hidden_size = hidden_size
@@ -225,7 +231,6 @@ class ModernBertConfig(PreTrainedConfig):
225
231
  self.sparse_prediction = sparse_prediction
226
232
  self.sparse_pred_ignore_index = sparse_pred_ignore_index
227
233
  self.reference_compile = reference_compile
228
- self.repad_logits_with_grad = repad_logits_with_grad
229
234
 
230
235
  if self.classifier_pooling not in ["cls", "mean"]:
231
236
  raise ValueError(
@@ -245,14 +250,7 @@ class ModernBertConfig(PreTrainedConfig):
245
250
  layer_type_validation(self.layer_types, self.num_hidden_layers)
246
251
 
247
252
  self.rope_parameters = rope_parameters
248
- super().__init__(
249
- pad_token_id=pad_token_id,
250
- bos_token_id=bos_token_id,
251
- eos_token_id=eos_token_id,
252
- cls_token_id=cls_token_id,
253
- sep_token_id=sep_token_id,
254
- **kwargs,
255
- )
253
+ super().__init__(**kwargs)
256
254
 
257
255
  def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation=None, **kwargs):
258
256
  rope_scaling = kwargs.pop("rope_scaling", None)
@@ -267,9 +265,15 @@ class ModernBertConfig(PreTrainedConfig):
267
265
  if rope_scaling is not None:
268
266
  self.rope_parameters["full_attention"].update(rope_scaling)
269
267
  self.rope_parameters["sliding_attention"].update(rope_scaling)
268
+
269
+ # Set default values if not present
270
+ if self.rope_parameters.get("full_attention") is None:
271
+ self.rope_parameters["full_attention"] = {"rope_type": "default"}
270
272
  self.rope_parameters["full_attention"].setdefault(
271
273
  "rope_theta", kwargs.pop("global_rope_theta", self.default_theta["global"])
272
274
  )
275
+ if self.rope_parameters.get("sliding_attention") is None:
276
+ self.rope_parameters["sliding_attention"] = {"rope_type": "default"}
273
277
  self.rope_parameters["sliding_attention"].setdefault(
274
278
  "rope_theta", kwargs.pop("local_rope_theta", self.default_theta["local"])
275
279
  )
@@ -284,211 +288,15 @@ class ModernBertConfig(PreTrainedConfig):
284
288
  output.pop("reference_compile", None)
285
289
  return output
286
290
 
291
+ @property
292
+ def sliding_window(self):
293
+ """Half-window size: `local_attention` is the total window, so we divide by 2."""
294
+ return self.local_attention // 2
287
295
 
288
- def _unpad_modernbert_input(
289
- inputs: torch.Tensor,
290
- attention_mask: torch.Tensor,
291
- position_ids: torch.Tensor | None = None,
292
- labels: torch.Tensor | None = None,
293
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, torch.Tensor | None, torch.Tensor | None]:
294
- """
295
- Remove padding from input sequences.
296
-
297
- Args:
298
- inputs: (batch, seqlen, ...) or (batch, seqlen)
299
- attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
300
- position_ids: (batch, seqlen), int, position ids
301
- labels: (batch, seqlen), int, labels
302
-
303
- Returns:
304
- unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
305
- indices: (total_nnz)
306
- cu_seqlens: (batch + 1), the cumulative sequence lengths
307
- max_seqlen_in_batch: int
308
- unpadded_position_ids: (total_nnz) or None
309
- unpadded_labels: (total_nnz) or None
310
- """
311
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
312
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
313
- max_seqlen_in_batch = int(seqlens_in_batch.max().item())
314
- cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
315
-
316
- if inputs.dim() == 2:
317
- unpadded_inputs = inputs.flatten()[indices]
318
- else:
319
- batch, seqlen, *rest = inputs.shape
320
- shape = batch * seqlen
321
- unpadded_inputs = inputs.view(shape, *rest)[indices]
322
-
323
- unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
324
- unpadded_labels = labels.flatten()[indices] if labels is not None else None
325
-
326
- return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
327
-
328
-
329
- def _pad_modernbert_output(
330
- inputs: torch.Tensor,
331
- indices: torch.Tensor,
332
- batch: int,
333
- seqlen: int,
334
- ) -> torch.Tensor:
335
- """
336
- Add padding to sequences.
337
-
338
- Args:
339
- inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
340
- indices: (total_nnz)
341
- batch: int, batch size
342
- seqlen: int, max sequence length
343
-
344
- Returns:
345
- padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
346
- """
347
- if inputs.dim() == 1:
348
- output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
349
- output[indices] = inputs
350
- padded_inputs = output.view(batch, seqlen)
351
- else:
352
- _, *rest = inputs.shape
353
- output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
354
- output[indices] = inputs
355
- padded_inputs = output.view(batch, seqlen, *rest)
356
-
357
- return padded_inputs
358
-
359
-
360
- class ApplyRotaryEmbUnpad(torch.autograd.Function):
361
- @staticmethod
362
- def forward(
363
- ctx,
364
- qkv,
365
- cos,
366
- sin,
367
- cu_seqlens: torch.Tensor | None = None,
368
- max_seqlen: int | None = None,
369
- ):
370
- # (total_nnz, 3, nheads, headdim)
371
- qkv = qkv.contiguous()
372
- total_nnz, _three, _nheads, headdim = qkv.shape
373
- # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
374
- # we get the same tensor
375
- # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
376
- qk = qkv[:, :2].view(total_nnz, -1, headdim)
377
- apply_rotary(
378
- qk,
379
- cos,
380
- sin,
381
- seqlen_offsets=0,
382
- cu_seqlens=cu_seqlens,
383
- max_seqlen=max_seqlen,
384
- interleaved=False,
385
- inplace=True,
386
- )
387
-
388
- ctx.save_for_backward(cos, sin, cu_seqlens)
389
- ctx.max_seqlen = max_seqlen
390
- return qkv
391
-
392
- @staticmethod
393
- def backward(ctx, do):
394
- cos, sin, cu_seqlens = ctx.saved_tensors
395
- do = do.contiguous()
396
- total_nnz, _three, _nheads, headdim = do.shape
397
- # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
398
- # we get the same tensor
399
- dqk = do[:, :2].view(total_nnz, -1, headdim)
400
- apply_rotary(
401
- dqk,
402
- cos,
403
- sin,
404
- seqlen_offsets=0,
405
- cu_seqlens=cu_seqlens,
406
- max_seqlen=ctx.max_seqlen,
407
- interleaved=False,
408
- inplace=True,
409
- conjugate=True,
410
- )
411
-
412
- return do, None, None, None, None, None, None
413
-
414
-
415
- def apply_rotary_unpadded(
416
- qkv,
417
- cos,
418
- sin,
419
- cu_seqlens: torch.Tensor | None = None,
420
- max_seqlen: int | None = None,
421
- ):
422
- """
423
- Arguments:
424
- qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
425
- cos, sin: (seqlen_rotary, rotary_dim / 2)
426
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
427
- of 1st half and 2nd half (GPT-NeoX style).
428
- inplace: if True, apply rotary embedding in-place.
429
- seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
430
- Most commonly used in inference when we have KV cache.
431
- cu_seqlens: (batch + 1,) or None
432
- max_seqlen: int
433
- Return:
434
- out: (total_nnz, dim)
435
- rotary_dim must be <= headdim
436
- Apply rotary embedding to the first rotary_dim of x.
437
- """
438
- return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
439
-
440
-
441
- class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
442
- """
443
- The rotary position embeddings applied directly to unpadded sequences.
444
- """
445
-
446
- def __init__(
447
- self,
448
- dim: int,
449
- base: float = 10000.0,
450
- max_seqlen: int | None = None,
451
- device: torch.device | None = None,
452
- dtype: torch.dtype | None = None,
453
- ):
454
- """
455
- max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
456
- up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
457
- the cos_sin_cache will be recomputed during the forward pass.
458
- """
459
- super().__init__(dim=dim, base=base, device=device, interleaved=False)
460
- self.max_seqlen = max_seqlen
461
-
462
- if max_seqlen is not None and device is not None and dtype is not None:
463
- self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
464
-
465
- def forward(
466
- self,
467
- qkv: torch.Tensor,
468
- cu_seqlens: torch.Tensor,
469
- max_seqlen: int | None = None,
470
- ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
471
- """
472
- Apply rotary embedding *inplace* to qkv.
473
- qkv: (total_nnz, 3, nheads, headdim)
474
- cu_seqlens: (batch + 1,) cumulative sequence lengths
475
- max_seqlen: int max seq length in the batch
476
- """
477
- if max_seqlen is not None:
478
- self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
479
-
480
- qkv = apply_rotary_unpadded(
481
- qkv,
482
- self._cos_cached,
483
- self._sin_cached,
484
- cu_seqlens=cu_seqlens,
485
- max_seqlen=max_seqlen,
486
- )
487
-
488
- return qkv
489
-
490
- def extra_repr(self) -> str:
491
- return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
296
+ @sliding_window.setter
297
+ def sliding_window(self, value):
298
+ """Set sliding_window by updating local_attention to 2 * value."""
299
+ self.local_attention = value * 2
492
300
 
493
301
 
494
302
  class ModernBertEmbeddings(nn.Module):
@@ -503,21 +311,13 @@ class ModernBertEmbeddings(nn.Module):
503
311
  self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
504
312
  self.drop = nn.Dropout(config.embedding_dropout)
505
313
 
506
- @torch.compile(dynamic=True)
507
- def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
508
- return self.drop(self.norm(self.tok_embeddings(input_ids)))
509
-
510
314
  def forward(
511
315
  self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.Tensor | None = None
512
316
  ) -> torch.Tensor:
513
317
  if inputs_embeds is not None:
514
318
  hidden_states = self.drop(self.norm(inputs_embeds))
515
319
  else:
516
- hidden_states = (
517
- self.compiled_embeddings(input_ids)
518
- if self.config.reference_compile
519
- else self.drop(self.norm(self.tok_embeddings(input_ids)))
520
- )
320
+ hidden_states = self.drop(self.norm(self.tok_embeddings(input_ids)))
521
321
  return hidden_states
522
322
 
523
323
 
@@ -555,130 +355,34 @@ class ModernBertRotaryEmbedding(Gemma3RotaryEmbedding):
555
355
  return super().compute_default_rope_parameters(config, device, seq_len, layer_type)
556
356
 
557
357
 
558
- def eager_attention_forward(
559
- module: "ModernBertAttention",
560
- qkv: torch.Tensor,
561
- attention_mask: torch.Tensor,
562
- sliding_window_mask: torch.Tensor,
563
- position_ids: torch.LongTensor | None,
564
- local_attention: tuple[int, int],
565
- bs: int,
566
- dim: int,
567
- position_embeddings: torch.Tensor,
568
- output_attentions: bool | None = False,
569
- **_kwargs,
570
- ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]:
571
- # qkv: [batch_size, seqlen, 3, nheads, headdim]
572
- cos, sin = position_embeddings
573
- query, key, value = qkv.transpose(3, 1).unbind(dim=2)
574
- # query, key, value: [batch_size, heads, seq_len, head_dim]
575
- query, key = apply_rotary_pos_emb(query, key, cos, sin)
576
-
577
- scale = module.head_dim**-0.5
578
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
579
-
580
- if local_attention != (-1, -1):
581
- attention_mask = sliding_window_mask
582
-
583
- attn_weights = attn_weights + attention_mask
584
-
585
- # upcast attention to fp32
586
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
587
- attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
588
- attn_output = torch.matmul(attn_weights, value)
589
- attn_output = attn_output.transpose(1, 2).contiguous()
590
- attn_output = attn_output.view(bs, -1, dim)
591
- if output_attentions:
592
- return (attn_output, attn_weights)
593
- return (attn_output,)
594
-
595
-
596
- def flash_attention_forward(
597
- module: "ModernBertAttention",
598
- qkv: torch.Tensor,
599
- rotary_emb: ModernBertUnpaddedRotaryEmbedding,
600
- cu_seqlens: torch.Tensor,
601
- max_seqlen: int,
602
- local_attention: tuple[int, int],
603
- bs: int,
604
- dim: int,
605
- target_dtype: torch.dtype = torch.bfloat16,
606
- **_kwargs,
607
- ) -> tuple[torch.Tensor]:
608
- # (total_seqlen, 3, nheads, headdim)
609
- qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
610
-
611
- convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
612
- if convert_dtype:
613
- # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
614
- # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
615
- orig_dtype = qkv.dtype
616
- qkv = qkv.to(target_dtype)
617
-
618
- attn = flash_attn_varlen_qkvpacked_func(
619
- qkv,
620
- cu_seqlens=cu_seqlens,
621
- max_seqlen=max_seqlen,
622
- dropout_p=module.attention_dropout if module.training else 0.0,
623
- deterministic=module.deterministic_flash_attn,
624
- window_size=local_attention,
625
- )
626
- attn = attn.to(orig_dtype) # type: ignore
627
- else:
628
- attn = flash_attn_varlen_qkvpacked_func(
629
- qkv,
630
- cu_seqlens=cu_seqlens,
631
- max_seqlen=max_seqlen,
632
- dropout_p=module.attention_dropout if module.training else 0.0,
633
- deterministic=module.deterministic_flash_attn,
634
- window_size=local_attention,
635
- )
636
- return (attn.view(bs, dim),)
637
-
638
-
639
- def sdpa_attention_forward(
640
- module: "ModernBertAttention",
641
- qkv: torch.Tensor,
642
- attention_mask: torch.Tensor,
643
- sliding_window_mask: torch.Tensor,
644
- position_ids: torch.LongTensor | None,
645
- local_attention: tuple[int, int],
646
- bs: int,
647
- dim: int,
648
- position_embeddings: torch.Tensor,
649
- **_kwargs,
650
- ) -> tuple[torch.Tensor]:
651
- # qkv: [batch_size, seqlen, 3, nheads, headdim]
652
- cos, sin = position_embeddings
653
- query, key, value = qkv.transpose(3, 1).unbind(dim=2)
654
- # query, key, value: [batch_size, heads, seq_len, head_dim]
655
- query, key = apply_rotary_pos_emb(query, key, cos, sin)
656
-
657
- if local_attention != (-1, -1):
658
- attention_mask = sliding_window_mask
659
-
660
- attn_output = (
661
- F.scaled_dot_product_attention(
662
- query,
663
- key,
664
- value,
665
- dropout_p=module.attention_dropout if module.training else 0.0,
666
- attn_mask=attention_mask,
667
- )
668
- .transpose(1, 2)
669
- .contiguous()
670
- )
671
- attn_output = attn_output.view(bs, -1, dim)
672
- return (attn_output,)
673
-
358
+ @use_kernel_func_from_hub("rotary_pos_emb")
359
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
360
+ """Applies Rotary Position Embedding to the query and key tensors.
674
361
 
675
- MODERNBERT_ATTENTION_FUNCTION = {
676
- "flash_attention_2": flash_attention_forward,
677
- "eager": eager_attention_forward,
678
- "sdpa": sdpa_attention_forward,
679
- }
362
+ Args:
363
+ q (`torch.Tensor`): The query tensor.
364
+ k (`torch.Tensor`): The key tensor.
365
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
366
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
367
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
368
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
369
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
370
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
371
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
372
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
373
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
374
+ Returns:
375
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
376
+ """
377
+ original_dtype = q.dtype
378
+ cos = cos.unsqueeze(unsqueeze_dim)
379
+ sin = sin.unsqueeze(unsqueeze_dim)
380
+ q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
381
+ k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
382
+ return q_embed.to(original_dtype), k_embed.to(original_dtype)
680
383
 
681
384
 
385
+ @use_kernelized_func(apply_rotary_pos_emb)
682
386
  class ModernBertAttention(nn.Module):
683
387
  """Performs multi-headed self attention on a batch of unpadded sequences.
684
388
 
@@ -689,10 +393,10 @@ class ModernBertAttention(nn.Module):
689
393
  See `forward` method for additional details.
690
394
  """
691
395
 
692
- def __init__(self, config: ModernBertConfig, layer_id: int | None = None):
396
+ def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
693
397
  super().__init__()
694
398
  self.config = config
695
- self.layer_id = layer_id
399
+ self.layer_idx = layer_idx
696
400
 
697
401
  if config.hidden_size % config.num_attention_heads != 0:
698
402
  raise ValueError(
@@ -701,29 +405,19 @@ class ModernBertAttention(nn.Module):
701
405
 
702
406
  self.attention_dropout = config.attention_dropout
703
407
  self.deterministic_flash_attn = config.deterministic_flash_attn
704
- self.num_heads = config.num_attention_heads
705
408
  self.head_dim = config.hidden_size // config.num_attention_heads
706
- self.all_head_size = self.head_dim * self.num_heads
707
- self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
708
- layer_type = config.layer_types[layer_id]
409
+ self.Wqkv = nn.Linear(
410
+ config.hidden_size, 3 * self.head_dim * config.num_attention_heads, bias=config.attention_bias
411
+ )
709
412
 
710
- if layer_id % config.global_attn_every_n_layers != 0:
711
- self.local_attention = (config.local_attention // 2, config.local_attention // 2)
712
- max_position_embeddings = config.local_attention
413
+ if config.layer_types[layer_idx] == "sliding_attention":
414
+ # config.sliding_window = local_attention // 2 (half-window size, e.g. 64 for local_attention=128)
415
+ # +1 is needed because flash attention sets inclusive boundaries (see modeling_flash_attention_utils.py)
416
+ self.sliding_window = config.sliding_window + 1
713
417
  else:
714
- self.local_attention = (-1, -1)
715
- max_position_embeddings = config.max_position_embeddings
418
+ self.sliding_window = None
716
419
 
717
- if config._attn_implementation == "flash_attention_2":
718
- rope_parameters_dict = (
719
- self.config.rope_parameters[layer_type] if layer_type is not None else self.config.rope_parameters
720
- )
721
- rope_theta = rope_parameters_dict["rope_theta"]
722
- self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
723
- dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
724
- )
725
- else:
726
- self.rotary_emb = None
420
+ self.is_causal = False
727
421
 
728
422
  self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
729
423
  self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
@@ -731,82 +425,75 @@ class ModernBertAttention(nn.Module):
731
425
  def forward(
732
426
  self,
733
427
  hidden_states: torch.Tensor,
734
- position_embeddings: torch.Tensor | None = None,
735
- output_attentions: bool | None = False,
736
- **kwargs,
737
- ) -> torch.Tensor:
428
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
429
+ attention_mask: torch.Tensor | None = None,
430
+ **kwargs: Unpack[TransformersKwargs],
431
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
432
+ input_shape = hidden_states.shape[:-1]
433
+
738
434
  qkv = self.Wqkv(hidden_states)
435
+ qkv = qkv.view(*input_shape, 3, -1, self.head_dim)
436
+ query_states, key_states, value_states = qkv.unbind(dim=-3)
739
437
 
740
- bs = hidden_states.shape[0]
741
- if self.config._attn_implementation == "flash_attention_2":
742
- qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
743
- else:
744
- qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
438
+ query_states = query_states.transpose(1, 2)
439
+ key_states = key_states.transpose(1, 2)
440
+ value_states = value_states.transpose(1, 2)
441
+
442
+ cos, sin = position_embeddings
443
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1)
444
+
445
+ attention_interface = eager_attention_forward
446
+ if self.config._attn_implementation != "eager":
447
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
745
448
 
746
- attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
449
+ attn_output, attn_weights = attention_interface(
747
450
  self,
748
- qkv=qkv,
749
- rotary_emb=self.rotary_emb,
750
- local_attention=self.local_attention,
751
- bs=bs,
752
- dim=self.all_head_size,
753
- position_embeddings=position_embeddings,
754
- output_attentions=output_attentions,
451
+ query_states,
452
+ key_states,
453
+ value_states,
454
+ attention_mask,
455
+ dropout=self.attention_dropout if self.training else 0.0,
456
+ scaling=self.head_dim**-0.5,
457
+ sliding_window=self.sliding_window,
458
+ deterministic=self.deterministic_flash_attn,
755
459
  **kwargs,
756
460
  )
757
- hidden_states = attn_outputs[0]
758
- hidden_states = self.out_drop(self.Wo(hidden_states))
759
461
 
760
- return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
462
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
463
+ attn_output = self.out_drop(self.Wo(attn_output))
464
+ return attn_output, attn_weights
761
465
 
762
466
 
763
467
  class ModernBertEncoderLayer(GradientCheckpointingLayer):
764
- def __init__(self, config: ModernBertConfig, layer_id: int | None = None):
468
+ def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
765
469
  super().__init__()
766
470
  self.config = config
767
- if layer_id == 0:
471
+ self.layer_idx = layer_idx
472
+ if layer_idx == 0:
768
473
  self.attn_norm = nn.Identity()
769
474
  else:
770
475
  self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
771
- self.attn = ModernBertAttention(config=config, layer_id=layer_id)
476
+ self.attn = ModernBertAttention(config=config, layer_idx=layer_idx)
772
477
  self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
773
478
  self.mlp = ModernBertMLP(config)
774
- self.attention_type = config.layer_types[layer_id]
775
-
776
- @torch.compile(dynamic=True)
777
- def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
778
- return self.mlp(self.mlp_norm(hidden_states))
479
+ self.attention_type = config.layer_types[layer_idx]
779
480
 
780
481
  def forward(
781
482
  self,
782
483
  hidden_states: torch.Tensor,
783
484
  attention_mask: torch.Tensor | None = None,
784
- sliding_window_mask: torch.Tensor | None = None,
785
- position_ids: torch.LongTensor | None = None,
786
- cu_seqlens: torch.Tensor | None = None,
787
- max_seqlen: int | None = None,
788
485
  position_embeddings: torch.Tensor | None = None,
789
- output_attentions: bool | None = False,
486
+ **kwargs: Unpack[TransformersKwargs],
790
487
  ) -> torch.Tensor:
791
- attn_outputs = self.attn(
488
+ attn_output, _ = self.attn(
792
489
  self.attn_norm(hidden_states),
793
- attention_mask=attention_mask,
794
- sliding_window_mask=sliding_window_mask,
795
- position_ids=position_ids,
796
- cu_seqlens=cu_seqlens,
797
- max_seqlen=max_seqlen,
798
490
  position_embeddings=position_embeddings,
799
- output_attentions=output_attentions,
800
- )
801
- hidden_states = hidden_states + attn_outputs[0]
802
- mlp_output = (
803
- self.compiled_mlp(hidden_states)
804
- if self.config.reference_compile
805
- else self.mlp(self.mlp_norm(hidden_states))
491
+ attention_mask=attention_mask,
492
+ **kwargs,
806
493
  )
807
- hidden_states = hidden_states + mlp_output
808
-
809
- return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
494
+ hidden_states = hidden_states + attn_output
495
+ hidden_states = hidden_states + self.mlp(self.mlp_norm(hidden_states))
496
+ return hidden_states
810
497
 
811
498
 
812
499
  @auto_docstring
@@ -817,7 +504,13 @@ class ModernBertPreTrainedModel(PreTrainedModel):
817
504
  _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
818
505
  _supports_flash_attn = True
819
506
  _supports_sdpa = True
820
- _supports_flex_attn = False
507
+ _supports_flex_attn = True
508
+ _supports_attention_backend = True
509
+
510
+ _can_record_outputs = {
511
+ "hidden_states": ModernBertEncoderLayer,
512
+ "attentions": ModernBertAttention,
513
+ }
821
514
 
822
515
  @torch.no_grad()
823
516
  def _init_weights(self, module: nn.Module):
@@ -879,9 +572,6 @@ class ModernBertPreTrainedModel(PreTrainedModel):
879
572
  curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
880
573
  init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
881
574
  init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
882
- elif isinstance(module, ModernBertUnpaddedRotaryEmbedding):
883
- inv_freq = module._compute_inv_freq()
884
- init.copy_(module.inv_freq, inv_freq)
885
575
 
886
576
  def _check_and_adjust_attn_implementation(
887
577
  self, attn_implementation: str | None, is_init_check: bool = False
@@ -889,65 +579,17 @@ class ModernBertPreTrainedModel(PreTrainedModel):
889
579
  """
890
580
  Checks and dispatches to hhe requested attention implementation.
891
581
  """
892
- # If the user didn't specify anything, try to use flash_attention_2 if available.
582
+ # If the user didn't specify anything, try to use flash_attention_2.
893
583
  # Otherwise we fall back to the default SDPA -> Eager from the super() method.
894
- # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
895
- # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
896
-
897
584
  try:
898
- attn_implementation = (
899
- "flash_attention_2"
900
- if attn_implementation is None and self._flash_attn_2_can_dispatch()
901
- else attn_implementation
585
+ requested_attn_implementation = "flash_attention_2" if attn_implementation is None else attn_implementation
586
+ return super()._check_and_adjust_attn_implementation(
587
+ attn_implementation=requested_attn_implementation, is_init_check=is_init_check
902
588
  )
903
589
  except (ValueError, ImportError):
904
- pass
905
- return super()._check_and_adjust_attn_implementation(
906
- attn_implementation=attn_implementation, is_init_check=is_init_check
907
- )
908
-
909
- def _maybe_set_compile(self):
910
- if self.config.reference_compile is False:
911
- return
912
-
913
- if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
914
- if self.config.reference_compile:
915
- logger.warning_once(
916
- "If `accelerate` split the model across devices, `torch.compile` will not work. "
917
- "Falling back to non-compiled mode."
918
- )
919
- self.config.reference_compile = False
920
-
921
- if self.device.type == "mps":
922
- if self.config.reference_compile:
923
- logger.warning_once(
924
- "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
925
- "Falling back to non-compiled mode."
926
- )
927
- self.config.reference_compile = False
928
-
929
- if self.device.type == "cpu":
930
- if self.config.reference_compile:
931
- logger.warning_once(
932
- "Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
933
- "Falling back to non-compiled mode."
934
- )
935
- self.config.reference_compile = False
936
-
937
- if self.config.reference_compile is None:
938
- self.config.reference_compile = is_triton_available()
939
-
940
- def resize_token_embeddings(self, *args, **kwargs):
941
- model_embeds = super().resize_token_embeddings(*args, **kwargs)
942
-
943
- if self.config.reference_compile in {True, None}:
944
- if self.config.reference_compile:
945
- logger.warning_once(
946
- "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
947
- )
948
- self.config.reference_compile = False
949
-
950
- return model_embeds
590
+ return super()._check_and_adjust_attn_implementation(
591
+ attn_implementation=attn_implementation, is_init_check=is_init_check
592
+ )
951
593
 
952
594
 
953
595
  @auto_docstring
@@ -957,7 +599,7 @@ class ModernBertModel(ModernBertPreTrainedModel):
957
599
  self.config = config
958
600
  self.embeddings = ModernBertEmbeddings(config)
959
601
  self.layers = nn.ModuleList(
960
- [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
602
+ [ModernBertEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
961
603
  )
962
604
  self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
963
605
  self.rotary_emb = ModernBertRotaryEmbedding(config=config)
@@ -970,175 +612,53 @@ class ModernBertModel(ModernBertPreTrainedModel):
970
612
  def set_input_embeddings(self, value):
971
613
  self.embeddings.tok_embeddings = value
972
614
 
615
+ @check_model_inputs
973
616
  @auto_docstring
974
617
  def forward(
975
618
  self,
976
619
  input_ids: torch.LongTensor | None = None,
977
620
  attention_mask: torch.Tensor | None = None,
978
- sliding_window_mask: torch.Tensor | None = None,
979
621
  position_ids: torch.LongTensor | None = None,
980
622
  inputs_embeds: torch.Tensor | None = None,
981
- indices: torch.Tensor | None = None,
982
- cu_seqlens: torch.Tensor | None = None,
983
- max_seqlen: int | None = None,
984
- batch_size: int | None = None,
985
- seq_len: int | None = None,
986
- output_attentions: bool | None = None,
987
- output_hidden_states: bool | None = None,
988
- return_dict: bool | None = None,
989
- **kwargs,
990
- ) -> tuple[torch.Tensor, ...] | BaseModelOutput:
991
- r"""
992
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
993
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
994
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
995
- far-away tokens in the local attention layers when not using Flash Attention.
996
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
997
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
998
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
999
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1000
- max_seqlen (`int`, *optional*):
1001
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1002
- batch_size (`int`, *optional*):
1003
- Batch size of the input sequences. Used to pad the output tensors.
1004
- seq_len (`int`, *optional*):
1005
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1006
- """
1007
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1008
- output_hidden_states = (
1009
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1010
- )
1011
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1012
-
623
+ **kwargs: Unpack[TransformersKwargs],
624
+ ) -> BaseModelOutput:
1013
625
  if (input_ids is None) ^ (inputs_embeds is not None):
1014
626
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1015
627
 
1016
- all_hidden_states = () if output_hidden_states else None
1017
- all_self_attentions = () if output_attentions else None
1018
-
1019
- self._maybe_set_compile()
1020
-
1021
- if input_ids is not None:
1022
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1023
-
1024
- if batch_size is None and seq_len is None:
1025
- if inputs_embeds is not None:
1026
- batch_size, seq_len = inputs_embeds.shape[:2]
1027
- else:
1028
- batch_size, seq_len = input_ids.shape[:2]
628
+ seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
1029
629
  device = input_ids.device if input_ids is not None else inputs_embeds.device
1030
630
 
1031
- if attention_mask is None:
1032
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1033
-
1034
- repad = False
1035
- if self.config._attn_implementation == "flash_attention_2":
1036
- if indices is None and cu_seqlens is None and max_seqlen is None:
1037
- repad = True
1038
- if inputs_embeds is None:
1039
- with torch.no_grad():
1040
- input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
1041
- inputs=input_ids, attention_mask=attention_mask
1042
- )
1043
- else:
1044
- inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
1045
- inputs=inputs_embeds, attention_mask=attention_mask
1046
- )
1047
- if position_ids is None:
1048
- position_ids = indices.unsqueeze(0)
1049
- else:
1050
- if position_ids is None:
1051
- position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
1052
-
1053
- attention_mask, sliding_window_mask = self._update_attention_mask(
1054
- attention_mask, output_attentions=output_attentions
1055
- )
631
+ if position_ids is None:
632
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
1056
633
 
1057
634
  hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
635
+
636
+ if not isinstance(attention_mask_mapping := attention_mask, dict):
637
+ mask_kwargs = {
638
+ "config": self.config,
639
+ "input_embeds": hidden_states,
640
+ "attention_mask": attention_mask,
641
+ }
642
+ attention_mask_mapping = {
643
+ "full_attention": create_bidirectional_mask(**mask_kwargs),
644
+ "sliding_attention": create_bidirectional_sliding_window_mask(**mask_kwargs),
645
+ }
646
+
1058
647
  position_embeddings = {}
1059
648
  for layer_type in self.config.layer_types:
1060
649
  position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
1061
650
 
1062
651
  for encoder_layer in self.layers:
1063
- if output_hidden_states:
1064
- all_hidden_states = all_hidden_states + (hidden_states,)
1065
-
1066
- layer_outputs = encoder_layer(
652
+ hidden_states = encoder_layer(
1067
653
  hidden_states,
1068
- attention_mask=attention_mask,
1069
- sliding_window_mask=sliding_window_mask,
1070
- position_ids=position_ids,
1071
- cu_seqlens=cu_seqlens,
1072
- max_seqlen=max_seqlen,
654
+ attention_mask=attention_mask_mapping[encoder_layer.attention_type],
1073
655
  position_embeddings=position_embeddings[encoder_layer.attention_type],
1074
- output_attentions=output_attentions,
656
+ **kwargs,
1075
657
  )
1076
- hidden_states = layer_outputs[0]
1077
- if output_attentions and len(layer_outputs) > 1:
1078
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
1079
-
1080
- if output_hidden_states:
1081
- all_hidden_states = all_hidden_states + (hidden_states,)
1082
658
 
1083
659
  hidden_states = self.final_norm(hidden_states)
1084
660
 
1085
- if repad:
1086
- hidden_states = _pad_modernbert_output(
1087
- inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
1088
- )
1089
- if all_hidden_states is not None:
1090
- all_hidden_states = tuple(
1091
- _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
1092
- for hs in all_hidden_states
1093
- )
1094
- # If the attention implementation is FA2 and there is no need for repadding, there might still be the batch
1095
- # dimension missing
1096
- elif (
1097
- self.config._attn_implementation == "flash_attention_2"
1098
- and all_hidden_states is not None
1099
- and all_hidden_states[-1].dim() == 2
1100
- ):
1101
- hidden_states = hidden_states.unsqueeze(0)
1102
- all_hidden_states = tuple(hs.unsqueeze(0) for hs in all_hidden_states)
1103
-
1104
- if not return_dict:
1105
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
1106
- return BaseModelOutput(
1107
- last_hidden_state=hidden_states,
1108
- hidden_states=all_hidden_states,
1109
- attentions=all_self_attentions,
1110
- )
1111
-
1112
- def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
1113
- if output_attentions:
1114
- if self.config._attn_implementation == "sdpa":
1115
- logger.warning_once(
1116
- "Outputting attentions is only supported with the 'eager' attention implementation, "
1117
- 'not with "sdpa". Falling back to `attn_implementation="eager"`.'
1118
- )
1119
- self.config._attn_implementation = "eager"
1120
- elif self.config._attn_implementation != "eager":
1121
- logger.warning_once(
1122
- "Outputting attentions is only supported with the eager attention implementation, "
1123
- f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
1124
- " Setting `output_attentions=False`."
1125
- )
1126
-
1127
- global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
1128
-
1129
- # Create position indices
1130
- rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
1131
- # Calculate distance between positions
1132
- distance = torch.abs(rows - rows.T)
1133
-
1134
- # Create sliding window mask (1 for positions within window, 0 outside)
1135
- window_mask = (
1136
- (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
1137
- )
1138
- # Combine with existing mask
1139
- sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
1140
-
1141
- return global_attention_mask, sliding_window_mask
661
+ return BaseModelOutput(last_hidden_state=hidden_states)
1142
662
 
1143
663
 
1144
664
  class ModernBertPredictionHead(nn.Module):
@@ -1180,84 +700,23 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
1180
700
  def set_output_embeddings(self, new_embeddings: nn.Linear):
1181
701
  self.decoder = new_embeddings
1182
702
 
1183
- @torch.compile(dynamic=True)
1184
- def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
1185
- return self.decoder(self.head(output))
1186
-
703
+ @can_return_tuple
1187
704
  @auto_docstring
1188
705
  def forward(
1189
706
  self,
1190
707
  input_ids: torch.LongTensor | None = None,
1191
708
  attention_mask: torch.Tensor | None = None,
1192
- sliding_window_mask: torch.Tensor | None = None,
1193
709
  position_ids: torch.Tensor | None = None,
1194
710
  inputs_embeds: torch.Tensor | None = None,
1195
711
  labels: torch.Tensor | None = None,
1196
- indices: torch.Tensor | None = None,
1197
- cu_seqlens: torch.Tensor | None = None,
1198
- max_seqlen: int | None = None,
1199
- batch_size: int | None = None,
1200
- seq_len: int | None = None,
1201
- output_attentions: bool | None = None,
1202
- output_hidden_states: bool | None = None,
1203
- return_dict: bool | None = None,
1204
- **kwargs,
712
+ **kwargs: Unpack[TransformersKwargs],
1205
713
  ) -> tuple[torch.Tensor] | MaskedLMOutput:
1206
- r"""
1207
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1208
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1209
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1210
- far-away tokens in the local attention layers when not using Flash Attention.
1211
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1212
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1213
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1214
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1215
- max_seqlen (`int`, *optional*):
1216
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1217
- batch_size (`int`, *optional*):
1218
- Batch size of the input sequences. Used to pad the output tensors.
1219
- seq_len (`int`, *optional*):
1220
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1221
- """
1222
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1223
- self._maybe_set_compile()
1224
-
1225
- if self.config._attn_implementation == "flash_attention_2":
1226
- if indices is None and cu_seqlens is None and max_seqlen is None:
1227
- if batch_size is None and seq_len is None:
1228
- if inputs_embeds is not None:
1229
- batch_size, seq_len = inputs_embeds.shape[:2]
1230
- else:
1231
- batch_size, seq_len = input_ids.shape[:2]
1232
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1233
-
1234
- if attention_mask is None:
1235
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1236
-
1237
- if inputs_embeds is None:
1238
- with torch.no_grad():
1239
- input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1240
- inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1241
- )
1242
- else:
1243
- inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1244
- inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1245
- )
1246
-
1247
714
  outputs = self.model(
1248
715
  input_ids=input_ids,
1249
716
  attention_mask=attention_mask,
1250
- sliding_window_mask=sliding_window_mask,
1251
717
  position_ids=position_ids,
1252
718
  inputs_embeds=inputs_embeds,
1253
- indices=indices,
1254
- cu_seqlens=cu_seqlens,
1255
- max_seqlen=max_seqlen,
1256
- batch_size=batch_size,
1257
- seq_len=seq_len,
1258
- output_attentions=output_attentions,
1259
- output_hidden_states=output_hidden_states,
1260
- return_dict=return_dict,
719
+ **kwargs,
1261
720
  )
1262
721
  last_hidden_state = outputs[0]
1263
722
 
@@ -1271,35 +730,12 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
1271
730
  last_hidden_state = last_hidden_state[mask_tokens]
1272
731
  labels = labels[mask_tokens]
1273
732
 
1274
- logits = (
1275
- self.compiled_head(last_hidden_state)
1276
- if self.config.reference_compile
1277
- else self.decoder(self.head(last_hidden_state))
1278
- )
733
+ logits = self.decoder(self.head(last_hidden_state))
1279
734
 
1280
735
  loss = None
1281
736
  if labels is not None:
1282
737
  loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
1283
738
 
1284
- if self.config._attn_implementation == "flash_attention_2":
1285
- # Logits padding
1286
- with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
1287
- logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
1288
- # Hidden states padding
1289
- if getattr(outputs, "hidden_states", None) is not None:
1290
- padded_hidden_states = []
1291
- for hs in outputs.hidden_states:
1292
- if hs.dim() == 3 and hs.shape[0] == 1:
1293
- hs = hs.squeeze(0)
1294
- padded_hidden_states.append(
1295
- _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
1296
- )
1297
- outputs.hidden_states = tuple(padded_hidden_states)
1298
-
1299
- if not return_dict:
1300
- output = (logits,)
1301
- return ((loss,) + output) if loss is not None else output
1302
-
1303
739
  return MaskedLMOutput(
1304
740
  loss=loss,
1305
741
  logits=logits,
@@ -1327,81 +763,39 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
1327
763
  # Initialize weights and apply final processing
1328
764
  self.post_init()
1329
765
 
766
+ @can_return_tuple
1330
767
  @auto_docstring
1331
768
  def forward(
1332
769
  self,
1333
770
  input_ids: torch.LongTensor | None = None,
1334
771
  attention_mask: torch.Tensor | None = None,
1335
- sliding_window_mask: torch.Tensor | None = None,
1336
772
  position_ids: torch.Tensor | None = None,
1337
773
  inputs_embeds: torch.Tensor | None = None,
1338
774
  labels: torch.Tensor | None = None,
1339
- indices: torch.Tensor | None = None,
1340
- cu_seqlens: torch.Tensor | None = None,
1341
- max_seqlen: int | None = None,
1342
- batch_size: int | None = None,
1343
- seq_len: int | None = None,
1344
- output_attentions: bool | None = None,
1345
- output_hidden_states: bool | None = None,
1346
- return_dict: bool | None = None,
1347
- **kwargs,
775
+ **kwargs: Unpack[TransformersKwargs],
1348
776
  ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
1349
777
  r"""
1350
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1351
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1352
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1353
- far-away tokens in the local attention layers when not using Flash Attention.
1354
778
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1355
779
  Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1356
780
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1357
781
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1358
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1359
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1360
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1361
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1362
- max_seqlen (`int`, *optional*):
1363
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1364
- batch_size (`int`, *optional*):
1365
- Batch size of the input sequences. Used to pad the output tensors.
1366
- seq_len (`int`, *optional*):
1367
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1368
782
  """
1369
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1370
- self._maybe_set_compile()
1371
-
1372
- if input_ids is not None:
1373
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1374
-
1375
- if batch_size is None and seq_len is None:
1376
- if inputs_embeds is not None:
1377
- batch_size, seq_len = inputs_embeds.shape[:2]
1378
- else:
1379
- batch_size, seq_len = input_ids.shape[:2]
1380
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1381
-
1382
- if attention_mask is None:
1383
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1384
-
1385
783
  outputs = self.model(
1386
784
  input_ids=input_ids,
1387
785
  attention_mask=attention_mask,
1388
- sliding_window_mask=sliding_window_mask,
1389
786
  position_ids=position_ids,
1390
787
  inputs_embeds=inputs_embeds,
1391
- indices=indices,
1392
- cu_seqlens=cu_seqlens,
1393
- max_seqlen=max_seqlen,
1394
- batch_size=batch_size,
1395
- seq_len=seq_len,
1396
- output_attentions=output_attentions,
1397
- output_hidden_states=output_hidden_states,
1398
- return_dict=return_dict,
788
+ **kwargs,
1399
789
  )
1400
790
  last_hidden_state = outputs[0]
1401
791
 
1402
792
  if self.config.classifier_pooling == "cls":
1403
793
  last_hidden_state = last_hidden_state[:, 0]
1404
794
  elif self.config.classifier_pooling == "mean":
795
+ if attention_mask is None:
796
+ attention_mask = torch.ones(
797
+ last_hidden_state.shape[:2], device=last_hidden_state.device, dtype=torch.bool
798
+ )
1405
799
  last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
1406
800
  dim=1, keepdim=True
1407
801
  )
@@ -1433,10 +827,6 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
1433
827
  loss_fct = BCEWithLogitsLoss()
1434
828
  loss = loss_fct(logits, labels)
1435
829
 
1436
- if not return_dict:
1437
- output = (logits,)
1438
- return ((loss,) + output) if loss is not None else output
1439
-
1440
830
  return SequenceClassifierOutput(
1441
831
  loss=loss,
1442
832
  logits=logits,
@@ -1463,60 +853,27 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1463
853
  # Initialize weights and apply final processing
1464
854
  self.post_init()
1465
855
 
856
+ @can_return_tuple
1466
857
  @auto_docstring
1467
858
  def forward(
1468
859
  self,
1469
860
  input_ids: torch.LongTensor | None = None,
1470
861
  attention_mask: torch.Tensor | None = None,
1471
- sliding_window_mask: torch.Tensor | None = None,
1472
862
  position_ids: torch.Tensor | None = None,
1473
863
  inputs_embeds: torch.Tensor | None = None,
1474
864
  labels: torch.Tensor | None = None,
1475
- indices: torch.Tensor | None = None,
1476
- cu_seqlens: torch.Tensor | None = None,
1477
- max_seqlen: int | None = None,
1478
- batch_size: int | None = None,
1479
- seq_len: int | None = None,
1480
- output_attentions: bool | None = None,
1481
- output_hidden_states: bool | None = None,
1482
- return_dict: bool | None = None,
1483
- **kwargs,
865
+ **kwargs: Unpack[TransformersKwargs],
1484
866
  ) -> tuple[torch.Tensor] | TokenClassifierOutput:
1485
867
  r"""
1486
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1487
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1488
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1489
- far-away tokens in the local attention layers when not using Flash Attention.
1490
868
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1491
869
  Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1492
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1493
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1494
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1495
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1496
- max_seqlen (`int`, *optional*):
1497
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1498
- batch_size (`int`, *optional*):
1499
- Batch size of the input sequences. Used to pad the output tensors.
1500
- seq_len (`int`, *optional*):
1501
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1502
870
  """
1503
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1504
- self._maybe_set_compile()
1505
-
1506
871
  outputs = self.model(
1507
872
  input_ids=input_ids,
1508
873
  attention_mask=attention_mask,
1509
- sliding_window_mask=sliding_window_mask,
1510
874
  position_ids=position_ids,
1511
875
  inputs_embeds=inputs_embeds,
1512
- indices=indices,
1513
- cu_seqlens=cu_seqlens,
1514
- max_seqlen=max_seqlen,
1515
- batch_size=batch_size,
1516
- seq_len=seq_len,
1517
- output_attentions=output_attentions,
1518
- output_hidden_states=output_hidden_states,
1519
- return_dict=return_dict,
876
+ **kwargs,
1520
877
  )
1521
878
  last_hidden_state = outputs[0]
1522
879
 
@@ -1529,10 +886,6 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1529
886
  loss_fct = CrossEntropyLoss()
1530
887
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1531
888
 
1532
- if not return_dict:
1533
- output = (logits,) + outputs[1:]
1534
- return ((loss,) + output) if loss is not None else output
1535
-
1536
889
  return TokenClassifierOutput(
1537
890
  loss=loss,
1538
891
  logits=logits,
@@ -1554,57 +907,22 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
1554
907
 
1555
908
  self.post_init()
1556
909
 
910
+ @can_return_tuple
1557
911
  @auto_docstring
1558
912
  def forward(
1559
913
  self,
1560
- input_ids: torch.Tensor | None,
914
+ input_ids: torch.Tensor | None = None,
1561
915
  attention_mask: torch.Tensor | None = None,
1562
- sliding_window_mask: torch.Tensor | None = None,
1563
916
  position_ids: torch.Tensor | None = None,
1564
917
  start_positions: torch.Tensor | None = None,
1565
918
  end_positions: torch.Tensor | None = None,
1566
- indices: torch.Tensor | None = None,
1567
- cu_seqlens: torch.Tensor | None = None,
1568
- max_seqlen: int | None = None,
1569
- batch_size: int | None = None,
1570
- seq_len: int | None = None,
1571
- output_attentions: bool | None = None,
1572
- output_hidden_states: bool | None = None,
1573
- return_dict: bool | None = None,
1574
- **kwargs,
919
+ **kwargs: Unpack[TransformersKwargs],
1575
920
  ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
1576
- r"""
1577
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1578
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1579
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1580
- far-away tokens in the local attention layers when not using Flash Attention.
1581
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1582
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1583
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1584
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1585
- max_seqlen (`int`, *optional*):
1586
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1587
- batch_size (`int`, *optional*):
1588
- Batch size of the input sequences. Used to pad the output tensors.
1589
- seq_len (`int`, *optional*):
1590
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1591
- """
1592
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1593
- self._maybe_set_compile()
1594
-
1595
921
  outputs = self.model(
1596
922
  input_ids,
1597
923
  attention_mask=attention_mask,
1598
- sliding_window_mask=sliding_window_mask,
1599
924
  position_ids=position_ids,
1600
- indices=indices,
1601
- cu_seqlens=cu_seqlens,
1602
- max_seqlen=max_seqlen,
1603
- batch_size=batch_size,
1604
- seq_len=seq_len,
1605
- output_attentions=output_attentions,
1606
- output_hidden_states=output_hidden_states,
1607
- return_dict=return_dict,
925
+ **kwargs,
1608
926
  )
1609
927
  last_hidden_state = outputs[0]
1610
928
 
@@ -1620,10 +938,6 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
1620
938
  if start_positions is not None and end_positions is not None:
1621
939
  loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1622
940
 
1623
- if not return_dict:
1624
- output = (start_logits, end_logits) + outputs[1:]
1625
- return ((loss,) + output) if loss is not None else output
1626
-
1627
941
  return QuestionAnsweringModelOutput(
1628
942
  loss=loss,
1629
943
  start_logits=start_logits,
@@ -1651,45 +965,22 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
1651
965
  # Initialize weights and apply final processing
1652
966
  self.post_init()
1653
967
 
968
+ @can_return_tuple
1654
969
  @auto_docstring
1655
970
  def forward(
1656
971
  self,
1657
972
  input_ids: torch.LongTensor | None = None,
1658
973
  attention_mask: torch.Tensor | None = None,
1659
- sliding_window_mask: torch.Tensor | None = None,
1660
974
  position_ids: torch.Tensor | None = None,
1661
975
  inputs_embeds: torch.Tensor | None = None,
1662
976
  labels: torch.Tensor | None = None,
1663
- indices: torch.Tensor | None = None,
1664
- cu_seqlens: torch.Tensor | None = None,
1665
- max_seqlen: int | None = None,
1666
- batch_size: int | None = None,
1667
- seq_len: int | None = None,
1668
- output_attentions: bool | None = None,
1669
- output_hidden_states: bool | None = None,
1670
- return_dict: bool | None = None,
1671
- **kwargs,
977
+ **kwargs: Unpack[TransformersKwargs],
1672
978
  ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
1673
979
  r"""
1674
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1675
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1676
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1677
- far-away tokens in the local attention layers when not using Flash Attention.
1678
980
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1679
981
  Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1680
982
  num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
1681
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1682
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1683
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1684
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1685
- max_seqlen (`int`, *optional*):
1686
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1687
- batch_size (`int`, *optional*):
1688
- Batch size of the input sequences. Used to pad the output tensors.
1689
- seq_len (`int`, *optional*):
1690
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1691
983
  """
1692
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1693
984
  num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1694
985
 
1695
986
  input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@@ -1701,22 +992,12 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
1701
992
  else None
1702
993
  )
1703
994
 
1704
- self._maybe_set_compile()
1705
-
1706
995
  outputs = self.model(
1707
996
  input_ids=input_ids,
1708
997
  attention_mask=attention_mask,
1709
- sliding_window_mask=sliding_window_mask,
1710
998
  position_ids=position_ids,
1711
999
  inputs_embeds=inputs_embeds,
1712
- indices=indices,
1713
- cu_seqlens=cu_seqlens,
1714
- max_seqlen=max_seqlen,
1715
- batch_size=batch_size,
1716
- seq_len=seq_len,
1717
- output_attentions=output_attentions,
1718
- output_hidden_states=output_hidden_states,
1719
- return_dict=return_dict,
1000
+ **kwargs,
1720
1001
  )
1721
1002
  last_hidden_state = outputs[0] # shape (num_choices, seq_len, hidden_size)
1722
1003
 
@@ -1748,10 +1029,6 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
1748
1029
  loss_fct = nn.CrossEntropyLoss()
1749
1030
  loss = loss_fct(reshaped_logits, labels)
1750
1031
 
1751
- if not return_dict:
1752
- output = (reshaped_logits,) + outputs[1:]
1753
- return ((loss,) + output) if loss is not None else output
1754
-
1755
1032
  return MultipleChoiceModelOutput(
1756
1033
  loss=loss,
1757
1034
  logits=reshaped_logits,