transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1594) hide show
  1. transformers/__init__.py +11 -37
  2. transformers/activations.py +2 -2
  3. transformers/audio_utils.py +32 -32
  4. transformers/backbone_utils.py +326 -0
  5. transformers/cache_utils.py +26 -126
  6. transformers/cli/chat.py +3 -3
  7. transformers/cli/serve.py +13 -10
  8. transformers/cli/transformers.py +2 -1
  9. transformers/configuration_utils.py +22 -92
  10. transformers/conversion_mapping.py +150 -26
  11. transformers/convert_slow_tokenizer.py +9 -12
  12. transformers/core_model_loading.py +217 -129
  13. transformers/data/processors/glue.py +0 -1
  14. transformers/data/processors/utils.py +0 -1
  15. transformers/data/processors/xnli.py +0 -1
  16. transformers/dependency_versions_check.py +0 -1
  17. transformers/dependency_versions_table.py +10 -11
  18. transformers/distributed/configuration_utils.py +1 -2
  19. transformers/dynamic_module_utils.py +23 -23
  20. transformers/feature_extraction_sequence_utils.py +19 -23
  21. transformers/feature_extraction_utils.py +14 -14
  22. transformers/file_utils.py +0 -2
  23. transformers/generation/candidate_generator.py +2 -4
  24. transformers/generation/configuration_utils.py +54 -39
  25. transformers/generation/continuous_batching/__init__.py +0 -1
  26. transformers/generation/continuous_batching/cache.py +74 -44
  27. transformers/generation/continuous_batching/cache_manager.py +28 -28
  28. transformers/generation/continuous_batching/continuous_api.py +133 -414
  29. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  30. transformers/generation/continuous_batching/requests.py +77 -19
  31. transformers/generation/continuous_batching/scheduler.py +154 -104
  32. transformers/generation/logits_process.py +10 -133
  33. transformers/generation/stopping_criteria.py +1 -2
  34. transformers/generation/streamers.py +0 -1
  35. transformers/generation/utils.py +91 -121
  36. transformers/generation/watermarking.py +2 -3
  37. transformers/hf_argparser.py +9 -13
  38. transformers/hyperparameter_search.py +1 -2
  39. transformers/image_processing_base.py +9 -9
  40. transformers/image_processing_utils.py +11 -15
  41. transformers/image_processing_utils_fast.py +70 -71
  42. transformers/image_transforms.py +73 -42
  43. transformers/image_utils.py +30 -37
  44. transformers/initialization.py +57 -0
  45. transformers/integrations/__init__.py +10 -24
  46. transformers/integrations/accelerate.py +47 -11
  47. transformers/integrations/awq.py +1 -3
  48. transformers/integrations/deepspeed.py +146 -4
  49. transformers/integrations/eetq.py +0 -1
  50. transformers/integrations/executorch.py +2 -6
  51. transformers/integrations/fbgemm_fp8.py +1 -2
  52. transformers/integrations/finegrained_fp8.py +149 -13
  53. transformers/integrations/flash_attention.py +3 -8
  54. transformers/integrations/flex_attention.py +1 -1
  55. transformers/integrations/fp_quant.py +4 -6
  56. transformers/integrations/ggml.py +0 -1
  57. transformers/integrations/hub_kernels.py +18 -7
  58. transformers/integrations/integration_utils.py +2 -3
  59. transformers/integrations/moe.py +226 -106
  60. transformers/integrations/mxfp4.py +52 -40
  61. transformers/integrations/peft.py +488 -176
  62. transformers/integrations/quark.py +2 -4
  63. transformers/integrations/tensor_parallel.py +641 -581
  64. transformers/integrations/torchao.py +4 -6
  65. transformers/loss/loss_lw_detr.py +356 -0
  66. transformers/loss/loss_utils.py +2 -0
  67. transformers/masking_utils.py +199 -59
  68. transformers/model_debugging_utils.py +4 -5
  69. transformers/modelcard.py +14 -192
  70. transformers/modeling_attn_mask_utils.py +19 -19
  71. transformers/modeling_flash_attention_utils.py +28 -29
  72. transformers/modeling_gguf_pytorch_utils.py +5 -5
  73. transformers/modeling_layers.py +21 -22
  74. transformers/modeling_outputs.py +242 -253
  75. transformers/modeling_rope_utils.py +32 -32
  76. transformers/modeling_utils.py +416 -438
  77. transformers/models/__init__.py +10 -0
  78. transformers/models/afmoe/configuration_afmoe.py +40 -33
  79. transformers/models/afmoe/modeling_afmoe.py +38 -41
  80. transformers/models/afmoe/modular_afmoe.py +23 -25
  81. transformers/models/aimv2/configuration_aimv2.py +2 -10
  82. transformers/models/aimv2/modeling_aimv2.py +46 -45
  83. transformers/models/aimv2/modular_aimv2.py +13 -19
  84. transformers/models/albert/configuration_albert.py +8 -2
  85. transformers/models/albert/modeling_albert.py +70 -72
  86. transformers/models/albert/tokenization_albert.py +1 -4
  87. transformers/models/align/configuration_align.py +8 -6
  88. transformers/models/align/modeling_align.py +83 -86
  89. transformers/models/align/processing_align.py +2 -30
  90. transformers/models/altclip/configuration_altclip.py +4 -7
  91. transformers/models/altclip/modeling_altclip.py +106 -103
  92. transformers/models/altclip/processing_altclip.py +2 -15
  93. transformers/models/apertus/__init__.py +0 -1
  94. transformers/models/apertus/configuration_apertus.py +23 -28
  95. transformers/models/apertus/modeling_apertus.py +35 -38
  96. transformers/models/apertus/modular_apertus.py +36 -40
  97. transformers/models/arcee/configuration_arcee.py +25 -30
  98. transformers/models/arcee/modeling_arcee.py +35 -38
  99. transformers/models/arcee/modular_arcee.py +20 -23
  100. transformers/models/aria/configuration_aria.py +31 -44
  101. transformers/models/aria/image_processing_aria.py +25 -27
  102. transformers/models/aria/modeling_aria.py +102 -102
  103. transformers/models/aria/modular_aria.py +111 -124
  104. transformers/models/aria/processing_aria.py +28 -35
  105. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
  106. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
  107. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
  108. transformers/models/audioflamingo3/__init__.py +0 -1
  109. transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
  110. transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
  111. transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
  112. transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
  113. transformers/models/auto/auto_factory.py +12 -11
  114. transformers/models/auto/configuration_auto.py +48 -5
  115. transformers/models/auto/feature_extraction_auto.py +5 -7
  116. transformers/models/auto/image_processing_auto.py +30 -39
  117. transformers/models/auto/modeling_auto.py +33 -199
  118. transformers/models/auto/processing_auto.py +11 -19
  119. transformers/models/auto/tokenization_auto.py +38 -37
  120. transformers/models/auto/video_processing_auto.py +7 -8
  121. transformers/models/autoformer/configuration_autoformer.py +4 -7
  122. transformers/models/autoformer/modeling_autoformer.py +100 -101
  123. transformers/models/aya_vision/configuration_aya_vision.py +4 -1
  124. transformers/models/aya_vision/modeling_aya_vision.py +64 -99
  125. transformers/models/aya_vision/modular_aya_vision.py +46 -74
  126. transformers/models/aya_vision/processing_aya_vision.py +25 -53
  127. transformers/models/bamba/configuration_bamba.py +46 -39
  128. transformers/models/bamba/modeling_bamba.py +83 -119
  129. transformers/models/bamba/modular_bamba.py +70 -109
  130. transformers/models/bark/configuration_bark.py +6 -8
  131. transformers/models/bark/generation_configuration_bark.py +3 -5
  132. transformers/models/bark/modeling_bark.py +64 -65
  133. transformers/models/bark/processing_bark.py +19 -41
  134. transformers/models/bart/configuration_bart.py +9 -5
  135. transformers/models/bart/modeling_bart.py +124 -129
  136. transformers/models/barthez/tokenization_barthez.py +1 -4
  137. transformers/models/bartpho/tokenization_bartpho.py +6 -7
  138. transformers/models/beit/configuration_beit.py +2 -15
  139. transformers/models/beit/image_processing_beit.py +53 -56
  140. transformers/models/beit/image_processing_beit_fast.py +11 -12
  141. transformers/models/beit/modeling_beit.py +65 -62
  142. transformers/models/bert/configuration_bert.py +12 -2
  143. transformers/models/bert/modeling_bert.py +117 -152
  144. transformers/models/bert/tokenization_bert.py +2 -4
  145. transformers/models/bert/tokenization_bert_legacy.py +3 -5
  146. transformers/models/bert_generation/configuration_bert_generation.py +17 -2
  147. transformers/models/bert_generation/modeling_bert_generation.py +53 -55
  148. transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
  149. transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
  150. transformers/models/bertweet/tokenization_bertweet.py +1 -3
  151. transformers/models/big_bird/configuration_big_bird.py +12 -9
  152. transformers/models/big_bird/modeling_big_bird.py +107 -124
  153. transformers/models/big_bird/tokenization_big_bird.py +1 -4
  154. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  155. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
  156. transformers/models/biogpt/configuration_biogpt.py +8 -2
  157. transformers/models/biogpt/modeling_biogpt.py +73 -79
  158. transformers/models/biogpt/modular_biogpt.py +60 -66
  159. transformers/models/biogpt/tokenization_biogpt.py +3 -5
  160. transformers/models/bit/configuration_bit.py +2 -5
  161. transformers/models/bit/image_processing_bit.py +21 -24
  162. transformers/models/bit/image_processing_bit_fast.py +0 -1
  163. transformers/models/bit/modeling_bit.py +15 -16
  164. transformers/models/bitnet/configuration_bitnet.py +23 -28
  165. transformers/models/bitnet/modeling_bitnet.py +34 -38
  166. transformers/models/bitnet/modular_bitnet.py +7 -10
  167. transformers/models/blenderbot/configuration_blenderbot.py +8 -5
  168. transformers/models/blenderbot/modeling_blenderbot.py +68 -99
  169. transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
  170. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
  171. transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
  172. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
  173. transformers/models/blip/configuration_blip.py +9 -10
  174. transformers/models/blip/image_processing_blip.py +17 -20
  175. transformers/models/blip/image_processing_blip_fast.py +0 -1
  176. transformers/models/blip/modeling_blip.py +115 -108
  177. transformers/models/blip/modeling_blip_text.py +63 -65
  178. transformers/models/blip/processing_blip.py +5 -36
  179. transformers/models/blip_2/configuration_blip_2.py +2 -2
  180. transformers/models/blip_2/modeling_blip_2.py +145 -121
  181. transformers/models/blip_2/processing_blip_2.py +8 -38
  182. transformers/models/bloom/configuration_bloom.py +5 -2
  183. transformers/models/bloom/modeling_bloom.py +60 -60
  184. transformers/models/blt/configuration_blt.py +94 -86
  185. transformers/models/blt/modeling_blt.py +93 -90
  186. transformers/models/blt/modular_blt.py +127 -69
  187. transformers/models/bridgetower/configuration_bridgetower.py +7 -2
  188. transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
  189. transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
  190. transformers/models/bridgetower/modeling_bridgetower.py +136 -124
  191. transformers/models/bridgetower/processing_bridgetower.py +2 -16
  192. transformers/models/bros/configuration_bros.py +24 -18
  193. transformers/models/bros/modeling_bros.py +78 -80
  194. transformers/models/bros/processing_bros.py +2 -12
  195. transformers/models/byt5/tokenization_byt5.py +4 -6
  196. transformers/models/camembert/configuration_camembert.py +8 -2
  197. transformers/models/camembert/modeling_camembert.py +97 -99
  198. transformers/models/camembert/modular_camembert.py +51 -54
  199. transformers/models/camembert/tokenization_camembert.py +1 -4
  200. transformers/models/canine/configuration_canine.py +4 -2
  201. transformers/models/canine/modeling_canine.py +73 -75
  202. transformers/models/canine/tokenization_canine.py +0 -1
  203. transformers/models/chameleon/configuration_chameleon.py +29 -34
  204. transformers/models/chameleon/image_processing_chameleon.py +21 -24
  205. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
  206. transformers/models/chameleon/modeling_chameleon.py +135 -92
  207. transformers/models/chameleon/processing_chameleon.py +16 -41
  208. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
  209. transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
  210. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
  211. transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
  212. transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
  213. transformers/models/clap/configuration_clap.py +4 -9
  214. transformers/models/clap/feature_extraction_clap.py +9 -10
  215. transformers/models/clap/modeling_clap.py +109 -111
  216. transformers/models/clap/processing_clap.py +2 -15
  217. transformers/models/clip/configuration_clip.py +4 -2
  218. transformers/models/clip/image_processing_clip.py +21 -24
  219. transformers/models/clip/image_processing_clip_fast.py +9 -1
  220. transformers/models/clip/modeling_clip.py +70 -68
  221. transformers/models/clip/processing_clip.py +2 -14
  222. transformers/models/clip/tokenization_clip.py +2 -5
  223. transformers/models/clipseg/configuration_clipseg.py +4 -2
  224. transformers/models/clipseg/modeling_clipseg.py +113 -112
  225. transformers/models/clipseg/processing_clipseg.py +19 -42
  226. transformers/models/clvp/configuration_clvp.py +15 -5
  227. transformers/models/clvp/feature_extraction_clvp.py +7 -10
  228. transformers/models/clvp/modeling_clvp.py +138 -145
  229. transformers/models/clvp/number_normalizer.py +1 -2
  230. transformers/models/clvp/processing_clvp.py +3 -20
  231. transformers/models/clvp/tokenization_clvp.py +0 -1
  232. transformers/models/code_llama/tokenization_code_llama.py +3 -6
  233. transformers/models/codegen/configuration_codegen.py +4 -4
  234. transformers/models/codegen/modeling_codegen.py +50 -49
  235. transformers/models/codegen/tokenization_codegen.py +5 -6
  236. transformers/models/cohere/configuration_cohere.py +25 -30
  237. transformers/models/cohere/modeling_cohere.py +39 -42
  238. transformers/models/cohere/modular_cohere.py +27 -31
  239. transformers/models/cohere/tokenization_cohere.py +5 -6
  240. transformers/models/cohere2/configuration_cohere2.py +27 -32
  241. transformers/models/cohere2/modeling_cohere2.py +38 -41
  242. transformers/models/cohere2/modular_cohere2.py +48 -52
  243. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  244. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
  245. transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
  246. transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
  247. transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
  248. transformers/models/colpali/configuration_colpali.py +0 -1
  249. transformers/models/colpali/modeling_colpali.py +14 -16
  250. transformers/models/colpali/modular_colpali.py +11 -51
  251. transformers/models/colpali/processing_colpali.py +14 -52
  252. transformers/models/colqwen2/modeling_colqwen2.py +27 -28
  253. transformers/models/colqwen2/modular_colqwen2.py +36 -74
  254. transformers/models/colqwen2/processing_colqwen2.py +16 -52
  255. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
  256. transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
  257. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
  258. transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
  259. transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
  260. transformers/models/convbert/configuration_convbert.py +11 -8
  261. transformers/models/convbert/modeling_convbert.py +85 -87
  262. transformers/models/convbert/tokenization_convbert.py +0 -1
  263. transformers/models/convnext/configuration_convnext.py +2 -5
  264. transformers/models/convnext/image_processing_convnext.py +18 -21
  265. transformers/models/convnext/image_processing_convnext_fast.py +7 -8
  266. transformers/models/convnext/modeling_convnext.py +12 -14
  267. transformers/models/convnextv2/configuration_convnextv2.py +2 -5
  268. transformers/models/convnextv2/modeling_convnextv2.py +12 -14
  269. transformers/models/cpm/tokenization_cpm.py +6 -7
  270. transformers/models/cpm/tokenization_cpm_fast.py +3 -5
  271. transformers/models/cpmant/configuration_cpmant.py +4 -1
  272. transformers/models/cpmant/modeling_cpmant.py +38 -40
  273. transformers/models/cpmant/tokenization_cpmant.py +1 -3
  274. transformers/models/csm/configuration_csm.py +58 -66
  275. transformers/models/csm/generation_csm.py +13 -14
  276. transformers/models/csm/modeling_csm.py +81 -84
  277. transformers/models/csm/modular_csm.py +56 -58
  278. transformers/models/csm/processing_csm.py +25 -68
  279. transformers/models/ctrl/configuration_ctrl.py +16 -1
  280. transformers/models/ctrl/modeling_ctrl.py +51 -66
  281. transformers/models/ctrl/tokenization_ctrl.py +0 -1
  282. transformers/models/cvt/configuration_cvt.py +0 -1
  283. transformers/models/cvt/modeling_cvt.py +13 -15
  284. transformers/models/cwm/__init__.py +0 -1
  285. transformers/models/cwm/configuration_cwm.py +8 -12
  286. transformers/models/cwm/modeling_cwm.py +36 -38
  287. transformers/models/cwm/modular_cwm.py +10 -12
  288. transformers/models/d_fine/configuration_d_fine.py +10 -57
  289. transformers/models/d_fine/modeling_d_fine.py +786 -927
  290. transformers/models/d_fine/modular_d_fine.py +339 -417
  291. transformers/models/dab_detr/configuration_dab_detr.py +22 -49
  292. transformers/models/dab_detr/modeling_dab_detr.py +79 -77
  293. transformers/models/dac/configuration_dac.py +0 -1
  294. transformers/models/dac/feature_extraction_dac.py +6 -9
  295. transformers/models/dac/modeling_dac.py +22 -24
  296. transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
  297. transformers/models/data2vec/configuration_data2vec_text.py +11 -3
  298. transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
  299. transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
  300. transformers/models/data2vec/modeling_data2vec_text.py +97 -99
  301. transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
  302. transformers/models/data2vec/modular_data2vec_audio.py +6 -1
  303. transformers/models/data2vec/modular_data2vec_text.py +51 -54
  304. transformers/models/dbrx/configuration_dbrx.py +29 -22
  305. transformers/models/dbrx/modeling_dbrx.py +45 -48
  306. transformers/models/dbrx/modular_dbrx.py +37 -39
  307. transformers/models/deberta/configuration_deberta.py +6 -1
  308. transformers/models/deberta/modeling_deberta.py +57 -60
  309. transformers/models/deberta/tokenization_deberta.py +2 -5
  310. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
  311. transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
  312. transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
  313. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
  314. transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
  315. transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
  316. transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
  317. transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
  318. transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
  319. transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
  320. transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
  321. transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
  322. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
  323. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
  324. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
  325. transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
  326. transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
  327. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
  328. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
  329. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
  330. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
  331. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
  332. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
  333. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
  334. transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
  335. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
  336. transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
  337. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
  338. transformers/models/deit/configuration_deit.py +0 -1
  339. transformers/models/deit/image_processing_deit.py +18 -21
  340. transformers/models/deit/image_processing_deit_fast.py +0 -1
  341. transformers/models/deit/modeling_deit.py +27 -25
  342. transformers/models/depth_anything/configuration_depth_anything.py +12 -43
  343. transformers/models/depth_anything/modeling_depth_anything.py +10 -11
  344. transformers/models/depth_pro/configuration_depth_pro.py +0 -1
  345. transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
  346. transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
  347. transformers/models/depth_pro/modeling_depth_pro.py +29 -27
  348. transformers/models/detr/configuration_detr.py +18 -50
  349. transformers/models/detr/image_processing_detr.py +64 -66
  350. transformers/models/detr/image_processing_detr_fast.py +33 -34
  351. transformers/models/detr/modeling_detr.py +748 -789
  352. transformers/models/dia/configuration_dia.py +9 -15
  353. transformers/models/dia/feature_extraction_dia.py +6 -9
  354. transformers/models/dia/generation_dia.py +48 -53
  355. transformers/models/dia/modeling_dia.py +68 -71
  356. transformers/models/dia/modular_dia.py +56 -58
  357. transformers/models/dia/processing_dia.py +39 -29
  358. transformers/models/dia/tokenization_dia.py +3 -6
  359. transformers/models/diffllama/configuration_diffllama.py +25 -30
  360. transformers/models/diffllama/modeling_diffllama.py +45 -53
  361. transformers/models/diffllama/modular_diffllama.py +18 -25
  362. transformers/models/dinat/configuration_dinat.py +2 -5
  363. transformers/models/dinat/modeling_dinat.py +47 -48
  364. transformers/models/dinov2/configuration_dinov2.py +2 -5
  365. transformers/models/dinov2/modeling_dinov2.py +20 -21
  366. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
  367. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
  368. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
  369. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
  370. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
  371. transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
  372. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
  373. transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
  374. transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
  375. transformers/models/distilbert/configuration_distilbert.py +8 -2
  376. transformers/models/distilbert/modeling_distilbert.py +47 -49
  377. transformers/models/distilbert/tokenization_distilbert.py +0 -1
  378. transformers/models/doge/__init__.py +0 -1
  379. transformers/models/doge/configuration_doge.py +42 -35
  380. transformers/models/doge/modeling_doge.py +46 -49
  381. transformers/models/doge/modular_doge.py +77 -68
  382. transformers/models/donut/configuration_donut_swin.py +0 -1
  383. transformers/models/donut/image_processing_donut.py +26 -29
  384. transformers/models/donut/image_processing_donut_fast.py +9 -14
  385. transformers/models/donut/modeling_donut_swin.py +44 -46
  386. transformers/models/donut/processing_donut.py +5 -26
  387. transformers/models/dots1/configuration_dots1.py +43 -36
  388. transformers/models/dots1/modeling_dots1.py +35 -38
  389. transformers/models/dots1/modular_dots1.py +0 -1
  390. transformers/models/dpr/configuration_dpr.py +19 -2
  391. transformers/models/dpr/modeling_dpr.py +37 -39
  392. transformers/models/dpr/tokenization_dpr.py +7 -9
  393. transformers/models/dpr/tokenization_dpr_fast.py +7 -9
  394. transformers/models/dpt/configuration_dpt.py +23 -66
  395. transformers/models/dpt/image_processing_dpt.py +65 -66
  396. transformers/models/dpt/image_processing_dpt_fast.py +18 -19
  397. transformers/models/dpt/modeling_dpt.py +38 -36
  398. transformers/models/dpt/modular_dpt.py +14 -15
  399. transformers/models/edgetam/configuration_edgetam.py +1 -2
  400. transformers/models/edgetam/modeling_edgetam.py +87 -89
  401. transformers/models/edgetam/modular_edgetam.py +7 -13
  402. transformers/models/edgetam_video/__init__.py +0 -1
  403. transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
  404. transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
  405. transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
  406. transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
  407. transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
  408. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
  409. transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
  410. transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
  411. transformers/models/efficientnet/configuration_efficientnet.py +0 -1
  412. transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
  413. transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
  414. transformers/models/efficientnet/modeling_efficientnet.py +12 -14
  415. transformers/models/electra/configuration_electra.py +13 -3
  416. transformers/models/electra/modeling_electra.py +107 -109
  417. transformers/models/emu3/configuration_emu3.py +17 -17
  418. transformers/models/emu3/image_processing_emu3.py +44 -39
  419. transformers/models/emu3/modeling_emu3.py +143 -109
  420. transformers/models/emu3/modular_emu3.py +109 -73
  421. transformers/models/emu3/processing_emu3.py +18 -43
  422. transformers/models/encodec/configuration_encodec.py +2 -4
  423. transformers/models/encodec/feature_extraction_encodec.py +10 -13
  424. transformers/models/encodec/modeling_encodec.py +25 -29
  425. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
  426. transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
  427. transformers/models/eomt/configuration_eomt.py +12 -14
  428. transformers/models/eomt/image_processing_eomt.py +53 -55
  429. transformers/models/eomt/image_processing_eomt_fast.py +18 -19
  430. transformers/models/eomt/modeling_eomt.py +19 -21
  431. transformers/models/eomt/modular_eomt.py +28 -30
  432. transformers/models/eomt_dinov3/__init__.py +28 -0
  433. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  434. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  435. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  436. transformers/models/ernie/configuration_ernie.py +24 -3
  437. transformers/models/ernie/modeling_ernie.py +127 -162
  438. transformers/models/ernie/modular_ernie.py +91 -103
  439. transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
  440. transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
  441. transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
  442. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
  443. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
  444. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
  445. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
  446. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
  447. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
  448. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
  449. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
  450. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
  451. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
  452. transformers/models/esm/configuration_esm.py +11 -15
  453. transformers/models/esm/modeling_esm.py +35 -37
  454. transformers/models/esm/modeling_esmfold.py +43 -50
  455. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  456. transformers/models/esm/openfold_utils/loss.py +1 -2
  457. transformers/models/esm/openfold_utils/protein.py +15 -16
  458. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  459. transformers/models/esm/tokenization_esm.py +2 -4
  460. transformers/models/evolla/configuration_evolla.py +50 -40
  461. transformers/models/evolla/modeling_evolla.py +69 -68
  462. transformers/models/evolla/modular_evolla.py +50 -48
  463. transformers/models/evolla/processing_evolla.py +23 -35
  464. transformers/models/exaone4/configuration_exaone4.py +27 -27
  465. transformers/models/exaone4/modeling_exaone4.py +36 -39
  466. transformers/models/exaone4/modular_exaone4.py +51 -50
  467. transformers/models/exaone_moe/__init__.py +27 -0
  468. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  469. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  470. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  471. transformers/models/falcon/configuration_falcon.py +31 -26
  472. transformers/models/falcon/modeling_falcon.py +76 -84
  473. transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
  474. transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
  475. transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
  476. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
  477. transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
  478. transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
  479. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
  480. transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
  481. transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
  482. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
  483. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
  484. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
  485. transformers/models/flaubert/configuration_flaubert.py +10 -5
  486. transformers/models/flaubert/modeling_flaubert.py +125 -129
  487. transformers/models/flaubert/tokenization_flaubert.py +3 -5
  488. transformers/models/flava/configuration_flava.py +9 -9
  489. transformers/models/flava/image_processing_flava.py +66 -67
  490. transformers/models/flava/image_processing_flava_fast.py +46 -47
  491. transformers/models/flava/modeling_flava.py +144 -135
  492. transformers/models/flava/processing_flava.py +2 -12
  493. transformers/models/flex_olmo/__init__.py +0 -1
  494. transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
  495. transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
  496. transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
  497. transformers/models/florence2/configuration_florence2.py +4 -1
  498. transformers/models/florence2/modeling_florence2.py +96 -72
  499. transformers/models/florence2/modular_florence2.py +100 -107
  500. transformers/models/florence2/processing_florence2.py +18 -47
  501. transformers/models/fnet/configuration_fnet.py +6 -2
  502. transformers/models/fnet/modeling_fnet.py +69 -80
  503. transformers/models/fnet/tokenization_fnet.py +0 -1
  504. transformers/models/focalnet/configuration_focalnet.py +2 -5
  505. transformers/models/focalnet/modeling_focalnet.py +49 -48
  506. transformers/models/fsmt/configuration_fsmt.py +12 -17
  507. transformers/models/fsmt/modeling_fsmt.py +47 -48
  508. transformers/models/fsmt/tokenization_fsmt.py +3 -5
  509. transformers/models/funnel/configuration_funnel.py +8 -1
  510. transformers/models/funnel/modeling_funnel.py +91 -93
  511. transformers/models/funnel/tokenization_funnel.py +2 -5
  512. transformers/models/fuyu/configuration_fuyu.py +28 -34
  513. transformers/models/fuyu/image_processing_fuyu.py +29 -31
  514. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  515. transformers/models/fuyu/modeling_fuyu.py +50 -52
  516. transformers/models/fuyu/processing_fuyu.py +9 -36
  517. transformers/models/gemma/configuration_gemma.py +25 -30
  518. transformers/models/gemma/modeling_gemma.py +36 -38
  519. transformers/models/gemma/modular_gemma.py +33 -36
  520. transformers/models/gemma/tokenization_gemma.py +3 -6
  521. transformers/models/gemma2/configuration_gemma2.py +30 -35
  522. transformers/models/gemma2/modeling_gemma2.py +38 -41
  523. transformers/models/gemma2/modular_gemma2.py +63 -67
  524. transformers/models/gemma3/configuration_gemma3.py +53 -48
  525. transformers/models/gemma3/image_processing_gemma3.py +29 -31
  526. transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
  527. transformers/models/gemma3/modeling_gemma3.py +123 -122
  528. transformers/models/gemma3/modular_gemma3.py +128 -125
  529. transformers/models/gemma3/processing_gemma3.py +5 -5
  530. transformers/models/gemma3n/configuration_gemma3n.py +42 -30
  531. transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
  532. transformers/models/gemma3n/modeling_gemma3n.py +166 -147
  533. transformers/models/gemma3n/modular_gemma3n.py +176 -148
  534. transformers/models/gemma3n/processing_gemma3n.py +12 -26
  535. transformers/models/git/configuration_git.py +5 -8
  536. transformers/models/git/modeling_git.py +115 -127
  537. transformers/models/git/processing_git.py +2 -14
  538. transformers/models/glm/configuration_glm.py +26 -30
  539. transformers/models/glm/modeling_glm.py +36 -39
  540. transformers/models/glm/modular_glm.py +4 -7
  541. transformers/models/glm4/configuration_glm4.py +26 -30
  542. transformers/models/glm4/modeling_glm4.py +39 -41
  543. transformers/models/glm4/modular_glm4.py +8 -10
  544. transformers/models/glm46v/configuration_glm46v.py +4 -1
  545. transformers/models/glm46v/image_processing_glm46v.py +40 -38
  546. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  547. transformers/models/glm46v/modeling_glm46v.py +138 -93
  548. transformers/models/glm46v/modular_glm46v.py +5 -3
  549. transformers/models/glm46v/processing_glm46v.py +7 -41
  550. transformers/models/glm46v/video_processing_glm46v.py +9 -11
  551. transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
  552. transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
  553. transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
  554. transformers/models/glm4_moe_lite/__init__.py +28 -0
  555. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
  556. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
  557. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
  558. transformers/models/glm4v/configuration_glm4v.py +25 -24
  559. transformers/models/glm4v/image_processing_glm4v.py +39 -38
  560. transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
  561. transformers/models/glm4v/modeling_glm4v.py +249 -210
  562. transformers/models/glm4v/modular_glm4v.py +211 -230
  563. transformers/models/glm4v/processing_glm4v.py +7 -41
  564. transformers/models/glm4v/video_processing_glm4v.py +9 -11
  565. transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
  566. transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
  567. transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
  568. transformers/models/glm_image/__init__.py +31 -0
  569. transformers/models/glm_image/configuration_glm_image.py +358 -0
  570. transformers/models/glm_image/image_processing_glm_image.py +503 -0
  571. transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
  572. transformers/models/glm_image/modeling_glm_image.py +1691 -0
  573. transformers/models/glm_image/modular_glm_image.py +1640 -0
  574. transformers/models/glm_image/processing_glm_image.py +265 -0
  575. transformers/models/glm_ocr/__init__.py +28 -0
  576. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  577. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  578. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  579. transformers/models/glmasr/__init__.py +0 -1
  580. transformers/models/glmasr/configuration_glmasr.py +0 -1
  581. transformers/models/glmasr/modeling_glmasr.py +51 -46
  582. transformers/models/glmasr/modular_glmasr.py +39 -29
  583. transformers/models/glmasr/processing_glmasr.py +7 -8
  584. transformers/models/glpn/configuration_glpn.py +0 -1
  585. transformers/models/glpn/image_processing_glpn.py +11 -12
  586. transformers/models/glpn/image_processing_glpn_fast.py +11 -12
  587. transformers/models/glpn/modeling_glpn.py +14 -14
  588. transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
  589. transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
  590. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
  591. transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
  592. transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
  593. transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
  594. transformers/models/gpt2/configuration_gpt2.py +13 -2
  595. transformers/models/gpt2/modeling_gpt2.py +111 -113
  596. transformers/models/gpt2/tokenization_gpt2.py +6 -9
  597. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
  598. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
  599. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
  600. transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
  601. transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
  602. transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
  603. transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
  604. transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
  605. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
  606. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
  607. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
  608. transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
  609. transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
  610. transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
  611. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  612. transformers/models/gptj/configuration_gptj.py +4 -5
  613. transformers/models/gptj/modeling_gptj.py +85 -88
  614. transformers/models/granite/configuration_granite.py +28 -33
  615. transformers/models/granite/modeling_granite.py +43 -45
  616. transformers/models/granite/modular_granite.py +29 -31
  617. transformers/models/granite_speech/configuration_granite_speech.py +0 -1
  618. transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
  619. transformers/models/granite_speech/modeling_granite_speech.py +84 -60
  620. transformers/models/granite_speech/processing_granite_speech.py +11 -4
  621. transformers/models/granitemoe/configuration_granitemoe.py +31 -36
  622. transformers/models/granitemoe/modeling_granitemoe.py +39 -41
  623. transformers/models/granitemoe/modular_granitemoe.py +21 -23
  624. transformers/models/granitemoehybrid/__init__.py +0 -1
  625. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
  626. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
  627. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
  628. transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
  629. transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
  630. transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
  631. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
  632. transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
  633. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
  634. transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
  635. transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
  636. transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
  637. transformers/models/groupvit/configuration_groupvit.py +4 -2
  638. transformers/models/groupvit/modeling_groupvit.py +98 -92
  639. transformers/models/helium/configuration_helium.py +25 -29
  640. transformers/models/helium/modeling_helium.py +37 -40
  641. transformers/models/helium/modular_helium.py +3 -7
  642. transformers/models/herbert/tokenization_herbert.py +4 -6
  643. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
  644. transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
  645. transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
  646. transformers/models/hiera/configuration_hiera.py +2 -5
  647. transformers/models/hiera/modeling_hiera.py +71 -70
  648. transformers/models/hubert/configuration_hubert.py +4 -2
  649. transformers/models/hubert/modeling_hubert.py +42 -41
  650. transformers/models/hubert/modular_hubert.py +8 -11
  651. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
  652. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
  653. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
  654. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
  655. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
  656. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
  657. transformers/models/ibert/configuration_ibert.py +4 -2
  658. transformers/models/ibert/modeling_ibert.py +60 -62
  659. transformers/models/ibert/quant_modules.py +0 -1
  660. transformers/models/idefics/configuration_idefics.py +5 -8
  661. transformers/models/idefics/image_processing_idefics.py +13 -15
  662. transformers/models/idefics/modeling_idefics.py +63 -65
  663. transformers/models/idefics/perceiver.py +1 -3
  664. transformers/models/idefics/processing_idefics.py +32 -48
  665. transformers/models/idefics/vision.py +27 -28
  666. transformers/models/idefics2/configuration_idefics2.py +1 -3
  667. transformers/models/idefics2/image_processing_idefics2.py +31 -32
  668. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  669. transformers/models/idefics2/modeling_idefics2.py +126 -106
  670. transformers/models/idefics2/processing_idefics2.py +10 -68
  671. transformers/models/idefics3/configuration_idefics3.py +1 -4
  672. transformers/models/idefics3/image_processing_idefics3.py +42 -43
  673. transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
  674. transformers/models/idefics3/modeling_idefics3.py +113 -92
  675. transformers/models/idefics3/processing_idefics3.py +15 -69
  676. transformers/models/ijepa/configuration_ijepa.py +0 -1
  677. transformers/models/ijepa/modeling_ijepa.py +13 -14
  678. transformers/models/ijepa/modular_ijepa.py +5 -7
  679. transformers/models/imagegpt/configuration_imagegpt.py +9 -2
  680. transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
  681. transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
  682. transformers/models/imagegpt/modeling_imagegpt.py +65 -62
  683. transformers/models/informer/configuration_informer.py +6 -9
  684. transformers/models/informer/modeling_informer.py +87 -89
  685. transformers/models/informer/modular_informer.py +13 -16
  686. transformers/models/instructblip/configuration_instructblip.py +2 -2
  687. transformers/models/instructblip/modeling_instructblip.py +104 -79
  688. transformers/models/instructblip/processing_instructblip.py +10 -36
  689. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  690. transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
  691. transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
  692. transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
  693. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
  694. transformers/models/internvl/configuration_internvl.py +5 -1
  695. transformers/models/internvl/modeling_internvl.py +76 -98
  696. transformers/models/internvl/modular_internvl.py +45 -59
  697. transformers/models/internvl/processing_internvl.py +12 -45
  698. transformers/models/internvl/video_processing_internvl.py +10 -11
  699. transformers/models/jais2/configuration_jais2.py +25 -29
  700. transformers/models/jais2/modeling_jais2.py +36 -38
  701. transformers/models/jais2/modular_jais2.py +20 -22
  702. transformers/models/jamba/configuration_jamba.py +5 -8
  703. transformers/models/jamba/modeling_jamba.py +47 -50
  704. transformers/models/jamba/modular_jamba.py +40 -41
  705. transformers/models/janus/configuration_janus.py +0 -1
  706. transformers/models/janus/image_processing_janus.py +37 -39
  707. transformers/models/janus/image_processing_janus_fast.py +20 -21
  708. transformers/models/janus/modeling_janus.py +103 -188
  709. transformers/models/janus/modular_janus.py +122 -83
  710. transformers/models/janus/processing_janus.py +17 -43
  711. transformers/models/jetmoe/configuration_jetmoe.py +26 -27
  712. transformers/models/jetmoe/modeling_jetmoe.py +42 -45
  713. transformers/models/jetmoe/modular_jetmoe.py +33 -36
  714. transformers/models/kosmos2/configuration_kosmos2.py +10 -9
  715. transformers/models/kosmos2/modeling_kosmos2.py +199 -178
  716. transformers/models/kosmos2/processing_kosmos2.py +40 -55
  717. transformers/models/kosmos2_5/__init__.py +0 -1
  718. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
  719. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
  720. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
  721. transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
  722. transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
  723. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
  724. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
  725. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
  726. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
  727. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
  728. transformers/models/lasr/configuration_lasr.py +3 -7
  729. transformers/models/lasr/feature_extraction_lasr.py +10 -12
  730. transformers/models/lasr/modeling_lasr.py +21 -24
  731. transformers/models/lasr/modular_lasr.py +11 -13
  732. transformers/models/lasr/processing_lasr.py +12 -6
  733. transformers/models/lasr/tokenization_lasr.py +2 -4
  734. transformers/models/layoutlm/configuration_layoutlm.py +14 -2
  735. transformers/models/layoutlm/modeling_layoutlm.py +70 -72
  736. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
  737. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
  738. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
  739. transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
  740. transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
  741. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
  742. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
  743. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
  744. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
  745. transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
  746. transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
  747. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
  748. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
  749. transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
  750. transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
  751. transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
  752. transformers/models/led/configuration_led.py +8 -12
  753. transformers/models/led/modeling_led.py +113 -267
  754. transformers/models/levit/configuration_levit.py +0 -1
  755. transformers/models/levit/image_processing_levit.py +19 -21
  756. transformers/models/levit/image_processing_levit_fast.py +4 -5
  757. transformers/models/levit/modeling_levit.py +17 -19
  758. transformers/models/lfm2/configuration_lfm2.py +27 -30
  759. transformers/models/lfm2/modeling_lfm2.py +46 -48
  760. transformers/models/lfm2/modular_lfm2.py +32 -32
  761. transformers/models/lfm2_moe/__init__.py +0 -1
  762. transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
  763. transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
  764. transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
  765. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
  766. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
  767. transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
  768. transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
  769. transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
  770. transformers/models/lightglue/image_processing_lightglue.py +16 -15
  771. transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
  772. transformers/models/lightglue/modeling_lightglue.py +31 -33
  773. transformers/models/lightglue/modular_lightglue.py +31 -31
  774. transformers/models/lighton_ocr/__init__.py +28 -0
  775. transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
  776. transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
  777. transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
  778. transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
  779. transformers/models/lilt/configuration_lilt.py +6 -2
  780. transformers/models/lilt/modeling_lilt.py +53 -55
  781. transformers/models/llama/configuration_llama.py +26 -31
  782. transformers/models/llama/modeling_llama.py +35 -38
  783. transformers/models/llama/tokenization_llama.py +2 -4
  784. transformers/models/llama4/configuration_llama4.py +87 -69
  785. transformers/models/llama4/image_processing_llama4_fast.py +11 -12
  786. transformers/models/llama4/modeling_llama4.py +116 -115
  787. transformers/models/llama4/processing_llama4.py +33 -57
  788. transformers/models/llava/configuration_llava.py +10 -1
  789. transformers/models/llava/image_processing_llava.py +25 -28
  790. transformers/models/llava/image_processing_llava_fast.py +9 -10
  791. transformers/models/llava/modeling_llava.py +73 -102
  792. transformers/models/llava/processing_llava.py +18 -51
  793. transformers/models/llava_next/configuration_llava_next.py +2 -2
  794. transformers/models/llava_next/image_processing_llava_next.py +43 -45
  795. transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
  796. transformers/models/llava_next/modeling_llava_next.py +103 -104
  797. transformers/models/llava_next/processing_llava_next.py +18 -47
  798. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
  799. transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
  800. transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
  801. transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
  802. transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
  803. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
  804. transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
  805. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
  806. transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
  807. transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
  808. transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
  809. transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
  810. transformers/models/longcat_flash/__init__.py +0 -1
  811. transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
  812. transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
  813. transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
  814. transformers/models/longformer/configuration_longformer.py +5 -5
  815. transformers/models/longformer/modeling_longformer.py +99 -101
  816. transformers/models/longt5/configuration_longt5.py +9 -7
  817. transformers/models/longt5/modeling_longt5.py +45 -45
  818. transformers/models/luke/configuration_luke.py +8 -2
  819. transformers/models/luke/modeling_luke.py +179 -181
  820. transformers/models/luke/tokenization_luke.py +99 -105
  821. transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
  822. transformers/models/lw_detr/configuration_lw_detr.py +362 -0
  823. transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
  824. transformers/models/lw_detr/modular_lw_detr.py +1609 -0
  825. transformers/models/lxmert/configuration_lxmert.py +16 -1
  826. transformers/models/lxmert/modeling_lxmert.py +63 -74
  827. transformers/models/m2m_100/configuration_m2m_100.py +7 -9
  828. transformers/models/m2m_100/modeling_m2m_100.py +72 -74
  829. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  830. transformers/models/mamba/configuration_mamba.py +5 -3
  831. transformers/models/mamba/modeling_mamba.py +61 -70
  832. transformers/models/mamba2/configuration_mamba2.py +5 -8
  833. transformers/models/mamba2/modeling_mamba2.py +66 -79
  834. transformers/models/marian/configuration_marian.py +10 -5
  835. transformers/models/marian/modeling_marian.py +88 -90
  836. transformers/models/marian/tokenization_marian.py +6 -6
  837. transformers/models/markuplm/configuration_markuplm.py +4 -7
  838. transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
  839. transformers/models/markuplm/modeling_markuplm.py +63 -65
  840. transformers/models/markuplm/processing_markuplm.py +31 -38
  841. transformers/models/markuplm/tokenization_markuplm.py +67 -77
  842. transformers/models/mask2former/configuration_mask2former.py +14 -52
  843. transformers/models/mask2former/image_processing_mask2former.py +84 -85
  844. transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
  845. transformers/models/mask2former/modeling_mask2former.py +108 -104
  846. transformers/models/mask2former/modular_mask2former.py +6 -8
  847. transformers/models/maskformer/configuration_maskformer.py +17 -51
  848. transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
  849. transformers/models/maskformer/image_processing_maskformer.py +84 -85
  850. transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
  851. transformers/models/maskformer/modeling_maskformer.py +71 -67
  852. transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
  853. transformers/models/mbart/configuration_mbart.py +9 -5
  854. transformers/models/mbart/modeling_mbart.py +120 -119
  855. transformers/models/mbart/tokenization_mbart.py +2 -4
  856. transformers/models/mbart50/tokenization_mbart50.py +3 -5
  857. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
  858. transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
  859. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  860. transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
  861. transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
  862. transformers/models/mgp_str/configuration_mgp_str.py +0 -1
  863. transformers/models/mgp_str/modeling_mgp_str.py +18 -18
  864. transformers/models/mgp_str/processing_mgp_str.py +3 -20
  865. transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
  866. transformers/models/mimi/configuration_mimi.py +42 -40
  867. transformers/models/mimi/modeling_mimi.py +116 -115
  868. transformers/models/minimax/__init__.py +0 -1
  869. transformers/models/minimax/configuration_minimax.py +40 -47
  870. transformers/models/minimax/modeling_minimax.py +46 -49
  871. transformers/models/minimax/modular_minimax.py +59 -65
  872. transformers/models/minimax_m2/__init__.py +28 -0
  873. transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
  874. transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
  875. transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
  876. transformers/models/ministral/configuration_ministral.py +25 -29
  877. transformers/models/ministral/modeling_ministral.py +35 -37
  878. transformers/models/ministral/modular_ministral.py +32 -37
  879. transformers/models/ministral3/configuration_ministral3.py +23 -26
  880. transformers/models/ministral3/modeling_ministral3.py +35 -37
  881. transformers/models/ministral3/modular_ministral3.py +7 -8
  882. transformers/models/mistral/configuration_mistral.py +24 -29
  883. transformers/models/mistral/modeling_mistral.py +35 -37
  884. transformers/models/mistral/modular_mistral.py +14 -15
  885. transformers/models/mistral3/configuration_mistral3.py +4 -1
  886. transformers/models/mistral3/modeling_mistral3.py +79 -82
  887. transformers/models/mistral3/modular_mistral3.py +66 -67
  888. transformers/models/mixtral/configuration_mixtral.py +32 -38
  889. transformers/models/mixtral/modeling_mixtral.py +39 -42
  890. transformers/models/mixtral/modular_mixtral.py +26 -29
  891. transformers/models/mlcd/configuration_mlcd.py +0 -1
  892. transformers/models/mlcd/modeling_mlcd.py +17 -17
  893. transformers/models/mlcd/modular_mlcd.py +16 -16
  894. transformers/models/mllama/configuration_mllama.py +10 -15
  895. transformers/models/mllama/image_processing_mllama.py +23 -25
  896. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  897. transformers/models/mllama/modeling_mllama.py +100 -103
  898. transformers/models/mllama/processing_mllama.py +6 -55
  899. transformers/models/mluke/tokenization_mluke.py +97 -103
  900. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
  901. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
  902. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
  903. transformers/models/mobilebert/configuration_mobilebert.py +4 -2
  904. transformers/models/mobilebert/modeling_mobilebert.py +78 -88
  905. transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
  906. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
  907. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
  908. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
  909. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
  910. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
  911. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
  912. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
  913. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
  914. transformers/models/mobilevit/configuration_mobilevit.py +0 -1
  915. transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
  916. transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
  917. transformers/models/mobilevit/modeling_mobilevit.py +21 -21
  918. transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
  919. transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
  920. transformers/models/modernbert/configuration_modernbert.py +76 -51
  921. transformers/models/modernbert/modeling_modernbert.py +188 -943
  922. transformers/models/modernbert/modular_modernbert.py +255 -978
  923. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
  924. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
  925. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
  926. transformers/models/moonshine/configuration_moonshine.py +34 -31
  927. transformers/models/moonshine/modeling_moonshine.py +70 -72
  928. transformers/models/moonshine/modular_moonshine.py +91 -86
  929. transformers/models/moshi/configuration_moshi.py +46 -23
  930. transformers/models/moshi/modeling_moshi.py +134 -142
  931. transformers/models/mpnet/configuration_mpnet.py +6 -2
  932. transformers/models/mpnet/modeling_mpnet.py +55 -57
  933. transformers/models/mpnet/tokenization_mpnet.py +1 -4
  934. transformers/models/mpt/configuration_mpt.py +17 -9
  935. transformers/models/mpt/modeling_mpt.py +58 -60
  936. transformers/models/mra/configuration_mra.py +8 -2
  937. transformers/models/mra/modeling_mra.py +54 -56
  938. transformers/models/mt5/configuration_mt5.py +9 -6
  939. transformers/models/mt5/modeling_mt5.py +80 -85
  940. transformers/models/musicgen/configuration_musicgen.py +12 -8
  941. transformers/models/musicgen/modeling_musicgen.py +114 -116
  942. transformers/models/musicgen/processing_musicgen.py +3 -21
  943. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
  944. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
  945. transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
  946. transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
  947. transformers/models/mvp/configuration_mvp.py +8 -5
  948. transformers/models/mvp/modeling_mvp.py +121 -123
  949. transformers/models/myt5/tokenization_myt5.py +8 -10
  950. transformers/models/nanochat/configuration_nanochat.py +5 -8
  951. transformers/models/nanochat/modeling_nanochat.py +36 -39
  952. transformers/models/nanochat/modular_nanochat.py +16 -18
  953. transformers/models/nemotron/configuration_nemotron.py +25 -30
  954. transformers/models/nemotron/modeling_nemotron.py +53 -66
  955. transformers/models/nllb/tokenization_nllb.py +14 -14
  956. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
  957. transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
  958. transformers/models/nougat/image_processing_nougat.py +29 -32
  959. transformers/models/nougat/image_processing_nougat_fast.py +12 -13
  960. transformers/models/nougat/processing_nougat.py +37 -39
  961. transformers/models/nougat/tokenization_nougat.py +5 -7
  962. transformers/models/nystromformer/configuration_nystromformer.py +8 -2
  963. transformers/models/nystromformer/modeling_nystromformer.py +61 -63
  964. transformers/models/olmo/configuration_olmo.py +23 -28
  965. transformers/models/olmo/modeling_olmo.py +35 -38
  966. transformers/models/olmo/modular_olmo.py +8 -12
  967. transformers/models/olmo2/configuration_olmo2.py +27 -32
  968. transformers/models/olmo2/modeling_olmo2.py +36 -39
  969. transformers/models/olmo2/modular_olmo2.py +36 -38
  970. transformers/models/olmo3/__init__.py +0 -1
  971. transformers/models/olmo3/configuration_olmo3.py +30 -34
  972. transformers/models/olmo3/modeling_olmo3.py +35 -38
  973. transformers/models/olmo3/modular_olmo3.py +44 -47
  974. transformers/models/olmoe/configuration_olmoe.py +29 -33
  975. transformers/models/olmoe/modeling_olmoe.py +41 -43
  976. transformers/models/olmoe/modular_olmoe.py +15 -16
  977. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
  978. transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
  979. transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
  980. transformers/models/oneformer/configuration_oneformer.py +11 -51
  981. transformers/models/oneformer/image_processing_oneformer.py +83 -84
  982. transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
  983. transformers/models/oneformer/modeling_oneformer.py +137 -133
  984. transformers/models/oneformer/processing_oneformer.py +28 -43
  985. transformers/models/openai/configuration_openai.py +16 -1
  986. transformers/models/openai/modeling_openai.py +50 -51
  987. transformers/models/openai/tokenization_openai.py +2 -5
  988. transformers/models/opt/configuration_opt.py +6 -7
  989. transformers/models/opt/modeling_opt.py +79 -80
  990. transformers/models/ovis2/__init__.py +0 -1
  991. transformers/models/ovis2/configuration_ovis2.py +4 -1
  992. transformers/models/ovis2/image_processing_ovis2.py +22 -24
  993. transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
  994. transformers/models/ovis2/modeling_ovis2.py +99 -142
  995. transformers/models/ovis2/modular_ovis2.py +82 -45
  996. transformers/models/ovis2/processing_ovis2.py +12 -40
  997. transformers/models/owlv2/configuration_owlv2.py +4 -2
  998. transformers/models/owlv2/image_processing_owlv2.py +20 -21
  999. transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
  1000. transformers/models/owlv2/modeling_owlv2.py +122 -114
  1001. transformers/models/owlv2/modular_owlv2.py +11 -12
  1002. transformers/models/owlv2/processing_owlv2.py +20 -49
  1003. transformers/models/owlvit/configuration_owlvit.py +4 -2
  1004. transformers/models/owlvit/image_processing_owlvit.py +21 -22
  1005. transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
  1006. transformers/models/owlvit/modeling_owlvit.py +121 -113
  1007. transformers/models/owlvit/processing_owlvit.py +20 -48
  1008. transformers/models/paddleocr_vl/__init__.py +0 -1
  1009. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
  1010. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
  1011. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
  1012. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
  1013. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
  1014. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
  1015. transformers/models/paligemma/configuration_paligemma.py +4 -1
  1016. transformers/models/paligemma/modeling_paligemma.py +81 -79
  1017. transformers/models/paligemma/processing_paligemma.py +13 -66
  1018. transformers/models/parakeet/configuration_parakeet.py +3 -8
  1019. transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
  1020. transformers/models/parakeet/modeling_parakeet.py +21 -25
  1021. transformers/models/parakeet/modular_parakeet.py +19 -21
  1022. transformers/models/parakeet/processing_parakeet.py +12 -5
  1023. transformers/models/parakeet/tokenization_parakeet.py +2 -4
  1024. transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
  1025. transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
  1026. transformers/models/patchtst/configuration_patchtst.py +6 -9
  1027. transformers/models/patchtst/modeling_patchtst.py +75 -77
  1028. transformers/models/pe_audio/__init__.py +0 -1
  1029. transformers/models/pe_audio/configuration_pe_audio.py +14 -16
  1030. transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
  1031. transformers/models/pe_audio/modeling_pe_audio.py +30 -31
  1032. transformers/models/pe_audio/modular_pe_audio.py +17 -18
  1033. transformers/models/pe_audio/processing_pe_audio.py +0 -1
  1034. transformers/models/pe_audio_video/__init__.py +0 -1
  1035. transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
  1036. transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
  1037. transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
  1038. transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
  1039. transformers/models/pe_video/__init__.py +0 -1
  1040. transformers/models/pe_video/configuration_pe_video.py +14 -16
  1041. transformers/models/pe_video/modeling_pe_video.py +57 -46
  1042. transformers/models/pe_video/modular_pe_video.py +47 -35
  1043. transformers/models/pe_video/video_processing_pe_video.py +2 -4
  1044. transformers/models/pegasus/configuration_pegasus.py +8 -6
  1045. transformers/models/pegasus/modeling_pegasus.py +67 -69
  1046. transformers/models/pegasus/tokenization_pegasus.py +1 -4
  1047. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
  1048. transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
  1049. transformers/models/perceiver/configuration_perceiver.py +0 -1
  1050. transformers/models/perceiver/image_processing_perceiver.py +22 -25
  1051. transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
  1052. transformers/models/perceiver/modeling_perceiver.py +152 -145
  1053. transformers/models/perceiver/tokenization_perceiver.py +3 -6
  1054. transformers/models/perception_lm/configuration_perception_lm.py +0 -1
  1055. transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
  1056. transformers/models/perception_lm/modeling_perception_lm.py +64 -67
  1057. transformers/models/perception_lm/modular_perception_lm.py +58 -58
  1058. transformers/models/perception_lm/processing_perception_lm.py +13 -47
  1059. transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
  1060. transformers/models/persimmon/configuration_persimmon.py +23 -28
  1061. transformers/models/persimmon/modeling_persimmon.py +44 -47
  1062. transformers/models/phi/configuration_phi.py +27 -28
  1063. transformers/models/phi/modeling_phi.py +39 -41
  1064. transformers/models/phi/modular_phi.py +26 -26
  1065. transformers/models/phi3/configuration_phi3.py +32 -37
  1066. transformers/models/phi3/modeling_phi3.py +37 -40
  1067. transformers/models/phi3/modular_phi3.py +16 -20
  1068. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
  1069. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
  1070. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  1071. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
  1072. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
  1073. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
  1074. transformers/models/phimoe/configuration_phimoe.py +31 -36
  1075. transformers/models/phimoe/modeling_phimoe.py +50 -77
  1076. transformers/models/phimoe/modular_phimoe.py +12 -8
  1077. transformers/models/phobert/tokenization_phobert.py +4 -6
  1078. transformers/models/pix2struct/configuration_pix2struct.py +12 -10
  1079. transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
  1080. transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
  1081. transformers/models/pix2struct/modeling_pix2struct.py +56 -52
  1082. transformers/models/pix2struct/processing_pix2struct.py +5 -26
  1083. transformers/models/pixio/__init__.py +0 -1
  1084. transformers/models/pixio/configuration_pixio.py +2 -5
  1085. transformers/models/pixio/modeling_pixio.py +16 -17
  1086. transformers/models/pixio/modular_pixio.py +7 -8
  1087. transformers/models/pixtral/configuration_pixtral.py +11 -14
  1088. transformers/models/pixtral/image_processing_pixtral.py +26 -28
  1089. transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
  1090. transformers/models/pixtral/modeling_pixtral.py +31 -37
  1091. transformers/models/pixtral/processing_pixtral.py +18 -52
  1092. transformers/models/plbart/configuration_plbart.py +8 -6
  1093. transformers/models/plbart/modeling_plbart.py +109 -109
  1094. transformers/models/plbart/modular_plbart.py +31 -33
  1095. transformers/models/plbart/tokenization_plbart.py +4 -5
  1096. transformers/models/poolformer/configuration_poolformer.py +0 -1
  1097. transformers/models/poolformer/image_processing_poolformer.py +21 -24
  1098. transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
  1099. transformers/models/poolformer/modeling_poolformer.py +10 -12
  1100. transformers/models/pop2piano/configuration_pop2piano.py +7 -7
  1101. transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
  1102. transformers/models/pop2piano/modeling_pop2piano.py +24 -24
  1103. transformers/models/pop2piano/processing_pop2piano.py +25 -33
  1104. transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
  1105. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  1106. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  1107. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  1108. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  1109. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  1110. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
  1111. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1112. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
  1113. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
  1114. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
  1115. transformers/models/prophetnet/configuration_prophetnet.py +37 -38
  1116. transformers/models/prophetnet/modeling_prophetnet.py +121 -153
  1117. transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
  1118. transformers/models/pvt/configuration_pvt.py +0 -1
  1119. transformers/models/pvt/image_processing_pvt.py +24 -27
  1120. transformers/models/pvt/image_processing_pvt_fast.py +1 -2
  1121. transformers/models/pvt/modeling_pvt.py +19 -21
  1122. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
  1123. transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
  1124. transformers/models/qwen2/configuration_qwen2.py +32 -25
  1125. transformers/models/qwen2/modeling_qwen2.py +35 -37
  1126. transformers/models/qwen2/modular_qwen2.py +14 -15
  1127. transformers/models/qwen2/tokenization_qwen2.py +2 -9
  1128. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
  1129. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
  1130. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
  1131. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
  1132. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
  1133. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
  1134. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
  1135. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
  1136. transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
  1137. transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
  1138. transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
  1139. transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
  1140. transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
  1141. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
  1142. transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
  1143. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
  1144. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
  1145. transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
  1146. transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
  1147. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
  1148. transformers/models/qwen3/configuration_qwen3.py +34 -27
  1149. transformers/models/qwen3/modeling_qwen3.py +35 -38
  1150. transformers/models/qwen3/modular_qwen3.py +7 -9
  1151. transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
  1152. transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
  1153. transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
  1154. transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
  1155. transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
  1156. transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
  1157. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
  1158. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
  1159. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
  1160. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
  1161. transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
  1162. transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
  1163. transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
  1164. transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
  1165. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
  1166. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
  1167. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
  1168. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
  1169. transformers/models/rag/configuration_rag.py +6 -7
  1170. transformers/models/rag/modeling_rag.py +119 -121
  1171. transformers/models/rag/retrieval_rag.py +3 -5
  1172. transformers/models/rag/tokenization_rag.py +0 -50
  1173. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
  1174. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
  1175. transformers/models/reformer/configuration_reformer.py +7 -8
  1176. transformers/models/reformer/modeling_reformer.py +67 -68
  1177. transformers/models/reformer/tokenization_reformer.py +3 -6
  1178. transformers/models/regnet/configuration_regnet.py +0 -1
  1179. transformers/models/regnet/modeling_regnet.py +7 -9
  1180. transformers/models/rembert/configuration_rembert.py +8 -2
  1181. transformers/models/rembert/modeling_rembert.py +108 -132
  1182. transformers/models/rembert/tokenization_rembert.py +1 -4
  1183. transformers/models/resnet/configuration_resnet.py +2 -5
  1184. transformers/models/resnet/modeling_resnet.py +14 -15
  1185. transformers/models/roberta/configuration_roberta.py +11 -3
  1186. transformers/models/roberta/modeling_roberta.py +97 -99
  1187. transformers/models/roberta/modular_roberta.py +55 -58
  1188. transformers/models/roberta/tokenization_roberta.py +2 -5
  1189. transformers/models/roberta/tokenization_roberta_old.py +2 -4
  1190. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
  1191. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
  1192. transformers/models/roc_bert/configuration_roc_bert.py +8 -2
  1193. transformers/models/roc_bert/modeling_roc_bert.py +125 -162
  1194. transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
  1195. transformers/models/roformer/configuration_roformer.py +13 -3
  1196. transformers/models/roformer/modeling_roformer.py +79 -95
  1197. transformers/models/roformer/tokenization_roformer.py +3 -6
  1198. transformers/models/roformer/tokenization_utils.py +0 -1
  1199. transformers/models/rt_detr/configuration_rt_detr.py +8 -50
  1200. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
  1201. transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
  1202. transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
  1203. transformers/models/rt_detr/modeling_rt_detr.py +643 -804
  1204. transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
  1205. transformers/models/rt_detr/modular_rt_detr.py +1522 -20
  1206. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
  1207. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
  1208. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
  1209. transformers/models/rwkv/configuration_rwkv.py +2 -4
  1210. transformers/models/rwkv/modeling_rwkv.py +29 -54
  1211. transformers/models/sam/configuration_sam.py +2 -1
  1212. transformers/models/sam/image_processing_sam.py +59 -60
  1213. transformers/models/sam/image_processing_sam_fast.py +25 -26
  1214. transformers/models/sam/modeling_sam.py +46 -43
  1215. transformers/models/sam/processing_sam.py +39 -27
  1216. transformers/models/sam2/configuration_sam2.py +1 -2
  1217. transformers/models/sam2/image_processing_sam2_fast.py +14 -15
  1218. transformers/models/sam2/modeling_sam2.py +96 -94
  1219. transformers/models/sam2/modular_sam2.py +85 -94
  1220. transformers/models/sam2/processing_sam2.py +31 -47
  1221. transformers/models/sam2_video/configuration_sam2_video.py +0 -1
  1222. transformers/models/sam2_video/modeling_sam2_video.py +114 -116
  1223. transformers/models/sam2_video/modular_sam2_video.py +72 -89
  1224. transformers/models/sam2_video/processing_sam2_video.py +49 -66
  1225. transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
  1226. transformers/models/sam3/configuration_sam3.py +0 -1
  1227. transformers/models/sam3/image_processing_sam3_fast.py +17 -20
  1228. transformers/models/sam3/modeling_sam3.py +94 -100
  1229. transformers/models/sam3/modular_sam3.py +3 -8
  1230. transformers/models/sam3/processing_sam3.py +37 -52
  1231. transformers/models/sam3_tracker/__init__.py +0 -1
  1232. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
  1233. transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
  1234. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
  1235. transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
  1236. transformers/models/sam3_tracker_video/__init__.py +0 -1
  1237. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
  1238. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
  1239. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
  1240. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
  1241. transformers/models/sam3_video/configuration_sam3_video.py +0 -1
  1242. transformers/models/sam3_video/modeling_sam3_video.py +56 -45
  1243. transformers/models/sam3_video/processing_sam3_video.py +25 -45
  1244. transformers/models/sam_hq/__init__.py +1 -1
  1245. transformers/models/sam_hq/configuration_sam_hq.py +2 -1
  1246. transformers/models/sam_hq/modeling_sam_hq.py +52 -50
  1247. transformers/models/sam_hq/modular_sam_hq.py +23 -25
  1248. transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
  1249. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
  1250. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
  1251. transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
  1252. transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
  1253. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
  1254. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
  1255. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
  1256. transformers/models/seed_oss/configuration_seed_oss.py +30 -34
  1257. transformers/models/seed_oss/modeling_seed_oss.py +34 -36
  1258. transformers/models/seed_oss/modular_seed_oss.py +6 -7
  1259. transformers/models/segformer/configuration_segformer.py +0 -10
  1260. transformers/models/segformer/image_processing_segformer.py +39 -42
  1261. transformers/models/segformer/image_processing_segformer_fast.py +11 -12
  1262. transformers/models/segformer/modeling_segformer.py +28 -28
  1263. transformers/models/segformer/modular_segformer.py +8 -9
  1264. transformers/models/seggpt/configuration_seggpt.py +0 -1
  1265. transformers/models/seggpt/image_processing_seggpt.py +38 -41
  1266. transformers/models/seggpt/modeling_seggpt.py +48 -38
  1267. transformers/models/sew/configuration_sew.py +4 -2
  1268. transformers/models/sew/modeling_sew.py +42 -40
  1269. transformers/models/sew/modular_sew.py +12 -13
  1270. transformers/models/sew_d/configuration_sew_d.py +4 -2
  1271. transformers/models/sew_d/modeling_sew_d.py +32 -31
  1272. transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
  1273. transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
  1274. transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
  1275. transformers/models/siglip/configuration_siglip.py +4 -2
  1276. transformers/models/siglip/image_processing_siglip.py +17 -20
  1277. transformers/models/siglip/image_processing_siglip_fast.py +0 -1
  1278. transformers/models/siglip/modeling_siglip.py +65 -110
  1279. transformers/models/siglip/processing_siglip.py +2 -14
  1280. transformers/models/siglip/tokenization_siglip.py +6 -7
  1281. transformers/models/siglip2/__init__.py +1 -0
  1282. transformers/models/siglip2/configuration_siglip2.py +4 -2
  1283. transformers/models/siglip2/image_processing_siglip2.py +15 -16
  1284. transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
  1285. transformers/models/siglip2/modeling_siglip2.py +89 -130
  1286. transformers/models/siglip2/modular_siglip2.py +95 -48
  1287. transformers/models/siglip2/processing_siglip2.py +2 -14
  1288. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  1289. transformers/models/smollm3/configuration_smollm3.py +29 -32
  1290. transformers/models/smollm3/modeling_smollm3.py +35 -38
  1291. transformers/models/smollm3/modular_smollm3.py +36 -38
  1292. transformers/models/smolvlm/configuration_smolvlm.py +2 -4
  1293. transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
  1294. transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
  1295. transformers/models/smolvlm/modeling_smolvlm.py +124 -96
  1296. transformers/models/smolvlm/modular_smolvlm.py +50 -39
  1297. transformers/models/smolvlm/processing_smolvlm.py +15 -76
  1298. transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
  1299. transformers/models/solar_open/__init__.py +27 -0
  1300. transformers/models/solar_open/configuration_solar_open.py +184 -0
  1301. transformers/models/solar_open/modeling_solar_open.py +642 -0
  1302. transformers/models/solar_open/modular_solar_open.py +224 -0
  1303. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
  1304. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
  1305. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1306. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
  1307. transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
  1308. transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
  1309. transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
  1310. transformers/models/speecht5/configuration_speecht5.py +7 -9
  1311. transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
  1312. transformers/models/speecht5/modeling_speecht5.py +172 -174
  1313. transformers/models/speecht5/number_normalizer.py +0 -1
  1314. transformers/models/speecht5/processing_speecht5.py +3 -37
  1315. transformers/models/speecht5/tokenization_speecht5.py +4 -5
  1316. transformers/models/splinter/configuration_splinter.py +6 -7
  1317. transformers/models/splinter/modeling_splinter.py +62 -59
  1318. transformers/models/splinter/tokenization_splinter.py +2 -4
  1319. transformers/models/squeezebert/configuration_squeezebert.py +14 -2
  1320. transformers/models/squeezebert/modeling_squeezebert.py +60 -62
  1321. transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
  1322. transformers/models/stablelm/configuration_stablelm.py +28 -29
  1323. transformers/models/stablelm/modeling_stablelm.py +44 -47
  1324. transformers/models/starcoder2/configuration_starcoder2.py +30 -27
  1325. transformers/models/starcoder2/modeling_starcoder2.py +38 -41
  1326. transformers/models/starcoder2/modular_starcoder2.py +17 -19
  1327. transformers/models/superglue/configuration_superglue.py +7 -3
  1328. transformers/models/superglue/image_processing_superglue.py +15 -15
  1329. transformers/models/superglue/image_processing_superglue_fast.py +8 -8
  1330. transformers/models/superglue/modeling_superglue.py +41 -37
  1331. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1332. transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
  1333. transformers/models/superpoint/modeling_superpoint.py +17 -16
  1334. transformers/models/swiftformer/configuration_swiftformer.py +0 -1
  1335. transformers/models/swiftformer/modeling_swiftformer.py +12 -14
  1336. transformers/models/swin/configuration_swin.py +2 -5
  1337. transformers/models/swin/modeling_swin.py +69 -78
  1338. transformers/models/swin2sr/configuration_swin2sr.py +0 -1
  1339. transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
  1340. transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
  1341. transformers/models/swin2sr/modeling_swin2sr.py +30 -30
  1342. transformers/models/swinv2/configuration_swinv2.py +2 -5
  1343. transformers/models/swinv2/modeling_swinv2.py +65 -74
  1344. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
  1345. transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
  1346. transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
  1347. transformers/models/t5/configuration_t5.py +9 -9
  1348. transformers/models/t5/modeling_t5.py +80 -85
  1349. transformers/models/t5/tokenization_t5.py +1 -3
  1350. transformers/models/t5gemma/configuration_t5gemma.py +43 -59
  1351. transformers/models/t5gemma/modeling_t5gemma.py +105 -108
  1352. transformers/models/t5gemma/modular_t5gemma.py +128 -142
  1353. transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
  1354. transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
  1355. transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
  1356. transformers/models/table_transformer/configuration_table_transformer.py +18 -50
  1357. transformers/models/table_transformer/modeling_table_transformer.py +73 -101
  1358. transformers/models/tapas/configuration_tapas.py +12 -2
  1359. transformers/models/tapas/modeling_tapas.py +65 -67
  1360. transformers/models/tapas/tokenization_tapas.py +116 -153
  1361. transformers/models/textnet/configuration_textnet.py +4 -7
  1362. transformers/models/textnet/image_processing_textnet.py +22 -25
  1363. transformers/models/textnet/image_processing_textnet_fast.py +8 -9
  1364. transformers/models/textnet/modeling_textnet.py +28 -28
  1365. transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
  1366. transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
  1367. transformers/models/timesfm/configuration_timesfm.py +0 -1
  1368. transformers/models/timesfm/modeling_timesfm.py +22 -25
  1369. transformers/models/timesfm/modular_timesfm.py +21 -24
  1370. transformers/models/timesformer/configuration_timesformer.py +0 -1
  1371. transformers/models/timesformer/modeling_timesformer.py +13 -16
  1372. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
  1373. transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
  1374. transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
  1375. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
  1376. transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
  1377. transformers/models/trocr/configuration_trocr.py +11 -8
  1378. transformers/models/trocr/modeling_trocr.py +42 -42
  1379. transformers/models/trocr/processing_trocr.py +5 -25
  1380. transformers/models/tvp/configuration_tvp.py +10 -36
  1381. transformers/models/tvp/image_processing_tvp.py +50 -52
  1382. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1383. transformers/models/tvp/modeling_tvp.py +26 -28
  1384. transformers/models/tvp/processing_tvp.py +2 -14
  1385. transformers/models/udop/configuration_udop.py +16 -8
  1386. transformers/models/udop/modeling_udop.py +73 -72
  1387. transformers/models/udop/processing_udop.py +7 -26
  1388. transformers/models/udop/tokenization_udop.py +80 -93
  1389. transformers/models/umt5/configuration_umt5.py +8 -7
  1390. transformers/models/umt5/modeling_umt5.py +87 -84
  1391. transformers/models/unispeech/configuration_unispeech.py +4 -2
  1392. transformers/models/unispeech/modeling_unispeech.py +54 -53
  1393. transformers/models/unispeech/modular_unispeech.py +20 -22
  1394. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
  1395. transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
  1396. transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
  1397. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1398. transformers/models/univnet/modeling_univnet.py +7 -8
  1399. transformers/models/upernet/configuration_upernet.py +8 -36
  1400. transformers/models/upernet/modeling_upernet.py +11 -14
  1401. transformers/models/vaultgemma/__init__.py +0 -1
  1402. transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
  1403. transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
  1404. transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
  1405. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  1406. transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
  1407. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
  1408. transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
  1409. transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
  1410. transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
  1411. transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
  1412. transformers/models/video_llava/configuration_video_llava.py +4 -1
  1413. transformers/models/video_llava/image_processing_video_llava.py +35 -38
  1414. transformers/models/video_llava/modeling_video_llava.py +139 -143
  1415. transformers/models/video_llava/processing_video_llava.py +38 -78
  1416. transformers/models/video_llava/video_processing_video_llava.py +0 -1
  1417. transformers/models/videomae/configuration_videomae.py +0 -1
  1418. transformers/models/videomae/image_processing_videomae.py +31 -34
  1419. transformers/models/videomae/modeling_videomae.py +17 -20
  1420. transformers/models/videomae/video_processing_videomae.py +0 -1
  1421. transformers/models/vilt/configuration_vilt.py +4 -2
  1422. transformers/models/vilt/image_processing_vilt.py +29 -30
  1423. transformers/models/vilt/image_processing_vilt_fast.py +15 -16
  1424. transformers/models/vilt/modeling_vilt.py +103 -90
  1425. transformers/models/vilt/processing_vilt.py +2 -14
  1426. transformers/models/vipllava/configuration_vipllava.py +4 -1
  1427. transformers/models/vipllava/modeling_vipllava.py +92 -67
  1428. transformers/models/vipllava/modular_vipllava.py +78 -54
  1429. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
  1430. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
  1431. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
  1432. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
  1433. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
  1434. transformers/models/visual_bert/configuration_visual_bert.py +6 -2
  1435. transformers/models/visual_bert/modeling_visual_bert.py +90 -92
  1436. transformers/models/vit/configuration_vit.py +2 -3
  1437. transformers/models/vit/image_processing_vit.py +19 -22
  1438. transformers/models/vit/image_processing_vit_fast.py +0 -1
  1439. transformers/models/vit/modeling_vit.py +20 -20
  1440. transformers/models/vit_mae/configuration_vit_mae.py +0 -1
  1441. transformers/models/vit_mae/modeling_vit_mae.py +32 -30
  1442. transformers/models/vit_msn/configuration_vit_msn.py +0 -1
  1443. transformers/models/vit_msn/modeling_vit_msn.py +21 -19
  1444. transformers/models/vitdet/configuration_vitdet.py +2 -5
  1445. transformers/models/vitdet/modeling_vitdet.py +14 -17
  1446. transformers/models/vitmatte/configuration_vitmatte.py +7 -39
  1447. transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
  1448. transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
  1449. transformers/models/vitmatte/modeling_vitmatte.py +10 -12
  1450. transformers/models/vitpose/configuration_vitpose.py +7 -47
  1451. transformers/models/vitpose/image_processing_vitpose.py +24 -25
  1452. transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
  1453. transformers/models/vitpose/modeling_vitpose.py +15 -15
  1454. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
  1455. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
  1456. transformers/models/vits/configuration_vits.py +4 -1
  1457. transformers/models/vits/modeling_vits.py +43 -42
  1458. transformers/models/vits/tokenization_vits.py +3 -4
  1459. transformers/models/vivit/configuration_vivit.py +0 -1
  1460. transformers/models/vivit/image_processing_vivit.py +36 -39
  1461. transformers/models/vivit/modeling_vivit.py +9 -11
  1462. transformers/models/vjepa2/__init__.py +0 -1
  1463. transformers/models/vjepa2/configuration_vjepa2.py +0 -1
  1464. transformers/models/vjepa2/modeling_vjepa2.py +39 -41
  1465. transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
  1466. transformers/models/voxtral/__init__.py +0 -1
  1467. transformers/models/voxtral/configuration_voxtral.py +0 -2
  1468. transformers/models/voxtral/modeling_voxtral.py +41 -48
  1469. transformers/models/voxtral/modular_voxtral.py +35 -38
  1470. transformers/models/voxtral/processing_voxtral.py +25 -48
  1471. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
  1472. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
  1473. transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
  1474. transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
  1475. transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
  1476. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
  1477. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
  1478. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
  1479. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
  1480. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
  1481. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
  1482. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
  1483. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
  1484. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
  1485. transformers/models/wavlm/configuration_wavlm.py +4 -2
  1486. transformers/models/wavlm/modeling_wavlm.py +49 -49
  1487. transformers/models/wavlm/modular_wavlm.py +4 -5
  1488. transformers/models/whisper/configuration_whisper.py +6 -5
  1489. transformers/models/whisper/english_normalizer.py +3 -4
  1490. transformers/models/whisper/feature_extraction_whisper.py +9 -24
  1491. transformers/models/whisper/generation_whisper.py +26 -49
  1492. transformers/models/whisper/modeling_whisper.py +71 -73
  1493. transformers/models/whisper/processing_whisper.py +3 -20
  1494. transformers/models/whisper/tokenization_whisper.py +9 -30
  1495. transformers/models/x_clip/configuration_x_clip.py +4 -2
  1496. transformers/models/x_clip/modeling_x_clip.py +94 -96
  1497. transformers/models/x_clip/processing_x_clip.py +2 -14
  1498. transformers/models/xcodec/configuration_xcodec.py +4 -6
  1499. transformers/models/xcodec/modeling_xcodec.py +15 -17
  1500. transformers/models/xglm/configuration_xglm.py +9 -8
  1501. transformers/models/xglm/modeling_xglm.py +49 -55
  1502. transformers/models/xglm/tokenization_xglm.py +1 -4
  1503. transformers/models/xlm/configuration_xlm.py +10 -8
  1504. transformers/models/xlm/modeling_xlm.py +127 -131
  1505. transformers/models/xlm/tokenization_xlm.py +3 -5
  1506. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
  1507. transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
  1508. transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
  1509. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
  1510. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
  1511. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
  1512. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
  1513. transformers/models/xlnet/configuration_xlnet.py +3 -12
  1514. transformers/models/xlnet/modeling_xlnet.py +149 -162
  1515. transformers/models/xlnet/tokenization_xlnet.py +1 -4
  1516. transformers/models/xlstm/configuration_xlstm.py +8 -12
  1517. transformers/models/xlstm/modeling_xlstm.py +61 -96
  1518. transformers/models/xmod/configuration_xmod.py +11 -3
  1519. transformers/models/xmod/modeling_xmod.py +111 -116
  1520. transformers/models/yolos/configuration_yolos.py +0 -1
  1521. transformers/models/yolos/image_processing_yolos.py +60 -62
  1522. transformers/models/yolos/image_processing_yolos_fast.py +42 -45
  1523. transformers/models/yolos/modeling_yolos.py +19 -21
  1524. transformers/models/yolos/modular_yolos.py +17 -19
  1525. transformers/models/yoso/configuration_yoso.py +8 -2
  1526. transformers/models/yoso/modeling_yoso.py +60 -62
  1527. transformers/models/youtu/__init__.py +27 -0
  1528. transformers/models/youtu/configuration_youtu.py +194 -0
  1529. transformers/models/youtu/modeling_youtu.py +619 -0
  1530. transformers/models/youtu/modular_youtu.py +254 -0
  1531. transformers/models/zamba/configuration_zamba.py +5 -8
  1532. transformers/models/zamba/modeling_zamba.py +93 -125
  1533. transformers/models/zamba2/configuration_zamba2.py +44 -50
  1534. transformers/models/zamba2/modeling_zamba2.py +137 -165
  1535. transformers/models/zamba2/modular_zamba2.py +79 -74
  1536. transformers/models/zoedepth/configuration_zoedepth.py +17 -41
  1537. transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
  1538. transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
  1539. transformers/models/zoedepth/modeling_zoedepth.py +19 -19
  1540. transformers/pipelines/__init__.py +47 -106
  1541. transformers/pipelines/any_to_any.py +15 -23
  1542. transformers/pipelines/audio_utils.py +1 -2
  1543. transformers/pipelines/automatic_speech_recognition.py +0 -2
  1544. transformers/pipelines/base.py +13 -17
  1545. transformers/pipelines/image_text_to_text.py +1 -2
  1546. transformers/pipelines/question_answering.py +4 -43
  1547. transformers/pipelines/text_classification.py +1 -14
  1548. transformers/pipelines/text_to_audio.py +5 -1
  1549. transformers/pipelines/token_classification.py +1 -22
  1550. transformers/pipelines/video_classification.py +1 -9
  1551. transformers/pipelines/zero_shot_audio_classification.py +0 -1
  1552. transformers/pipelines/zero_shot_classification.py +0 -6
  1553. transformers/pipelines/zero_shot_image_classification.py +0 -7
  1554. transformers/processing_utils.py +128 -137
  1555. transformers/pytorch_utils.py +2 -26
  1556. transformers/quantizers/base.py +10 -0
  1557. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  1558. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  1559. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  1560. transformers/quantizers/quantizer_mxfp4.py +1 -1
  1561. transformers/quantizers/quantizer_quark.py +0 -1
  1562. transformers/quantizers/quantizer_torchao.py +3 -19
  1563. transformers/safetensors_conversion.py +11 -4
  1564. transformers/testing_utils.py +6 -65
  1565. transformers/tokenization_mistral_common.py +563 -903
  1566. transformers/tokenization_python.py +6 -4
  1567. transformers/tokenization_utils_base.py +228 -341
  1568. transformers/tokenization_utils_sentencepiece.py +5 -6
  1569. transformers/tokenization_utils_tokenizers.py +36 -7
  1570. transformers/trainer.py +30 -41
  1571. transformers/trainer_jit_checkpoint.py +1 -2
  1572. transformers/trainer_seq2seq.py +1 -1
  1573. transformers/training_args.py +414 -420
  1574. transformers/utils/__init__.py +1 -4
  1575. transformers/utils/attention_visualizer.py +1 -1
  1576. transformers/utils/auto_docstring.py +567 -18
  1577. transformers/utils/backbone_utils.py +13 -373
  1578. transformers/utils/doc.py +4 -36
  1579. transformers/utils/dummy_pt_objects.py +0 -42
  1580. transformers/utils/generic.py +70 -34
  1581. transformers/utils/import_utils.py +72 -75
  1582. transformers/utils/loading_report.py +135 -107
  1583. transformers/utils/quantization_config.py +8 -31
  1584. transformers/video_processing_utils.py +24 -25
  1585. transformers/video_utils.py +21 -23
  1586. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
  1587. transformers-5.1.0.dist-info/RECORD +2092 -0
  1588. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1589. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1590. transformers/pipelines/image_to_text.py +0 -229
  1591. transformers-5.0.0rc2.dist-info/RECORD +0 -2042
  1592. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1593. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1594. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -21,18 +21,16 @@
21
21
 
22
22
  import math
23
23
  from collections.abc import Callable
24
- from contextlib import nullcontext
25
- from typing import Optional, Union
24
+ from typing import Optional
26
25
 
27
26
  import torch
28
- import torch.nn.functional as F
29
27
  from torch import nn
30
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
29
 
32
30
  from ... import initialization as init
33
31
  from ...activations import ACT2FN
34
- from ...integrations import use_kernel_func_from_hub
35
- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
32
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
33
+ from ...masking_utils import create_bidirectional_mask, create_bidirectional_sliding_window_mask
36
34
  from ...modeling_layers import GradientCheckpointingLayer
37
35
  from ...modeling_outputs import (
38
36
  BaseModelOutput,
@@ -43,158 +41,13 @@ from ...modeling_outputs import (
43
41
  TokenClassifierOutput,
44
42
  )
45
43
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
46
- from ...modeling_utils import PreTrainedModel
47
- from ...utils import auto_docstring, is_flash_attn_2_available, logging
48
- from ...utils.generic import maybe_autocast
49
- from ...utils.import_utils import is_triton_available
44
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
45
+ from ...processing_utils import Unpack
46
+ from ...utils import TransformersKwargs, auto_docstring
47
+ from ...utils.generic import can_return_tuple, check_model_inputs, maybe_autocast
50
48
  from .configuration_modernbert import ModernBertConfig
51
49
 
52
50
 
53
- if is_flash_attn_2_available():
54
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
55
- from flash_attn.layers.rotary import RotaryEmbedding
56
- from flash_attn.ops.triton.rotary import apply_rotary
57
- else:
58
- RotaryEmbedding = object
59
-
60
-
61
- logger = logging.get_logger(__name__)
62
-
63
-
64
- class ApplyRotaryEmbUnpad(torch.autograd.Function):
65
- @staticmethod
66
- def forward(
67
- ctx,
68
- qkv,
69
- cos,
70
- sin,
71
- cu_seqlens: Optional[torch.Tensor] = None,
72
- max_seqlen: Optional[int] = None,
73
- ):
74
- # (total_nnz, 3, nheads, headdim)
75
- qkv = qkv.contiguous()
76
- total_nnz, _three, _nheads, headdim = qkv.shape
77
- # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
78
- # we get the same tensor
79
- # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
80
- qk = qkv[:, :2].view(total_nnz, -1, headdim)
81
- apply_rotary(
82
- qk,
83
- cos,
84
- sin,
85
- seqlen_offsets=0,
86
- cu_seqlens=cu_seqlens,
87
- max_seqlen=max_seqlen,
88
- interleaved=False,
89
- inplace=True,
90
- )
91
-
92
- ctx.save_for_backward(cos, sin, cu_seqlens)
93
- ctx.max_seqlen = max_seqlen
94
- return qkv
95
-
96
- @staticmethod
97
- def backward(ctx, do):
98
- cos, sin, cu_seqlens = ctx.saved_tensors
99
- do = do.contiguous()
100
- total_nnz, _three, _nheads, headdim = do.shape
101
- # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
102
- # we get the same tensor
103
- dqk = do[:, :2].view(total_nnz, -1, headdim)
104
- apply_rotary(
105
- dqk,
106
- cos,
107
- sin,
108
- seqlen_offsets=0,
109
- cu_seqlens=cu_seqlens,
110
- max_seqlen=ctx.max_seqlen,
111
- interleaved=False,
112
- inplace=True,
113
- conjugate=True,
114
- )
115
-
116
- return do, None, None, None, None, None, None
117
-
118
-
119
- def apply_rotary_unpadded(
120
- qkv,
121
- cos,
122
- sin,
123
- cu_seqlens: Optional[torch.Tensor] = None,
124
- max_seqlen: Optional[int] = None,
125
- ):
126
- """
127
- Arguments:
128
- qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
129
- cos, sin: (seqlen_rotary, rotary_dim / 2)
130
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
131
- of 1st half and 2nd half (GPT-NeoX style).
132
- inplace: if True, apply rotary embedding in-place.
133
- seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
134
- Most commonly used in inference when we have KV cache.
135
- cu_seqlens: (batch + 1,) or None
136
- max_seqlen: int
137
- Return:
138
- out: (total_nnz, dim)
139
- rotary_dim must be <= headdim
140
- Apply rotary embedding to the first rotary_dim of x.
141
- """
142
- return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
143
-
144
-
145
- class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
146
- """
147
- The rotary position embeddings applied directly to unpadded sequences.
148
- """
149
-
150
- def __init__(
151
- self,
152
- dim: int,
153
- base: float = 10000.0,
154
- max_seqlen: Optional[int] = None,
155
- device: Optional[torch.device] = None,
156
- dtype: Optional[torch.dtype] = None,
157
- ):
158
- """
159
- max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
160
- up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
161
- the cos_sin_cache will be recomputed during the forward pass.
162
- """
163
- super().__init__(dim=dim, base=base, device=device, interleaved=False)
164
- self.max_seqlen = max_seqlen
165
-
166
- if max_seqlen is not None and device is not None and dtype is not None:
167
- self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
168
-
169
- def forward(
170
- self,
171
- qkv: torch.Tensor,
172
- cu_seqlens: torch.Tensor,
173
- max_seqlen: Optional[int] = None,
174
- ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
175
- """
176
- Apply rotary embedding *inplace* to qkv.
177
- qkv: (total_nnz, 3, nheads, headdim)
178
- cu_seqlens: (batch + 1,) cumulative sequence lengths
179
- max_seqlen: int max seq length in the batch
180
- """
181
- if max_seqlen is not None:
182
- self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
183
-
184
- qkv = apply_rotary_unpadded(
185
- qkv,
186
- self._cos_cached,
187
- self._sin_cached,
188
- cu_seqlens=cu_seqlens,
189
- max_seqlen=max_seqlen,
190
- )
191
-
192
- return qkv
193
-
194
- def extra_repr(self) -> str:
195
- return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
196
-
197
-
198
51
  class ModernBertEmbeddings(nn.Module):
199
52
  """
200
53
  Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
@@ -207,21 +60,13 @@ class ModernBertEmbeddings(nn.Module):
207
60
  self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
208
61
  self.drop = nn.Dropout(config.embedding_dropout)
209
62
 
210
- @torch.compile(dynamic=True)
211
- def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
212
- return self.drop(self.norm(self.tok_embeddings(input_ids)))
213
-
214
63
  def forward(
215
- self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None
64
+ self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.Tensor | None = None
216
65
  ) -> torch.Tensor:
217
66
  if inputs_embeds is not None:
218
67
  hidden_states = self.drop(self.norm(inputs_embeds))
219
68
  else:
220
- hidden_states = (
221
- self.compiled_embeddings(input_ids)
222
- if self.config.reference_compile
223
- else self.drop(self.norm(self.tok_embeddings(input_ids)))
224
- )
69
+ hidden_states = self.drop(self.norm(self.tok_embeddings(input_ids)))
225
70
  return hidden_states
226
71
 
227
72
 
@@ -273,10 +118,10 @@ class ModernBertRotaryEmbedding(nn.Module):
273
118
 
274
119
  @staticmethod
275
120
  def compute_default_rope_parameters(
276
- config: Optional[ModernBertConfig] = None,
121
+ config: ModernBertConfig | None = None,
277
122
  device: Optional["torch.device"] = None,
278
- seq_len: Optional[int] = None,
279
- layer_type: Optional[str] = None,
123
+ seq_len: int | None = None,
124
+ layer_type: str | None = None,
280
125
  ) -> tuple["torch.Tensor", float]:
281
126
  """
282
127
  Computes the inverse frequencies according to the original RoPE implementation
@@ -326,6 +171,29 @@ class ModernBertRotaryEmbedding(nn.Module):
326
171
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
327
172
 
328
173
 
174
+ def eager_attention_forward(
175
+ module: nn.Module,
176
+ query: torch.Tensor,
177
+ key: torch.Tensor,
178
+ value: torch.Tensor,
179
+ attention_mask: torch.Tensor | None,
180
+ scaling: float,
181
+ dropout: float = 0.0,
182
+ **kwargs,
183
+ ):
184
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
185
+ if attention_mask is not None:
186
+ causal_mask = attention_mask[:, :, :, : key.shape[-2]]
187
+ attn_weights = attn_weights + causal_mask
188
+
189
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
190
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
191
+
192
+ attn_output = torch.matmul(attn_weights, value)
193
+ attn_output = attn_output.transpose(1, 2).contiguous()
194
+ return attn_output, attn_weights
195
+
196
+
329
197
  def rotate_half(x):
330
198
  """Rotates half the hidden dims of the input."""
331
199
  x1 = x[..., : x.shape[-1] // 2]
@@ -334,7 +202,7 @@ def rotate_half(x):
334
202
 
335
203
 
336
204
  @use_kernel_func_from_hub("rotary_pos_emb")
337
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
205
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
338
206
  """Applies Rotary Position Embedding to the query and key tensors.
339
207
 
340
208
  Args:
@@ -342,8 +210,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
342
210
  k (`torch.Tensor`): The key tensor.
343
211
  cos (`torch.Tensor`): The cosine part of the rotary embedding.
344
212
  sin (`torch.Tensor`): The sine part of the rotary embedding.
345
- position_ids (`torch.Tensor`, *optional*):
346
- Deprecated and unused.
347
213
  unsqueeze_dim (`int`, *optional*, defaults to 1):
348
214
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
349
215
  sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -354,137 +220,15 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
354
220
  Returns:
355
221
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
356
222
  """
223
+ original_dtype = q.dtype
357
224
  cos = cos.unsqueeze(unsqueeze_dim)
358
225
  sin = sin.unsqueeze(unsqueeze_dim)
359
- q_embed = (q * cos) + (rotate_half(q) * sin)
360
- k_embed = (k * cos) + (rotate_half(k) * sin)
361
- return q_embed, k_embed
362
-
363
-
364
- def eager_attention_forward(
365
- module: "ModernBertAttention",
366
- qkv: torch.Tensor,
367
- attention_mask: torch.Tensor,
368
- sliding_window_mask: torch.Tensor,
369
- position_ids: Optional[torch.LongTensor],
370
- local_attention: tuple[int, int],
371
- bs: int,
372
- dim: int,
373
- position_embeddings: torch.Tensor,
374
- output_attentions: Optional[bool] = False,
375
- **_kwargs,
376
- ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
377
- # qkv: [batch_size, seqlen, 3, nheads, headdim]
378
- cos, sin = position_embeddings
379
- query, key, value = qkv.transpose(3, 1).unbind(dim=2)
380
- # query, key, value: [batch_size, heads, seq_len, head_dim]
381
- query, key = apply_rotary_pos_emb(query, key, cos, sin)
382
-
383
- scale = module.head_dim**-0.5
384
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
385
-
386
- if local_attention != (-1, -1):
387
- attention_mask = sliding_window_mask
388
-
389
- attn_weights = attn_weights + attention_mask
390
-
391
- # upcast attention to fp32
392
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
393
- attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
394
- attn_output = torch.matmul(attn_weights, value)
395
- attn_output = attn_output.transpose(1, 2).contiguous()
396
- attn_output = attn_output.view(bs, -1, dim)
397
- if output_attentions:
398
- return (attn_output, attn_weights)
399
- return (attn_output,)
400
-
401
-
402
- def flash_attention_forward(
403
- module: "ModernBertAttention",
404
- qkv: torch.Tensor,
405
- rotary_emb: ModernBertUnpaddedRotaryEmbedding,
406
- cu_seqlens: torch.Tensor,
407
- max_seqlen: int,
408
- local_attention: tuple[int, int],
409
- bs: int,
410
- dim: int,
411
- target_dtype: torch.dtype = torch.bfloat16,
412
- **_kwargs,
413
- ) -> tuple[torch.Tensor]:
414
- # (total_seqlen, 3, nheads, headdim)
415
- qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
416
-
417
- convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
418
- if convert_dtype:
419
- # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
420
- # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
421
- orig_dtype = qkv.dtype
422
- qkv = qkv.to(target_dtype)
423
-
424
- attn = flash_attn_varlen_qkvpacked_func(
425
- qkv,
426
- cu_seqlens=cu_seqlens,
427
- max_seqlen=max_seqlen,
428
- dropout_p=module.attention_dropout if module.training else 0.0,
429
- deterministic=module.deterministic_flash_attn,
430
- window_size=local_attention,
431
- )
432
- attn = attn.to(orig_dtype) # type: ignore
433
- else:
434
- attn = flash_attn_varlen_qkvpacked_func(
435
- qkv,
436
- cu_seqlens=cu_seqlens,
437
- max_seqlen=max_seqlen,
438
- dropout_p=module.attention_dropout if module.training else 0.0,
439
- deterministic=module.deterministic_flash_attn,
440
- window_size=local_attention,
441
- )
442
- return (attn.view(bs, dim),)
443
-
444
-
445
- def sdpa_attention_forward(
446
- module: "ModernBertAttention",
447
- qkv: torch.Tensor,
448
- attention_mask: torch.Tensor,
449
- sliding_window_mask: torch.Tensor,
450
- position_ids: Optional[torch.LongTensor],
451
- local_attention: tuple[int, int],
452
- bs: int,
453
- dim: int,
454
- position_embeddings: torch.Tensor,
455
- **_kwargs,
456
- ) -> tuple[torch.Tensor]:
457
- # qkv: [batch_size, seqlen, 3, nheads, headdim]
458
- cos, sin = position_embeddings
459
- query, key, value = qkv.transpose(3, 1).unbind(dim=2)
460
- # query, key, value: [batch_size, heads, seq_len, head_dim]
461
- query, key = apply_rotary_pos_emb(query, key, cos, sin)
462
-
463
- if local_attention != (-1, -1):
464
- attention_mask = sliding_window_mask
465
-
466
- attn_output = (
467
- F.scaled_dot_product_attention(
468
- query,
469
- key,
470
- value,
471
- dropout_p=module.attention_dropout if module.training else 0.0,
472
- attn_mask=attention_mask,
473
- )
474
- .transpose(1, 2)
475
- .contiguous()
476
- )
477
- attn_output = attn_output.view(bs, -1, dim)
478
- return (attn_output,)
479
-
480
-
481
- MODERNBERT_ATTENTION_FUNCTION = {
482
- "flash_attention_2": flash_attention_forward,
483
- "eager": eager_attention_forward,
484
- "sdpa": sdpa_attention_forward,
485
- }
226
+ q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
227
+ k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
228
+ return q_embed.to(original_dtype), k_embed.to(original_dtype)
486
229
 
487
230
 
231
+ @use_kernelized_func(apply_rotary_pos_emb)
488
232
  class ModernBertAttention(nn.Module):
489
233
  """Performs multi-headed self attention on a batch of unpadded sequences.
490
234
 
@@ -495,10 +239,10 @@ class ModernBertAttention(nn.Module):
495
239
  See `forward` method for additional details.
496
240
  """
497
241
 
498
- def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
242
+ def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
499
243
  super().__init__()
500
244
  self.config = config
501
- self.layer_id = layer_id
245
+ self.layer_idx = layer_idx
502
246
 
503
247
  if config.hidden_size % config.num_attention_heads != 0:
504
248
  raise ValueError(
@@ -507,29 +251,19 @@ class ModernBertAttention(nn.Module):
507
251
 
508
252
  self.attention_dropout = config.attention_dropout
509
253
  self.deterministic_flash_attn = config.deterministic_flash_attn
510
- self.num_heads = config.num_attention_heads
511
254
  self.head_dim = config.hidden_size // config.num_attention_heads
512
- self.all_head_size = self.head_dim * self.num_heads
513
- self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
514
- layer_type = config.layer_types[layer_id]
255
+ self.Wqkv = nn.Linear(
256
+ config.hidden_size, 3 * self.head_dim * config.num_attention_heads, bias=config.attention_bias
257
+ )
515
258
 
516
- if layer_id % config.global_attn_every_n_layers != 0:
517
- self.local_attention = (config.local_attention // 2, config.local_attention // 2)
518
- max_position_embeddings = config.local_attention
259
+ if config.layer_types[layer_idx] == "sliding_attention":
260
+ # config.sliding_window = local_attention // 2 (half-window size, e.g. 64 for local_attention=128)
261
+ # +1 is needed because flash attention sets inclusive boundaries (see modeling_flash_attention_utils.py)
262
+ self.sliding_window = config.sliding_window + 1
519
263
  else:
520
- self.local_attention = (-1, -1)
521
- max_position_embeddings = config.max_position_embeddings
264
+ self.sliding_window = None
522
265
 
523
- if config._attn_implementation == "flash_attention_2":
524
- rope_parameters_dict = (
525
- self.config.rope_parameters[layer_type] if layer_type is not None else self.config.rope_parameters
526
- )
527
- rope_theta = rope_parameters_dict["rope_theta"]
528
- self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
529
- dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
530
- )
531
- else:
532
- self.rotary_emb = None
266
+ self.is_causal = False
533
267
 
534
268
  self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
535
269
  self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
@@ -537,82 +271,75 @@ class ModernBertAttention(nn.Module):
537
271
  def forward(
538
272
  self,
539
273
  hidden_states: torch.Tensor,
540
- position_embeddings: Optional[torch.Tensor] = None,
541
- output_attentions: Optional[bool] = False,
542
- **kwargs,
543
- ) -> torch.Tensor:
274
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
275
+ attention_mask: torch.Tensor | None = None,
276
+ **kwargs: Unpack[TransformersKwargs],
277
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
278
+ input_shape = hidden_states.shape[:-1]
279
+
544
280
  qkv = self.Wqkv(hidden_states)
281
+ qkv = qkv.view(*input_shape, 3, -1, self.head_dim)
282
+ query_states, key_states, value_states = qkv.unbind(dim=-3)
545
283
 
546
- bs = hidden_states.shape[0]
547
- if self.config._attn_implementation == "flash_attention_2":
548
- qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
549
- else:
550
- qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
284
+ query_states = query_states.transpose(1, 2)
285
+ key_states = key_states.transpose(1, 2)
286
+ value_states = value_states.transpose(1, 2)
287
+
288
+ cos, sin = position_embeddings
289
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1)
290
+
291
+ attention_interface = eager_attention_forward
292
+ if self.config._attn_implementation != "eager":
293
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
551
294
 
552
- attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
295
+ attn_output, attn_weights = attention_interface(
553
296
  self,
554
- qkv=qkv,
555
- rotary_emb=self.rotary_emb,
556
- local_attention=self.local_attention,
557
- bs=bs,
558
- dim=self.all_head_size,
559
- position_embeddings=position_embeddings,
560
- output_attentions=output_attentions,
297
+ query_states,
298
+ key_states,
299
+ value_states,
300
+ attention_mask,
301
+ dropout=self.attention_dropout if self.training else 0.0,
302
+ scaling=self.head_dim**-0.5,
303
+ sliding_window=self.sliding_window,
304
+ deterministic=self.deterministic_flash_attn,
561
305
  **kwargs,
562
306
  )
563
- hidden_states = attn_outputs[0]
564
- hidden_states = self.out_drop(self.Wo(hidden_states))
565
307
 
566
- return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
308
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
309
+ attn_output = self.out_drop(self.Wo(attn_output))
310
+ return attn_output, attn_weights
567
311
 
568
312
 
569
313
  class ModernBertEncoderLayer(GradientCheckpointingLayer):
570
- def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
314
+ def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
571
315
  super().__init__()
572
316
  self.config = config
573
- if layer_id == 0:
317
+ self.layer_idx = layer_idx
318
+ if layer_idx == 0:
574
319
  self.attn_norm = nn.Identity()
575
320
  else:
576
321
  self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
577
- self.attn = ModernBertAttention(config=config, layer_id=layer_id)
322
+ self.attn = ModernBertAttention(config=config, layer_idx=layer_idx)
578
323
  self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
579
324
  self.mlp = ModernBertMLP(config)
580
- self.attention_type = config.layer_types[layer_id]
581
-
582
- @torch.compile(dynamic=True)
583
- def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
584
- return self.mlp(self.mlp_norm(hidden_states))
325
+ self.attention_type = config.layer_types[layer_idx]
585
326
 
586
327
  def forward(
587
328
  self,
588
329
  hidden_states: torch.Tensor,
589
- attention_mask: Optional[torch.Tensor] = None,
590
- sliding_window_mask: Optional[torch.Tensor] = None,
591
- position_ids: Optional[torch.LongTensor] = None,
592
- cu_seqlens: Optional[torch.Tensor] = None,
593
- max_seqlen: Optional[int] = None,
594
- position_embeddings: Optional[torch.Tensor] = None,
595
- output_attentions: Optional[bool] = False,
330
+ attention_mask: torch.Tensor | None = None,
331
+ position_embeddings: torch.Tensor | None = None,
332
+ **kwargs: Unpack[TransformersKwargs],
596
333
  ) -> torch.Tensor:
597
- attn_outputs = self.attn(
334
+ attn_output, _ = self.attn(
598
335
  self.attn_norm(hidden_states),
599
- attention_mask=attention_mask,
600
- sliding_window_mask=sliding_window_mask,
601
- position_ids=position_ids,
602
- cu_seqlens=cu_seqlens,
603
- max_seqlen=max_seqlen,
604
336
  position_embeddings=position_embeddings,
605
- output_attentions=output_attentions,
606
- )
607
- hidden_states = hidden_states + attn_outputs[0]
608
- mlp_output = (
609
- self.compiled_mlp(hidden_states)
610
- if self.config.reference_compile
611
- else self.mlp(self.mlp_norm(hidden_states))
337
+ attention_mask=attention_mask,
338
+ **kwargs,
612
339
  )
613
- hidden_states = hidden_states + mlp_output
614
-
615
- return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
340
+ hidden_states = hidden_states + attn_output
341
+ hidden_states = hidden_states + self.mlp(self.mlp_norm(hidden_states))
342
+ return hidden_states
616
343
 
617
344
 
618
345
  @auto_docstring
@@ -623,7 +350,13 @@ class ModernBertPreTrainedModel(PreTrainedModel):
623
350
  _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
624
351
  _supports_flash_attn = True
625
352
  _supports_sdpa = True
626
- _supports_flex_attn = False
353
+ _supports_flex_attn = True
354
+ _supports_attention_backend = True
355
+
356
+ _can_record_outputs = {
357
+ "hidden_states": ModernBertEncoderLayer,
358
+ "attentions": ModernBertAttention,
359
+ }
627
360
 
628
361
  @torch.no_grad()
629
362
  def _init_weights(self, module: nn.Module):
@@ -685,147 +418,24 @@ class ModernBertPreTrainedModel(PreTrainedModel):
685
418
  curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
686
419
  init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
687
420
  init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
688
- elif isinstance(module, ModernBertUnpaddedRotaryEmbedding):
689
- inv_freq = module._compute_inv_freq()
690
- init.copy_(module.inv_freq, inv_freq)
691
421
 
692
422
  def _check_and_adjust_attn_implementation(
693
- self, attn_implementation: Optional[str], is_init_check: bool = False
423
+ self, attn_implementation: str | None, is_init_check: bool = False
694
424
  ) -> str:
695
425
  """
696
426
  Checks and dispatches to hhe requested attention implementation.
697
427
  """
698
- # If the user didn't specify anything, try to use flash_attention_2 if available.
428
+ # If the user didn't specify anything, try to use flash_attention_2.
699
429
  # Otherwise we fall back to the default SDPA -> Eager from the super() method.
700
- # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
701
- # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
702
-
703
430
  try:
704
- attn_implementation = (
705
- "flash_attention_2"
706
- if attn_implementation is None and self._flash_attn_2_can_dispatch()
707
- else attn_implementation
431
+ requested_attn_implementation = "flash_attention_2" if attn_implementation is None else attn_implementation
432
+ return super()._check_and_adjust_attn_implementation(
433
+ attn_implementation=requested_attn_implementation, is_init_check=is_init_check
708
434
  )
709
435
  except (ValueError, ImportError):
710
- pass
711
- return super()._check_and_adjust_attn_implementation(
712
- attn_implementation=attn_implementation, is_init_check=is_init_check
713
- )
714
-
715
- def _maybe_set_compile(self):
716
- if self.config.reference_compile is False:
717
- return
718
-
719
- if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
720
- if self.config.reference_compile:
721
- logger.warning_once(
722
- "If `accelerate` split the model across devices, `torch.compile` will not work. "
723
- "Falling back to non-compiled mode."
724
- )
725
- self.config.reference_compile = False
726
-
727
- if self.device.type == "mps":
728
- if self.config.reference_compile:
729
- logger.warning_once(
730
- "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
731
- "Falling back to non-compiled mode."
732
- )
733
- self.config.reference_compile = False
734
-
735
- if self.device.type == "cpu":
736
- if self.config.reference_compile:
737
- logger.warning_once(
738
- "Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
739
- "Falling back to non-compiled mode."
740
- )
741
- self.config.reference_compile = False
742
-
743
- if self.config.reference_compile is None:
744
- self.config.reference_compile = is_triton_available()
745
-
746
- def resize_token_embeddings(self, *args, **kwargs):
747
- model_embeds = super().resize_token_embeddings(*args, **kwargs)
748
-
749
- if self.config.reference_compile in {True, None}:
750
- if self.config.reference_compile:
751
- logger.warning_once(
752
- "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
753
- )
754
- self.config.reference_compile = False
755
-
756
- return model_embeds
757
-
758
-
759
- def _unpad_modernbert_input(
760
- inputs: torch.Tensor,
761
- attention_mask: torch.Tensor,
762
- position_ids: Optional[torch.Tensor] = None,
763
- labels: Optional[torch.Tensor] = None,
764
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
765
- """
766
- Remove padding from input sequences.
767
-
768
- Args:
769
- inputs: (batch, seqlen, ...) or (batch, seqlen)
770
- attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
771
- position_ids: (batch, seqlen), int, position ids
772
- labels: (batch, seqlen), int, labels
773
-
774
- Returns:
775
- unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
776
- indices: (total_nnz)
777
- cu_seqlens: (batch + 1), the cumulative sequence lengths
778
- max_seqlen_in_batch: int
779
- unpadded_position_ids: (total_nnz) or None
780
- unpadded_labels: (total_nnz) or None
781
- """
782
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
783
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
784
- max_seqlen_in_batch = int(seqlens_in_batch.max().item())
785
- cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
786
-
787
- if inputs.dim() == 2:
788
- unpadded_inputs = inputs.flatten()[indices]
789
- else:
790
- batch, seqlen, *rest = inputs.shape
791
- shape = batch * seqlen
792
- unpadded_inputs = inputs.view(shape, *rest)[indices]
793
-
794
- unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
795
- unpadded_labels = labels.flatten()[indices] if labels is not None else None
796
-
797
- return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
798
-
799
-
800
- def _pad_modernbert_output(
801
- inputs: torch.Tensor,
802
- indices: torch.Tensor,
803
- batch: int,
804
- seqlen: int,
805
- ) -> torch.Tensor:
806
- """
807
- Add padding to sequences.
808
-
809
- Args:
810
- inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
811
- indices: (total_nnz)
812
- batch: int, batch size
813
- seqlen: int, max sequence length
814
-
815
- Returns:
816
- padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
817
- """
818
- if inputs.dim() == 1:
819
- output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
820
- output[indices] = inputs
821
- padded_inputs = output.view(batch, seqlen)
822
- else:
823
- _, *rest = inputs.shape
824
- output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
825
- output[indices] = inputs
826
- padded_inputs = output.view(batch, seqlen, *rest)
827
-
828
- return padded_inputs
436
+ return super()._check_and_adjust_attn_implementation(
437
+ attn_implementation=attn_implementation, is_init_check=is_init_check
438
+ )
829
439
 
830
440
 
831
441
  @auto_docstring
@@ -835,7 +445,7 @@ class ModernBertModel(ModernBertPreTrainedModel):
835
445
  self.config = config
836
446
  self.embeddings = ModernBertEmbeddings(config)
837
447
  self.layers = nn.ModuleList(
838
- [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
448
+ [ModernBertEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
839
449
  )
840
450
  self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
841
451
  self.rotary_emb = ModernBertRotaryEmbedding(config=config)
@@ -848,175 +458,53 @@ class ModernBertModel(ModernBertPreTrainedModel):
848
458
  def set_input_embeddings(self, value):
849
459
  self.embeddings.tok_embeddings = value
850
460
 
461
+ @check_model_inputs
851
462
  @auto_docstring
852
463
  def forward(
853
464
  self,
854
- input_ids: Optional[torch.LongTensor] = None,
855
- attention_mask: Optional[torch.Tensor] = None,
856
- sliding_window_mask: Optional[torch.Tensor] = None,
857
- position_ids: Optional[torch.LongTensor] = None,
858
- inputs_embeds: Optional[torch.Tensor] = None,
859
- indices: Optional[torch.Tensor] = None,
860
- cu_seqlens: Optional[torch.Tensor] = None,
861
- max_seqlen: Optional[int] = None,
862
- batch_size: Optional[int] = None,
863
- seq_len: Optional[int] = None,
864
- output_attentions: Optional[bool] = None,
865
- output_hidden_states: Optional[bool] = None,
866
- return_dict: Optional[bool] = None,
867
- **kwargs,
868
- ) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]:
869
- r"""
870
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
871
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
872
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
873
- far-away tokens in the local attention layers when not using Flash Attention.
874
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
875
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
876
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
877
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
878
- max_seqlen (`int`, *optional*):
879
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
880
- batch_size (`int`, *optional*):
881
- Batch size of the input sequences. Used to pad the output tensors.
882
- seq_len (`int`, *optional*):
883
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
884
- """
885
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
886
- output_hidden_states = (
887
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
888
- )
889
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
890
-
465
+ input_ids: torch.LongTensor | None = None,
466
+ attention_mask: torch.Tensor | None = None,
467
+ position_ids: torch.LongTensor | None = None,
468
+ inputs_embeds: torch.Tensor | None = None,
469
+ **kwargs: Unpack[TransformersKwargs],
470
+ ) -> BaseModelOutput:
891
471
  if (input_ids is None) ^ (inputs_embeds is not None):
892
472
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
893
473
 
894
- all_hidden_states = () if output_hidden_states else None
895
- all_self_attentions = () if output_attentions else None
896
-
897
- self._maybe_set_compile()
898
-
899
- if input_ids is not None:
900
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
901
-
902
- if batch_size is None and seq_len is None:
903
- if inputs_embeds is not None:
904
- batch_size, seq_len = inputs_embeds.shape[:2]
905
- else:
906
- batch_size, seq_len = input_ids.shape[:2]
474
+ seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
907
475
  device = input_ids.device if input_ids is not None else inputs_embeds.device
908
476
 
909
- if attention_mask is None:
910
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
911
-
912
- repad = False
913
- if self.config._attn_implementation == "flash_attention_2":
914
- if indices is None and cu_seqlens is None and max_seqlen is None:
915
- repad = True
916
- if inputs_embeds is None:
917
- with torch.no_grad():
918
- input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
919
- inputs=input_ids, attention_mask=attention_mask
920
- )
921
- else:
922
- inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
923
- inputs=inputs_embeds, attention_mask=attention_mask
924
- )
925
- if position_ids is None:
926
- position_ids = indices.unsqueeze(0)
927
- else:
928
- if position_ids is None:
929
- position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
930
-
931
- attention_mask, sliding_window_mask = self._update_attention_mask(
932
- attention_mask, output_attentions=output_attentions
933
- )
477
+ if position_ids is None:
478
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
934
479
 
935
480
  hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
481
+
482
+ if not isinstance(attention_mask_mapping := attention_mask, dict):
483
+ mask_kwargs = {
484
+ "config": self.config,
485
+ "input_embeds": hidden_states,
486
+ "attention_mask": attention_mask,
487
+ }
488
+ attention_mask_mapping = {
489
+ "full_attention": create_bidirectional_mask(**mask_kwargs),
490
+ "sliding_attention": create_bidirectional_sliding_window_mask(**mask_kwargs),
491
+ }
492
+
936
493
  position_embeddings = {}
937
494
  for layer_type in self.config.layer_types:
938
495
  position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
939
496
 
940
497
  for encoder_layer in self.layers:
941
- if output_hidden_states:
942
- all_hidden_states = all_hidden_states + (hidden_states,)
943
-
944
- layer_outputs = encoder_layer(
498
+ hidden_states = encoder_layer(
945
499
  hidden_states,
946
- attention_mask=attention_mask,
947
- sliding_window_mask=sliding_window_mask,
948
- position_ids=position_ids,
949
- cu_seqlens=cu_seqlens,
950
- max_seqlen=max_seqlen,
500
+ attention_mask=attention_mask_mapping[encoder_layer.attention_type],
951
501
  position_embeddings=position_embeddings[encoder_layer.attention_type],
952
- output_attentions=output_attentions,
502
+ **kwargs,
953
503
  )
954
- hidden_states = layer_outputs[0]
955
- if output_attentions and len(layer_outputs) > 1:
956
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
957
-
958
- if output_hidden_states:
959
- all_hidden_states = all_hidden_states + (hidden_states,)
960
504
 
961
505
  hidden_states = self.final_norm(hidden_states)
962
506
 
963
- if repad:
964
- hidden_states = _pad_modernbert_output(
965
- inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
966
- )
967
- if all_hidden_states is not None:
968
- all_hidden_states = tuple(
969
- _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
970
- for hs in all_hidden_states
971
- )
972
- # If the attention implementation is FA2 and there is no need for repadding, there might still be the batch
973
- # dimension missing
974
- elif (
975
- self.config._attn_implementation == "flash_attention_2"
976
- and all_hidden_states is not None
977
- and all_hidden_states[-1].dim() == 2
978
- ):
979
- hidden_states = hidden_states.unsqueeze(0)
980
- all_hidden_states = tuple(hs.unsqueeze(0) for hs in all_hidden_states)
981
-
982
- if not return_dict:
983
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
984
- return BaseModelOutput(
985
- last_hidden_state=hidden_states,
986
- hidden_states=all_hidden_states,
987
- attentions=all_self_attentions,
988
- )
989
-
990
- def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
991
- if output_attentions:
992
- if self.config._attn_implementation == "sdpa":
993
- logger.warning_once(
994
- "Outputting attentions is only supported with the 'eager' attention implementation, "
995
- 'not with "sdpa". Falling back to `attn_implementation="eager"`.'
996
- )
997
- self.config._attn_implementation = "eager"
998
- elif self.config._attn_implementation != "eager":
999
- logger.warning_once(
1000
- "Outputting attentions is only supported with the eager attention implementation, "
1001
- f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
1002
- " Setting `output_attentions=False`."
1003
- )
1004
-
1005
- global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
1006
-
1007
- # Create position indices
1008
- rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
1009
- # Calculate distance between positions
1010
- distance = torch.abs(rows - rows.T)
1011
-
1012
- # Create sliding window mask (1 for positions within window, 0 outside)
1013
- window_mask = (
1014
- (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
1015
- )
1016
- # Combine with existing mask
1017
- sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
1018
-
1019
- return global_attention_mask, sliding_window_mask
507
+ return BaseModelOutput(last_hidden_state=hidden_states)
1020
508
 
1021
509
 
1022
510
  class ModernBertPredictionHead(nn.Module):
@@ -1058,84 +546,23 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
1058
546
  def set_output_embeddings(self, new_embeddings: nn.Linear):
1059
547
  self.decoder = new_embeddings
1060
548
 
1061
- @torch.compile(dynamic=True)
1062
- def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
1063
- return self.decoder(self.head(output))
1064
-
549
+ @can_return_tuple
1065
550
  @auto_docstring
1066
551
  def forward(
1067
552
  self,
1068
- input_ids: Optional[torch.LongTensor] = None,
1069
- attention_mask: Optional[torch.Tensor] = None,
1070
- sliding_window_mask: Optional[torch.Tensor] = None,
1071
- position_ids: Optional[torch.Tensor] = None,
1072
- inputs_embeds: Optional[torch.Tensor] = None,
1073
- labels: Optional[torch.Tensor] = None,
1074
- indices: Optional[torch.Tensor] = None,
1075
- cu_seqlens: Optional[torch.Tensor] = None,
1076
- max_seqlen: Optional[int] = None,
1077
- batch_size: Optional[int] = None,
1078
- seq_len: Optional[int] = None,
1079
- output_attentions: Optional[bool] = None,
1080
- output_hidden_states: Optional[bool] = None,
1081
- return_dict: Optional[bool] = None,
1082
- **kwargs,
1083
- ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
1084
- r"""
1085
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1086
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1087
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1088
- far-away tokens in the local attention layers when not using Flash Attention.
1089
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1090
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1091
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1092
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1093
- max_seqlen (`int`, *optional*):
1094
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1095
- batch_size (`int`, *optional*):
1096
- Batch size of the input sequences. Used to pad the output tensors.
1097
- seq_len (`int`, *optional*):
1098
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1099
- """
1100
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1101
- self._maybe_set_compile()
1102
-
1103
- if self.config._attn_implementation == "flash_attention_2":
1104
- if indices is None and cu_seqlens is None and max_seqlen is None:
1105
- if batch_size is None and seq_len is None:
1106
- if inputs_embeds is not None:
1107
- batch_size, seq_len = inputs_embeds.shape[:2]
1108
- else:
1109
- batch_size, seq_len = input_ids.shape[:2]
1110
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1111
-
1112
- if attention_mask is None:
1113
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1114
-
1115
- if inputs_embeds is None:
1116
- with torch.no_grad():
1117
- input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1118
- inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1119
- )
1120
- else:
1121
- inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1122
- inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1123
- )
1124
-
553
+ input_ids: torch.LongTensor | None = None,
554
+ attention_mask: torch.Tensor | None = None,
555
+ position_ids: torch.Tensor | None = None,
556
+ inputs_embeds: torch.Tensor | None = None,
557
+ labels: torch.Tensor | None = None,
558
+ **kwargs: Unpack[TransformersKwargs],
559
+ ) -> tuple[torch.Tensor] | MaskedLMOutput:
1125
560
  outputs = self.model(
1126
561
  input_ids=input_ids,
1127
562
  attention_mask=attention_mask,
1128
- sliding_window_mask=sliding_window_mask,
1129
563
  position_ids=position_ids,
1130
564
  inputs_embeds=inputs_embeds,
1131
- indices=indices,
1132
- cu_seqlens=cu_seqlens,
1133
- max_seqlen=max_seqlen,
1134
- batch_size=batch_size,
1135
- seq_len=seq_len,
1136
- output_attentions=output_attentions,
1137
- output_hidden_states=output_hidden_states,
1138
- return_dict=return_dict,
565
+ **kwargs,
1139
566
  )
1140
567
  last_hidden_state = outputs[0]
1141
568
 
@@ -1149,35 +576,12 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
1149
576
  last_hidden_state = last_hidden_state[mask_tokens]
1150
577
  labels = labels[mask_tokens]
1151
578
 
1152
- logits = (
1153
- self.compiled_head(last_hidden_state)
1154
- if self.config.reference_compile
1155
- else self.decoder(self.head(last_hidden_state))
1156
- )
579
+ logits = self.decoder(self.head(last_hidden_state))
1157
580
 
1158
581
  loss = None
1159
582
  if labels is not None:
1160
583
  loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
1161
584
 
1162
- if self.config._attn_implementation == "flash_attention_2":
1163
- # Logits padding
1164
- with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
1165
- logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
1166
- # Hidden states padding
1167
- if getattr(outputs, "hidden_states", None) is not None:
1168
- padded_hidden_states = []
1169
- for hs in outputs.hidden_states:
1170
- if hs.dim() == 3 and hs.shape[0] == 1:
1171
- hs = hs.squeeze(0)
1172
- padded_hidden_states.append(
1173
- _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
1174
- )
1175
- outputs.hidden_states = tuple(padded_hidden_states)
1176
-
1177
- if not return_dict:
1178
- output = (logits,)
1179
- return ((loss,) + output) if loss is not None else output
1180
-
1181
585
  return MaskedLMOutput(
1182
586
  loss=loss,
1183
587
  logits=logits,
@@ -1205,81 +609,39 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
1205
609
  # Initialize weights and apply final processing
1206
610
  self.post_init()
1207
611
 
612
+ @can_return_tuple
1208
613
  @auto_docstring
1209
614
  def forward(
1210
615
  self,
1211
- input_ids: Optional[torch.LongTensor] = None,
1212
- attention_mask: Optional[torch.Tensor] = None,
1213
- sliding_window_mask: Optional[torch.Tensor] = None,
1214
- position_ids: Optional[torch.Tensor] = None,
1215
- inputs_embeds: Optional[torch.Tensor] = None,
1216
- labels: Optional[torch.Tensor] = None,
1217
- indices: Optional[torch.Tensor] = None,
1218
- cu_seqlens: Optional[torch.Tensor] = None,
1219
- max_seqlen: Optional[int] = None,
1220
- batch_size: Optional[int] = None,
1221
- seq_len: Optional[int] = None,
1222
- output_attentions: Optional[bool] = None,
1223
- output_hidden_states: Optional[bool] = None,
1224
- return_dict: Optional[bool] = None,
1225
- **kwargs,
1226
- ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
616
+ input_ids: torch.LongTensor | None = None,
617
+ attention_mask: torch.Tensor | None = None,
618
+ position_ids: torch.Tensor | None = None,
619
+ inputs_embeds: torch.Tensor | None = None,
620
+ labels: torch.Tensor | None = None,
621
+ **kwargs: Unpack[TransformersKwargs],
622
+ ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
1227
623
  r"""
1228
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1229
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1230
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1231
- far-away tokens in the local attention layers when not using Flash Attention.
1232
624
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1233
625
  Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1234
626
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1235
627
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1236
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1237
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1238
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1239
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1240
- max_seqlen (`int`, *optional*):
1241
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1242
- batch_size (`int`, *optional*):
1243
- Batch size of the input sequences. Used to pad the output tensors.
1244
- seq_len (`int`, *optional*):
1245
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1246
628
  """
1247
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1248
- self._maybe_set_compile()
1249
-
1250
- if input_ids is not None:
1251
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1252
-
1253
- if batch_size is None and seq_len is None:
1254
- if inputs_embeds is not None:
1255
- batch_size, seq_len = inputs_embeds.shape[:2]
1256
- else:
1257
- batch_size, seq_len = input_ids.shape[:2]
1258
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1259
-
1260
- if attention_mask is None:
1261
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1262
-
1263
629
  outputs = self.model(
1264
630
  input_ids=input_ids,
1265
631
  attention_mask=attention_mask,
1266
- sliding_window_mask=sliding_window_mask,
1267
632
  position_ids=position_ids,
1268
633
  inputs_embeds=inputs_embeds,
1269
- indices=indices,
1270
- cu_seqlens=cu_seqlens,
1271
- max_seqlen=max_seqlen,
1272
- batch_size=batch_size,
1273
- seq_len=seq_len,
1274
- output_attentions=output_attentions,
1275
- output_hidden_states=output_hidden_states,
1276
- return_dict=return_dict,
634
+ **kwargs,
1277
635
  )
1278
636
  last_hidden_state = outputs[0]
1279
637
 
1280
638
  if self.config.classifier_pooling == "cls":
1281
639
  last_hidden_state = last_hidden_state[:, 0]
1282
640
  elif self.config.classifier_pooling == "mean":
641
+ if attention_mask is None:
642
+ attention_mask = torch.ones(
643
+ last_hidden_state.shape[:2], device=last_hidden_state.device, dtype=torch.bool
644
+ )
1283
645
  last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
1284
646
  dim=1, keepdim=True
1285
647
  )
@@ -1311,10 +673,6 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
1311
673
  loss_fct = BCEWithLogitsLoss()
1312
674
  loss = loss_fct(logits, labels)
1313
675
 
1314
- if not return_dict:
1315
- output = (logits,)
1316
- return ((loss,) + output) if loss is not None else output
1317
-
1318
676
  return SequenceClassifierOutput(
1319
677
  loss=loss,
1320
678
  logits=logits,
@@ -1341,60 +699,27 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1341
699
  # Initialize weights and apply final processing
1342
700
  self.post_init()
1343
701
 
702
+ @can_return_tuple
1344
703
  @auto_docstring
1345
704
  def forward(
1346
705
  self,
1347
- input_ids: Optional[torch.LongTensor] = None,
1348
- attention_mask: Optional[torch.Tensor] = None,
1349
- sliding_window_mask: Optional[torch.Tensor] = None,
1350
- position_ids: Optional[torch.Tensor] = None,
1351
- inputs_embeds: Optional[torch.Tensor] = None,
1352
- labels: Optional[torch.Tensor] = None,
1353
- indices: Optional[torch.Tensor] = None,
1354
- cu_seqlens: Optional[torch.Tensor] = None,
1355
- max_seqlen: Optional[int] = None,
1356
- batch_size: Optional[int] = None,
1357
- seq_len: Optional[int] = None,
1358
- output_attentions: Optional[bool] = None,
1359
- output_hidden_states: Optional[bool] = None,
1360
- return_dict: Optional[bool] = None,
1361
- **kwargs,
1362
- ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
706
+ input_ids: torch.LongTensor | None = None,
707
+ attention_mask: torch.Tensor | None = None,
708
+ position_ids: torch.Tensor | None = None,
709
+ inputs_embeds: torch.Tensor | None = None,
710
+ labels: torch.Tensor | None = None,
711
+ **kwargs: Unpack[TransformersKwargs],
712
+ ) -> tuple[torch.Tensor] | TokenClassifierOutput:
1363
713
  r"""
1364
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1365
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1366
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1367
- far-away tokens in the local attention layers when not using Flash Attention.
1368
714
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1369
715
  Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1370
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1371
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1372
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1373
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1374
- max_seqlen (`int`, *optional*):
1375
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1376
- batch_size (`int`, *optional*):
1377
- Batch size of the input sequences. Used to pad the output tensors.
1378
- seq_len (`int`, *optional*):
1379
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1380
716
  """
1381
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1382
- self._maybe_set_compile()
1383
-
1384
717
  outputs = self.model(
1385
718
  input_ids=input_ids,
1386
719
  attention_mask=attention_mask,
1387
- sliding_window_mask=sliding_window_mask,
1388
720
  position_ids=position_ids,
1389
721
  inputs_embeds=inputs_embeds,
1390
- indices=indices,
1391
- cu_seqlens=cu_seqlens,
1392
- max_seqlen=max_seqlen,
1393
- batch_size=batch_size,
1394
- seq_len=seq_len,
1395
- output_attentions=output_attentions,
1396
- output_hidden_states=output_hidden_states,
1397
- return_dict=return_dict,
722
+ **kwargs,
1398
723
  )
1399
724
  last_hidden_state = outputs[0]
1400
725
 
@@ -1407,10 +732,6 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1407
732
  loss_fct = CrossEntropyLoss()
1408
733
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1409
734
 
1410
- if not return_dict:
1411
- output = (logits,) + outputs[1:]
1412
- return ((loss,) + output) if loss is not None else output
1413
-
1414
735
  return TokenClassifierOutput(
1415
736
  loss=loss,
1416
737
  logits=logits,
@@ -1432,57 +753,22 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
1432
753
 
1433
754
  self.post_init()
1434
755
 
756
+ @can_return_tuple
1435
757
  @auto_docstring
1436
758
  def forward(
1437
759
  self,
1438
- input_ids: Optional[torch.Tensor],
1439
- attention_mask: Optional[torch.Tensor] = None,
1440
- sliding_window_mask: Optional[torch.Tensor] = None,
1441
- position_ids: Optional[torch.Tensor] = None,
1442
- start_positions: Optional[torch.Tensor] = None,
1443
- end_positions: Optional[torch.Tensor] = None,
1444
- indices: Optional[torch.Tensor] = None,
1445
- cu_seqlens: Optional[torch.Tensor] = None,
1446
- max_seqlen: Optional[int] = None,
1447
- batch_size: Optional[int] = None,
1448
- seq_len: Optional[int] = None,
1449
- output_attentions: Optional[bool] = None,
1450
- output_hidden_states: Optional[bool] = None,
1451
- return_dict: Optional[bool] = None,
1452
- **kwargs,
1453
- ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1454
- r"""
1455
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1456
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1457
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1458
- far-away tokens in the local attention layers when not using Flash Attention.
1459
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1460
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1461
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1462
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1463
- max_seqlen (`int`, *optional*):
1464
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1465
- batch_size (`int`, *optional*):
1466
- Batch size of the input sequences. Used to pad the output tensors.
1467
- seq_len (`int`, *optional*):
1468
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1469
- """
1470
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1471
- self._maybe_set_compile()
1472
-
760
+ input_ids: torch.Tensor | None = None,
761
+ attention_mask: torch.Tensor | None = None,
762
+ position_ids: torch.Tensor | None = None,
763
+ start_positions: torch.Tensor | None = None,
764
+ end_positions: torch.Tensor | None = None,
765
+ **kwargs: Unpack[TransformersKwargs],
766
+ ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
1473
767
  outputs = self.model(
1474
768
  input_ids,
1475
769
  attention_mask=attention_mask,
1476
- sliding_window_mask=sliding_window_mask,
1477
770
  position_ids=position_ids,
1478
- indices=indices,
1479
- cu_seqlens=cu_seqlens,
1480
- max_seqlen=max_seqlen,
1481
- batch_size=batch_size,
1482
- seq_len=seq_len,
1483
- output_attentions=output_attentions,
1484
- output_hidden_states=output_hidden_states,
1485
- return_dict=return_dict,
771
+ **kwargs,
1486
772
  )
1487
773
  last_hidden_state = outputs[0]
1488
774
 
@@ -1498,10 +784,6 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
1498
784
  if start_positions is not None and end_positions is not None:
1499
785
  loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1500
786
 
1501
- if not return_dict:
1502
- output = (start_logits, end_logits) + outputs[1:]
1503
- return ((loss,) + output) if loss is not None else output
1504
-
1505
787
  return QuestionAnsweringModelOutput(
1506
788
  loss=loss,
1507
789
  start_logits=start_logits,
@@ -1529,45 +811,22 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
1529
811
  # Initialize weights and apply final processing
1530
812
  self.post_init()
1531
813
 
814
+ @can_return_tuple
1532
815
  @auto_docstring
1533
816
  def forward(
1534
817
  self,
1535
- input_ids: Optional[torch.LongTensor] = None,
1536
- attention_mask: Optional[torch.Tensor] = None,
1537
- sliding_window_mask: Optional[torch.Tensor] = None,
1538
- position_ids: Optional[torch.Tensor] = None,
1539
- inputs_embeds: Optional[torch.Tensor] = None,
1540
- labels: Optional[torch.Tensor] = None,
1541
- indices: Optional[torch.Tensor] = None,
1542
- cu_seqlens: Optional[torch.Tensor] = None,
1543
- max_seqlen: Optional[int] = None,
1544
- batch_size: Optional[int] = None,
1545
- seq_len: Optional[int] = None,
1546
- output_attentions: Optional[bool] = None,
1547
- output_hidden_states: Optional[bool] = None,
1548
- return_dict: Optional[bool] = None,
1549
- **kwargs,
1550
- ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
818
+ input_ids: torch.LongTensor | None = None,
819
+ attention_mask: torch.Tensor | None = None,
820
+ position_ids: torch.Tensor | None = None,
821
+ inputs_embeds: torch.Tensor | None = None,
822
+ labels: torch.Tensor | None = None,
823
+ **kwargs: Unpack[TransformersKwargs],
824
+ ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
1551
825
  r"""
1552
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1553
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1554
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1555
- far-away tokens in the local attention layers when not using Flash Attention.
1556
826
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1557
827
  Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1558
828
  num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
1559
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1560
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1561
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1562
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1563
- max_seqlen (`int`, *optional*):
1564
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1565
- batch_size (`int`, *optional*):
1566
- Batch size of the input sequences. Used to pad the output tensors.
1567
- seq_len (`int`, *optional*):
1568
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1569
829
  """
1570
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1571
830
  num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1572
831
 
1573
832
  input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@@ -1579,22 +838,12 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
1579
838
  else None
1580
839
  )
1581
840
 
1582
- self._maybe_set_compile()
1583
-
1584
841
  outputs = self.model(
1585
842
  input_ids=input_ids,
1586
843
  attention_mask=attention_mask,
1587
- sliding_window_mask=sliding_window_mask,
1588
844
  position_ids=position_ids,
1589
845
  inputs_embeds=inputs_embeds,
1590
- indices=indices,
1591
- cu_seqlens=cu_seqlens,
1592
- max_seqlen=max_seqlen,
1593
- batch_size=batch_size,
1594
- seq_len=seq_len,
1595
- output_attentions=output_attentions,
1596
- output_hidden_states=output_hidden_states,
1597
- return_dict=return_dict,
846
+ **kwargs,
1598
847
  )
1599
848
  last_hidden_state = outputs[0] # shape (num_choices, seq_len, hidden_size)
1600
849
 
@@ -1626,10 +875,6 @@ class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
1626
875
  loss_fct = nn.CrossEntropyLoss()
1627
876
  loss = loss_fct(reshaped_logits, labels)
1628
877
 
1629
- if not return_dict:
1630
- output = (reshaped_logits,) + outputs[1:]
1631
- return ((loss,) + output) if loss is not None else output
1632
-
1633
878
  return MultipleChoiceModelOutput(
1634
879
  loss=loss,
1635
880
  logits=reshaped_logits,