transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -81,6 +81,9 @@ class TimmWrapperConfig(PreTrainedConfig):
81
81
 
82
82
  @classmethod
83
83
  def from_dict(cls, config_dict: dict[str, Any], **kwargs):
84
+ # Create a copy to avoid mutating the original dict
85
+ config_dict = config_dict.copy()
86
+
84
87
  label_names = config_dict.get("label_names")
85
88
  is_custom_model = "num_labels" in kwargs or "id2label" in kwargs
86
89
 
@@ -84,16 +84,13 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
84
84
  main_input_name = "pixel_values"
85
85
  input_modalities = ("image",)
86
86
  config: TimmWrapperConfig
87
- _no_split_modules = []
87
+ # add WA here as `timm` does not support model parallelism
88
+ _no_split_modules = ["TimmWrapperModel"]
88
89
  model_tags = ["timm"]
89
90
 
90
91
  # used in Trainer to avoid passing `loss_kwargs` to model forward
91
92
  accepts_loss_kwargs = False
92
93
 
93
- def __init__(self, *args, **kwargs):
94
- requires_backends(self, ["vision", "timm"])
95
- super().__init__(*args, **kwargs)
96
-
97
94
  def post_init(self):
98
95
  self.supports_gradient_checkpointing = self._timm_model_supports_gradient_checkpointing()
99
96
  super().post_init()
@@ -113,10 +110,17 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
113
110
  Since model architectures may vary, we assume only the classifier requires
114
111
  initialization, while all other weights should be loaded from the checkpoint.
115
112
  """
116
- if isinstance(module, (nn.Linear)):
113
+ if isinstance(module, nn.Linear):
117
114
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
118
115
  if module.bias is not None:
119
116
  init.zeros_(module.bias)
117
+ # Also, reinit all non-persistemt buffers if any!
118
+ if hasattr(module, "init_non_persistent_buffers"):
119
+ module.init_non_persistent_buffers()
120
+ elif isinstance(module, nn.BatchNorm2d) and getattr(module, "running_mean", None) is not None:
121
+ init.zeros_(module.running_mean)
122
+ init.ones_(module.running_var)
123
+ init.zeros_(module.num_batches_tracked)
120
124
 
121
125
  def _timm_model_supports_gradient_checkpointing(self):
122
126
  """
@@ -136,6 +140,13 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
136
140
  def _set_gradient_checkpointing(self, enable: bool = True, *args, **kwargs):
137
141
  self.timm_model.set_grad_checkpointing(enable)
138
142
 
143
+ def get_input_embeddings(self):
144
+ # TIMM backbones operate directly on images and do not expose token embeddings.
145
+ return None
146
+
147
+ def set_input_embeddings(self, value):
148
+ raise NotImplementedError("TimmWrapper models do not own token embeddings and cannot set them.")
149
+
139
150
 
140
151
  class TimmWrapperModel(TimmWrapperPreTrainedModel):
141
152
  """
@@ -143,6 +154,7 @@ class TimmWrapperModel(TimmWrapperPreTrainedModel):
143
154
  """
144
155
 
145
156
  def __init__(self, config: TimmWrapperConfig):
157
+ requires_backends(self, ["vision", "timm"])
146
158
  super().__init__(config)
147
159
  # using num_classes=0 to avoid creating classification head
148
160
  extra_init_kwargs = config.model_args or {}
@@ -150,13 +162,6 @@ class TimmWrapperModel(TimmWrapperPreTrainedModel):
150
162
  self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs)
151
163
  self.post_init()
152
164
 
153
- def get_input_embeddings(self):
154
- # Vision backbones from timm do not expose token embeddings, so there is nothing to return.
155
- return None
156
-
157
- def set_input_embeddings(self, value):
158
- raise NotImplementedError("TimmWrapperModel does not own token embeddings and cannot set them.")
159
-
160
165
  @auto_docstring
161
166
  def forward(
162
167
  self,
@@ -225,7 +230,7 @@ class TimmWrapperModel(TimmWrapperPreTrainedModel):
225
230
  "different architecture or updating the timm package to a compatible version."
226
231
  )
227
232
 
228
- pixel_values = pixel_values.to(self.device, self.dtype)
233
+ pixel_values = pixel_values.to(self.device)
229
234
 
230
235
  if self.features_only:
231
236
  last_hidden_state = self.timm_model.forward(pixel_values, **kwargs)
@@ -265,6 +270,7 @@ class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel):
265
270
  """
266
271
 
267
272
  def __init__(self, config: TimmWrapperConfig):
273
+ requires_backends(self, ["vision", "timm"])
268
274
  super().__init__(config)
269
275
 
270
276
  if config.num_labels == 0:
@@ -89,7 +89,6 @@ class TrOCRSinusoidalPositionalEmbedding(nn.Module):
89
89
  self.embedding_dim = embedding_dim
90
90
  self.padding_idx = padding_idx
91
91
  self.weights = self.get_embedding(num_positions, embedding_dim, padding_idx)
92
- self.register_buffer("_float_tensor", torch.FloatTensor(1))
93
92
 
94
93
  @staticmethod
95
94
  def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
@@ -123,7 +122,6 @@ class TrOCRSinusoidalPositionalEmbedding(nn.Module):
123
122
  if self.weights is None or max_pos > self.weights.size(0):
124
123
  # recompute/expand embeddings if needed
125
124
  self.weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx)
126
- self.weights = self.weights.to(self._float_tensor)
127
125
 
128
126
  x = self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
129
127
 
@@ -459,6 +457,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
459
457
  output_hidden_states=None,
460
458
  return_dict=None,
461
459
  cache_position=None,
460
+ **kwargs,
462
461
  ):
463
462
  r"""
464
463
  Args:
@@ -635,6 +634,7 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
635
634
  def __init__(self, config):
636
635
  super().__init__(config)
637
636
  self.decoder = TrOCRDecoder(config)
637
+ self.post_init()
638
638
 
639
639
  def forward(self, *args, **kwargs):
640
640
  return self.decoder(*args, **kwargs)
@@ -686,6 +686,7 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin):
686
686
  output_hidden_states: Optional[bool] = None,
687
687
  return_dict: Optional[bool] = None,
688
688
  cache_position: Optional[torch.Tensor] = None,
689
+ **kwargs,
689
690
  ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
690
691
  r"""
691
692
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -35,7 +35,7 @@ class TvpConfig(PreTrainedConfig):
35
35
 
36
36
 
37
37
  Args:
38
- backbone_config (`PreTrainedConfig` or `dict`, *optional*):
38
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `ResNetConfig()`):
39
39
  The configuration of the backbone model.
40
40
  backbone (`str`, *optional*):
41
41
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -68,6 +68,8 @@ class TvpConfig(PreTrainedConfig):
68
68
  vocab_size (`int`, *optional*, defaults to 30522):
69
69
  Vocabulary size of the Tvp text model. Defines the number of different tokens that can be represented by
70
70
  the `inputs_ids` passed when calling [`TvpModel`].
71
+ type_vocab_size (`int`, *optional*, defaults to 2):
72
+ The vocabulary size of the `token_type_ids` passed when calling [`TvpModel`].
71
73
  hidden_size (`int`, *optional*, defaults to 768):
72
74
  Dimensionality of the encoder layers.
73
75
  intermediate_size (`int`, *optional*, defaults to 3072):
@@ -114,6 +116,7 @@ class TvpConfig(PreTrainedConfig):
114
116
  max_img_size=448,
115
117
  num_frames=48,
116
118
  vocab_size=30522,
119
+ type_vocab_size=2,
117
120
  hidden_size=768,
118
121
  intermediate_size=3072,
119
122
  num_hidden_layers=12,
@@ -157,6 +160,7 @@ class TvpConfig(PreTrainedConfig):
157
160
  self.max_img_size = max_img_size
158
161
  self.num_frames = num_frames
159
162
  self.vocab_size = vocab_size
163
+ self.type_vocab_size = type_vocab_size
160
164
  self.hidden_size = hidden_size
161
165
  self.intermediate_size = intermediate_size
162
166
  self.num_hidden_layers = num_hidden_layers
@@ -16,7 +16,7 @@
16
16
 
17
17
  import math
18
18
  from dataclasses import dataclass
19
- from typing import Optional
19
+ from typing import Optional, Union
20
20
 
21
21
  import torch
22
22
  from torch import nn
@@ -462,7 +462,7 @@ class TvpEncoder(nn.Module):
462
462
  output_attentions: Optional[bool] = None,
463
463
  output_hidden_states: Optional[bool] = None,
464
464
  return_dict: Optional[bool] = None,
465
- ):
465
+ ) -> Union[tuple, BaseModelOutput]:
466
466
  return_dict = return_dict if return_dict is not None else self.config.return_dict
467
467
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
468
468
  output_hidden_states = (
@@ -721,7 +721,8 @@ class TvpModel(TvpPreTrainedModel):
721
721
  output_hidden_states: Optional[bool] = None,
722
722
  return_dict: Optional[bool] = None,
723
723
  interpolate_pos_encoding: bool = False,
724
- ):
724
+ **kwargs,
725
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
725
726
  r"""
726
727
  Examples:
727
728
  ```python
@@ -822,7 +823,8 @@ class TvpForVideoGrounding(TvpPreTrainedModel):
822
823
  output_hidden_states: Optional[bool] = None,
823
824
  return_dict: Optional[bool] = None,
824
825
  interpolate_pos_encoding: bool = False,
825
- ):
826
+ **kwargs,
827
+ ) -> Union[tuple, TvpVideoGroundingOutput]:
826
828
  r"""
827
829
  labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
828
830
  The labels contains duration, start time, and end time of the video corresponding to the text.
@@ -149,6 +149,7 @@ class UdopConfig(PreTrainedConfig):
149
149
  "'gated-gelu' or 'relu'"
150
150
  )
151
151
 
152
+ kwargs["tie_word_embeddings"] = True
152
153
  super().__init__(
153
154
  pad_token_id=pad_token_id,
154
155
  eos_token_id=eos_token_id,
@@ -1105,7 +1105,8 @@ class UdopStack(UdopPreTrainedModel):
1105
1105
  output_hidden_states=None,
1106
1106
  return_dict=None,
1107
1107
  cache_position=None,
1108
- ):
1108
+ **kwargs,
1109
+ ) -> Union[tuple, BaseModelOutputWithAttentionMask]:
1109
1110
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1110
1111
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1111
1112
  output_hidden_states = (
@@ -1435,12 +1436,10 @@ class UdopModel(UdopPreTrainedModel):
1435
1436
  encoder_config = deepcopy(config)
1436
1437
  encoder_config.is_decoder = False
1437
1438
  encoder_config.use_cache = False
1438
- encoder_config.tie_word_embeddings = True
1439
1439
  self.encoder = UdopStack(encoder_config)
1440
1440
 
1441
1441
  decoder_config = deepcopy(config)
1442
1442
  decoder_config.is_decoder = True
1443
- decoder_config.tie_word_embeddings = True
1444
1443
  decoder_config.num_layers = config.num_decoder_layers
1445
1444
  self.decoder = UdopStack(decoder_config)
1446
1445
 
@@ -1474,7 +1473,8 @@ class UdopModel(UdopPreTrainedModel):
1474
1473
  output_hidden_states: Optional[bool] = None,
1475
1474
  return_dict: Optional[bool] = None,
1476
1475
  cache_position: Optional[torch.LongTensor] = None,
1477
- ) -> tuple[Tensor, ...]:
1476
+ **kwargs,
1477
+ ) -> Union[tuple, Seq2SeqModelOutput]:
1478
1478
  r"""
1479
1479
  bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
1480
1480
  Bounding boxes of each input sequence tokens. Selected in the range `[0,
@@ -1609,12 +1609,10 @@ class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin):
1609
1609
  encoder_config = deepcopy(config)
1610
1610
  encoder_config.is_decoder = False
1611
1611
  encoder_config.use_cache = False
1612
- encoder_config.tie_encoder_decoder = False
1613
1612
  self.encoder = UdopStack(encoder_config)
1614
1613
 
1615
1614
  decoder_config = deepcopy(config)
1616
1615
  decoder_config.is_decoder = True
1617
- decoder_config.tie_encoder_decoder = False
1618
1616
  decoder_config.num_layers = config.num_decoder_layers
1619
1617
  self.decoder = UdopStack(decoder_config)
1620
1618
 
@@ -1652,7 +1650,8 @@ class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin):
1652
1650
  return_dict: Optional[bool] = None,
1653
1651
  labels: Optional[Tensor] = None,
1654
1652
  cache_position: Optional[torch.LongTensor] = None,
1655
- ) -> tuple[Tensor, ...]:
1653
+ **kwargs,
1654
+ ) -> Union[tuple, Seq2SeqLMOutput]:
1656
1655
  r"""
1657
1656
  bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
1658
1657
  Bounding boxes of each input sequence tokens. Selected in the range `[0,
@@ -1821,6 +1820,7 @@ class UdopEncoderModel(UdopPreTrainedModel):
1821
1820
  output_attentions: Optional[bool] = None,
1822
1821
  output_hidden_states: Optional[bool] = None,
1823
1822
  return_dict: Optional[bool] = None,
1823
+ **kwargs,
1824
1824
  ) -> Union[tuple[torch.FloatTensor], BaseModelOutputWithAttentionMask]:
1825
1825
  r"""
1826
1826
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -183,10 +183,11 @@ class UdopTokenizer(TokenizersBackend):
183
183
 
184
184
  vocab_files_names = VOCAB_FILES_NAMES
185
185
  model_input_names = ["input_ids", "attention_mask"]
186
- slow_tokenizer_class = None
186
+ model = Unigram
187
187
 
188
188
  def __init__(
189
189
  self,
190
+ vocab: Optional[Union[str, list[tuple[str, float]]]] = None,
190
191
  eos_token="</s>",
191
192
  sep_token="</s>",
192
193
  unk_token="<unk>",
@@ -196,7 +197,6 @@ class UdopTokenizer(TokenizersBackend):
196
197
  pad_token_label=-100,
197
198
  only_label_first_subword=True,
198
199
  extra_special_tokens=None,
199
- vocab=None,
200
200
  **kwargs,
201
201
  ):
202
202
  if "additional_special_tokens" in kwargs and "extra_special_tokens" not in kwargs:
@@ -205,24 +205,17 @@ class UdopTokenizer(TokenizersBackend):
205
205
  kwargs["extra_special_tokens"] = extra_special_tokens
206
206
 
207
207
  if vocab is None:
208
- vocab_scores = [(str(pad_token), 0.0), (str(eos_token), 0.0), (str(unk_token), 0.0), ("▁", -2.0)]
209
- elif isinstance(vocab, dict):
210
- vocab_scores = [(str(token), float(score)) for token, score in vocab.items()]
211
- elif isinstance(vocab, list) and len(vocab) > 0:
212
- if isinstance(vocab[0], (tuple, list)):
213
- vocab_scores = [(str(token), float(score)) for token, score in vocab]
214
- else:
215
- vocab_scores = [(str(token), 0.0) for token in vocab]
208
+ vocab = [(str(pad_token), 0.0), (str(eos_token), 0.0), (str(unk_token), 0.0), ("▁", -2.0)]
216
209
 
217
210
  unk_id = 2
218
- for idx, (token, _) in enumerate(vocab_scores):
211
+ for idx, (token, _) in enumerate(vocab):
219
212
  if token == str(unk_token):
220
213
  unk_id = idx
221
214
  break
222
215
 
223
216
  self._tokenizer = Tokenizer(
224
217
  Unigram(
225
- vocab_scores,
218
+ vocab,
226
219
  unk_id=unk_id,
227
220
  byte_fallback=False,
228
221
  )
@@ -240,7 +233,6 @@ class UdopTokenizer(TokenizersBackend):
240
233
  self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
241
234
 
242
235
  super().__init__(
243
- tokenizer_object=self._tokenizer,
244
236
  eos_token=eos_token,
245
237
  sep_token=sep_token,
246
238
  unk_token=unk_token,
@@ -94,7 +94,6 @@ class UMT5Config(PreTrainedConfig):
94
94
  is_encoder_decoder=True,
95
95
  use_cache=True,
96
96
  tokenizer_class="T5Tokenizer",
97
- tie_word_embeddings=True,
98
97
  pad_token_id=0,
99
98
  eos_token_id=1,
100
99
  decoder_start_token_id=0,
@@ -133,10 +132,11 @@ class UMT5Config(PreTrainedConfig):
133
132
  if feed_forward_proj == "gated-gelu":
134
133
  self.dense_act_fn = "gelu_new"
135
134
 
135
+ # Force because official weights have False serialized, but we have to tie always
136
+ kwargs["tie_word_embeddings"] = True
136
137
  super().__init__(
137
138
  is_encoder_decoder=is_encoder_decoder,
138
139
  tokenizer_class=tokenizer_class,
139
- tie_word_embeddings=tie_word_embeddings,
140
140
  pad_token_id=pad_token_id,
141
141
  eos_token_id=eos_token_id,
142
142
  decoder_start_token_id=decoder_start_token_id,
@@ -621,6 +621,7 @@ class UMT5Stack(UMT5PreTrainedModel):
621
621
  output_hidden_states=None,
622
622
  return_dict=None,
623
623
  cache_position=None,
624
+ **kwargs,
624
625
  ):
625
626
  use_cache = use_cache if use_cache is not None else self.config.use_cache
626
627
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -928,12 +929,10 @@ class UMT5Model(UMT5PreTrainedModel):
928
929
  encoder_config = copy.deepcopy(config)
929
930
  encoder_config.is_decoder = False
930
931
  encoder_config.use_cache = False
931
- encoder_config.tie_encoder_decoder = False
932
932
  self.encoder = UMT5Stack(encoder_config)
933
933
 
934
934
  decoder_config = copy.deepcopy(config)
935
935
  decoder_config.is_decoder = True
936
- decoder_config.tie_encoder_decoder = False
937
936
  decoder_config.num_layers = config.num_decoder_layers
938
937
  self.decoder = UMT5Stack(decoder_config)
939
938
 
@@ -966,6 +965,7 @@ class UMT5Model(UMT5PreTrainedModel):
966
965
  output_hidden_states: Optional[bool] = None,
967
966
  return_dict: Optional[bool] = None,
968
967
  cache_position: Optional[torch.LongTensor] = None,
968
+ **kwargs,
969
969
  ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]:
970
970
  r"""
971
971
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1106,12 +1106,10 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel, GenerationMixin):
1106
1106
  encoder_config = copy.deepcopy(config)
1107
1107
  encoder_config.is_decoder = False
1108
1108
  encoder_config.use_cache = False
1109
- encoder_config.tie_encoder_decoder = False
1110
1109
  self.encoder = UMT5Stack(encoder_config)
1111
1110
 
1112
1111
  decoder_config = copy.deepcopy(config)
1113
1112
  decoder_config.is_decoder = True
1114
- decoder_config.tie_encoder_decoder = False
1115
1113
  decoder_config.num_layers = config.num_decoder_layers
1116
1114
  self.decoder = UMT5Stack(decoder_config)
1117
1115
 
@@ -1147,6 +1145,7 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel, GenerationMixin):
1147
1145
  output_hidden_states: Optional[bool] = None,
1148
1146
  return_dict: Optional[bool] = None,
1149
1147
  cache_position: Optional[torch.LongTensor] = None,
1148
+ **kwargs,
1150
1149
  ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1151
1150
  r"""
1152
1151
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1332,6 +1331,7 @@ class UMT5EncoderModel(UMT5PreTrainedModel):
1332
1331
  output_attentions: Optional[bool] = None,
1333
1332
  output_hidden_states: Optional[bool] = None,
1334
1333
  return_dict: Optional[bool] = None,
1334
+ **kwargs,
1335
1335
  ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
1336
1336
  r"""
1337
1337
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1403,6 +1403,7 @@ class UMT5ForSequenceClassification(UMT5PreTrainedModel):
1403
1403
  output_attentions: Optional[bool] = None,
1404
1404
  output_hidden_states: Optional[bool] = None,
1405
1405
  return_dict: Optional[bool] = None,
1406
+ **kwargs,
1406
1407
  ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
1407
1408
  r"""
1408
1409
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1545,6 +1546,7 @@ class UMT5ForTokenClassification(UMT5PreTrainedModel):
1545
1546
  output_attentions: Optional[bool] = None,
1546
1547
  output_hidden_states: Optional[bool] = None,
1547
1548
  return_dict: Optional[bool] = None,
1549
+ **kwargs,
1548
1550
  ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
1549
1551
  r"""
1550
1552
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1608,12 +1610,10 @@ class UMT5ForQuestionAnswering(UMT5PreTrainedModel):
1608
1610
  encoder_config = copy.deepcopy(config)
1609
1611
  encoder_config.is_decoder = False
1610
1612
  encoder_config.use_cache = False
1611
- encoder_config.tie_encoder_decoder = False
1612
1613
  self.encoder = UMT5Stack(encoder_config)
1613
1614
 
1614
1615
  decoder_config = copy.deepcopy(config)
1615
1616
  decoder_config.is_decoder = True
1616
- decoder_config.tie_encoder_decoder = False
1617
1617
  decoder_config.num_layers = config.num_decoder_layers
1618
1618
  self.decoder = UMT5Stack(decoder_config)
1619
1619
 
@@ -1649,6 +1649,7 @@ class UMT5ForQuestionAnswering(UMT5PreTrainedModel):
1649
1649
  output_attentions: Optional[bool] = None,
1650
1650
  output_hidden_states: Optional[bool] = None,
1651
1651
  return_dict: Optional[bool] = None,
1652
+ **kwargs,
1652
1653
  ) -> Union[tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]:
1653
1654
  r"""
1654
1655
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1001,6 +1001,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
1001
1001
  output_attentions: Optional[bool] = None,
1002
1002
  output_hidden_states: Optional[bool] = None,
1003
1003
  return_dict: Optional[bool] = None,
1004
+ **kwargs,
1004
1005
  ) -> Union[tuple, UniSpeechBaseModelOutput]:
1005
1006
  r"""
1006
1007
  mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1108,6 +1109,7 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
1108
1109
  output_attentions: Optional[bool] = None,
1109
1110
  output_hidden_states: Optional[bool] = None,
1110
1111
  return_dict: Optional[bool] = None,
1112
+ **kwargs,
1111
1113
  ) -> Union[tuple, UniSpeechForPreTrainingOutput]:
1112
1114
  r"""
1113
1115
  Example:
@@ -1255,6 +1257,7 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
1255
1257
  output_hidden_states: Optional[bool] = None,
1256
1258
  return_dict: Optional[bool] = None,
1257
1259
  labels: Optional[torch.Tensor] = None,
1260
+ **kwargs,
1258
1261
  ) -> Union[tuple, CausalLMOutput]:
1259
1262
  r"""
1260
1263
  labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
@@ -1366,6 +1369,7 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
1366
1369
  output_hidden_states: Optional[bool] = None,
1367
1370
  return_dict: Optional[bool] = None,
1368
1371
  labels: Optional[torch.Tensor] = None,
1372
+ **kwargs,
1369
1373
  ) -> Union[tuple, SequenceClassifierOutput]:
1370
1374
  r"""
1371
1375
  input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -244,6 +244,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel, Wav2Vec2Model):
244
244
  output_attentions: Optional[bool] = None,
245
245
  output_hidden_states: Optional[bool] = None,
246
246
  return_dict: Optional[bool] = None,
247
+ **kwargs,
247
248
  ) -> Union[tuple, UniSpeechBaseModelOutput]:
248
249
  r"""
249
250
  mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -351,6 +352,7 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
351
352
  output_attentions: Optional[bool] = None,
352
353
  output_hidden_states: Optional[bool] = None,
353
354
  return_dict: Optional[bool] = None,
355
+ **kwargs,
354
356
  ) -> Union[tuple, UniSpeechForPreTrainingOutput]:
355
357
  r"""
356
358
  Example:
@@ -1006,6 +1006,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
1006
1006
  output_attentions: Optional[bool] = None,
1007
1007
  output_hidden_states: Optional[bool] = None,
1008
1008
  return_dict: Optional[bool] = None,
1009
+ **kwargs,
1009
1010
  ) -> Union[tuple, UniSpeechSatBaseModelOutput]:
1010
1011
  r"""
1011
1012
  mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1120,6 +1121,7 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
1120
1121
  output_attentions: Optional[bool] = None,
1121
1122
  output_hidden_states: Optional[bool] = None,
1122
1123
  return_dict: Optional[bool] = None,
1124
+ **kwargs,
1123
1125
  ) -> Union[tuple, UniSpeechSatForPreTrainingOutput]:
1124
1126
  r"""
1125
1127
  Example:
@@ -1251,6 +1253,7 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
1251
1253
  output_hidden_states: Optional[bool] = None,
1252
1254
  return_dict: Optional[bool] = None,
1253
1255
  labels: Optional[torch.Tensor] = None,
1256
+ **kwargs,
1254
1257
  ) -> Union[tuple, CausalLMOutput]:
1255
1258
  r"""
1256
1259
  labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
@@ -1362,6 +1365,7 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
1362
1365
  output_hidden_states: Optional[bool] = None,
1363
1366
  return_dict: Optional[bool] = None,
1364
1367
  labels: Optional[torch.Tensor] = None,
1368
+ **kwargs,
1365
1369
  ) -> Union[tuple, SequenceClassifierOutput]:
1366
1370
  r"""
1367
1371
  input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -1465,6 +1469,7 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
1465
1469
  output_attentions: Optional[bool] = None,
1466
1470
  output_hidden_states: Optional[bool] = None,
1467
1471
  return_dict: Optional[bool] = None,
1472
+ **kwargs,
1468
1473
  ) -> Union[tuple, TokenClassifierOutput]:
1469
1474
  r"""
1470
1475
  input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -1636,6 +1641,7 @@ class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel):
1636
1641
  output_hidden_states: Optional[bool] = None,
1637
1642
  return_dict: Optional[bool] = None,
1638
1643
  labels: Optional[torch.Tensor] = None,
1644
+ **kwargs,
1639
1645
  ) -> Union[tuple, XVectorOutput]:
1640
1646
  r"""
1641
1647
  input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -255,6 +255,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel, Wav2Vec2Model):
255
255
  output_attentions: Optional[bool] = None,
256
256
  output_hidden_states: Optional[bool] = None,
257
257
  return_dict: Optional[bool] = None,
258
+ **kwargs,
258
259
  ) -> Union[tuple, UniSpeechSatBaseModelOutput]:
259
260
  r"""
260
261
  mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -369,6 +370,7 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
369
370
  output_attentions: Optional[bool] = None,
370
371
  output_hidden_states: Optional[bool] = None,
371
372
  return_dict: Optional[bool] = None,
373
+ **kwargs,
372
374
  ) -> Union[tuple, UniSpeechSatForPreTrainingOutput]:
373
375
  r"""
374
376
  Example:
@@ -476,6 +476,7 @@ class UnivNetModel(PreTrainedModel):
476
476
  padding_mask: Optional[torch.FloatTensor] = None,
477
477
  generator: Optional[torch.Generator] = None,
478
478
  return_dict: Optional[bool] = None,
479
+ **kwargs,
479
480
  ) -> Union[tuple[torch.FloatTensor], UnivNetModelOutput]:
480
481
  r"""
481
482
  noise_sequence (`torch.FloatTensor`, *optional*):
@@ -301,6 +301,7 @@ class UperNetForSemanticSegmentation(UperNetPreTrainedModel):
301
301
  output_hidden_states: Optional[bool] = None,
302
302
  labels: Optional[torch.Tensor] = None,
303
303
  return_dict: Optional[bool] = None,
304
+ **kwargs,
304
305
  ) -> Union[tuple, SemanticSegmenterOutput]:
305
306
  r"""
306
307
  labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
@@ -29,7 +29,7 @@ from ... import initialization as init
29
29
  from ...activations import ACT2FN
30
30
  from ...cache_utils import Cache, DynamicCache
31
31
  from ...generation import GenerationMixin
32
- from ...integrations import use_kernel_func_from_hub
32
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
33
33
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
34
34
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
35
35
  from ...modeling_layers import GradientCheckpointingLayer
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
38
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
39
  from ...processing_utils import Unpack
40
40
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
41
- from ...utils.generic import check_model_inputs
41
+ from ...utils.generic import check_model_inputs, maybe_autocast
42
42
  from .configuration_vaultgemma import VaultGemmaConfig
43
43
 
44
44
 
@@ -160,6 +160,7 @@ def eager_attention_forward(
160
160
  return attn_output, attn_weights
161
161
 
162
162
 
163
+ @use_kernelized_func(apply_rotary_pos_emb)
163
164
  class VaultGemmaAttention(nn.Module):
164
165
  """Multi-headed attention from 'Attention Is All You Need' paper"""
165
166
 
@@ -186,7 +187,6 @@ class VaultGemmaAttention(nn.Module):
186
187
  self.o_proj = nn.Linear(
187
188
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
188
189
  )
189
- self.rotary_fn = apply_rotary_pos_emb
190
190
  self.attn_logit_softcapping = self.config.attn_logit_softcapping
191
191
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
192
192
 
@@ -297,7 +297,7 @@ class VaultGemmaRotaryEmbedding(nn.Module):
297
297
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
298
298
 
299
299
  self.register_buffer("inv_freq", inv_freq, persistent=False)
300
- self.original_inv_freq = inv_freq
300
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
301
301
 
302
302
  @staticmethod
303
303
  def compute_default_rope_parameters(
@@ -336,7 +336,7 @@ class VaultGemmaRotaryEmbedding(nn.Module):
336
336
  position_ids_expanded = position_ids[:, None, :].float()
337
337
 
338
338
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
339
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
339
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
340
340
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
341
341
  emb = torch.cat((freqs, freqs), dim=-1)
342
342
  cos = emb.cos() * self.attention_scaling
@@ -154,8 +154,9 @@ class VideoLlama3ImageProcessor(BaseImageProcessor):
154
154
  **kwargs,
155
155
  ) -> None:
156
156
  super().__init__(**kwargs)
157
- if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
158
- raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
157
+ if size is not None:
158
+ if "shortest_edge" not in size or "longest_edge" not in size:
159
+ raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
159
160
  else:
160
161
  size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
161
162
  # backward compatibility: override size with min_pixels and max_pixels if they are provided