transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1594) hide show
  1. transformers/__init__.py +11 -37
  2. transformers/activations.py +2 -2
  3. transformers/audio_utils.py +32 -32
  4. transformers/backbone_utils.py +326 -0
  5. transformers/cache_utils.py +26 -126
  6. transformers/cli/chat.py +3 -3
  7. transformers/cli/serve.py +13 -10
  8. transformers/cli/transformers.py +2 -1
  9. transformers/configuration_utils.py +22 -92
  10. transformers/conversion_mapping.py +150 -26
  11. transformers/convert_slow_tokenizer.py +9 -12
  12. transformers/core_model_loading.py +217 -129
  13. transformers/data/processors/glue.py +0 -1
  14. transformers/data/processors/utils.py +0 -1
  15. transformers/data/processors/xnli.py +0 -1
  16. transformers/dependency_versions_check.py +0 -1
  17. transformers/dependency_versions_table.py +10 -11
  18. transformers/distributed/configuration_utils.py +1 -2
  19. transformers/dynamic_module_utils.py +23 -23
  20. transformers/feature_extraction_sequence_utils.py +19 -23
  21. transformers/feature_extraction_utils.py +14 -14
  22. transformers/file_utils.py +0 -2
  23. transformers/generation/candidate_generator.py +2 -4
  24. transformers/generation/configuration_utils.py +54 -39
  25. transformers/generation/continuous_batching/__init__.py +0 -1
  26. transformers/generation/continuous_batching/cache.py +74 -44
  27. transformers/generation/continuous_batching/cache_manager.py +28 -28
  28. transformers/generation/continuous_batching/continuous_api.py +133 -414
  29. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  30. transformers/generation/continuous_batching/requests.py +77 -19
  31. transformers/generation/continuous_batching/scheduler.py +154 -104
  32. transformers/generation/logits_process.py +10 -133
  33. transformers/generation/stopping_criteria.py +1 -2
  34. transformers/generation/streamers.py +0 -1
  35. transformers/generation/utils.py +91 -121
  36. transformers/generation/watermarking.py +2 -3
  37. transformers/hf_argparser.py +9 -13
  38. transformers/hyperparameter_search.py +1 -2
  39. transformers/image_processing_base.py +9 -9
  40. transformers/image_processing_utils.py +11 -15
  41. transformers/image_processing_utils_fast.py +70 -71
  42. transformers/image_transforms.py +73 -42
  43. transformers/image_utils.py +30 -37
  44. transformers/initialization.py +57 -0
  45. transformers/integrations/__init__.py +10 -24
  46. transformers/integrations/accelerate.py +47 -11
  47. transformers/integrations/awq.py +1 -3
  48. transformers/integrations/deepspeed.py +146 -4
  49. transformers/integrations/eetq.py +0 -1
  50. transformers/integrations/executorch.py +2 -6
  51. transformers/integrations/fbgemm_fp8.py +1 -2
  52. transformers/integrations/finegrained_fp8.py +149 -13
  53. transformers/integrations/flash_attention.py +3 -8
  54. transformers/integrations/flex_attention.py +1 -1
  55. transformers/integrations/fp_quant.py +4 -6
  56. transformers/integrations/ggml.py +0 -1
  57. transformers/integrations/hub_kernels.py +18 -7
  58. transformers/integrations/integration_utils.py +2 -3
  59. transformers/integrations/moe.py +226 -106
  60. transformers/integrations/mxfp4.py +52 -40
  61. transformers/integrations/peft.py +488 -176
  62. transformers/integrations/quark.py +2 -4
  63. transformers/integrations/tensor_parallel.py +641 -581
  64. transformers/integrations/torchao.py +4 -6
  65. transformers/loss/loss_lw_detr.py +356 -0
  66. transformers/loss/loss_utils.py +2 -0
  67. transformers/masking_utils.py +199 -59
  68. transformers/model_debugging_utils.py +4 -5
  69. transformers/modelcard.py +14 -192
  70. transformers/modeling_attn_mask_utils.py +19 -19
  71. transformers/modeling_flash_attention_utils.py +28 -29
  72. transformers/modeling_gguf_pytorch_utils.py +5 -5
  73. transformers/modeling_layers.py +21 -22
  74. transformers/modeling_outputs.py +242 -253
  75. transformers/modeling_rope_utils.py +32 -32
  76. transformers/modeling_utils.py +416 -438
  77. transformers/models/__init__.py +10 -0
  78. transformers/models/afmoe/configuration_afmoe.py +40 -33
  79. transformers/models/afmoe/modeling_afmoe.py +38 -41
  80. transformers/models/afmoe/modular_afmoe.py +23 -25
  81. transformers/models/aimv2/configuration_aimv2.py +2 -10
  82. transformers/models/aimv2/modeling_aimv2.py +46 -45
  83. transformers/models/aimv2/modular_aimv2.py +13 -19
  84. transformers/models/albert/configuration_albert.py +8 -2
  85. transformers/models/albert/modeling_albert.py +70 -72
  86. transformers/models/albert/tokenization_albert.py +1 -4
  87. transformers/models/align/configuration_align.py +8 -6
  88. transformers/models/align/modeling_align.py +83 -86
  89. transformers/models/align/processing_align.py +2 -30
  90. transformers/models/altclip/configuration_altclip.py +4 -7
  91. transformers/models/altclip/modeling_altclip.py +106 -103
  92. transformers/models/altclip/processing_altclip.py +2 -15
  93. transformers/models/apertus/__init__.py +0 -1
  94. transformers/models/apertus/configuration_apertus.py +23 -28
  95. transformers/models/apertus/modeling_apertus.py +35 -38
  96. transformers/models/apertus/modular_apertus.py +36 -40
  97. transformers/models/arcee/configuration_arcee.py +25 -30
  98. transformers/models/arcee/modeling_arcee.py +35 -38
  99. transformers/models/arcee/modular_arcee.py +20 -23
  100. transformers/models/aria/configuration_aria.py +31 -44
  101. transformers/models/aria/image_processing_aria.py +25 -27
  102. transformers/models/aria/modeling_aria.py +102 -102
  103. transformers/models/aria/modular_aria.py +111 -124
  104. transformers/models/aria/processing_aria.py +28 -35
  105. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
  106. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
  107. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
  108. transformers/models/audioflamingo3/__init__.py +0 -1
  109. transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
  110. transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
  111. transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
  112. transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
  113. transformers/models/auto/auto_factory.py +12 -11
  114. transformers/models/auto/configuration_auto.py +48 -5
  115. transformers/models/auto/feature_extraction_auto.py +5 -7
  116. transformers/models/auto/image_processing_auto.py +30 -39
  117. transformers/models/auto/modeling_auto.py +33 -199
  118. transformers/models/auto/processing_auto.py +11 -19
  119. transformers/models/auto/tokenization_auto.py +38 -37
  120. transformers/models/auto/video_processing_auto.py +7 -8
  121. transformers/models/autoformer/configuration_autoformer.py +4 -7
  122. transformers/models/autoformer/modeling_autoformer.py +100 -101
  123. transformers/models/aya_vision/configuration_aya_vision.py +4 -1
  124. transformers/models/aya_vision/modeling_aya_vision.py +64 -99
  125. transformers/models/aya_vision/modular_aya_vision.py +46 -74
  126. transformers/models/aya_vision/processing_aya_vision.py +25 -53
  127. transformers/models/bamba/configuration_bamba.py +46 -39
  128. transformers/models/bamba/modeling_bamba.py +83 -119
  129. transformers/models/bamba/modular_bamba.py +70 -109
  130. transformers/models/bark/configuration_bark.py +6 -8
  131. transformers/models/bark/generation_configuration_bark.py +3 -5
  132. transformers/models/bark/modeling_bark.py +64 -65
  133. transformers/models/bark/processing_bark.py +19 -41
  134. transformers/models/bart/configuration_bart.py +9 -5
  135. transformers/models/bart/modeling_bart.py +124 -129
  136. transformers/models/barthez/tokenization_barthez.py +1 -4
  137. transformers/models/bartpho/tokenization_bartpho.py +6 -7
  138. transformers/models/beit/configuration_beit.py +2 -15
  139. transformers/models/beit/image_processing_beit.py +53 -56
  140. transformers/models/beit/image_processing_beit_fast.py +11 -12
  141. transformers/models/beit/modeling_beit.py +65 -62
  142. transformers/models/bert/configuration_bert.py +12 -2
  143. transformers/models/bert/modeling_bert.py +117 -152
  144. transformers/models/bert/tokenization_bert.py +2 -4
  145. transformers/models/bert/tokenization_bert_legacy.py +3 -5
  146. transformers/models/bert_generation/configuration_bert_generation.py +17 -2
  147. transformers/models/bert_generation/modeling_bert_generation.py +53 -55
  148. transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
  149. transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
  150. transformers/models/bertweet/tokenization_bertweet.py +1 -3
  151. transformers/models/big_bird/configuration_big_bird.py +12 -9
  152. transformers/models/big_bird/modeling_big_bird.py +107 -124
  153. transformers/models/big_bird/tokenization_big_bird.py +1 -4
  154. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  155. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
  156. transformers/models/biogpt/configuration_biogpt.py +8 -2
  157. transformers/models/biogpt/modeling_biogpt.py +73 -79
  158. transformers/models/biogpt/modular_biogpt.py +60 -66
  159. transformers/models/biogpt/tokenization_biogpt.py +3 -5
  160. transformers/models/bit/configuration_bit.py +2 -5
  161. transformers/models/bit/image_processing_bit.py +21 -24
  162. transformers/models/bit/image_processing_bit_fast.py +0 -1
  163. transformers/models/bit/modeling_bit.py +15 -16
  164. transformers/models/bitnet/configuration_bitnet.py +23 -28
  165. transformers/models/bitnet/modeling_bitnet.py +34 -38
  166. transformers/models/bitnet/modular_bitnet.py +7 -10
  167. transformers/models/blenderbot/configuration_blenderbot.py +8 -5
  168. transformers/models/blenderbot/modeling_blenderbot.py +68 -99
  169. transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
  170. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
  171. transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
  172. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
  173. transformers/models/blip/configuration_blip.py +9 -10
  174. transformers/models/blip/image_processing_blip.py +17 -20
  175. transformers/models/blip/image_processing_blip_fast.py +0 -1
  176. transformers/models/blip/modeling_blip.py +115 -108
  177. transformers/models/blip/modeling_blip_text.py +63 -65
  178. transformers/models/blip/processing_blip.py +5 -36
  179. transformers/models/blip_2/configuration_blip_2.py +2 -2
  180. transformers/models/blip_2/modeling_blip_2.py +145 -121
  181. transformers/models/blip_2/processing_blip_2.py +8 -38
  182. transformers/models/bloom/configuration_bloom.py +5 -2
  183. transformers/models/bloom/modeling_bloom.py +60 -60
  184. transformers/models/blt/configuration_blt.py +94 -86
  185. transformers/models/blt/modeling_blt.py +93 -90
  186. transformers/models/blt/modular_blt.py +127 -69
  187. transformers/models/bridgetower/configuration_bridgetower.py +7 -2
  188. transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
  189. transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
  190. transformers/models/bridgetower/modeling_bridgetower.py +136 -124
  191. transformers/models/bridgetower/processing_bridgetower.py +2 -16
  192. transformers/models/bros/configuration_bros.py +24 -18
  193. transformers/models/bros/modeling_bros.py +78 -80
  194. transformers/models/bros/processing_bros.py +2 -12
  195. transformers/models/byt5/tokenization_byt5.py +4 -6
  196. transformers/models/camembert/configuration_camembert.py +8 -2
  197. transformers/models/camembert/modeling_camembert.py +97 -99
  198. transformers/models/camembert/modular_camembert.py +51 -54
  199. transformers/models/camembert/tokenization_camembert.py +1 -4
  200. transformers/models/canine/configuration_canine.py +4 -2
  201. transformers/models/canine/modeling_canine.py +73 -75
  202. transformers/models/canine/tokenization_canine.py +0 -1
  203. transformers/models/chameleon/configuration_chameleon.py +29 -34
  204. transformers/models/chameleon/image_processing_chameleon.py +21 -24
  205. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
  206. transformers/models/chameleon/modeling_chameleon.py +135 -92
  207. transformers/models/chameleon/processing_chameleon.py +16 -41
  208. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
  209. transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
  210. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
  211. transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
  212. transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
  213. transformers/models/clap/configuration_clap.py +4 -9
  214. transformers/models/clap/feature_extraction_clap.py +9 -10
  215. transformers/models/clap/modeling_clap.py +109 -111
  216. transformers/models/clap/processing_clap.py +2 -15
  217. transformers/models/clip/configuration_clip.py +4 -2
  218. transformers/models/clip/image_processing_clip.py +21 -24
  219. transformers/models/clip/image_processing_clip_fast.py +9 -1
  220. transformers/models/clip/modeling_clip.py +70 -68
  221. transformers/models/clip/processing_clip.py +2 -14
  222. transformers/models/clip/tokenization_clip.py +2 -5
  223. transformers/models/clipseg/configuration_clipseg.py +4 -2
  224. transformers/models/clipseg/modeling_clipseg.py +113 -112
  225. transformers/models/clipseg/processing_clipseg.py +19 -42
  226. transformers/models/clvp/configuration_clvp.py +15 -5
  227. transformers/models/clvp/feature_extraction_clvp.py +7 -10
  228. transformers/models/clvp/modeling_clvp.py +138 -145
  229. transformers/models/clvp/number_normalizer.py +1 -2
  230. transformers/models/clvp/processing_clvp.py +3 -20
  231. transformers/models/clvp/tokenization_clvp.py +0 -1
  232. transformers/models/code_llama/tokenization_code_llama.py +3 -6
  233. transformers/models/codegen/configuration_codegen.py +4 -4
  234. transformers/models/codegen/modeling_codegen.py +50 -49
  235. transformers/models/codegen/tokenization_codegen.py +5 -6
  236. transformers/models/cohere/configuration_cohere.py +25 -30
  237. transformers/models/cohere/modeling_cohere.py +39 -42
  238. transformers/models/cohere/modular_cohere.py +27 -31
  239. transformers/models/cohere/tokenization_cohere.py +5 -6
  240. transformers/models/cohere2/configuration_cohere2.py +27 -32
  241. transformers/models/cohere2/modeling_cohere2.py +38 -41
  242. transformers/models/cohere2/modular_cohere2.py +48 -52
  243. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  244. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
  245. transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
  246. transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
  247. transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
  248. transformers/models/colpali/configuration_colpali.py +0 -1
  249. transformers/models/colpali/modeling_colpali.py +14 -16
  250. transformers/models/colpali/modular_colpali.py +11 -51
  251. transformers/models/colpali/processing_colpali.py +14 -52
  252. transformers/models/colqwen2/modeling_colqwen2.py +27 -28
  253. transformers/models/colqwen2/modular_colqwen2.py +36 -74
  254. transformers/models/colqwen2/processing_colqwen2.py +16 -52
  255. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
  256. transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
  257. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
  258. transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
  259. transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
  260. transformers/models/convbert/configuration_convbert.py +11 -8
  261. transformers/models/convbert/modeling_convbert.py +85 -87
  262. transformers/models/convbert/tokenization_convbert.py +0 -1
  263. transformers/models/convnext/configuration_convnext.py +2 -5
  264. transformers/models/convnext/image_processing_convnext.py +18 -21
  265. transformers/models/convnext/image_processing_convnext_fast.py +7 -8
  266. transformers/models/convnext/modeling_convnext.py +12 -14
  267. transformers/models/convnextv2/configuration_convnextv2.py +2 -5
  268. transformers/models/convnextv2/modeling_convnextv2.py +12 -14
  269. transformers/models/cpm/tokenization_cpm.py +6 -7
  270. transformers/models/cpm/tokenization_cpm_fast.py +3 -5
  271. transformers/models/cpmant/configuration_cpmant.py +4 -1
  272. transformers/models/cpmant/modeling_cpmant.py +38 -40
  273. transformers/models/cpmant/tokenization_cpmant.py +1 -3
  274. transformers/models/csm/configuration_csm.py +58 -66
  275. transformers/models/csm/generation_csm.py +13 -14
  276. transformers/models/csm/modeling_csm.py +81 -84
  277. transformers/models/csm/modular_csm.py +56 -58
  278. transformers/models/csm/processing_csm.py +25 -68
  279. transformers/models/ctrl/configuration_ctrl.py +16 -1
  280. transformers/models/ctrl/modeling_ctrl.py +51 -66
  281. transformers/models/ctrl/tokenization_ctrl.py +0 -1
  282. transformers/models/cvt/configuration_cvt.py +0 -1
  283. transformers/models/cvt/modeling_cvt.py +13 -15
  284. transformers/models/cwm/__init__.py +0 -1
  285. transformers/models/cwm/configuration_cwm.py +8 -12
  286. transformers/models/cwm/modeling_cwm.py +36 -38
  287. transformers/models/cwm/modular_cwm.py +10 -12
  288. transformers/models/d_fine/configuration_d_fine.py +10 -57
  289. transformers/models/d_fine/modeling_d_fine.py +786 -927
  290. transformers/models/d_fine/modular_d_fine.py +339 -417
  291. transformers/models/dab_detr/configuration_dab_detr.py +22 -49
  292. transformers/models/dab_detr/modeling_dab_detr.py +79 -77
  293. transformers/models/dac/configuration_dac.py +0 -1
  294. transformers/models/dac/feature_extraction_dac.py +6 -9
  295. transformers/models/dac/modeling_dac.py +22 -24
  296. transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
  297. transformers/models/data2vec/configuration_data2vec_text.py +11 -3
  298. transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
  299. transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
  300. transformers/models/data2vec/modeling_data2vec_text.py +97 -99
  301. transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
  302. transformers/models/data2vec/modular_data2vec_audio.py +6 -1
  303. transformers/models/data2vec/modular_data2vec_text.py +51 -54
  304. transformers/models/dbrx/configuration_dbrx.py +29 -22
  305. transformers/models/dbrx/modeling_dbrx.py +45 -48
  306. transformers/models/dbrx/modular_dbrx.py +37 -39
  307. transformers/models/deberta/configuration_deberta.py +6 -1
  308. transformers/models/deberta/modeling_deberta.py +57 -60
  309. transformers/models/deberta/tokenization_deberta.py +2 -5
  310. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
  311. transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
  312. transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
  313. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
  314. transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
  315. transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
  316. transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
  317. transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
  318. transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
  319. transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
  320. transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
  321. transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
  322. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
  323. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
  324. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
  325. transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
  326. transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
  327. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
  328. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
  329. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
  330. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
  331. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
  332. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
  333. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
  334. transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
  335. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
  336. transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
  337. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
  338. transformers/models/deit/configuration_deit.py +0 -1
  339. transformers/models/deit/image_processing_deit.py +18 -21
  340. transformers/models/deit/image_processing_deit_fast.py +0 -1
  341. transformers/models/deit/modeling_deit.py +27 -25
  342. transformers/models/depth_anything/configuration_depth_anything.py +12 -43
  343. transformers/models/depth_anything/modeling_depth_anything.py +10 -11
  344. transformers/models/depth_pro/configuration_depth_pro.py +0 -1
  345. transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
  346. transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
  347. transformers/models/depth_pro/modeling_depth_pro.py +29 -27
  348. transformers/models/detr/configuration_detr.py +18 -50
  349. transformers/models/detr/image_processing_detr.py +64 -66
  350. transformers/models/detr/image_processing_detr_fast.py +33 -34
  351. transformers/models/detr/modeling_detr.py +748 -789
  352. transformers/models/dia/configuration_dia.py +9 -15
  353. transformers/models/dia/feature_extraction_dia.py +6 -9
  354. transformers/models/dia/generation_dia.py +48 -53
  355. transformers/models/dia/modeling_dia.py +68 -71
  356. transformers/models/dia/modular_dia.py +56 -58
  357. transformers/models/dia/processing_dia.py +39 -29
  358. transformers/models/dia/tokenization_dia.py +3 -6
  359. transformers/models/diffllama/configuration_diffllama.py +25 -30
  360. transformers/models/diffllama/modeling_diffllama.py +45 -53
  361. transformers/models/diffllama/modular_diffllama.py +18 -25
  362. transformers/models/dinat/configuration_dinat.py +2 -5
  363. transformers/models/dinat/modeling_dinat.py +47 -48
  364. transformers/models/dinov2/configuration_dinov2.py +2 -5
  365. transformers/models/dinov2/modeling_dinov2.py +20 -21
  366. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
  367. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
  368. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
  369. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
  370. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
  371. transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
  372. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
  373. transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
  374. transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
  375. transformers/models/distilbert/configuration_distilbert.py +8 -2
  376. transformers/models/distilbert/modeling_distilbert.py +47 -49
  377. transformers/models/distilbert/tokenization_distilbert.py +0 -1
  378. transformers/models/doge/__init__.py +0 -1
  379. transformers/models/doge/configuration_doge.py +42 -35
  380. transformers/models/doge/modeling_doge.py +46 -49
  381. transformers/models/doge/modular_doge.py +77 -68
  382. transformers/models/donut/configuration_donut_swin.py +0 -1
  383. transformers/models/donut/image_processing_donut.py +26 -29
  384. transformers/models/donut/image_processing_donut_fast.py +9 -14
  385. transformers/models/donut/modeling_donut_swin.py +44 -46
  386. transformers/models/donut/processing_donut.py +5 -26
  387. transformers/models/dots1/configuration_dots1.py +43 -36
  388. transformers/models/dots1/modeling_dots1.py +35 -38
  389. transformers/models/dots1/modular_dots1.py +0 -1
  390. transformers/models/dpr/configuration_dpr.py +19 -2
  391. transformers/models/dpr/modeling_dpr.py +37 -39
  392. transformers/models/dpr/tokenization_dpr.py +7 -9
  393. transformers/models/dpr/tokenization_dpr_fast.py +7 -9
  394. transformers/models/dpt/configuration_dpt.py +23 -66
  395. transformers/models/dpt/image_processing_dpt.py +65 -66
  396. transformers/models/dpt/image_processing_dpt_fast.py +18 -19
  397. transformers/models/dpt/modeling_dpt.py +38 -36
  398. transformers/models/dpt/modular_dpt.py +14 -15
  399. transformers/models/edgetam/configuration_edgetam.py +1 -2
  400. transformers/models/edgetam/modeling_edgetam.py +87 -89
  401. transformers/models/edgetam/modular_edgetam.py +7 -13
  402. transformers/models/edgetam_video/__init__.py +0 -1
  403. transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
  404. transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
  405. transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
  406. transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
  407. transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
  408. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
  409. transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
  410. transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
  411. transformers/models/efficientnet/configuration_efficientnet.py +0 -1
  412. transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
  413. transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
  414. transformers/models/efficientnet/modeling_efficientnet.py +12 -14
  415. transformers/models/electra/configuration_electra.py +13 -3
  416. transformers/models/electra/modeling_electra.py +107 -109
  417. transformers/models/emu3/configuration_emu3.py +17 -17
  418. transformers/models/emu3/image_processing_emu3.py +44 -39
  419. transformers/models/emu3/modeling_emu3.py +143 -109
  420. transformers/models/emu3/modular_emu3.py +109 -73
  421. transformers/models/emu3/processing_emu3.py +18 -43
  422. transformers/models/encodec/configuration_encodec.py +2 -4
  423. transformers/models/encodec/feature_extraction_encodec.py +10 -13
  424. transformers/models/encodec/modeling_encodec.py +25 -29
  425. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
  426. transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
  427. transformers/models/eomt/configuration_eomt.py +12 -14
  428. transformers/models/eomt/image_processing_eomt.py +53 -55
  429. transformers/models/eomt/image_processing_eomt_fast.py +18 -19
  430. transformers/models/eomt/modeling_eomt.py +19 -21
  431. transformers/models/eomt/modular_eomt.py +28 -30
  432. transformers/models/eomt_dinov3/__init__.py +28 -0
  433. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  434. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  435. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  436. transformers/models/ernie/configuration_ernie.py +24 -3
  437. transformers/models/ernie/modeling_ernie.py +127 -162
  438. transformers/models/ernie/modular_ernie.py +91 -103
  439. transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
  440. transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
  441. transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
  442. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
  443. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
  444. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
  445. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
  446. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
  447. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
  448. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
  449. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
  450. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
  451. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
  452. transformers/models/esm/configuration_esm.py +11 -15
  453. transformers/models/esm/modeling_esm.py +35 -37
  454. transformers/models/esm/modeling_esmfold.py +43 -50
  455. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  456. transformers/models/esm/openfold_utils/loss.py +1 -2
  457. transformers/models/esm/openfold_utils/protein.py +15 -16
  458. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  459. transformers/models/esm/tokenization_esm.py +2 -4
  460. transformers/models/evolla/configuration_evolla.py +50 -40
  461. transformers/models/evolla/modeling_evolla.py +69 -68
  462. transformers/models/evolla/modular_evolla.py +50 -48
  463. transformers/models/evolla/processing_evolla.py +23 -35
  464. transformers/models/exaone4/configuration_exaone4.py +27 -27
  465. transformers/models/exaone4/modeling_exaone4.py +36 -39
  466. transformers/models/exaone4/modular_exaone4.py +51 -50
  467. transformers/models/exaone_moe/__init__.py +27 -0
  468. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  469. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  470. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  471. transformers/models/falcon/configuration_falcon.py +31 -26
  472. transformers/models/falcon/modeling_falcon.py +76 -84
  473. transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
  474. transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
  475. transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
  476. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
  477. transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
  478. transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
  479. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
  480. transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
  481. transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
  482. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
  483. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
  484. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
  485. transformers/models/flaubert/configuration_flaubert.py +10 -5
  486. transformers/models/flaubert/modeling_flaubert.py +125 -129
  487. transformers/models/flaubert/tokenization_flaubert.py +3 -5
  488. transformers/models/flava/configuration_flava.py +9 -9
  489. transformers/models/flava/image_processing_flava.py +66 -67
  490. transformers/models/flava/image_processing_flava_fast.py +46 -47
  491. transformers/models/flava/modeling_flava.py +144 -135
  492. transformers/models/flava/processing_flava.py +2 -12
  493. transformers/models/flex_olmo/__init__.py +0 -1
  494. transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
  495. transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
  496. transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
  497. transformers/models/florence2/configuration_florence2.py +4 -1
  498. transformers/models/florence2/modeling_florence2.py +96 -72
  499. transformers/models/florence2/modular_florence2.py +100 -107
  500. transformers/models/florence2/processing_florence2.py +18 -47
  501. transformers/models/fnet/configuration_fnet.py +6 -2
  502. transformers/models/fnet/modeling_fnet.py +69 -80
  503. transformers/models/fnet/tokenization_fnet.py +0 -1
  504. transformers/models/focalnet/configuration_focalnet.py +2 -5
  505. transformers/models/focalnet/modeling_focalnet.py +49 -48
  506. transformers/models/fsmt/configuration_fsmt.py +12 -17
  507. transformers/models/fsmt/modeling_fsmt.py +47 -48
  508. transformers/models/fsmt/tokenization_fsmt.py +3 -5
  509. transformers/models/funnel/configuration_funnel.py +8 -1
  510. transformers/models/funnel/modeling_funnel.py +91 -93
  511. transformers/models/funnel/tokenization_funnel.py +2 -5
  512. transformers/models/fuyu/configuration_fuyu.py +28 -34
  513. transformers/models/fuyu/image_processing_fuyu.py +29 -31
  514. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  515. transformers/models/fuyu/modeling_fuyu.py +50 -52
  516. transformers/models/fuyu/processing_fuyu.py +9 -36
  517. transformers/models/gemma/configuration_gemma.py +25 -30
  518. transformers/models/gemma/modeling_gemma.py +36 -38
  519. transformers/models/gemma/modular_gemma.py +33 -36
  520. transformers/models/gemma/tokenization_gemma.py +3 -6
  521. transformers/models/gemma2/configuration_gemma2.py +30 -35
  522. transformers/models/gemma2/modeling_gemma2.py +38 -41
  523. transformers/models/gemma2/modular_gemma2.py +63 -67
  524. transformers/models/gemma3/configuration_gemma3.py +53 -48
  525. transformers/models/gemma3/image_processing_gemma3.py +29 -31
  526. transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
  527. transformers/models/gemma3/modeling_gemma3.py +123 -122
  528. transformers/models/gemma3/modular_gemma3.py +128 -125
  529. transformers/models/gemma3/processing_gemma3.py +5 -5
  530. transformers/models/gemma3n/configuration_gemma3n.py +42 -30
  531. transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
  532. transformers/models/gemma3n/modeling_gemma3n.py +166 -147
  533. transformers/models/gemma3n/modular_gemma3n.py +176 -148
  534. transformers/models/gemma3n/processing_gemma3n.py +12 -26
  535. transformers/models/git/configuration_git.py +5 -8
  536. transformers/models/git/modeling_git.py +115 -127
  537. transformers/models/git/processing_git.py +2 -14
  538. transformers/models/glm/configuration_glm.py +26 -30
  539. transformers/models/glm/modeling_glm.py +36 -39
  540. transformers/models/glm/modular_glm.py +4 -7
  541. transformers/models/glm4/configuration_glm4.py +26 -30
  542. transformers/models/glm4/modeling_glm4.py +39 -41
  543. transformers/models/glm4/modular_glm4.py +8 -10
  544. transformers/models/glm46v/configuration_glm46v.py +4 -1
  545. transformers/models/glm46v/image_processing_glm46v.py +40 -38
  546. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  547. transformers/models/glm46v/modeling_glm46v.py +138 -93
  548. transformers/models/glm46v/modular_glm46v.py +5 -3
  549. transformers/models/glm46v/processing_glm46v.py +7 -41
  550. transformers/models/glm46v/video_processing_glm46v.py +9 -11
  551. transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
  552. transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
  553. transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
  554. transformers/models/glm4_moe_lite/__init__.py +28 -0
  555. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
  556. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
  557. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
  558. transformers/models/glm4v/configuration_glm4v.py +25 -24
  559. transformers/models/glm4v/image_processing_glm4v.py +39 -38
  560. transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
  561. transformers/models/glm4v/modeling_glm4v.py +249 -210
  562. transformers/models/glm4v/modular_glm4v.py +211 -230
  563. transformers/models/glm4v/processing_glm4v.py +7 -41
  564. transformers/models/glm4v/video_processing_glm4v.py +9 -11
  565. transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
  566. transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
  567. transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
  568. transformers/models/glm_image/__init__.py +31 -0
  569. transformers/models/glm_image/configuration_glm_image.py +358 -0
  570. transformers/models/glm_image/image_processing_glm_image.py +503 -0
  571. transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
  572. transformers/models/glm_image/modeling_glm_image.py +1691 -0
  573. transformers/models/glm_image/modular_glm_image.py +1640 -0
  574. transformers/models/glm_image/processing_glm_image.py +265 -0
  575. transformers/models/glm_ocr/__init__.py +28 -0
  576. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  577. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  578. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  579. transformers/models/glmasr/__init__.py +0 -1
  580. transformers/models/glmasr/configuration_glmasr.py +0 -1
  581. transformers/models/glmasr/modeling_glmasr.py +51 -46
  582. transformers/models/glmasr/modular_glmasr.py +39 -29
  583. transformers/models/glmasr/processing_glmasr.py +7 -8
  584. transformers/models/glpn/configuration_glpn.py +0 -1
  585. transformers/models/glpn/image_processing_glpn.py +11 -12
  586. transformers/models/glpn/image_processing_glpn_fast.py +11 -12
  587. transformers/models/glpn/modeling_glpn.py +14 -14
  588. transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
  589. transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
  590. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
  591. transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
  592. transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
  593. transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
  594. transformers/models/gpt2/configuration_gpt2.py +13 -2
  595. transformers/models/gpt2/modeling_gpt2.py +111 -113
  596. transformers/models/gpt2/tokenization_gpt2.py +6 -9
  597. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
  598. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
  599. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
  600. transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
  601. transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
  602. transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
  603. transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
  604. transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
  605. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
  606. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
  607. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
  608. transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
  609. transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
  610. transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
  611. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  612. transformers/models/gptj/configuration_gptj.py +4 -5
  613. transformers/models/gptj/modeling_gptj.py +85 -88
  614. transformers/models/granite/configuration_granite.py +28 -33
  615. transformers/models/granite/modeling_granite.py +43 -45
  616. transformers/models/granite/modular_granite.py +29 -31
  617. transformers/models/granite_speech/configuration_granite_speech.py +0 -1
  618. transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
  619. transformers/models/granite_speech/modeling_granite_speech.py +84 -60
  620. transformers/models/granite_speech/processing_granite_speech.py +11 -4
  621. transformers/models/granitemoe/configuration_granitemoe.py +31 -36
  622. transformers/models/granitemoe/modeling_granitemoe.py +39 -41
  623. transformers/models/granitemoe/modular_granitemoe.py +21 -23
  624. transformers/models/granitemoehybrid/__init__.py +0 -1
  625. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
  626. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
  627. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
  628. transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
  629. transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
  630. transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
  631. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
  632. transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
  633. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
  634. transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
  635. transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
  636. transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
  637. transformers/models/groupvit/configuration_groupvit.py +4 -2
  638. transformers/models/groupvit/modeling_groupvit.py +98 -92
  639. transformers/models/helium/configuration_helium.py +25 -29
  640. transformers/models/helium/modeling_helium.py +37 -40
  641. transformers/models/helium/modular_helium.py +3 -7
  642. transformers/models/herbert/tokenization_herbert.py +4 -6
  643. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
  644. transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
  645. transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
  646. transformers/models/hiera/configuration_hiera.py +2 -5
  647. transformers/models/hiera/modeling_hiera.py +71 -70
  648. transformers/models/hubert/configuration_hubert.py +4 -2
  649. transformers/models/hubert/modeling_hubert.py +42 -41
  650. transformers/models/hubert/modular_hubert.py +8 -11
  651. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
  652. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
  653. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
  654. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
  655. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
  656. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
  657. transformers/models/ibert/configuration_ibert.py +4 -2
  658. transformers/models/ibert/modeling_ibert.py +60 -62
  659. transformers/models/ibert/quant_modules.py +0 -1
  660. transformers/models/idefics/configuration_idefics.py +5 -8
  661. transformers/models/idefics/image_processing_idefics.py +13 -15
  662. transformers/models/idefics/modeling_idefics.py +63 -65
  663. transformers/models/idefics/perceiver.py +1 -3
  664. transformers/models/idefics/processing_idefics.py +32 -48
  665. transformers/models/idefics/vision.py +27 -28
  666. transformers/models/idefics2/configuration_idefics2.py +1 -3
  667. transformers/models/idefics2/image_processing_idefics2.py +31 -32
  668. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  669. transformers/models/idefics2/modeling_idefics2.py +126 -106
  670. transformers/models/idefics2/processing_idefics2.py +10 -68
  671. transformers/models/idefics3/configuration_idefics3.py +1 -4
  672. transformers/models/idefics3/image_processing_idefics3.py +42 -43
  673. transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
  674. transformers/models/idefics3/modeling_idefics3.py +113 -92
  675. transformers/models/idefics3/processing_idefics3.py +15 -69
  676. transformers/models/ijepa/configuration_ijepa.py +0 -1
  677. transformers/models/ijepa/modeling_ijepa.py +13 -14
  678. transformers/models/ijepa/modular_ijepa.py +5 -7
  679. transformers/models/imagegpt/configuration_imagegpt.py +9 -2
  680. transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
  681. transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
  682. transformers/models/imagegpt/modeling_imagegpt.py +65 -62
  683. transformers/models/informer/configuration_informer.py +6 -9
  684. transformers/models/informer/modeling_informer.py +87 -89
  685. transformers/models/informer/modular_informer.py +13 -16
  686. transformers/models/instructblip/configuration_instructblip.py +2 -2
  687. transformers/models/instructblip/modeling_instructblip.py +104 -79
  688. transformers/models/instructblip/processing_instructblip.py +10 -36
  689. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  690. transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
  691. transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
  692. transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
  693. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
  694. transformers/models/internvl/configuration_internvl.py +5 -1
  695. transformers/models/internvl/modeling_internvl.py +76 -98
  696. transformers/models/internvl/modular_internvl.py +45 -59
  697. transformers/models/internvl/processing_internvl.py +12 -45
  698. transformers/models/internvl/video_processing_internvl.py +10 -11
  699. transformers/models/jais2/configuration_jais2.py +25 -29
  700. transformers/models/jais2/modeling_jais2.py +36 -38
  701. transformers/models/jais2/modular_jais2.py +20 -22
  702. transformers/models/jamba/configuration_jamba.py +5 -8
  703. transformers/models/jamba/modeling_jamba.py +47 -50
  704. transformers/models/jamba/modular_jamba.py +40 -41
  705. transformers/models/janus/configuration_janus.py +0 -1
  706. transformers/models/janus/image_processing_janus.py +37 -39
  707. transformers/models/janus/image_processing_janus_fast.py +20 -21
  708. transformers/models/janus/modeling_janus.py +103 -188
  709. transformers/models/janus/modular_janus.py +122 -83
  710. transformers/models/janus/processing_janus.py +17 -43
  711. transformers/models/jetmoe/configuration_jetmoe.py +26 -27
  712. transformers/models/jetmoe/modeling_jetmoe.py +42 -45
  713. transformers/models/jetmoe/modular_jetmoe.py +33 -36
  714. transformers/models/kosmos2/configuration_kosmos2.py +10 -9
  715. transformers/models/kosmos2/modeling_kosmos2.py +199 -178
  716. transformers/models/kosmos2/processing_kosmos2.py +40 -55
  717. transformers/models/kosmos2_5/__init__.py +0 -1
  718. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
  719. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
  720. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
  721. transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
  722. transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
  723. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
  724. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
  725. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
  726. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
  727. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
  728. transformers/models/lasr/configuration_lasr.py +3 -7
  729. transformers/models/lasr/feature_extraction_lasr.py +10 -12
  730. transformers/models/lasr/modeling_lasr.py +21 -24
  731. transformers/models/lasr/modular_lasr.py +11 -13
  732. transformers/models/lasr/processing_lasr.py +12 -6
  733. transformers/models/lasr/tokenization_lasr.py +2 -4
  734. transformers/models/layoutlm/configuration_layoutlm.py +14 -2
  735. transformers/models/layoutlm/modeling_layoutlm.py +70 -72
  736. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
  737. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
  738. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
  739. transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
  740. transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
  741. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
  742. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
  743. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
  744. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
  745. transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
  746. transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
  747. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
  748. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
  749. transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
  750. transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
  751. transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
  752. transformers/models/led/configuration_led.py +8 -12
  753. transformers/models/led/modeling_led.py +113 -267
  754. transformers/models/levit/configuration_levit.py +0 -1
  755. transformers/models/levit/image_processing_levit.py +19 -21
  756. transformers/models/levit/image_processing_levit_fast.py +4 -5
  757. transformers/models/levit/modeling_levit.py +17 -19
  758. transformers/models/lfm2/configuration_lfm2.py +27 -30
  759. transformers/models/lfm2/modeling_lfm2.py +46 -48
  760. transformers/models/lfm2/modular_lfm2.py +32 -32
  761. transformers/models/lfm2_moe/__init__.py +0 -1
  762. transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
  763. transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
  764. transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
  765. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
  766. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
  767. transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
  768. transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
  769. transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
  770. transformers/models/lightglue/image_processing_lightglue.py +16 -15
  771. transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
  772. transformers/models/lightglue/modeling_lightglue.py +31 -33
  773. transformers/models/lightglue/modular_lightglue.py +31 -31
  774. transformers/models/lighton_ocr/__init__.py +28 -0
  775. transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
  776. transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
  777. transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
  778. transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
  779. transformers/models/lilt/configuration_lilt.py +6 -2
  780. transformers/models/lilt/modeling_lilt.py +53 -55
  781. transformers/models/llama/configuration_llama.py +26 -31
  782. transformers/models/llama/modeling_llama.py +35 -38
  783. transformers/models/llama/tokenization_llama.py +2 -4
  784. transformers/models/llama4/configuration_llama4.py +87 -69
  785. transformers/models/llama4/image_processing_llama4_fast.py +11 -12
  786. transformers/models/llama4/modeling_llama4.py +116 -115
  787. transformers/models/llama4/processing_llama4.py +33 -57
  788. transformers/models/llava/configuration_llava.py +10 -1
  789. transformers/models/llava/image_processing_llava.py +25 -28
  790. transformers/models/llava/image_processing_llava_fast.py +9 -10
  791. transformers/models/llava/modeling_llava.py +73 -102
  792. transformers/models/llava/processing_llava.py +18 -51
  793. transformers/models/llava_next/configuration_llava_next.py +2 -2
  794. transformers/models/llava_next/image_processing_llava_next.py +43 -45
  795. transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
  796. transformers/models/llava_next/modeling_llava_next.py +103 -104
  797. transformers/models/llava_next/processing_llava_next.py +18 -47
  798. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
  799. transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
  800. transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
  801. transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
  802. transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
  803. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
  804. transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
  805. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
  806. transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
  807. transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
  808. transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
  809. transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
  810. transformers/models/longcat_flash/__init__.py +0 -1
  811. transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
  812. transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
  813. transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
  814. transformers/models/longformer/configuration_longformer.py +5 -5
  815. transformers/models/longformer/modeling_longformer.py +99 -101
  816. transformers/models/longt5/configuration_longt5.py +9 -7
  817. transformers/models/longt5/modeling_longt5.py +45 -45
  818. transformers/models/luke/configuration_luke.py +8 -2
  819. transformers/models/luke/modeling_luke.py +179 -181
  820. transformers/models/luke/tokenization_luke.py +99 -105
  821. transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
  822. transformers/models/lw_detr/configuration_lw_detr.py +362 -0
  823. transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
  824. transformers/models/lw_detr/modular_lw_detr.py +1609 -0
  825. transformers/models/lxmert/configuration_lxmert.py +16 -1
  826. transformers/models/lxmert/modeling_lxmert.py +63 -74
  827. transformers/models/m2m_100/configuration_m2m_100.py +7 -9
  828. transformers/models/m2m_100/modeling_m2m_100.py +72 -74
  829. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  830. transformers/models/mamba/configuration_mamba.py +5 -3
  831. transformers/models/mamba/modeling_mamba.py +61 -70
  832. transformers/models/mamba2/configuration_mamba2.py +5 -8
  833. transformers/models/mamba2/modeling_mamba2.py +66 -79
  834. transformers/models/marian/configuration_marian.py +10 -5
  835. transformers/models/marian/modeling_marian.py +88 -90
  836. transformers/models/marian/tokenization_marian.py +6 -6
  837. transformers/models/markuplm/configuration_markuplm.py +4 -7
  838. transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
  839. transformers/models/markuplm/modeling_markuplm.py +63 -65
  840. transformers/models/markuplm/processing_markuplm.py +31 -38
  841. transformers/models/markuplm/tokenization_markuplm.py +67 -77
  842. transformers/models/mask2former/configuration_mask2former.py +14 -52
  843. transformers/models/mask2former/image_processing_mask2former.py +84 -85
  844. transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
  845. transformers/models/mask2former/modeling_mask2former.py +108 -104
  846. transformers/models/mask2former/modular_mask2former.py +6 -8
  847. transformers/models/maskformer/configuration_maskformer.py +17 -51
  848. transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
  849. transformers/models/maskformer/image_processing_maskformer.py +84 -85
  850. transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
  851. transformers/models/maskformer/modeling_maskformer.py +71 -67
  852. transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
  853. transformers/models/mbart/configuration_mbart.py +9 -5
  854. transformers/models/mbart/modeling_mbart.py +120 -119
  855. transformers/models/mbart/tokenization_mbart.py +2 -4
  856. transformers/models/mbart50/tokenization_mbart50.py +3 -5
  857. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
  858. transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
  859. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  860. transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
  861. transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
  862. transformers/models/mgp_str/configuration_mgp_str.py +0 -1
  863. transformers/models/mgp_str/modeling_mgp_str.py +18 -18
  864. transformers/models/mgp_str/processing_mgp_str.py +3 -20
  865. transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
  866. transformers/models/mimi/configuration_mimi.py +42 -40
  867. transformers/models/mimi/modeling_mimi.py +116 -115
  868. transformers/models/minimax/__init__.py +0 -1
  869. transformers/models/minimax/configuration_minimax.py +40 -47
  870. transformers/models/minimax/modeling_minimax.py +46 -49
  871. transformers/models/minimax/modular_minimax.py +59 -65
  872. transformers/models/minimax_m2/__init__.py +28 -0
  873. transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
  874. transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
  875. transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
  876. transformers/models/ministral/configuration_ministral.py +25 -29
  877. transformers/models/ministral/modeling_ministral.py +35 -37
  878. transformers/models/ministral/modular_ministral.py +32 -37
  879. transformers/models/ministral3/configuration_ministral3.py +23 -26
  880. transformers/models/ministral3/modeling_ministral3.py +35 -37
  881. transformers/models/ministral3/modular_ministral3.py +7 -8
  882. transformers/models/mistral/configuration_mistral.py +24 -29
  883. transformers/models/mistral/modeling_mistral.py +35 -37
  884. transformers/models/mistral/modular_mistral.py +14 -15
  885. transformers/models/mistral3/configuration_mistral3.py +4 -1
  886. transformers/models/mistral3/modeling_mistral3.py +79 -82
  887. transformers/models/mistral3/modular_mistral3.py +66 -67
  888. transformers/models/mixtral/configuration_mixtral.py +32 -38
  889. transformers/models/mixtral/modeling_mixtral.py +39 -42
  890. transformers/models/mixtral/modular_mixtral.py +26 -29
  891. transformers/models/mlcd/configuration_mlcd.py +0 -1
  892. transformers/models/mlcd/modeling_mlcd.py +17 -17
  893. transformers/models/mlcd/modular_mlcd.py +16 -16
  894. transformers/models/mllama/configuration_mllama.py +10 -15
  895. transformers/models/mllama/image_processing_mllama.py +23 -25
  896. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  897. transformers/models/mllama/modeling_mllama.py +100 -103
  898. transformers/models/mllama/processing_mllama.py +6 -55
  899. transformers/models/mluke/tokenization_mluke.py +97 -103
  900. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
  901. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
  902. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
  903. transformers/models/mobilebert/configuration_mobilebert.py +4 -2
  904. transformers/models/mobilebert/modeling_mobilebert.py +78 -88
  905. transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
  906. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
  907. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
  908. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
  909. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
  910. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
  911. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
  912. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
  913. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
  914. transformers/models/mobilevit/configuration_mobilevit.py +0 -1
  915. transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
  916. transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
  917. transformers/models/mobilevit/modeling_mobilevit.py +21 -21
  918. transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
  919. transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
  920. transformers/models/modernbert/configuration_modernbert.py +76 -51
  921. transformers/models/modernbert/modeling_modernbert.py +188 -943
  922. transformers/models/modernbert/modular_modernbert.py +255 -978
  923. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
  924. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
  925. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
  926. transformers/models/moonshine/configuration_moonshine.py +34 -31
  927. transformers/models/moonshine/modeling_moonshine.py +70 -72
  928. transformers/models/moonshine/modular_moonshine.py +91 -86
  929. transformers/models/moshi/configuration_moshi.py +46 -23
  930. transformers/models/moshi/modeling_moshi.py +134 -142
  931. transformers/models/mpnet/configuration_mpnet.py +6 -2
  932. transformers/models/mpnet/modeling_mpnet.py +55 -57
  933. transformers/models/mpnet/tokenization_mpnet.py +1 -4
  934. transformers/models/mpt/configuration_mpt.py +17 -9
  935. transformers/models/mpt/modeling_mpt.py +58 -60
  936. transformers/models/mra/configuration_mra.py +8 -2
  937. transformers/models/mra/modeling_mra.py +54 -56
  938. transformers/models/mt5/configuration_mt5.py +9 -6
  939. transformers/models/mt5/modeling_mt5.py +80 -85
  940. transformers/models/musicgen/configuration_musicgen.py +12 -8
  941. transformers/models/musicgen/modeling_musicgen.py +114 -116
  942. transformers/models/musicgen/processing_musicgen.py +3 -21
  943. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
  944. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
  945. transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
  946. transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
  947. transformers/models/mvp/configuration_mvp.py +8 -5
  948. transformers/models/mvp/modeling_mvp.py +121 -123
  949. transformers/models/myt5/tokenization_myt5.py +8 -10
  950. transformers/models/nanochat/configuration_nanochat.py +5 -8
  951. transformers/models/nanochat/modeling_nanochat.py +36 -39
  952. transformers/models/nanochat/modular_nanochat.py +16 -18
  953. transformers/models/nemotron/configuration_nemotron.py +25 -30
  954. transformers/models/nemotron/modeling_nemotron.py +53 -66
  955. transformers/models/nllb/tokenization_nllb.py +14 -14
  956. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
  957. transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
  958. transformers/models/nougat/image_processing_nougat.py +29 -32
  959. transformers/models/nougat/image_processing_nougat_fast.py +12 -13
  960. transformers/models/nougat/processing_nougat.py +37 -39
  961. transformers/models/nougat/tokenization_nougat.py +5 -7
  962. transformers/models/nystromformer/configuration_nystromformer.py +8 -2
  963. transformers/models/nystromformer/modeling_nystromformer.py +61 -63
  964. transformers/models/olmo/configuration_olmo.py +23 -28
  965. transformers/models/olmo/modeling_olmo.py +35 -38
  966. transformers/models/olmo/modular_olmo.py +8 -12
  967. transformers/models/olmo2/configuration_olmo2.py +27 -32
  968. transformers/models/olmo2/modeling_olmo2.py +36 -39
  969. transformers/models/olmo2/modular_olmo2.py +36 -38
  970. transformers/models/olmo3/__init__.py +0 -1
  971. transformers/models/olmo3/configuration_olmo3.py +30 -34
  972. transformers/models/olmo3/modeling_olmo3.py +35 -38
  973. transformers/models/olmo3/modular_olmo3.py +44 -47
  974. transformers/models/olmoe/configuration_olmoe.py +29 -33
  975. transformers/models/olmoe/modeling_olmoe.py +41 -43
  976. transformers/models/olmoe/modular_olmoe.py +15 -16
  977. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
  978. transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
  979. transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
  980. transformers/models/oneformer/configuration_oneformer.py +11 -51
  981. transformers/models/oneformer/image_processing_oneformer.py +83 -84
  982. transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
  983. transformers/models/oneformer/modeling_oneformer.py +137 -133
  984. transformers/models/oneformer/processing_oneformer.py +28 -43
  985. transformers/models/openai/configuration_openai.py +16 -1
  986. transformers/models/openai/modeling_openai.py +50 -51
  987. transformers/models/openai/tokenization_openai.py +2 -5
  988. transformers/models/opt/configuration_opt.py +6 -7
  989. transformers/models/opt/modeling_opt.py +79 -80
  990. transformers/models/ovis2/__init__.py +0 -1
  991. transformers/models/ovis2/configuration_ovis2.py +4 -1
  992. transformers/models/ovis2/image_processing_ovis2.py +22 -24
  993. transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
  994. transformers/models/ovis2/modeling_ovis2.py +99 -142
  995. transformers/models/ovis2/modular_ovis2.py +82 -45
  996. transformers/models/ovis2/processing_ovis2.py +12 -40
  997. transformers/models/owlv2/configuration_owlv2.py +4 -2
  998. transformers/models/owlv2/image_processing_owlv2.py +20 -21
  999. transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
  1000. transformers/models/owlv2/modeling_owlv2.py +122 -114
  1001. transformers/models/owlv2/modular_owlv2.py +11 -12
  1002. transformers/models/owlv2/processing_owlv2.py +20 -49
  1003. transformers/models/owlvit/configuration_owlvit.py +4 -2
  1004. transformers/models/owlvit/image_processing_owlvit.py +21 -22
  1005. transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
  1006. transformers/models/owlvit/modeling_owlvit.py +121 -113
  1007. transformers/models/owlvit/processing_owlvit.py +20 -48
  1008. transformers/models/paddleocr_vl/__init__.py +0 -1
  1009. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
  1010. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
  1011. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
  1012. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
  1013. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
  1014. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
  1015. transformers/models/paligemma/configuration_paligemma.py +4 -1
  1016. transformers/models/paligemma/modeling_paligemma.py +81 -79
  1017. transformers/models/paligemma/processing_paligemma.py +13 -66
  1018. transformers/models/parakeet/configuration_parakeet.py +3 -8
  1019. transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
  1020. transformers/models/parakeet/modeling_parakeet.py +21 -25
  1021. transformers/models/parakeet/modular_parakeet.py +19 -21
  1022. transformers/models/parakeet/processing_parakeet.py +12 -5
  1023. transformers/models/parakeet/tokenization_parakeet.py +2 -4
  1024. transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
  1025. transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
  1026. transformers/models/patchtst/configuration_patchtst.py +6 -9
  1027. transformers/models/patchtst/modeling_patchtst.py +75 -77
  1028. transformers/models/pe_audio/__init__.py +0 -1
  1029. transformers/models/pe_audio/configuration_pe_audio.py +14 -16
  1030. transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
  1031. transformers/models/pe_audio/modeling_pe_audio.py +30 -31
  1032. transformers/models/pe_audio/modular_pe_audio.py +17 -18
  1033. transformers/models/pe_audio/processing_pe_audio.py +0 -1
  1034. transformers/models/pe_audio_video/__init__.py +0 -1
  1035. transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
  1036. transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
  1037. transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
  1038. transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
  1039. transformers/models/pe_video/__init__.py +0 -1
  1040. transformers/models/pe_video/configuration_pe_video.py +14 -16
  1041. transformers/models/pe_video/modeling_pe_video.py +57 -46
  1042. transformers/models/pe_video/modular_pe_video.py +47 -35
  1043. transformers/models/pe_video/video_processing_pe_video.py +2 -4
  1044. transformers/models/pegasus/configuration_pegasus.py +8 -6
  1045. transformers/models/pegasus/modeling_pegasus.py +67 -69
  1046. transformers/models/pegasus/tokenization_pegasus.py +1 -4
  1047. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
  1048. transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
  1049. transformers/models/perceiver/configuration_perceiver.py +0 -1
  1050. transformers/models/perceiver/image_processing_perceiver.py +22 -25
  1051. transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
  1052. transformers/models/perceiver/modeling_perceiver.py +152 -145
  1053. transformers/models/perceiver/tokenization_perceiver.py +3 -6
  1054. transformers/models/perception_lm/configuration_perception_lm.py +0 -1
  1055. transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
  1056. transformers/models/perception_lm/modeling_perception_lm.py +64 -67
  1057. transformers/models/perception_lm/modular_perception_lm.py +58 -58
  1058. transformers/models/perception_lm/processing_perception_lm.py +13 -47
  1059. transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
  1060. transformers/models/persimmon/configuration_persimmon.py +23 -28
  1061. transformers/models/persimmon/modeling_persimmon.py +44 -47
  1062. transformers/models/phi/configuration_phi.py +27 -28
  1063. transformers/models/phi/modeling_phi.py +39 -41
  1064. transformers/models/phi/modular_phi.py +26 -26
  1065. transformers/models/phi3/configuration_phi3.py +32 -37
  1066. transformers/models/phi3/modeling_phi3.py +37 -40
  1067. transformers/models/phi3/modular_phi3.py +16 -20
  1068. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
  1069. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
  1070. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  1071. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
  1072. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
  1073. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
  1074. transformers/models/phimoe/configuration_phimoe.py +31 -36
  1075. transformers/models/phimoe/modeling_phimoe.py +50 -77
  1076. transformers/models/phimoe/modular_phimoe.py +12 -8
  1077. transformers/models/phobert/tokenization_phobert.py +4 -6
  1078. transformers/models/pix2struct/configuration_pix2struct.py +12 -10
  1079. transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
  1080. transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
  1081. transformers/models/pix2struct/modeling_pix2struct.py +56 -52
  1082. transformers/models/pix2struct/processing_pix2struct.py +5 -26
  1083. transformers/models/pixio/__init__.py +0 -1
  1084. transformers/models/pixio/configuration_pixio.py +2 -5
  1085. transformers/models/pixio/modeling_pixio.py +16 -17
  1086. transformers/models/pixio/modular_pixio.py +7 -8
  1087. transformers/models/pixtral/configuration_pixtral.py +11 -14
  1088. transformers/models/pixtral/image_processing_pixtral.py +26 -28
  1089. transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
  1090. transformers/models/pixtral/modeling_pixtral.py +31 -37
  1091. transformers/models/pixtral/processing_pixtral.py +18 -52
  1092. transformers/models/plbart/configuration_plbart.py +8 -6
  1093. transformers/models/plbart/modeling_plbart.py +109 -109
  1094. transformers/models/plbart/modular_plbart.py +31 -33
  1095. transformers/models/plbart/tokenization_plbart.py +4 -5
  1096. transformers/models/poolformer/configuration_poolformer.py +0 -1
  1097. transformers/models/poolformer/image_processing_poolformer.py +21 -24
  1098. transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
  1099. transformers/models/poolformer/modeling_poolformer.py +10 -12
  1100. transformers/models/pop2piano/configuration_pop2piano.py +7 -7
  1101. transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
  1102. transformers/models/pop2piano/modeling_pop2piano.py +24 -24
  1103. transformers/models/pop2piano/processing_pop2piano.py +25 -33
  1104. transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
  1105. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  1106. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  1107. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  1108. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  1109. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  1110. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
  1111. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1112. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
  1113. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
  1114. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
  1115. transformers/models/prophetnet/configuration_prophetnet.py +37 -38
  1116. transformers/models/prophetnet/modeling_prophetnet.py +121 -153
  1117. transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
  1118. transformers/models/pvt/configuration_pvt.py +0 -1
  1119. transformers/models/pvt/image_processing_pvt.py +24 -27
  1120. transformers/models/pvt/image_processing_pvt_fast.py +1 -2
  1121. transformers/models/pvt/modeling_pvt.py +19 -21
  1122. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
  1123. transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
  1124. transformers/models/qwen2/configuration_qwen2.py +32 -25
  1125. transformers/models/qwen2/modeling_qwen2.py +35 -37
  1126. transformers/models/qwen2/modular_qwen2.py +14 -15
  1127. transformers/models/qwen2/tokenization_qwen2.py +2 -9
  1128. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
  1129. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
  1130. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
  1131. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
  1132. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
  1133. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
  1134. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
  1135. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
  1136. transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
  1137. transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
  1138. transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
  1139. transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
  1140. transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
  1141. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
  1142. transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
  1143. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
  1144. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
  1145. transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
  1146. transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
  1147. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
  1148. transformers/models/qwen3/configuration_qwen3.py +34 -27
  1149. transformers/models/qwen3/modeling_qwen3.py +35 -38
  1150. transformers/models/qwen3/modular_qwen3.py +7 -9
  1151. transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
  1152. transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
  1153. transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
  1154. transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
  1155. transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
  1156. transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
  1157. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
  1158. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
  1159. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
  1160. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
  1161. transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
  1162. transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
  1163. transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
  1164. transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
  1165. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
  1166. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
  1167. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
  1168. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
  1169. transformers/models/rag/configuration_rag.py +6 -7
  1170. transformers/models/rag/modeling_rag.py +119 -121
  1171. transformers/models/rag/retrieval_rag.py +3 -5
  1172. transformers/models/rag/tokenization_rag.py +0 -50
  1173. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
  1174. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
  1175. transformers/models/reformer/configuration_reformer.py +7 -8
  1176. transformers/models/reformer/modeling_reformer.py +67 -68
  1177. transformers/models/reformer/tokenization_reformer.py +3 -6
  1178. transformers/models/regnet/configuration_regnet.py +0 -1
  1179. transformers/models/regnet/modeling_regnet.py +7 -9
  1180. transformers/models/rembert/configuration_rembert.py +8 -2
  1181. transformers/models/rembert/modeling_rembert.py +108 -132
  1182. transformers/models/rembert/tokenization_rembert.py +1 -4
  1183. transformers/models/resnet/configuration_resnet.py +2 -5
  1184. transformers/models/resnet/modeling_resnet.py +14 -15
  1185. transformers/models/roberta/configuration_roberta.py +11 -3
  1186. transformers/models/roberta/modeling_roberta.py +97 -99
  1187. transformers/models/roberta/modular_roberta.py +55 -58
  1188. transformers/models/roberta/tokenization_roberta.py +2 -5
  1189. transformers/models/roberta/tokenization_roberta_old.py +2 -4
  1190. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
  1191. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
  1192. transformers/models/roc_bert/configuration_roc_bert.py +8 -2
  1193. transformers/models/roc_bert/modeling_roc_bert.py +125 -162
  1194. transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
  1195. transformers/models/roformer/configuration_roformer.py +13 -3
  1196. transformers/models/roformer/modeling_roformer.py +79 -95
  1197. transformers/models/roformer/tokenization_roformer.py +3 -6
  1198. transformers/models/roformer/tokenization_utils.py +0 -1
  1199. transformers/models/rt_detr/configuration_rt_detr.py +8 -50
  1200. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
  1201. transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
  1202. transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
  1203. transformers/models/rt_detr/modeling_rt_detr.py +643 -804
  1204. transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
  1205. transformers/models/rt_detr/modular_rt_detr.py +1522 -20
  1206. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
  1207. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
  1208. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
  1209. transformers/models/rwkv/configuration_rwkv.py +2 -4
  1210. transformers/models/rwkv/modeling_rwkv.py +29 -54
  1211. transformers/models/sam/configuration_sam.py +2 -1
  1212. transformers/models/sam/image_processing_sam.py +59 -60
  1213. transformers/models/sam/image_processing_sam_fast.py +25 -26
  1214. transformers/models/sam/modeling_sam.py +46 -43
  1215. transformers/models/sam/processing_sam.py +39 -27
  1216. transformers/models/sam2/configuration_sam2.py +1 -2
  1217. transformers/models/sam2/image_processing_sam2_fast.py +14 -15
  1218. transformers/models/sam2/modeling_sam2.py +96 -94
  1219. transformers/models/sam2/modular_sam2.py +85 -94
  1220. transformers/models/sam2/processing_sam2.py +31 -47
  1221. transformers/models/sam2_video/configuration_sam2_video.py +0 -1
  1222. transformers/models/sam2_video/modeling_sam2_video.py +114 -116
  1223. transformers/models/sam2_video/modular_sam2_video.py +72 -89
  1224. transformers/models/sam2_video/processing_sam2_video.py +49 -66
  1225. transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
  1226. transformers/models/sam3/configuration_sam3.py +0 -1
  1227. transformers/models/sam3/image_processing_sam3_fast.py +17 -20
  1228. transformers/models/sam3/modeling_sam3.py +94 -100
  1229. transformers/models/sam3/modular_sam3.py +3 -8
  1230. transformers/models/sam3/processing_sam3.py +37 -52
  1231. transformers/models/sam3_tracker/__init__.py +0 -1
  1232. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
  1233. transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
  1234. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
  1235. transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
  1236. transformers/models/sam3_tracker_video/__init__.py +0 -1
  1237. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
  1238. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
  1239. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
  1240. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
  1241. transformers/models/sam3_video/configuration_sam3_video.py +0 -1
  1242. transformers/models/sam3_video/modeling_sam3_video.py +56 -45
  1243. transformers/models/sam3_video/processing_sam3_video.py +25 -45
  1244. transformers/models/sam_hq/__init__.py +1 -1
  1245. transformers/models/sam_hq/configuration_sam_hq.py +2 -1
  1246. transformers/models/sam_hq/modeling_sam_hq.py +52 -50
  1247. transformers/models/sam_hq/modular_sam_hq.py +23 -25
  1248. transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
  1249. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
  1250. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
  1251. transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
  1252. transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
  1253. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
  1254. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
  1255. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
  1256. transformers/models/seed_oss/configuration_seed_oss.py +30 -34
  1257. transformers/models/seed_oss/modeling_seed_oss.py +34 -36
  1258. transformers/models/seed_oss/modular_seed_oss.py +6 -7
  1259. transformers/models/segformer/configuration_segformer.py +0 -10
  1260. transformers/models/segformer/image_processing_segformer.py +39 -42
  1261. transformers/models/segformer/image_processing_segformer_fast.py +11 -12
  1262. transformers/models/segformer/modeling_segformer.py +28 -28
  1263. transformers/models/segformer/modular_segformer.py +8 -9
  1264. transformers/models/seggpt/configuration_seggpt.py +0 -1
  1265. transformers/models/seggpt/image_processing_seggpt.py +38 -41
  1266. transformers/models/seggpt/modeling_seggpt.py +48 -38
  1267. transformers/models/sew/configuration_sew.py +4 -2
  1268. transformers/models/sew/modeling_sew.py +42 -40
  1269. transformers/models/sew/modular_sew.py +12 -13
  1270. transformers/models/sew_d/configuration_sew_d.py +4 -2
  1271. transformers/models/sew_d/modeling_sew_d.py +32 -31
  1272. transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
  1273. transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
  1274. transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
  1275. transformers/models/siglip/configuration_siglip.py +4 -2
  1276. transformers/models/siglip/image_processing_siglip.py +17 -20
  1277. transformers/models/siglip/image_processing_siglip_fast.py +0 -1
  1278. transformers/models/siglip/modeling_siglip.py +65 -110
  1279. transformers/models/siglip/processing_siglip.py +2 -14
  1280. transformers/models/siglip/tokenization_siglip.py +6 -7
  1281. transformers/models/siglip2/__init__.py +1 -0
  1282. transformers/models/siglip2/configuration_siglip2.py +4 -2
  1283. transformers/models/siglip2/image_processing_siglip2.py +15 -16
  1284. transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
  1285. transformers/models/siglip2/modeling_siglip2.py +89 -130
  1286. transformers/models/siglip2/modular_siglip2.py +95 -48
  1287. transformers/models/siglip2/processing_siglip2.py +2 -14
  1288. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  1289. transformers/models/smollm3/configuration_smollm3.py +29 -32
  1290. transformers/models/smollm3/modeling_smollm3.py +35 -38
  1291. transformers/models/smollm3/modular_smollm3.py +36 -38
  1292. transformers/models/smolvlm/configuration_smolvlm.py +2 -4
  1293. transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
  1294. transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
  1295. transformers/models/smolvlm/modeling_smolvlm.py +124 -96
  1296. transformers/models/smolvlm/modular_smolvlm.py +50 -39
  1297. transformers/models/smolvlm/processing_smolvlm.py +15 -76
  1298. transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
  1299. transformers/models/solar_open/__init__.py +27 -0
  1300. transformers/models/solar_open/configuration_solar_open.py +184 -0
  1301. transformers/models/solar_open/modeling_solar_open.py +642 -0
  1302. transformers/models/solar_open/modular_solar_open.py +224 -0
  1303. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
  1304. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
  1305. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1306. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
  1307. transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
  1308. transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
  1309. transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
  1310. transformers/models/speecht5/configuration_speecht5.py +7 -9
  1311. transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
  1312. transformers/models/speecht5/modeling_speecht5.py +172 -174
  1313. transformers/models/speecht5/number_normalizer.py +0 -1
  1314. transformers/models/speecht5/processing_speecht5.py +3 -37
  1315. transformers/models/speecht5/tokenization_speecht5.py +4 -5
  1316. transformers/models/splinter/configuration_splinter.py +6 -7
  1317. transformers/models/splinter/modeling_splinter.py +62 -59
  1318. transformers/models/splinter/tokenization_splinter.py +2 -4
  1319. transformers/models/squeezebert/configuration_squeezebert.py +14 -2
  1320. transformers/models/squeezebert/modeling_squeezebert.py +60 -62
  1321. transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
  1322. transformers/models/stablelm/configuration_stablelm.py +28 -29
  1323. transformers/models/stablelm/modeling_stablelm.py +44 -47
  1324. transformers/models/starcoder2/configuration_starcoder2.py +30 -27
  1325. transformers/models/starcoder2/modeling_starcoder2.py +38 -41
  1326. transformers/models/starcoder2/modular_starcoder2.py +17 -19
  1327. transformers/models/superglue/configuration_superglue.py +7 -3
  1328. transformers/models/superglue/image_processing_superglue.py +15 -15
  1329. transformers/models/superglue/image_processing_superglue_fast.py +8 -8
  1330. transformers/models/superglue/modeling_superglue.py +41 -37
  1331. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1332. transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
  1333. transformers/models/superpoint/modeling_superpoint.py +17 -16
  1334. transformers/models/swiftformer/configuration_swiftformer.py +0 -1
  1335. transformers/models/swiftformer/modeling_swiftformer.py +12 -14
  1336. transformers/models/swin/configuration_swin.py +2 -5
  1337. transformers/models/swin/modeling_swin.py +69 -78
  1338. transformers/models/swin2sr/configuration_swin2sr.py +0 -1
  1339. transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
  1340. transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
  1341. transformers/models/swin2sr/modeling_swin2sr.py +30 -30
  1342. transformers/models/swinv2/configuration_swinv2.py +2 -5
  1343. transformers/models/swinv2/modeling_swinv2.py +65 -74
  1344. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
  1345. transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
  1346. transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
  1347. transformers/models/t5/configuration_t5.py +9 -9
  1348. transformers/models/t5/modeling_t5.py +80 -85
  1349. transformers/models/t5/tokenization_t5.py +1 -3
  1350. transformers/models/t5gemma/configuration_t5gemma.py +43 -59
  1351. transformers/models/t5gemma/modeling_t5gemma.py +105 -108
  1352. transformers/models/t5gemma/modular_t5gemma.py +128 -142
  1353. transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
  1354. transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
  1355. transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
  1356. transformers/models/table_transformer/configuration_table_transformer.py +18 -50
  1357. transformers/models/table_transformer/modeling_table_transformer.py +73 -101
  1358. transformers/models/tapas/configuration_tapas.py +12 -2
  1359. transformers/models/tapas/modeling_tapas.py +65 -67
  1360. transformers/models/tapas/tokenization_tapas.py +116 -153
  1361. transformers/models/textnet/configuration_textnet.py +4 -7
  1362. transformers/models/textnet/image_processing_textnet.py +22 -25
  1363. transformers/models/textnet/image_processing_textnet_fast.py +8 -9
  1364. transformers/models/textnet/modeling_textnet.py +28 -28
  1365. transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
  1366. transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
  1367. transformers/models/timesfm/configuration_timesfm.py +0 -1
  1368. transformers/models/timesfm/modeling_timesfm.py +22 -25
  1369. transformers/models/timesfm/modular_timesfm.py +21 -24
  1370. transformers/models/timesformer/configuration_timesformer.py +0 -1
  1371. transformers/models/timesformer/modeling_timesformer.py +13 -16
  1372. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
  1373. transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
  1374. transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
  1375. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
  1376. transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
  1377. transformers/models/trocr/configuration_trocr.py +11 -8
  1378. transformers/models/trocr/modeling_trocr.py +42 -42
  1379. transformers/models/trocr/processing_trocr.py +5 -25
  1380. transformers/models/tvp/configuration_tvp.py +10 -36
  1381. transformers/models/tvp/image_processing_tvp.py +50 -52
  1382. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1383. transformers/models/tvp/modeling_tvp.py +26 -28
  1384. transformers/models/tvp/processing_tvp.py +2 -14
  1385. transformers/models/udop/configuration_udop.py +16 -8
  1386. transformers/models/udop/modeling_udop.py +73 -72
  1387. transformers/models/udop/processing_udop.py +7 -26
  1388. transformers/models/udop/tokenization_udop.py +80 -93
  1389. transformers/models/umt5/configuration_umt5.py +8 -7
  1390. transformers/models/umt5/modeling_umt5.py +87 -84
  1391. transformers/models/unispeech/configuration_unispeech.py +4 -2
  1392. transformers/models/unispeech/modeling_unispeech.py +54 -53
  1393. transformers/models/unispeech/modular_unispeech.py +20 -22
  1394. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
  1395. transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
  1396. transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
  1397. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1398. transformers/models/univnet/modeling_univnet.py +7 -8
  1399. transformers/models/upernet/configuration_upernet.py +8 -36
  1400. transformers/models/upernet/modeling_upernet.py +11 -14
  1401. transformers/models/vaultgemma/__init__.py +0 -1
  1402. transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
  1403. transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
  1404. transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
  1405. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  1406. transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
  1407. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
  1408. transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
  1409. transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
  1410. transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
  1411. transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
  1412. transformers/models/video_llava/configuration_video_llava.py +4 -1
  1413. transformers/models/video_llava/image_processing_video_llava.py +35 -38
  1414. transformers/models/video_llava/modeling_video_llava.py +139 -143
  1415. transformers/models/video_llava/processing_video_llava.py +38 -78
  1416. transformers/models/video_llava/video_processing_video_llava.py +0 -1
  1417. transformers/models/videomae/configuration_videomae.py +0 -1
  1418. transformers/models/videomae/image_processing_videomae.py +31 -34
  1419. transformers/models/videomae/modeling_videomae.py +17 -20
  1420. transformers/models/videomae/video_processing_videomae.py +0 -1
  1421. transformers/models/vilt/configuration_vilt.py +4 -2
  1422. transformers/models/vilt/image_processing_vilt.py +29 -30
  1423. transformers/models/vilt/image_processing_vilt_fast.py +15 -16
  1424. transformers/models/vilt/modeling_vilt.py +103 -90
  1425. transformers/models/vilt/processing_vilt.py +2 -14
  1426. transformers/models/vipllava/configuration_vipllava.py +4 -1
  1427. transformers/models/vipllava/modeling_vipllava.py +92 -67
  1428. transformers/models/vipllava/modular_vipllava.py +78 -54
  1429. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
  1430. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
  1431. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
  1432. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
  1433. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
  1434. transformers/models/visual_bert/configuration_visual_bert.py +6 -2
  1435. transformers/models/visual_bert/modeling_visual_bert.py +90 -92
  1436. transformers/models/vit/configuration_vit.py +2 -3
  1437. transformers/models/vit/image_processing_vit.py +19 -22
  1438. transformers/models/vit/image_processing_vit_fast.py +0 -1
  1439. transformers/models/vit/modeling_vit.py +20 -20
  1440. transformers/models/vit_mae/configuration_vit_mae.py +0 -1
  1441. transformers/models/vit_mae/modeling_vit_mae.py +32 -30
  1442. transformers/models/vit_msn/configuration_vit_msn.py +0 -1
  1443. transformers/models/vit_msn/modeling_vit_msn.py +21 -19
  1444. transformers/models/vitdet/configuration_vitdet.py +2 -5
  1445. transformers/models/vitdet/modeling_vitdet.py +14 -17
  1446. transformers/models/vitmatte/configuration_vitmatte.py +7 -39
  1447. transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
  1448. transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
  1449. transformers/models/vitmatte/modeling_vitmatte.py +10 -12
  1450. transformers/models/vitpose/configuration_vitpose.py +7 -47
  1451. transformers/models/vitpose/image_processing_vitpose.py +24 -25
  1452. transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
  1453. transformers/models/vitpose/modeling_vitpose.py +15 -15
  1454. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
  1455. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
  1456. transformers/models/vits/configuration_vits.py +4 -1
  1457. transformers/models/vits/modeling_vits.py +43 -42
  1458. transformers/models/vits/tokenization_vits.py +3 -4
  1459. transformers/models/vivit/configuration_vivit.py +0 -1
  1460. transformers/models/vivit/image_processing_vivit.py +36 -39
  1461. transformers/models/vivit/modeling_vivit.py +9 -11
  1462. transformers/models/vjepa2/__init__.py +0 -1
  1463. transformers/models/vjepa2/configuration_vjepa2.py +0 -1
  1464. transformers/models/vjepa2/modeling_vjepa2.py +39 -41
  1465. transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
  1466. transformers/models/voxtral/__init__.py +0 -1
  1467. transformers/models/voxtral/configuration_voxtral.py +0 -2
  1468. transformers/models/voxtral/modeling_voxtral.py +41 -48
  1469. transformers/models/voxtral/modular_voxtral.py +35 -38
  1470. transformers/models/voxtral/processing_voxtral.py +25 -48
  1471. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
  1472. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
  1473. transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
  1474. transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
  1475. transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
  1476. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
  1477. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
  1478. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
  1479. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
  1480. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
  1481. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
  1482. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
  1483. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
  1484. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
  1485. transformers/models/wavlm/configuration_wavlm.py +4 -2
  1486. transformers/models/wavlm/modeling_wavlm.py +49 -49
  1487. transformers/models/wavlm/modular_wavlm.py +4 -5
  1488. transformers/models/whisper/configuration_whisper.py +6 -5
  1489. transformers/models/whisper/english_normalizer.py +3 -4
  1490. transformers/models/whisper/feature_extraction_whisper.py +9 -24
  1491. transformers/models/whisper/generation_whisper.py +26 -49
  1492. transformers/models/whisper/modeling_whisper.py +71 -73
  1493. transformers/models/whisper/processing_whisper.py +3 -20
  1494. transformers/models/whisper/tokenization_whisper.py +9 -30
  1495. transformers/models/x_clip/configuration_x_clip.py +4 -2
  1496. transformers/models/x_clip/modeling_x_clip.py +94 -96
  1497. transformers/models/x_clip/processing_x_clip.py +2 -14
  1498. transformers/models/xcodec/configuration_xcodec.py +4 -6
  1499. transformers/models/xcodec/modeling_xcodec.py +15 -17
  1500. transformers/models/xglm/configuration_xglm.py +9 -8
  1501. transformers/models/xglm/modeling_xglm.py +49 -55
  1502. transformers/models/xglm/tokenization_xglm.py +1 -4
  1503. transformers/models/xlm/configuration_xlm.py +10 -8
  1504. transformers/models/xlm/modeling_xlm.py +127 -131
  1505. transformers/models/xlm/tokenization_xlm.py +3 -5
  1506. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
  1507. transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
  1508. transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
  1509. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
  1510. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
  1511. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
  1512. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
  1513. transformers/models/xlnet/configuration_xlnet.py +3 -12
  1514. transformers/models/xlnet/modeling_xlnet.py +149 -162
  1515. transformers/models/xlnet/tokenization_xlnet.py +1 -4
  1516. transformers/models/xlstm/configuration_xlstm.py +8 -12
  1517. transformers/models/xlstm/modeling_xlstm.py +61 -96
  1518. transformers/models/xmod/configuration_xmod.py +11 -3
  1519. transformers/models/xmod/modeling_xmod.py +111 -116
  1520. transformers/models/yolos/configuration_yolos.py +0 -1
  1521. transformers/models/yolos/image_processing_yolos.py +60 -62
  1522. transformers/models/yolos/image_processing_yolos_fast.py +42 -45
  1523. transformers/models/yolos/modeling_yolos.py +19 -21
  1524. transformers/models/yolos/modular_yolos.py +17 -19
  1525. transformers/models/yoso/configuration_yoso.py +8 -2
  1526. transformers/models/yoso/modeling_yoso.py +60 -62
  1527. transformers/models/youtu/__init__.py +27 -0
  1528. transformers/models/youtu/configuration_youtu.py +194 -0
  1529. transformers/models/youtu/modeling_youtu.py +619 -0
  1530. transformers/models/youtu/modular_youtu.py +254 -0
  1531. transformers/models/zamba/configuration_zamba.py +5 -8
  1532. transformers/models/zamba/modeling_zamba.py +93 -125
  1533. transformers/models/zamba2/configuration_zamba2.py +44 -50
  1534. transformers/models/zamba2/modeling_zamba2.py +137 -165
  1535. transformers/models/zamba2/modular_zamba2.py +79 -74
  1536. transformers/models/zoedepth/configuration_zoedepth.py +17 -41
  1537. transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
  1538. transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
  1539. transformers/models/zoedepth/modeling_zoedepth.py +19 -19
  1540. transformers/pipelines/__init__.py +47 -106
  1541. transformers/pipelines/any_to_any.py +15 -23
  1542. transformers/pipelines/audio_utils.py +1 -2
  1543. transformers/pipelines/automatic_speech_recognition.py +0 -2
  1544. transformers/pipelines/base.py +13 -17
  1545. transformers/pipelines/image_text_to_text.py +1 -2
  1546. transformers/pipelines/question_answering.py +4 -43
  1547. transformers/pipelines/text_classification.py +1 -14
  1548. transformers/pipelines/text_to_audio.py +5 -1
  1549. transformers/pipelines/token_classification.py +1 -22
  1550. transformers/pipelines/video_classification.py +1 -9
  1551. transformers/pipelines/zero_shot_audio_classification.py +0 -1
  1552. transformers/pipelines/zero_shot_classification.py +0 -6
  1553. transformers/pipelines/zero_shot_image_classification.py +0 -7
  1554. transformers/processing_utils.py +128 -137
  1555. transformers/pytorch_utils.py +2 -26
  1556. transformers/quantizers/base.py +10 -0
  1557. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  1558. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  1559. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  1560. transformers/quantizers/quantizer_mxfp4.py +1 -1
  1561. transformers/quantizers/quantizer_quark.py +0 -1
  1562. transformers/quantizers/quantizer_torchao.py +3 -19
  1563. transformers/safetensors_conversion.py +11 -4
  1564. transformers/testing_utils.py +6 -65
  1565. transformers/tokenization_mistral_common.py +563 -903
  1566. transformers/tokenization_python.py +6 -4
  1567. transformers/tokenization_utils_base.py +228 -341
  1568. transformers/tokenization_utils_sentencepiece.py +5 -6
  1569. transformers/tokenization_utils_tokenizers.py +36 -7
  1570. transformers/trainer.py +30 -41
  1571. transformers/trainer_jit_checkpoint.py +1 -2
  1572. transformers/trainer_seq2seq.py +1 -1
  1573. transformers/training_args.py +414 -420
  1574. transformers/utils/__init__.py +1 -4
  1575. transformers/utils/attention_visualizer.py +1 -1
  1576. transformers/utils/auto_docstring.py +567 -18
  1577. transformers/utils/backbone_utils.py +13 -373
  1578. transformers/utils/doc.py +4 -36
  1579. transformers/utils/dummy_pt_objects.py +0 -42
  1580. transformers/utils/generic.py +70 -34
  1581. transformers/utils/import_utils.py +72 -75
  1582. transformers/utils/loading_report.py +135 -107
  1583. transformers/utils/quantization_config.py +8 -31
  1584. transformers/video_processing_utils.py +24 -25
  1585. transformers/video_utils.py +21 -23
  1586. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
  1587. transformers-5.1.0.dist-info/RECORD +2092 -0
  1588. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1589. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1590. transformers/pipelines/image_to_text.py +0 -229
  1591. transformers-5.0.0rc2.dist-info/RECORD +0 -2042
  1592. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1593. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1594. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -14,18 +14,17 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import math
17
- from contextlib import nullcontext
18
- from typing import Literal, Optional, Union
17
+ from typing import Literal, Optional
19
18
 
20
19
  import torch
21
- import torch.nn.functional as F
22
20
  from torch import nn
23
21
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
22
 
25
23
  from ... import initialization as init
26
24
  from ...activations import ACT2FN
27
25
  from ...configuration_utils import PreTrainedConfig, layer_type_validation
28
- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
26
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
27
+ from ...masking_utils import create_bidirectional_mask, create_bidirectional_sliding_window_mask
29
28
  from ...modeling_layers import GradientCheckpointingLayer
30
29
  from ...modeling_outputs import (
31
30
  BaseModelOutput,
@@ -36,18 +35,12 @@ from ...modeling_outputs import (
36
35
  TokenClassifierOutput,
37
36
  )
38
37
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters
39
- from ...modeling_utils import PreTrainedModel
40
- from ...utils import auto_docstring, is_flash_attn_2_available, logging
41
- from ...utils.import_utils import is_triton_available
42
- from ..gemma3.modeling_gemma3 import Gemma3RotaryEmbedding, apply_rotary_pos_emb
43
-
44
-
45
- if is_flash_attn_2_available():
46
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
47
- from flash_attn.layers.rotary import RotaryEmbedding
48
- from flash_attn.ops.triton.rotary import apply_rotary
49
- else:
50
- RotaryEmbedding = object
38
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
+ from ...processing_utils import Unpack
40
+ from ...utils import TransformersKwargs, auto_docstring, logging
41
+ from ...utils.generic import can_return_tuple, check_model_inputs
42
+ from ..align.modeling_align import eager_attention_forward
43
+ from ..gemma3.modeling_gemma3 import Gemma3RotaryEmbedding, rotate_half
51
44
 
52
45
 
53
46
  logger = logging.get_logger(__name__)
@@ -104,10 +97,9 @@ class ModernBertConfig(PreTrainedConfig):
104
97
  The dropout ratio for the attention probabilities.
105
98
  layer_types (`list`, *optional*):
106
99
  Attention pattern for each layer.
107
- rope_parameters (`RopeParameters`, *optional*):
108
- Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
109
- a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
110
- with longer `max_position_embeddings`.
100
+ rope_parameters (`dict`, *optional*):
101
+ Dictionary mapping attention patterns (`"full_attention"`, `"sliding_attention"`) to `RopeParameters`.
102
+ Each value should be a dictionary containing `rope_type` and optional scaling parameters.
111
103
  local_attention (`int`, *optional*, defaults to 128):
112
104
  The window size for local attention.
113
105
  embedding_dropout (`float`, *optional*, defaults to 0.0):
@@ -137,10 +129,9 @@ class ModernBertConfig(PreTrainedConfig):
137
129
  Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
138
130
  the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
139
131
  shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
140
- be faster in some scenarios.
141
- repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
142
- When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
143
- applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
132
+ be faster in some scenarios. This argument is deprecated and will be removed in a future version.
133
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
134
+ Whether to tie weight embeddings
144
135
 
145
136
  Examples:
146
137
 
@@ -161,44 +152,59 @@ class ModernBertConfig(PreTrainedConfig):
161
152
  keys_to_ignore_at_inference = ["past_key_values"]
162
153
  default_theta = {"global": 160_000.0, "local": 10_000.0}
163
154
 
155
+ def __setattr__(self, name, value):
156
+ if name == "reference_compile" and value is not None:
157
+ logger.warning_once(
158
+ "The `reference_compile` argument is deprecated and will be removed in `transformers v5.2.0`"
159
+ "Use `torch.compile()` directly on the model instead."
160
+ )
161
+ value = None
162
+ super().__setattr__(name, value)
163
+
164
164
  def __init__(
165
165
  self,
166
- vocab_size: Optional[int] = 50368,
167
- hidden_size: Optional[int] = 768,
168
- intermediate_size: Optional[int] = 1152,
169
- num_hidden_layers: Optional[int] = 22,
170
- num_attention_heads: Optional[int] = 12,
171
- hidden_activation: Optional[str] = "gelu",
172
- max_position_embeddings: Optional[int] = 8192,
173
- initializer_range: Optional[float] = 0.02,
174
- initializer_cutoff_factor: Optional[float] = 2.0,
175
- norm_eps: Optional[int] = 1e-5,
176
- norm_bias: Optional[bool] = False,
177
- pad_token_id: Optional[int] = 50283,
178
- eos_token_id: Optional[int] = 50282,
179
- bos_token_id: Optional[int] = 50281,
180
- cls_token_id: Optional[int] = 50281,
181
- sep_token_id: Optional[int] = 50282,
182
- attention_bias: Optional[bool] = False,
183
- attention_dropout: Optional[float] = 0.0,
184
- layer_types: Optional[list[str]] = None,
185
- rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
186
- local_attention: Optional[int] = 128,
187
- embedding_dropout: Optional[float] = 0.0,
188
- mlp_bias: Optional[bool] = False,
189
- mlp_dropout: Optional[float] = 0.0,
190
- decoder_bias: Optional[bool] = True,
166
+ vocab_size: int | None = 50368,
167
+ hidden_size: int | None = 768,
168
+ intermediate_size: int | None = 1152,
169
+ num_hidden_layers: int | None = 22,
170
+ num_attention_heads: int | None = 12,
171
+ hidden_activation: str | None = "gelu",
172
+ max_position_embeddings: int | None = 8192,
173
+ initializer_range: float | None = 0.02,
174
+ initializer_cutoff_factor: float | None = 2.0,
175
+ norm_eps: float | None = 1e-5,
176
+ norm_bias: bool | None = False,
177
+ pad_token_id: int | None = 50283,
178
+ eos_token_id: int | None = 50282,
179
+ bos_token_id: int | None = 50281,
180
+ cls_token_id: int | None = 50281,
181
+ sep_token_id: int | None = 50282,
182
+ attention_bias: bool | None = False,
183
+ attention_dropout: float | None = 0.0,
184
+ layer_types: list[str] | None = None,
185
+ rope_parameters: dict[Literal["full_attention", "sliding_attention"], RopeParameters] | None = None,
186
+ local_attention: int | None = 128,
187
+ embedding_dropout: float | None = 0.0,
188
+ mlp_bias: bool | None = False,
189
+ mlp_dropout: float | None = 0.0,
190
+ decoder_bias: bool | None = True,
191
191
  classifier_pooling: Literal["cls", "mean"] = "cls",
192
- classifier_dropout: Optional[float] = 0.0,
193
- classifier_bias: Optional[bool] = False,
194
- classifier_activation: Optional[str] = "gelu",
195
- deterministic_flash_attn: Optional[bool] = False,
196
- sparse_prediction: Optional[bool] = False,
197
- sparse_pred_ignore_index: Optional[int] = -100,
198
- reference_compile: Optional[bool] = None,
199
- repad_logits_with_grad: Optional[bool] = False,
192
+ classifier_dropout: float | None = 0.0,
193
+ classifier_bias: bool | None = False,
194
+ classifier_activation: str | None = "gelu",
195
+ deterministic_flash_attn: bool | None = False,
196
+ sparse_prediction: bool | None = False,
197
+ sparse_pred_ignore_index: int | None = -100,
198
+ reference_compile: bool | None = None, # Deprecated
199
+ tie_word_embeddings: bool | None = True,
200
200
  **kwargs,
201
201
  ):
202
+ self.pad_token_id = pad_token_id
203
+ self.bos_token_id = bos_token_id
204
+ self.eos_token_id = eos_token_id
205
+ self.cls_token_id = cls_token_id
206
+ self.sep_token_id = sep_token_id
207
+ self.tie_word_embeddings = tie_word_embeddings
202
208
  self.vocab_size = vocab_size
203
209
  self.max_position_embeddings = max_position_embeddings
204
210
  self.hidden_size = hidden_size
@@ -225,7 +231,6 @@ class ModernBertConfig(PreTrainedConfig):
225
231
  self.sparse_prediction = sparse_prediction
226
232
  self.sparse_pred_ignore_index = sparse_pred_ignore_index
227
233
  self.reference_compile = reference_compile
228
- self.repad_logits_with_grad = repad_logits_with_grad
229
234
 
230
235
  if self.classifier_pooling not in ["cls", "mean"]:
231
236
  raise ValueError(
@@ -245,14 +250,7 @@ class ModernBertConfig(PreTrainedConfig):
245
250
  layer_type_validation(self.layer_types, self.num_hidden_layers)
246
251
 
247
252
  self.rope_parameters = rope_parameters
248
- super().__init__(
249
- pad_token_id=pad_token_id,
250
- bos_token_id=bos_token_id,
251
- eos_token_id=eos_token_id,
252
- cls_token_id=cls_token_id,
253
- sep_token_id=sep_token_id,
254
- **kwargs,
255
- )
253
+ super().__init__(**kwargs)
256
254
 
257
255
  def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation=None, **kwargs):
258
256
  rope_scaling = kwargs.pop("rope_scaling", None)
@@ -267,9 +265,15 @@ class ModernBertConfig(PreTrainedConfig):
267
265
  if rope_scaling is not None:
268
266
  self.rope_parameters["full_attention"].update(rope_scaling)
269
267
  self.rope_parameters["sliding_attention"].update(rope_scaling)
268
+
269
+ # Set default values if not present
270
+ if self.rope_parameters.get("full_attention") is None:
271
+ self.rope_parameters["full_attention"] = {"rope_type": "default"}
270
272
  self.rope_parameters["full_attention"].setdefault(
271
273
  "rope_theta", kwargs.pop("global_rope_theta", self.default_theta["global"])
272
274
  )
275
+ if self.rope_parameters.get("sliding_attention") is None:
276
+ self.rope_parameters["sliding_attention"] = {"rope_type": "default"}
273
277
  self.rope_parameters["sliding_attention"].setdefault(
274
278
  "rope_theta", kwargs.pop("local_rope_theta", self.default_theta["local"])
275
279
  )
@@ -284,211 +288,15 @@ class ModernBertConfig(PreTrainedConfig):
284
288
  output.pop("reference_compile", None)
285
289
  return output
286
290
 
291
+ @property
292
+ def sliding_window(self):
293
+ """Half-window size: `local_attention` is the total window, so we divide by 2."""
294
+ return self.local_attention // 2
287
295
 
288
- def _unpad_modernbert_input(
289
- inputs: torch.Tensor,
290
- attention_mask: torch.Tensor,
291
- position_ids: Optional[torch.Tensor] = None,
292
- labels: Optional[torch.Tensor] = None,
293
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
294
- """
295
- Remove padding from input sequences.
296
-
297
- Args:
298
- inputs: (batch, seqlen, ...) or (batch, seqlen)
299
- attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
300
- position_ids: (batch, seqlen), int, position ids
301
- labels: (batch, seqlen), int, labels
302
-
303
- Returns:
304
- unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
305
- indices: (total_nnz)
306
- cu_seqlens: (batch + 1), the cumulative sequence lengths
307
- max_seqlen_in_batch: int
308
- unpadded_position_ids: (total_nnz) or None
309
- unpadded_labels: (total_nnz) or None
310
- """
311
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
312
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
313
- max_seqlen_in_batch = int(seqlens_in_batch.max().item())
314
- cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
315
-
316
- if inputs.dim() == 2:
317
- unpadded_inputs = inputs.flatten()[indices]
318
- else:
319
- batch, seqlen, *rest = inputs.shape
320
- shape = batch * seqlen
321
- unpadded_inputs = inputs.view(shape, *rest)[indices]
322
-
323
- unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
324
- unpadded_labels = labels.flatten()[indices] if labels is not None else None
325
-
326
- return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
327
-
328
-
329
- def _pad_modernbert_output(
330
- inputs: torch.Tensor,
331
- indices: torch.Tensor,
332
- batch: int,
333
- seqlen: int,
334
- ) -> torch.Tensor:
335
- """
336
- Add padding to sequences.
337
-
338
- Args:
339
- inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
340
- indices: (total_nnz)
341
- batch: int, batch size
342
- seqlen: int, max sequence length
343
-
344
- Returns:
345
- padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
346
- """
347
- if inputs.dim() == 1:
348
- output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
349
- output[indices] = inputs
350
- padded_inputs = output.view(batch, seqlen)
351
- else:
352
- _, *rest = inputs.shape
353
- output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
354
- output[indices] = inputs
355
- padded_inputs = output.view(batch, seqlen, *rest)
356
-
357
- return padded_inputs
358
-
359
-
360
- class ApplyRotaryEmbUnpad(torch.autograd.Function):
361
- @staticmethod
362
- def forward(
363
- ctx,
364
- qkv,
365
- cos,
366
- sin,
367
- cu_seqlens: Optional[torch.Tensor] = None,
368
- max_seqlen: Optional[int] = None,
369
- ):
370
- # (total_nnz, 3, nheads, headdim)
371
- qkv = qkv.contiguous()
372
- total_nnz, _three, _nheads, headdim = qkv.shape
373
- # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
374
- # we get the same tensor
375
- # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
376
- qk = qkv[:, :2].view(total_nnz, -1, headdim)
377
- apply_rotary(
378
- qk,
379
- cos,
380
- sin,
381
- seqlen_offsets=0,
382
- cu_seqlens=cu_seqlens,
383
- max_seqlen=max_seqlen,
384
- interleaved=False,
385
- inplace=True,
386
- )
387
-
388
- ctx.save_for_backward(cos, sin, cu_seqlens)
389
- ctx.max_seqlen = max_seqlen
390
- return qkv
391
-
392
- @staticmethod
393
- def backward(ctx, do):
394
- cos, sin, cu_seqlens = ctx.saved_tensors
395
- do = do.contiguous()
396
- total_nnz, _three, _nheads, headdim = do.shape
397
- # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
398
- # we get the same tensor
399
- dqk = do[:, :2].view(total_nnz, -1, headdim)
400
- apply_rotary(
401
- dqk,
402
- cos,
403
- sin,
404
- seqlen_offsets=0,
405
- cu_seqlens=cu_seqlens,
406
- max_seqlen=ctx.max_seqlen,
407
- interleaved=False,
408
- inplace=True,
409
- conjugate=True,
410
- )
411
-
412
- return do, None, None, None, None, None, None
413
-
414
-
415
- def apply_rotary_unpadded(
416
- qkv,
417
- cos,
418
- sin,
419
- cu_seqlens: Optional[torch.Tensor] = None,
420
- max_seqlen: Optional[int] = None,
421
- ):
422
- """
423
- Arguments:
424
- qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
425
- cos, sin: (seqlen_rotary, rotary_dim / 2)
426
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
427
- of 1st half and 2nd half (GPT-NeoX style).
428
- inplace: if True, apply rotary embedding in-place.
429
- seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
430
- Most commonly used in inference when we have KV cache.
431
- cu_seqlens: (batch + 1,) or None
432
- max_seqlen: int
433
- Return:
434
- out: (total_nnz, dim)
435
- rotary_dim must be <= headdim
436
- Apply rotary embedding to the first rotary_dim of x.
437
- """
438
- return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
439
-
440
-
441
- class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
442
- """
443
- The rotary position embeddings applied directly to unpadded sequences.
444
- """
445
-
446
- def __init__(
447
- self,
448
- dim: int,
449
- base: float = 10000.0,
450
- max_seqlen: Optional[int] = None,
451
- device: Optional[torch.device] = None,
452
- dtype: Optional[torch.dtype] = None,
453
- ):
454
- """
455
- max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
456
- up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
457
- the cos_sin_cache will be recomputed during the forward pass.
458
- """
459
- super().__init__(dim=dim, base=base, device=device, interleaved=False)
460
- self.max_seqlen = max_seqlen
461
-
462
- if max_seqlen is not None and device is not None and dtype is not None:
463
- self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
464
-
465
- def forward(
466
- self,
467
- qkv: torch.Tensor,
468
- cu_seqlens: torch.Tensor,
469
- max_seqlen: Optional[int] = None,
470
- ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
471
- """
472
- Apply rotary embedding *inplace* to qkv.
473
- qkv: (total_nnz, 3, nheads, headdim)
474
- cu_seqlens: (batch + 1,) cumulative sequence lengths
475
- max_seqlen: int max seq length in the batch
476
- """
477
- if max_seqlen is not None:
478
- self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
479
-
480
- qkv = apply_rotary_unpadded(
481
- qkv,
482
- self._cos_cached,
483
- self._sin_cached,
484
- cu_seqlens=cu_seqlens,
485
- max_seqlen=max_seqlen,
486
- )
487
-
488
- return qkv
489
-
490
- def extra_repr(self) -> str:
491
- return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
296
+ @sliding_window.setter
297
+ def sliding_window(self, value):
298
+ """Set sliding_window by updating local_attention to 2 * value."""
299
+ self.local_attention = value * 2
492
300
 
493
301
 
494
302
  class ModernBertEmbeddings(nn.Module):
@@ -503,21 +311,13 @@ class ModernBertEmbeddings(nn.Module):
503
311
  self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
504
312
  self.drop = nn.Dropout(config.embedding_dropout)
505
313
 
506
- @torch.compile(dynamic=True)
507
- def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
508
- return self.drop(self.norm(self.tok_embeddings(input_ids)))
509
-
510
314
  def forward(
511
- self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None
315
+ self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.Tensor | None = None
512
316
  ) -> torch.Tensor:
513
317
  if inputs_embeds is not None:
514
318
  hidden_states = self.drop(self.norm(inputs_embeds))
515
319
  else:
516
- hidden_states = (
517
- self.compiled_embeddings(input_ids)
518
- if self.config.reference_compile
519
- else self.drop(self.norm(self.tok_embeddings(input_ids)))
520
- )
320
+ hidden_states = self.drop(self.norm(self.tok_embeddings(input_ids)))
521
321
  return hidden_states
522
322
 
523
323
 
@@ -547,138 +347,42 @@ class ModernBertRotaryEmbedding(Gemma3RotaryEmbedding):
547
347
 
548
348
  @staticmethod
549
349
  def compute_default_rope_parameters(
550
- config: Optional[ModernBertConfig] = None,
350
+ config: ModernBertConfig | None = None,
551
351
  device: Optional["torch.device"] = None,
552
- seq_len: Optional[int] = None,
553
- layer_type: Optional[str] = None,
352
+ seq_len: int | None = None,
353
+ layer_type: str | None = None,
554
354
  ) -> tuple["torch.Tensor", float]:
555
355
  return super().compute_default_rope_parameters(config, device, seq_len, layer_type)
556
356
 
557
357
 
558
- def eager_attention_forward(
559
- module: "ModernBertAttention",
560
- qkv: torch.Tensor,
561
- attention_mask: torch.Tensor,
562
- sliding_window_mask: torch.Tensor,
563
- position_ids: Optional[torch.LongTensor],
564
- local_attention: tuple[int, int],
565
- bs: int,
566
- dim: int,
567
- position_embeddings: torch.Tensor,
568
- output_attentions: Optional[bool] = False,
569
- **_kwargs,
570
- ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
571
- # qkv: [batch_size, seqlen, 3, nheads, headdim]
572
- cos, sin = position_embeddings
573
- query, key, value = qkv.transpose(3, 1).unbind(dim=2)
574
- # query, key, value: [batch_size, heads, seq_len, head_dim]
575
- query, key = apply_rotary_pos_emb(query, key, cos, sin)
576
-
577
- scale = module.head_dim**-0.5
578
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
579
-
580
- if local_attention != (-1, -1):
581
- attention_mask = sliding_window_mask
582
-
583
- attn_weights = attn_weights + attention_mask
584
-
585
- # upcast attention to fp32
586
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
587
- attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
588
- attn_output = torch.matmul(attn_weights, value)
589
- attn_output = attn_output.transpose(1, 2).contiguous()
590
- attn_output = attn_output.view(bs, -1, dim)
591
- if output_attentions:
592
- return (attn_output, attn_weights)
593
- return (attn_output,)
594
-
595
-
596
- def flash_attention_forward(
597
- module: "ModernBertAttention",
598
- qkv: torch.Tensor,
599
- rotary_emb: ModernBertUnpaddedRotaryEmbedding,
600
- cu_seqlens: torch.Tensor,
601
- max_seqlen: int,
602
- local_attention: tuple[int, int],
603
- bs: int,
604
- dim: int,
605
- target_dtype: torch.dtype = torch.bfloat16,
606
- **_kwargs,
607
- ) -> tuple[torch.Tensor]:
608
- # (total_seqlen, 3, nheads, headdim)
609
- qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
610
-
611
- convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
612
- if convert_dtype:
613
- # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
614
- # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
615
- orig_dtype = qkv.dtype
616
- qkv = qkv.to(target_dtype)
617
-
618
- attn = flash_attn_varlen_qkvpacked_func(
619
- qkv,
620
- cu_seqlens=cu_seqlens,
621
- max_seqlen=max_seqlen,
622
- dropout_p=module.attention_dropout if module.training else 0.0,
623
- deterministic=module.deterministic_flash_attn,
624
- window_size=local_attention,
625
- )
626
- attn = attn.to(orig_dtype) # type: ignore
627
- else:
628
- attn = flash_attn_varlen_qkvpacked_func(
629
- qkv,
630
- cu_seqlens=cu_seqlens,
631
- max_seqlen=max_seqlen,
632
- dropout_p=module.attention_dropout if module.training else 0.0,
633
- deterministic=module.deterministic_flash_attn,
634
- window_size=local_attention,
635
- )
636
- return (attn.view(bs, dim),)
637
-
638
-
639
- def sdpa_attention_forward(
640
- module: "ModernBertAttention",
641
- qkv: torch.Tensor,
642
- attention_mask: torch.Tensor,
643
- sliding_window_mask: torch.Tensor,
644
- position_ids: Optional[torch.LongTensor],
645
- local_attention: tuple[int, int],
646
- bs: int,
647
- dim: int,
648
- position_embeddings: torch.Tensor,
649
- **_kwargs,
650
- ) -> tuple[torch.Tensor]:
651
- # qkv: [batch_size, seqlen, 3, nheads, headdim]
652
- cos, sin = position_embeddings
653
- query, key, value = qkv.transpose(3, 1).unbind(dim=2)
654
- # query, key, value: [batch_size, heads, seq_len, head_dim]
655
- query, key = apply_rotary_pos_emb(query, key, cos, sin)
656
-
657
- if local_attention != (-1, -1):
658
- attention_mask = sliding_window_mask
659
-
660
- attn_output = (
661
- F.scaled_dot_product_attention(
662
- query,
663
- key,
664
- value,
665
- dropout_p=module.attention_dropout if module.training else 0.0,
666
- attn_mask=attention_mask,
667
- )
668
- .transpose(1, 2)
669
- .contiguous()
670
- )
671
- attn_output = attn_output.view(bs, -1, dim)
672
- return (attn_output,)
673
-
358
+ @use_kernel_func_from_hub("rotary_pos_emb")
359
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
360
+ """Applies Rotary Position Embedding to the query and key tensors.
674
361
 
675
- MODERNBERT_ATTENTION_FUNCTION = {
676
- "flash_attention_2": flash_attention_forward,
677
- "eager": eager_attention_forward,
678
- "sdpa": sdpa_attention_forward,
679
- }
362
+ Args:
363
+ q (`torch.Tensor`): The query tensor.
364
+ k (`torch.Tensor`): The key tensor.
365
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
366
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
367
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
368
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
369
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
370
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
371
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
372
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
373
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
374
+ Returns:
375
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
376
+ """
377
+ original_dtype = q.dtype
378
+ cos = cos.unsqueeze(unsqueeze_dim)
379
+ sin = sin.unsqueeze(unsqueeze_dim)
380
+ q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
381
+ k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
382
+ return q_embed.to(original_dtype), k_embed.to(original_dtype)
680
383
 
681
384
 
385
+ @use_kernelized_func(apply_rotary_pos_emb)
682
386
  class ModernBertAttention(nn.Module):
683
387
  """Performs multi-headed self attention on a batch of unpadded sequences.
684
388
 
@@ -689,10 +393,10 @@ class ModernBertAttention(nn.Module):
689
393
  See `forward` method for additional details.
690
394
  """
691
395
 
692
- def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
396
+ def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
693
397
  super().__init__()
694
398
  self.config = config
695
- self.layer_id = layer_id
399
+ self.layer_idx = layer_idx
696
400
 
697
401
  if config.hidden_size % config.num_attention_heads != 0:
698
402
  raise ValueError(
@@ -701,29 +405,19 @@ class ModernBertAttention(nn.Module):
701
405
 
702
406
  self.attention_dropout = config.attention_dropout
703
407
  self.deterministic_flash_attn = config.deterministic_flash_attn
704
- self.num_heads = config.num_attention_heads
705
408
  self.head_dim = config.hidden_size // config.num_attention_heads
706
- self.all_head_size = self.head_dim * self.num_heads
707
- self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
708
- layer_type = config.layer_types[layer_id]
409
+ self.Wqkv = nn.Linear(
410
+ config.hidden_size, 3 * self.head_dim * config.num_attention_heads, bias=config.attention_bias
411
+ )
709
412
 
710
- if layer_id % config.global_attn_every_n_layers != 0:
711
- self.local_attention = (config.local_attention // 2, config.local_attention // 2)
712
- max_position_embeddings = config.local_attention
413
+ if config.layer_types[layer_idx] == "sliding_attention":
414
+ # config.sliding_window = local_attention // 2 (half-window size, e.g. 64 for local_attention=128)
415
+ # +1 is needed because flash attention sets inclusive boundaries (see modeling_flash_attention_utils.py)
416
+ self.sliding_window = config.sliding_window + 1
713
417
  else:
714
- self.local_attention = (-1, -1)
715
- max_position_embeddings = config.max_position_embeddings
418
+ self.sliding_window = None
716
419
 
717
- if config._attn_implementation == "flash_attention_2":
718
- rope_parameters_dict = (
719
- self.config.rope_parameters[layer_type] if layer_type is not None else self.config.rope_parameters
720
- )
721
- rope_theta = rope_parameters_dict["rope_theta"]
722
- self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
723
- dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
724
- )
725
- else:
726
- self.rotary_emb = None
420
+ self.is_causal = False
727
421
 
728
422
  self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
729
423
  self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
@@ -731,82 +425,75 @@ class ModernBertAttention(nn.Module):
731
425
  def forward(
732
426
  self,
733
427
  hidden_states: torch.Tensor,
734
- position_embeddings: Optional[torch.Tensor] = None,
735
- output_attentions: Optional[bool] = False,
736
- **kwargs,
737
- ) -> torch.Tensor:
428
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
429
+ attention_mask: torch.Tensor | None = None,
430
+ **kwargs: Unpack[TransformersKwargs],
431
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
432
+ input_shape = hidden_states.shape[:-1]
433
+
738
434
  qkv = self.Wqkv(hidden_states)
435
+ qkv = qkv.view(*input_shape, 3, -1, self.head_dim)
436
+ query_states, key_states, value_states = qkv.unbind(dim=-3)
739
437
 
740
- bs = hidden_states.shape[0]
741
- if self.config._attn_implementation == "flash_attention_2":
742
- qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
743
- else:
744
- qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
438
+ query_states = query_states.transpose(1, 2)
439
+ key_states = key_states.transpose(1, 2)
440
+ value_states = value_states.transpose(1, 2)
441
+
442
+ cos, sin = position_embeddings
443
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1)
444
+
445
+ attention_interface = eager_attention_forward
446
+ if self.config._attn_implementation != "eager":
447
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
745
448
 
746
- attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
449
+ attn_output, attn_weights = attention_interface(
747
450
  self,
748
- qkv=qkv,
749
- rotary_emb=self.rotary_emb,
750
- local_attention=self.local_attention,
751
- bs=bs,
752
- dim=self.all_head_size,
753
- position_embeddings=position_embeddings,
754
- output_attentions=output_attentions,
451
+ query_states,
452
+ key_states,
453
+ value_states,
454
+ attention_mask,
455
+ dropout=self.attention_dropout if self.training else 0.0,
456
+ scaling=self.head_dim**-0.5,
457
+ sliding_window=self.sliding_window,
458
+ deterministic=self.deterministic_flash_attn,
755
459
  **kwargs,
756
460
  )
757
- hidden_states = attn_outputs[0]
758
- hidden_states = self.out_drop(self.Wo(hidden_states))
759
461
 
760
- return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
462
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
463
+ attn_output = self.out_drop(self.Wo(attn_output))
464
+ return attn_output, attn_weights
761
465
 
762
466
 
763
467
  class ModernBertEncoderLayer(GradientCheckpointingLayer):
764
- def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
468
+ def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
765
469
  super().__init__()
766
470
  self.config = config
767
- if layer_id == 0:
471
+ self.layer_idx = layer_idx
472
+ if layer_idx == 0:
768
473
  self.attn_norm = nn.Identity()
769
474
  else:
770
475
  self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
771
- self.attn = ModernBertAttention(config=config, layer_id=layer_id)
476
+ self.attn = ModernBertAttention(config=config, layer_idx=layer_idx)
772
477
  self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
773
478
  self.mlp = ModernBertMLP(config)
774
- self.attention_type = config.layer_types[layer_id]
775
-
776
- @torch.compile(dynamic=True)
777
- def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
778
- return self.mlp(self.mlp_norm(hidden_states))
479
+ self.attention_type = config.layer_types[layer_idx]
779
480
 
780
481
  def forward(
781
482
  self,
782
483
  hidden_states: torch.Tensor,
783
- attention_mask: Optional[torch.Tensor] = None,
784
- sliding_window_mask: Optional[torch.Tensor] = None,
785
- position_ids: Optional[torch.LongTensor] = None,
786
- cu_seqlens: Optional[torch.Tensor] = None,
787
- max_seqlen: Optional[int] = None,
788
- position_embeddings: Optional[torch.Tensor] = None,
789
- output_attentions: Optional[bool] = False,
484
+ attention_mask: torch.Tensor | None = None,
485
+ position_embeddings: torch.Tensor | None = None,
486
+ **kwargs: Unpack[TransformersKwargs],
790
487
  ) -> torch.Tensor:
791
- attn_outputs = self.attn(
488
+ attn_output, _ = self.attn(
792
489
  self.attn_norm(hidden_states),
793
- attention_mask=attention_mask,
794
- sliding_window_mask=sliding_window_mask,
795
- position_ids=position_ids,
796
- cu_seqlens=cu_seqlens,
797
- max_seqlen=max_seqlen,
798
490
  position_embeddings=position_embeddings,
799
- output_attentions=output_attentions,
800
- )
801
- hidden_states = hidden_states + attn_outputs[0]
802
- mlp_output = (
803
- self.compiled_mlp(hidden_states)
804
- if self.config.reference_compile
805
- else self.mlp(self.mlp_norm(hidden_states))
491
+ attention_mask=attention_mask,
492
+ **kwargs,
806
493
  )
807
- hidden_states = hidden_states + mlp_output
808
-
809
- return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
494
+ hidden_states = hidden_states + attn_output
495
+ hidden_states = hidden_states + self.mlp(self.mlp_norm(hidden_states))
496
+ return hidden_states
810
497
 
811
498
 
812
499
  @auto_docstring
@@ -817,7 +504,13 @@ class ModernBertPreTrainedModel(PreTrainedModel):
817
504
  _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
818
505
  _supports_flash_attn = True
819
506
  _supports_sdpa = True
820
- _supports_flex_attn = False
507
+ _supports_flex_attn = True
508
+ _supports_attention_backend = True
509
+
510
+ _can_record_outputs = {
511
+ "hidden_states": ModernBertEncoderLayer,
512
+ "attentions": ModernBertAttention,
513
+ }
821
514
 
822
515
  @torch.no_grad()
823
516
  def _init_weights(self, module: nn.Module):
@@ -879,75 +572,24 @@ class ModernBertPreTrainedModel(PreTrainedModel):
879
572
  curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
880
573
  init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
881
574
  init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
882
- elif isinstance(module, ModernBertUnpaddedRotaryEmbedding):
883
- inv_freq = module._compute_inv_freq()
884
- init.copy_(module.inv_freq, inv_freq)
885
575
 
886
576
  def _check_and_adjust_attn_implementation(
887
- self, attn_implementation: Optional[str], is_init_check: bool = False
577
+ self, attn_implementation: str | None, is_init_check: bool = False
888
578
  ) -> str:
889
579
  """
890
580
  Checks and dispatches to hhe requested attention implementation.
891
581
  """
892
- # If the user didn't specify anything, try to use flash_attention_2 if available.
582
+ # If the user didn't specify anything, try to use flash_attention_2.
893
583
  # Otherwise we fall back to the default SDPA -> Eager from the super() method.
894
- # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
895
- # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
896
-
897
584
  try:
898
- attn_implementation = (
899
- "flash_attention_2"
900
- if attn_implementation is None and self._flash_attn_2_can_dispatch()
901
- else attn_implementation
585
+ requested_attn_implementation = "flash_attention_2" if attn_implementation is None else attn_implementation
586
+ return super()._check_and_adjust_attn_implementation(
587
+ attn_implementation=requested_attn_implementation, is_init_check=is_init_check
902
588
  )
903
589
  except (ValueError, ImportError):
904
- pass
905
- return super()._check_and_adjust_attn_implementation(
906
- attn_implementation=attn_implementation, is_init_check=is_init_check
907
- )
908
-
909
- def _maybe_set_compile(self):
910
- if self.config.reference_compile is False:
911
- return
912
-
913
- if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
914
- if self.config.reference_compile:
915
- logger.warning_once(
916
- "If `accelerate` split the model across devices, `torch.compile` will not work. "
917
- "Falling back to non-compiled mode."
918
- )
919
- self.config.reference_compile = False
920
-
921
- if self.device.type == "mps":
922
- if self.config.reference_compile:
923
- logger.warning_once(
924
- "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
925
- "Falling back to non-compiled mode."
926
- )
927
- self.config.reference_compile = False
928
-
929
- if self.device.type == "cpu":
930
- if self.config.reference_compile:
931
- logger.warning_once(
932
- "Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
933
- "Falling back to non-compiled mode."
934
- )
935
- self.config.reference_compile = False
936
-
937
- if self.config.reference_compile is None:
938
- self.config.reference_compile = is_triton_available()
939
-
940
- def resize_token_embeddings(self, *args, **kwargs):
941
- model_embeds = super().resize_token_embeddings(*args, **kwargs)
942
-
943
- if self.config.reference_compile in {True, None}:
944
- if self.config.reference_compile:
945
- logger.warning_once(
946
- "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
947
- )
948
- self.config.reference_compile = False
949
-
950
- return model_embeds
590
+ return super()._check_and_adjust_attn_implementation(
591
+ attn_implementation=attn_implementation, is_init_check=is_init_check
592
+ )
951
593
 
952
594
 
953
595
  @auto_docstring
@@ -957,7 +599,7 @@ class ModernBertModel(ModernBertPreTrainedModel):
957
599
  self.config = config
958
600
  self.embeddings = ModernBertEmbeddings(config)
959
601
  self.layers = nn.ModuleList(
960
- [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
602
+ [ModernBertEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
961
603
  )
962
604
  self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
963
605
  self.rotary_emb = ModernBertRotaryEmbedding(config=config)
@@ -970,175 +612,53 @@ class ModernBertModel(ModernBertPreTrainedModel):
970
612
  def set_input_embeddings(self, value):
971
613
  self.embeddings.tok_embeddings = value
972
614
 
615
+ @check_model_inputs
973
616
  @auto_docstring
974
617
  def forward(
975
618
  self,
976
- input_ids: Optional[torch.LongTensor] = None,
977
- attention_mask: Optional[torch.Tensor] = None,
978
- sliding_window_mask: Optional[torch.Tensor] = None,
979
- position_ids: Optional[torch.LongTensor] = None,
980
- inputs_embeds: Optional[torch.Tensor] = None,
981
- indices: Optional[torch.Tensor] = None,
982
- cu_seqlens: Optional[torch.Tensor] = None,
983
- max_seqlen: Optional[int] = None,
984
- batch_size: Optional[int] = None,
985
- seq_len: Optional[int] = None,
986
- output_attentions: Optional[bool] = None,
987
- output_hidden_states: Optional[bool] = None,
988
- return_dict: Optional[bool] = None,
989
- **kwargs,
990
- ) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]:
991
- r"""
992
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
993
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
994
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
995
- far-away tokens in the local attention layers when not using Flash Attention.
996
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
997
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
998
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
999
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1000
- max_seqlen (`int`, *optional*):
1001
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1002
- batch_size (`int`, *optional*):
1003
- Batch size of the input sequences. Used to pad the output tensors.
1004
- seq_len (`int`, *optional*):
1005
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1006
- """
1007
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1008
- output_hidden_states = (
1009
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1010
- )
1011
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1012
-
619
+ input_ids: torch.LongTensor | None = None,
620
+ attention_mask: torch.Tensor | None = None,
621
+ position_ids: torch.LongTensor | None = None,
622
+ inputs_embeds: torch.Tensor | None = None,
623
+ **kwargs: Unpack[TransformersKwargs],
624
+ ) -> BaseModelOutput:
1013
625
  if (input_ids is None) ^ (inputs_embeds is not None):
1014
626
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1015
627
 
1016
- all_hidden_states = () if output_hidden_states else None
1017
- all_self_attentions = () if output_attentions else None
1018
-
1019
- self._maybe_set_compile()
1020
-
1021
- if input_ids is not None:
1022
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1023
-
1024
- if batch_size is None and seq_len is None:
1025
- if inputs_embeds is not None:
1026
- batch_size, seq_len = inputs_embeds.shape[:2]
1027
- else:
1028
- batch_size, seq_len = input_ids.shape[:2]
628
+ seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
1029
629
  device = input_ids.device if input_ids is not None else inputs_embeds.device
1030
630
 
1031
- if attention_mask is None:
1032
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1033
-
1034
- repad = False
1035
- if self.config._attn_implementation == "flash_attention_2":
1036
- if indices is None and cu_seqlens is None and max_seqlen is None:
1037
- repad = True
1038
- if inputs_embeds is None:
1039
- with torch.no_grad():
1040
- input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
1041
- inputs=input_ids, attention_mask=attention_mask
1042
- )
1043
- else:
1044
- inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
1045
- inputs=inputs_embeds, attention_mask=attention_mask
1046
- )
1047
- if position_ids is None:
1048
- position_ids = indices.unsqueeze(0)
1049
- else:
1050
- if position_ids is None:
1051
- position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
1052
-
1053
- attention_mask, sliding_window_mask = self._update_attention_mask(
1054
- attention_mask, output_attentions=output_attentions
1055
- )
631
+ if position_ids is None:
632
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
1056
633
 
1057
634
  hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
635
+
636
+ if not isinstance(attention_mask_mapping := attention_mask, dict):
637
+ mask_kwargs = {
638
+ "config": self.config,
639
+ "input_embeds": hidden_states,
640
+ "attention_mask": attention_mask,
641
+ }
642
+ attention_mask_mapping = {
643
+ "full_attention": create_bidirectional_mask(**mask_kwargs),
644
+ "sliding_attention": create_bidirectional_sliding_window_mask(**mask_kwargs),
645
+ }
646
+
1058
647
  position_embeddings = {}
1059
648
  for layer_type in self.config.layer_types:
1060
649
  position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
1061
650
 
1062
651
  for encoder_layer in self.layers:
1063
- if output_hidden_states:
1064
- all_hidden_states = all_hidden_states + (hidden_states,)
1065
-
1066
- layer_outputs = encoder_layer(
652
+ hidden_states = encoder_layer(
1067
653
  hidden_states,
1068
- attention_mask=attention_mask,
1069
- sliding_window_mask=sliding_window_mask,
1070
- position_ids=position_ids,
1071
- cu_seqlens=cu_seqlens,
1072
- max_seqlen=max_seqlen,
654
+ attention_mask=attention_mask_mapping[encoder_layer.attention_type],
1073
655
  position_embeddings=position_embeddings[encoder_layer.attention_type],
1074
- output_attentions=output_attentions,
656
+ **kwargs,
1075
657
  )
1076
- hidden_states = layer_outputs[0]
1077
- if output_attentions and len(layer_outputs) > 1:
1078
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
1079
-
1080
- if output_hidden_states:
1081
- all_hidden_states = all_hidden_states + (hidden_states,)
1082
658
 
1083
659
  hidden_states = self.final_norm(hidden_states)
1084
660
 
1085
- if repad:
1086
- hidden_states = _pad_modernbert_output(
1087
- inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
1088
- )
1089
- if all_hidden_states is not None:
1090
- all_hidden_states = tuple(
1091
- _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
1092
- for hs in all_hidden_states
1093
- )
1094
- # If the attention implementation is FA2 and there is no need for repadding, there might still be the batch
1095
- # dimension missing
1096
- elif (
1097
- self.config._attn_implementation == "flash_attention_2"
1098
- and all_hidden_states is not None
1099
- and all_hidden_states[-1].dim() == 2
1100
- ):
1101
- hidden_states = hidden_states.unsqueeze(0)
1102
- all_hidden_states = tuple(hs.unsqueeze(0) for hs in all_hidden_states)
1103
-
1104
- if not return_dict:
1105
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
1106
- return BaseModelOutput(
1107
- last_hidden_state=hidden_states,
1108
- hidden_states=all_hidden_states,
1109
- attentions=all_self_attentions,
1110
- )
1111
-
1112
- def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
1113
- if output_attentions:
1114
- if self.config._attn_implementation == "sdpa":
1115
- logger.warning_once(
1116
- "Outputting attentions is only supported with the 'eager' attention implementation, "
1117
- 'not with "sdpa". Falling back to `attn_implementation="eager"`.'
1118
- )
1119
- self.config._attn_implementation = "eager"
1120
- elif self.config._attn_implementation != "eager":
1121
- logger.warning_once(
1122
- "Outputting attentions is only supported with the eager attention implementation, "
1123
- f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
1124
- " Setting `output_attentions=False`."
1125
- )
1126
-
1127
- global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
1128
-
1129
- # Create position indices
1130
- rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
1131
- # Calculate distance between positions
1132
- distance = torch.abs(rows - rows.T)
1133
-
1134
- # Create sliding window mask (1 for positions within window, 0 outside)
1135
- window_mask = (
1136
- (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
1137
- )
1138
- # Combine with existing mask
1139
- sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
1140
-
1141
- return global_attention_mask, sliding_window_mask
661
+ return BaseModelOutput(last_hidden_state=hidden_states)
1142
662
 
1143
663
 
1144
664
  class ModernBertPredictionHead(nn.Module):
@@ -1180,84 +700,23 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
1180
700
  def set_output_embeddings(self, new_embeddings: nn.Linear):
1181
701
  self.decoder = new_embeddings
1182
702
 
1183
- @torch.compile(dynamic=True)
1184
- def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
1185
- return self.decoder(self.head(output))
1186
-
703
+ @can_return_tuple
1187
704
  @auto_docstring
1188
705
  def forward(
1189
706
  self,
1190
- input_ids: Optional[torch.LongTensor] = None,
1191
- attention_mask: Optional[torch.Tensor] = None,
1192
- sliding_window_mask: Optional[torch.Tensor] = None,
1193
- position_ids: Optional[torch.Tensor] = None,
1194
- inputs_embeds: Optional[torch.Tensor] = None,
1195
- labels: Optional[torch.Tensor] = None,
1196
- indices: Optional[torch.Tensor] = None,
1197
- cu_seqlens: Optional[torch.Tensor] = None,
1198
- max_seqlen: Optional[int] = None,
1199
- batch_size: Optional[int] = None,
1200
- seq_len: Optional[int] = None,
1201
- output_attentions: Optional[bool] = None,
1202
- output_hidden_states: Optional[bool] = None,
1203
- return_dict: Optional[bool] = None,
1204
- **kwargs,
1205
- ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
1206
- r"""
1207
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1208
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1209
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1210
- far-away tokens in the local attention layers when not using Flash Attention.
1211
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1212
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1213
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1214
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1215
- max_seqlen (`int`, *optional*):
1216
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1217
- batch_size (`int`, *optional*):
1218
- Batch size of the input sequences. Used to pad the output tensors.
1219
- seq_len (`int`, *optional*):
1220
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1221
- """
1222
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1223
- self._maybe_set_compile()
1224
-
1225
- if self.config._attn_implementation == "flash_attention_2":
1226
- if indices is None and cu_seqlens is None and max_seqlen is None:
1227
- if batch_size is None and seq_len is None:
1228
- if inputs_embeds is not None:
1229
- batch_size, seq_len = inputs_embeds.shape[:2]
1230
- else:
1231
- batch_size, seq_len = input_ids.shape[:2]
1232
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1233
-
1234
- if attention_mask is None:
1235
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1236
-
1237
- if inputs_embeds is None:
1238
- with torch.no_grad():
1239
- input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1240
- inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1241
- )
1242
- else:
1243
- inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1244
- inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1245
- )
1246
-
707
+ input_ids: torch.LongTensor | None = None,
708
+ attention_mask: torch.Tensor | None = None,
709
+ position_ids: torch.Tensor | None = None,
710
+ inputs_embeds: torch.Tensor | None = None,
711
+ labels: torch.Tensor | None = None,
712
+ **kwargs: Unpack[TransformersKwargs],
713
+ ) -> tuple[torch.Tensor] | MaskedLMOutput:
1247
714
  outputs = self.model(
1248
715
  input_ids=input_ids,
1249
716
  attention_mask=attention_mask,
1250
- sliding_window_mask=sliding_window_mask,
1251
717
  position_ids=position_ids,
1252
718
  inputs_embeds=inputs_embeds,
1253
- indices=indices,
1254
- cu_seqlens=cu_seqlens,
1255
- max_seqlen=max_seqlen,
1256
- batch_size=batch_size,
1257
- seq_len=seq_len,
1258
- output_attentions=output_attentions,
1259
- output_hidden_states=output_hidden_states,
1260
- return_dict=return_dict,
719
+ **kwargs,
1261
720
  )
1262
721
  last_hidden_state = outputs[0]
1263
722
 
@@ -1271,35 +730,12 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
1271
730
  last_hidden_state = last_hidden_state[mask_tokens]
1272
731
  labels = labels[mask_tokens]
1273
732
 
1274
- logits = (
1275
- self.compiled_head(last_hidden_state)
1276
- if self.config.reference_compile
1277
- else self.decoder(self.head(last_hidden_state))
1278
- )
733
+ logits = self.decoder(self.head(last_hidden_state))
1279
734
 
1280
735
  loss = None
1281
736
  if labels is not None:
1282
737
  loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
1283
738
 
1284
- if self.config._attn_implementation == "flash_attention_2":
1285
- # Logits padding
1286
- with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
1287
- logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
1288
- # Hidden states padding
1289
- if getattr(outputs, "hidden_states", None) is not None:
1290
- padded_hidden_states = []
1291
- for hs in outputs.hidden_states:
1292
- if hs.dim() == 3 and hs.shape[0] == 1:
1293
- hs = hs.squeeze(0)
1294
- padded_hidden_states.append(
1295
- _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
1296
- )
1297
- outputs.hidden_states = tuple(padded_hidden_states)
1298
-
1299
- if not return_dict:
1300
- output = (logits,)
1301
- return ((loss,) + output) if loss is not None else output
1302
-
1303
739
  return MaskedLMOutput(
1304
740
  loss=loss,
1305
741
  logits=logits,
@@ -1327,81 +763,39 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
1327
763
  # Initialize weights and apply final processing
1328
764
  self.post_init()
1329
765
 
766
+ @can_return_tuple
1330
767
  @auto_docstring
1331
768
  def forward(
1332
769
  self,
1333
- input_ids: Optional[torch.LongTensor] = None,
1334
- attention_mask: Optional[torch.Tensor] = None,
1335
- sliding_window_mask: Optional[torch.Tensor] = None,
1336
- position_ids: Optional[torch.Tensor] = None,
1337
- inputs_embeds: Optional[torch.Tensor] = None,
1338
- labels: Optional[torch.Tensor] = None,
1339
- indices: Optional[torch.Tensor] = None,
1340
- cu_seqlens: Optional[torch.Tensor] = None,
1341
- max_seqlen: Optional[int] = None,
1342
- batch_size: Optional[int] = None,
1343
- seq_len: Optional[int] = None,
1344
- output_attentions: Optional[bool] = None,
1345
- output_hidden_states: Optional[bool] = None,
1346
- return_dict: Optional[bool] = None,
1347
- **kwargs,
1348
- ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
770
+ input_ids: torch.LongTensor | None = None,
771
+ attention_mask: torch.Tensor | None = None,
772
+ position_ids: torch.Tensor | None = None,
773
+ inputs_embeds: torch.Tensor | None = None,
774
+ labels: torch.Tensor | None = None,
775
+ **kwargs: Unpack[TransformersKwargs],
776
+ ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
1349
777
  r"""
1350
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1351
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1352
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1353
- far-away tokens in the local attention layers when not using Flash Attention.
1354
778
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1355
779
  Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1356
780
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1357
781
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1358
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1359
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1360
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1361
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1362
- max_seqlen (`int`, *optional*):
1363
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1364
- batch_size (`int`, *optional*):
1365
- Batch size of the input sequences. Used to pad the output tensors.
1366
- seq_len (`int`, *optional*):
1367
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1368
782
  """
1369
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1370
- self._maybe_set_compile()
1371
-
1372
- if input_ids is not None:
1373
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1374
-
1375
- if batch_size is None and seq_len is None:
1376
- if inputs_embeds is not None:
1377
- batch_size, seq_len = inputs_embeds.shape[:2]
1378
- else:
1379
- batch_size, seq_len = input_ids.shape[:2]
1380
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1381
-
1382
- if attention_mask is None:
1383
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1384
-
1385
783
  outputs = self.model(
1386
784
  input_ids=input_ids,
1387
785
  attention_mask=attention_mask,
1388
- sliding_window_mask=sliding_window_mask,
1389
786
  position_ids=position_ids,
1390
787
  inputs_embeds=inputs_embeds,
1391
- indices=indices,
1392
- cu_seqlens=cu_seqlens,
1393
- max_seqlen=max_seqlen,
1394
- batch_size=batch_size,
1395
- seq_len=seq_len,
1396
- output_attentions=output_attentions,
1397
- output_hidden_states=output_hidden_states,
1398
- return_dict=return_dict,
788
+ **kwargs,
1399
789
  )
1400
790
  last_hidden_state = outputs[0]
1401
791
 
1402
792
  if self.config.classifier_pooling == "cls":
1403
793
  last_hidden_state = last_hidden_state[:, 0]
1404
794
  elif self.config.classifier_pooling == "mean":
795
+ if attention_mask is None:
796
+ attention_mask = torch.ones(
797
+ last_hidden_state.shape[:2], device=last_hidden_state.device, dtype=torch.bool
798
+ )
1405
799
  last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
1406
800
  dim=1, keepdim=True
1407
801
  )
@@ -1433,10 +827,6 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
1433
827
  loss_fct = BCEWithLogitsLoss()
1434
828
  loss = loss_fct(logits, labels)
1435
829
 
1436
- if not return_dict:
1437
- output = (logits,)
1438
- return ((loss,) + output) if loss is not None else output
1439
-
1440
830
  return SequenceClassifierOutput(
1441
831
  loss=loss,
1442
832
  logits=logits,
@@ -1463,60 +853,27 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1463
853
  # Initialize weights and apply final processing
1464
854
  self.post_init()
1465
855
 
856
+ @can_return_tuple
1466
857
  @auto_docstring
1467
858
  def forward(
1468
859
  self,
1469
- input_ids: Optional[torch.LongTensor] = None,
1470
- attention_mask: Optional[torch.Tensor] = None,
1471
- sliding_window_mask: Optional[torch.Tensor] = None,
1472
- position_ids: Optional[torch.Tensor] = None,
1473
- inputs_embeds: Optional[torch.Tensor] = None,
1474
- labels: Optional[torch.Tensor] = None,
1475
- indices: Optional[torch.Tensor] = None,
1476
- cu_seqlens: Optional[torch.Tensor] = None,
1477
- max_seqlen: Optional[int] = None,
1478
- batch_size: Optional[int] = None,
1479
- seq_len: Optional[int] = None,
1480
- output_attentions: Optional[bool] = None,
1481
- output_hidden_states: Optional[bool] = None,
1482
- return_dict: Optional[bool] = None,
1483
- **kwargs,
1484
- ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
860
+ input_ids: torch.LongTensor | None = None,
861
+ attention_mask: torch.Tensor | None = None,
862
+ position_ids: torch.Tensor | None = None,
863
+ inputs_embeds: torch.Tensor | None = None,
864
+ labels: torch.Tensor | None = None,
865
+ **kwargs: Unpack[TransformersKwargs],
866
+ ) -> tuple[torch.Tensor] | TokenClassifierOutput:
1485
867
  r"""
1486
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1487
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1488
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1489
- far-away tokens in the local attention layers when not using Flash Attention.
1490
868
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1491
869
  Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1492
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1493
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1494
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1495
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1496
- max_seqlen (`int`, *optional*):
1497
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1498
- batch_size (`int`, *optional*):
1499
- Batch size of the input sequences. Used to pad the output tensors.
1500
- seq_len (`int`, *optional*):
1501
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1502
870
  """
1503
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1504
- self._maybe_set_compile()
1505
-
1506
871
  outputs = self.model(
1507
872
  input_ids=input_ids,
1508
873
  attention_mask=attention_mask,
1509
- sliding_window_mask=sliding_window_mask,
1510
874
  position_ids=position_ids,
1511
875
  inputs_embeds=inputs_embeds,
1512
- indices=indices,
1513
- cu_seqlens=cu_seqlens,
1514
- max_seqlen=max_seqlen,
1515
- batch_size=batch_size,
1516
- seq_len=seq_len,
1517
- output_attentions=output_attentions,
1518
- output_hidden_states=output_hidden_states,
1519
- return_dict=return_dict,
876
+ **kwargs,
1520
877
  )
1521
878
  last_hidden_state = outputs[0]
1522
879
 
@@ -1529,10 +886,6 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1529
886
  loss_fct = CrossEntropyLoss()
1530
887
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1531
888
 
1532
- if not return_dict:
1533
- output = (logits,) + outputs[1:]
1534
- return ((loss,) + output) if loss is not None else output
1535
-
1536
889
  return TokenClassifierOutput(
1537
890
  loss=loss,
1538
891
  logits=logits,
@@ -1554,57 +907,22 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
1554
907
 
1555
908
  self.post_init()
1556
909
 
910
+ @can_return_tuple
1557
911
  @auto_docstring
1558
912
  def forward(
1559
913
  self,
1560
- input_ids: Optional[torch.Tensor],
1561
- attention_mask: Optional[torch.Tensor] = None,
1562
- sliding_window_mask: Optional[torch.Tensor] = None,
1563
- position_ids: Optional[torch.Tensor] = None,
1564
- start_positions: Optional[torch.Tensor] = None,
1565
- end_positions: Optional[torch.Tensor] = None,
1566
- indices: Optional[torch.Tensor] = None,
1567
- cu_seqlens: Optional[torch.Tensor] = None,
1568
- max_seqlen: Optional[int] = None,
1569
- batch_size: Optional[int] = None,
1570
- seq_len: Optional[int] = None,
1571
- output_attentions: Optional[bool] = None,
1572
- output_hidden_states: Optional[bool] = None,
1573
- return_dict: Optional[bool] = None,
1574
- **kwargs,
1575
- ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1576
- r"""
1577
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1578
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1579
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1580
- far-away tokens in the local attention layers when not using Flash Attention.
1581
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1582
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1583
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1584
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1585
- max_seqlen (`int`, *optional*):
1586
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1587
- batch_size (`int`, *optional*):
1588
- Batch size of the input sequences. Used to pad the output tensors.
1589
- seq_len (`int`, *optional*):
1590
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1591
- """
1592
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1593
- self._maybe_set_compile()
1594
-
914
+ input_ids: torch.Tensor | None = None,
915
+ attention_mask: torch.Tensor | None = None,
916
+ position_ids: torch.Tensor | None = None,
917
+ start_positions: torch.Tensor | None = None,
918
+ end_positions: torch.Tensor | None = None,
919
+ **kwargs: Unpack[TransformersKwargs],
920
+ ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
1595
921
  outputs = self.model(
1596
922
  input_ids,
1597
923
  attention_mask=attention_mask,
1598
- sliding_window_mask=sliding_window_mask,
1599
924
  position_ids=position_ids,
1600
- indices=indices,
1601
- cu_seqlens=cu_seqlens,
1602
- max_seqlen=max_seqlen,
1603
- batch_size=batch_size,
1604
- seq_len=seq_len,
1605
- output_attentions=output_attentions,
1606
- output_hidden_states=output_hidden_states,
1607
- return_dict=return_dict,
925
+ **kwargs,
1608
926
  )
1609
927
  last_hidden_state = outputs[0]
1610
928
 
@@ -1620,10 +938,6 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
1620
938
  if start_positions is not None and end_positions is not None:
1621
939
  loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1622
940
 
1623
- if not return_dict:
1624
- output = (start_logits, end_logits) + outputs[1:]
1625
- return ((loss,) + output) if loss is not None else output
1626
-
1627
941
  return QuestionAnsweringModelOutput(
1628
942
  loss=loss,
1629
943
  start_logits=start_logits,
@@ -1651,45 +965,22 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
1651
965
  # Initialize weights and apply final processing
1652
966
  self.post_init()
1653
967
 
968
+ @can_return_tuple
1654
969
  @auto_docstring
1655
970
  def forward(
1656
971
  self,
1657
- input_ids: Optional[torch.LongTensor] = None,
1658
- attention_mask: Optional[torch.Tensor] = None,
1659
- sliding_window_mask: Optional[torch.Tensor] = None,
1660
- position_ids: Optional[torch.Tensor] = None,
1661
- inputs_embeds: Optional[torch.Tensor] = None,
1662
- labels: Optional[torch.Tensor] = None,
1663
- indices: Optional[torch.Tensor] = None,
1664
- cu_seqlens: Optional[torch.Tensor] = None,
1665
- max_seqlen: Optional[int] = None,
1666
- batch_size: Optional[int] = None,
1667
- seq_len: Optional[int] = None,
1668
- output_attentions: Optional[bool] = None,
1669
- output_hidden_states: Optional[bool] = None,
1670
- return_dict: Optional[bool] = None,
1671
- **kwargs,
1672
- ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
972
+ input_ids: torch.LongTensor | None = None,
973
+ attention_mask: torch.Tensor | None = None,
974
+ position_ids: torch.Tensor | None = None,
975
+ inputs_embeds: torch.Tensor | None = None,
976
+ labels: torch.Tensor | None = None,
977
+ **kwargs: Unpack[TransformersKwargs],
978
+ ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
1673
979
  r"""
1674
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1675
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1676
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1677
- far-away tokens in the local attention layers when not using Flash Attention.
1678
980
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1679
981
  Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1680
982
  num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
1681
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1682
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1683
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1684
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1685
- max_seqlen (`int`, *optional*):
1686
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1687
- batch_size (`int`, *optional*):
1688
- Batch size of the input sequences. Used to pad the output tensors.
1689
- seq_len (`int`, *optional*):
1690
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1691
983
  """
1692
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1693
984
  num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1694
985
 
1695
986
  input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@@ -1701,22 +992,12 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
1701
992
  else None
1702
993
  )
1703
994
 
1704
- self._maybe_set_compile()
1705
-
1706
995
  outputs = self.model(
1707
996
  input_ids=input_ids,
1708
997
  attention_mask=attention_mask,
1709
- sliding_window_mask=sliding_window_mask,
1710
998
  position_ids=position_ids,
1711
999
  inputs_embeds=inputs_embeds,
1712
- indices=indices,
1713
- cu_seqlens=cu_seqlens,
1714
- max_seqlen=max_seqlen,
1715
- batch_size=batch_size,
1716
- seq_len=seq_len,
1717
- output_attentions=output_attentions,
1718
- output_hidden_states=output_hidden_states,
1719
- return_dict=return_dict,
1000
+ **kwargs,
1720
1001
  )
1721
1002
  last_hidden_state = outputs[0] # shape (num_choices, seq_len, hidden_size)
1722
1003
 
@@ -1748,10 +1029,6 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
1748
1029
  loss_fct = nn.CrossEntropyLoss()
1749
1030
  loss = loss_fct(reshaped_logits, labels)
1750
1031
 
1751
- if not return_dict:
1752
- output = (reshaped_logits,) + outputs[1:]
1753
- return ((loss,) + output) if loss is not None else output
1754
-
1755
1032
  return MultipleChoiceModelOutput(
1756
1033
  loss=loss,
1757
1034
  logits=reshaped_logits,