transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -25,9 +25,10 @@ from typing import Optional, Union
25
25
  import torch
26
26
  from torch import nn
27
27
 
28
+ from ... import initialization as init
28
29
  from ...activations import ACT2FN
29
30
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
30
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
31
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
31
32
  from ...masking_utils import create_bidirectional_mask, create_causal_mask
32
33
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
33
34
  from ...modeling_layers import GradientCheckpointingLayer
@@ -41,6 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
41
42
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
42
43
  from ...processing_utils import Unpack
43
44
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
45
+ from ...utils.generic import maybe_autocast
44
46
  from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
45
47
  from .generation_dia import DiaGenerationMixin
46
48
 
@@ -60,6 +62,12 @@ class DiaPreTrainedModel(PreTrainedModel):
60
62
  main_input_name = "input_ids"
61
63
  _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
62
64
 
65
+ def _init_weights(self, module):
66
+ super()._init_weights(module)
67
+ if isinstance(module, DiaMultiChannelEmbedding):
68
+ offsets = torch.arange(self.config.num_channels, dtype=torch.long) * self.config.vocab_size
69
+ init.copy_(module.offsets, offsets)
70
+
63
71
 
64
72
  class DiaMultiChannelEmbedding(nn.Module):
65
73
  """In order to efficiently compute the audio embedding from the 9 different channels,
@@ -145,7 +153,7 @@ class DiaRotaryEmbedding(nn.Module):
145
153
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
146
154
 
147
155
  self.register_buffer("inv_freq", inv_freq, persistent=False)
148
- self.original_inv_freq = inv_freq
156
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
149
157
 
150
158
  @staticmethod
151
159
  def compute_default_rope_parameters(
@@ -184,7 +192,7 @@ class DiaRotaryEmbedding(nn.Module):
184
192
  position_ids_expanded = position_ids[:, None, :].float()
185
193
 
186
194
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
187
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
195
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
188
196
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
189
197
  emb = torch.cat((freqs, freqs), dim=-1)
190
198
  cos = emb.cos() * self.attention_scaling
@@ -266,6 +274,7 @@ def eager_attention_forward(
266
274
  return attn_output, attn_weights
267
275
 
268
276
 
277
+ @use_kernelized_func(apply_rotary_pos_emb)
269
278
  class DiaSelfAttention(nn.Module):
270
279
  """Multi-headed attention from 'Attention Is All You Need' paper"""
271
280
 
@@ -450,6 +459,8 @@ class DiaEncoder(DiaPreTrainedModel):
450
459
  self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
451
460
  self.rotary_emb = DiaRotaryEmbedding(config=config)
452
461
 
462
+ self.post_init()
463
+
453
464
  @auto_docstring
454
465
  @can_return_tuple
455
466
  def forward(
@@ -523,7 +534,6 @@ class DiaDecoderLayer(GradientCheckpointingLayer):
523
534
  encoder_attention_mask: Optional[torch.Tensor] = None,
524
535
  past_key_values: Optional[EncoderDecoderCache] = None,
525
536
  cache_position: Optional[torch.LongTensor] = None,
526
- position_ids: Optional[torch.LongTensor] = None,
527
537
  **kwargs,
528
538
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
529
539
  self_attn_cache = past_key_values
@@ -577,6 +587,8 @@ class DiaDecoder(DiaPreTrainedModel):
577
587
  self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
578
588
  self.rotary_emb = DiaRotaryEmbedding(config=config)
579
589
 
590
+ self.post_init()
591
+
580
592
  @auto_docstring
581
593
  @can_return_tuple
582
594
  def forward(
@@ -20,6 +20,7 @@ from typing import Optional, Union
20
20
  import torch
21
21
  from torch import nn
22
22
 
23
+ from ... import initialization as init
23
24
  from ...cache_utils import DynamicCache, EncoderDecoderCache
24
25
  from ...masking_utils import create_bidirectional_mask, create_causal_mask
25
26
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -59,6 +60,12 @@ class DiaPreTrainedModel(PreTrainedModel):
59
60
  main_input_name = "input_ids"
60
61
  _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
61
62
 
63
+ def _init_weights(self, module):
64
+ super()._init_weights(module)
65
+ if isinstance(module, DiaMultiChannelEmbedding):
66
+ offsets = torch.arange(self.config.num_channels, dtype=torch.long) * self.config.vocab_size
67
+ init.copy_(module.offsets, offsets)
68
+
62
69
 
63
70
  class DiaMultiChannelEmbedding(nn.Module):
64
71
  """In order to efficiently compute the audio embedding from the 9 different channels,
@@ -241,6 +248,8 @@ class DiaEncoder(DiaPreTrainedModel):
241
248
  self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
242
249
  self.rotary_emb = DiaRotaryEmbedding(config=config)
243
250
 
251
+ self.post_init()
252
+
244
253
  @auto_docstring
245
254
  @can_return_tuple
246
255
  def forward(
@@ -314,7 +323,6 @@ class DiaDecoderLayer(GradientCheckpointingLayer):
314
323
  encoder_attention_mask: Optional[torch.Tensor] = None,
315
324
  past_key_values: Optional[EncoderDecoderCache] = None,
316
325
  cache_position: Optional[torch.LongTensor] = None,
317
- position_ids: Optional[torch.LongTensor] = None,
318
326
  **kwargs,
319
327
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
320
328
  self_attn_cache = past_key_values
@@ -368,6 +376,8 @@ class DiaDecoder(DiaPreTrainedModel):
368
376
  self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
369
377
  self.rotary_emb = DiaRotaryEmbedding(config=config)
370
378
 
379
+ self.post_init()
380
+
371
381
  @auto_docstring
372
382
  @can_return_tuple
373
383
  def forward(
@@ -74,7 +74,7 @@ class DiaProcessor(ProcessorMixin):
74
74
  tokenizer (`DiaTokenizer`):
75
75
  An instance of [`DiaTokenizer`]. The tokenizer is a required input.
76
76
  audio_tokenizer (`DacModel`):
77
- An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is is a required input.
77
+ An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is a required input.
78
78
  """
79
79
 
80
80
  audio_tokenizer_class = "DacModel"
@@ -46,7 +46,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
46
46
  from ...modeling_utils import PreTrainedModel
47
47
  from ...processing_utils import Unpack
48
48
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
49
- from ...utils.generic import check_model_inputs
49
+ from ...utils.generic import check_model_inputs, maybe_autocast
50
50
  from .configuration_diffllama import DiffLlamaConfig
51
51
 
52
52
 
@@ -86,7 +86,7 @@ class DiffLlamaRotaryEmbedding(nn.Module):
86
86
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
87
87
 
88
88
  self.register_buffer("inv_freq", inv_freq, persistent=False)
89
- self.original_inv_freq = inv_freq
89
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
90
90
 
91
91
  @staticmethod
92
92
  def compute_default_rope_parameters(
@@ -125,7 +125,7 @@ class DiffLlamaRotaryEmbedding(nn.Module):
125
125
  position_ids_expanded = position_ids[:, None, :].float()
126
126
 
127
127
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
128
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
128
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
129
129
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
130
130
  emb = torch.cat((freqs, freqs), dim=-1)
131
131
  cos = emb.cos() * self.attention_scaling
@@ -361,8 +361,8 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention):
361
361
  else torch.get_autocast_gpu_dtype()
362
362
  )
363
363
  # Handle the case where the model is quantized
364
- elif hasattr(self.config, "_pre_quantization_dtype"):
365
- target_dtype = self.config._pre_quantization_dtype
364
+ elif hasattr(self.config, "quantization_config"):
365
+ target_dtype = self.config.dtype
366
366
  else:
367
367
  target_dtype = self.q_proj.weight.dtype
368
368
 
@@ -236,8 +236,8 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention):
236
236
  else torch.get_autocast_gpu_dtype()
237
237
  )
238
238
  # Handle the case where the model is quantized
239
- elif hasattr(self.config, "_pre_quantization_dtype"):
240
- target_dtype = self.config._pre_quantization_dtype
239
+ elif hasattr(self.config, "quantization_config"):
240
+ target_dtype = self.config.dtype
241
241
  else:
242
242
  target_dtype = self.q_proj.weight.dtype
243
243
 
@@ -596,6 +596,7 @@ class DinatModel(DinatPreTrainedModel):
596
596
  output_attentions: Optional[bool] = None,
597
597
  output_hidden_states: Optional[bool] = None,
598
598
  return_dict: Optional[bool] = None,
599
+ **kwargs,
599
600
  ) -> Union[tuple, DinatModelOutput]:
600
601
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
601
602
  output_hidden_states = (
@@ -668,6 +669,7 @@ class DinatForImageClassification(DinatPreTrainedModel):
668
669
  output_attentions: Optional[bool] = None,
669
670
  output_hidden_states: Optional[bool] = None,
670
671
  return_dict: Optional[bool] = None,
672
+ **kwargs,
671
673
  ) -> Union[tuple, DinatImageClassifierOutput]:
672
674
  r"""
673
675
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -740,6 +742,7 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
740
742
  output_hidden_states: Optional[bool] = None,
741
743
  output_attentions: Optional[bool] = None,
742
744
  return_dict: Optional[bool] = None,
745
+ **kwargs,
743
746
  ) -> BackboneOutput:
744
747
  r"""
745
748
  Examples:
@@ -214,7 +214,7 @@ class DINOv3ConvNextModel(DINOv3ConvNextPreTrainedModel):
214
214
  @can_return_tuple
215
215
  @auto_docstring
216
216
  def forward(
217
- self, pixel_values: torch.FloatTensor, output_hidden_states: Optional[bool] = None
217
+ self, pixel_values: torch.FloatTensor, output_hidden_states: Optional[bool] = None, **kwargs
218
218
  ) -> BaseModelOutputWithPoolingAndNoAttention:
219
219
  hidden_states = pixel_values
220
220
 
@@ -88,7 +88,6 @@ class DINOv3ViTImageProcessorFast(BaseImageProcessorFast):
88
88
  processed_images_grouped[shape] = stacked_images
89
89
 
90
90
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
91
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
92
91
 
93
92
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
94
93
 
@@ -36,7 +36,7 @@ from ...processing_utils import Unpack
36
36
  from ...pytorch_utils import compile_compatible_method_lru_cache
37
37
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
38
38
  from ...utils.backbone_utils import BackboneMixin
39
- from ...utils.generic import check_model_inputs
39
+ from ...utils.generic import check_model_inputs, maybe_autocast
40
40
  from .configuration_dinov3_vit import DINOv3ViTConfig
41
41
 
42
42
 
@@ -156,7 +156,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
156
156
  device = pixel_values.device
157
157
  device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
158
158
 
159
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
159
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
160
160
  # Although we could precompute static patch_coords from image_size and patch_size in the config,
161
161
  # the model was trained with random_scale, so it can process images of varying sizes.
162
162
  # Therefore, it's better to compute patch_coords dynamically (with lru_cache).
@@ -466,6 +466,9 @@ class DINOv3ViTPreTrainedModel(PreTrainedModel):
466
466
  init.zeros_(module.mask_token)
467
467
  elif isinstance(module, DINOv3ViTLayerScale):
468
468
  init.constant_(module.lambda1, self.config.layerscale_value)
469
+ elif isinstance(module, DINOv3ViTRopePositionEmbedding):
470
+ inv_freq = 1 / module.base ** torch.arange(0, 1, 4 / module.head_dim, dtype=torch.float32)
471
+ init.copy_(module.inv_freq, inv_freq)
469
472
 
470
473
 
471
474
  @auto_docstring
@@ -40,7 +40,7 @@ from ...processing_utils import Unpack
40
40
  from ...pytorch_utils import compile_compatible_method_lru_cache
41
41
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
42
42
  from ...utils.backbone_utils import BackboneMixin
43
- from ...utils.generic import check_model_inputs
43
+ from ...utils.generic import check_model_inputs, maybe_autocast
44
44
  from .configuration_dinov3_vit import DINOv3ViTConfig
45
45
 
46
46
 
@@ -163,7 +163,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
163
163
  device = pixel_values.device
164
164
  device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
165
165
 
166
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
166
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
167
167
  # Although we could precompute static patch_coords from image_size and patch_size in the config,
168
168
  # the model was trained with random_scale, so it can process images of varying sizes.
169
169
  # Therefore, it's better to compute patch_coords dynamically (with lru_cache).
@@ -361,6 +361,9 @@ class DINOv3ViTPreTrainedModel(Dinov2PreTrainedModel):
361
361
  init.zeros_(module.mask_token)
362
362
  elif isinstance(module, DINOv3ViTLayerScale):
363
363
  init.constant_(module.lambda1, self.config.layerscale_value)
364
+ elif isinstance(module, DINOv3ViTRopePositionEmbedding):
365
+ inv_freq = 1 / module.base ** torch.arange(0, 1, 4 / module.head_dim, dtype=torch.float32)
366
+ init.copy_(module.inv_freq, inv_freq)
364
367
 
365
368
 
366
369
  @auto_docstring
@@ -305,15 +305,17 @@ class DistilBertPreTrainedModel(PreTrainedModel):
305
305
  def _init_weights(self, module: nn.Module):
306
306
  """Initialize the weights."""
307
307
  super()._init_weights(module)
308
- if isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds:
309
- init.copy_(
310
- module.position_embeddings.weight,
311
- create_sinusoidal_embeddings(
312
- self.config.max_position_embeddings,
313
- self.config.dim,
314
- torch.empty_like(module.position_embeddings.weight),
315
- ),
316
- )
308
+ if isinstance(module, Embeddings):
309
+ if self.config.sinusoidal_pos_embds:
310
+ init.copy_(
311
+ module.position_embeddings.weight,
312
+ create_sinusoidal_embeddings(
313
+ self.config.max_position_embeddings,
314
+ self.config.dim,
315
+ torch.empty_like(module.position_embeddings.weight),
316
+ ),
317
+ )
318
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
317
319
 
318
320
 
319
321
  @auto_docstring
@@ -23,6 +23,19 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.jso
23
23
  class DistilBertTokenizer(BertTokenizer):
24
24
  model_input_names = ["input_ids", "attention_mask"]
25
25
 
26
+ def __init__(self, *args, do_lower_case: bool = True, **kwargs):
27
+ """
28
+ Construct a DistilBERT tokenizer (backed by HuggingFace's tokenizers library). Based on WordPiece.
29
+
30
+ This tokenizer inherits from [`BertTokenizer`] which contains most of the main methods. Users should refer to
31
+ this superclass for more information regarding those methods.
32
+
33
+ Args:
34
+ do_lower_case (`bool`, *optional*, defaults to `True`):
35
+ Whether or not to lowercase the input when tokenizing.
36
+ """
37
+ super().__init__(*args, do_lower_case=do_lower_case, **kwargs)
38
+
26
39
 
27
40
  # DistilBertTokenizerFast is an alias for DistilBertTokenizer (since BertTokenizer is already a fast tokenizer)
28
41
  DistilBertTokenizerFast = DistilBertTokenizer
@@ -42,7 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
42
42
  from ...modeling_utils import AttentionInterface, PreTrainedModel
43
43
  from ...processing_utils import Unpack
44
44
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available
45
- from ...utils.generic import OutputRecorder, check_model_inputs
45
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
46
46
  from .configuration_doge import DogeConfig
47
47
 
48
48
 
@@ -88,7 +88,7 @@ class DogeRotaryEmbedding(nn.Module):
88
88
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
89
89
 
90
90
  self.register_buffer("inv_freq", inv_freq, persistent=False)
91
- self.original_inv_freq = inv_freq
91
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
92
92
 
93
93
  @staticmethod
94
94
  def compute_default_rope_parameters(
@@ -127,7 +127,7 @@ class DogeRotaryEmbedding(nn.Module):
127
127
  position_ids_expanded = position_ids[:, None, :].float()
128
128
 
129
129
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
130
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
130
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
131
131
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
132
132
  emb = torch.cat((freqs, freqs), dim=-1)
133
133
  cos = emb.cos() * self.attention_scaling
@@ -297,7 +297,6 @@ class DogeAttention(nn.Module):
297
297
  attention_mask: Optional[torch.Tensor] = None,
298
298
  past_key_values: Optional[Cache] = None,
299
299
  cache_position: Optional[torch.LongTensor] = None,
300
- position_ids: Optional[torch.LongTensor] = None,
301
300
  **kwargs,
302
301
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
303
302
  input_shape = hidden_states.shape[:-1]
@@ -321,7 +321,6 @@ class DogeAttention(nn.Module):
321
321
  attention_mask: Optional[torch.Tensor] = None,
322
322
  past_key_values: Optional[Cache] = None,
323
323
  cache_position: Optional[torch.LongTensor] = None,
324
- position_ids: Optional[torch.LongTensor] = None,
325
324
  **kwargs,
326
325
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
327
326
  input_shape = hidden_states.shape[:-1]
@@ -231,7 +231,6 @@ class DonutImageProcessorFast(BaseImageProcessorFast):
231
231
  processed_images_grouped[shape] = stacked_images
232
232
 
233
233
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
234
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
235
234
 
236
235
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
237
236
 
@@ -381,18 +381,7 @@ class DonutSwinSelfAttention(nn.Module):
381
381
  torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
382
382
  )
383
383
 
384
- # get pair-wise relative position index for each token inside the window
385
- coords_h = torch.arange(self.window_size[0])
386
- coords_w = torch.arange(self.window_size[1])
387
- coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
388
- coords_flatten = torch.flatten(coords, 1)
389
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
390
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
391
- relative_coords[:, :, 0] += self.window_size[0] - 1
392
- relative_coords[:, :, 1] += self.window_size[1] - 1
393
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
394
- relative_position_index = relative_coords.sum(-1)
395
- self.register_buffer("relative_position_index", relative_position_index)
384
+ self.register_buffer("relative_position_index", self.create_relative_position_index())
396
385
 
397
386
  self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
398
387
  self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
@@ -451,6 +440,20 @@ class DonutSwinSelfAttention(nn.Module):
451
440
 
452
441
  return outputs
453
442
 
443
+ def create_relative_position_index(self):
444
+ # get pair-wise relative position index for each token inside the window
445
+ coords_h = torch.arange(self.window_size[0])
446
+ coords_w = torch.arange(self.window_size[1])
447
+ coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
448
+ coords_flatten = torch.flatten(coords, 1)
449
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
450
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
451
+ relative_coords[:, :, 0] += self.window_size[0] - 1
452
+ relative_coords[:, :, 1] += self.window_size[1] - 1
453
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
454
+ relative_position_index = relative_coords.sum(-1)
455
+ return relative_position_index
456
+
454
457
 
455
458
  # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
456
459
  class DonutSwinSelfOutput(nn.Module):
@@ -801,6 +804,7 @@ class DonutSwinPreTrainedModel(PreTrainedModel):
801
804
  init.zeros_(module.position_embeddings)
802
805
  elif isinstance(module, DonutSwinSelfAttention):
803
806
  init.zeros_(module.relative_position_bias_table)
807
+ init.copy_(module.relative_position_index, module.create_relative_position_index())
804
808
 
805
809
 
806
810
  @auto_docstring
@@ -837,6 +841,7 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
837
841
  output_hidden_states: Optional[bool] = None,
838
842
  interpolate_pos_encoding: bool = False,
839
843
  return_dict: Optional[bool] = None,
844
+ **kwargs,
840
845
  ) -> Union[tuple, DonutSwinModelOutput]:
841
846
  r"""
842
847
  bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
@@ -923,6 +928,7 @@ class DonutSwinForImageClassification(DonutSwinPreTrainedModel):
923
928
  output_hidden_states: Optional[bool] = None,
924
929
  interpolate_pos_encoding: bool = False,
925
930
  return_dict: Optional[bool] = None,
931
+ **kwargs,
926
932
  ) -> Union[tuple, DonutSwinImageClassifierOutput]:
927
933
  r"""
928
934
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -29,7 +29,12 @@ from ... import initialization as init
29
29
  from ...activations import ACT2FN
30
30
  from ...cache_utils import Cache, DynamicCache
31
31
  from ...generation import GenerationMixin
32
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
32
+ from ...integrations import (
33
+ use_experts_implementation,
34
+ use_kernel_forward_from_hub,
35
+ use_kernel_func_from_hub,
36
+ use_kernelized_func,
37
+ )
33
38
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
34
39
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
35
40
  from ...modeling_layers import GradientCheckpointingLayer
@@ -37,8 +42,8 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
37
42
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
43
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
44
  from ...processing_utils import Unpack
40
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
41
- from ...utils.generic import check_model_inputs
45
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
46
+ from ...utils.generic import check_model_inputs, maybe_autocast
42
47
  from .configuration_dots1 import Dots1Config
43
48
 
44
49
 
@@ -80,7 +85,7 @@ class Dots1RotaryEmbedding(nn.Module):
80
85
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
81
86
 
82
87
  self.register_buffer("inv_freq", inv_freq, persistent=False)
83
- self.original_inv_freq = inv_freq
88
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
84
89
 
85
90
  @staticmethod
86
91
  def compute_default_rope_parameters(
@@ -119,7 +124,7 @@ class Dots1RotaryEmbedding(nn.Module):
119
124
  position_ids_expanded = position_ids[:, None, :].float()
120
125
 
121
126
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
122
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
127
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
123
128
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
124
129
  emb = torch.cat((freqs, freqs), dim=-1)
125
130
  cos = emb.cos() * self.attention_scaling
@@ -201,6 +206,7 @@ def eager_attention_forward(
201
206
  return attn_output, attn_weights
202
207
 
203
208
 
209
+ @use_kernelized_func(apply_rotary_pos_emb)
204
210
  class Dots1Attention(nn.Module):
205
211
  """Multi-headed attention from 'Attention Is All You Need' paper"""
206
212
 
@@ -227,7 +233,6 @@ class Dots1Attention(nn.Module):
227
233
  self.o_proj = nn.Linear(
228
234
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
229
235
  )
230
- self.rotary_fn = apply_rotary_pos_emb
231
236
  self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
232
237
  self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
233
238
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
@@ -308,6 +313,7 @@ class Dots1TopkRouter(nn.Module):
308
313
  return router_logits
309
314
 
310
315
 
316
+ @use_experts_implementation
311
317
  class Dots1NaiveMoe(nn.Module):
312
318
  """Collection of expert weights stored as 3D tensors."""
313
319
 
@@ -315,7 +321,7 @@ class Dots1NaiveMoe(nn.Module):
315
321
  super().__init__()
316
322
  self.num_experts = config.num_local_experts
317
323
  self.hidden_dim = config.hidden_size
318
- self.intermediate_dim = config.intermediate_size
324
+ self.intermediate_dim = config.moe_intermediate_size
319
325
  self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
320
326
  self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
321
327
  self.act_fn = ACT2FN[config.hidden_act]
@@ -369,9 +375,11 @@ class Dots1MoE(nn.Module):
369
375
 
370
376
  def route_tokens_to_experts(self, router_logits):
371
377
  router_logits = router_logits.sigmoid() # main diff with deepseekv3
372
- router_logits = router_logits + self.gate.e_score_correction_bias
378
+ router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
373
379
  group_scores = (
374
- router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1)
380
+ router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
381
+ .topk(2, dim=-1)[0]
382
+ .sum(dim=-1)
375
383
  )
376
384
  group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
377
385
  group_mask = torch.zeros_like(group_scores)
@@ -381,7 +389,7 @@ class Dots1MoE(nn.Module):
381
389
  .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
382
390
  .reshape(-1, self.n_routed_experts)
383
391
  )
384
- scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0)
392
+ scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
385
393
  topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
386
394
  topk_weights = router_logits.gather(1, topk_indices)
387
395
  if self.norm_topk_prob:
@@ -461,18 +469,22 @@ class Dots1PreTrainedModel(PreTrainedModel):
461
469
  _supports_flash_attn = True
462
470
  _supports_sdpa = True
463
471
  _supports_flex_attn = True
464
- _can_compile_fullgraph = False
472
+ _can_compile_fullgraph = (
473
+ is_grouped_mm_available()
474
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
465
475
  _supports_attention_backend = True
466
476
  _can_record_outputs = {
467
477
  "hidden_states": Dots1DecoderLayer,
468
478
  "attentions": Dots1Attention,
469
479
  }
480
+ _keep_in_fp32_modules_strict = ["e_score_correction_bias"]
470
481
 
471
482
  @torch.no_grad()
472
483
  def _init_weights(self, module):
473
484
  super()._init_weights(module)
474
485
  if isinstance(module, Dots1TopkRouter):
475
486
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
487
+ init.zeros_(module.e_score_correction_bias)
476
488
  elif isinstance(module, Dots1NaiveMoe):
477
489
  init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
478
490
  init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
@@ -61,9 +61,11 @@ class Dots1TopkRouter(DeepseekV3TopkRouter):
61
61
  class Dots1MoE(DeepseekV3MoE):
62
62
  def route_tokens_to_experts(self, router_logits):
63
63
  router_logits = router_logits.sigmoid() # main diff with deepseekv3
64
- router_logits = router_logits + self.gate.e_score_correction_bias
64
+ router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
65
65
  group_scores = (
66
- router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1)
66
+ router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
67
+ .topk(2, dim=-1)[0]
68
+ .sum(dim=-1)
67
69
  )
68
70
  group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
69
71
  group_mask = torch.zeros_like(group_scores)
@@ -73,7 +75,7 @@ class Dots1MoE(DeepseekV3MoE):
73
75
  .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
74
76
  .reshape(-1, self.n_routed_experts)
75
77
  )
76
- scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0)
78
+ scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
77
79
  topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
78
80
  topk_weights = router_logits.gather(1, topk_indices)
79
81
  if self.norm_topk_prob:
@@ -129,6 +129,7 @@ class DPREncoder(DPRPreTrainedModel):
129
129
  output_attentions: bool = False,
130
130
  output_hidden_states: bool = False,
131
131
  return_dict: bool = False,
132
+ **kwargs,
132
133
  ) -> Union[BaseModelOutputWithPooling, tuple[Tensor, ...]]:
133
134
  outputs = self.bert_model(
134
135
  input_ids=input_ids,
@@ -181,6 +182,7 @@ class DPRSpanPredictor(DPRPreTrainedModel):
181
182
  output_attentions: bool = False,
182
183
  output_hidden_states: bool = False,
183
184
  return_dict: bool = False,
185
+ **kwargs,
184
186
  ) -> Union[DPRReaderOutput, tuple[Tensor, ...]]:
185
187
  # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
186
188
  n_passages, sequence_length = input_ids.size() if input_ids is not None else inputs_embeds.size()[:2]
@@ -282,6 +284,7 @@ class DPRContextEncoder(DPRPretrainedContextEncoder):
282
284
  output_attentions: Optional[bool] = None,
283
285
  output_hidden_states: Optional[bool] = None,
284
286
  return_dict: Optional[bool] = None,
287
+ **kwargs,
285
288
  ) -> Union[DPRContextEncoderOutput, tuple[Tensor, ...]]:
286
289
  r"""
287
290
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -387,6 +390,7 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
387
390
  output_attentions: Optional[bool] = None,
388
391
  output_hidden_states: Optional[bool] = None,
389
392
  return_dict: Optional[bool] = None,
393
+ **kwargs,
390
394
  ) -> Union[DPRQuestionEncoderOutput, tuple[Tensor, ...]]:
391
395
  r"""
392
396
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -492,6 +496,7 @@ class DPRReader(DPRPretrainedReader):
492
496
  output_attentions: Optional[bool] = None,
493
497
  output_hidden_states: Optional[bool] = None,
494
498
  return_dict: Optional[bool] = None,
499
+ **kwargs,
495
500
  ) -> Union[DPRReaderOutput, tuple[Tensor, ...]]:
496
501
  r"""
497
502
  input_ids (`tuple[torch.LongTensor]` of shapes `(n_passages, sequence_length)`):